[MoE Refactor] MXFP4 Cutlass Experts to MK#34542
[MoE Refactor] MXFP4 Cutlass Experts to MK#34542vllm-bot merged 29 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the MXFP4 cutlass backend for MoE layers, improving modularity and adding support for new quantization schemes. The changes are well-structured and consistent across the modified files. The refactoring in vllm/model_executor/layers/quantization/mxfp4.py to use the modular kernel framework is a significant improvement. I've identified one high-severity performance issue related to object instantiation within the forward pass and provided a suggestion to address it.
| from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( | ||
| FlashInferExperts, | ||
| ) | ||
| from vllm.model_executor.layers.fused_moe.prepare_finalize import ( | ||
| MoEPrepareAndFinalizeNoEP, | ||
| ) | ||
|
|
||
| return output | ||
| self.moe_quant_config = self.get_fused_moe_quant_config(layer) | ||
| assert self.moe_quant_config is not None | ||
| self.kernel = mk.FusedMoEModularKernel( | ||
| MoEPrepareAndFinalizeNoEP(), | ||
| FlashInferExperts(moe_config=self.moe, quant_config=self.moe_quant_config), | ||
| shared_experts=None, | ||
| ) | ||
| return self.kernel( | ||
| hidden_states=x, | ||
| w1=layer.w13_weight, | ||
| w2=layer.w2_weight, | ||
| topk_weights=topk_weights, | ||
| topk_ids=topk_ids, | ||
| ) |
There was a problem hiding this comment.
Creating the FusedMoEModularKernel on every forward pass can introduce unnecessary overhead. It's better to initialize it once and cache it for subsequent calls. This can be done with lazy initialization within the apply method.
if not hasattr(self, "_kernel"):
self._kernel = None
if self._kernel is None:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
from vllm.model_executor.layers.f|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request refactors the MXFP4 cutlass backend to use the modular kernel interface, which is a positive step towards better code organization and maintainability. The changes primarily involve moving backend-specific logic into dedicated classes and leveraging the FusedMoEModularKernel. While the overall direction is good, I've identified two critical issues that could lead to incorrect behavior due to the refactoring. One issue involves incorrect data type casting for weights on a specific backend path, and the other relates to a quantization parameter not being correctly propagated in the new generic implementation. Please address these points to ensure the correctness of the refactored code.
| fc1_expert_weights = w1.view(torch.long) | ||
| fc2_expert_weights = w2.view(torch.long) | ||
| if self.quant_dtype == "mxfp8": | ||
| fake_input_scale = torch.ones( | ||
| self.moe_config.num_experts, device=hidden_states.device | ||
| ) | ||
| quant_scales = [ | ||
| self.w1_scale.view(torch.int32), | ||
| fake_input_scale, | ||
| self.w2_scale.view(torch.int32), | ||
| fake_input_scale, | ||
| ] | ||
| use_mxfp8_act_scaling = True | ||
| else: | ||
| assert hidden_states.dtype == torch.bfloat16 | ||
| quant_scales = [ | ||
| self.w1_scale, | ||
| self.w2_scale, | ||
| ] | ||
| a1q_scale = None | ||
| use_w4_group_scaling = True |
There was a problem hiding this comment.
The weight tensors w1 and w2 are unconditionally cast to torch.long. However, this cast should only occur when self.quant_dtype == 'mxfp8' (i.e., for the use_mxfp8_act_scaling=True path). For the else branch (use_w4_group_scaling=True), the weights should be passed as-is (as torch.uint8), as was done in the previous implementation. This incorrect casting can lead to runtime errors or incorrect computation in the kernel.
if self.quant_dtype == "mxfp8":
fc1_expert_weights = w1.view(torch.long)
fc2_expert_weights = w2.view(torch.long)
fake_input_scale = torch.ones(
self.moe_config.num_experts, device=hidden_states.device
)
quant_scales = [
self.w1_scale.view(torch.int32),
fake_input_scale,
self.w2_scale.view(torch.int32),
fake_input_scale,
]
use_mxfp8_act_scaling = True
else:
fc1_expert_weights = w1
fc2_expert_weights = w2
assert hidden_states.dtype == torch.bfloat16
quant_scales = [
self.w1_scale,
self.w2_scale,
]
a1q_scale = None
use_w4_group_scaling = True| return self.kernel( | ||
| hidden_states=x, | ||
| w1=layer.w13_weight, | ||
| w2=layer.w2_weight, | ||
| topk_weights=topk_weights, | ||
| topk_ids=topk_ids, | ||
| ) |
There was a problem hiding this comment.
This refactoring seems to have introduced a potential issue. The previous implementation for the SM100_FI_MXFP4_MXFP8_CUTLASS backend called mxfp8_quantize(x, True, 32), with is_sf_swizzled_layout=True. The new modular kernel path, via MoEPrepareAndFinalizeNoEP and moe_kernel_quantize_input, effectively calls mxfp8_quantize with is_sf_swizzled_layout=False. This discrepancy might lead to incorrect behavior or performance degradation for this backend, as the activation scales will not have the swizzled layout expected by the kernel when weight scales are swizzled. Please ensure the is_sf_swizzled_layout flag is correctly propagated or handled for this backend.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request refactors the MXFP4 CUTLASS backend for MoE layers to use the modular kernel framework, which improves code organization and maintainability. It also introduces a comprehensive testing infrastructure for GPQA evaluation, making the tests more robust and easier to configure. The changes are well-structured and the refactoring correctly encapsulates the kernel-specific logic within the FlashInferExperts class. The new testing setup is a great addition for ensuring correctness and performance on different hardware. Overall, this is a solid improvement to the codebase.
|
Hi @zyongye, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
|
||
| - label: GPQA Eval (GPT-OSS) (H200) | ||
| timeout_in_minutes: 120 | ||
| device: h200 |
There was a problem hiding this comment.
switch to h100 due to resource contraints (we have much more h100 in the ci)
There was a problem hiding this comment.
ditto, I think H200 is only for 8xH200
| @@ -635,6 +660,9 @@ def mxfp4_w4a16_moe_quant_config( | |||
| w2_scale: Union[torch.Tensor, "PrecisionConfig"], | |||
| w1_bias: torch.Tensor | None = None, | |||
| w2_bias: torch.Tensor | None = None, | |||
| gemm1_alpha: torch.Tensor | None = None, | |||
There was a problem hiding this comment.
These are just hardcoded values right? (AFAICT they are: https://github.com/zyongye/vllm/blob/d7d68c3127bc27d97b20ceb901068e709c430bd5/vllm/model_executor/layers/quantization/mxfp4.py#L419-L430)
In that case, I think we should avoid passing these via the quant config and instead just having these parameters in the Kernel itself. This will make things clearer and reduce the surface of the API contract
There was a problem hiding this comment.
I moved inside the FlashInferExperts into the init phase.
| @@ -746,6 +746,30 @@ def _interleave_mxfp4_cutlass_sm90(w): | |||
| layer.w2_weight_scale = torch.nn.Parameter( | |||
| w2_scales_interleaved, requires_grad=False | |||
| ) | |||
|
|
|||
| assert not self.moe.use_ep, ( | |||
There was a problem hiding this comment.
I dont think this is needed. I think this kernel does support EP
NOTE: the noEP thing here is misnamed. It should be NoDPEP
There was a problem hiding this comment.
The kernel interface actually dispatch to multiple kernels. It will error out when I run EP.
| @@ -193,13 +193,10 @@ def _mxfp4_quantize( | |||
| def _mxfp8_e4m3_quantize( | |||
| A: torch.Tensor, | |||
| A_scale: torch.Tensor | None, | |||
| per_act_token_quant: bool, | |||
| block_shape: list[int] | None = None, | |||
| is_sf_swizzled_layout: bool, | |||
There was a problem hiding this comment.
I think we should preserve the existing args to just avoid future footguns
There was a problem hiding this comment.
I changed it back. Earlier I thought we should align this with nxfp4 quantization function signature.
|
this looks good. minor nits other than the stuff about the gemm_alpha |
ca4fc4c to
f744783
Compare
|
Hi @zyongye, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
014d4ef to
2618960
Compare
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Purpose
Refactor MXFP4 cutlass backend for ongoing moe refactor
Also adding testing infrastructure.
Test Plan
Test GPQA benchmarks, with medium reasoning effort
gpt-oss-120b TP=2 on gb200 with tested kernel on
gpt-oss-120b TEP=2 on gb200 with tested kernel on
gpt-oss-20b TP=2 on H200 with tested kernel on
Test command: Follow the recipe
Test Result
gb200: GPQA with medium reasoning effort on 120b: 0.727. Match the recipe.
H200: GPQA with medium reasoning effort on 20b: 0.6641. Match the recipe.
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.