From the JAX docs:
import jax
import jax.numpy as jnp
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = jnp.arange(1000000)
selu(x)
"The code above is sending one operation at a time to the accelerator. This limits the ability of the XLA compiler to optimize our functions."
The docs then proceed to wrap selu
in jit
:
selu_jit = jax.jit(selu)
selu_jit(x)
And for some reason this improves performance significantly.
Why is jit
even needed here? More specifically, why is the original code "sending one operation at a time to the accelerator"?
I was under the impression that jax.numpy
is meant for this exact purpose, oherwise we might as well be using plain old numpy
? What was wrong with the original selu
?
Thanks!