Skip to content

call_jax doesn't take jax config into hashing. #8963

@zpcore

Description

@zpcore

🐛 Bug

call_jax doesn't take jax config into hashing.

Detail

JAX config changes (https://github.com/jax-ml/jax/blob/3864c4f335d1d236d5367264f3885dfce8721d9d/jax/_src/config.py#L254) will not be reflected in the call_jax function argument. However, the config will be embedded in the HLO level (e.g., data precision), which potentially causes computations with different JAX config to reuse the same HLO.

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingtracingLazy Tensor tracing

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions