Skip to content

[Pallas] Support Flash Attention backward kernels#6870

Merged
alanwaketan merged 4 commits intomasterfrom
alanwaketan/fa_backward
Apr 2, 2024
Merged

[Pallas] Support Flash Attention backward kernels#6870
alanwaketan merged 4 commits intomasterfrom
alanwaketan/fa_backward

Conversation

@alanwaketan
Copy link
Copy Markdown
Collaborator

Summary:
This changes refactors custom_kernel.py to support all three new kernels from Pallas that are involved in Flash Attention backward calculations.

The refactoring includes:

  1. Adds support for static_argnums which will ignore some positional arguments for jax tracing.
  2. Separate jax tracing part out such that we can do the tracing alone.

Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py

@alanwaketan alanwaketan requested review from JackCaoG and lsy323 April 2, 2024 05:50
@alanwaketan alanwaketan self-assigned this Apr 2, 2024
Comment thread test/test_pallas.py
xm.mark_step()

# TODO: I don't really know how to test the value. Let's do the shape check for now.
self.assertEqual(grad_q.shape, (3, 2, 128, 4))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we do the fwd and do res.backward then check the grad on q they should match?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The softmax is done differently. I don't think there is any guarantees.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

result should still be somewhat close right? we can tune down the precision. If the result return by this is dramatically different than the one that was computed using dot attention that seems wrong..

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://pytorch.org/docs/stable/generated/torch.nn.Softmax.html

Softmax requires all the elements to produce the results, but flash attention chunks the data into blocks and use a technique called tiling to make sure the softmax still serve the functionality to stable the data. Since there are no aggregation, I don't know how the tiling softmax could produce the same results as the regular one.

In JAX, I have to use atol=1e-01, rtol=1e-01 to do the comparisons...

@alanwaketan
Copy link
Copy Markdown
Collaborator Author

Thanks, Jack!

@alanwaketan alanwaketan merged commit c54367c into master Apr 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants