Skip to content

[Blocked] [call_jax] Support multi-pod JAX/pallas mesh usage in call_jax #8972

@tengyifei

Description

@tengyifei

xb.call_jax should update the JAX mesh context object to contain the same devices and in the same order before entering the JAX function. This is necessary to ensure that any SPMD computation in JAX has the same semantics as the SPMD computation torch_xla.

There is some logic for that in the splash attention kernel and we can factor it out.


Impact

Currently, call_jax does not set the JAX mesh, and even if we manually convert the PyTorch/XLA mesh to the JAX mesh, that will only work on a single pod: 1.

As a result, Pallas kernels that use SPMD features only work on a single TPU pod. For example, the splash_attention kernel 2 only works on single pod.


Blocked status

In #9038, we realized this is blocked on letting JAX discover the same devices as PyTorch/XLA.

In multi-slice training, JAX right now doesn't discover all the slices due to b/374631442 and #8609 (comment)

We can't automatically set the JAX contextual mesh as that would break too many multi-slice workflows.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions