xb.call_jax should update the JAX mesh context object to contain the same devices and in the same order before entering the JAX function. This is necessary to ensure that any SPMD computation in JAX has the same semantics as the SPMD computation torch_xla.
There is some logic for that in the splash attention kernel and we can factor it out.
Impact
Currently, call_jax does not set the JAX mesh, and even if we manually convert the PyTorch/XLA mesh to the JAX mesh, that will only work on a single pod: 1.
As a result, Pallas kernels that use SPMD features only work on a single TPU pod. For example, the splash_attention kernel 2 only works on single pod.
Blocked status
In #9038, we realized this is blocked on letting JAX discover the same devices as PyTorch/XLA.
In multi-slice training, JAX right now doesn't discover all the slices due to b/374631442 and #8609 (comment)
We can't automatically set the JAX contextual mesh as that would break too many multi-slice workflows.
xb.call_jaxshould update the JAX mesh context object to contain the same devices and in the same order before entering the JAX function. This is necessary to ensure that any SPMD computation in JAX has the same semantics as the SPMD computation torch_xla.There is some logic for that in the splash attention kernel and we can factor it out.
Impact
Currently,
call_jaxdoes not set the JAX mesh, and even if we manually convert the PyTorch/XLA mesh to the JAX mesh, that will only work on a single pod: 1.As a result, Pallas kernels that use SPMD features only work on a single TPU pod. For example, the
splash_attentionkernel 2 only works on single pod.Blocked status
In #9038, we realized this is blocked on letting JAX discover the same devices as PyTorch/XLA.
In multi-slice training, JAX right now doesn't discover all the slices due to b/374631442 and #8609 (comment)
We can't automatically set the JAX contextual mesh as that would break too many multi-slice workflows.