Skip to content

[API change] Allow using torch.Tensor for scales for trtllm-gen attention#2084

Merged
jiahanc merged 11 commits intomainfrom
trtllm-gen-attention-allow-device-scales
Nov 18, 2025
Merged

[API change] Allow using torch.Tensor for scales for trtllm-gen attention#2084
jiahanc merged 11 commits intomainfrom
trtllm-gen-attention-allow-device-scales

Conversation

@IwakuraRein
Copy link
Copy Markdown
Collaborator

@IwakuraRein IwakuraRein commented Nov 13, 2025

📌 Description

  • change bmm1_scale and bmm2_scale to Union[float, torch.Tensor]. notice that when using tensor, it must be applied by log2e
  • remove the bmm1_scale_log2_tensor and bmm2_scale_tensor in the xqa_batch_decode_with_kv_cache_mla
  • update trtllm-gen FMHA kernels

TODO: do the same refactor for xqa kernels. The support for the device side scales was removed in #2033

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Attention scale parameters now accept either floats or 1-element tensors across prefill, decode and runtime; tensor scales are validated and applied on-device and pointer-backed scale paths are supported.
  • Chores

    • Updated FMHA artifact path and checksum constants; added a public utility import and removed an obsolete inline comment.
  • Tests

    • Updated tests to exercise device/tensor-or-scalar scale flows, removed legacy per-tensor call-site args, and added device-scale parametrization for several test variants.

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Nov 13, 2025

Walkthrough

Adds tensor-or-scalar support for attention scaling across Python APIs and C++ FMHA launchers using tvm::ffi::Variant; accepts device-resident scale tensors (passed as float pointers) or host scalars, applies on-device log2e scaling for tensor inputs, and updates FMHA cubin artifact path and checksum constants.

Changes

Cohort / File(s) Summary
C++ Kernel Launcher Interface Updates
csrc/trtllm_fmha_kernel_launcher.cu
Added #include <tvm/ffi/container/variant.h> and using tvm::ffi::Variant; changed launcher signatures to accept Variant<double, ffi::Tensor> for bmm scales; added const float* bmm1_scale_log2_ptr, const float* bmm2_scale_ptr params; extraction/validation logic for Variant values and wiring into runner_params (scaleSoftmaxLog2Ptr, outputScalePtr).
Ragged / Paged Attention Callsites
csrc/trtllm_fmha_kernel_launcher.cu (functions: trtllm_paged_attention_decode, trtllm_paged_attention_context, trtllm_ragged_attention, trtllm_paged_attention_launcher, trtllm_ragged_attention_launcher)
Replaced scalar bmm1/bmm2 with Variant<double, ffi::Tensor> in public callsites; extract concrete doubles and optional device pointers; added/adjusted args (e.g., max_kv_len, lse, attention_sinks); updated launcher invocations to prefer pointer-based scales when present.
Python Decode API Updates
flashinfer/decode.py
Broadened bmm scale parameters to Union[float, torch.Tensor] across decode entry points and _paged_run; imported log2e; validate tensor dtype (float32), compute log2-scaled tensor on-device (multiply by log2e), and pass tensors or scalars to C++ paths.
Python Prefill API Updates
flashinfer/prefill.py
Broadened bmm scale parameters to Union[float, torch.Tensor] for prefill/context functions; imported log2e; when tensor provided, assert float32 and apply log2e scaling on-device while preserving device/dtype; removed .item() scalar extraction.
Artifact Metadata
flashinfer/artifacts.py
Updated ArtifactPath.TRTLLM_GEN_FMHA path string and CheckSumHash.TRTLLM_GEN_FMHA checksum constant values.
Tests
tests/attention/test_trtllm_gen_mla.py, tests/attention/test_trtllm_gen_attention.py
test_trtllm_gen_mla.py: removed construction/passing of on-device log2 scale tensors and corresponding args. test_trtllm_gen_attention.py: added device_scale parameterization and logic to materialize bmm scales as scalars or CUDA tensors for coverage; updated call sites accordingly.
Minor Header Cleanup
include/flashinfer/trtllm/fmha/kernelParams.h
Removed an inline TODO comment; no functional changes.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant Py as Python API
    participant Bind as FFI Binder (C++)
    participant Launcher as trtllm_*_launcher
    participant Runner as FMHA Runner

    Py->>Bind: call API with bmm1_scale, bmm2_scale (float or Tensor)
    alt Tensor inputs
        note right of Bind `#d6f5d6`: assert dtype float32\ncompute tensor * log2e on device
        Bind->>Launcher: Variant(tensor) + provide bmm1_scale_log2_ptr & bmm2_scale_ptr
    else Scalar inputs
        note right of Bind `#f0f0f0`: keep/convert as double
        Bind->>Launcher: Variant(double) + pass nullptr for scale pointers
    end
    Launcher->>Runner: set runner_params.scaleSoftmaxLog2Ptr & outputScalePtr
    Runner->>Runner: execute FMHA using pointer scales if present else scalar scales
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Areas needing extra attention:
    • csrc/trtllm_fmha_kernel_launcher.cu — Variant extraction, pointer lifetime/null handling, and wiring into runner_params.
    • Python bindings (flashinfer/decode.py, flashinfer/prefill.py) — device/dtype assertions, on-device log2e scaling, and consistent pointer vs scalar branching.
    • Tests — ensure new parameterization covers both tensor and scalar paths and removed args are reconciled.

