Skip to content

fix: guard MXFP8 fc1 weight shape check for non-gated activations#3082

Merged
aleozlx merged 1 commit intoflashinfer-ai:mainfrom
ianliuy:fix/issue-2731-gated-activation-guard
Apr 24, 2026
Merged

fix: guard MXFP8 fc1 weight shape check for non-gated activations#3082
aleozlx merged 1 commit intoflashinfer-ai:mainfrom
ianliuy:fix/issue-2731-gated-activation-guard

Conversation

@ianliuy
Copy link
Copy Markdown
Contributor

@ianliuy ianliuy commented Apr 15, 2026

Fixes #2731.

What's broken?

When using the CUTLASS fused MoE backend with non-gated activations (e.g., Relu2, Gelu, Silu) and MXFP8 quantization, the fc1 weight shape validation unconditionally rejects the input — even when the shape is correct.

Who is affected?

Anyone using the CUTLASS fused MoE path with:

  • Quantization: WMxfp8AMxfp8, WMxfp4AFp8, or WMxfp4AMxfp8
  • Activation: any non-gated type (Relu2, Gelu, Silu, etc.)

Not affected: gated activations (Swiglu, Geglu, SwigluBias), or other quant modes (NVFP4 already handles this correctly).

Where is the bug?

csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu, inside getQuantParams() — the fc1 weight block N-dimension check hardcodes * 2 at three MXFP8 branches (~L898, ~L1004, ~L1063).

Why does it happen?

PR #2581 introduced MXFP8 support when only gated activations (Swiglu) existed, so inter_size * 2 was correct. Later, non-gated activation support was added to the trtllm-gen backend (PR #2707), but the CUTLASS backend's validation was never updated. The NVFP4 path in the same file (line ~1131) already handles this correctly with an if (isGatedActivation(...)) guard.

How did we fix it?

For each of the 3 MXFP8 quant branches:

  1. Extract int const fc1_n_mult = isGatedActivation(base_activation_type) ? 2 : 1;
  2. Replace the hardcoded * 2 with * fc1_n_mult
  3. Update error messages: gated shows "inter_size * 2", non-gated shows "inter_size"

Before:

fc1_weight_block.size(1) == alignToSfDim(inter_size, ...) * 2

After:

int const fc1_n_mult = isGatedActivation(base_activation_type) ? 2 : 1;
fc1_weight_block.size(1) == alignToSfDim(inter_size, ...) * fc1_n_mult

How do we know it works?

  • pre-commit run passes (clang-format, lint, etc.)
  • Gated activations (default Swiglu): fc1_n_mult = 2 — identical to old behavior, no regression
  • Non-gated activations: fc1_n_mult = 1 — shape check now accepts correct inter_size dimension
  • Full GPU test suite requires CI (@flashinfer-bot run)

Related

cc @aleozlx @nv-yunzheq

Summary by CodeRabbit

  • Bug Fixes
    • Fixed weight block size validation for Mixture of Experts (MOE) to correctly handle both gated and non-gated activation types, ensuring proper support across different activation configurations.

The fc1 weight block shape validation in getQuantParams() hardcodes
'* 2' for the N-dimension check, assuming gated activations. This
causes non-gated activations (Relu2, Gelu, etc.) to fail validation
even with correct shapes.

Extract the gated-activation multiplier to a local variable
(fc1_n_mult) and use it in all three MXFP8 quant branches
(WMxfp8AMxfp8, WMxfp4AFp8, WMxfp4AMxfp8). Also update error
messages to display the actual expected multiplier.

Fixes flashinfer-ai#2731

Tested: lint + pre-commit pass; full pytest requires Linux/GPU (CI)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Yiyang Liu <37043548+ianliuy@users.noreply.github.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 15, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 1c00199f-fdcd-48ba-98c1-c69ff2e60755

📥 Commits

Reviewing files that changed from the base of the PR and between 25b324d and e88d0c0.

📒 Files selected for processing (1)
  • csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu

📝 Walkthrough

Walkthrough

A bug fix refines the CUTLASS MOE backend's weight block dimension validation by introducing conditional multiplier logic. The hard-coded factor * 2 is now applied only to gated activations via fc1_n_mult, correcting overly strict validation for non-gated activation types.

Changes

Cohort / File(s) Summary
CUTLASS MOE Backend Validation
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu
Replaced hard-coded * 2 multiplier in fc1_weight_block dimension validation with computed fc1_n_mult variable that equals 2 for gated activations and 1 otherwise. Updated corresponding TVM_FFI_ICHECK error messages to conditionally display " * 2" only for gated activation types.

Estimated code review effort

🎯 1 (Trivial) | ⏱️ ~3 minutes

Suggested labels

run-ci, op: moe

Suggested reviewers

  • yzh119
  • sricketts
  • nv-yunzheq
  • bkryu
  • samuellees
  • jiahanc

Poem

🐰 A bunny hops through gated halls,
Where activations rise and fall,
No more strict constraints that bind,
Just logic matched, exactly aligned!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title 'fix: guard MXFP8 fc1 weight shape check for non-gated activations' accurately and concisely describes the main change of making the fc1 weight shape validation conditional based on activation type.
Description check ✅ Passed The PR description is comprehensive and well-structured, containing clear sections explaining the bug, affected users, root cause, fix implementation, and verification steps.
Linked Issues check ✅ Passed The code changes directly address issue #2731 by conditioning the fc1 weight shape multiplier on activation type (2 for gated, 1 for non-gated), enabling non-gated activations to pass validation.
Out of Scope Changes check ✅ Passed All changes are scoped to fixing the MXFP8 fc1 weight validation in three branches for gated/non-gated activations; no unrelated modifications are present.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the FusedMoeRunner in the CUTLASS backend to support both gated and non-gated activations. It replaces the hardcoded multiplier of 2 for the fc1 weight block size with a dynamic fc1_n_mult variable determined by the activation type. Additionally, the error messages for shape validation have been updated to dynamically reflect the expected dimensions. I have no feedback to provide.

@aleozlx aleozlx added the run-ci label Apr 24, 2026
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 24, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !600 has been created, and the CI pipeline #49350454 is currently running. I'll report back once the pipeline job completes.

@aleozlx aleozlx enabled auto-merge (squash) April 24, 2026 02:02
@aleozlx aleozlx merged commit 0798a7d into flashinfer-ai:main Apr 24, 2026
31 of 44 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Potentially superfluous check that disables non gated activations in the cutlass fused moe API

3 participants