perf: optimize per-token nvfp4 quantization kernel.#3237
perf: optimize per-token nvfp4 quantization kernel.#3237aleozlx merged 57 commits intoflashinfer-ai:mainfrom
Conversation
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
📝 WalkthroughWalkthroughThis PR introduces runtime control over FP4 quantization fast-math behavior via a new environment variable ChangesFP4 Quantization Fast-Math Runtime Control
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a mechanism to disable fast FP4 quantization math via an environment variable and implements shared memory caching for input vectors in NVFP4 kernels to improve performance. The review feedback highlights several critical concerns regarding potential integer overflows in memory offset calculations where uint32_t was used instead of int64_t, which could lead to out-of-bounds memory access on large tensors. Additionally, the reviewer suggests caching environment variable lookups for better efficiency and resolving inconsistencies between code comments and the implementation of shared memory caching.
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
|
/bot run |
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh (1)
392-415:⚠️ Potential issue | 🟠 Major | ⚡ Quick winComplete the template parameter rename in the sibling function.
cvt_warp_fp16_to_fp4(line 292) was renamed to useDISABLE_FP4_QUANT_FAST_MATH, but its siblingcvt_warp_fp16_to_fp4_with_vec_max(line 393) still uses the oldTE_EXACT_NVFP4template parameter and branch check (line 415). Both functions implement the same compile-time switch. Rename both occurrences incvt_warp_fp16_to_fp4_with_vec_maxto maintain consistency and avoid confusion for callers using named template arguments.♻️ Proposed rename
template <class Type, int SF_VEC_SIZE, int CVT_ELTS_PER_THREAD, bool UE8M0_SF, - bool TE_EXACT_NVFP4 = false> + bool DISABLE_FP4_QUANT_FAST_MATH = false> __device__ std::conditional_t<CVT_ELTS_PER_THREAD == 16, uint64_t, uint32_t> cvt_warp_fp16_to_fp4_with_vec_max(PackedVec<Type, CVT_ELTS_PER_THREAD>& vec, float SFScaleVal, float reciprocalSFScaleVal, float vecMax, uint8_t* SFout) { @@ - } else if constexpr (TE_EXACT_NVFP4) { + } else if constexpr (DISABLE_FP4_QUANT_FAST_MATH) {🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh` around lines 392 - 415, The sibling function cvt_warp_fp16_to_fp4_with_vec_max still declares and checks the old template parameter TE_EXACT_NVFP4; update its template parameter list and the corresponding constexpr branch (the occurrence after the UE8M0_SF branch) to use the new name DISABLE_FP4_QUANT_FAST_MATH so it matches cvt_warp_fp16_to_fp4 and the compile-time switch is consistent for callers using named template arguments.
🧹 Nitpick comments (1)
csrc/nv_internal/cpp/kernels/quantization.cu (1)
239-250: 💤 Low valueTypo in macro name:
NVP4should beNVFP4.The macro is named
DISPATCH_NVP4_QUANT_AND_PER_TOKEN_SCALE_KERNEL(missing theF) and is referenced by the same misspelling at lines 275, 277, 282, 284, 289, 291. Worth fixing while the surface is fresh, since the macro is local to this file.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@csrc/nv_internal/cpp/kernels/quantization.cu` around lines 239 - 250, The macro name is misspelled: rename the macro DISPATCH_NVP4_QUANT_AND_PER_TOKEN_SCALE_KERNEL to DISPATCH_NVFP4_QUANT_AND_PER_TOKEN_SCALE_KERNEL (note the added "F") and update all local references that use the misspelled identifier (e.g., the call sites at the later references in this file) so they match the new name; ensure you change both the macro definition and every usage of DISPATCH_NVP4_QUANT_AND_PER_TOKEN_SCALE_KERNEL to DISPATCH_NVFP4_QUANT_AND_PER_TOKEN_SCALE_KERNEL so the dispatcher resolves correctly.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/utils/test_fp4_quantize.py`:
- Around line 587-595: The triple-quoted string meant as the test docstring is
placed after set_te_reference_test_env("1") so it becomes a discarded string
expression; move that triple-quoted string to be the very first statement in the
test function (so it is the actual docstring for
test_nvfp4_per_token_quantize_te_reference) and remove the unreachable skip
block that checks os.getenv("TRTLLM_DISABLE_FP4_QUANT_FAST_MATH", "0") == "0"
(since set_te_reference_test_env forces the env var to "1"), leaving only the
valid device/FP4 support skip using _is_fp4_supported(torch.device(device)).
- Around line 553-571: The C++ function getEnvDisableFP4QuantFastMath() caches
the env var in a static bool so changing os.environ in the Python fixture
set_te_reference_test_env() doesn't affect the kernel after the first use
(invokeNvfp4QuantAndPerTokenScale), making tests order-dependent; fix by either
(preferred) adding a C++ API to reset/re-read that cached value (e.g., expose a
resetEnvDisableFP4QuantFastMath() or make getEnvDisableFP4QuantFastMath() read
the env dynamically) and call that reset from the fixture
(set_te_reference_test_env) after changing the env, or alternatively document
the ordering constraint in the fixture docstring; also move the dangling
triple-quoted string into the function start so it becomes a real docstring for
set_te_reference_test_env.
---
Outside diff comments:
In `@csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh`:
- Around line 392-415: The sibling function cvt_warp_fp16_to_fp4_with_vec_max
still declares and checks the old template parameter TE_EXACT_NVFP4; update its
template parameter list and the corresponding constexpr branch (the occurrence
after the UE8M0_SF branch) to use the new name DISABLE_FP4_QUANT_FAST_MATH so it
matches cvt_warp_fp16_to_fp4 and the compile-time switch is consistent for
callers using named template arguments.
---
Nitpick comments:
In `@csrc/nv_internal/cpp/kernels/quantization.cu`:
- Around line 239-250: The macro name is misspelled: rename the macro
DISPATCH_NVP4_QUANT_AND_PER_TOKEN_SCALE_KERNEL to
DISPATCH_NVFP4_QUANT_AND_PER_TOKEN_SCALE_KERNEL (note the added "F") and update
all local references that use the misspelled identifier (e.g., the call sites at
the later references in this file) so they match the new name; ensure you change
both the macro definition and every usage of
DISPATCH_NVP4_QUANT_AND_PER_TOKEN_SCALE_KERNEL to
DISPATCH_NVFP4_QUANT_AND_PER_TOKEN_SCALE_KERNEL so the dispatcher resolves
correctly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 675e3764-bc08-4651-a522-870fb4c63ad0
📒 Files selected for processing (7)
csrc/nv_internal/cpp/common/envUtils.cppcsrc/nv_internal/cpp/kernels/quantization.cucsrc/nv_internal/tensorrt_llm/common/envUtils.hcsrc/nv_internal/tensorrt_llm/kernels/quantization.cuhcsrc/nv_internal/tensorrt_llm/kernels/quantization.hcsrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuhtests/utils/test_fp4_quantize.py
| @pytest.fixture | ||
| def set_te_reference_test_env(): | ||
| """Fixture to set and reset TRTLLM_DISABLE_FP4_QUANT_FAST_MATH environment variable.""" | ||
| original_value = os.environ.get("TRTLLM_DISABLE_FP4_QUANT_FAST_MATH", None) | ||
|
|
||
| def _set_algo(algo: str): | ||
| if algo == "auto": | ||
| os.environ.pop("TRTLLM_DISABLE_FP4_QUANT_FAST_MATH", None) | ||
| else: | ||
| os.environ["TRTLLM_DISABLE_FP4_QUANT_FAST_MATH"] = algo | ||
|
|
||
| yield _set_algo | ||
|
|
||
| # Restore original value | ||
| if original_value is None: | ||
| os.environ.pop("TRTLLM_DISABLE_FP4_QUANT_FAST_MATH", None) | ||
| else: | ||
| os.environ["TRTLLM_DISABLE_FP4_QUANT_FAST_MATH"] = original_value | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Confirm the env getter is cached and locate all call sites of the per-token dispatch.
rg -nP -C2 'getEnvDisableFP4QuantFastMath'
rg -nP -C2 'invokeNvfp4QuantAndPerTokenScale\b'
rg -nP -C2 'per_token_activation\s*=\s*True' tests/Repository: flashinfer-ai/flashinfer
Length of output: 6776
🏁 Script executed:
sed -n '553,595p' tests/utils/test_fp4_quantize.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 2066
C++ static cache defeats fixture's "restore" semantics, causing test order fragility.
getEnvDisableFP4QuantFastMath() in csrc/nv_internal/cpp/common/envUtils.cpp initializes a static bool const on first call, so the C++ side captures whatever the env var is at the moment of the first per-token NVFP4 dispatch and never re-reads it. Consequences:
- The "restore original value" branch (lines 567–571) only restores
os.environ; the C++ runtime keeps the cached value for the rest of the process. - If any earlier test in the run dispatches
invokeNvfp4QuantAndPerTokenScalebefore this fixture sets the env var, the cache is locked and the fixture's env var manipulation becomes a no-op for the kernel. - Conversely, once this test runs first and locks the cache to a value, all later tests using per-token activation (e.g., at
test_fp4_quantize.py:636,688ortest_trtllm_gen_per_token_moe.py:113) will permanently use that cached path.
This works only due to pytest's declaration ordering; it breaks with test sharding, -k filtering, or pytest-randomly. Either document this constraint on the fixture or refactor the kernel to read the env var dynamically when needed for tests.
Minor: Line 36 is a dangling string literal, not a docstring. The first executable statement in the function is at line 35 (set_te_reference_test_env("1")), so the triple-quoted string on line 36 is discarded at runtime. Move it to line 35 or restructure as a proper docstring.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/utils/test_fp4_quantize.py` around lines 553 - 571, The C++ function
getEnvDisableFP4QuantFastMath() caches the env var in a static bool so changing
os.environ in the Python fixture set_te_reference_test_env() doesn't affect the
kernel after the first use (invokeNvfp4QuantAndPerTokenScale), making tests
order-dependent; fix by either (preferred) adding a C++ API to reset/re-read
that cached value (e.g., expose a resetEnvDisableFP4QuantFastMath() or make
getEnvDisableFP4QuantFastMath() read the env dynamically) and call that reset
from the fixture (set_te_reference_test_env) after changing the env, or
alternatively document the ordering constraint in the fixture docstring; also
move the dangling triple-quoted string into the function start so it becomes a
real docstring for set_te_reference_test_env.
| set_te_reference_test_env("1") | ||
| """Per-token NVFP4 quantization should match the TE Python reference bitwise.""" | ||
| if not _is_fp4_supported(torch.device(device)): | ||
| pytest.skip("Nvfp4 Requires compute capability >= 10 and CUDA >= 12.8") | ||
| if os.getenv("TRTLLM_DISABLE_FP4_QUANT_FAST_MATH", "0") == "0": | ||
| pytest.skip( | ||
| "Environment variable TRTLLM_DISABLE_FP4_QUANT_FAST_MATH is not set or false, " | ||
| "skipping test_nvfp4_per_token_quantize_te_reference." | ||
| ) |
There was a problem hiding this comment.
Misplaced docstring and unreachable skip block.
Two issues in this body:
- Docstring is now a no-op:
set_te_reference_test_env("1")is the first statement in the function, so the triple-quoted string on line 588 is no longer interpreted as a docstring — it is a discarded string expression and won't show up in--collect-only/--co -qlistings orpydoc. - Dead skip: with the fixture forcing the env var to
"1"on line 587, the check at lines 591-595 can never evaluate toTrue, so the skip is unreachable. (This appears to address the prior request to override the env var inside the test, but the leftover skip should be removed.)
🔧 Suggested cleanup
def test_nvfp4_per_token_quantize_te_reference(
dtype: torch.dtype,
shape: tuple[int, int],
is_sf_swizzled_layout: bool,
init_data: str,
device: str,
set_te_reference_test_env,
) -> None:
- set_te_reference_test_env("1")
"""Per-token NVFP4 quantization should match the TE Python reference bitwise."""
+ set_te_reference_test_env("1")
if not _is_fp4_supported(torch.device(device)):
pytest.skip("Nvfp4 Requires compute capability >= 10 and CUDA >= 12.8")
- if os.getenv("TRTLLM_DISABLE_FP4_QUANT_FAST_MATH", "0") == "0":
- pytest.skip(
- "Environment variable TRTLLM_DISABLE_FP4_QUANT_FAST_MATH is not set or false, "
- "skipping test_nvfp4_per_token_quantize_te_reference."
- )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| set_te_reference_test_env("1") | |
| """Per-token NVFP4 quantization should match the TE Python reference bitwise.""" | |
| if not _is_fp4_supported(torch.device(device)): | |
| pytest.skip("Nvfp4 Requires compute capability >= 10 and CUDA >= 12.8") | |
| if os.getenv("TRTLLM_DISABLE_FP4_QUANT_FAST_MATH", "0") == "0": | |
| pytest.skip( | |
| "Environment variable TRTLLM_DISABLE_FP4_QUANT_FAST_MATH is not set or false, " | |
| "skipping test_nvfp4_per_token_quantize_te_reference." | |
| ) | |
| """Per-token NVFP4 quantization should match the TE Python reference bitwise.""" | |
| set_te_reference_test_env("1") | |
| if not _is_fp4_supported(torch.device(device)): | |
| pytest.skip("Nvfp4 Requires compute capability >= 10 and CUDA >= 12.8") |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/utils/test_fp4_quantize.py` around lines 587 - 595, The triple-quoted
string meant as the test docstring is placed after
set_te_reference_test_env("1") so it becomes a discarded string expression; move
that triple-quoted string to be the very first statement in the test function
(so it is the actual docstring for test_nvfp4_per_token_quantize_te_reference)
and remove the unreachable skip block that checks
os.getenv("TRTLLM_DISABLE_FP4_QUANT_FAST_MATH", "0") == "0" (since
set_te_reference_test_env forces the env var to "1"), leaving only the valid
device/FP4 support skip using _is_fp4_supported(torch.device(device)).
📌 Description
Optimize the performance of the per-token nvfp4 quantization kernel introduced by #3027.
TE_EXACT_FP4toTRTLLM_DISABLE_FP4_QUANT_FAST_MATHand controlled by environmental variable.get_sf_out_offset_128x4andget_sf_out_offset_8x4.TODOs:
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests