[Pallas] Make FlashAttention as torch.autograd.Function#6886
[Pallas] Make FlashAttention as torch.autograd.Function#6886alanwaketan merged 9 commits intomasterfrom
Conversation
5b316ba to
02a15f0
Compare
| # Import JAX within the function such that we don't need to call the jax_import_guard() | ||
| # in the global scope which could cause problems for xmp.spawn. | ||
| jax_import_guard() |
There was a problem hiding this comment.
we shouldn't need this right? The forward should have been called at this point and fwd/bwd happens in the same process.
| "block_q_major", "block_k_major", "block_k", "sm_scale", "causal", | ||
| "mask_value", "debug" | ||
| ]) | ||
| grad_q = torch.empty(q.shape, dtype=q.dtype).to(xm.xla_device()) |
There was a problem hiding this comment.
it should be fine for the most part, but I think it is better to do .to(q.deivce) instead of move to xm.xla_device(). Should we add a check somewhere to make sure all tensors on XLA device?
There was a problem hiding this comment.
Fixed the .to(). For the second question, I guess Mosaic already guard it?
| o.to(torch.float32) * grad_output.to(torch.float32), | ||
| axis=-1) # [batch_size, num_heads, q_seq_len] | ||
|
|
||
| expanded_l = l.unsqueeze(-1).expand(3, 2, 128, |
There was a problem hiding this comment.
Oops, I shouldn't hardcode the shape...
| mse = torch.nn.MSELoss() | ||
| for i in [(q, q_grad), (k, k_grad), (v, v_grad)]: | ||
| self.assertTrue(mse(i[0].grad.cpu(), i[1].cpu()) < 1e-4) |
There was a problem hiding this comment.
I am not sure what this part is checking, do you mind explaining a bit?
There was a problem hiding this comment.
Since the gradients are a little bit off, it's hard to use torch.allclose. I'm just trying to use MSE to calculate the difference to see if it's close to zero.
Summary:
This pull request makes the flash attention kernel as a torch.autograd.Function such that we can enable backward on the kernel.
Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py -v -k test_flash_attention_backward