Skip to content

tensorflow_probability is incompatible with jax=0.7.0 #235

@tare

Description

@tare

Describe the bug
The tensorflow_probability package does not work with JAX=0.7.0 as described here. The suggestion is to replace tensorflow_probability with tfp-nightly.

Expected behavior
I should be able to run import jaxns without issues

Observed behavior
Couldn't run import jaxns successfully

$ uv venv --python 3.12
$ source .venv/bin/activate
$ uv pip install jaxns
$ python -c 'import jaxns'
⋮
AttributeError: jax.interpreters.xla.pytype_aval_mappings was deprecated in JAX v0.5.0 and removed in JAX v0.7.0. jax.core.pytype_aval_mappings can be used as a replacement in most cases.

However, if I replace tensorflow_probability with tfp_nightly

$ uv pip uninstall tensorflow_probability
$ uv pip install tfp-nightly
$ python -c 'import jaxns'

Minimal Verifiable Complete Example

$ uv venv --python 3.12
$ source .venv/bin/activate
$ uv pip install jaxns
$ python -c 'import jaxns'

Screenshots
If applicable, add screenshots to help explain your problem.

JAXNS version
Output of pip freeze | grep jaxns:

jaxns==2.6.8

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions