fix: guard MXFP8 fc1 weight shape check for non-gated activations#3082
Conversation
The fc1 weight block shape validation in getQuantParams() hardcodes '* 2' for the N-dimension check, assuming gated activations. This causes non-gated activations (Relu2, Gelu, etc.) to fail validation even with correct shapes. Extract the gated-activation multiplier to a local variable (fc1_n_mult) and use it in all three MXFP8 quant branches (WMxfp8AMxfp8, WMxfp4AFp8, WMxfp4AMxfp8). Also update error messages to display the actual expected multiplier. Fixes flashinfer-ai#2731 Tested: lint + pre-commit pass; full pytest requires Linux/GPU (CI) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Yiyang Liu <37043548+ianliuy@users.noreply.github.com>
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughA bug fix refines the CUTLASS MOE backend's weight block dimension validation by introducing conditional multiplier logic. The hard-coded factor Changes
Estimated code review effort🎯 1 (Trivial) | ⏱️ ~3 minutes Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request updates the FusedMoeRunner in the CUTLASS backend to support both gated and non-gated activations. It replaces the hardcoded multiplier of 2 for the fc1 weight block size with a dynamic fc1_n_mult variable determined by the activation type. Additionally, the error messages for shape validation have been updated to dynamically reflect the expected dimensions. I have no feedback to provide.
|
/bot run |
Fixes #2731.
What's broken?
When using the CUTLASS fused MoE backend with non-gated activations (e.g., Relu2, Gelu, Silu) and MXFP8 quantization, the fc1 weight shape validation unconditionally rejects the input — even when the shape is correct.
Who is affected?
Anyone using the CUTLASS fused MoE path with:
WMxfp8AMxfp8,WMxfp4AFp8, orWMxfp4AMxfp8Not affected: gated activations (Swiglu, Geglu, SwigluBias), or other quant modes (NVFP4 already handles this correctly).
Where is the bug?
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu, insidegetQuantParams()— the fc1 weight block N-dimension check hardcodes* 2at three MXFP8 branches (~L898, ~L1004, ~L1063).Why does it happen?
PR #2581 introduced MXFP8 support when only gated activations (Swiglu) existed, so
inter_size * 2was correct. Later, non-gated activation support was added to the trtllm-gen backend (PR #2707), but the CUTLASS backend's validation was never updated. The NVFP4 path in the same file (line ~1131) already handles this correctly with anif (isGatedActivation(...))guard.How did we fix it?
For each of the 3 MXFP8 quant branches:
int const fc1_n_mult = isGatedActivation(base_activation_type) ? 2 : 1;* 2with* fc1_n_mult"inter_size * 2", non-gated shows"inter_size"Before:
After:
How do we know it works?
pre-commit runpasses (clang-format, lint, etc.)fc1_n_mult = 2— identical to old behavior, no regressionfc1_n_mult = 1— shape check now accepts correctinter_sizedimension@flashinfer-bot run)Related
cc @aleozlx @nv-yunzheq
Summary by CodeRabbit