Skip to content

[Pallas] Make FlashAttention as torch.autograd.Function#6886

Merged
alanwaketan merged 9 commits intomasterfrom
alanwaketan/fa_autograd
Apr 4, 2024
Merged

[Pallas] Make FlashAttention as torch.autograd.Function#6886
alanwaketan merged 9 commits intomasterfrom
alanwaketan/fa_autograd

Conversation

@alanwaketan
Copy link
Copy Markdown
Collaborator

Summary:
This pull request makes the flash attention kernel as a torch.autograd.Function such that we can enable backward on the kernel.

Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py -v -k test_flash_attention_backward

@alanwaketan alanwaketan requested review from JackCaoG and lsy323 April 4, 2024 01:11
@alanwaketan alanwaketan self-assigned this Apr 4, 2024
@alanwaketan alanwaketan force-pushed the alanwaketan/fa_autograd branch from 5b316ba to 02a15f0 Compare April 4, 2024 01:14
Comment thread torch_xla/experimental/custom_kernel.py Outdated
Comment on lines +237 to +239
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
jax_import_guard()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shouldn't need this right? The forward should have been called at this point and fwd/bwd happens in the same process.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right!

Comment thread torch_xla/experimental/custom_kernel.py Outdated
"block_q_major", "block_k_major", "block_k", "sm_scale", "causal",
"mask_value", "debug"
])
grad_q = torch.empty(q.shape, dtype=q.dtype).to(xm.xla_device())
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be fine for the most part, but I think it is better to do .to(q.deivce) instead of move to xm.xla_device(). Should we add a check somewhere to make sure all tensors on XLA device?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed the .to(). For the second question, I guess Mosaic already guard it?

Comment thread torch_xla/experimental/custom_kernel.py Outdated
o.to(torch.float32) * grad_output.to(torch.float32),
axis=-1) # [batch_size, num_heads, q_seq_len]

expanded_l = l.unsqueeze(-1).expand(3, 2, 128,
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, I shouldn't hardcode the shape...

Comment thread test/test_pallas.py
Comment on lines +487 to +489
mse = torch.nn.MSELoss()
for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
self.assertTrue(mse(i[0].grad.cpu(), i[1].cpu()) < 1e-4)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what this part is checking, do you mind explaining a bit?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the gradients are a little bit off, it's hard to use torch.allclose. I'm just trying to use MSE to calculate the difference to see if it's close to zero.

@alanwaketan alanwaketan merged commit 0c704cf into master Apr 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants