Skip to content

Support lse in trtllm paged attn kernels#3058

Merged
aleozlx merged 7 commits intoflashinfer-ai:mainfrom
murphymatt:support-trtllm-attn-lse
Apr 15, 2026
Merged

Support lse in trtllm paged attn kernels#3058
aleozlx merged 7 commits intoflashinfer-ai:mainfrom
murphymatt:support-trtllm-attn-lse

Conversation

@murphymatt
Copy link
Copy Markdown
Contributor

@murphymatt murphymatt commented Apr 14, 2026

📌 Description

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Optional support for computing and returning per-request attention Log-Sum-Exp (LSE) from paged attention; decode/context APIs accept an optional LSE tensor and a return_lse option.
  • Refactor

    • Public signatures and returns updated to pass/return LSE; runtime validation/allocation and stride-aware handling for LSE added. Generation workspace layout refined to reserve softmax stats separately.
  • Tests

    • Tests now request, validate, and compare LSE (shape/dtype/finiteness/values) and adjust workspace zeroing checks.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 14, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough
📝 Walkthrough
🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description consists solely of the template with only checklist items completed; it lacks any substantive implementation details, rationale, changed files, or reviewer notes. Replace the template with an actual description explaining what LSE support enables, why it's needed, and summarizing the key implementation changes across the modified files.
Docstring Coverage ⚠️ Warning Docstring coverage is 40.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Support lse in trtllm paged attn kernels' is concise and clearly describes the main feature added: LSE support in TRTLLM paged attention kernels.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for returning the log-sum-exp (LSE) of attention logits in the TRT-LLM paged attention and MLA backends. The changes include updating the C++ kernel launchers to handle LSE pointers and strides, modifying the Python API to allow users to provide or receive LSE tensors, and updating relevant tests. Several critical issues were identified in the MLA implementation regarding the shape and flattening of the LSE tensor, which could lead to memory corruption or incorrect stride calculations. Additionally, a potential memory exhaustion issue was noted in the C++ workspace allocation for the context path.

