When you try to index a numpy ndarray with a DeviceArray, the numpy array tries to interpret the jax array as a tuple.
import numpy as onp
import jax.numpy as np
x = onp.zeros((5,7))
np_idx = onp.array([1,2,3])
jax_idx = np.array([1,2,3])
x[np_idx]
x[jax_idx] # <- raises IndexError
Workaround: put jax_idx in a singleton tuple x[(jax_idx,)]
This bug resulted in a confusing situation where my function worked when decorated by jax.jit but had a shape mismatch when called on a numpy array.
When you try to index a numpy ndarray with a DeviceArray, the numpy array tries to interpret the jax array as a tuple.
Workaround: put
jax_idxin a singleton tuplex[(jax_idx,)]This bug resulted in a confusing situation where my function worked when decorated by jax.jit but had a shape mismatch when called on a numpy array.