fix: Fix trtllm-gen prefill IMA when batch_size==1#1912
fix: Fix trtllm-gen prefill IMA when batch_size==1#1912yzh119 merged 12 commits intoflashinfer-ai:mainfrom
Conversation
| @@ -932,6 +928,12 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): | |||
| v_fp8 = (v_data / v_scale).to(kv_dtype) | |||
| kv_cache = torch.cat([k_fp8, v_fp8], dim=1) | |||
|
|
|||
| if batch_size == 1: | |||
| # trtllm kernel requires max_q_len to be the same as the seqlen of the query when batch_size=1 | |||
There was a problem hiding this comment.
Why qo_indptr[-1] could be different to s_qo, is it because we want to be compatible with cudagraphs and s_qo will always be the maximum length?
There was a problem hiding this comment.
Short answer is yes.
Longer answer: In a batch_size > 1 situation, the CUDA graph containing prefill.trtllm_batch_context_with_kv_cache() can be reused with multiple sequence lengths but not when batch_size==1. For example,
- If batch_size is 3 and we have two batches with query lengths
[100, 200, 300]and[16, 500, 1024], we can sets_qo=1024, when we construct the CUDA graph and use the same CUDA graph for the two batches. - However for batch_size=1, where we have batches of query lengths
[100]and[1024], a CUDA graph must be constructed each time -- first withs_qo=100and second withs_qo=1024.
Not sure whether the above is a real concern at the framework level. Nevertheless, s_qo goes in as the max_q_len input argument where it is the max sequence length for query. We may at least want to consider whether the wording in the documentation is clear 😄
4dade1b to
197a7a0
Compare
|
Hi @bkryu does upgrading to latest trtllm-gen fixing the issue? |
|
/bot run |
|
[FAILED] Pipeline #36750562: 1/17 passed |
WalkthroughRe-enables trtllm-gen-native for batch_size==1 in benchmark routines, updates three TRTLLM_GEN_FMHA artifact hash constants, adds Changes
Sequence Diagram(s)sequenceDiagram
participant Test as Test Runner
participant Bench as Benchmark Routine
participant Selector as Backend Selector
participant Backend as trtllm-gen-native
participant Kernel as KernelParams
Note over Test,Bench: Parametrized batch-prefill tests (including bs==1)
Test->>Bench: invoke testBatchPrefill(max_q_len, max_kv_len, ...)
Bench->>Selector: request eligible backends (batch_size considered)
Note right of Selector: bs==1 no longer auto-skipped
Selector-->>Backend: select trtllm-gen-native when constraints met
Backend->>Kernel: build KernelParams (mUseBlockSparseAttention = false)
Kernel-->>Backend: return params
Backend-->>Bench: execute prefill using returned params
Bench-->>Test: return results
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related issues
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
|
/bot run |
57e47ea to
003ef55
Compare
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (2)
benchmarks/README.md (1)
19-19: LGTM! Documentation correctly updated.The documentation now accurately reflects that
BatchPrefillWithRaggedKVCacheWrappersupportstrtllm_ragged_attention_deepseekfor ragged attention operations.Optional: Fix list indentation for consistency.
The static analysis tool flags that this line uses 8 spaces for indentation instead of the expected 4 for nested list items.
- - Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` and `trtllm_ragged_attention_deepseek`. + - Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` and `trtllm_ragged_attention_deepseek`.tests/attention/test_trtllm_gen_attention.py (1)
348-361: LGTM! Function signature correctly updated.The new
max_q_lenandmax_kv_lenparameters are properly integrated into the function signature and correctly passed togenerate_seq_lens_prefill.Optional: Prefix unused variable with underscore.
Line 360 unpacks
in_kv_lensfromgenerate_seq_lens_prefill, but the variable is never used in the function body. Consider prefixing it with an underscore to indicate it's intentionally unused:- q_lens, in_kv_lens, seq_lens = generate_seq_lens_prefill( + q_lens, _in_kv_lens, seq_lens = generate_seq_lens_prefill( batch_size, max_q_len, max_kv_len )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between 57e47ea93fccf21f1ea7ddf3cea23566c1458367 and 003ef55b24ea764fd621792e017dcf2ed84bc5a8.
📒 Files selected for processing (5)
benchmarks/README.md(1 hunks)benchmarks/routines/attention.py(0 hunks)flashinfer/artifacts.py(3 hunks)include/flashinfer/trtllm/fmha/kernelParams.h(2 hunks)tests/attention/test_trtllm_gen_attention.py(3 hunks)
💤 Files with no reviewable changes (1)
- benchmarks/routines/attention.py
🚧 Files skipped from review as they are similar to previous changes (2)
- include/flashinfer/trtllm/fmha/kernelParams.h
- flashinfer/artifacts.py
🧰 Additional context used
🧬 Code graph analysis (1)
tests/attention/test_trtllm_gen_attention.py (1)
flashinfer/utils.py (1)
get_compute_capability(251-254)
🪛 markdownlint-cli2 (0.18.1)
benchmarks/README.md
19-19: Unordered list indentation
Expected: 4; Actual: 8
(MD007, ul-indent)
🪛 Ruff (0.14.0)
tests/attention/test_trtllm_gen_attention.py
360-360: Unpacked variable in_kv_lens is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
tests/attention/test_trtllm_gen_attention.py (2)
334-335: LGTM! Parameterization enhances test flexibility.Adding
max_q_lenandmax_kv_lenas test parameters allows testing different sequence length combinations, which is essential for validating the batch_size==1 fix across various configurations.
530-578: LGTM! Dedicated batch_size=1 test addresses PR objective.The new
test_trtllm_batch_prefill_bs1function specifically tests the batch_size==1 scenario with large sequence lengths (8192), which directly addresses the issue described in #1898. The test properly delegates to the main test function with appropriate parameters and minimal configuration to focus on the batch_size==1 edge case.
|
[FAILED] Pipeline #36805526: 1/17 passed |
|
@nvmbreughe , can I get a review on the PR? Zihao and Perkz already approved, but due to code owner review requirements, it seems like I need a review from you |
fa62171 to
1b7f9e8
Compare
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/attention/test_trtllm_gen_attention.py (1)
360-362: Optional: Consider using underscore for unused unpacked variable.The
in_kv_lensvariable is unpacked but never used. Consider using_to signal this is intentional:- q_lens, in_kv_lens, seq_lens = generate_seq_lens_prefill( - batch_size, max_q_len, max_kv_len - ) + q_lens, _, seq_lens = generate_seq_lens_prefill( + batch_size, max_q_len, max_kv_len + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between 003ef55b24ea764fd621792e017dcf2ed84bc5a8 and 1b7f9e8.
📒 Files selected for processing (5)
benchmarks/README.md(1 hunks)benchmarks/routines/attention.py(0 hunks)flashinfer/artifacts.py(3 hunks)include/flashinfer/trtllm/fmha/kernelParams.h(2 hunks)tests/attention/test_trtllm_gen_attention.py(3 hunks)
💤 Files with no reviewable changes (1)
- benchmarks/routines/attention.py
🚧 Files skipped from review as they are similar to previous changes (2)
- flashinfer/artifacts.py
- include/flashinfer/trtllm/fmha/kernelParams.h
🧰 Additional context used
🧬 Code graph analysis (1)
tests/attention/test_trtllm_gen_attention.py (1)
flashinfer/utils.py (1)
get_compute_capability(251-254)
🪛 markdownlint-cli2 (0.18.1)
benchmarks/README.md
19-19: Unordered list indentation
Expected: 4; Actual: 8
(MD007, ul-indent)
🪛 Ruff (0.14.1)
tests/attention/test_trtllm_gen_attention.py
360-360: Unpacked variable in_kv_lens is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🔇 Additional comments (3)
benchmarks/README.md (1)
19-19: LGTM! Documentation correctly reflects deepseek support.The addition of
trtllm_ragged_attention_deepseekto the supported operations forBatchPrefillWithRaggedKVCacheWrapperis accurate and aligns with the PR's objective to fix and enable trtllm-gen prefill with deepseek attention.tests/attention/test_trtllm_gen_attention.py (2)
334-335: Good parameterization for flexible sequence length testing.Adding
max_q_lenandmax_kv_lenparameters allows testing different sequence length scenarios while preserving existing test behavior with sensible defaults. This enables the new batch_size=1 test to use longer sequences.Also applies to: 348-349, 361-361
530-578: Well-targeted regression test for batch_size=1 with large sequences.This test specifically validates the kernel fix for batch_size==1 by using large sequence lengths (8192+8192=16384 total), which was the failing scenario described in issue #1898. The narrow parameter space is appropriate for a focused regression test.
<!-- .github/pull_request_template.md --> ## 📌 Description In #1898, it was raised that trtllm-gen's attention kernels fail for batch size 1. The prefill kernel was fixed in #1912 and prefill tests have been enabled. Further updates to trtllm-gen kernels have also fixed the decode batch size 1 issue. Current PR re-enables testing. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Tests** * Expanded batch_decode test scenarios to cover additional small-batch and page-size combinations. * Increased coverage for max_in_kv_len by testing multiple length options instead of a single value. * Restored previously marked-as-expected-failure case to run normally, improving overall test pass coverage. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Zihao Ye <expye@outlook.com>
<!-- .github/pull_request_template.md --> ## 📌 Description In flashinfer-ai#1898, it was raised that trtllm-gen's attention kernels fail for batch size 1. The prefill kernel was fixed in flashinfer-ai#1912 and prefill tests have been enabled. Further updates to trtllm-gen kernels have also fixed the decode batch size 1 issue. Current PR re-enables testing. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Tests** * Expanded batch_decode test scenarios to cover additional small-batch and page-size combinations. * Increased coverage for max_in_kv_len by testing multiple length options instead of a single value. * Restored previously marked-as-expected-failure case to run normally, improving overall test pass coverage. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Zihao Ye <expye@outlook.com>
📌 Description
Current PR fixes the test and benchmark codes IMAs when running trtllm-gen paged & ragged prefill with batch size 1 -- the issue was described in #1898
Root cause of the issue:
flashinfer.prefill.trtllm_ragged_attention_deepseekandflashinfer.prefill.trtllm_batch_context_with_kv_cacheboth requiremax_q_lento match the length of the query when batch size is 1.Updated PR:
Issue has been addressed from the kernel-side so that the "
max_q_lento match the length of the query when batch size is 1" is no longer required.Current PR updates trtllm-gen FMHA cubins to latest and brings minor updates to kernel metadata.
Unit test results after PR:
Description of previous solution:
Updatingmax_q_lentocum_seq_lens_q[-1].item()within thetrtllm_ragged_attention_deepseekortrtllm_batch_context_with_kv_cachefunctions are not a viable option because the CPU-side synchronization breaks the deterministic and fully device-side execution required during CUDA graph capture. The workaround was thus to update the test & benchmark codes that call the trtllm prefill functions, and clearly state in the docstring that when batch_size == 1, max_q_len must match the query size.🔍 Related Issues
#1898
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Bug Fixes
New Features
Documentation
Tests