np.newaxis with Numba nopython
Asked Answered
M

3

7

Is there a way to use np.newaxis with Numba nopython ? In order to apply broadcasting function without fallbacking on python ?

for example

@jit(nopython=True)
def toto():
    a = np.random.randn(20, 10)
    b = np.random.randn(20) 
    c = np.random.randn(10)
    d = a - b[:, np.newaxis] * c[np.newaxis, :]
    return d

Thanks

Motch answered 4/8, 2016 at 7:21 Comment(0)
B
9

In my casse (numba: 0.35, numpy: 1.14.0) expand_dims works fine:

import numpy as np
from numba import jit

@jit(nopython=True)
def toto():
    a = np.random.randn(20, 10)
    b = np.random.randn(20) 
    c = np.random.randn(10)
    d = a - np.expand_dims(b, -1) * np.expand_dims(c, 0)
    return d

Of course we can omit the second expand_dims using broadcasting.

Brutify answered 10/5, 2019 at 12:48 Comment(0)
B
8

You can accomplish this using reshape, it looks like the [:, None] indexing isn't currently supported. Note that this probably won't be much faster than doing it python, since it was already vectorized.

@jit(nopython=True)
def toto():
    a = np.random.randn(20, 10)
    b = np.random.randn(20) 
    c = np.random.randn(10)
    d = a - b.reshape((-1, 1)) * c.reshape((1,-1))
    return d
Bushmaster answered 4/8, 2016 at 11:44 Comment(2)
I have tried it but I get : reshape() supports contiguous array only. And of course, toto() is an example not my actual functionMotch
You could do b.copy().reshape((-1,1)). If your array isn't contiguous I believe this would have copied anyways, though not 100% sure.Bushmaster
C
1

This can be done with the newest version of Numba (0.27) and numpy stride_tricks. You need to be careful with this and it's a bit ugly. Read the docstring for as_strided to make sure you understand what's going on since this isn't "safe" since it doesn't check the shape or the strides.

import numpy as np
import numba as nb

a = np.random.randn(20, 10)
b = np.random.randn(20) 
c = np.random.randn(10)

def toto(a, b, c):

    d = a - b[:, np.newaxis] * c[np.newaxis, :]
    return d

@nb.jit(nopython=True)
def toto2(a, b, c):
    _b = np.lib.stride_tricks.as_strided(b, shape=(b.shape[0], 1), strides=(b.strides[0], 0))
    _c = np.lib.stride_tricks.as_strided(c, shape=(1, c.shape[0]), strides=(0, c.strides[0]))
    d = a - _b * _c

    return d

x = toto(a,b,c)
y = toto2(a,b,c)
print np.allclose(x, y) # True
Customs answered 5/8, 2016 at 16:36 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.