I'm trying to understand JAX's auto-vectorization capabilities using vmap
and implemented a minimal working example based on JAX's documentation.
I don't understand how in_axes
is used correctly. In the example below I can set in_axes=(None, 0)
or in_axes=(None, 1)
leading to the same results. Why is that the case?
And why do I have to use in_axes=(None, 0)
and not something like in_axes=(0, )
?
import jax.numpy as jnp
from jax import vmap
def predict(params, input_vec):
assert input_vec.ndim == 1
activations = input_vec
for W, b in params:
outputs = jnp.dot(W, activations) + b
activations = jnp.tanh(outputs)
return outputs
if __name__ == "__main__":
# Parameters
dims = [2, 3, 5]
input_dims = dims[0]
batch_size = 2
# Weights
params = list()
for dims_in, dims_out in zip(dims, dims[1:]):
params.append((jnp.ones((dims_out, dims_in)), jnp.ones((dims_out,))))
# Input data
input_batch = jnp.ones((batch_size, input_dims))
# With vmap
predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)
print(predictions)