Skip to content

Minimal support for calling JAX from PyTorch/XLA#8781

Merged
tengyifei merged 6 commits intomasterfrom
hanq_jax_torchxla_0302
Mar 4, 2025
Merged

Minimal support for calling JAX from PyTorch/XLA#8781
tengyifei merged 6 commits intomasterfrom
hanq_jax_torchxla_0302

Conversation

@tengyifei
Copy link
Copy Markdown
Collaborator

The trick is to turn JAX into HLO and weave that into an existing PyTorch/XLA lazy tensor graph.

So far we just test a simple numerical function.

qihqi and others added 4 commits March 3, 2025 10:40
The key is to use `as_serialized_hlo_module_proto`. Our self cooked
`_xla_computation_text_to_proto` causes an undefined op reference error.
@tengyifei tengyifei requested a review from qihqi March 4, 2025 00:55
@tengyifei tengyifei marked this pull request as ready for review March 4, 2025 00:55
@tengyifei tengyifei enabled auto-merge (squash) March 4, 2025 03:50
@tengyifei
Copy link
Copy Markdown
Collaborator Author

Our CPU CI doesn't have JAX and torchax so I'm installing it like this: https://github.com/pytorch/xla/pull/8781/files#diff-d30e144e1a94b9125c97915d4bd55eeb00d136257f96c9c753c255b66b1c00b4

@tengyifei tengyifei merged commit 17270e2 into master Mar 4, 2025
@miladm miladm self-assigned this Mar 5, 2025
@miladm
Copy link
Copy Markdown
Collaborator

miladm commented Mar 5, 2025

What happens if we tie torch.aotautograd.grad() to here? won't work?

@miladm
Copy link
Copy Markdown
Collaborator

miladm commented Mar 5, 2025

I guess we are skipping this op :)
This is a great exploration

  • I'd love to understand how non-functional code behaves in this code path?
  • How do we make tracing decisions on deta-dependent conditional operations?

@qihqi @tengyifei

@tengyifei
Copy link
Copy Markdown
Collaborator Author

tengyifei commented Mar 5, 2025

What happens if we tie torch.aotautograd.grad() to here? won't work?

We'll have to option to either use this with AOTAutograd, or avoid AOTAutograd.

If we would like to use AOTAutograd, we can use AOTAutograd to turn a PyTorch function into FX Graph, then use torchax to turn the FX Graph to a JAX function, then use call_jax to turn the JAX function into a LazyTensor node.

If we would like to avoid AOTAutograd, we'll use torchax to turn a PyTorch function into a JAX function, use jax.grad to get the backward (also as a JAX function), then use call_jax to turn the JAX function into a LazyTensor node.

In the latter approach, we'll be able to use powerful JAX remat features like checkpoint_name (checkpoint or offload an arbitrary tensor).

I'd love to understand how non-functional code behaves in this code path?

In the general case, this won't support non-functional code. For example if some PyTorch function inserts a tensor into a global list, we'll just end up inserting a JAX tracer into the global list. But AOTAutograd has the same constraint. This limitation does not exist when only using LazyTensor because that framework blurs the boundary between tracing/execution at the cost of accidental graph breaks.

My hypothesis is that most high performance models (esp. ones in torchprime) won't do weird stuff like that.

We'll need to find some other solution to cover the long tail e.g. use dynamo.

How do we make tracing decisions on data-dependent conditional operations?

We can't do that (e.g. if (tensor) { ... }) in jax.jit. It'll raise a Python exception. To be clear, we also can't efficiently do that in LazyTensor as you'll graph break.

Similarly, my hypothesis is that high performance models won't do a data dependent branch. They need to rewrite the if into a jax.lax.cond.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants