Support lse in trtllm paged attn kernels#3058
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 Walkthrough📝 Walkthrough🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces support for returning the log-sum-exp (LSE) of attention logits in the TRT-LLM paged attention and MLA backends. The changes include updating the C++ kernel launchers to handle LSE pointers and strides, modifying the Python API to allow users to provide or receive LSE tensors, and updating relevant tests. Several critical issues were identified in the MLA implementation regarding the shape and flattening of the LSE tensor, which could lead to memory corruption or incorrect stride calculations. Additionally, a potential memory exhaustion issue was noted in the C++ workspace allocation for the context path.
| lse = torch.empty( | ||
| (query.size(0), query.size(1)), | ||
| dtype=torch.float32, | ||
| device=query.device, | ||
| ) |
There was a problem hiding this comment.
The lse tensor is allocated with shape (query.size(0), query.size(1)), which corresponds to (batch_size, q_len_per_request). However, attention Log-Sum-Exp (LSE) is computed per query head. Since query has shape (batch_size, q_len_per_request, num_heads, head_dim_qk), the lse tensor should be allocated with shape (query.size(0), query.size(1), query.size(2)) to accommodate all heads and avoid memory corruption when the kernel writes LSE values.
| lse = torch.empty( | |
| (query.size(0), query.size(1)), | |
| dtype=torch.float32, | |
| device=query.device, | |
| ) | |
| lse = torch.empty( | |
| (query.size(0), query.size(1), query.size(2)), | |
| dtype=torch.float32, | |
| device=query.device, | |
| ) |
| lse, | ||
| (query.size(0), query.size(1)), | ||
| torch.float32, | ||
| query.device, | ||
| "lse", | ||
| ) |
There was a problem hiding this comment.
| @@ -795,9 +817,13 @@ def trtllm_batch_decode_with_kv_cache_mla( | |||
| None, # value_block_scales | |||
| skip_softmax_threshold_scale_factor, | |||
| uses_shared_paged_kv_idx, | |||
| lse, | |||
There was a problem hiding this comment.
| runner_params.softmaxStatsPtr = float_allocator.aligned_alloc<float2>( | ||
| sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ, 16, | ||
| "trtllm_gen_softmax_workspace"); |
There was a problem hiding this comment.
The softmaxStatsPtr is now allocated for both Context and Generation modes. In Context mode, runner_params.mSumOfSeqLensQ (the total number of query tokens) can be very large, which may lead to excessive memory allocation in the workspace (e.g., for 128 heads and 128k tokens, this would require ~128MB). Since Context kernels typically do not require this workspace buffer for multi-block reduction (as mMultiCtasKvMode is false), this allocation could cause workspace exhaustion for long sequences. Consider moving this allocation inside the else block (Generation path) or making it conditional on mode == TllmPagedAttentionMode::ForGen.
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
tests/attention/test_trtllm_gen_attention.py (1)
634-659:⚠️ Potential issue | 🟠 MajorThe new prefill/decode LSE path is still untested.
Both branches unpack
lsebut never assert on it, so a token/head-stride bug in the new kernel plumbing would ship unnoticed. Please validatelseagainst the reference wrappers where they can emit it, and at minimum check shape, dtype, and finiteness in the remaining cases. That also resolves the current RUF059 warnings.Also applies to: 1112-1115
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_trtllm_gen_attention.py` around lines 634 - 659, The test calls flashinfer.prefill.trtllm_batch_context_with_kv_cache and unpacks lse but never validates it; add assertions to verify lse equals the reference lse when using the reference wrapper (compare values/close), and otherwise assert lse has the expected shape, dtype, and that all elements are finite (use torch.isfinite). Update the checks around the other prefill/decode call sites mentioned (the similar block at lines ~1112-1115) to perform the same validation so token/head-stride bugs are caught and RUF059 warnings are resolved.flashinfer/prefill.py (1)
3971-4008:⚠️ Potential issue | 🟠 MajorValidate provided
lsebuffers even whenreturn_lseis false.
lseis forwarded to the kernel unconditionally, but shape/dtype/device checks only run insideif return_lse. A caller can pass a malformedlsewithreturn_lse=Falseand bypass validation entirely.Proposed fix
- if return_lse: - if lse is None: - lse = torch.empty( - query.size(0), query.size(1), dtype=torch.float32, device=query.device - ) - else: - check_shape_dtype_device( - lse, (query.size(0), query.size(1)), torch.float32, query.device, "lse" - ) + if lse is not None: + check_shape_dtype_device( + lse, (query.size(0), query.size(1)), torch.float32, query.device, "lse" + ) + elif return_lse: + lse = torch.empty( + query.size(0), query.size(1), dtype=torch.float32, device=query.device + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` around lines 3971 - 4008, The code currently only validates the optional tensor lse when return_lse is true, but lse is passed to run_func unconditionally; update the logic so any non-None lse is validated regardless of return_lse: before calling run_func (or at top of this block) check if lse is not None and call check_shape_dtype_device(lse, (query.size(0), query.size(1)), torch.float32, query.device, "lse"); keep the existing branch that allocates an empty lse when return_lse is true and lse is None, but ensure validation still happens for user-supplied lse even when return_lse is false so malformed buffers cannot be forwarded to run_func.tests/attention/test_trtllm_gen_mla.py (1)
394-416:⚠️ Potential issue | 🟠 MajorAssert on MLA
lse, not just tuple unpacking.This only proves that
return_lse=Truereturns a second object. A bad LSE layout/stride would still pass becauselseis never checked, and Ruff already flags it as unused. Please compare it against a reference LSE where available, or at least assert shape, dtype, and finiteness.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_trtllm_gen_mla.py` around lines 394 - 416, The test currently only unpacks the second return (lse) from flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla but never validates it; update the backend == "trtllm-gen" branch to assert properties of lse: check that lse exists and has the expected shape (match seq_lens_tensor or other reference LSE shape), correct dtype (e.g., same as output or torch.float32), and that all values are finite (no NaN/Inf); if a reference LSE tensor is available in the test harness, compare lse against it (or at minimum validate shape/dtype/torch.isfinite) in addition to the existing workspace_buffer zero check, keeping the tuple unpacking of output, lse the same.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/mla/_core.py`:
- Around line 612-614: The code adds TRTLLM-only args lse and return_lse but the
xqa and cute-dsl branches still ignore them and return a plain tensor; update
the function handling to fail fast: at the start of the method (around the
signature with parameters lse and return_lse) add a guard that checks if
(return_lse or lse is not None) and the current backend is not 'trtllm-gen' (or
equivalent backend identifier used in this module), and raise a clear
ValueError/RuntimeError explaining that LSE/return_lse are only supported for
trtllm-gen; also add the same guard inside the xqa and cute-dsl branch handlers
(the code paths around the existing xqa and cute-dsl handling lines) so they
explicitly raise instead of returning plain tensors when LSE is requested.
- Around line 775-789: The LSE buffer is being allocated/validated with shape
(batch, q_len) but MLA decode flattens query and indexes LSE per token/head;
update the allocation/validation in the return_lse branch so lse has space for
all heads: use shape (query.size(0), query.size(1), query.size(2)) (i.e.,
[batch, q_len, heads]) or equivalently allocate/validate as a flat buffer of
length query.size(0)*query.size(1)*query.size(2) that the kernel can index, and
if you must keep a public 2-D shape then reshape the 3-D buffer to (batch,
q_len) only after the kernel call; adjust the check_shape_dtype_device call and
any downstream assumptions around lse accordingly (references: return_lse, lse,
query, check_shape_dtype_device).
In `@flashinfer/prefill.py`:
- Around line 3739-3743: The wrapper path doesn't forward return_lse so
BatchPrefillWithPagedKVCacheWrapper.run(return_lse=True) still takes the
trtllm-gen paged path which asserts maybe_lse is None; wire the return_lse flag
through the wrapper to the underlying prefill helper and paged path so the
paged/trtllm-gen code can return LSE when requested: update
BatchPrefillWithPagedKVCacheWrapper.run (and any intermediate wrapper functions)
to accept and pass the return_lse argument into the direct helper call, remove
or adjust the assertion on maybe_lse in the trtllm-gen paged path, and ensure
the function signature that was changed (the helper with lse:
Optional[torch.Tensor] and return_lse: bool) is invoked with the new flag so the
returned tuple (maybe_lse, ...) is propagated back to callers.
---
Outside diff comments:
In `@flashinfer/prefill.py`:
- Around line 3971-4008: The code currently only validates the optional tensor
lse when return_lse is true, but lse is passed to run_func unconditionally;
update the logic so any non-None lse is validated regardless of return_lse:
before calling run_func (or at top of this block) check if lse is not None and
call check_shape_dtype_device(lse, (query.size(0), query.size(1)),
torch.float32, query.device, "lse"); keep the existing branch that allocates an
empty lse when return_lse is true and lse is None, but ensure validation still
happens for user-supplied lse even when return_lse is false so malformed buffers
cannot be forwarded to run_func.
In `@tests/attention/test_trtllm_gen_attention.py`:
- Around line 634-659: The test calls
flashinfer.prefill.trtllm_batch_context_with_kv_cache and unpacks lse but never
validates it; add assertions to verify lse equals the reference lse when using
the reference wrapper (compare values/close), and otherwise assert lse has the
expected shape, dtype, and that all elements are finite (use torch.isfinite).
Update the checks around the other prefill/decode call sites mentioned (the
similar block at lines ~1112-1115) to perform the same validation so
token/head-stride bugs are caught and RUF059 warnings are resolved.
In `@tests/attention/test_trtllm_gen_mla.py`:
- Around line 394-416: The test currently only unpacks the second return (lse)
from flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla but never validates
it; update the backend == "trtllm-gen" branch to assert properties of lse: check
that lse exists and has the expected shape (match seq_lens_tensor or other
reference LSE shape), correct dtype (e.g., same as output or torch.float32), and
that all values are finite (no NaN/Inf); if a reference LSE tensor is available
in the test harness, compare lse against it (or at minimum validate
shape/dtype/torch.isfinite) in addition to the existing workspace_buffer zero
check, keeping the tuple unpacking of output, lse the same.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: d0be76a5-7a89-46d7-a68d-1535c9c74a52
📥 Commits
Reviewing files that changed from the base of the PR and between 7c562d5 and e1a4a8a6abcdb89a70936cc8d3c2e98f9d8b16ed.
📒 Files selected for processing (7)
csrc/trtllm_fmha_kernel_launcher.cuflashinfer/decode.pyflashinfer/mla/_core.pyflashinfer/prefill.pyinclude/flashinfer/trtllm/fmha/fmhaRunnerParams.htests/attention/test_trtllm_gen_attention.pytests/attention/test_trtllm_gen_mla.py
| lse: Optional[torch.Tensor] = None, | ||
| return_lse: bool = False, | ||
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
There was a problem hiding this comment.
Fail fast when LSE is requested on non-trtllm-gen backends.
The new args are documented as TRTLLM-only, but the xqa and cute-dsl branches still ignore them and return a plain tensor. out, lse = ... will therefore either raise or silently unpack batch slices depending on the output shape.
🛠️ Suggested guard
if backend == "auto":
backend = (
"trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa"
)
+ if backend != "trtllm-gen" and (return_lse or lse is not None):
+ raise ValueError(
+ "lse and return_lse are only supported by the trtllm-gen backend"
+ )Also applies to: 659-663
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/mla/_core.py` around lines 612 - 614, The code adds TRTLLM-only
args lse and return_lse but the xqa and cute-dsl branches still ignore them and
return a plain tensor; update the function handling to fail fast: at the start
of the method (around the signature with parameters lse and return_lse) add a
guard that checks if (return_lse or lse is not None) and the current backend is
not 'trtllm-gen' (or equivalent backend identifier used in this module), and
raise a clear ValueError/RuntimeError explaining that LSE/return_lse are only
supported for trtllm-gen; also add the same guard inside the xqa and cute-dsl
branch handlers (the code paths around the existing xqa and cute-dsl handling
lines) so they explicitly raise instead of returning plain tensors when LSE is
requested.
|
thanks for the contrib sounds similar to this one #2332 i'm checking the difference.then decide which to go forward |
|
Nice work! One suggestion: consider adding a |
|
/bot run |
| ) | ||
| if backend == "trtllm-gen": | ||
| output, lse = output |
There was a problem hiding this comment.
It seems there is no reference check or any kind of check for lse in test function. Should we add something to check if lse is legit or not?
There was a problem hiding this comment.
@nv-yunzheq we check against the reference impl
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/mla/_core.py`:
- Around line 562-568: The LSE tensor shape is computed incorrectly: change the
lse_shape calculation from using q_nope.shape[0] * num_heads to a shape of
(q_nope.shape[0], num_heads) so LSE matches the expected [num_tokens,
num_heads]; when creating or validating lse (in the block using lse, lse_shape,
torch.empty, check_shape_dtype_device) ensure dtype is torch.float32 and device
is q_nope.device (or device variable) as before.
In `@tests/attention/test_trtllm_gen_attention.py`:
- Around line 1134-1140: The test unpacks the tuple `output, lse = output` twice
causing a runtime error when `backend == "trtllm-gen"`; fix by removing the
duplicate unpack: only unpack `output` into `output, lse` once (e.g., inside the
`if should_check_lse:` block which is already true for trtllm-gen) and then
perform the `backend == "trtllm-gen"` workspace_buffer assertion without
re-unpacking; adjust the branching so `should_check_lse`, `backend`, `output`,
`lse`, and `workspace_buffer` are used correctly and no second unpack occurs.
In `@tests/attention/test_trtllm_gen_mla.py`:
- Around line 415-422: The test unpacks the (output, lse) tuple twice which
causes a runtime error when should_check_lse is True and backend ==
"trtllm-gen"; remove the second unpack in the backend == "trtllm-gen" branch
(the line `output, lse = output`) and use the already-unpacked tensors (output,
lse) when asserting on workspace_buffer, keeping checks that verify lse
shape/dtype/finite values and the workspace_buffer assertion intact.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 05458218-7db9-4dda-87bb-562f3f841586
📥 Commits
Reviewing files that changed from the base of the PR and between e1a4a8a6abcdb89a70936cc8d3c2e98f9d8b16ed and 9a562c947b98e2f7c719f980bea50a8f0e4bc57f.
📒 Files selected for processing (5)
flashinfer/decode.pyflashinfer/mla/_core.pyflashinfer/prefill.pytests/attention/test_trtllm_gen_attention.pytests/attention/test_trtllm_gen_mla.py
🚧 Files skipped from review as they are similar to previous changes (2)
- flashinfer/prefill.py
- flashinfer/decode.py
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/attention/test_trtllm_gen_mla.py (1)
392-411: Exercise the reject path forreturn_lseon unsupported backends.This helper now avoids requesting LSE on
xqa/cute-dsl, so the test suite no longer proves that those backends still raiseValueErrorforreturn_lse=True. A smallpytest.raises(...)case here would lock in that contract and catch regressions back to silent wrong behavior.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_trtllm_gen_mla.py` around lines 392 - 411, The test currently never asserts that requesting return_lse=True on unsupported backends raises an error; add an explicit pytest.raises(ValueError) case around the call to flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla (the same invocation using variables like backend, query, kv_cache, workspace_buffer, block_tables_kernel, seq_lens_tensor, max_seq_len, etc.) for the unsupported backends (e.g., "xqa" and "cute-dsl") to exercise the reject path; you can either parametrize backend over those values or add two small blocks that set backend to each unsupported value, set should_check_lse=True, and assert a ValueError is raised.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/attention/test_trtllm_gen_mla.py`:
- Around line 392-411: The test currently never asserts that requesting
return_lse=True on unsupported backends raises an error; add an explicit
pytest.raises(ValueError) case around the call to
flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla (the same invocation
using variables like backend, query, kv_cache, workspace_buffer,
block_tables_kernel, seq_lens_tensor, max_seq_len, etc.) for the unsupported
backends (e.g., "xqa" and "cute-dsl") to exercise the reject path; you can
either parametrize backend over those values or add two small blocks that set
backend to each unsupported value, set should_check_lse=True, and assert a
ValueError is raised.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 334cd0a9-4a28-4e6c-9de1-897e6eaab398
📥 Commits
Reviewing files that changed from the base of the PR and between 9a562c947b98e2f7c719f980bea50a8f0e4bc57f and 492e7d7021480af631921492ead9f87877650208.
📒 Files selected for processing (2)
tests/attention/test_trtllm_gen_attention.pytests/attention/test_trtllm_gen_mla.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/attention/test_trtllm_gen_attention.py
492e7d7 to
a892361
Compare
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/attention/test_trtllm_gen_attention.py (1)
660-685:⚠️ Potential issue | 🟡 MinorPlease add one
lse=buffer regression.These new calls only cover the internally allocated path via
return_lse=True. They never pass a preallocatedlsetensor, so the explicit-buffer contract added by the Python APIs can regress without this file catching it. One case that passeslse=torch.empty(...)and checks the returned tensor aliases it would be enough.Also applies to: 1138-1165
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_trtllm_gen_attention.py` around lines 660 - 685, Add a regression test that calls flashinfer.prefill.trtllm_batch_context_with_kv_cache with a preallocated lse tensor instead of relying on return_lse=True: allocate lse = torch.empty(..., device=GPU_DEVICE, dtype=out_dtype) and pass it as the lse= argument in one of the existing calls (the block invoking trtllm_batch_context_with_kv_cache), then assert the function populates and returns that same tensor (alias/identity check) and has correct contents; do the same for the other similar call range mentioned (around lines 1138-1165) to ensure the explicit-buffer contract is exercised.
♻️ Duplicate comments (2)
flashinfer/prefill.py (1)
693-717:⚠️ Potential issue | 🔴 CriticalThread
lsethrough_paged_runbefore passing it here.
paged_run_funcresolves to the_paged_runhelper defined earlier in this file (Lines 249-310), and that helper still neither acceptslsenor forwards one toop.trtllm_paged_attention_context. The new keyword at Line 716 therefore raisesTypeErroras soon asBatchPrefillWithPagedKVCacheWrapper.run(..., return_lse=True)takes thetrtllm-genpath.🛠️ Minimal fix
def _paged_run( query: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, workspace_buffer: torch.Tensor, @@ enable_pdl: bool, workspace_size: int, window_left: int = -1, out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, sinks: Optional[torch.Tensor] = None, key_block_scales: Optional[torch.Tensor] = None, value_block_scales: Optional[torch.Tensor] = None, @@ key_block_scales, value_block_scales, skip_softmax_threshold_scale_factor, uses_shared_paged_kv_idx, + lse, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` around lines 693 - 717, Paged attention is being called with an lse keyword (lse) but _paged_run (the function paged_run_func resolves to) does not accept or forward lse to op.trtllm_paged_attention_context, causing a TypeError; update _paged_run to accept an lse parameter and pass it through to op.trtllm_paged_attention_context (and any intermediate call sites inside _paged_run) so that when BatchPrefillWithPagedKVCacheWrapper.run(..., return_lse=True) takes the trtllm-gen path the lse argument is threaded end-to-end.flashinfer/mla/_core.py (1)
561-568:⚠️ Potential issue | 🔴 CriticalUse
[tokens, heads]for wrapper LSE, not[tokens*heads, head_dim].
q_nopehere is[tokens, num_heads, head_dim_ckv], so the kernel emits one LSE value per token/head.lse_shape = (q_nope.shape[0] * num_heads, q_nope.shape[2])bakeshead_dim_ckvinto the layout and returns garbage-shaped stats. This also diverges fromflashinfer/decode.py(Lines 1966-1972), which uses(q_nope.size(0), q_nope.size(1))for MLA LSE.🛠️ Minimal fix
- lse_shape = (q_nope.shape[0] * num_heads, q_nope.shape[2]) + lse_shape = (q_nope.shape[0], num_heads)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/mla/_core.py` around lines 561 - 568, The wrapper currently computes lse_shape as (q_nope.shape[0] * num_heads, q_nope.shape[2]) which flattens heads into tokens and yields wrong layout; update the logic in the return_lse block to use per-token-per-head shape (tokens, heads) by setting lse_shape = (q_nope.size(0), q_nope.size(1)), allocate torch.empty(lse_shape, dtype=torch.float32, device=device) when lse is None, and call check_shape_dtype_device(lse, lse_shape, torch.float32, q_nope.device, "lse") otherwise so the wrapper LSE matches the kernel output and aligns with flashinfer/decode.py’s MLA LSE shape.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/decode.py`:
- Around line 1965-1972: The validation for the optional MLA buffer lse should
require torch.float32 regardless of whether lse is provided; update the check in
the return_lse branch so when lse is not None you call
check_shape_dtype_device(lse, lse_shape, torch.float32, q_nope.device, "lse")
(i.e., replace q_nope.dtype with torch.float32) and ensure any created empty
buffer uses dtype=torch.float32; this ensures lse is accepted only if float32
and the branch still covers the explicit-supplied case.
- Around line 1367-1374: The caller-provided lse buffer must be validated even
when return_lse is False; change the logic around the lse handling so that
whenever lse is not None you call check_shape_dtype_device(lse, (q.size(0),
q.size(1)), torch.float32, q.device, "lse") before passing it to kernels, and
only allocate a new tensor when lse is None; apply this same change to the other
decode API blocks that currently guard validation with return_lse (the blocks
containing the if return_lse: / if lse is None: / check_shape_dtype_device(...)
pattern) so all code paths that accept an lse argument enforce shape, dtype, and
device correctness.
In `@flashinfer/mla/_core.py`:
- Around line 781-789: Validate any provided lse buffer regardless of
return_lse: compute lse_shape = (batch_size * max_q_len, num_qo_heads) as shown,
then if lse is not None call check_shape_dtype_device(lse, lse_shape,
torch.float32, query.device, "lse") to validate shape/dtype/device; only
allocate a new tensor (torch.empty(...)) when return_lse is True and lse is
None. In other words, move the validation out of the if return_lse block (or add
an explicit if lse is not None check) and keep allocation conditional on
return_lse so bad buffers are caught in Python before calling the kernel.
In `@include/flashinfer/trtllm/fmha/fmhaRunnerParams.h`:
- Around line 227-230: The new lseStrideTokens and lseStrideHeads fields in
fmhaRunnerParams are declared as int but must preserve 64-bit stride values from
the launcher (see csrc/trtllm_fmha_kernel_launcher.cu using int64_t
lse_stride_tokens/lse_stride_heads); either change the fmhaRunnerParams fields
lseStrideTokens and lseStrideHeads to int64_t, or if you must keep int, add a
range-check where the launcher/population code assigns to fmhaRunnerParams
(check against INT32_MAX/INT32_MIN) and fail fast with a clear error if the
64-bit stride doesn't fit in 32 bits so the kernel cannot write out-of-bounds.
---
Outside diff comments:
In `@tests/attention/test_trtllm_gen_attention.py`:
- Around line 660-685: Add a regression test that calls
flashinfer.prefill.trtllm_batch_context_with_kv_cache with a preallocated lse
tensor instead of relying on return_lse=True: allocate lse = torch.empty(...,
device=GPU_DEVICE, dtype=out_dtype) and pass it as the lse= argument in one of
the existing calls (the block invoking trtllm_batch_context_with_kv_cache), then
assert the function populates and returns that same tensor (alias/identity
check) and has correct contents; do the same for the other similar call range
mentioned (around lines 1138-1165) to ensure the explicit-buffer contract is
exercised.
---
Duplicate comments:
In `@flashinfer/mla/_core.py`:
- Around line 561-568: The wrapper currently computes lse_shape as
(q_nope.shape[0] * num_heads, q_nope.shape[2]) which flattens heads into tokens
and yields wrong layout; update the logic in the return_lse block to use
per-token-per-head shape (tokens, heads) by setting lse_shape = (q_nope.size(0),
q_nope.size(1)), allocate torch.empty(lse_shape, dtype=torch.float32,
device=device) when lse is None, and call check_shape_dtype_device(lse,
lse_shape, torch.float32, q_nope.device, "lse") otherwise so the wrapper LSE
matches the kernel output and aligns with flashinfer/decode.py’s MLA LSE shape.
In `@flashinfer/prefill.py`:
- Around line 693-717: Paged attention is being called with an lse keyword (lse)
but _paged_run (the function paged_run_func resolves to) does not accept or
forward lse to op.trtllm_paged_attention_context, causing a TypeError; update
_paged_run to accept an lse parameter and pass it through to
op.trtllm_paged_attention_context (and any intermediate call sites inside
_paged_run) so that when BatchPrefillWithPagedKVCacheWrapper.run(...,
return_lse=True) takes the trtllm-gen path the lse argument is threaded
end-to-end.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: de148d35-83f8-4f7f-a640-a6d708815ee0
📥 Commits
Reviewing files that changed from the base of the PR and between 492e7d7021480af631921492ead9f87877650208 and a892361.
📒 Files selected for processing (7)
csrc/trtllm_fmha_kernel_launcher.cuflashinfer/decode.pyflashinfer/mla/_core.pyflashinfer/prefill.pyinclude/flashinfer/trtllm/fmha/fmhaRunnerParams.htests/attention/test_trtllm_gen_attention.pytests/attention/test_trtllm_gen_mla.py
🚧 Files skipped from review as they are similar to previous changes (2)
- tests/attention/test_trtllm_gen_mla.py
- csrc/trtllm_fmha_kernel_launcher.cu
|
sorry i missed there were test errors in the bot run... do you mind i reverting it and let's put in an open PR again and re-run tests? |
|
there existed at least one IMA on both B200 and B300 |
|
i'm wondering have you observed that on your end, @murphymatt ? thanks |
|
@aleozlx can you link to this error? Is this from the current PR head or maybe some internal run you may have? I am seeing this pass locally on my machine (B200, cuda 12.8) |
|
It's an internal link so i could only paste some content. |
|
but it is about this PR |
|
let me give run it manually to double check |
|
seeing clean results at 42ff2b0 (one commit before on main) but the following error log on 25b324d (this merge) |
|
so the 1st encounter seem to be |
|
Hi @murphymatt I had to commit the revert based on the above info to keep our CI green, but please let's try this again in an open PR and address such issues. Your contribution is much valued. Thank you! |
|
@aleozlx is it possible for Nvidia to work on re-landing this change? I am not quite sure how we can ensure test correctness over your environment, given that it passes on our host. |
|
Hi this appears to be observed on all b200/b300/gb200/gb300 configurations as well as my local attempt with a gb100 node. So it didn't appear to me that it was a tricky environment thing. But ofc let me escalate it for vis
|
|
Filed issue for it #3114 |
|
some clues about the IMA has been posted in the issue |
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
return_lseoption.Refactor
Tests