[WIP] Refactor: simplify torch -> cute-dsl boilerplate and enable tvm-ffi for cute-dsl kernels#2279
[WIP] Refactor: simplify torch -> cute-dsl boilerplate and enable tvm-ffi for cute-dsl kernels#2279
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughThis change converts two RMSNorm FP4-quant kernels from pointer-based host/kernel bindings to TVM-FFI tensor-based invocation, updating host signatures, compilation scaffolding (fake/symbolic tensors and M), and runtime launch flows with swizzle-aware scale handling. Changes
Sequence Diagram(s)sequenceDiagram
participant Host as Host Code
participant TVM as TVM-FFI
participant Compiler as Kernel Compiler
participant CUDA as CUDA Runtime
rect rgb(245,250,255)
Note over Host,Compiler: Compile-time (fake/symbolic tensors)
Host->>Host: Create symbolic M & fake `cute.Tensor` fixtures + fake stream
Host->>Compiler: Invoke compiler with TVM-FFI enabled and fake tensors
Compiler->>TVM: Register tensor signatures for TVM-FFI
Compiler->>CUDA: Emit compiled kernel artifact
end
rect rgb(245,255,245)
Note over Host,CUDA: Runtime (real tensor passing)
Host->>Host: Prepare real tensors (mX,mW,mY,mS[,mR,mGlobalScale])\nflatten/contiguate scale if swizzled
Host->>TVM: Pass real tensors + stream via TVM-FFI
TVM->>CUDA: Launch kernel with tensor-backed inputs
CUDA->>CUDA: Execute kernel on device memory
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Suggested reviewers
Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
📜 Recent review detailsConfiguration used: defaults Review profile: CHILL Plan: Pro 📒 Files selected for processing (2)
🧰 Additional context used📓 Path-based instructions (1)flashinfer/**/*.py📄 CodeRabbit inference engine (CLAUDE.md)
Files:
🧠 Learnings (2)📓 Common learnings📚 Learning: 2025-12-30T09:34:39.900ZApplied to files:
🧬 Code graph analysis (2)flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (2)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (2)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
🔇 Additional comments (6)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refactors the existing cute-dsl kernels, specifically for RMSNorm FP4 quantization, to leverage TVM-FFI. This integration aims to simplify the interaction between PyTorch tensors and the underlying CUDA kernels by allowing direct tensor passing, thereby reducing boilerplate code and improving the overall efficiency and developer experience when working with CUTLASS's cute-dsl. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request refactors the cute-dsl kernels to leverage tvm-ffi, which is a significant improvement. By enabling tvm-ffi, the code is simplified by allowing torch.Tensor objects to be passed directly to the kernels, removing the boilerplate for manual pointer creation and management. The changes in add_rmsnorm_fp4quant.py and rmsnorm_fp4quant.py are consistent and correctly use cute.runtime.make_fake_compact_tensor with symbolic dimensions for compilation. My review includes a couple of suggestions to correct misleading comments for better code clarity. Overall, this is a great change that improves maintainability.
| # Scale factor tensor layout depends on swizzle mode | ||
| if is_sf_swizzled_layout: | ||
| # For swizzled mode, use 1D layout - the swizzle pattern is computed in kernel | ||
| # Size is: num_m_tiles * num_k_tiles * 512, which is independent of M |
There was a problem hiding this comment.
The comment incorrectly states that the swizzled size is independent of M. The number of M-tiles (num_m_tiles) is derived from M (the batch dimension), so the total swizzled size is dependent on M. The implementation correctly uses a symbolic integer for this dynamic size, but the comment is misleading and should be corrected for clarity.
| # Size is: num_m_tiles * num_k_tiles * 512, which is independent of M | |
| # Size is `num_m_tiles * num_k_tiles * 512`, which depends on the `M` dimension. |
| # Scale factor tensor layout depends on swizzle mode | ||
| if is_sf_swizzled_layout: | ||
| # For swizzled mode, use 1D layout - the swizzle pattern is computed in kernel | ||
| # Size is: num_m_tiles * num_k_tiles * 512, which is independent of M |
There was a problem hiding this comment.
This comment is misleading. The swizzled size is dependent on M because num_m_tiles is calculated based on M. While the code correctly uses a symbolic size, the comment should be updated to reflect this dependency to avoid confusion.
| # Size is: num_m_tiles * num_k_tiles * 512, which is independent of M | |
| # Size is `num_m_tiles * num_k_tiles * 512`, which depends on the `M` dimension. |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (1)
1658-1661: Outdated section header.The section header on line 1659 says "Pointer-based Compilation" but the code now uses tensor-based TVM-FFI compilation. This should be updated for consistency.
Suggested fix
# ============================================================================= -# PyTorch API Functions - Streamlined with Pointer-based Compilation +# PyTorch API Functions - Streamlined with TVM-FFI Tensor Compilation # =============================================================================
🧹 Nitpick comments (2)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (1)
1706-1713: Minor: Misleading comment about M-independence.The comment states the swizzled size "is independent of M", but
num_m_tiles = ceil(M / 128), so the size actually depends on M. The implementation using a separate symbolic variable is correct, but the comment is confusing.Suggested fix
if is_sf_swizzled_layout: # For swizzled mode, use 1D layout - the swizzle pattern is computed in kernel - # Size is: num_m_tiles * num_k_tiles * 512, which is independent of M - # Use a separate symbolic variable for this size + # Size is: num_m_tiles * num_k_tiles * 512 + # Use a separate symbolic variable since this has different shape semantics sym_swizzled_size = cute.sym_int()flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (1)
39-671: Consider extracting shared intrinsics and utilities to a common module.Both
rmsnorm_fp4quant.pyandadd_rmsnorm_fp4quant.pyshare substantial duplicate code (~800+ lines):
- PTX intrinsics (
set_block_rank,store_shared_remote,ld_global_v4_u32, etc.)- Half2/BFloat2 SIMD intrinsics
- FP8/UE8M0 conversion intrinsics
- Reduction utilities (
warp_reduce,block_reduce,cluster_reduce)get_sm_versionfunctionExtracting these to a shared module (e.g.,
flashinfer/cute_dsl/intrinsics.py) would reduce maintenance burden and ensure consistency.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.pyflashinfer/cute_dsl/rmsnorm_fp4quant.py
🧰 Additional context used
📓 Path-based instructions (1)
flashinfer/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/**/*.py: Use@functools.cachedecorator on Python API functions to implement module-level caching and avoid recompilation
Use@flashinfer_apidecorator for debugging API calls, enable viaFLASHINFER_LOGLEVELenvironment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Files:
flashinfer/cute_dsl/add_rmsnorm_fp4quant.pyflashinfer/cute_dsl/rmsnorm_fp4quant.py
🧠 Learnings (2)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers
Applied to files:
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
🧬 Code graph analysis (2)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (2)
csrc/tvm_ffi_utils.h (1)
Tensor(304-306)include/flashinfer/trtllm/fused_moe/runner.h (1)
hidden_size(265-265)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (2)
csrc/tvm_ffi_utils.h (1)
Tensor(304-306)include/flashinfer/trtllm/fused_moe/runner.h (1)
hidden_size(265-265)
🔇 Additional comments (5)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (3)
1001-1018: LGTM - Kernel interface refactored to tensor-based API.The signature change from pointer-based to tensor-based inputs aligns with the TVM-FFI refactoring objective. The docstrings accurately describe the expected tensor shapes and layouts.
1739-1760: LGTM - Runtime tensor API correctly handles tensor passing.The
tensor_apiclosure appropriately handles the scale tensor layout (flatten for swizzled, contiguous for non-swizzled). The callerrmsnorm_fp4quantensures input tensors are contiguous before passing them.
1726-1737: LGTM - TVM-FFI compilation setup.The compilation correctly uses a fake stream with
use_tvm_ffi_env_stream=Trueto capture the environment stream at runtime, and enables TVM-FFI with the--enable-tvm-ffioption.flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (2)
2103-2171: LGTM - TVM-FFI compilation setup consistent with rmsnorm_fp4quant.py.The fake tensor creation and compilation approach is consistent with the pattern in
rmsnorm_fp4quant.py. The implementation correctly uses symbolic dimensions and TVM-FFI options.Note: The same minor comment about "independent of M" at line 2141 applies here as well.
2173-2196: LGTM - Runtime tensor API correctly handles tensor passing.The
tensor_apiclosure appropriately handles the scale tensor layout and matches the pattern fromrmsnorm_fp4quant.py.
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
…er into cute-dsl-refactor
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (2)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (1)
1733-1734: Comment is misleading about M-independence.The swizzled size depends on
Mbecausenum_m_tiles = (batch_size + 127) // 128is computed from the batch dimension. While the code correctly uses a symbolic variable, the comment should be corrected.flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (1)
2182-2183: Comment incorrectly claims M-independence.Same issue as in
rmsnorm_fp4quant.py: the swizzled size depends onMvianum_m_tiles.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.pyflashinfer/cute_dsl/rmsnorm_fp4quant.py
🧰 Additional context used
📓 Path-based instructions (1)
flashinfer/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/**/*.py: Use@functools.cachedecorator on Python API functions to implement module-level caching and avoid recompilation
Use@flashinfer_apidecorator for debugging API calls, enable viaFLASHINFER_LOGLEVELenvironment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Files:
flashinfer/cute_dsl/rmsnorm_fp4quant.pyflashinfer/cute_dsl/add_rmsnorm_fp4quant.py
🧠 Learnings (2)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers
Applied to files:
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
🧬 Code graph analysis (1)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (2)
csrc/tvm_ffi_utils.h (1)
Tensor(304-306)include/flashinfer/trtllm/fused_moe/runner.h (1)
hidden_size(265-265)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (4)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (2)
1001-1041: LGTM! Clean transition to TVM-FFI tensor-based interface.The signature change from pointer-based to tensor-based inputs is well-structured. The docstring accurately describes the tensor shapes and the TVM-FFI approach.
1765-1788: Verify the runtime API aligns with the corrected compilation signature.Once the missing
global_scale_faketensor is added to thecute.compilecall, ensure thistensor_apifunction continues to pass tensors in the correct order matching the kernel's__call__signature.flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (2)
1005-1027: LGTM! Docstring correctly describes tensor inputs.The signature change to tensor-based inputs is well-structured. The docstring accurately describes tensor shapes without the previously flagged incorrect claim about in-place mR updates.
2215-2240: Verify tensor_api aligns with corrected compilation after fix.Once the missing
global_scale_faketensor is added to thecute.compilecall, thistensor_apifunction should correctly pass tensors in the expected order.
|
/bot run |
|
[FAILED] Pipeline #41059115: 12/20 passed |
|
cc @tqchen on the failure, looks similar to what we observed on FA4. |
|
cuteDSL related arm failure should be resolved by cuteDSL 4.3.4 |
<!-- .github/pull_request_template.md --> ## 📌 Description Update minimal version requirement of nvidia-cutlass-dsl to 4.3.4, which should resolve the arm issue in #2279 ## 🔍 Related Issues #2279 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Chores** * Updated internal dependencies to improve stability and compatibility. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
|
/bot run |
|
Hi @bkryu cu129 unittest on gb300 failed, do you think it's relevant? |
Failure was unrelated. I relaunched the test. Will keep an a eye on it and then approve |
|
[SUCCESS] Pipeline #41205461: 8/20 passed |
bkryu
left a comment
There was a problem hiding this comment.
Failed GB300 cu129 unit test passed after retry. LGTM
<!-- .github/pull_request_template.md --> ## 📌 Description Update minimal version requirement of nvidia-cutlass-dsl to 4.3.4, which should resolve the arm issue in flashinfer-ai/flashinfer#2279 ## 🔍 Related Issues #2279 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Chores** * Updated internal dependencies to improve stability and compatibility. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
cute-dsl adds support of compiling with tvm-ffi since 4.3 release https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl_general/compile_with_tvm_ffi.html, which allows user to pass torch tensors directly with negligible dlpack conversion cost, without the need of manually creating cute tensors from cute pointer.
In this PR we refactored the existing cute-dsl kernels to enable tvm-ffi and simplify the torch -> cute-dsl boilerplate.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.