Use TE NVFP4 reference quantizer and add bitwise exact tests#1054
Use TE NVFP4 reference quantizer and add bitwise exact tests#1054yueming-yuan merged 5 commits intoradixark:mainfrom
Conversation
|
Hi @Zhichenzzz @yueming-yuan , let me know if you have any workloads that use the nvfp4 quantizers and want to test them. Thank you. |
There was a problem hiding this comment.
Code Review
This pull request refactors NVFP4 quantization logic across the processor and conversion tools to use the TransformerEngine reference implementation, adds MoE gate/up pair validation, and introduces unit tests for MXFP8 and NVFP4. Feedback highlights a regression in zero global_amax handling and suggests optimizing performance by moving constant definitions and imports out of frequently called functions.
| torch.finfo(torch.float32).max, | ||
| device=global_encode_scale.device, | ||
| dtype=torch.float32, | ||
| ), | ||
| ) | ||
| if global_encode_scale.numel() == 1: | ||
| if global_encode_scale == torch.tensor(0.0, device=global_amax.device, dtype=torch.float32): | ||
| global_encode_scale = torch.tensor(1.0, device=global_amax.device, dtype=torch.float32) | ||
| else: | ||
| global_encode_scale = torch.where( | ||
| global_encode_scale == 0.0, | ||
| torch.ones_like(global_encode_scale), | ||
| global_encode_scale, | ||
| ) | ||
| return torch.div(1.0, global_encode_scale) | ||
|
|
||
|
|
||
| def _quantize_nvfp4_1d( | ||
| weight: torch.Tensor, | ||
| global_amax: torch.Tensor | None = None, | ||
| group_size: int = NVFP4_GROUP_SIZE, | ||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| """ |
There was a problem hiding this comment.
The implementation of _nvfp4_global_decode_scale_te introduces a regression for tensors with zero global_amax. In the previous version (line 201 on the LEFT), a zero global_amax resulted in a global_scale of 1.0. With the new logic, if global_amax is 0.0, global_encode_scale becomes inf (clamped to finfo.max), which leads to a global_decode_scale of approximately 0.0. While this might not affect the final dequantized values (which remain zero), it is inconsistent with the previous behavior and could lead to numerical issues if the scale is used in other contexts.
Additionally, creating tensors for constants like FP4_E2M1_MAX and FP8_E4M3_MAX inside the function is inefficient. You can use the scalars directly in torch.div or define them as tensors once at the module level.
def _nvfp4_global_decode_scale_te(global_amax: torch.Tensor) -> torch.Tensor:
scale_factor = FP8_E4M3_MAX * FP4_E2M1_MAX
global_encode_scale = torch.div(scale_factor, global_amax.to(torch.float32))
# Handle zero amax (inf scale) and inf amax (zero scale) by defaulting to 1.0
# to match previous behavior and ensure numerical stability.
is_invalid = torch.isinf(global_encode_scale) | (global_encode_scale == 0)
global_encode_scale = torch.where(
is_invalid,
torch.ones_like(global_encode_scale),
global_encode_scale
)
# Clamp to float32 max to prevent overflow
finfo_max = torch.finfo(torch.float32).max
global_encode_scale = torch.clamp(global_encode_scale, max=finfo_max)
return torch.div(1.0, global_encode_scale)| global_amax, | ||
| NVFP4_GROUP_SIZE, | ||
| 1, | ||
| pow_2_scales=False, |
There was a problem hiding this comment.
Importing NVFP4QuantizerRef inside _quantize_nvfp4_1d is inefficient because this function is called within a loop when processing 3D weights (e.g., MoE experts). Although Python caches imports, performing the lookup repeatedly adds unnecessary overhead. Consider moving this import to the top of the file or at least outside the loop in quantize_nvfp4.
| def _nvfp4_global_decode_scale_te(global_amax: torch.Tensor) -> torch.Tensor: | ||
| fp4_max = torch.tensor(FP4_E2M1_MAX, device=global_amax.device, dtype=torch.float32) | ||
| fp8_max = torch.tensor(FP8_E4M3_MAX, device=global_amax.device, dtype=torch.float32) | ||
| global_encode_scale = torch.div(fp8_max * fp4_max, global_amax.to(torch.float32)) | ||
| global_encode_scale = torch.min( | ||
| global_encode_scale, | ||
| torch.tensor( | ||
| torch.finfo(torch.float32).max, | ||
| device=global_encode_scale.device, | ||
| dtype=torch.float32, | ||
| ), | ||
| ) | ||
| if global_encode_scale.numel() == 1: | ||
| if global_encode_scale == torch.tensor(0.0, device=global_amax.device, dtype=torch.float32): | ||
| global_encode_scale = torch.tensor(1.0, device=global_amax.device, dtype=torch.float32) | ||
| else: | ||
| global_encode_scale = torch.where( | ||
| global_encode_scale == 0.0, | ||
| torch.ones_like(global_encode_scale), | ||
| global_encode_scale, | ||
| ) | ||
| return torch.div(1.0, global_encode_scale) |
There was a problem hiding this comment.
This function duplicates the logic in the processor and shares the same issues: regression for zero global_amax and inefficient constant tensor creation. Please apply the same improvements suggested for the processor version.
def _nvfp4_global_decode_scale_te(global_amax: torch.Tensor) -> torch.Tensor:
scale_factor = FP8_E4M3_MAX * FP4_E2M1_MAX
global_encode_scale = torch.div(scale_factor, global_amax.to(torch.float32))
is_invalid = torch.isinf(global_encode_scale) | (global_encode_scale == 0)
global_encode_scale = torch.where(
is_invalid,
torch.ones_like(global_encode_scale),
global_encode_scale
)
finfo_max = torch.finfo(torch.float32).max
global_encode_scale = torch.clamp(global_encode_scale, max=finfo_max)
return torch.div(1.0, global_encode_scale)
Thank you @zianglih for the great work! Could you also share some accuracy validation experiments on top of this pr? In this case, we can reproduce and re-verify it. Thanks! |
|
@Zhichenzzz There is no E2E RL validation but the quantizer is bitwise exact with TE's ground truth reference implementation. You can check the new tests I have added. |
Zhichenzzz
left a comment
There was a problem hiding this comment.
LGTM! Due to resource constraints, I verified this with kernel-level matching tests.
This implementation should have better numerics and is less error-prone.
Current FlashInfer quantizer does not have
rnfor e4m3 scales and is not bitwise exact with TE, so not yet safe to use.Redo the implementation for the following flags to mirror well tested mxfp8 implementation introduced in #614:
--num-layers-at-start-in-bf16--num-layers-at-end-in-bf16--extra-high-precision-layers-hf--extra-high-precision-layers-megatron