[Flex Attn][CPU] support flash decoding for cpu#159835
[Flex Attn][CPU] support flash decoding for cpu#159835Valentine233 wants to merge 32 commits intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159835
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 5d66b9f with merge base d67b279 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@jianan-gu @CaoE Please help review, thanks~ |
| ) | ||
| return self._template_from_string(FLEX_ATTENTION_TEMPLATE).render(**options) | ||
| if ( | ||
| query.data.data.layout.size[2] == 1 |
There was a problem hiding this comment.
we may double check the condition for selecting the flash decoding path, ref: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/kernel/flex/flex_decoding.py#L34
There was a problem hiding this comment.
Thanks! A function to choose flex template is added.
| SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE) | ||
| SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_Q_BLOCK_SIZE) | ||
| # In flash decoding, the partition size of doing the parallelism on KV length dim | ||
| PARTITION_SIZE = kernel_options.get("PARTITION_SIZE", 128) |
There was a problem hiding this comment.
Can we add more PARTITION_SIZE for testing ?
There was a problem hiding this comment.
Why flash decoding needs to set PARTITION_SIZE instead of automatically selecting a suitable PARTITION_SIZE?
There was a problem hiding this comment.
Thanks, the UT for PARTITION_SIZE is added.
Yes, it is possible to automatically select the PARTITION_SIZE, which depends on PAGE_SIZE, input shape and thread numbers. It is also the same case for PAGE_SIZE, which is a fixed value now. We can have a round of tuning for PAGE_SIZE and PARTITION_SIZE as a future work.
| {{kernel.kernel_name}}_conditional_data_ptr(logits, logits_reduced) + token_num, | ||
| v_addr, | ||
| tmp_out, | ||
| false); |
There was a problem hiding this comment.
we may also need to add back need_pack check, depending on the qsize threshold mentioned below.
There was a problem hiding this comment.
We only let qsize=1 enter flash decoding for now, and for this case we do not need packing.
cbfa6e4 to
b08d898
Compare
|
@ZainRizvi Hi, could you please help check if the UT duration is acceptable with this PR? Previously reverted in #158617. |
|
Hi @drisspg, |
87dcbad to
1746bb2
Compare
| self.partition_size % self.kv_block_size == 0 | ||
| and q_seq_len == 1 | ||
| and num_threads > q_batch_size * q_num_heads | ||
| and k_seq_len / q_batch_size >= max(self.partition_size * 2, 512) |
There was a problem hiding this comment.
Add comments to explain this formula?
| def score_mod(score, b, h, m, n): | ||
| return score * 2 | ||
|
|
||
| self.run_test_with_paged_attention( |
d79d17b to
5d66b9f
Compare
Description:
CppFlexAttentionTemplate. We prefer to choose flash decoding instead of flash attention when query length is 1.PARTITION_SIZEto define the partition size of doing the parallelism on KV length dimension. The default value is 128, which should be multiple of KV cache block size to use flash decoding.flex_attnUTs for the cpu backend are disabled because of the long duration. Here we re-enable them on CPU-only machines. (Already merged in Enable XPU path for FlexAttention #143553)Performance:
Here are the E2E results for Llama3.1-8B decoding validated on a GNR machine with 6 NUMA nodes, where we can see E2E improvements from
114%to121%.cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo @chenyang78