[Pallas] Support Flash Attention backward kernels#6870
Conversation
| xm.mark_step() | ||
|
|
||
| # TODO: I don't really know how to test the value. Let's do the shape check for now. | ||
| self.assertEqual(grad_q.shape, (3, 2, 128, 4)) |
There was a problem hiding this comment.
if we do the fwd and do res.backward then check the grad on q they should match?
There was a problem hiding this comment.
The softmax is done differently. I don't think there is any guarantees.
There was a problem hiding this comment.
result should still be somewhat close right? we can tune down the precision. If the result return by this is dramatically different than the one that was computed using dot attention that seems wrong..
There was a problem hiding this comment.
https://pytorch.org/docs/stable/generated/torch.nn.Softmax.html
Softmax requires all the elements to produce the results, but flash attention chunks the data into blocks and use a technique called tiling to make sure the softmax still serve the functionality to stable the data. Since there are no aggregation, I don't know how the tiling softmax could produce the same results as the regular one.
In JAX, I have to use atol=1e-01, rtol=1e-01 to do the comparisons...
|
Thanks, Jack! |
Summary:
This changes refactors custom_kernel.py to support all three new kernels from Pallas that are involved in Flash Attention backward calculations.
The refactoring includes:
Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py