Fix a bug in flash attention where kv_seq_len should divide block_k_major.#8671
Fix a bug in flash attention where kv_seq_len should divide block_k_major.#8671zpcore merged 8 commits intopytorch:masterfrom
Conversation
|
Thanks for this change! |
| k_padded if k_pad_size > 0 else k, | ||
| v_padded if k_pad_size > 0 else v, | ||
| ab_padded if k_pad_size > 0 and ab is not None else ab, |
There was a problem hiding this comment.
I think we can let k,v,ab all go through _pad_to_block_size() and use k_padded, v_padded, ab_padded afterwards to simplify the logic.
zpcore
left a comment
There was a problem hiding this comment.
I remember we saw same issue in stable diffusion run. This is great, thanks for the fixing!
|
|
||
| block_k_major = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[2]) | ||
| block_k = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2]) | ||
| k_padded, k_pad_size = _pad_to_block_size(k, max(block_k_major, block_k), 2) |
There was a problem hiding this comment.
We probably need to do the same padding for backward pass. Let's see if the test can pass or not.
| @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, | ||
| "This test only works on TPUv3+.") | ||
| @with_jax_high_precision | ||
| def test_flash_attention_spmd_data_parallel_kv_and_ab_padding(self): |
There was a problem hiding this comment.
Can we also add a backward pass test to make sure q,v,k,ab grad are the same with self._attention output (similar to test_flash_attention_backward_aot_autograd_traceable)? I feel the backward pass needs the same update. Thanks
There was a problem hiding this comment.
I'm not very familiar with the backward, Can I just fix this forward bug first?
|
@zpcore Hi, I found that the tpu-test failed again. I've run the unit test locally, but still have some questions. and added a print. The output is: The diff is highly above the tolerance of 1e-5. the output diff is similar, and the test method also fails. So, I'm confused about why the diff is so large. Can you give me some advice? |
|
From test log: the bug is an existing bug? not from this fix? |
I ran your code with the test_pallas, and the test passes. The |
zpcore
left a comment
There was a problem hiding this comment.
LGTM! Thanks for fixing the shape.
|
ok, thanks for accepting the pr. |
When generating images with flash attention on TPU, a bug occurs with the following error message:
Cause:
This bug happens when the image resolution is not divisible by 512 on at least one side. Specifically, the sequence length (kv_seq_len) should be divisible by the block size (block_k_major, which is 512) for the flash attention mechanism to work correctly. In the error above, kv_seq_len=4992 is not divisible by 512, leading to this exception.
Solution:
To resolve this issue, we need to pad the k, v, and ab vectors to ensure that their lengths are divisible by the block sizes.