Why is JAX jit needed for jax.numpy operations?
Asked Answered
N

1

5

From the JAX docs:

import jax
import jax.numpy as jnp

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(1000000)

selu(x)

"The code above is sending one operation at a time to the accelerator. This limits the ability of the XLA compiler to optimize our functions."

The docs then proceed to wrap selu in jit:

selu_jit = jax.jit(selu)

selu_jit(x)

And for some reason this improves performance significantly.

Why is jit even needed here? More specifically, why is the original code "sending one operation at a time to the accelerator"?

I was under the impression that jax.numpy is meant for this exact purpose, oherwise we might as well be using plain old numpy? What was wrong with the original selu?

Thanks!

Narine answered 16/1, 2023 at 18:43 Comment(0)
C
6

Edit: after a short discussion below I realized a more concise answer to the original question: JAX uses eager computations by default; if you want lazy evaluation—what's sometimes called graph mode in other packages—you can specify this by wrapping your function in jax.jit.


Python is an interpreted language, which means that statements are executed one at a time. This is the sense in which the un-jitted code is sending one operation at a time to the compiler: each statement must execute and return a value before the interpreter runs the next.

Within a jit-compiled function, JAX replaces arrays with abstract tracers in order to determine the full sequence of operations in the function, and to send them all to XLA for compilation, where the operations may be rearranged or transformed by the compiler to make the overall execution more efficient.

The reason we use jax.numpy rather than normal numpy is because jax.numpy operations work with the JIT tracer machinery, whereas normal numpy operations do not.

For a high-level intro to how JAX and its transforms work, a good place to start is How To Think In JAX.

Cribriform answered 16/1, 2023 at 20:28 Comment(6)
Thanks Jake, glad this caught your attention :) I guess the main source of confusion for me is which operations are being executed one at a time? From what I can tell, there are only 2 operations in the function: 1. jnp.where(...) 2. lambda + #1. I interpret those as single "atomic" operations that JAX can handle efficiently (rather than a 1000000 operations), Is this correct?Narine
I count six or more distinct jax.numpy operations in that one line: jnp.exp, jnp.multiply, jnp.subtract, jnp.gt, jnp.where, and jnp.multiply again, plus several implicit dtype conversions. In normal (interpreted) execution, each of these executes individually in sequence. and each returns its output in a freshly-allocated buffer that is passed to the next operation. When you JIT-compile it, XLA can fuse and re-arrange operations, and only needs to allocate a single new buffer for the final output.Cribriform
Ah, I was under the impression that JAX knows how to fuse things like alpha * jnp.exp(x) into one operation, that it builds a graph on its own. Isn't it technically feasible?Narine
alpha * jnp.exp(x) is two python statements: tmp = jnp.exp(x) followed by alpha * tmp. Because of the way Python evaluates programs, JAX cannot do any optimization across the boundary of these statements, except via jit, which is the mechanism specifically designed to allow JAX to do this sort of optimization.Cribriform
Technically you could design a JAX-like library to execute in a manner similar to Tensorflow v1's "graph mode", where normal operations don't compute anything but rather build a graph, with some eval step at the end to mark it for execution, but that's not how JAX is designed. In this parlance, JAX is effectively eager-mode by default, with the jit transform to explicitly opt-in to graph mode.Cribriform
This was indeed what I assumed, that a lazy operations graph was being built even without jit. Thanks!Narine

© 2022 - 2024 — McMap. All rights reserved.