Skip to content

[MoE Refactor] MXFP4 Cutlass Experts to MK#34542

Merged
vllm-bot merged 29 commits intovllm-project:mainfrom
zyongye:mxfp4_refactor_cutlass_experts
Feb 26, 2026
Merged

[MoE Refactor] MXFP4 Cutlass Experts to MK#34542
vllm-bot merged 29 commits intovllm-project:mainfrom
zyongye:mxfp4_refactor_cutlass_experts

Conversation

@zyongye
Copy link
Copy Markdown
Member

@zyongye zyongye commented Feb 13, 2026

Purpose

Refactor MXFP4 cutlass backend for ongoing moe refactor

Also adding testing infrastructure.

Test Plan

Test GPQA benchmarks, with medium reasoning effort

gpt-oss-120b TP=2 on gb200 with tested kernel on

VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS=1 vllm serve openai/gpt-oss-120b -tp 2

gpt-oss-120b TEP=2 on gb200 with tested kernel on

VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS=1 vllm serve openai/gpt-oss-120b -tp 2 -ep

gpt-oss-20b TP=2 on H200 with tested kernel on

VLLM_USE_FLASHINFER_MOE_MXFP4_BF16=1 vllm serve openai/gpt-oss-20b -tp 2

Test command: Follow the recipe

Test Result

gb200: GPQA with medium reasoning effort on 120b: 0.727. Match the recipe.

H200: GPQA with medium reasoning effort on 20b: 0.6641. Match the recipe.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the MXFP4 cutlass backend for MoE layers, improving modularity and adding support for new quantization schemes. The changes are well-structured and consistent across the modified files. The refactoring in vllm/model_executor/layers/quantization/mxfp4.py to use the modular kernel framework is a significant improvement. I've identified one high-severity performance issue related to object instantiation within the forward pass and provided a suggestion to address it.

Comment on lines +964 to +984
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)

return output
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
assert self.moe_quant_config is not None
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
FlashInferExperts(moe_config=self.moe, quant_config=self.moe_quant_config),
shared_experts=None,
)
return self.kernel(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Creating the FusedMoEModularKernel on every forward pass can introduce unnecessary overhead. It's better to initialize it once and cache it for subsequent calls. This can be done with lazy initialization within the apply method.

        if not hasattr(self, "_kernel"):
            self._kernel = None

        if self._kernel is None:
            from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
                FlashInferExperts,
            )
            from vllm.model_executor.layers.f

@zyongye
Copy link
Copy Markdown
Member Author

zyongye commented Feb 15, 2026

/gemini review

@zyongye zyongye marked this pull request as ready for review February 15, 2026 17:32
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the MXFP4 cutlass backend to use the modular kernel interface, which is a positive step towards better code organization and maintainability. The changes primarily involve moving backend-specific logic into dedicated classes and leveraging the FusedMoEModularKernel. While the overall direction is good, I've identified two critical issues that could lead to incorrect behavior due to the refactoring. One issue involves incorrect data type casting for weights on a specific backend path, and the other relates to a quantization parameter not being correctly propagated in the new generic implementation. Please address these points to ensure the correctness of the refactored code.

Comment on lines +301 to +321
fc1_expert_weights = w1.view(torch.long)
fc2_expert_weights = w2.view(torch.long)
if self.quant_dtype == "mxfp8":
fake_input_scale = torch.ones(
self.moe_config.num_experts, device=hidden_states.device
)
quant_scales = [
self.w1_scale.view(torch.int32),
fake_input_scale,
self.w2_scale.view(torch.int32),
fake_input_scale,
]
use_mxfp8_act_scaling = True
else:
assert hidden_states.dtype == torch.bfloat16
quant_scales = [
self.w1_scale,
self.w2_scale,
]
a1q_scale = None
use_w4_group_scaling = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The weight tensors w1 and w2 are unconditionally cast to torch.long. However, this cast should only occur when self.quant_dtype == 'mxfp8' (i.e., for the use_mxfp8_act_scaling=True path). For the else branch (use_w4_group_scaling=True), the weights should be passed as-is (as torch.uint8), as was done in the previous implementation. This incorrect casting can lead to runtime errors or incorrect computation in the kernel.

            if self.quant_dtype == "mxfp8":
                fc1_expert_weights = w1.view(torch.long)
                fc2_expert_weights = w2.view(torch.long)
                fake_input_scale = torch.ones(
                    self.moe_config.num_experts, device=hidden_states.device
                )
                quant_scales = [
                    self.w1_scale.view(torch.int32),
                    fake_input_scale,
                    self.w2_scale.view(torch.int32),
                    fake_input_scale,
                ]
                use_mxfp8_act_scaling = True
            else:
                fc1_expert_weights = w1
                fc2_expert_weights = w2
                assert hidden_states.dtype == torch.bfloat16
                quant_scales = [
                    self.w1_scale,
                    self.w2_scale,
                ]
                a1q_scale = None
                use_w4_group_scaling = True

Comment on lines 984 to 990
return self.kernel(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This refactoring seems to have introduced a potential issue. The previous implementation for the SM100_FI_MXFP4_MXFP8_CUTLASS backend called mxfp8_quantize(x, True, 32), with is_sf_swizzled_layout=True. The new modular kernel path, via MoEPrepareAndFinalizeNoEP and moe_kernel_quantize_input, effectively calls mxfp8_quantize with is_sf_swizzled_layout=False. This discrepancy might lead to incorrect behavior or performance degradation for this backend, as the activation scales will not have the swizzled layout expected by the kernel when weight scales are swizzled. Please ensure the is_sf_swizzled_layout flag is correctly propagated or handled for this backend.

@zyongye zyongye added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 15, 2026
@zyongye
Copy link
Copy Markdown
Member Author

zyongye commented Feb 16, 2026

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the MXFP4 CUTLASS backend for MoE layers to use the modular kernel framework, which improves code organization and maintainability. It also introduces a comprehensive testing infrastructure for GPQA evaluation, making the tests more robust and easier to configure. The changes are well-structured and the refactoring correctly encapsulates the kernel-specific logic within the FlashInferExperts class. The new testing setup is a great addition for ensuring correctness and performance on different hardware. Overall, this is a solid improvement to the codebase.

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Feb 16, 2026

Hi @zyongye, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Comment thread .buildkite/test_areas/lm_eval.yaml Outdated

- label: GPQA Eval (GPT-OSS) (H200)
timeout_in_minutes: 120
device: h200
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

switch to h100 due to resource contraints (we have much more h100 in the ci)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, I think H200 is only for 8xH200

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed

@@ -635,6 +660,9 @@ def mxfp4_w4a16_moe_quant_config(
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
gemm1_alpha: torch.Tensor | None = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are just hardcoded values right? (AFAICT they are: https://github.com/zyongye/vllm/blob/d7d68c3127bc27d97b20ceb901068e709c430bd5/vllm/model_executor/layers/quantization/mxfp4.py#L419-L430)

In that case, I think we should avoid passing these via the quant config and instead just having these parameters in the Kernel itself. This will make things clearer and reduce the surface of the API contract

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved inside the FlashInferExperts into the init phase.

Comment thread vllm/model_executor/layers/quantization/utils/quant_utils.py
@@ -746,6 +746,30 @@ def _interleave_mxfp4_cutlass_sm90(w):
layer.w2_weight_scale = torch.nn.Parameter(
w2_scales_interleaved, requires_grad=False
)

assert not self.moe.use_ep, (
Copy link
Copy Markdown
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat Feb 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think this is needed. I think this kernel does support EP

NOTE: the noEP thing here is misnamed. It should be NoDPEP

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kernel interface actually dispatch to multiple kernels. It will error out when I run EP.

@@ -193,13 +193,10 @@ def _mxfp4_quantize(
def _mxfp8_e4m3_quantize(
A: torch.Tensor,
A_scale: torch.Tensor | None,
per_act_token_quant: bool,
block_shape: list[int] | None = None,
is_sf_swizzled_layout: bool,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should preserve the existing args to just avoid future footguns

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it back. Earlier I thought we should align this with nxfp4 quantization function signature.

@robertgshaw2-redhat robertgshaw2-redhat changed the title Mxfp4 refactor cutlass experts [MoE Refactor] MXFP4 Cutlass Experts to MK Feb 16, 2026
@robertgshaw2-redhat
Copy link
Copy Markdown
Collaborator

this looks good. minor nits other than the stuff about the gemm_alpha

@zyongye zyongye force-pushed the mxfp4_refactor_cutlass_experts branch from ca4fc4c to f744783 Compare February 17, 2026 20:40
@zyongye zyongye removed the ready ONLY add when PR is ready to merge/full CI is needed label Feb 17, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Feb 17, 2026

Hi @zyongye, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
@zyongye zyongye force-pushed the mxfp4_refactor_cutlass_experts branch from 014d4ef to 2618960 Compare February 25, 2026 21:38
Copy link
Copy Markdown
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I kicked off the GPQA Eval tests manually now to see that they work

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Feb 26, 2026
@github-project-automation github-project-automation bot moved this from To Triage to Ready in gpt-oss Issues & Enhancements Feb 26, 2026
@vllm-bot vllm-bot merged commit 1976356 into vllm-project:main Feb 26, 2026
70 of 71 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Feb 26, 2026
haanjack pushed a commit to haanjack/vllm that referenced this pull request Feb 26, 2026
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Copilot AI pushed a commit to machov/vllm that referenced this pull request Mar 10, 2026
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
@zyongye zyongye deleted the mxfp4_refactor_cutlass_experts branch March 12, 2026 21:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build gpt-oss Related to GPT-OSS models nvidia ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants