Conversation
|
|
||
| def test_call_jax_grad(self): | ||
| """ | ||
| Test that we can call a JAX function and retrive grad from PyTorch/XLA lazy |
| out.backward() | ||
| return a.grad, b.grad | ||
|
|
||
| _a = a.clone().requires_grad_(True) |
There was a problem hiding this comment.
What's the purpose of cloning from a to _a? Usually variable names like _a is reserved for variables that are never read. Also I wonder if we could avoid the cloning and just use a. If you wanted to avoid accumulating gradient into a, then the right incantation is a.clone().detach().requires_grad_(True).
| out_grad_torch = f_backward_torch(f_torch, _a, _b) | ||
|
|
||
| torch_xla.sync() | ||
| out_torch.detach() |
There was a problem hiding this comment.
This detach is a no-op. detach returns a new tensor and keeps the original tensor unchanged. This line can probably be removed.
|
|
||
| torch_xla.sync() | ||
| out_torch.detach() | ||
| out_grad_torch = [g.detach() for g in out_grad_torch] |
There was a problem hiding this comment.
If the only thing we're doing next is comparing these tensor against the call_jax version, then we probably don't need to detach either. Detach returns a new tensor and also stops the gradient flow. If we're no longer doing any gradient propagation, then it's useless.
Demonstrate that we can run jax.grad with
call_jax. This can be helpful to port the Splash Attention kernel from Jax into torch_xla directly.