Revert "Support lse in trtllm paged attn kernels"#3079
Conversation
This reverts commit 25b324d.
Code Review: Revert "Support lse in trtllm paged attn kernels"This PR reverts #3058, removing 🐛 Critical Bug: Wrong dtype in LSE validation (
|
There was a problem hiding this comment.
Code Review
This pull request removes Log-Sum-Exp (LSE) support and its associated parameters from the TRT-LLM attention kernels and their Python wrappers, including paged attention and MLA implementations. The changes involve simplifying function signatures, removing LSE-related workspace allocations, and updating tests to reflect the removal of LSE return values. Feedback includes a correction for a type check in flashinfer/decode.py where the LSE tensor should be validated against torch.float32 rather than the query's data type, and a suggestion to replace magic numbers in the test suite with named constants for better maintainability.
| check_shape_dtype_device( | ||
| lse, | ||
| (q_nope.size(0), q_nope.size(1)), | ||
| q_nope.dtype, | ||
| q_nope.device, | ||
| "lse", | ||
| ) |
There was a problem hiding this comment.
The check_shape_dtype_device for lse is using q_nope.dtype as the expected data type. The log-sum-exp tensor (lse) should have a high-precision float type, typically torch.float32, regardless of the query's data type. Using q_nope.dtype could lead to incorrect type checks when q_nope is a lower precision type like float16 or bfloat16. This appears to be a reintroduction of a bug that might have been fixed in the reverted changes.
| check_shape_dtype_device( | |
| lse, | |
| (q_nope.size(0), q_nope.size(1)), | |
| q_nope.dtype, | |
| q_nope.device, | |
| "lse", | |
| ) | |
| check_shape_dtype_device( | |
| lse, | |
| (q_nope.size(0), q_nope.size(1)), | |
| torch.float32, | |
| q_nope.device, | |
| "lse", | |
| ) |
| ).all() | ||
| # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero | ||
| # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future | ||
| assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() |
There was a problem hiding this comment.
The size of the workspace buffer being checked, 8192 * 256 * 4, is a magic number. This value corresponds to the size of the counter workspace. To improve readability and maintainability, it would be better to define this as a constant, for example TRTLLM_GEN_COUNTER_WORKSPACE_BYTES, and use that constant here and on line 744. The comment on line 665 already indicates that this size might change in the future, which further strengthens the case for using a named constant.
📝 WalkthroughWalkthroughThis PR removes LSE (log-sum-exp) buffer support from the TRTLlm-gen paged attention implementation. Changes include eliminating LSE parameters from kernel launcher signatures, removing LSE workspace allocations, and updating Python decode/prefill/MLA APIs to no longer accept or return LSE tensors. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Warning Review ran into problems🔥 ProblemsGit: Failed to clone repository. Please run the Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tests/attention/test_trtllm_gen_mla.py (1)
665-675:⚠️ Potential issue | 🟠 MajorRe-zero the shared TRT-LLM workspace before each sparse run.
Unlike
trtllm_batch_decode_mla(), this path reusesglobal_trtllm_gen_fmha_workspace_bufferwithout resetting it first. Because the buffer is global, earlier parametrized cases can dirty the counter region and make the zero-region assertion at Line 697 flaky.♻️ Suggested fix
if global_trtllm_gen_fmha_workspace_buffer is None: global_trtllm_gen_fmha_workspace_buffer = torch.zeros( workspace_size, dtype=torch.int8, device=device ) workspace_buffer = global_trtllm_gen_fmha_workspace_buffer + workspace_buffer.zero_() # workspace_buffer_ref = global_workspace_buffer🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_trtllm_gen_mla.py` around lines 665 - 675, Reset the shared TRT-LLM workspace buffer before each sparse run by explicitly zeroing global_trtllm_gen_fmha_workspace_buffer prior to setting workspace_buffer; locate the block that sets global_trtllm_gen_fmha_workspace_buffer and workspace_buffer in test_trtllm_gen_mla.py and call an in-place zeroing operation (e.g., fill_(0) or torch.zeros_like assignment) on global_trtllm_gen_fmha_workspace_buffer so the counter region is cleared before reuse (similar to how trtllm_batch_decode_mla() initializes its buffer).tests/attention/test_trtllm_gen_attention.py (1)
639-666:⚠️ Potential issue | 🟠 MajorThese fixed zero-slice assertions need a per-call workspace reset.
create_workspace_buffers()returns a global TRT-LLM workspace, and the decode helper also reuses that buffer for the XQA path, which writes semaphore state into this exact prefix. After switching to a fixed zero-region assert, these tests become order-dependent unless the TRT-LLM workspace iszero_()ed before each kernel invocation.♻️ Suggested fix
workspace_buffer, workspace_buffer_ref = create_workspace_buffers(GPU_DEVICE) + workspace_buffer.zero_()Apply the same reset in both the prefill and decode helpers before the TRT-LLM call.
Also applies to: 1095-1125
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_trtllm_gen_attention.py` around lines 639 - 666, The tests assume the TRT-LLM global workspace is zeroed but it’s reused across calls; before each kernel invocation (e.g., before calling trtllm_batch_context_with_kv_cache in the prefill helper and similarly in the decode helper), explicitly reset the shared workspace returned by create_workspace_buffers() by calling its zeroing method (e.g., workspace_buffer.zero_() or .zero()) so the fixed zero-region assertion on workspace_buffer[:8192 * 256 * 4] is valid and tests are order-independent; apply the same reset in both the prefill and decode helper locations referenced in the diff.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/decode.py`:
- Around line 1966-1980: The code allocates lse as torch.float32 when return_lse
is true but validates a caller-supplied lse against q_nope.dtype (which may be
bf16/fp16), causing valid float32 buffers to fail; update the validation to
expect torch.float32 instead of q_nope.dtype by calling
check_shape_dtype_device(lse, (q_nope.size(0), q_nope.size(1)), torch.float32,
q_nope.device, "lse") (ensure you import/qualify torch.float32 if needed) while
keeping the shape and device checks the same.
- Around line 1367-1375: The TRT-LLM decode wrapper must refuse use of
return_lse/lse instead of letting the custom op's paged_run assert; add an
explicit check in the wrapper (the block handling return_lse and lse) that if
the backend is TRT-LLM (or when calling paged_run) and (return_lse is True or
lse is not None) raise a clear ValueError with a user-facing message; reference
the existing symbols return_lse, lse, paged_run and maybe_lse so you locate the
code path and replace the silent assertion with this explicit check.
---
Outside diff comments:
In `@tests/attention/test_trtllm_gen_attention.py`:
- Around line 639-666: The tests assume the TRT-LLM global workspace is zeroed
but it’s reused across calls; before each kernel invocation (e.g., before
calling trtllm_batch_context_with_kv_cache in the prefill helper and similarly
in the decode helper), explicitly reset the shared workspace returned by
create_workspace_buffers() by calling its zeroing method (e.g.,
workspace_buffer.zero_() or .zero()) so the fixed zero-region assertion on
workspace_buffer[:8192 * 256 * 4] is valid and tests are order-independent;
apply the same reset in both the prefill and decode helper locations referenced
in the diff.
In `@tests/attention/test_trtllm_gen_mla.py`:
- Around line 665-675: Reset the shared TRT-LLM workspace buffer before each
sparse run by explicitly zeroing global_trtllm_gen_fmha_workspace_buffer prior
to setting workspace_buffer; locate the block that sets
global_trtllm_gen_fmha_workspace_buffer and workspace_buffer in
test_trtllm_gen_mla.py and call an in-place zeroing operation (e.g., fill_(0) or
torch.zeros_like assignment) on global_trtllm_gen_fmha_workspace_buffer so the
counter region is cleared before reuse (similar to how trtllm_batch_decode_mla()
initializes its buffer).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: c6857e0b-0eb4-486c-b26d-292c8254736e
📒 Files selected for processing (7)
csrc/trtllm_fmha_kernel_launcher.cuflashinfer/decode.pyflashinfer/mla/_core.pyflashinfer/prefill.pyinclude/flashinfer/trtllm/fmha/fmhaRunnerParams.htests/attention/test_trtllm_gen_attention.pytests/attention/test_trtllm_gen_mla.py
💤 Files with no reviewable changes (1)
- include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
| if return_lse: | ||
| if lse is None: | ||
| lse = torch.empty( | ||
| (q.size(0), q.size(1)), dtype=torch.float32, device=q.device | ||
| ) | ||
| else: | ||
| check_shape_dtype_device( | ||
| lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse" | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -e
sed -n '1218,1525p' flashinfer/decode.py
echo '---'
sed -n '2096,2160p' flashinfer/decode.pyRepository: flashinfer-ai/flashinfer
Length of output: 16188
🏁 Script executed:
# Check for existing guards on lse/return_lse with trtllm-gen
rg -n "trtllm-gen" flashinfer/decode.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 1599
🏁 Script executed:
# Check the complete argument assembly and how lse is passed to the backend
sed -n '1350,1450p' flashinfer/decode.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 4692
🏁 Script executed:
# Verify the custom op assert and surrounding context
sed -n '2145,2165p' flashinfer/decode.pyRepository: flashinfer-ai/flashinfer
Length of output: 975
Block lse/return_lse on the TRT-LLM decode wrapper instead of asserting internally.
The public wrapper accepts return_lse=True and explicit lse tensors without checking the backend, but passes them to the custom op's paged_run() which asserts maybe_lse is None. This causes an AssertionError instead of a stable user-facing error, and fails silently under python -O.
♻️ Suggested fix
+ if self._backend == "trtllm-gen" and (return_lse or lse is not None):
+ raise ValueError(
+ "trtllm-gen backend does not support lse/return_lse"
+ )
if return_lse:
if lse is None:
lse = torch.empty(
(q.size(0), q.size(1)), dtype=torch.float32, device=q.device
)
else:
check_shape_dtype_device(
lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/decode.py` around lines 1367 - 1375, The TRT-LLM decode wrapper
must refuse use of return_lse/lse instead of letting the custom op's paged_run
assert; add an explicit check in the wrapper (the block handling return_lse and
lse) that if the backend is TRT-LLM (or when calling paged_run) and (return_lse
is True or lse is not None) raise a clear ValueError with a user-facing message;
reference the existing symbols return_lse, lse, paged_run and maybe_lse so you
locate the code path and replace the silent assertion with this explicit check.
| if return_lse: | ||
| if lse is None: | ||
| lse = torch.empty( | ||
| (q_nope.size(0), q_nope.size(1)), | ||
| dtype=torch.float32, | ||
| device=device, | ||
| ) | ||
| else: | ||
| check_shape_dtype_device( | ||
| lse, | ||
| (q_nope.size(0), q_nope.size(1)), | ||
| q_nope.dtype, | ||
| q_nope.device, | ||
| "lse", | ||
| ) |
There was a problem hiding this comment.
Validate caller-supplied MLA lse buffers as float32.
This branch allocates lse as torch.float32, but the explicit-buffer path validates against q_nope.dtype. A correctly preallocated float32 lse tensor will fail here for bf16/fp16 inputs.
♻️ Suggested fix
check_shape_dtype_device(
lse,
(q_nope.size(0), q_nope.size(1)),
- q_nope.dtype,
+ torch.float32,
q_nope.device,
"lse",
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/decode.py` around lines 1966 - 1980, The code allocates lse as
torch.float32 when return_lse is true but validates a caller-supplied lse
against q_nope.dtype (which may be bf16/fp16), causing valid float32 buffers to
fail; update the validation to expect torch.float32 instead of q_nope.dtype by
calling check_shape_dtype_device(lse, (q_nope.size(0), q_nope.size(1)),
torch.float32, q_nope.device, "lse") (ensure you import/qualify torch.float32 if
needed) while keeping the shape and device checks the same.
| if return_lse: | ||
| if lse is None: | ||
| lse = torch.empty( | ||
| (q.size(0), q.size(1)), dtype=torch.float32, device=q.device | ||
| ) | ||
| else: | ||
| check_shape_dtype_device( | ||
| lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse" | ||
| ) |
There was a problem hiding this comment.
Preserve explicit lse validation and reject TRT-LLM LSE requests earlier.
lse is still forwarded to the backend even when return_lse is False, but this block now validates it only inside the return_lse path. That drops the shape/dtype/device check for callers reusing an explicit lse buffer, and on backend="trtllm-gen" it allocates lse here only to trip the internal assert at Line 682. Please validate any provided lse unconditionally, then raise a real NotImplementedError/ValueError before allocation when the selected backend does not support LSE.
Suggested fix
- if return_lse:
- if lse is None:
- lse = torch.empty(
- (q.size(0), q.size(1)), dtype=torch.float32, device=q.device
- )
- else:
- check_shape_dtype_device(
- lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
- )
+ if lse is not None:
+ check_shape_dtype_device(
+ lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
+ )
+ if return_lse:
+ if self._backend == "trtllm-gen":
+ raise NotImplementedError(
+ "return_lse is not supported for backend='trtllm-gen'."
+ )
+ if lse is None:
+ lse = torch.empty(
+ (q.size(0), q.size(1)), dtype=torch.float32, device=q.device
+ )
Reverts #3058
Summary by CodeRabbit
Refactor
Tests