Possibly related PRs

Suggested reviewers

  • joker-eph
  • aleozlx
  • djmmoss
  • cyx-6
  • nvmbreughe
  • yzh119

Poem

🐇 I hop with scales both small and grand,
Variant in paw and pointer in hand,
log2e twinkles on-device tonight,
Kernels lean in as pointers take flight,
A rabbit cheers — attention tuned right.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 30.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly and concisely describes the main change: enabling torch.Tensor support for bmm scales in trtllm-gen attention alongside existing float support.
Description check ✅ Passed The PR description covers the key changes (tensor scale support with log2e application, removal of legacy parameters, kernel updates) and confirms pre-commit checks and testing completed.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch trtllm-gen-attention-allow-device-scales

Tip

📝 Customizable high-level summaries are now available in beta!

You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.

  • Provide your own instructions using the high_level_summary_instructions setting.
  • Format the summary however you like (bullet lists, tables, multi-section layouts, contributor stats, etc.).
  • Use high_level_summary_in_walkthrough to move the summary from the description to the walkthrough section.

Example instruction:

"Divide the high-level summary into five sections:

  1. 📝 Description — Summarize the main change in 50–60 words, explaining why this PR is needed, why this solution was chosen, and what was done.
  2. 📓 References — List relevant issues, discussions, documentation, or related PRs.
  3. 📦 Dependencies & Requirements — Mention any new/updated dependencies, environment variable changes, or configuration updates.
  4. 📊 Contributor Summary — Include a Markdown table showing contributions:
    | Contributor | Lines Added | Lines Removed | Files Changed |
  5. ✔️ Additional Notes — Add any extra reviewer context.
    Keep each section concise (under 200 words) and use bullet or numbered lists for clarity."

Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@IwakuraRein IwakuraRein marked this pull request as ready for review November 13, 2025 19:02
@IwakuraRein
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

@IwakuraRein is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

@jiahanc
Copy link
Copy Markdown
Collaborator

jiahanc commented Nov 13, 2025

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !136 has been created, and the CI pipeline #38436074 is currently running. I'll report back once the pipeline job completes.

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[CANCELING] Pipeline #38436074: canceled

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/decode.py (1)

1883-1901: In-place scale multiply causes drift

Same issue here: bmm1_scale *= log2e updates the caller’s tensor. If the caller caches that buffer (common in decode loops), it compounds every step. Please switch to a non-in-place multiply or clone first. (docs.pytorch.org)

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6765cad and 2f39e1f.

📒 Files selected for processing (5)
  • csrc/trtllm_fmha_kernel_launcher.cu (13 hunks)
  • flashinfer/artifacts.py (2 hunks)
  • flashinfer/decode.py (13 hunks)
  • flashinfer/prefill.py (9 hunks)
  • include/flashinfer/trtllm/fmha/kernelParams.h (0 hunks)
💤 Files with no reviewable changes (1)
  • include/flashinfer/trtllm/fmha/kernelParams.h

Comment on lines +260 to +277
auto maybe_bmm2_scale_value = bmm2_scale.as<double>();
auto maybe_bmm1_scale_log2_tensor = bmm1_scale.as<ffi::Tensor>();
auto maybe_bmm2_scale_tensor = bmm2_scale.as<ffi::Tensor>();
TVM_FFI_CHECK(maybe_bmm1_scale_value.has_value() || maybe_bmm1_scale_log2_tensor.has_value(),
"bmm1_scale must be either a double or a tensor");
TVM_FFI_CHECK(maybe_bmm2_scale_value.has_value() || maybe_bmm2_scale_tensor.has_value(),
"bmm2_scale must be either a double or a tensor");
double bmm1_scale_value =
maybe_bmm1_scale_value.has_value() ? maybe_bmm1_scale_value.value() : 1.0;
double bmm2_scale_value =
maybe_bmm2_scale_value.has_value() ? maybe_bmm2_scale_value.value() : 1.0;
float* bmm1_scale_log2_ptr =
maybe_bmm1_scale_log2_tensor.has_value()
? static_cast<float*>(maybe_bmm1_scale_log2_tensor.value().data_ptr())
: nullptr;
float* bmm2_scale_ptr = maybe_bmm2_scale_tensor.has_value()
? static_cast<float*>(maybe_bmm2_scale_tensor.value().data_ptr())
: nullptr;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard tensor-based scales with dtype checks

When bmm*_scale comes in as a tensor, we immediately reinterpret the storage as float*. Callers can legally hand us torch.Float16/torch.BFloat16 today, so this reinterpret cast will read garbage and corrupt the softmax/output scales. Please gate the tensor branch with a dtype == dl_float32 check (and emit a clear error otherwise) before taking the pointer, and apply the same fix in the context and ragged code paths.

@@
-  float* bmm1_scale_log2_ptr =
-      maybe_bmm1_scale_log2_tensor.has_value()
-          ? static_cast<float*>(maybe_bmm1_scale_log2_tensor.value().data_ptr())
-          : nullptr;
-  float* bmm2_scale_ptr = maybe_bmm2_scale_tensor.has_value()
-                              ? static_cast<float*>(maybe_bmm2_scale_tensor.value().data_ptr())
-                              : nullptr;
+  float* bmm1_scale_log2_ptr = nullptr;
+  if (maybe_bmm1_scale_log2_tensor.has_value()) {
+    TVM_FFI_ICHECK_EQ(maybe_bmm1_scale_log2_tensor.value().dtype(), dl_float32)
+        << "bmm1_scale tensor must be float32";
+    bmm1_scale_log2_ptr =
+        static_cast<float*>(maybe_bmm1_scale_log2_tensor.value().data_ptr());
+  }
+  float* bmm2_scale_ptr = nullptr;
+  if (maybe_bmm2_scale_tensor.has_value()) {
+    TVM_FFI_ICHECK_EQ(maybe_bmm2_scale_tensor.value().dtype(), dl_float32)
+        << "bmm2_scale tensor must be float32";
+    bmm2_scale_ptr =
+        static_cast<float*>(maybe_bmm2_scale_tensor.value().data_ptr());
+  }

Please mirror this guard in trtllm_paged_attention_context and trtllm_ragged_attention.

Also applies to: 338-356, 503-521

🤖 Prompt for AI Agents
csrc/trtllm_fmha_kernel_launcher.cu lines 260-277: when bmm1_scale or bmm2_scale
is a tensor the code currently reinterpret_casts data_ptr() to float* without
checking dtype which will misread half/bfloat tensors; modify the tensor branch
to first check the tensor dtype is float32 (dl_float32) and TVM_FFI_CHECK/throw
a clear error if not, then take the data_ptr() as float*; apply the identical
dtype-guard and error message to the similar blocks at lines 338-356 and 503-521
and also mirror these dtype guards in the corresponding
trtllm_paged_attention_context and trtllm_ragged_attention code paths.

Comment thread flashinfer/prefill.py
Comment thread flashinfer/prefill.py
Comment thread flashinfer/prefill.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2f39e1f and d8f6387.

