feat: Batch-size invariant FA2 Prefill & Decode#1675
feat: Batch-size invariant FA2 Prefill & Decode#1675yzh119 merged 22 commits intoflashinfer-ai:mainfrom
Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @Edenzzzz, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request enhances the FlashInfer library by introducing a mechanism to achieve deterministic and batch-invariant behavior for prefill and decode operations. It allows users to specify a fixed split size for processing KV caches, which helps in producing consistent results regardless of batch size, thereby mitigating non-determinism in LLM inference. The changes involve modifications across the C++ backend, Python bindings, and the core attention scheduler logic, along with the addition of a dedicated test to validate the new functionality.
Highlights
- Batch-Invariant Prefill and Decode: Introduces a new
fixed_split_sizeparameter to the prefill and decode operations, allowing for deterministic and batch-size invariant outputs by fixing the split size for KV cache processing. This addresses non-determinism issues in LLM inference. - CUDA Graph Compatibility: Explicitly notes that compatibility with CUDA graphs is not guaranteed when
fixed_split_sizeis enabled, as varying sequence lengths can still lead to a different number of launched CTAs (Cooperative Thread Arrays), even with a fixed batch size. - New Test Case: Adds a new test file,
tests/test_invariant_batch_decode.py, to verify the batch-invariant behavior of the decode operation whenfixed_split_sizeis utilized.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request introduces a fixed_split_size parameter to enable batch-invariant attention computation in prefill and tensor-core decode paths, addressing potential non-determinism. The changes are well-implemented across the C++ backend, Python wrappers, and TVM bindings. The addition of a new test to verify batch-invariance for the decode path is a great step. My review includes a few suggestions to improve documentation consistency and test coverage.
| def test_batch_decode_tensor_cores( | ||
| batch_size: int, | ||
| invariant_bs: int, | ||
| kv_len: int, | ||
| fixed_split_size: int, | ||
| page_size: int, | ||
| num_kv_heads: int, | ||
| group_size: int, | ||
| head_dim: int, | ||
| kv_layout: str, | ||
| pos_encoding_mode: str, | ||
| ): | ||
| num_qo_heads = num_kv_heads * group_size | ||
| q = torch.randn( | ||
| batch_size, num_qo_heads, head_dim, device="cuda:0", dtype=torch.float16 | ||
| ) | ||
| num_pages_per_seq = (kv_len + page_size - 1) // page_size | ||
| total_num_pages = num_pages_per_seq * batch_size | ||
| kv_data = ( | ||
| torch.randn( | ||
| total_num_pages, | ||
| 2, | ||
| num_kv_heads, | ||
| page_size, | ||
| head_dim, | ||
| device="cuda:0", | ||
| dtype=torch.float16, | ||
| ) | ||
| / 10 | ||
| if kv_layout == "HND" | ||
| else torch.randn( | ||
| total_num_pages, | ||
| 2, | ||
| page_size, | ||
| num_kv_heads, | ||
| head_dim, | ||
| device="cuda:0", | ||
| dtype=torch.float16, | ||
| ) | ||
| / 10 | ||
| ) | ||
| kv_indptr = ( | ||
| torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) | ||
| * num_pages_per_seq | ||
| ) | ||
| kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) | ||
| kv_last_page_len = torch.full( | ||
| (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0" | ||
| ) | ||
|
|
||
| workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda:0") | ||
|
|
||
| wrapper_tensor_cores = flashinfer.BatchDecodeWithPagedKVCacheWrapper( | ||
| workspace_buffer, kv_layout, use_tensor_cores=True | ||
| ) | ||
| wrapper_tensor_cores.plan( | ||
| kv_indptr, | ||
| kv_indices, | ||
| kv_last_page_len, | ||
| num_qo_heads, | ||
| num_kv_heads, | ||
| head_dim, | ||
| page_size, | ||
| pos_encoding_mode=pos_encoding_mode, | ||
| data_type=torch.float16, | ||
| q_data_type=torch.float16, | ||
| fixed_split_size=fixed_split_size, | ||
| ) | ||
| o_tensor_cores, lse_tensor_cores = wrapper_tensor_cores.run( | ||
| q, kv_data, return_lse=True | ||
| ) | ||
|
|
||
| kv_indptr_invariant = kv_indptr[: invariant_bs + 1] | ||
| kv_last_page_len_invariant = kv_last_page_len[:invariant_bs] | ||
| wrapper_tensor_cores.plan( | ||
| kv_indptr_invariant, | ||
| kv_indices, | ||
| kv_last_page_len_invariant, | ||
| num_qo_heads, | ||
| num_kv_heads, | ||
| head_dim, | ||
| page_size, | ||
| pos_encoding_mode=pos_encoding_mode, | ||
| data_type=torch.float16, | ||
| q_data_type=torch.float16, | ||
| fixed_split_size=fixed_split_size, | ||
| ) | ||
| o_tensor_cores_invariant, lse_tensor_cores_invariant = wrapper_tensor_cores.run( | ||
| q[:invariant_bs], kv_data, return_lse=True | ||
| ) | ||
| torch.testing.assert_close( | ||
| o_tensor_cores[:invariant_bs], o_tensor_cores_invariant, rtol=1e-7, atol=1e-7 | ||
| ) | ||
| torch.testing.assert_close( | ||
| lse_tensor_cores[:invariant_bs], | ||
| lse_tensor_cores_invariant, | ||
| rtol=1e-7, | ||
| atol=1e-7, | ||
| ) |
There was a problem hiding this comment.
This new test is great for verifying the batch-invariance of the decode path with tensor cores. Since the PR also adds fixed_split_size to the prefill path (BatchPrefillWithPagedKVCacheWrapper), it would be beneficial to add a similar test for prefill to ensure its batch-invariance as well. The current test only covers the decode case where query length is 1 for each request, while prefill handles variable query lengths.
| # test that without fixed split size, precision is different | ||
| # TODO: this works for the first 29 cases, but then fails with "illegal memory access"..? | ||
|
|
||
| # wrapper_tensor_cores.plan( |
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
|
cc @yzh119 |
|
The easiest way to turn off split-k in flashinfer is to change these two lines to |
|
Yes, but simply turning it off would hurt performance |
| o_tensor_cores_invariant, lse_tensor_cores_invariant = wrapper_tensor_cores.run( | ||
| q[:invariant_bs], kv_data, return_lse=True | ||
| ) | ||
| torch.testing.assert_close( |
There was a problem hiding this comment.
can we add utility functions to compare bitwise identical check?
|
One minor note, the attention merge kernel ( flashinfer/include/flashinfer/attention/cascade.cuh Lines 342 to 467 in 88e333e Current reduction order:
and at here we reduce them into: merge(out[0], out[1], ... out[15]) the order might change for different kv-length. |
|
Though I thought the threads in a block just collectively reduce along the head dim, one position & head at a time in |
|
@happierpig has done some work on changing the reduction order before.
It matters if you want to guarantee the reproducibility across prefill and decode, my question is what kind of reproducibility do you want to achieve in this PR? |
We want to ensure that changing the batch size(adding requests) does not change the output of individual requests, using one kernel. Not sure how reduction affects that? They both use the same reduction and batch prefill kernel |
I think the current |
|
|
|
Do we need to add a flag for disabling split-kv? Since finally we want to use cuda graph for decoding. |
It should work now |
|
To be more clear on my comments about "prefill and decode consistency": I'm worried about the use cases where we manually merge attention outputs from different KV components (e.g. in chunked-prefill, speculative decoding), it's okay to ignore these cases at this moment but we should know their possible effect on reprodubility. |
Yes, I think if the chunk size changes, then setting a fixed split size can still lead to differences in the tail chunk for each request.
In these cases we need to consistently use one BatchPagedPrefill kernel, and avoid merging different chunks using a separate |
For deterministic inference we can add some assumptions, like disabling radix cache, disabling speculative decoding, or even disabling chunked prefill. Part of the speed can be sacrificed for more stable output |
yzh119
left a comment
There was a problem hiding this comment.
LGTM in general, left some minor suggestions.
| at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, | ||
| int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, | ||
| int64_t head_dim_vo, bool causal, int64_t window_left) { | ||
| int64_t head_dim_vo, bool causal, int64_t window_left, int64_t fixed_split_size = -1, |
There was a problem hiding this comment.
I don't encourage setting up default value in C++ APIs (as these APIs are directly exposed as python bindings), any use case of them?
There was a problem hiding this comment.
just example default values for python interface writers, removing now
| IntTuple kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, | ||
| int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, | ||
| int64_t head_dim_vo, bool causal, int64_t window_left, TVMStreamHandle cuda_stream) { | ||
| int64_t head_dim_vo, bool causal, int64_t window_left, int64_t fixed_split_size, |
There was a problem hiding this comment.
For TVM bindings, @MasterJH5574 could probably help with checking them.
There was a problem hiding this comment.
Sorry I just notice this change. On TVM side we don't use fixed_split_size at this moment, so will need to remove it from the parameter list and use a default value. We will submit a fix later.
This PR uses the default value -1 for `fixed_split_size` introduced in PR flashinfer-ai#1675, to keep the interface consistent with the TVM side.
## 📌 Description This PR uses the default value -1 for `fixed_split_size` introduced in PR #1675, to keep the interface consistent with the TVM side. ## 🔍 Related Issues N/A ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/).



📌 Description
As mentioned in https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/, split-kv can cause non-determinism. This PR introduces a fixed split size to produce exact precision when batch size changes.
Compatibility with CUDA graph is left for future work, as varied seqlen can cause #CTAs to exceed the number when the graph was captured.Also adds a
disable_split_kvflag for CUDA Graph mode.cc @Fridge003 I will likely be busy with some papers this semester, but this should be a useful starting point

🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes