🐛 Bug
call_jax doesn't take jax config into hashing.
Detail
JAX config changes (https://github.com/jax-ml/jax/blob/3864c4f335d1d236d5367264f3885dfce8721d9d/jax/_src/config.py#L254) will not be reflected in the call_jax function argument. However, the config will be embedded in the HLO level (e.g., data precision), which potentially causes computations with different JAX config to reuse the same HLO.
🐛 Bug
call_jaxdoesn't take jax config into hashing.Detail
JAX config changes (https://github.com/jax-ml/jax/blob/3864c4f335d1d236d5367264f3885dfce8721d9d/jax/_src/config.py#L254) will not be reflected in the call_jax function argument. However, the config will be embedded in the HLO level (e.g., data precision), which potentially causes computations with different JAX config to reuse the same HLO.