-
Notifications
You must be signed in to change notification settings - Fork 279
Incompatibility with jax 0.7.0 #2051
Copy link
Copy link
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
import numpyro fails with jax 0.7.0 with the following error
>>> import numpyro
ImportError: cannot import name 'pjit_p' from 'jax.experimental.pjit' (/Users/damon/Documents/GitHub/stf-pyrenew-added-value/.venv/lib/python3.13/site-packages/jax/experimental/pjit.py)This is because of the following note from the jax 0.7.0 release:
The
jax.extend.core.primitives.pjit_pprimitive has been renamed tojit_p, and itsnameattribute has changed from"pjit"to"jit". This affects the string representations of jaxprs. The same primitive is no longer exported from thejax.experimental.pjitmodule.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working