Skip to content

[Issue]: Incorrect MXFP4 mantissa rounding in fp4_utils.py #2247

@Knarf04

Description

@Knarf04

Problem Description

aiter/utility/fp4_utils.py contains an MXFP4 quantization kernel (_dynamic_mxfp4_quant_kernel_asm_layout) that still uses the old round-ties-up rounding logic, which was identified and fixed in aiter/ops/triton/_triton_kernels/quant.py via PR #975 (fixing #974).

The fix in #975 replaced the manual shift-based conversion with proper roundTiesToEven (banker's rounding) using three-way branching (saturate/denormal/normal masks) and the magic-number addition trick for denormals — matching torchao and MI355 v_cvt_scalef32_pk_fp4_f32 behavior. However, only quant.py was patched. The same buggy logic remains in fp4_utils.py.

Affected code

https://github.com/ROCm/aiter/blob/main/aiter/utility/fp4_utils.py#L321

The kernel at this location still contains the pre-#975 conversion:

# Extract sign, exponents and mantissa fields from FP32
s = qx & 0x80000000
e = (qx >> 23) & 0xFF
m = qx & 0x7FFFFF

E8_BIAS: tl.constexpr = 127
E2_BIAS: tl.constexpr = 1

# Denormal numbers
adjusted_exponents = tl.core.sub(E8_BIAS, e + 1, sanitize_overflow=False)
m = tl.where(e < E8_BIAS, (0x400000 | (m >> 1)) >> adjusted_exponents, m)

e = tl.maximum(e, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS)

# rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right
e2m1_tmp = tl.minimum((((e << 2) | (m >> 21)) + 1) >> 1, 0x7)
e2m1_value = ((s >> 28) | e2m1_tmp).to(tl.uint8)

What's wrong

Two issues (same as originally reported in #974):

  1. Normal values use round-ties-up instead of roundTiesToEven. The (value + 1) >> 1 pattern always rounds midpoints up. For example, -0.625 / scale at the exact midpoint between FP4 values -1.0 and -1.5 rounds to -1.0 instead of the correct -1.5 (even mantissa).

  2. Denormals cannot round up properly. The manual shift-based denormal path doesn't handle rounding correctly. For example, FP32 value 0x3F000003 should round up to 0.5 but rounds down to 0.0.

Both torchao and the MI355 hardware instruction v_cvt_scalef32_pk_fp4_f32 use roundTiesToEven, so this code produces results inconsistent with the hardware and with the already-patched quant.py.

Expected behavior

fp4_utils.py should use the same corrected rounding logic that was applied to quant.py in PR #975, which implements:

  • Three-way masking: saturate_mask, denormal_mask, normal_mask
  • Denormal conversion via the magic-number addition trick (denorm_mask_float)
  • Normal conversion with proper round-to-nearest-even via mant_odd bias

Suggested fix

Apply the same transformation from PR #975 to the _dynamic_mxfp4_quant_kernel_asm_layout kernel in fp4_utils.py. The corrected conversion block from quant.py can be used directly.

Related

Operating System

Linux-6.8.0-60-generic-x86_64-with-glibc2.39

CPU

AMD EPYC 9575F 64-Core Processor

GPU

AMD Instinct MI355X

ROCm Version

ROCm 7.1

ROCm Component

No response

Steps to Reproduce

No response

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions