Skip to content

feat: unit-test and api change, w4a8 grouped-gemm fused MoE for SM90#2193

Merged
yzh119 merged 5 commits intoflashinfer-ai:mainfrom
jimmyzho:grouped-gemm
Dec 11, 2025
Merged

feat: unit-test and api change, w4a8 grouped-gemm fused MoE for SM90#2193
yzh119 merged 5 commits intoflashinfer-ai:mainfrom
jimmyzho:grouped-gemm

Conversation

@jimmyzho
Copy link
Copy Markdown
Contributor

@jimmyzho jimmyzho commented Dec 10, 2025

📌 Description

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added packed‑weights support for fused MoE calls with a new public flag to select packed vs. unpacked weight handling; runtime paths now distinguish runners by this flag.
    • Exposed the flag across public wrappers so callers can opt into packed‑weights behavior.
  • Tests

    • Added W4A8 quantization pathway tests and helpers to validate INT4 packing/dequantization and end‑to‑end W4A8 behavior.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Dec 10, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds a new boolean flag use_packed_weights across the fused MoE codepath: C++ FusedMoeRunner now stores and uses the flag to gate quantization checks; Python MoERunner and public APIs accept and propagate the flag (including in runner cache keys); tests add INT4 dequant/depack helpers and a W4A8 reference test.

Changes

Cohort / File(s) Summary
C++ backend change
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu
Added use_packed_weights parameter to FusedMoeRunner constructor and init(). Introduced private member mUsePackedWeights (default false). Quantization checks updated so isWFP4A16Quant requires !mUsePackedWeights and isInt4Quant requires mUsePackedWeights. Object creation sites updated to pass the flag.
Python core integration
flashinfer/fused_moe/core.py
Extended MoERunner to accept/store use_packed_weights; included it in the instance cache key. Public wrappers cutlass_fused_moe() and _fake_cutlass_fused_moe() now accept and forward use_packed_weights. TRT-LLM MoERunner paths updated to propagate the flag.
Tests & utilities
tests/moe/test_trtllm_cutlass_fused_moe.py
Added break_int4_bytes_to_int8() and dequantize_int4_to_dtype() helpers for INT4 unpack/dequant. Implemented torch_moe_w4a8() reference W4A8 path and added test_moe_w4a8 to validate the W4A8 flow.

Sequence Diagram(s)

sequenceDiagram
    participant User as Python API
    participant MoERunner as MoERunner (cache)
    participant Module as Fused MoE Module (Python binding)
    participant CUDABackend as FusedMoeRunner (C++ / CUDA)

    User->>MoERunner: request runner (use_packed_weights flag)
    MoERunner->>MoERunner: include flag in instance_key (cache lookup)
    alt cache miss
        MoERunner->>Module: init(..., use_packed_weights)
        Module->>CUDABackend: construct FusedMoeRunner(..., use_packed_weights)
        CUDABackend-->>Module: created (mUsePackedWeights set)
        Module-->>MoERunner: module ready
        MoERunner-->>User: runner instance
    else cache hit
        MoERunner-->>User: cached runner
    end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

  • Areas to inspect:
    • All MoERunner instantiation paths and cache key composition for correct propagation of use_packed_weights
    • C++ quantization gating logic around mUsePackedWeights (ensure no unintended paths allowed)
    • Call sites where FusedMoeRunner is constructed (bindings/ffi) to confirm flag is always forwarded
    • New test coverage for correctness and edge cases (both flag values)

Suggested reviewers

  • yongwww
  • djmmoss
  • cyx-6
  • yzh119
  • wenscarl
  • jiahanc

Poem

🐇 A tiny flag hops into place,
Packed weights tucked in a cozy space,
From Python cache down to CUDA core,
INT4 whispers, W4A8 explores,
Hooray — the rabbit cheers, and hops once more!

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 16.67% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description check ⚠️ Warning The PR description is incomplete with the main content sections left blank (Description and Related Issues are empty comments). Fill in the Description section with a clear explanation of what changes were made and why, and link any related issues in the Related Issues section.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the main change: adding W4A8 grouped-gemm fused MoE support for SM90 with unit tests and API changes.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@jimmyzho
Copy link
Copy Markdown
Contributor Author

/bot run

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @jimmyzho, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the fused MoE functionality by integrating support for W4A8 quantization, which uses INT4 weights and FP8 activations. This optimization is particularly beneficial for SM90 architectures, aiming to improve performance and efficiency. The changes involve modifying both the core C++ binding and the Python interface to accommodate a new use_packed_weights parameter, ensuring proper handling of different INT4 weight representations. A dedicated unit test has been added to rigorously verify the correctness of this new quantization scheme.

Highlights

  • W4A8 Grouped-GEMM Fused MoE Support: Introduced support for W4A8 (INT4 weights, FP8 activations) grouped-GEMM fused Mixture-of-Experts (MoE) operations, specifically targeting NVIDIA's SM90 (Hopper) architecture.
  • API and Backend Changes for Packed Weights: Added a use_packed_weights boolean parameter across the C++ backend and Python API to correctly handle and distinguish between packed and unpacked INT4 weight formats, updating relevant quantization logic.
  • Comprehensive Unit Testing: A new unit test (test_moe_w4a8) has been added to validate the W4A8 grouped-GEMM fused MoE implementation, including helper functions for INT4 unpacking/dequantization and a PyTorch reference model, with a specific check for SM90 device compatibility.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !187 has been created, and the CI pipeline #39922596 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for W4A8 grouped-GEMM fused MoE for SM90, along with corresponding API changes and a new unit test. The changes primarily involve adding a use_packed_weights flag through the API layers and implementing a comprehensive test case for this new functionality. The implementation appears solid. I have provided a few suggestions to enhance code clarity and maintainability, focusing on the C++ constructor's initialization order and improving the new test case's data types and documentation of magic numbers.

Comment thread tests/moe/test_trtllm_cutlass_fused_moe.py Outdated
Comment thread tests/moe/test_trtllm_cutlass_fused_moe.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
tests/moe/test_trtllm_cutlass_fused_moe.py (2)

1493-1501: Minor: torch.randint upper bound is exclusive.

The weight initialization uses torch.randint(-128, 127, ...), which generates values in [-128, 126] since torch.randint excludes the upper bound. If you intend to cover the full int8 range including 127, use torch.randint(-128, 128, ...).

This likely has negligible impact on test coverage.

-    w1_weight = torch.randint(
-        -128, 127, (e, n, k // 2), dtype=torch.int8, device="cuda"
-    )
+    w1_weight = torch.randint(
+        -128, 128, (e, n, k // 2), dtype=torch.int8, device="cuda"
+    )

1552-1566: SM90 bit pattern conversion looks intentional but warrants a brief clarifying comment.

The bfloat16 reinterpretation via .to(torch.bfloat16).view(dtype) is a technique to encode bfloat16 scale factors in a different dtype's bit representation. Consider expanding the comment to briefly explain why SM90 requires this, for future maintainability.

flashinfer/fused_moe/core.py (1)

733-733: Missing documentation for use_packed_weights parameter.

The use_packed_weights parameter is added to the public cutlass_fused_moe API but is not documented in the docstring. Consider adding documentation explaining:

  • What packed weights means (INT4 vs WFP4A16)
  • When to use this flag
  • SM90 requirement

Add to the Parameters section of the docstring:

    use_packed_weights : bool = False
        Whether weights are in packed INT4 format for W4A8 quantization.
        When True, enables the INT4 weight path with FP8 activations (SM90 only).
        When False (default), uses the WFP4A16 path if use_w4_group_scaling is True.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6bb01d1 and 31d7993.

📒 Files selected for processing (3)
  • csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu (4 hunks)
  • flashinfer/fused_moe/core.py (11 hunks)
  • tests/moe/test_trtllm_cutlass_fused_moe.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/moe/test_trtllm_cutlass_fused_moe.py (1)
flashinfer/fused_moe/core.py (2)
  • cutlass_fused_moe (495-649)
  • cutlass_fused_moe (707-925)
flashinfer/fused_moe/core.py (2)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu (2)
  • dtype (50-77)
  • dtype (50-50)
flashinfer/logits_processor/types.py (1)
  • dtype (126-130)
🪛 Ruff (0.14.8)
flashinfer/fused_moe/core.py

680-680: Unused function argument: use_packed_weights

(ARG001)

🔇 Additional comments (9)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu (2)

1179-1183: Quantization predicate logic is clear and mutually exclusive.

The updates correctly distinguish between WFP4A16 (unpacked uint8 with group scaling) and INT4 (packed uint8) quantization paths using the mUsePackedWeights flag. The predicates are now mutually exclusive.


120-125: Parameter threading looks correct.

The use_packed_weights parameter is properly propagated from the init function through to the FusedMoeRunner constructor and stored in mUsePackedWeights. The default value of false preserves backward compatibility.

Also applies to: 1196-1201

tests/moe/test_trtllm_cutlass_fused_moe.py (4)

119-124: INT4 unpacking implementation is correct.

The signed 4-bit conversion logic (subtracting 16 when value >= 8) correctly converts from unsigned nibble representation to signed two's complement values (-8 to 7 range).


127-140: Dequantization helper is clean and correct.

The function correctly unpacks INT4 weights and applies per-group scaling. The optional weight_scale_2 is applied as a divisor after the primary scaling, which aligns with the reference implementation in torch_moe_w4a8.


211-279: Reference implementation follows expected W4A8 computation pattern.

The implementation correctly models the W4A8 quantization flow:

  1. Pre-quant scaling → FP8 quantization → matmul → rescale
  2. SwiGLU activation on intermediate
  3. Second layer follows same pattern
  4. Final routing weight combination

The FP8 clamping range (±448) correctly matches torch.float8_e4m3fn max value.


1637-1637: Test tolerance is acceptable for quantized computation.

The tolerances (rtol=1e-2, atol=1e-1) are consistent with other quantized MoE tests in this file and reasonable given the cumulative error from multiple quantization/dequantization steps in W4A8.

flashinfer/fused_moe/core.py (3)

376-416: MoERunner correctly incorporates use_packed_weights into caching and initialization.

The use_packed_weights flag is properly:

  1. Accepted as a constructor parameter
  2. Stored on the instance
  3. Included in instance_key for runner caching (important to avoid mixing runners with different weight formats)
  4. Passed to the underlying module initialization

680-680: Unused parameter in fake op is expected behavior.

The static analysis hint flagging use_packed_weights as unused in _fake_cutlass_fused_moe is a false positive. Fake ops are used for shape inference during torch.compile and must match the real op's signature, but they only need to return correct output shapes—they don't need to use all parameters.


997-1010: use_packed_weights stored but unused in TRT-LLM MoERunner.

The TRT-LLM MoERunner accepts and stores use_packed_weights (lines 997, 1010), but the flag is never used in any of the TRT-LLM MoE operation paths (trtllm_bf16_moe, trtllm_fp8_*_moe, trtllm_fp4_*_moe, trtllm_mxint4_*_moe).

Is this intended as preparation for future W4A8 support in the TRT-LLM backend, or should it be wired through to the relevant operations?

Comment thread flashinfer/fused_moe/core.py
Copy link
Copy Markdown
Collaborator

@aleozlx aleozlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm. remember to address open comments

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #39922596: 15/20 passed

@jimmyzho
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !187 has been updated with latest changes, and the CI pipeline #39974155 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[SUCCESS] Pipeline #39974155: 15/20 passed

@yzh119 yzh119 merged commit fae79ab into flashinfer-ai:main Dec 11, 2025
3 checks passed
BingooYang pushed a commit to BingooYang/flashinfer that referenced this pull request Mar 13, 2026
…lashinfer-ai#2193)

<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added packed‑weights support for fused MoE calls with a new public
flag to select packed vs. unpacked weight handling; runtime paths now
distinguish runners by this flag.
* Exposed the flag across public wrappers so callers can opt into
packed‑weights behavior.

* **Tests**
* Added W4A8 quantization pathway tests and helpers to validate INT4
packing/dequantization and end‑to‑end W4A8 behavior.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants