Describe the bug
The tensorflow_probability package does not work with JAX=0.7.0 as described here. The suggestion is to replace tensorflow_probability with tfp-nightly.
Expected behavior
I should be able to run import jaxns without issues
Observed behavior
Couldn't run import jaxns successfully
$ uv venv --python 3.12
$ source .venv/bin/activate
$ uv pip install jaxns
$ python -c 'import jaxns'
⋮
AttributeError: jax.interpreters.xla.pytype_aval_mappings was deprecated in JAX v0.5.0 and removed in JAX v0.7.0. jax.core.pytype_aval_mappings can be used as a replacement in most cases.
However, if I replace tensorflow_probability with tfp_nightly
$ uv pip uninstall tensorflow_probability
$ uv pip install tfp-nightly
$ python -c 'import jaxns'
Minimal Verifiable Complete Example
$ uv venv --python 3.12
$ source .venv/bin/activate
$ uv pip install jaxns
$ python -c 'import jaxns'
Screenshots
If applicable, add screenshots to help explain your problem.
JAXNS version
Output of pip freeze | grep jaxns:
Describe the bug
The tensorflow_probability package does not work with JAX=0.7.0 as described here. The suggestion is to replace tensorflow_probability with tfp-nightly.
Expected behavior
I should be able to run
import jaxnswithout issuesObserved behavior
Couldn't run
import jaxnssuccessfullyHowever, if I replace tensorflow_probability with tfp_nightly
Minimal Verifiable Complete Example
Screenshots
If applicable, add screenshots to help explain your problem.
JAXNS version
Output of
pip freeze | grep jaxns: