Skip to content

[Flex Attn][CPU] support flash decoding for cpu#159835

Open
Valentine233 wants to merge 32 commits intomainfrom
flash_decoding_cpu
Open

[Flex Attn][CPU] support flash decoding for cpu#159835
Valentine233 wants to merge 32 commits intomainfrom
flash_decoding_cpu

Conversation

@Valentine233
Copy link
Collaborator

@Valentine233 Valentine233 commented Aug 5, 2025

Description:

  1. Support flash decoding in CppFlexAttentionTemplate. We prefer to choose flash decoding instead of flash attention when query length is 1.
  2. For flash decoding, we add a kernel option PARTITION_SIZE to define the partition size of doing the parallelism on KV length dimension. The default value is 128, which should be multiple of KV cache block size to use flash decoding.
  3. As mentioned in Fix large_tensor_test skipping cpu #158617, flex_attn UTs for the cpu backend are disabled because of the long duration. Here we re-enable them on CPU-only machines. (Already merged in Enable XPU path for FlexAttention #143553)

Performance:

Here are the E2E results for Llama3.1-8B decoding validated on a GNR machine with 6 NUMA nodes, where we can see E2E improvements from 114% to 121%.

Data Type Input/Output tokens Batch Size W/O Flash Decoding (tokens/s) With Flash Decoding (tokens/s) Speedup
BF16 2016/32 25 892.196 1083.073 121.39%
FP16 2016/32 25 879.541 1015.593 115.47%
BF16 1024/128 30 1291.349 1529.251 118.42%
FP16 1024/128 30 1294.228 1473.049 113.82%

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo @chenyang78

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159835

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 5d66b9f with merge base d67b279 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@Valentine233
Copy link
Collaborator Author

Valentine233 commented Aug 6, 2025

@jianan-gu @CaoE Please help review, thanks~

@CaoE CaoE added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 12, 2025
)
return self._template_from_string(FLEX_ATTENTION_TEMPLATE).render(**options)
if (
query.data.data.layout.size[2] == 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we may double check the condition for selecting the flash decoding path, ref: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/kernel/flex/flex_decoding.py#L34

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! A function to choose flex template is added.

SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE)
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_Q_BLOCK_SIZE)
# In flash decoding, the partition size of doing the parallelism on KV length dim
PARTITION_SIZE = kernel_options.get("PARTITION_SIZE", 128)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add more PARTITION_SIZE for testing ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why flash decoding needs to set PARTITION_SIZE instead of automatically selecting a suitable PARTITION_SIZE?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, the UT for PARTITION_SIZE is added.
Yes, it is possible to automatically select the PARTITION_SIZE, which depends on PAGE_SIZE, input shape and thread numbers. It is also the same case for PAGE_SIZE, which is a fixed value now. We can have a round of tuning for PAGE_SIZE and PARTITION_SIZE as a future work.

{{kernel.kernel_name}}_conditional_data_ptr(logits, logits_reduced) + token_num,
v_addr,
tmp_out,
false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we may also need to add back need_pack check, depending on the qsize threshold mentioned below.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only let qsize=1 enter flash decoding for now, and for this case we do not need packing.

@Valentine233 Valentine233 force-pushed the flash_decoding_cpu branch 2 times, most recently from cbfa6e4 to b08d898 Compare August 25, 2025 06:04
@Valentine233 Valentine233 requested review from CaoE, drisspg, jansel and jianan-gu and removed request for jianan-gu August 25, 2025 06:57
@Valentine233 Valentine233 changed the title [Flex Attn][CPU] support flash decoding for cpu [PyTorch2.9 Feature][Flex Attn][CPU] support flash decoding for cpu Aug 25, 2025
@Valentine233
Copy link
Collaborator Author

@ZainRizvi Hi, could you please help check if the UT duration is acceptable with this PR? Previously reverted in #158617.

@Valentine233 Valentine233 marked this pull request as ready for review August 25, 2025 07:45
@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 26, 2025
@jianan-gu
Copy link
Contributor

Hi @drisspg,
We are adding flash decoding for inductor CPU backend (and also UT changes mentioned in #158617 (comment)) , could you kindly help review? Thanks!

@Valentine233 Valentine233 force-pushed the flash_decoding_cpu branch 2 times, most recently from 87dcbad to 1746bb2 Compare August 27, 2025 05:54
self.partition_size % self.kv_block_size == 0
and q_seq_len == 1
and num_threads > q_batch_size * q_num_heads
and k_seq_len / q_batch_size >= max(self.partition_size * 2, 512)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comments to explain this formula?

def score_mod(score, b, h, m, n):
return score * 2

self.run_test_with_paged_attention(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add mask_mod tests?

@Valentine233
Copy link
Collaborator Author

Valentine233 commented Jan 20, 2026

@drisspg @malfet @jansel
Hi, this feature is planned to target PyTorch 2.11. The PR is rebased and please help review again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request module: inductor open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants