feat: unit-test and api change, w4a8 grouped-gemm fused MoE for SM90#2193
feat: unit-test and api change, w4a8 grouped-gemm fused MoE for SM90#2193yzh119 merged 5 commits intoflashinfer-ai:mainfrom
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds a new boolean flag Changes
Sequence Diagram(s)sequenceDiagram
participant User as Python API
participant MoERunner as MoERunner (cache)
participant Module as Fused MoE Module (Python binding)
participant CUDABackend as FusedMoeRunner (C++ / CUDA)
User->>MoERunner: request runner (use_packed_weights flag)
MoERunner->>MoERunner: include flag in instance_key (cache lookup)
alt cache miss
MoERunner->>Module: init(..., use_packed_weights)
Module->>CUDABackend: construct FusedMoeRunner(..., use_packed_weights)
CUDABackend-->>Module: created (mUsePackedWeights set)
Module-->>MoERunner: module ready
MoERunner-->>User: runner instance
else cache hit
MoERunner-->>User: cached runner
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
/bot run |
Summary of ChangesHello @jimmyzho, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the fused MoE functionality by integrating support for W4A8 quantization, which uses INT4 weights and FP8 activations. This optimization is particularly beneficial for SM90 architectures, aiming to improve performance and efficiency. The changes involve modifying both the core C++ binding and the Python interface to accommodate a new Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for W4A8 grouped-GEMM fused MoE for SM90, along with corresponding API changes and a new unit test. The changes primarily involve adding a use_packed_weights flag through the API layers and implementing a comprehensive test case for this new functionality. The implementation appears solid. I have provided a few suggestions to enhance code clarity and maintainability, focusing on the C++ constructor's initialization order and improving the new test case's data types and documentation of magic numbers.
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (3)
tests/moe/test_trtllm_cutlass_fused_moe.py (2)
1493-1501: Minor:torch.randintupper bound is exclusive.The weight initialization uses
torch.randint(-128, 127, ...), which generates values in[-128, 126]sincetorch.randintexcludes the upper bound. If you intend to cover the full int8 range including 127, usetorch.randint(-128, 128, ...).This likely has negligible impact on test coverage.
- w1_weight = torch.randint( - -128, 127, (e, n, k // 2), dtype=torch.int8, device="cuda" - ) + w1_weight = torch.randint( + -128, 128, (e, n, k // 2), dtype=torch.int8, device="cuda" + )
1552-1566: SM90 bit pattern conversion looks intentional but warrants a brief clarifying comment.The bfloat16 reinterpretation via
.to(torch.bfloat16).view(dtype)is a technique to encode bfloat16 scale factors in a different dtype's bit representation. Consider expanding the comment to briefly explain why SM90 requires this, for future maintainability.flashinfer/fused_moe/core.py (1)
733-733: Missing documentation foruse_packed_weightsparameter.The
use_packed_weightsparameter is added to the publiccutlass_fused_moeAPI but is not documented in the docstring. Consider adding documentation explaining:
- What packed weights means (INT4 vs WFP4A16)
- When to use this flag
- SM90 requirement
Add to the Parameters section of the docstring:
use_packed_weights : bool = False Whether weights are in packed INT4 format for W4A8 quantization. When True, enables the INT4 weight path with FP8 activations (SM90 only). When False (default), uses the WFP4A16 path if use_w4_group_scaling is True.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu(4 hunks)flashinfer/fused_moe/core.py(11 hunks)tests/moe/test_trtllm_cutlass_fused_moe.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/moe/test_trtllm_cutlass_fused_moe.py (1)
flashinfer/fused_moe/core.py (2)
cutlass_fused_moe(495-649)cutlass_fused_moe(707-925)
flashinfer/fused_moe/core.py (2)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu (2)
dtype(50-77)dtype(50-50)flashinfer/logits_processor/types.py (1)
dtype(126-130)
🪛 Ruff (0.14.8)
flashinfer/fused_moe/core.py
680-680: Unused function argument: use_packed_weights
(ARG001)
🔇 Additional comments (9)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu (2)
1179-1183: Quantization predicate logic is clear and mutually exclusive.The updates correctly distinguish between WFP4A16 (unpacked uint8 with group scaling) and INT4 (packed uint8) quantization paths using the
mUsePackedWeightsflag. The predicates are now mutually exclusive.
120-125: Parameter threading looks correct.The
use_packed_weightsparameter is properly propagated from theinitfunction through to theFusedMoeRunnerconstructor and stored inmUsePackedWeights. The default value offalsepreserves backward compatibility.Also applies to: 1196-1201
tests/moe/test_trtllm_cutlass_fused_moe.py (4)
119-124: INT4 unpacking implementation is correct.The signed 4-bit conversion logic (subtracting 16 when value >= 8) correctly converts from unsigned nibble representation to signed two's complement values (-8 to 7 range).
127-140: Dequantization helper is clean and correct.The function correctly unpacks INT4 weights and applies per-group scaling. The optional
weight_scale_2is applied as a divisor after the primary scaling, which aligns with the reference implementation intorch_moe_w4a8.
211-279: Reference implementation follows expected W4A8 computation pattern.The implementation correctly models the W4A8 quantization flow:
- Pre-quant scaling → FP8 quantization → matmul → rescale
- SwiGLU activation on intermediate
- Second layer follows same pattern
- Final routing weight combination
The FP8 clamping range (±448) correctly matches
torch.float8_e4m3fnmax value.
1637-1637: Test tolerance is acceptable for quantized computation.The tolerances (
rtol=1e-2,atol=1e-1) are consistent with other quantized MoE tests in this file and reasonable given the cumulative error from multiple quantization/dequantization steps in W4A8.flashinfer/fused_moe/core.py (3)
376-416: MoERunner correctly incorporatesuse_packed_weightsinto caching and initialization.The
use_packed_weightsflag is properly:
- Accepted as a constructor parameter
- Stored on the instance
- Included in
instance_keyfor runner caching (important to avoid mixing runners with different weight formats)- Passed to the underlying module initialization
680-680: Unused parameter in fake op is expected behavior.The static analysis hint flagging
use_packed_weightsas unused in_fake_cutlass_fused_moeis a false positive. Fake ops are used for shape inference duringtorch.compileand must match the real op's signature, but they only need to return correct output shapes—they don't need to use all parameters.
997-1010:use_packed_weightsstored but unused in TRT-LLM MoERunner.The TRT-LLM
MoERunneraccepts and storesuse_packed_weights(lines 997, 1010), but the flag is never used in any of the TRT-LLM MoE operation paths (trtllm_bf16_moe,trtllm_fp8_*_moe,trtllm_fp4_*_moe,trtllm_mxint4_*_moe).Is this intended as preparation for future W4A8 support in the TRT-LLM backend, or should it be wired through to the relevant operations?
aleozlx
left a comment
There was a problem hiding this comment.
lgtm. remember to address open comments
|
[FAILED] Pipeline #39922596: 15/20 passed |
|
/bot run |
|
[SUCCESS] Pipeline #39974155: 15/20 passed |
…lashinfer-ai#2193) <!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added packed‑weights support for fused MoE calls with a new public flag to select packed vs. unpacked weight handling; runtime paths now distinguish runners by this flag. * Exposed the flag across public wrappers so callers can opt into packed‑weights behavior. * **Tests** * Added W4A8 quantization pathway tests and helpers to validate INT4 packing/dequantization and end‑to‑end W4A8 behavior. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.