Implement overrides of NumPy's public API on JAX arrays#611
Implement overrides of NumPy's public API on JAX arrays#611shoyer wants to merge 37 commits intojax-ml:masterfrom
Conversation
`__array_ufunc__` allows for writing NumPy's ufuncs, e.g., `onp.sin()`. `__array_function__` is a new, experimental override for most other functions in NumPy public API, e.g., `onp.concatenate()`. It will be enabled by default in NumPy 1.17, but is also available in NumPy 1.16 if you set the environment variable `NUMPY_EXPERIMENTAL_ARRAY_FUNCTION=1` before importing NumPy. Together, these should allow users to stick with `import numpy as np` for use with JAX, instead of requiring `import jax.numpy as np`. I expect this will be particularly useful for projects that want to remain implementation agnostic, e.g., so they can write functions that will run without changes on JAX, CuPy and Dask arrays. Note: if you want to test this out in Colab, I think you need to install the development version of NumPy (e.g., `pip install -U git+https://github.com/numpy/numpy.git`). As far as I can tell, it isn't possible to set an environment variable from Colab before importing NumPy.
|
I added |
|
NumPy 1.17 is out, so these overrides will work by default now. It would be nice merge this soonish, if only so I don't need to continue to rebase :). Also I can't wait to stop writing |
|
Thanks, Stephan! Would we still need to More importantly, this would change the behavior for anyone using Those might be good things, but I just want to make sure I’m understanding. It would be a nontrivial api change. |
Yes, that would work. More practically, the easy way to ensure functions get executed with JAX is to use
That's right, this would be a breaking change for such code. Instead, users will need to write For the most part I expect this will should be fine -- JAX has mostly equivalent implementations of most commonly used NumPy functions. One noteworthy case are functions like |
__array_ufunc__allows for writing NumPy's ufuncs, e.g.,onp.sin().__array_function__is a new, experimental override for most other functions in NumPy public API, e.g.,onp.concatenate(). It will be enabled by default in NumPy 1.17, but is also available in NumPy 1.16 if you set the environment variableNUMPY_EXPERIMENTAL_ARRAY_FUNCTION=1before importing NumPy.Together, these should allow users to stick with
import numpy as npfor use with JAX, instead of requiringimport jax.numpy as np. I expect this will be particularly useful for projects that want to remain implementation agnostic, e.g., so they can write functions that will run without changes on JAX, CuPy and Dask arrays.Note: if you want to test this out in Colab, I think you need to install the development version of NumPy (e.g.,
pip install -U git+https://github.com/numpy/numpy.git). As far as I can tell, it isn't possible to set an environment variable from Colab before importing NumPy.