Comment thread flashinfer/mla/_core.py Outdated
Comment on lines +777 to +781
lse = torch.empty(
(query.size(0), query.size(1)),
dtype=torch.float32,
device=query.device,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The lse tensor is allocated with shape (query.size(0), query.size(1)), which corresponds to (batch_size, q_len_per_request). However, attention Log-Sum-Exp (LSE) is computed per query head. Since query has shape (batch_size, q_len_per_request, num_heads, head_dim_qk), the lse tensor should be allocated with shape (query.size(0), query.size(1), query.size(2)) to accommodate all heads and avoid memory corruption when the kernel writes LSE values.

Suggested change
lse = torch.empty(
(query.size(0), query.size(1)),
dtype=torch.float32,
device=query.device,
)
lse = torch.empty(
(query.size(0), query.size(1), query.size(2)),
dtype=torch.float32,
device=query.device,
)

Comment thread flashinfer/mla/_core.py Outdated
Comment on lines +784 to +789
lse,
(query.size(0), query.size(1)),
torch.float32,
query.device,
"lse",
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The shape check for the provided lse tensor is incorrect as it misses the num_heads dimension. It should be updated to match the 3D shape (batch_size, q_len_per_request, num_heads).

                check_shape_dtype_device(
                    lse,
                    (query.size(0), query.size(1), query.size(2)),
                    torch.float32,
                    query.device,
                    "lse",
                )

Comment thread flashinfer/mla/_core.py
@@ -795,9 +817,13 @@ def trtllm_batch_decode_with_kv_cache_mla(
None, # value_block_scales
skip_softmax_threshold_scale_factor,
uses_shared_paged_kv_idx,
lse,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The lse tensor should be flattened to 2D (num_tokens, num_heads) before being passed to the kernel, as the TRT-LLM launcher expects the first dimension to be the token dimension for stride calculations.

Suggested change
lse,
lse.flatten(0, 1) if lse is not None else None,

Comment on lines +200 to +202
runner_params.softmaxStatsPtr = float_allocator.aligned_alloc<float2>(
sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ, 16,
"trtllm_gen_softmax_workspace");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The softmaxStatsPtr is now allocated for both Context and Generation modes. In Context mode, runner_params.mSumOfSeqLensQ (the total number of query tokens) can be very large, which may lead to excessive memory allocation in the workspace (e.g., for 128 heads and 128k tokens, this would require ~128MB). Since Context kernels typically do not require this workspace buffer for multi-block reduction (as mMultiCtasKvMode is false), this allocation could cause workspace exhaustion for long sequences. Consider moving this allocation inside the else block (Generation path) or making it conditional on mode == TllmPagedAttentionMode::ForGen.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
tests/attention/test_trtllm_gen_attention.py (1)

634-659: ⚠️ Potential issue | 🟠 Major

The new prefill/decode LSE path is still untested.

Both branches unpack lse but never assert on it, so a token/head-stride bug in the new kernel plumbing would ship unnoticed. Please validate lse against the reference wrappers where they can emit it, and at minimum check shape, dtype, and finiteness in the remaining cases. That also resolves the current RUF059 warnings.

Also applies to: 1112-1115

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_trtllm_gen_attention.py` around lines 634 - 659, The
test calls flashinfer.prefill.trtllm_batch_context_with_kv_cache and unpacks lse
but never validates it; add assertions to verify lse equals the reference lse
when using the reference wrapper (compare values/close), and otherwise assert
lse has the expected shape, dtype, and that all elements are finite (use
torch.isfinite). Update the checks around the other prefill/decode call sites
mentioned (the similar block at lines ~1112-1115) to perform the same validation
so token/head-stride bugs are caught and RUF059 warnings are resolved.
flashinfer/prefill.py (1)

3971-4008: ⚠️ Potential issue | 🟠 Major

Validate provided lse buffers even when return_lse is false.

lse is forwarded to the kernel unconditionally, but shape/dtype/device checks only run inside if return_lse. A caller can pass a malformed lse with return_lse=False and bypass validation entirely.

Proposed fix
-    if return_lse:
-        if lse is None:
-            lse = torch.empty(
-                query.size(0), query.size(1), dtype=torch.float32, device=query.device
-            )
-        else:
-            check_shape_dtype_device(
-                lse, (query.size(0), query.size(1)), torch.float32, query.device, "lse"
-            )
+    if lse is not None:
+        check_shape_dtype_device(
+            lse, (query.size(0), query.size(1)), torch.float32, query.device, "lse"
+        )
+    elif return_lse:
+        lse = torch.empty(
+            query.size(0), query.size(1), dtype=torch.float32, device=query.device
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 3971 - 4008, The code currently only
validates the optional tensor lse when return_lse is true, but lse is passed to
run_func unconditionally; update the logic so any non-None lse is validated
regardless of return_lse: before calling run_func (or at top of this block)
check if lse is not None and call check_shape_dtype_device(lse, (query.size(0),
query.size(1)), torch.float32, query.device, "lse"); keep the existing branch
that allocates an empty lse when return_lse is true and lse is None, but ensure
validation still happens for user-supplied lse even when return_lse is false so
malformed buffers cannot be forwarded to run_func.
tests/attention/test_trtllm_gen_mla.py (1)

394-416: ⚠️ Potential issue | 🟠 Major

Assert on MLA lse, not just tuple unpacking.

This only proves that return_lse=True returns a second object. A bad LSE layout/stride would still pass because lse is never checked, and Ruff already flags it as unused. Please compare it against a reference LSE where available, or at least assert shape, dtype, and finiteness.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_trtllm_gen_mla.py` around lines 394 - 416, The test
currently only unpacks the second return (lse) from
flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla but never validates it;
update the backend == "trtllm-gen" branch to assert properties of lse: check
that lse exists and has the expected shape (match seq_lens_tensor or other
reference LSE shape), correct dtype (e.g., same as output or torch.float32), and
that all values are finite (no NaN/Inf); if a reference LSE tensor is available
in the test harness, compare lse against it (or at minimum validate
shape/dtype/torch.isfinite) in addition to the existing workspace_buffer zero
check, keeping the tuple unpacking of output, lse the same.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/mla/_core.py`:
- Around line 612-614: The code adds TRTLLM-only args lse and return_lse but the
xqa and cute-dsl branches still ignore them and return a plain tensor; update
the function handling to fail fast: at the start of the method (around the
signature with parameters lse and return_lse) add a guard that checks if
(return_lse or lse is not None) and the current backend is not 'trtllm-gen' (or
equivalent backend identifier used in this module), and raise a clear
ValueError/RuntimeError explaining that LSE/return_lse are only supported for
trtllm-gen; also add the same guard inside the xqa and cute-dsl branch handlers
(the code paths around the existing xqa and cute-dsl handling lines) so they
explicitly raise instead of returning plain tensors when LSE is requested.
- Around line 775-789: The LSE buffer is being allocated/validated with shape
(batch, q_len) but MLA decode flattens query and indexes LSE per token/head;
update the allocation/validation in the return_lse branch so lse has space for
all heads: use shape (query.size(0), query.size(1), query.size(2)) (i.e.,
[batch, q_len, heads]) or equivalently allocate/validate as a flat buffer of
length query.size(0)*query.size(1)*query.size(2) that the kernel can index, and
if you must keep a public 2-D shape then reshape the 3-D buffer to (batch,
q_len) only after the kernel call; adjust the check_shape_dtype_device call and
any downstream assumptions around lse accordingly (references: return_lse, lse,
query, check_shape_dtype_device).

In `@flashinfer/prefill.py`:
- Around line 3739-3743: The wrapper path doesn't forward return_lse so
BatchPrefillWithPagedKVCacheWrapper.run(return_lse=True) still takes the
trtllm-gen paged path which asserts maybe_lse is None; wire the return_lse flag
through the wrapper to the underlying prefill helper and paged path so the
paged/trtllm-gen code can return LSE when requested: update
BatchPrefillWithPagedKVCacheWrapper.run (and any intermediate wrapper functions)
to accept and pass the return_lse argument into the direct helper call, remove
or adjust the assertion on maybe_lse in the trtllm-gen paged path, and ensure
the function signature that was changed (the helper with lse:
Optional[torch.Tensor] and return_lse: bool) is invoked with the new flag so the
returned tuple (maybe_lse, ...) is propagated back to callers.

---

Outside diff comments:
In `@flashinfer/prefill.py`:
- Around line 3971-4008: The code currently only validates the optional tensor
lse when return_lse is true, but lse is passed to run_func unconditionally;
update the logic so any non-None lse is validated regardless of return_lse:
before calling run_func (or at top of this block) check if lse is not None and
call check_shape_dtype_device(lse, (query.size(0), query.size(1)),
torch.float32, query.device, "lse"); keep the existing branch that allocates an
empty lse when return_lse is true and lse is None, but ensure validation still
happens for user-supplied lse even when return_lse is false so malformed buffers
cannot be forwarded to run_func.

In `@tests/attention/test_trtllm_gen_attention.py`:
- Around line 634-659: The test calls
flashinfer.prefill.trtllm_batch_context_with_kv_cache and unpacks lse but never
validates it; add assertions to verify lse equals the reference lse when using
the reference wrapper (compare values/close), and otherwise assert lse has the
expected shape, dtype, and that all elements are finite (use torch.isfinite).
Update the checks around the other prefill/decode call sites mentioned (the
similar block at lines ~1112-1115) to perform the same validation so
token/head-stride bugs are caught and RUF059 warnings are resolved.

In `@tests/attention/test_trtllm_gen_mla.py`:
- Around line 394-416: The test currently only unpacks the second return (lse)
from flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla but never validates
it; update the backend == "trtllm-gen" branch to assert properties of lse: check
that lse exists and has the expected shape (match seq_lens_tensor or other
reference LSE shape), correct dtype (e.g., same as output or torch.float32), and
that all values are finite (no NaN/Inf); if a reference LSE tensor is available
in the test harness, compare lse against it (or at minimum validate
shape/dtype/torch.isfinite) in addition to the existing workspace_buffer zero
check, keeping the tuple unpacking of output, lse the same.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: d0be76a5-7a89-46d7-a68d-1535c9c74a52

📥 Commits

Reviewing files that changed from the base of the PR and between 7c562d5 and e1a4a8a6abcdb89a70936cc8d3c2e98f9d8b16ed.

📒 Files selected for processing (7)
  • csrc/trtllm_fmha_kernel_launcher.cu
  • flashinfer/decode.py
  • flashinfer/mla/_core.py
  • flashinfer/prefill.py
  • include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
  • tests/attention/test_trtllm_gen_attention.py
  • tests/attention/test_trtllm_gen_mla.py

Comment thread flashinfer/decode.py
Comment thread flashinfer/mla/_core.py
Comment on lines +612 to +614
lse: Optional[torch.Tensor] = None,
return_lse: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fail fast when LSE is requested on non-trtllm-gen backends.

The new args are documented as TRTLLM-only, but the xqa and cute-dsl branches still ignore them and return a plain tensor. out, lse = ... will therefore either raise or silently unpack batch slices depending on the output shape.

🛠️ Suggested guard
     if backend == "auto":
         backend = (
             "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa"
         )
+    if backend != "trtllm-gen" and (return_lse or lse is not None):
+        raise ValueError(
+            "lse and return_lse are only supported by the trtllm-gen backend"
+        )

Also applies to: 659-663

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/mla/_core.py` around lines 612 - 614, The code adds TRTLLM-only
args lse and return_lse but the xqa and cute-dsl branches still ignore them and
return a plain tensor; update the function handling to fail fast: at the start
of the method (around the signature with parameters lse and return_lse) add a
guard that checks if (return_lse or lse is not None) and the current backend is
not 'trtllm-gen' (or equivalent backend identifier used in this module), and
raise a clear ValueError/RuntimeError explaining that LSE/return_lse are only
supported for trtllm-gen; also add the same guard inside the xqa and cute-dsl
branch handlers (the code paths around the existing xqa and cute-dsl handling
lines) so they explicitly raise instead of returning plain tensors when LSE is
requested.

Comment thread flashinfer/mla/_core.py Outdated
Comment thread flashinfer/prefill.py
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 14, 2026

thanks for the contrib

sounds similar to this one #2332

i'm checking the difference.then decide which to go forward

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 14, 2026

Nice work! One suggestion: consider adding a ValueError guard when return_lse=True with the XQA backend (which doesn't support LSE). PR #2332 had this check — would be good to include here too for a clear error message instead of silent incorrect behavior.

@aleozlx aleozlx added the run-ci label Apr 14, 2026
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 14, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !544 has been created, and the CI pipeline #48457534 is currently running. I'll report back once the pipeline job completes.

)
if backend == "trtllm-gen":
output, lse = output
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems there is no reference check or any kind of check for lse in test function. Should we add something to check if lse is legit or not?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nv-yunzheq we check against the reference impl

@murphymatt murphymatt requested a review from qsang-nv as a code owner April 14, 2026 18:03
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/mla/_core.py`:
- Around line 562-568: The LSE tensor shape is computed incorrectly: change the
lse_shape calculation from using q_nope.shape[0] * num_heads to a shape of
(q_nope.shape[0], num_heads) so LSE matches the expected [num_tokens,
num_heads]; when creating or validating lse (in the block using lse, lse_shape,
torch.empty, check_shape_dtype_device) ensure dtype is torch.float32 and device
is q_nope.device (or device variable) as before.

In `@tests/attention/test_trtllm_gen_attention.py`:
- Around line 1134-1140: The test unpacks the tuple `output, lse = output` twice
causing a runtime error when `backend == "trtllm-gen"`; fix by removing the
duplicate unpack: only unpack `output` into `output, lse` once (e.g., inside the
`if should_check_lse:` block which is already true for trtllm-gen) and then
perform the `backend == "trtllm-gen"` workspace_buffer assertion without
re-unpacking; adjust the branching so `should_check_lse`, `backend`, `output`,
`lse`, and `workspace_buffer` are used correctly and no second unpack occurs.

In `@tests/attention/test_trtllm_gen_mla.py`:
- Around line 415-422: The test unpacks the (output, lse) tuple twice which
causes a runtime error when should_check_lse is True and backend ==
"trtllm-gen"; remove the second unpack in the backend == "trtllm-gen" branch
(the line `output, lse = output`) and use the already-unpacked tensors (output,
lse) when asserting on workspace_buffer, keeping checks that verify lse
shape/dtype/finite values and the workspace_buffer assertion intact.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 05458218-7db9-4dda-87bb-562f3f841586

📥 Commits

Reviewing files that changed from the base of the PR and between e1a4a8a6abcdb89a70936cc8d3c2e98f9d8b16ed and 9a562c947b98e2f7c719f980bea50a8f0e4bc57f.

📒 Files selected for processing (5)
  • flashinfer/decode.py
  • flashinfer/mla/_core.py
  • flashinfer/prefill.py
  • tests/attention/test_trtllm_gen_attention.py
  • tests/attention/test_trtllm_gen_mla.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • flashinfer/prefill.py
  • flashinfer/decode.py

Comment thread flashinfer/mla/_core.py Outdated
Comment thread tests/attention/test_trtllm_gen_attention.py
Comment thread tests/attention/test_trtllm_gen_mla.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/attention/test_trtllm_gen_mla.py (1)

392-411: Exercise the reject path for return_lse on unsupported backends.

This helper now avoids requesting LSE on xqa/cute-dsl, so the test suite no longer proves that those backends still raise ValueError for return_lse=True. A small pytest.raises(...) case here would lock in that contract and catch regressions back to silent wrong behavior.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_trtllm_gen_mla.py` around lines 392 - 411, The test
currently never asserts that requesting return_lse=True on unsupported backends
raises an error; add an explicit pytest.raises(ValueError) case around the call
to flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla (the same invocation
using variables like backend, query, kv_cache, workspace_buffer,
block_tables_kernel, seq_lens_tensor, max_seq_len, etc.) for the unsupported
backends (e.g., "xqa" and "cute-dsl") to exercise the reject path; you can
either parametrize backend over those values or add two small blocks that set
backend to each unsupported value, set should_check_lse=True, and assert a
ValueError is raised.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/attention/test_trtllm_gen_mla.py`:
- Around line 392-411: The test currently never asserts that requesting
return_lse=True on unsupported backends raises an error; add an explicit
pytest.raises(ValueError) case around the call to
flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla (the same invocation
using variables like backend, query, kv_cache, workspace_buffer,
block_tables_kernel, seq_lens_tensor, max_seq_len, etc.) for the unsupported
backends (e.g., "xqa" and "cute-dsl") to exercise the reject path; you can
either parametrize backend over those values or add two small blocks that set
backend to each unsupported value, set should_check_lse=True, and assert a
ValueError is raised.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 334cd0a9-4a28-4e6c-9de1-897e6eaab398

📥 Commits

Reviewing files that changed from the base of the PR and between 9a562c947b98e2f7c719f980bea50a8f0e4bc57f and 492e7d7021480af631921492ead9f87877650208.

📒 Files selected for processing (2)
  • tests/attention/test_trtllm_gen_attention.py
  • tests/attention/test_trtllm_gen_mla.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/attention/test_trtllm_gen_attention.py

@murphymatt murphymatt force-pushed the support-trtllm-attn-lse branch from 492e7d7 to a892361 Compare April 14, 2026 19:11
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/attention/test_trtllm_gen_attention.py (1)

660-685: ⚠️ Potential issue | 🟡 Minor

Please add one lse= buffer regression.

These new calls only cover the internally allocated path via return_lse=True. They never pass a preallocated lse tensor, so the explicit-buffer contract added by the Python APIs can regress without this file catching it. One case that passes lse=torch.empty(...) and checks the returned tensor aliases it would be enough.

Also applies to: 1138-1165

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_trtllm_gen_attention.py` around lines 660 - 685, Add a
regression test that calls flashinfer.prefill.trtllm_batch_context_with_kv_cache
with a preallocated lse tensor instead of relying on return_lse=True: allocate
lse = torch.empty(..., device=GPU_DEVICE, dtype=out_dtype) and pass it as the
lse= argument in one of the existing calls (the block invoking
trtllm_batch_context_with_kv_cache), then assert the function populates and
returns that same tensor (alias/identity check) and has correct contents; do the
same for the other similar call range mentioned (around lines 1138-1165) to
ensure the explicit-buffer contract is exercised.
♻️ Duplicate comments (2)
flashinfer/prefill.py (1)

693-717: ⚠️ Potential issue | 🔴 Critical

Thread lse through _paged_run before passing it here.

paged_run_func resolves to the _paged_run helper defined earlier in this file (Lines 249-310), and that helper still neither accepts lse nor forwards one to op.trtllm_paged_attention_context. The new keyword at Line 716 therefore raises TypeError as soon as BatchPrefillWithPagedKVCacheWrapper.run(..., return_lse=True) takes the trtllm-gen path.

🛠️ Minimal fix
 def _paged_run(
     query: torch.Tensor,
     k_cache: torch.Tensor,
     v_cache: torch.Tensor,
     workspace_buffer: torch.Tensor,
@@
     enable_pdl: bool,
     workspace_size: int,
     window_left: int = -1,
     out: Optional[torch.Tensor] = None,
+    lse: Optional[torch.Tensor] = None,
     sinks: Optional[torch.Tensor] = None,
     key_block_scales: Optional[torch.Tensor] = None,
     value_block_scales: Optional[torch.Tensor] = None,
@@
             key_block_scales,
             value_block_scales,
             skip_softmax_threshold_scale_factor,
             uses_shared_paged_kv_idx,
+            lse,
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 693 - 717, Paged attention is being
called with an lse keyword (lse) but _paged_run (the function paged_run_func
resolves to) does not accept or forward lse to
op.trtllm_paged_attention_context, causing a TypeError; update _paged_run to
accept an lse parameter and pass it through to op.trtllm_paged_attention_context
(and any intermediate call sites inside _paged_run) so that when
BatchPrefillWithPagedKVCacheWrapper.run(..., return_lse=True) takes the
trtllm-gen path the lse argument is threaded end-to-end.
flashinfer/mla/_core.py (1)

561-568: ⚠️ Potential issue | 🔴 Critical

Use [tokens, heads] for wrapper LSE, not [tokens*heads, head_dim].

q_nope here is [tokens, num_heads, head_dim_ckv], so the kernel emits one LSE value per token/head. lse_shape = (q_nope.shape[0] * num_heads, q_nope.shape[2]) bakes head_dim_ckv into the layout and returns garbage-shaped stats. This also diverges from flashinfer/decode.py (Lines 1966-1972), which uses (q_nope.size(0), q_nope.size(1)) for MLA LSE.

🛠️ Minimal fix
-            lse_shape = (q_nope.shape[0] * num_heads, q_nope.shape[2])
+            lse_shape = (q_nope.shape[0], num_heads)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/mla/_core.py` around lines 561 - 568, The wrapper currently
computes lse_shape as (q_nope.shape[0] * num_heads, q_nope.shape[2]) which
flattens heads into tokens and yields wrong layout; update the logic in the
return_lse block to use per-token-per-head shape (tokens, heads) by setting
lse_shape = (q_nope.size(0), q_nope.size(1)), allocate torch.empty(lse_shape,
dtype=torch.float32, device=device) when lse is None, and call
check_shape_dtype_device(lse, lse_shape, torch.float32, q_nope.device, "lse")
otherwise so the wrapper LSE matches the kernel output and aligns with
flashinfer/decode.py’s MLA LSE shape.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/decode.py`:
- Around line 1965-1972: The validation for the optional MLA buffer lse should
require torch.float32 regardless of whether lse is provided; update the check in
the return_lse branch so when lse is not None you call
check_shape_dtype_device(lse, lse_shape, torch.float32, q_nope.device, "lse")
(i.e., replace q_nope.dtype with torch.float32) and ensure any created empty
buffer uses dtype=torch.float32; this ensures lse is accepted only if float32
and the branch still covers the explicit-supplied case.
- Around line 1367-1374: The caller-provided lse buffer must be validated even
when return_lse is False; change the logic around the lse handling so that
whenever lse is not None you call check_shape_dtype_device(lse, (q.size(0),
q.size(1)), torch.float32, q.device, "lse") before passing it to kernels, and
only allocate a new tensor when lse is None; apply this same change to the other
decode API blocks that currently guard validation with return_lse (the blocks
containing the if return_lse: / if lse is None: / check_shape_dtype_device(...)
pattern) so all code paths that accept an lse argument enforce shape, dtype, and
device correctness.

In `@flashinfer/mla/_core.py`:
- Around line 781-789: Validate any provided lse buffer regardless of
return_lse: compute lse_shape = (batch_size * max_q_len, num_qo_heads) as shown,
then if lse is not None call check_shape_dtype_device(lse, lse_shape,
torch.float32, query.device, "lse") to validate shape/dtype/device; only
allocate a new tensor (torch.empty(...)) when return_lse is True and lse is
None. In other words, move the validation out of the if return_lse block (or add
an explicit if lse is not None check) and keep allocation conditional on
return_lse so bad buffers are caught in Python before calling the kernel.

In `@include/flashinfer/trtllm/fmha/fmhaRunnerParams.h`:
- Around line 227-230: The new lseStrideTokens and lseStrideHeads fields in
fmhaRunnerParams are declared as int but must preserve 64-bit stride values from
the launcher (see csrc/trtllm_fmha_kernel_launcher.cu using int64_t
lse_stride_tokens/lse_stride_heads); either change the fmhaRunnerParams fields
lseStrideTokens and lseStrideHeads to int64_t, or if you must keep int, add a
range-check where the launcher/population code assigns to fmhaRunnerParams
(check against INT32_MAX/INT32_MIN) and fail fast with a clear error if the
64-bit stride doesn't fit in 32 bits so the kernel cannot write out-of-bounds.

---

Outside diff comments:
In `@tests/attention/test_trtllm_gen_attention.py`:
- Around line 660-685: Add a regression test that calls
flashinfer.prefill.trtllm_batch_context_with_kv_cache with a preallocated lse
tensor instead of relying on return_lse=True: allocate lse = torch.empty(...,
device=GPU_DEVICE, dtype=out_dtype) and pass it as the lse= argument in one of
the existing calls (the block invoking trtllm_batch_context_with_kv_cache), then
assert the function populates and returns that same tensor (alias/identity
check) and has correct contents; do the same for the other similar call range
mentioned (around lines 1138-1165) to ensure the explicit-buffer contract is
exercised.

---

Duplicate comments:
In `@flashinfer/mla/_core.py`:
- Around line 561-568: The wrapper currently computes lse_shape as
(q_nope.shape[0] * num_heads, q_nope.shape[2]) which flattens heads into tokens
and yields wrong layout; update the logic in the return_lse block to use
per-token-per-head shape (tokens, heads) by setting lse_shape = (q_nope.size(0),
q_nope.size(1)), allocate torch.empty(lse_shape, dtype=torch.float32,
device=device) when lse is None, and call check_shape_dtype_device(lse,
lse_shape, torch.float32, q_nope.device, "lse") otherwise so the wrapper LSE
matches the kernel output and aligns with flashinfer/decode.py’s MLA LSE shape.

In `@flashinfer/prefill.py`:
- Around line 693-717: Paged attention is being called with an lse keyword (lse)
but _paged_run (the function paged_run_func resolves to) does not accept or
forward lse to op.trtllm_paged_attention_context, causing a TypeError; update
_paged_run to accept an lse parameter and pass it through to
op.trtllm_paged_attention_context (and any intermediate call sites inside
_paged_run) so that when BatchPrefillWithPagedKVCacheWrapper.run(...,
return_lse=True) takes the trtllm-gen path the lse argument is threaded
end-to-end.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: de148d35-83f8-4f7f-a640-a6d708815ee0

📥 Commits

Reviewing files that changed from the base of the PR and between 492e7d7021480af631921492ead9f87877650208 and a892361.

📒 Files selected for processing (7)
  • csrc/trtllm_fmha_kernel_launcher.cu
  • flashinfer/decode.py
  • flashinfer/mla/_core.py
  • flashinfer/prefill.py
  • include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
  • tests/attention/test_trtllm_gen_attention.py
  • tests/attention/test_trtllm_gen_mla.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/attention/test_trtllm_gen_mla.py
  • csrc/trtllm_fmha_kernel_launcher.cu

Comment thread flashinfer/decode.py Outdated
Comment thread flashinfer/decode.py Outdated
Comment thread flashinfer/mla/_core.py Outdated
Comment thread include/flashinfer/trtllm/fmha/fmhaRunnerParams.h Outdated
Copy link
Copy Markdown
Collaborator

@qsang-nv qsang-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@aleozlx aleozlx merged commit 25b324d into flashinfer-ai:main Apr 15, 2026
31 of 32 checks passed
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 15, 2026

sorry i missed there were test errors in the bot run... do you mind i reverting it and let's put in an open PR again and re-run tests?

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 15, 2026

there existed at least one IMA on both B200 and B300

FAILED tests/attention/test_trtllm_gen_attention.py::test_trtllm_batch_decode_spec[True-True-110-False-None-fp8-fp8-nvfp4-127-256-16-32-2-8-256-NHD-trtllm-gen] - torch.AcceleratorError: CUDA error: an illegal memory access was encountered
Search for `cudaErrorIllegalAddress' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
========= 73964 failed, 12 passed, 100192 skipped in 218.63s (0:03:38) =========
❌ FAILED: tests/attention/test_trtllm_gen_attention.py

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 15, 2026

i'm wondering have you observed that on your end, @murphymatt ? thanks

@murphymatt
Copy link
Copy Markdown
Contributor Author

@aleozlx can you link to this error? Is this from the current PR head or maybe some internal run you may have? I am seeing this pass locally on my machine (B200, cuda 12.8)

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 15, 2026

It's an internal link so i could only paste some content.
Since IMA is a sticky error, we can only attribute it to the test file (each being run in an isolated process), but not the specific test combination, as the IMA could be propagated from another test within the same process

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 15, 2026

but it is about this PR

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 15, 2026

let me give run it manually to double check

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 15, 2026

seeing clean results at 42ff2b0 (one commit before on main)

but the following error log on 25b324d (this merge)

/workspace/flashinfer$ pytest tests/attention/test_trtllm_gen_attention.py -x
===================================================== test session starts =====================================================
platform linux -- Python 3.12.13, pytest-9.0.2, pluggy-1.6.0
rootdir: /workspace/flashinfer
configfile: pytest.ini
collected 174168 items

tests/attention/test_trtllm_gen_attention.py ............F

========================================================== FAILURES ===========================================================
_____________ test_trtllm_batch_prefill[True-False-False-128-2047-511-True-None-bf16-bf16-bf16--1-256-16-4-8-HND] _____________

kv_layout = 'HND', batch_size = 256, page_size = 16, num_kv_heads = 4, head_grp_size = 8, window_left = -1, q_dtype = 'bf16'
o_dtype = 'bf16', kv_dtype = 'bf16', enable_pdl = None, enable_sink = True, max_q_len = 511, max_kv_len = 2047, head_dim = 128
non_contiguous_query = False, skips_softmax = False, uses_shared_paged_kv_idx = True

    @pytest.mark.parametrize("kv_layout", ["HND", "NHD"])
    @pytest.mark.parametrize(
        "batch_size,page_size,num_kv_heads,head_grp_size",
        [
            (4, 16, 2, 1),
            (4, 32, 4, 5),
            (4, 64, 4, 8),
            (128, 16, 2, 5),
            (128, 32, 4, 1),
            (128, 64, 2, 8),
            (256, 16, 4, 8),
            (256, 32, 2, 8),
            (256, 64, 4, 1),
            (256, 64, 4, 5),
        ],
    )
    @pytest.mark.parametrize("window_left", [-1])  # todo(Siyuan): add 127 window_left
    @pytest.mark.parametrize(
        "q_dtype,kv_dtype,o_dtype",
        [
            ("bf16", "bf16", "bf16"),
            ("fp16", "fp16", "fp16"),
            ("fp8", "fp8", "bf16"),
            ("fp8", "fp8", "fp16"),
            ("fp8", "fp8", "fp8"),
            ("fp8", "fp8", "nvfp4"),
            ("fp8", "nvfp4", "fp8"),
        ],
    )
    @pytest.mark.parametrize("enable_pdl", [None])
    @pytest.mark.parametrize("enable_sink", [True, False])
    @pytest.mark.parametrize("max_q_len", [511])
    @pytest.mark.parametrize("max_kv_len", [2047])
    @pytest.mark.parametrize("head_dim", [128, 256])
    @pytest.mark.parametrize("non_contiguous_query", [False, True])
    @pytest.mark.parametrize("skips_softmax", [False, True])
    @pytest.mark.parametrize("uses_shared_paged_kv_idx", [True, False])
    def test_trtllm_batch_prefill(
        kv_layout: str,
        batch_size: int,
        page_size: int,
        num_kv_heads: int,
        head_grp_size: int,
        window_left: int,
        q_dtype: str,
        o_dtype: str,
        kv_dtype: str,
        enable_pdl: bool,
        enable_sink: bool,
        max_q_len: int,
        max_kv_len: int,
        head_dim: int,
        non_contiguous_query: bool,
        skips_softmax: bool,
        uses_shared_paged_kv_idx: bool,
    ):
>       _test_trtllm_batch_prefill(
            kv_layout,
            batch_size,
            page_size,
            num_kv_heads,
            head_grp_size,
            window_left,
            q_dtype,
            o_dtype,
            kv_dtype,
            enable_pdl,
            enable_sink,
            max_q_len,
            max_kv_len,
            kv_dtype in ("fp8", "nvfp4"),
            head_dim,
            non_contiguous_query=non_contiguous_query,
            skips_softmax=skips_softmax,
            uses_shared_paged_kv_idx=uses_shared_paged_kv_idx,
        )

tests/attention/test_trtllm_gen_attention.py:852:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

kv_layout = 'HND', batch_size = 256, page_size = 16, num_kv_heads = 4, head_grp_size = 8, window_left = -1, q_dtype = 'bf16'
o_dtype = 'bf16', kv_dtype = 'bf16', enable_pdl = None, enable_sink = True, max_q_len = 511, max_kv_len = 2047
device_scale = False, head_dim = 128, non_contiguous_query = False, skips_softmax = False, uses_shared_paged_kv_idx = True

    def _test_trtllm_batch_prefill(
        kv_layout: str,
        batch_size: int,
        page_size: int,
        num_kv_heads: int,
        head_grp_size: int,
        window_left: int,
        q_dtype: str,
        o_dtype: str,
        kv_dtype: str,
        enable_pdl: bool,
        enable_sink: bool,
        max_q_len: int,
        max_kv_len: int,
        device_scale: float,
        head_dim: int,
        non_contiguous_query: bool = False,
        skips_softmax: bool = False,
        uses_shared_paged_kv_idx: bool = True,
    ):
        compute_capability = get_compute_capability(torch.device(device="cuda"))
        if compute_capability[0] != 10:
            pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.")

        if skips_softmax and q_dtype != kv_dtype:
            pytest.skip(
                "skips_softmax does not currently support Q and Kv types being different"
            )

        # NVFP4 KV cache constraints
        if kv_dtype == "nvfp4":
            if q_dtype != "fp8":
                pytest.skip("NVFP4 KV cache requires FP8 query")
            if o_dtype != "fp8":
                pytest.skip("NVFP4 KV cache only supports FP8 output")

        # Set up test parameters
        torch.manual_seed(0)

        # Generate random sequence lengths
        num_qo_heads = num_kv_heads * head_grp_size
        q_lens, _, seq_lens = generate_seq_lens_prefill(batch_size, max_q_len, max_kv_len)

        # Create query tensor and related data
        q, q_scale, ref_q = create_query_tensor(q_lens, num_qo_heads, head_dim, q_dtype)
        q_indptr = generate_cumsum_lens(q_lens)

        # Create KV cache and related data
        kv_cache, k_scale, v_scale, ref_kv_cache, kv_cache_sf = create_kv_cache(
            batch_size,
            seq_lens,
            page_size,
            num_kv_heads,
            head_dim,
            kv_dtype,
            "bf16" if q_dtype == "fp8" or kv_dtype == "nvfp4" else q_dtype,
            kv_layout,
        )
        page_table, all_page_ids, page_per_seq = create_page_table(
            batch_size, seq_lens, page_size
        )
        kv_indptr = generate_cumsum_lens(page_per_seq)
        kv_last_page_len = get_last_page_len(seq_lens, page_size)

        kv_cache_kernel, page_table_kernel, kv_cache_sf_kernel = (
            prepare_paged_kv_for_kernel(
                kv_cache, page_table, uses_shared_paged_kv_idx, kv_cache_sf
            )
        )

        workspace_buffer, workspace_buffer_ref = create_workspace_buffers(GPU_DEVICE)

        # Create output tensor and related data
        create_out_tensor = flip_coin(
            batch_size, page_size, num_kv_heads, head_grp_size, o_dtype
        )
        can_infer_type = q.dtype == DTYPE_MAP[o_dtype] or create_out_tensor
        create_out_dtype = not can_infer_type or flip_coin(
            batch_size, page_size, num_kv_heads, head_grp_size, o_dtype, q_dtype
        )
        out, out_dtype, o_scale, o_sf_scale, o_sf_vec_size = create_output(
            q, o_dtype, create_out_tensor, create_out_dtype
        )

        sm_scale = float(1.0 / (head_dim**0.5))
        lse_ref = None

        # Build reference output
        plan_params = {
            "qo_indptr": q_indptr,
            "paged_kv_indptr": kv_indptr,
            "paged_kv_indices": all_page_ids,
            "paged_kv_last_page_len": kv_last_page_len.to(GPU_DEVICE),
            "num_qo_heads": num_qo_heads,
            "num_kv_heads": num_kv_heads,
            "head_dim_qk": head_dim,
            "page_size": page_size,
            "causal": True,
            "pos_encoding_mode": "NONE",
            "logits_soft_cap": 0.0,
            "q_data_type": ref_q.dtype,
            "kv_data_type": ref_kv_cache.dtype,
            "window_left": window_left,
        }
        if not enable_sink:
            wrapper_ref = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(
                workspace_buffer_ref, kv_layout
            )
            wrapper_ref.plan(**plan_params)
            output_ref, lse_ref = wrapper_ref.run(ref_q, ref_kv_cache, return_lse=True)
        else:
            # Construct flat K/V via helper
            k_flat, v_flat, kv_indptr_tokens = flatten_paged_kv(
                ref_kv_cache,
                page_table,
                seq_lens.to(GPU_DEVICE),
                page_size,
                kv_last_page_len,
                kv_layout,
            )
            sink = torch.rand(num_qo_heads, device=GPU_DEVICE, dtype=torch.float32) * 5
            output_ref = sink_attention_unified(
                ref_q,
                k_flat,
                v_flat,
                sink,
                window_left,
                True,
                sm_scale,
                mode="varlen",
                batch_size=batch_size,
                qo_indptr=q_indptr,
                kv_indptr=kv_indptr_tokens,
            )

        # Run trtllm-gen function call
        bmm1_scale = q_scale * k_scale * sm_scale
        bmm2_scale = v_scale / o_scale
        if isinstance(bmm1_scale, torch.Tensor) and not device_scale:
            bmm1_scale = bmm1_scale.item()
        elif not isinstance(bmm1_scale, torch.Tensor) and device_scale:
            bmm1_scale = torch.tensor(bmm1_scale, device=GPU_DEVICE, dtype=torch.float32)
        if isinstance(bmm2_scale, torch.Tensor) and not device_scale:
            bmm2_scale = bmm2_scale.item()
        elif not isinstance(bmm2_scale, torch.Tensor) and device_scale:
            bmm2_scale = torch.tensor(bmm2_scale, device=GPU_DEVICE, dtype=torch.float32)

        # Optionally make query non-contiguous for testing stride support
        if non_contiguous_query:
            q_input = make_query_non_contiguous(q, num_qo_heads, head_dim)
        else:
            q_input = q.contiguous()

        # Using a tiny threshold should give the same result as normal attention.
        skip_softmax_threshold_scale_factor = 1e-30 if skips_softmax else None
        softmax_end_bytes = trtllm_gen_workspace_softmax_end_bytes_context(
            workspace_buffer,
            num_qo_heads=q_input.size(1),
            sum_seq_q=q_input.size(0),
        )
        workspace_check_end_bytes = min(
            softmax_end_bytes + TRTLLM_GEN_WORKSPACE_CHECK_BYTES, workspace_buffer.numel()
        )
        workspace_buffer[softmax_end_bytes:workspace_check_end_bytes].zero_()
        provided_lse = torch.empty(
            (q_input.size(0), q_input.size(1)),
            device=GPU_DEVICE,
            dtype=torch.float32,
        )

        output, lse = flashinfer.prefill.trtllm_batch_context_with_kv_cache(
            q_input,
            kv_cache_kernel,
            workspace_buffer,
            page_table_kernel,
            seq_lens.to(GPU_DEVICE),
            torch.max(q_lens).item(),
            torch.max(seq_lens).item(),
            bmm1_scale,  # bmm1_scale
            bmm2_scale,  # bmm2_scale
            batch_size,
            q_indptr,
            kv_indptr,
            window_left,  # window_left
            out=out,
            out_dtype=out_dtype,
            o_sf_scale=o_sf_scale,
            o_sf_vec_size=o_sf_vec_size,
            kv_layout=kv_layout,
            enable_pdl=enable_pdl,
            sinks=(sink if enable_sink else None),
            kv_cache_sf=kv_cache_sf_kernel,
            skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor,
            uses_shared_paged_kv_idx=uses_shared_paged_kv_idx,
            lse=provided_lse,
            return_lse=True,
        )
        assert (
            workspace_buffer[softmax_end_bytes:workspace_check_end_bytes].cpu().numpy() == 0
        ).all()

        if o_dtype == "nvfp4":
            output, output_ref = unpack_compare_nvfp4(
                output, output_ref, o_sf_scale, o_sf_vec_size
            )
            assert o_scale == 1.0
            rtol, atol = 4e-1, 1e0
        elif q_dtype == "fp8" and o_dtype == "fp8":
            rtol, atol = 5e-2, 7e-2
        elif q_dtype == "fp8" and o_dtype in ["bf16", "fp16"]:
            rtol, atol = 4e-2, 6e-2
        else:
            rtol, atol = 1e-2, 1e-2

        # NVFP4 KV cache has significant quantization error, especially with
        # outlier channels that create large per-block dynamic range.
        if kv_dtype == "nvfp4":
            rtol, atol = 5e-1, 5e-1

        # NVFP4 KV cache has higher mismatch rate due to 4-bit quantization noise,
        # especially with outlier channels that stress per-block scaling.
        allowed_mismatch_rate = 0.10 if kv_dtype == "nvfp4" else 1e-7
        # Calculate max allowed mismatched elements based on tensor size
        total_elements = (output.float() * o_scale).numel()
        max_mismatched_elements = int(allowed_mismatch_rate * total_elements)

        # convert to float32 for fp8 is not supported by assert_close
        assert_close_with_mismatch_tolerance(
            output.float() * o_scale,
            output_ref.float(),
            rtol=rtol,
            atol=atol,
            max_mismatched_elements=max_mismatched_elements,
        )

        # NVFP4 KV cache: use cosine similarity to catch block-scale mismatches
        # (e.g. wrong swizzling) that element-wise tolerances miss.
        if kv_dtype == "nvfp4":
            cos = torch.nn.functional.cosine_similarity(
                (output.float() * o_scale).reshape(-1),
                output_ref.float().reshape(-1),
                dim=0,
            )
            assert cos.item() > 0.86, (
                f"NVFP4 KV cache attention: cosine similarity {cos:.4f} < 0.86. "
                f"Block scale factors may be mismatched to FP4 data blocks."
            )

        expected_lse_shape = (q_input.size(0), q_input.size(1))
        assert lse is provided_lse
        assert lse.shape == expected_lse_shape
        assert lse.dtype == torch.float32
        assert torch.isfinite(lse).all()
        if lse_ref is not None:
            lse_rtol, lse_atol = get_lse_test_tolerances(q_dtype, kv_dtype)
            torch.testing.assert_close(lse, lse_ref, rtol=lse_rtol, atol=lse_atol)

        if (
            o_dtype != "nvfp4" and kv_dtype != "nvfp4" and uses_shared_paged_kv_idx
        ):  # wrapper api does not support fp4 output/kv or separate KV page indices yet.
            # test wrapper with trtllm-gen backend
            wrapper_trtllm_gen = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(
                workspace_buffer, kv_layout, backend="trtllm-gen"
            )
            plan_params["q_data_type"] = q.dtype
            plan_params["kv_data_type"] = kv_cache.dtype
            plan_params["o_data_type"] = DTYPE_MAP[o_dtype]
            wrapper_trtllm_gen.plan(**plan_params)
            workspace_buffer[softmax_end_bytes:workspace_check_end_bytes].zero_()
            output_wrapper = wrapper_trtllm_gen.run(
                q_input,
                kv_cache,
                q_scale=q_scale,
                k_scale=k_scale,
                v_scale=v_scale / o_scale,
                enable_pdl=enable_pdl,
                sinks=(sink if enable_sink else None),
            )
            # v_scale, o_scale in wrapper is emulated by multiplying output by v_scale instead of fused into kernel.
            if v_scale == o_scale == 1.0:
>               assert (output_wrapper == output).all()
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E               torch.AcceleratorError: CUDA error: an illegal memory access was encountered
E               Search for `cudaErrorIllegalAddress' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
E               CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
E               For debugging consider passing CUDA_LAUNCH_BLOCKING=1
E               Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

tests/attention/test_trtllm_gen_attention.py:785: AcceleratorError
=================================================== short test summary info ===================================================
FAILED tests/attention/test_trtllm_gen_attention.py::test_trtllm_batch_prefill[True-False-False-128-2047-511-True-None-bf16-bf16-bf16--1-256-16-4-8-HND] - torch.AcceleratorError: CUDA error: an illegal memory access was encountered
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! stopping after 1 failures !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
================================================ 1 failed, 12 passed in 28.75s

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 15, 2026

so the 1st encounter seem to be

FAILED tests/attention/test_trtllm_gen_attention.py::test_trtllm_batch_prefill[True-False-False-128-2047-511-True-None-bf16-bf16-bf16--1-256-16-4-8-HND] - torch.AcceleratorError: CUDA error: an illegal memory access was encountered

aleozlx added a commit that referenced this pull request Apr 16, 2026
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 16, 2026

Hi @murphymatt I had to commit the revert based on the above info to keep our CI green, but please let's try this again in an open PR and address such issues. Your contribution is much valued. Thank you!

@murphymatt
Copy link
Copy Markdown
Contributor Author

@aleozlx is it possible for Nvidia to work on re-landing this change? I am not quite sure how we can ensure test correctness over your environment, given that it passes on our host.

ziang-and pushed a commit to zianglih/flashinfer that referenced this pull request Apr 17, 2026
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 18, 2026

Hi this appears to be observed on all b200/b300/gb200/gb300 configurations as well as my local attempt with a gb100 node. So it didn't appear to me that it was a tricky environment thing. But ofc let me escalate it for vis

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 18, 2026

Filed issue for it #3114

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 18, 2026

some clues about the IMA has been posted in the issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants