Support 4over6 nvfp4 for quantizer and fused MoE#3264
Support 4over6 nvfp4 for quantizer and fused MoE#3264zianglih wants to merge 15 commits intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughThis PR adds a runtime-configurable NVFP4 4-over-6 per-token quantization mode with MSE-based scale-candidate selection. It extends kernel templates with ChangesNVFP4 4-over-6 per-token quantization
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a new NVFP4 quantization mode called '4/6 MSE scale-candidate mode,' which is activated via the FLASHINFER_NVFP4_FOUR_OVER_SIX environment variable. The implementation includes updates to CUDA kernels for per-token scaling and quantization, as well as corresponding Python tests and documentation. Reviewer feedback suggests several optimizations for the CUDA code, including refactoring duplicated logic into helper functions, precalculating values to reduce redundant arithmetic operations within loops, and replacing switch statements with lookup tables to improve performance and readability.
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/utils/test_fp4_quantize.py (1)
706-747:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winPin FOUR_OVER_SIX off for the baseline TE-reference test.
Line 706 validates the non-4/6 reference path, but this test can be affected by an externally set
FLASHINFER_NVFP4_FOUR_OVER_SIX. Make the mode explicit in-test to avoid environment-coupled failures.🔧 Proposed fix
def test_nvfp4_per_token_quantize_te_reference( dtype: torch.dtype, shape: tuple[int, int], sf_layout: SfLayout, init_data: str, device: str, + monkeypatch: pytest.MonkeyPatch, ) -> None: """Per-token NVFP4 quantization should match the TE Python reference bitwise.""" if not _is_fp4_supported(torch.device(device)): pytest.skip("Nvfp4 Requires compute capability >= 10 and CUDA >= 12.8") + monkeypatch.setenv("FLASHINFER_NVFP4_FOUR_OVER_SIX", "0")🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/utils/test_fp4_quantize.py` around lines 706 - 747, In test_nvfp4_per_token_quantize_te_reference ensure the FOUR_OVER_SIX mode is pinned off so the TE-reference path is deterministic: at the start of test_nvfp4_per_token_quantize_te_reference set the environment flag FLASHINFER_NVFP4_FOUR_OVER_SIX="0" (or call your library’s setter if available) before creating x and running ref_fp4_quant_te/nvfp4_quantize, and restore the previous value at the end of the test to avoid leaking global state.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@tests/utils/test_fp4_quantize.py`:
- Around line 706-747: In test_nvfp4_per_token_quantize_te_reference ensure the
FOUR_OVER_SIX mode is pinned off so the TE-reference path is deterministic: at
the start of test_nvfp4_per_token_quantize_te_reference set the environment flag
FLASHINFER_NVFP4_FOUR_OVER_SIX="0" (or call your library’s setter if available)
before creating x and running ref_fp4_quant_te/nvfp4_quantize, and restore the
previous value at the end of the test to avoid leaking global state.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: e9dcc260-81db-4c62-b9e1-585a7ba243bb
📒 Files selected for processing (5)
csrc/nv_internal/cpp/kernels/quantization.cucsrc/nv_internal/tensorrt_llm/kernels/quantization.cuhcsrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuhflashinfer/quantization/fp4_quantization.pytests/utils/test_fp4_quantize.py
aleozlx
left a comment
There was a problem hiding this comment.
looks good to me so far!
thx for the contrib. pls address conflicts
There was a problem hiding this comment.
Actionable comments posted: 1
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 064e42d2-1286-4387-8bd1-9c66fe18ddac
📒 Files selected for processing (7)
csrc/nv_internal/cpp/common/envUtils.cppcsrc/nv_internal/cpp/kernels/quantization.cucsrc/nv_internal/tensorrt_llm/common/envUtils.hcsrc/nv_internal/tensorrt_llm/kernels/quantization.cuhcsrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuhflashinfer/quantization/fp4_quantization.pytests/utils/test_fp4_quantize.py
✅ Files skipped from review due to trivial changes (2)
- csrc/nv_internal/tensorrt_llm/common/envUtils.h
- flashinfer/quantization/fp4_quantization.py
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/moe/test_trtllm_gen_per_token_moe.py (1)
114-134:⚠️ Potential issue | 🟠 Major | ⚡ Quick winThis changes the scales, but not the backend mode.
The new
use_4over6branch only rewrites the Python-side NVFP4 scale factors. The test never enables 4over6 viaset_nvfp4_4over6_envbefore callingnvfp4_quantize()andtrtllm_fp4_block_scale_routed_moe(), so theTruecases are not validating the actual 4over6 implementation. Apply the shared env helper around the quantize + kernel section so both sides run in the same mode.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/moe/test_trtllm_gen_per_token_moe.py` around lines 114 - 134, The test only updates Python-side scales via nvfp4_global_decode_scale_te but never flips the backend mode, so wrap the quantize+kernel calls with the shared helper set_nvfp4_4over6_env(use_4over6) so the backend is actually in 4over6 mode when calling nvfp4_quantize and trtllm_fp4_block_scale_routed_moe; specifically, call set_nvfp4_4over6_env(use_4over6) around the block that computes hidden_states/hidden_states_scale/per_token_scale_inv with nvfp4_quantize and the subsequent trtllm_fp4_block_scale_routed_moe invocation so both scale computation and kernel execution use the same mode (references: nvfp4_global_decode_scale_te, nvfp4_quantize, set_nvfp4_4over6_env, trtllm_fp4_block_scale_routed_moe).
♻️ Duplicate comments (1)
csrc/nv_internal/cpp/kernels/quantization.cu (1)
338-362:⚠️ Potential issue | 🟠 Major | ⚡ Quick winFP32 input still aborts when
FLASHINFER_NVFP4_4OVER6=1.
use4Over6is read unconditionally from the process-global env var, and theif constexpr (std::is_same_v<T, float>)branch then aborts viaTLLM_CHECK_WITH_INFO(!USE_4OVER6, ...). Any caller that quantizes afloatinput in a process where the env var is set (e.g. an MoE test running after a 4-over-6 test set the env in the same process) will fail, even though the legacy FP32 kernel is unchanged and capable of handling the request. Forceuse4Over6=falseforT=floatat the env-read site instead of aborting downstream.💡 Suggested fix
- bool const disableFP4QuantFastMath = tensorrt_llm::common::getEnvDisableFP4QuantFastMath(); - bool const use4Over6 = tensorrt_llm::common::getEnvNVFP4Use4Over6(); - bool const disable4Over6MSEFastMath = tensorrt_llm::common::getEnvNVFP4Disable4Over6MSEFastMath(); + bool const disableFP4QuantFastMath = tensorrt_llm::common::getEnvDisableFP4QuantFastMath(); + bool const use4Over6 = + !std::is_same_v<T, float> && tensorrt_llm::common::getEnvNVFP4Use4Over6(); + bool const disable4Over6MSEFastMath = + use4Over6 && tensorrt_llm::common::getEnvNVFP4Disable4Over6MSEFastMath();With that, the
TLLM_CHECK_WITH_INFO(!USE_4OVER6, ...)inside theT=floatbranch becomes unreachable and can be dropped (or kept defensively).🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@csrc/nv_internal/cpp/kernels/quantization.cu` around lines 338 - 362, The code reads the process-global use4Over6 unconditionally which causes FP32 instantiations to abort; fix by making the env-read T-aware: move or re-evaluate tensorrt_llm::common::getEnvNVFP4Use4Over6() into the template/lambda scope where T is visible (the launchKernel capture/instantiation) and force it false for T=float (e.g. compute auto const use4Over6 = tensorrt_llm::common::getEnvNVFP4Use4Over6() && !std::is_same_v<T,float> and pass that as the use4Over6Tag/std::bool_constant), then remove or leave the now-unreachable TLLM_CHECK_WITH_INFO(!USE_4OVER6, ...) in the float branch.
🧹 Nitpick comments (1)
tests/test_helpers/utils_fp4.py (1)
295-302: ⚡ Quick winVectorize the per-element MSE accumulation.
The explicit Python loop over
block_size=16is unnecessary work and obscures the intent. A vectorized form is shorter, faster, and (because the reduction order across the last dim is implementation-defined either way) preserves the strict<tiebreak onpick_four.♻️ Proposed refactor
- err4 = torch.zeros((m, n // block_size), dtype=torch.float32, device=x.device) - err6 = torch.zeros((m, n // block_size), dtype=torch.float32, device=x.device) - for i in range(block_size): - diff4 = dq4[:, :, i] - x_blocks[:, :, i] - diff6 = dq6[:, :, i] - x_blocks[:, :, i] - err4 += diff4 * diff4 - err6 += diff6 * diff6 - pick_four = err4 < err6 + err4 = ((dq4 - x_blocks) ** 2).sum(dim=-1) + err6 = ((dq6 - x_blocks) ** 2).sum(dim=-1) + pick_four = err4 < err6🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/test_helpers/utils_fp4.py` around lines 295 - 302, The loop computes per-block MSE by accumulating squared differences across the last dim; replace the explicit for-loop with a vectorized reduction: compute diff4 = dq4 - x_blocks and diff6 = dq6 - x_blocks, square them and sum over the last axis to produce err4 and err6, then set pick_four = err4 < err6 (preserving the strict < tiebreak). Update variables err4, err6, diff4, diff6 and use the existing dq4, dq6, x_blocks, and pick_four names so the change is localized to that block.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/moe/test_trtllm_gen_fused_moe.py`:
- Around line 2680-2687: run_moe_test currently only uses use_4over6 for
skipping but never actually sets the process env, so 4over6 paths may not be
exercised; wrap the quantize/reference/production section inside the
set_nvfp4_4over6_env context by calling set_nvfp4_4over6_env(use_4over6) (and
ensure the helper is imported) before entering the FP4
quantize/reference/production logic in run_moe_test and restore/unset it after
that block so the FLASHINFER_NVFP4_4OVER6 env state is consistently applied only
for those test cases.
In `@tests/moe/test_trtllm_gen_moe_autotune_tactics.py`:
- Around line 160-169: The test never actually enables the 4over6 NVFP4 runtime
flag because set_nvfp4_4over6_env is never applied; update the test harness so
that when _quant_mode_config is called with use_4over6=True the runtime
environment is toggled for the kernel run: call set_nvfp4_4over6_env(True)
before invoking _run_kernel_with_tactic (and set_nvfp4_4over6_env(False) or
restore the previous state after) so the launched kernel uses the 4over6 path;
adjust every place that constructs the use_4over6=True matrix (including the
other occurrences you noted) to wrap the kernel invocation with the env setter
rather than only changing scales.
In `@tests/moe/test_trtllm_gen_routed_fused_moe.py`:
- Around line 82-83: The test toggles use_4over6 but never actually flips the
NVFP4 4over6 environment, so fp4_quantize() and the routed/non-routed MoE kernel
calls still use the global env; fix by wrapping the sections that perform FP4
quantization and invoke the MoE kernels (references: fp4_quantize, the routed
MoE kernel call(s) and the non-routed MoE kernel call(s)) in the
set_nvfp4_4over6_env context when use_4over6 is True (e.g., with
set_nvfp4_4over6_env(): ...) so the env is applied for those operations and is
restored afterward; apply this same wrapping to the other similar test blocks
currently duplicated later in the file.
In `@tests/moe/utils.py`:
- Around line 40-65: The fixture set_nvfp4_4over6_env currently force-sets
TRTLLM_DISABLE_FP4_QUANT_FAST_MATH and
FLASHINFER_NVFP4_4OVER6_DISABLE_MSE_FAST_MATH unconditionally; change it so
those two env vars are only set when request.getfixturevalue("use_4over6") is
truthy (i.e., set them inside the branch where use_4over6 is True and leave them
untouched when False), while still recording original_values and restoring them
after yield; keep FLASHINFER_NVFP4_4OVER6 set to "1"/"0" based on use_4over6 as
before.
---
Outside diff comments:
In `@tests/moe/test_trtllm_gen_per_token_moe.py`:
- Around line 114-134: The test only updates Python-side scales via
nvfp4_global_decode_scale_te but never flips the backend mode, so wrap the
quantize+kernel calls with the shared helper set_nvfp4_4over6_env(use_4over6) so
the backend is actually in 4over6 mode when calling nvfp4_quantize and
trtllm_fp4_block_scale_routed_moe; specifically, call
set_nvfp4_4over6_env(use_4over6) around the block that computes
hidden_states/hidden_states_scale/per_token_scale_inv with nvfp4_quantize and
the subsequent trtllm_fp4_block_scale_routed_moe invocation so both scale
computation and kernel execution use the same mode (references:
nvfp4_global_decode_scale_te, nvfp4_quantize, set_nvfp4_4over6_env,
trtllm_fp4_block_scale_routed_moe).
---
Duplicate comments:
In `@csrc/nv_internal/cpp/kernels/quantization.cu`:
- Around line 338-362: The code reads the process-global use4Over6
unconditionally which causes FP32 instantiations to abort; fix by making the
env-read T-aware: move or re-evaluate
tensorrt_llm::common::getEnvNVFP4Use4Over6() into the template/lambda scope
where T is visible (the launchKernel capture/instantiation) and force it false
for T=float (e.g. compute auto const use4Over6 =
tensorrt_llm::common::getEnvNVFP4Use4Over6() && !std::is_same_v<T,float> and
pass that as the use4Over6Tag/std::bool_constant), then remove or leave the
now-unreachable TLLM_CHECK_WITH_INFO(!USE_4OVER6, ...) in the float branch.
---
Nitpick comments:
In `@tests/test_helpers/utils_fp4.py`:
- Around line 295-302: The loop computes per-block MSE by accumulating squared
differences across the last dim; replace the explicit for-loop with a vectorized
reduction: compute diff4 = dq4 - x_blocks and diff6 = dq6 - x_blocks, square
them and sum over the last axis to produce err4 and err6, then set pick_four =
err4 < err6 (preserving the strict < tiebreak). Update variables err4, err6,
diff4, diff6 and use the existing dq4, dq6, x_blocks, and pick_four names so the
change is localized to that block.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 07674709-c056-41fd-8bc9-27c3e59e1102
📒 Files selected for processing (14)
csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuhcsrc/nv_internal/cpp/common/envUtils.cppcsrc/nv_internal/cpp/kernels/quantization.cucsrc/nv_internal/tensorrt_llm/common/envUtils.hcsrc/nv_internal/tensorrt_llm/kernels/quantization.cuhcsrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuhtests/moe/test_trtllm_cutlass_fused_moe.pytests/moe/test_trtllm_gen_fused_moe.pytests/moe/test_trtllm_gen_moe_autotune_tactics.pytests/moe/test_trtllm_gen_per_token_moe.pytests/moe/test_trtllm_gen_routed_fused_moe.pytests/moe/utils.pytests/test_helpers/utils_fp4.pytests/utils/test_fp4_quantize.py
🚧 Files skipped from review as they are similar to previous changes (2)
- csrc/nv_internal/cpp/common/envUtils.cpp
- csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh
|
/bot run |
📌 Description
@HumansAnd
Implement 4over6 nvfp4 from:
TE PR:
Both original nvfp4 and per-token nvfp4 quantizer and moe are supported.
The results is bitwise exact with reference implementation by enabling:
FLASHINFER_NVFP4_4OVER6_DISABLE_MSE_FAST_MATH=1TRTLLM_DISABLE_FP4_QUANT_FAST_MATH=1Under strict no fast math mode, the quantizer is bitwise exact with pytorch reference implementation.
Need to rebase after:
Future work:
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Release Notes
New Features
FLASHINFER_NVFP4_4OVER6,TRTLLM_DISABLE_FP4_QUANT_FAST_MATH,FLASHINFER_NVFP4_4OVER6_DISABLE_MSE_FAST_MATH) for quantization behavior customization.Tests