Skip to content

Showcase jax.grad in torch_xla#8800

Open
zpcore wants to merge 3 commits intomasterfrom
piz/imp_test
Open

Showcase jax.grad in torch_xla#8800
zpcore wants to merge 3 commits intomasterfrom
piz/imp_test

Conversation

@zpcore
Copy link
Copy Markdown
Member

@zpcore zpcore commented Mar 5, 2025

Demonstrate that we can run jax.grad with call_jax. This can be helpful to port the Splash Attention kernel from Jax into torch_xla directly.

@zpcore zpcore requested review from qihqi and tengyifei March 5, 2025 22:03
@zpcore zpcore changed the title show case use jax.grad in torch_xla Showcase jax.grad in torch_xla Mar 5, 2025
Comment thread test/test_jax_interop.py
@zpcore zpcore marked this pull request as ready for review March 6, 2025 00:53
Comment thread test/test_jax_interop.py

def test_call_jax_grad(self):
"""
Test that we can call a JAX function and retrive grad from PyTorch/XLA lazy
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: retrieve

Comment thread test/test_jax_interop.py
out.backward()
return a.grad, b.grad

_a = a.clone().requires_grad_(True)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of cloning from a to _a? Usually variable names like _a is reserved for variables that are never read. Also I wonder if we could avoid the cloning and just use a. If you wanted to avoid accumulating gradient into a, then the right incantation is a.clone().detach().requires_grad_(True).

Comment thread test/test_jax_interop.py
out_grad_torch = f_backward_torch(f_torch, _a, _b)

torch_xla.sync()
out_torch.detach()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This detach is a no-op. detach returns a new tensor and keeps the original tensor unchanged. This line can probably be removed.

Comment thread test/test_jax_interop.py

torch_xla.sync()
out_torch.detach()
out_grad_torch = [g.detach() for g in out_grad_torch]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the only thing we're doing next is comparing these tensor against the call_jax version, then we probably don't need to detach either. Detach returns a new tensor and also stops the gradient flow. If we're no longer doing any gradient propagation, then it's useless.

tengyifei added a commit that referenced this pull request Mar 13, 2025
Fixes #8794.

I also took a test case in
#8800 and made it pass.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants