Skip to content

Indexing numpy array with DeviceArray: index interpreted as tuple #620

@joschu

Description

@joschu

When you try to index a numpy ndarray with a DeviceArray, the numpy array tries to interpret the jax array as a tuple.

import numpy as onp
import jax.numpy as np
x = onp.zeros((5,7))
np_idx = onp.array([1,2,3])
jax_idx = np.array([1,2,3])
x[np_idx]
x[jax_idx] # <- raises IndexError

Workaround: put jax_idx in a singleton tuple x[(jax_idx,)]

This bug resulted in a confusing situation where my function worked when decorated by jax.jit but had a shape mismatch when called on a numpy array.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions