[API change] Allow using torch.Tensor for scales for trtllm-gen attention#2084
[API change] Allow using torch.Tensor for scales for trtllm-gen attention#2084
Conversation
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
WalkthroughAdds tensor-or-scalar support for attention scaling across Python APIs and C++ FMHA launchers using tvm::ffi::Variant; accepts device-resident scale tensors (passed as float pointers) or host scalars, applies on-device log2e scaling for tensor inputs, and updates FMHA cubin artifact path and checksum constants. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Py as Python API
participant Bind as FFI Binder (C++)
participant Launcher as trtllm_*_launcher
participant Runner as FMHA Runner
Py->>Bind: call API with bmm1_scale, bmm2_scale (float or Tensor)
alt Tensor inputs
note right of Bind `#d6f5d6`: assert dtype float32\ncompute tensor * log2e on device
Bind->>Launcher: Variant(tensor) + provide bmm1_scale_log2_ptr & bmm2_scale_ptr
else Scalar inputs
note right of Bind `#f0f0f0`: keep/convert as double
Bind->>Launcher: Variant(double) + pass nullptr for scale pointers
end
Launcher->>Runner: set runner_params.scaleSoftmaxLog2Ptr & outputScalePtr
Runner->>Runner: execute FMHA using pointer scales if present else scalar scales
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Tip 📝 Customizable high-level summaries are now available in beta!You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.
Example instruction:
Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
|
/bot run |
|
@IwakuraRein is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
|
/bot run |
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
|
[CANCELING] Pipeline #38436074: canceled |
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/decode.py (1)
1883-1901: In-place scale multiply causes driftSame issue here:
bmm1_scale *= log2eupdates the caller’s tensor. If the caller caches that buffer (common in decode loops), it compounds every step. Please switch to a non-in-place multiply or clone first. (docs.pytorch.org)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
csrc/trtllm_fmha_kernel_launcher.cu(13 hunks)flashinfer/artifacts.py(2 hunks)flashinfer/decode.py(13 hunks)flashinfer/prefill.py(9 hunks)include/flashinfer/trtllm/fmha/kernelParams.h(0 hunks)
💤 Files with no reviewable changes (1)
- include/flashinfer/trtllm/fmha/kernelParams.h
| auto maybe_bmm2_scale_value = bmm2_scale.as<double>(); | ||
| auto maybe_bmm1_scale_log2_tensor = bmm1_scale.as<ffi::Tensor>(); | ||
| auto maybe_bmm2_scale_tensor = bmm2_scale.as<ffi::Tensor>(); | ||
| TVM_FFI_CHECK(maybe_bmm1_scale_value.has_value() || maybe_bmm1_scale_log2_tensor.has_value(), | ||
| "bmm1_scale must be either a double or a tensor"); | ||
| TVM_FFI_CHECK(maybe_bmm2_scale_value.has_value() || maybe_bmm2_scale_tensor.has_value(), | ||
| "bmm2_scale must be either a double or a tensor"); | ||
| double bmm1_scale_value = | ||
| maybe_bmm1_scale_value.has_value() ? maybe_bmm1_scale_value.value() : 1.0; | ||
| double bmm2_scale_value = | ||
| maybe_bmm2_scale_value.has_value() ? maybe_bmm2_scale_value.value() : 1.0; | ||
| float* bmm1_scale_log2_ptr = | ||
| maybe_bmm1_scale_log2_tensor.has_value() | ||
| ? static_cast<float*>(maybe_bmm1_scale_log2_tensor.value().data_ptr()) | ||
| : nullptr; | ||
| float* bmm2_scale_ptr = maybe_bmm2_scale_tensor.has_value() | ||
| ? static_cast<float*>(maybe_bmm2_scale_tensor.value().data_ptr()) | ||
| : nullptr; |
There was a problem hiding this comment.
Guard tensor-based scales with dtype checks
When bmm*_scale comes in as a tensor, we immediately reinterpret the storage as float*. Callers can legally hand us torch.Float16/torch.BFloat16 today, so this reinterpret cast will read garbage and corrupt the softmax/output scales. Please gate the tensor branch with a dtype == dl_float32 check (and emit a clear error otherwise) before taking the pointer, and apply the same fix in the context and ragged code paths.
@@
- float* bmm1_scale_log2_ptr =
- maybe_bmm1_scale_log2_tensor.has_value()
- ? static_cast<float*>(maybe_bmm1_scale_log2_tensor.value().data_ptr())
- : nullptr;
- float* bmm2_scale_ptr = maybe_bmm2_scale_tensor.has_value()
- ? static_cast<float*>(maybe_bmm2_scale_tensor.value().data_ptr())
- : nullptr;
+ float* bmm1_scale_log2_ptr = nullptr;
+ if (maybe_bmm1_scale_log2_tensor.has_value()) {
+ TVM_FFI_ICHECK_EQ(maybe_bmm1_scale_log2_tensor.value().dtype(), dl_float32)
+ << "bmm1_scale tensor must be float32";
+ bmm1_scale_log2_ptr =
+ static_cast<float*>(maybe_bmm1_scale_log2_tensor.value().data_ptr());
+ }
+ float* bmm2_scale_ptr = nullptr;
+ if (maybe_bmm2_scale_tensor.has_value()) {
+ TVM_FFI_ICHECK_EQ(maybe_bmm2_scale_tensor.value().dtype(), dl_float32)
+ << "bmm2_scale tensor must be float32";
+ bmm2_scale_ptr =
+ static_cast<float*>(maybe_bmm2_scale_tensor.value().data_ptr());
+ }Please mirror this guard in trtllm_paged_attention_context and trtllm_ragged_attention.
Also applies to: 338-356, 503-521
🤖 Prompt for AI Agents
csrc/trtllm_fmha_kernel_launcher.cu lines 260-277: when bmm1_scale or bmm2_scale
is a tensor the code currently reinterpret_casts data_ptr() to float* without
checking dtype which will misread half/bfloat tensors; modify the tensor branch
to first check the tensor dtype is float32 (dl_float32) and TVM_FFI_CHECK/throw
a clear error if not, then take the data_ptr() as float*; apply the identical
dtype-guard and error message to the similar blocks at lines 338-356 and 503-521
and also mirror these dtype guards in the corresponding
trtllm_paged_attention_context and trtllm_ragged_attention code paths.
| if isinstance(bmm1_scale, torch.Tensor): | ||
| assert bmm1_scale.dtype == torch.float32 | ||
| bmm1_scale *= log2e | ||
| if isinstance(bmm2_scale, torch.Tensor): | ||
| assert bmm2_scale.dtype == torch.float32 | ||
|
|
There was a problem hiding this comment.
Don’t mutate caller tensors when applying log2e
Same issue here: bmm1_scale *= log2e alters the input tensor in place, so repeated invocations accumulate the scaling and yield incorrect kernels. Use an out-of-place multiply before launching the kernel.
- if isinstance(bmm1_scale, torch.Tensor):
- assert bmm1_scale.dtype == torch.float32
- bmm1_scale *= log2e
+ if isinstance(bmm1_scale, torch.Tensor):
+ assert bmm1_scale.dtype == torch.float32
+ bmm1_scale = bmm1_scale * log2e📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if isinstance(bmm1_scale, torch.Tensor): | |
| assert bmm1_scale.dtype == torch.float32 | |
| bmm1_scale *= log2e | |
| if isinstance(bmm2_scale, torch.Tensor): | |
| assert bmm2_scale.dtype == torch.float32 | |
| if isinstance(bmm1_scale, torch.Tensor): | |
| assert bmm1_scale.dtype == torch.float32 | |
| bmm1_scale = bmm1_scale * log2e | |
| if isinstance(bmm2_scale, torch.Tensor): | |
| assert bmm2_scale.dtype == torch.float32 | |
🤖 Prompt for AI Agents
In flashinfer/decode.py around lines 2296 to 2301, the code currently does an
in-place scale (bmm1_scale *= log2e) which mutates the caller's tensor; change
this to an out-of-place multiplication and reassign the result to bmm1_scale
(for example use torch.mul or the * operator) so a new tensor is produced on the
same dtype/device and the original caller tensor is not modified; ensure the
result remains float32 and mirrored onto the correct device; also review nearby
bmm2_scale handling and apply the same out-of-place pattern if it will be scaled
later.
|
/bot run |
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
|
[CANCELING] Pipeline #38436713: canceled |
|
/bot run |
|
@IwakuraRein is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
|
/bot run |
|
[SUCCESS] Pipeline #38646833: 10/18 passed |
…tion (flashinfer-ai#2084) <!-- .github/pull_request_template.md --> ## 📌 Description - change `bmm1_scale` and `bmm2_scale` to `Union[float, torch.Tensor]`. notice that when using tensor, it must be applied by log2e - **remove the `bmm1_scale_log2_tensor` and `bmm2_scale_tensor` in the `xqa_batch_decode_with_kv_cache_mla`** - update trtllm-gen FMHA kernels TODO: do the same refactor for xqa kernels. The support for the device side scales was removed in flashinfer-ai#2033 ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Attention scale parameters now accept either floats or 1-element tensors across prefill, decode and runtime; tensor scales are validated and applied on-device and pointer-backed scale paths are supported. * **Chores** * Updated FMHA artifact path and checksum constants; added a public utility import and removed an obsolete inline comment. * **Tests** * Updated tests to exercise device/tensor-or-scalar scale flows, removed legacy per-tensor call-site args, and added device-scale parametrization for several test variants. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
📌 Description
bmm1_scaleandbmm2_scaletoUnion[float, torch.Tensor]. notice that when using tensor, it must be applied by log2ebmm1_scale_log2_tensorandbmm2_scale_tensorin thexqa_batch_decode_with_kv_cache_mlaTODO: do the same refactor for xqa kernels. The support for the device side scales was removed in #2033
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Chores
Tests