📒 Files selected for processing (1)
  • flashinfer/decode.py (13 hunks)

Comment thread flashinfer/decode.py
Comment thread flashinfer/decode.py
Comment on lines +2296 to 2301
if isinstance(bmm1_scale, torch.Tensor):
assert bmm1_scale.dtype == torch.float32
bmm1_scale *= log2e
if isinstance(bmm2_scale, torch.Tensor):
assert bmm2_scale.dtype == torch.float32

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Don’t mutate caller tensors when applying log2e

Same issue here: bmm1_scale *= log2e alters the input tensor in place, so repeated invocations accumulate the scaling and yield incorrect kernels. Use an out-of-place multiply before launching the kernel.

-        if isinstance(bmm1_scale, torch.Tensor):
-            assert bmm1_scale.dtype == torch.float32
-            bmm1_scale *= log2e
+        if isinstance(bmm1_scale, torch.Tensor):
+            assert bmm1_scale.dtype == torch.float32
+            bmm1_scale = bmm1_scale * log2e
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if isinstance(bmm1_scale, torch.Tensor):
assert bmm1_scale.dtype == torch.float32
bmm1_scale *= log2e
if isinstance(bmm2_scale, torch.Tensor):
assert bmm2_scale.dtype == torch.float32
if isinstance(bmm1_scale, torch.Tensor):
assert bmm1_scale.dtype == torch.float32
bmm1_scale = bmm1_scale * log2e
if isinstance(bmm2_scale, torch.Tensor):
assert bmm2_scale.dtype == torch.float32
🤖 Prompt for AI Agents
In flashinfer/decode.py around lines 2296 to 2301, the code currently does an
in-place scale (bmm1_scale *= log2e) which mutates the caller's tensor; change
this to an out-of-place multiplication and reassign the result to bmm1_scale
(for example use torch.mul or the * operator) so a new tensor is produced on the
same dtype/device and the original caller tensor is not modified; ensure the
result remains float32 and mirrored onto the correct device; also review nearby
bmm2_scale handling and apply the same out-of-place pattern if it will be scaled
later.

Comment thread flashinfer/decode.py
@jiahanc
Copy link
Copy Markdown
Collaborator

jiahanc commented Nov 13, 2025

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !136 has been updated with latest changes, and the CI pipeline #38436713 is currently running. I'll report back once the pipeline job completes.

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[CANCELING] Pipeline #38436713: canceled

@IwakuraRein
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

@IwakuraRein is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

@jiahanc
Copy link
Copy Markdown
Collaborator

jiahanc commented Nov 13, 2025

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !136 has been updated with latest changes, and the CI pipeline #38440534 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[SUCCESS] Pipeline #38646833: 10/18 passed

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Failed UTs are not relevant (will be fixed in #2097) and this PR itself LGTM, thanks for your contributions.

@jiahanc jiahanc merged commit a9f71bd into main Nov 18, 2025
4 checks passed
@jiahanc jiahanc deleted the trtllm-gen-attention-allow-device-scales branch November 18, 2025 07:53
BingooYang pushed a commit to BingooYang/flashinfer that referenced this pull request Mar 13, 2026
…tion (flashinfer-ai#2084)

<!-- .github/pull_request_template.md -->

## 📌 Description

- change `bmm1_scale` and `bmm2_scale` to `Union[float, torch.Tensor]`.
notice that when using tensor, it must be applied by log2e
- **remove the `bmm1_scale_log2_tensor` and `bmm2_scale_tensor` in the
`xqa_batch_decode_with_kv_cache_mla`**
- update trtllm-gen FMHA kernels

TODO: do the same refactor for xqa kernels. The support for the device
side scales was removed in flashinfer-ai#2033

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Attention scale parameters now accept either floats or 1-element
tensors across prefill, decode and runtime; tensor scales are validated
and applied on-device and pointer-backed scale paths are supported.

* **Chores**
* Updated FMHA artifact path and checksum constants; added a public
utility import and removed an obsolete inline comment.

* **Tests**
* Updated tests to exercise device/tensor-or-scalar scale flows, removed
legacy per-tensor call-site args, and added device-scale parametrization
for several test variants.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants