Skip to content

Use TE NVFP4 reference quantizer and add bitwise exact tests#1054

Merged
yueming-yuan merged 5 commits intoradixark:mainfrom
zianglih:bitwise
May 1, 2026
Merged

Use TE NVFP4 reference quantizer and add bitwise exact tests#1054
yueming-yuan merged 5 commits intoradixark:mainfrom
zianglih:bitwise

Conversation

@zianglih
Copy link
Copy Markdown
Contributor

@zianglih zianglih commented Apr 29, 2026

This implementation should have better numerics and is less error-prone.

Current FlashInfer quantizer does not have rn for e4m3 scales and is not bitwise exact with TE, so not yet safe to use.

Redo the implementation for the following flags to mirror well tested mxfp8 implementation introduced in #614:

  • --num-layers-at-start-in-bf16
  • --num-layers-at-end-in-bf16
  • --extra-high-precision-layers-hf
  • --extra-high-precision-layers-megatron

@zianglih
Copy link
Copy Markdown
Contributor Author

Hi @Zhichenzzz @yueming-yuan , let me know if you have any workloads that use the nvfp4 quantizers and want to test them. Thank you.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors NVFP4 quantization logic across the processor and conversion tools to use the TransformerEngine reference implementation, adds MoE gate/up pair validation, and introduces unit tests for MXFP8 and NVFP4. Feedback highlights a regression in zero global_amax handling and suggests optimizing performance by moving constant definitions and imports out of frequently called functions.

Comment on lines +156 to 177
torch.finfo(torch.float32).max,
device=global_encode_scale.device,
dtype=torch.float32,
),
)
if global_encode_scale.numel() == 1:
if global_encode_scale == torch.tensor(0.0, device=global_amax.device, dtype=torch.float32):
global_encode_scale = torch.tensor(1.0, device=global_amax.device, dtype=torch.float32)
else:
global_encode_scale = torch.where(
global_encode_scale == 0.0,
torch.ones_like(global_encode_scale),
global_encode_scale,
)
return torch.div(1.0, global_encode_scale)


def _quantize_nvfp4_1d(
weight: torch.Tensor,
global_amax: torch.Tensor | None = None,
group_size: int = NVFP4_GROUP_SIZE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The implementation of _nvfp4_global_decode_scale_te introduces a regression for tensors with zero global_amax. In the previous version (line 201 on the LEFT), a zero global_amax resulted in a global_scale of 1.0. With the new logic, if global_amax is 0.0, global_encode_scale becomes inf (clamped to finfo.max), which leads to a global_decode_scale of approximately 0.0. While this might not affect the final dequantized values (which remain zero), it is inconsistent with the previous behavior and could lead to numerical issues if the scale is used in other contexts.

Additionally, creating tensors for constants like FP4_E2M1_MAX and FP8_E4M3_MAX inside the function is inefficient. You can use the scalars directly in torch.div or define them as tensors once at the module level.

def _nvfp4_global_decode_scale_te(global_amax: torch.Tensor) -> torch.Tensor:
    scale_factor = FP8_E4M3_MAX * FP4_E2M1_MAX
    global_encode_scale = torch.div(scale_factor, global_amax.to(torch.float32))

    # Handle zero amax (inf scale) and inf amax (zero scale) by defaulting to 1.0
    # to match previous behavior and ensure numerical stability.
    is_invalid = torch.isinf(global_encode_scale) | (global_encode_scale == 0)
    global_encode_scale = torch.where(
        is_invalid,
        torch.ones_like(global_encode_scale),
        global_encode_scale
    )

    # Clamp to float32 max to prevent overflow
    finfo_max = torch.finfo(torch.float32).max
    global_encode_scale = torch.clamp(global_encode_scale, max=finfo_max)

    return torch.div(1.0, global_encode_scale)

global_amax,
NVFP4_GROUP_SIZE,
1,
pow_2_scales=False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Importing NVFP4QuantizerRef inside _quantize_nvfp4_1d is inefficient because this function is called within a loop when processing 3D weights (e.g., MoE experts). Although Python caches imports, performing the lookup repeatedly adds unnecessary overhead. Consider moving this import to the top of the file or at least outside the loop in quantize_nvfp4.

Comment on lines +158 to +179
def _nvfp4_global_decode_scale_te(global_amax: torch.Tensor) -> torch.Tensor:
fp4_max = torch.tensor(FP4_E2M1_MAX, device=global_amax.device, dtype=torch.float32)
fp8_max = torch.tensor(FP8_E4M3_MAX, device=global_amax.device, dtype=torch.float32)
global_encode_scale = torch.div(fp8_max * fp4_max, global_amax.to(torch.float32))
global_encode_scale = torch.min(
global_encode_scale,
torch.tensor(
torch.finfo(torch.float32).max,
device=global_encode_scale.device,
dtype=torch.float32,
),
)
if global_encode_scale.numel() == 1:
if global_encode_scale == torch.tensor(0.0, device=global_amax.device, dtype=torch.float32):
global_encode_scale = torch.tensor(1.0, device=global_amax.device, dtype=torch.float32)
else:
global_encode_scale = torch.where(
global_encode_scale == 0.0,
torch.ones_like(global_encode_scale),
global_encode_scale,
)
return torch.div(1.0, global_encode_scale)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function duplicates the logic in the processor and shares the same issues: regression for zero global_amax and inefficient constant tensor creation. Please apply the same improvements suggested for the processor version.

def _nvfp4_global_decode_scale_te(global_amax: torch.Tensor) -> torch.Tensor:
    scale_factor = FP8_E4M3_MAX * FP4_E2M1_MAX
    global_encode_scale = torch.div(scale_factor, global_amax.to(torch.float32))

    is_invalid = torch.isinf(global_encode_scale) | (global_encode_scale == 0)
    global_encode_scale = torch.where(
        is_invalid,
        torch.ones_like(global_encode_scale),
        global_encode_scale
    )

    finfo_max = torch.finfo(torch.float32).max
    global_encode_scale = torch.clamp(global_encode_scale, max=finfo_max)

    return torch.div(1.0, global_encode_scale)

@Zhichenzzz
Copy link
Copy Markdown
Contributor

Hi @Zhichenzzz @yueming-yuan , let me know if you have any workloads that use the nvfp4 quantizers and want to test them. Thank you.

Thank you @zianglih for the great work! Could you also share some accuracy validation experiments on top of this pr? In this case, we can reproduce and re-verify it. Thanks!

@zianglih
Copy link
Copy Markdown
Contributor Author

@Zhichenzzz There is no E2E RL validation but the quantizer is bitwise exact with TE's ground truth reference implementation. You can check the new tests I have added.

Copy link
Copy Markdown
Contributor

@Zhichenzzz Zhichenzzz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Due to resource constraints, I verified this with kernel-level matching tests.

@yueming-yuan yueming-yuan merged commit ce10d21 into radixark:main May 1, 2026
@zianglih zianglih deleted the bitwise branch May 1, 2026 07:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants