-
Notifications
You must be signed in to change notification settings - Fork 273
[Issue]: Incorrect MXFP4 mantissa rounding in fp4_utils.py #2247
Description
Problem Description
aiter/utility/fp4_utils.py contains an MXFP4 quantization kernel (_dynamic_mxfp4_quant_kernel_asm_layout) that still uses the old round-ties-up rounding logic, which was identified and fixed in aiter/ops/triton/_triton_kernels/quant.py via PR #975 (fixing #974).
The fix in #975 replaced the manual shift-based conversion with proper roundTiesToEven (banker's rounding) using three-way branching (saturate/denormal/normal masks) and the magic-number addition trick for denormals — matching torchao and MI355 v_cvt_scalef32_pk_fp4_f32 behavior. However, only quant.py was patched. The same buggy logic remains in fp4_utils.py.
Affected code
https://github.com/ROCm/aiter/blob/main/aiter/utility/fp4_utils.py#L321
The kernel at this location still contains the pre-#975 conversion:
# Extract sign, exponents and mantissa fields from FP32
s = qx & 0x80000000
e = (qx >> 23) & 0xFF
m = qx & 0x7FFFFF
E8_BIAS: tl.constexpr = 127
E2_BIAS: tl.constexpr = 1
# Denormal numbers
adjusted_exponents = tl.core.sub(E8_BIAS, e + 1, sanitize_overflow=False)
m = tl.where(e < E8_BIAS, (0x400000 | (m >> 1)) >> adjusted_exponents, m)
e = tl.maximum(e, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS)
# rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right
e2m1_tmp = tl.minimum((((e << 2) | (m >> 21)) + 1) >> 1, 0x7)
e2m1_value = ((s >> 28) | e2m1_tmp).to(tl.uint8)What's wrong
Two issues (same as originally reported in #974):
-
Normal values use round-ties-up instead of roundTiesToEven. The
(value + 1) >> 1pattern always rounds midpoints up. For example,-0.625 / scaleat the exact midpoint between FP4 values-1.0and-1.5rounds to-1.0instead of the correct-1.5(even mantissa). -
Denormals cannot round up properly. The manual shift-based denormal path doesn't handle rounding correctly. For example, FP32 value
0x3F000003should round up to0.5but rounds down to0.0.
Both torchao and the MI355 hardware instruction v_cvt_scalef32_pk_fp4_f32 use roundTiesToEven, so this code produces results inconsistent with the hardware and with the already-patched quant.py.
Expected behavior
fp4_utils.py should use the same corrected rounding logic that was applied to quant.py in PR #975, which implements:
- Three-way masking:
saturate_mask,denormal_mask,normal_mask - Denormal conversion via the magic-number addition trick (
denorm_mask_float) - Normal conversion with proper round-to-nearest-even via
mant_oddbias
Suggested fix
Apply the same transformation from PR #975 to the _dynamic_mxfp4_quant_kernel_asm_layout kernel in fp4_utils.py. The corrected conversion block from quant.py can be used directly.
Related
- [Issue]: Incorrect MXFP4 mantissa rounding #974 — Original issue report for incorrect MXFP4 mantissa rounding
- [TRITON] fix: MXFP4 mantissa rounding #975 — Fix applied to
aiter/ops/triton/_triton_kernels/quant.py(merged Nov 26, 2025)
Operating System
Linux-6.8.0-60-generic-x86_64-with-glibc2.39
CPU
AMD EPYC 9575F 64-Core Processor
GPU
AMD Instinct MI355X
ROCm Version
ROCm 7.1
ROCm Component
No response
Steps to Reproduce
No response
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
No response