Skip to content

Fix an edge case issue when storing output in the paged attention.#8431

Merged
vanbasten23 merged 4 commits intomasterfrom
xiowei/fixStoreToOutputEdgeCase
Dec 5, 2024
Merged

Fix an edge case issue when storing output in the paged attention.#8431
vanbasten23 merged 4 commits intomasterfrom
xiowei/fixStoreToOutputEdgeCase

Conversation

@vanbasten23
Copy link
Copy Markdown
Collaborator

@vanbasten23 vanbasten23 commented Dec 3, 2024

Fix an edge case for the condition to run def store_to_output() in the internal flash attention.

(1) If kv_len=512, kv_seq_len_per_kv_compute_blk=256, then last kv_blk_idx that we need to store_to_output is 1.
(2) If kv_len=513, kv_seq_len_per_kv_compute_blk=256, then last kv_blk_idx that we need to store_to_output is 2.

Without the fix, the case (1) would fail.

Test plan: python pytorch/xla/test/test_tpu_paged_attention_kernel.py PagedAttentionKernelTest.test_paged_attention_store_to_output_correctly

@vanbasten23 vanbasten23 requested a review from Liyang90 December 4, 2024 03:28
@vanbasten23 vanbasten23 marked this pull request as ready for review December 4, 2024 03:28
@vanbasten23
Copy link
Copy Markdown
Collaborator Author

Thanks for the review, Liyang!

@vanbasten23 vanbasten23 merged commit a1a1145 into master Dec 5, 2024
@miladm miladm added the pallas label Mar 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants