Skip to content

MXFP4 x BF16 CUTLASS MoE backend perf and profiling improvement on Hopper#12451

Open
StudyingShao wants to merge 6 commits intoNVIDIA:mainfrom
StudyingShao:jiangs/1.3.0rc3/opt_hopper_mix_dtype_moe
Open

MXFP4 x BF16 CUTLASS MoE backend perf and profiling improvement on Hopper#12451
StudyingShao wants to merge 6 commits intoNVIDIA:mainfrom
StudyingShao:jiangs/1.3.0rc3/opt_hopper_mix_dtype_moe

Conversation

@StudyingShao
Copy link
Copy Markdown
Collaborator

@StudyingShao StudyingShao commented Mar 23, 2026

Summary by CodeRabbit

Release Notes

  • New Features

    • Added INT4→FP8 and FP4→BF16 quantization conversion support with lookup tables.
    • Added weight interleaving utilities for Hopper GPU optimization.
    • Introduced tactic description retrieval for fused MoE kernels.
    • Enhanced MoE quantization support for W4A8 and MXFP4 methods.
  • Bug Fixes

    • Improved SM90 kernel scheduler configuration and heuristics.
  • Tests

    • Added comprehensive test suite for fused MoE quantization kernels.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 23, 2026

📝 Walkthrough

Walkthrough

This PR introduces Int4→FP8 lookup-table conversion capabilities for SM90 Hopper mixed-precision GEMMs, including new tensormap management logic, A operand loading pipeline refactoring, weight interleaving utilities, and comprehensive test coverage for fused MoE configurations.

Changes

Cohort / File(s) Summary
CUTLASS Mixed-Input Utility Extensions
cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp
Added FP8 E4M3s LUT constants and device converters for FP4→BF16 conversion; implemented Int4→FP8 LUT conversion via psx_cvt_lut_prmt_int4x8_to_fp8x8(...) with register-bit LUT constants; added UseInt4ToFP8LookupTable compile-time flag; introduced copy_tensors_A(...) and copy_tensors_SFA(...) helper copy functions with bounds guards; integrated routing logic in convert_A_kblock(...) to select Int4→FP8 lookup path when enabled.
SM90 MMA Collective Kernel Spec
cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp
Added UseInt4ToFP8LookupTable and TensormapUpdateShapesStridesForAandScale flags; introduced scale_convertor(...) device helper for scale type conversion; refactored A operand loading pipeline with LDSM-based copy/retiling (smem_tiled_copy_A_LDSM, tCsA_LDSM, tCrA_copy_view_LDSM); adjusted deferred scaling with group-wise chunk-offset logic; enhanced tensormaps_replace_global_address(...) signature to accept input_tensormaps tuple and conditionally update shape/stride; updated tensormaps_cp_fence_release(...) with conditional shared-memory vs. global-memory release logic.
Kernel Scheduling Configuration
cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp
Modified SM90 W4A-FP8 mainloop schedule candidate generation to start empty instead of defaulting to PINGPONG; conditionally adds COOPERATIVE for coop-supported tiles (excluding CtaShape128x128x128B) or PINGPONG for non-coop-supported tiles.
MOE GEMM Launcher & Infrastructure
cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl, cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.h, cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu
Added persistent tile scheduler configuration with RasterOrderOptions and max_swizzle_size=2 in launcher; introduced two Hopper weight interleaving utilities: interleave_fp4_weights_for_Hopper_mixed_gemm(...) and interleave_int4_weights_for_Hopper_mixed_gemm(...) as CUDA kernels with nibble/16-bit reinterpretation logic and host wrapper functions.
MOE Kernel Wrappers & Python Integration
cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu, cpp/tensorrt_llm/thop/moeOp.cpp, cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp
Marked unused variables with [[maybe_unused]] in moe_kernels.cu; added protected FusedMoeRunner::getTacticDesc(...) method with Torch binding in moeOp.cpp; introduced trtllm::interleave_4bit_weights_for_Hopper_mixed_gemm Torch operator in weightOnlyQuantOp.cpp with type validation and kernel dispatch.
Python Quantization & Testing
tensorrt_llm/_torch/modules/fused_moe/quantization.py, tests/unittest/_torch/modules/test_fused_moe_jiangs.py
Refactored SM90 weight loading conditionals and added post-processing step invoking interleave_4bit_weights_for_Hopper_mixed_gemm(...) for 2D shard tensors; added comprehensive test module with three GPU-focused MoE fused-kernel tests covering W4A-FP8, WFP4A16, and Triton MXFP4 variants, including reference implementations and AutoTuner kernel tactic capture/replay.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Suggested reviewers

  • nvchenghaoz
  • yuxianq
  • liji-nv
  • QiJune
🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description is essentially the template with no substantive content filled in; the author provided only the template structure without explaining what was changed, why it was changed, or what tests validate the changes. Fill in the Description and Test Coverage sections with substantive details about the changes and the tests that validate them.
Docstring Coverage ⚠️ Warning Docstring coverage is 18.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title 'MXFP4 x BF16 CUTLASS MoE backend perf and profiling improvement on Hopper' clearly summarizes the main change: performance and profiling improvements for a mixed-dtype MoE backend.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

Important

Merge conflicts detected (Beta)

  • Resolve merge conflict in branch jiangs/1.3.0rc3/opt_hopper_mix_dtype_moe
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 12

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu (1)

2-2: ⚠️ Potential issue | 🟠 Major

Update copyright year for a modified source file.

This file was modified in 2026, but the header still ends at 2025.

Proposed fix
- * Copyright (c) 2020-2025, NVIDIA CORPORATION.  All rights reserved.
+ * Copyright (c) 2020-2026, NVIDIA CORPORATION.  All rights reserved.

As per coding guidelines, "Add NVIDIA copyright header on ALL new files, and update year on modified files".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu` at line 2,
Update the copyright header in the top-of-file comment in moe_kernels.cu to
include 2026 (e.g., change "2020-2025" to "2020-2026") so the modified source
file reflects the current year; locate the header comment near the top of
cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu (the file-level
copyright comment) and update the year range accordingly.
cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp (1)

2-2: ⚠️ Potential issue | 🟠 Major

Update copyright year for this modified file.

The header still ends at 2023.

Proposed fix
- * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
+ * Copyright (c) 2020-2026, NVIDIA CORPORATION.  All rights reserved.

As per coding guidelines, "Add NVIDIA copyright header on ALL new files, and update year on modified files".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp` at line 2,
Update the file header in cutlass_heuristic.cpp to reflect the current
modified-year range and include the NVIDIA copyright header; specifically change
the existing "2020-2023" year span to include the current year (e.g.,
"2020-2026") and ensure the full NVIDIA copyright header text used across the
repo is present and matches formatting of other modified files.
cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl (1)

2-2: ⚠️ Potential issue | 🟠 Major

Update copyright year for this modified file.

The header currently ends at 2023.

Proposed fix
- * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
+ * Copyright (c) 2020-2026, NVIDIA CORPORATION.  All rights reserved.

As per coding guidelines, "Add NVIDIA copyright header on ALL new files, and update year on modified files".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl`
at line 2, Update the copyright header in the file
moe_gemm_tma_ws_mixed_input_launcher.inl to reflect the current modification
year (change the existing "2020-2023" to include the current year), i.e., update
the top-of-file NVIDIA copyright comment block in
cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl
so the range covers through the current year; keep the existing header text and
formatting intact except for the year range.
🧹 Nitpick comments (4)
cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu (2)

4616-4616: Remove the unused use_w4_groupwise local.

At Line 4616, the value is computed and never read.

Proposed fix
-    [[maybe_unused]] bool use_w4_groupwise = use_w4afp8 || use_wfp4a16;

As per coding guidelines, "Avoid dead code in C++".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu` at line
4616, Remove the dead local variable declaration [[maybe_unused]] bool
use_w4_groupwise = use_w4afp8 || use_wfp4a16; in moe_kernels.cu (the computed
value is never read); simply delete this unused variable (or if the intention
was to use it later, replace unused declaration with the appropriate logic that
actually references use_w4_groupwise in the relevant control flow) so there is
no unused local symbol left.

2157-2158: Use const for immutable locals introduced in this change.

These locals are not reassigned after initialization.

Proposed fix
-            [[maybe_unused]] float global_scale_val = fc2_act_global_scale ? fc2_act_global_scale[act_scale_idx] : 1.0f;
-            [[maybe_unused]] int64_t num_tokens_before_expert = (IsNVFP4 || IsMXFP8) ? expert_first_token_offset[expert] : 0;
+            [[maybe_unused]] float const global_scale_val
+                = fc2_act_global_scale ? fc2_act_global_scale[act_scale_idx] : 1.0f;
+            [[maybe_unused]] int64_t const num_tokens_before_expert
+                = (IsNVFP4 || IsMXFP8) ? expert_first_token_offset[expert] : 0;
...
-            [[maybe_unused]] auto NVFP4 = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,
+            [[maybe_unused]] auto const NVFP4 = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,
                 TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4>{};
-            [[maybe_unused]] auto MXFPX = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,
+            [[maybe_unused]] auto const MXFPX = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,
                 TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX>{};
-            [[maybe_unused]] auto NONE = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,
+            [[maybe_unused]] auto const NONE = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,
                 TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE>{};

As per coding guidelines, "A variable that is not modified after its initialization should be declared as const".

Also applies to: 2402-2407

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu` around
lines 2157 - 2158, Declare the local variables that are not modified after
initialization as const: change the declarations of global_scale_val and
num_tokens_before_expert to use const (e.g., const float global_scale_val = ...;
const int64_t num_tokens_before_expert = ...;). Apply the same change to the
other immutable locals noted in the diff region around 2402-2407 so any
variables that are initialized once and never reassigned become const (use the
exact symbol names from those lines when updating).
cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl (1)

262-263: Replace hard-coded scheduler swizzle value with a named constant.

At Line 262, 2 should be expressed as a named constant.

Proposed fix
+    int constexpr maxSwizzleSize = 2;
-    arguments.scheduler.max_swizzle_size = 2;
+    arguments.scheduler.max_swizzle_size = maxSwizzleSize;
     arguments.scheduler.raster_order = RasterOrderOptions::Heuristic;

As per coding guidelines, "Except 0 (only used in comparison), nullptr, true, false, all other literals should only be used for variable initialization in C++. Other occurrences should use named constants".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl`
around lines 262 - 263, Replace the magic literal 2 used for
arguments.scheduler.max_swizzle_size with a named constant (e.g.,
kDefaultMaxSwizzleSize or MAX_SWIZZLE_SIZE) and use that constant in the
assignment to arguments.scheduler.max_swizzle_size; define the constant as a
constexpr integer with an appropriate scope (file-scope in this implementation
file or in the module header alongside other scheduler constants) so calls like
arguments.scheduler.max_swizzle_size = kDefaultMaxSwizzleSize; replace the
hard-coded 2 and keep the adjacent use of RasterOrderOptions::Heuristic
unchanged.
tests/unittest/_torch/modules/test_fused_moe_jiangs.py (1)

667-680: Avoid a shared ModelConfig default.

Line 674 constructs one ModelConfig at import time. If any test mutates it, later RefGatedMLPFusedMoE instances inherit that state.

♻️ Proposed fix
     def __init__(
         self,
         num_experts: int,
         routing_method: BaseMoeRoutingMethod,
         hidden_size: int,
         intermediate_size: int,
         dtype: Optional[torch.dtype] = None,
-        model_config: ModelConfig = ModelConfig(),
+        model_config: Optional[ModelConfig] = None,
         use_cute_dsl_blockscaling_mm: bool = False,
         bias=False,
         swiglu_alpha: Optional[float] = None,
         swiglu_beta: Optional[float] = None,
         swiglu_limit: Optional[float] = None,
     ):
         super().__init__()
+        model_config = ModelConfig() if model_config is None else model_config
         self.num_experts = num_experts
         self.routing_method = routing_method
         self.hidden_size = hidden_size
         self.intermediate_size = intermediate_size
         self.bias = bias
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/modules/test_fused_moe_jiangs.py` around lines 667 -
680, The constructor currently uses a shared ModelConfig default (model_config:
ModelConfig = ModelConfig()), which can be mutated across tests; change the
signature to accept model_config: Optional[ModelConfig] = None and inside
__init__ set self.model_config = model_config if model_config is not None else
ModelConfig() so each instance gets a fresh ModelConfig; update any references
to model_config in the constructor/body accordingly (look for the __init__
method and the symbol model_config in this class).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In
`@cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp`:
- Around line 1493-1527: The zero-mode branch for
ConversionMode::ConvertAndScaleWithZero updates only the zero tensormap but
omits updating the scale tensormap, causing tma_load_scale to read stale or
pointer-table data; in the functions that call
cute::tma_descriptor_replace_addr_in_shared_mem /
cute::tma_descriptor_replace_addr_in_global_mem (look for
tensormaps_replace_global_address/static_assert text and the branches using
KernelConversionMode), add the same scale update as in the ConvertAndScale
branch: replace the shared/global smem_tensormap_scale/get<2>(input_tensormaps)
with mainloop_params.ptr_S[next_batch] wherever the ConvertAndScaleWithZero
branch touches A/zero metadata, and also ensure smem_tensormap_scale is
initialized/staged inside tensormaps_init() so per-batch scale updates exist for
grouped GEMM; repeat the same fix in the equivalent locations (the other similar
blocks noted in the review).
- Around line 1015-1016: The LDSM thread-slice is using the global thread_idx
instead of the warpgroup-local index, causing incorrect lane mapping when
multiple consumer warpgroups participate; change the call to
smem_tiled_copy_A_LDSM.get_thread_slice(thread_idx) to use
get_thread_slice(warp_group_thread_idx) (consistent with the earlier A-copy path
and other SM90 collectives) so SmemCopyAtomA_LDSM / make_tiled_copy_A with
tiled_mma uses the warpgroup-local index for the slice.

In `@cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp`:
- Around line 313-314: The single-line if statement checking "if (tile_config ==
CutlassTileConfigSM90::CtaShape128x128x128B) continue;" should be rewritten to
use brace-delimited body per style guidelines: locate the check in
cutlass_heuristic.cpp (the loop handling tile_config values) and replace the
single-statement form with an explicit block using braces around the continue;
so the condition and its action are enclosed in { } to ensure consistent
formatting and avoid style warnings.

In `@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu`:
- Around line 104-106: The two wrapper functions that currently launch
interleave_fp4_weights_for_Hopper_mixed_gemm_kernel with <<<1024, block>>> must
accept a cudaStream_t stream parameter, use a computed grid (e.g., dim3
grid(1024,1,1) or existing grid variable) and launch with <<<grid, block, 0,
stream>>> instead of the implicit default stream, and then call
cudaGetLastError() (or cudaPeekAtLastError()) immediately after the launch to
catch failures; update both function signatures and all call sites (e.g., where
weightOnlyQuantOp.cpp invokes these wrappers) to pass the caller's stream, and
follow the same error-checking pattern used in unfusedAttentionKernels.cu /
cudaCoreGemmNVFP4.cu.

In `@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.h`:
- Line 17: Replace the lone `#pragma once` with the repository-required include
guard: add a preprocessor guard using the macro name
TRTLLM_MOE_GEMM_MIXED_UTILS_H (open with `#ifndef` TRTLLM_MOE_GEMM_MIXED_UTILS_H /
`#define` TRTLLM_MOE_GEMM_MIXED_UTILS_H and close with `#endif`), ensuring the macro
matches the header filename moe_gemm_mixed_utils.h in all caps and without
directory names or trailing underscores; update the top and bottom of the file
(moe_gemm_mixed_utils.h) accordingly so other symbols in this header (e.g.,
declarations used by moe_gemm_mixed_utils) remain unchanged.

In `@cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp`:
- Around line 412-435: Add explicit shape validation before computing n/k and
before dispatching to the Hopper kernels: ensure weight.dim() == 2,
weight.size(0) % 16 == 0, and weight.size(1) % 32 == 0, and emit TORCH_CHECK
errors if any condition fails. Place these checks at the start of the
function/block handling the public entrypoint (before the lines that compute n =
weight.size(0) and k = weight.size(1) * 2) so that inputs violating the kernel
tile assumptions are rejected prior to calling
interleave_int4_weights_for_Hopper_mixed_gemm or
interleave_fp4_weights_for_Hopper_mixed_gemm.
- Around line 416-427: The output tensor is allocated on the current CUDA device
instead of the input tensor's device, causing mismatched allocations on
multi-GPU; add `#include` <c10/cuda/CUDAGuard.h>, create a c10::cuda::CUDAGuard
guard{weight.device()}; (or at::cuda::CUDAGuard guard(weight.device().index());)
immediately before allocating weight_interleaved so the torch::empty call uses
the same device as weight, and ensure subsequent get_ptr<uint8_t>(weight) /
get_ptr<uint8_t>(weight_interleaved) dereference the correct device pointers.

In `@tensorrt_llm/_torch/modules/fused_moe/quantization.py`:
- Around line 1735-1737: The interleave op
torch.ops.trtllm.interleave_4bit_weights_for_Hopper_mixed_gemm is CUDA-only, so
before calling it move the FP4 shard tensors to CUDA (e.g., w31_weight_shard =
w31_weight_shard.to(dst_device or torch.device('cuda'), non_blocking=True) or
use .cuda()) and ensure the same for the other shard(s) used in the nearby block
(the similar call around the w13/w31 handling at the later location). Update the
code that calls interleave_4bit_weights_for_Hopper_mixed_gemm to stage the
shard(s) onto CUDA first (preserving dtype/contiguity) so the CUDA-only op runs
on GPU-hosted tensors.
- Around line 1341-1344: The bare .cuda() calls moving w31_weight_shard to the
current CUDA device break multi-GPU placement; replace those .cuda() uses (the
ones before calling
torch.ops.trtllm.interleave_4bit_weights_for_Hopper_mixed_gemm) with an explicit
.to(device) (using the previously captured device variable `device =
dst_w3_w1_weight.device`) so the tensor is moved to the intended device before
invoking the Hopper interleave operator.

In `@tests/unittest/_torch/modules/test_fused_moe_jiangs.py`:
- Around line 229-234: The reference dequant path loads q2 and q3 swapped: when
weight_loading_mode == MoEWeightLoadingMode.VANILLA set q2 and q3 from the
correct weight keys (q2 should come from f"{e_idx}.w2.{lut['weight_scale_2']}"
and q3 from f"{e_idx}.w3.{lut['weight_scale_2']}") so fc2 receives the proper
tensor and q3_q1 combines w3 with w1 for fc13; update both occurrences (the
block around q1/q2/q3/q3_q1 and the similar block at lines ~263-275) to swap the
assignments accordingly while keeping q1 loaded from w1.
- Around line 466-469: The test currently forces a runner lookup with
next(iter(MoERunner.runner_dict.values())) which can raise StopIteration when
the backend doesn't populate MoERunner.runner_dict; change this to a safe
retrieval (e.g., use next(iter(...), None) or guard with a try/except) and
ensure any use of cpp_runner (the commented diagnostic block) only runs when
cpp_runner is not None so the test won't error when the registry is empty.
- Around line 1-3: This test module is missing the required NVIDIA Apache 2.0
header; add the repository-mandated NVIDIA copyright/license block at the very
top of the file (above the import lines) using the current year of the latest
meaningful modification and the standard Apache-2.0 header format used across
the repo so the file (test_fused_moe_jiangs.py) includes the exact NVIDIA
copyright and Apache 2.0 license block.

---

Outside diff comments:
In `@cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp`:
- Line 2: Update the file header in cutlass_heuristic.cpp to reflect the current
modified-year range and include the NVIDIA copyright header; specifically change
the existing "2020-2023" year span to include the current year (e.g.,
"2020-2026") and ensure the full NVIDIA copyright header text used across the
repo is present and matches formatting of other modified files.

In
`@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl`:
- Line 2: Update the copyright header in the file
moe_gemm_tma_ws_mixed_input_launcher.inl to reflect the current modification
year (change the existing "2020-2023" to include the current year), i.e., update
the top-of-file NVIDIA copyright comment block in
cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl
so the range covers through the current year; keep the existing header text and
formatting intact except for the year range.

In `@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu`:
- Line 2: Update the copyright header in the top-of-file comment in
moe_kernels.cu to include 2026 (e.g., change "2020-2025" to "2020-2026") so the
modified source file reflects the current year; locate the header comment near
the top of cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu (the
file-level copyright comment) and update the year range accordingly.

---

Nitpick comments:
In
`@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl`:
- Around line 262-263: Replace the magic literal 2 used for
arguments.scheduler.max_swizzle_size with a named constant (e.g.,
kDefaultMaxSwizzleSize or MAX_SWIZZLE_SIZE) and use that constant in the
assignment to arguments.scheduler.max_swizzle_size; define the constant as a
constexpr integer with an appropriate scope (file-scope in this implementation
file or in the module header alongside other scheduler constants) so calls like
arguments.scheduler.max_swizzle_size = kDefaultMaxSwizzleSize; replace the
hard-coded 2 and keep the adjacent use of RasterOrderOptions::Heuristic
unchanged.

In `@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu`:
- Line 4616: Remove the dead local variable declaration [[maybe_unused]] bool
use_w4_groupwise = use_w4afp8 || use_wfp4a16; in moe_kernels.cu (the computed
value is never read); simply delete this unused variable (or if the intention
was to use it later, replace unused declaration with the appropriate logic that
actually references use_w4_groupwise in the relevant control flow) so there is
no unused local symbol left.
- Around line 2157-2158: Declare the local variables that are not modified after
initialization as const: change the declarations of global_scale_val and
num_tokens_before_expert to use const (e.g., const float global_scale_val = ...;
const int64_t num_tokens_before_expert = ...;). Apply the same change to the
other immutable locals noted in the diff region around 2402-2407 so any
variables that are initialized once and never reassigned become const (use the
exact symbol names from those lines when updating).

In `@tests/unittest/_torch/modules/test_fused_moe_jiangs.py`:
- Around line 667-680: The constructor currently uses a shared ModelConfig
default (model_config: ModelConfig = ModelConfig()), which can be mutated across
tests; change the signature to accept model_config: Optional[ModelConfig] = None
and inside __init__ set self.model_config = model_config if model_config is not
None else ModelConfig() so each instance gets a fresh ModelConfig; update any
references to model_config in the constructor/body accordingly (look for the
__init__ method and the symbol model_config in this class).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 8cd5f7bd-7d25-496d-86b6-147cca74e2fa

📥 Commits

Reviewing files that changed from the base of the PR and between 4f929fe and 79315f6.

📒 Files selected for processing (11)
  • cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp
  • cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp
  • cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp
  • cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl
  • cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu
  • cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.h
  • cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu
  • cpp/tensorrt_llm/thop/moeOp.cpp
  • cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp
  • tensorrt_llm/_torch/modules/fused_moe/quantization.py
  • tests/unittest/_torch/modules/test_fused_moe_jiangs.py

Comment on lines +1015 to +1016
auto smem_tiled_copy_A_LDSM = make_tiled_copy_A(SmemCopyAtomA_LDSM{}, tiled_mma);
auto smem_thr_copy_A_LDSM = smem_tiled_copy_A_LDSM.get_thread_slice(thread_idx);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "Current collective:"
rg -nP 'make_tiled_copy_A\(|get_thread_slice\((thread_idx|warp_group_thread_idx)\)' \
  cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp \
  -A3 -B3

echo
echo "Nearby SM90 collectives:"
rg -nP 'get_thread_slice\((thread_idx|warp_group_thread_idx)\)' \
  cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective \
  -A2 -B2

Repository: NVIDIA/TensorRT-LLM

Length of output: 8930


Use the warpgroup-local thread index for the LDSM slice.

The existing A-copy path on line 993 uses get_thread_slice(warp_group_thread_idx), but the new LDSM path on line 1016 incorrectly uses thread_idx. When multiple consumer warpgroups participate, later warpgroups will select the wrong copy lane mapping and read incorrect A fragments. All similar patterns in nearby SM90 collectives use warp_group_thread_idx.

Fix
-        auto smem_thr_copy_A_LDSM = smem_tiled_copy_A_LDSM.get_thread_slice(thread_idx);
+        auto smem_thr_copy_A_LDSM = smem_tiled_copy_A_LDSM.get_thread_slice(warp_group_thread_idx);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
auto smem_tiled_copy_A_LDSM = make_tiled_copy_A(SmemCopyAtomA_LDSM{}, tiled_mma);
auto smem_thr_copy_A_LDSM = smem_tiled_copy_A_LDSM.get_thread_slice(thread_idx);
auto smem_tiled_copy_A_LDSM = make_tiled_copy_A(SmemCopyAtomA_LDSM{}, tiled_mma);
auto smem_thr_copy_A_LDSM = smem_tiled_copy_A_LDSM.get_thread_slice(warp_group_thread_idx);
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp`
around lines 1015 - 1016, The LDSM thread-slice is using the global thread_idx
instead of the warpgroup-local index, causing incorrect lane mapping when
multiple consumer warpgroups participate; change the call to
smem_tiled_copy_A_LDSM.get_thread_slice(thread_idx) to use
get_thread_slice(warp_group_thread_idx) (consistent with the earlier A-copy path
and other SM90 collectives) so SmemCopyAtomA_LDSM / make_tiled_copy_A with
tiled_mma uses the warpgroup-local index for the slice.

Comment on lines +1493 to +1527
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale)
{
cute::tma_descriptor_replace_addr_in_shared_mem(
shared_tensormaps.smem_tensormap_scale, mainloop_params.ptr_S[next_batch]);
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero)
{
cute::tma_descriptor_replace_addr_in_shared_mem(
shared_tensormaps.smem_tensormap_zero, mainloop_params.ptr_Z[next_batch]);
}
else if constexpr (KernelConversionMode != ConversionMode::DirectConvert)
{
static_assert(cutlass::detail::dependent_false<KernelSchedule>,
"Conversion mode not handled in tensormaps_replace_global_address.");
}
}
else if constexpr (KernelConversionMode != ConversionMode::DirectConvert)
else
{
static_assert(cutlass::detail::dependent_false<KernelSchedule>,
"Conversion mode not handled in tensormaps_replace_global_address.");
cute::tma_descriptor_replace_addr_in_global_mem(
get<0>(input_tensormaps), mainloop_params.ptr_A[next_batch]);
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale)
{
cute::tma_descriptor_replace_addr_in_global_mem(
get<2>(input_tensormaps), mainloop_params.ptr_S[next_batch]);
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero)
{
cute::tma_descriptor_replace_addr_in_global_mem(
get<3>(input_tensormaps), mainloop_params.ptr_Z[next_batch]);
}
else if constexpr (KernelConversionMode != ConversionMode::DirectConvert)
{
static_assert(cutlass::detail::dependent_false<KernelSchedule>,
"Conversion mode not handled in tensormaps_replace_global_address.");
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

ConvertAndScaleWithZero still drops the scale tensormap.

These branches only patch/release the zero descriptor. tma_load_scale is still consumed in zero-point mode, and for grouped GEMM its placeholder descriptor is built from the pointer table, so skipping the per-batch scale update means this path can read the pointer array or a stale expert’s scales. Mirror the scale handling anywhere the zero-mode branch currently touches A/zero metadata, and stage smem_tensormap_scale in tensormaps_init() too.

🐛 Minimum fix
             else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero)
             {
+                cute::tma_descriptor_replace_addr_in_shared_mem(
+                    shared_tensormaps.smem_tensormap_scale, mainloop_params.ptr_S[next_batch]);
                 cute::tma_descriptor_replace_addr_in_shared_mem(
                     shared_tensormaps.smem_tensormap_zero, mainloop_params.ptr_Z[next_batch]);
             }
...
             else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero)
             {
+                tma_descriptor_cp_fence_release(get<2>(input_tensormaps), shared_tensormaps.smem_tensormap_scale);
                 tma_descriptor_cp_fence_release(get<3>(input_tensormaps), shared_tensormaps.smem_tensormap_zero);
             }

Also applies to: 1571-1621, 1669-1676

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp`
around lines 1493 - 1527, The zero-mode branch for
ConversionMode::ConvertAndScaleWithZero updates only the zero tensormap but
omits updating the scale tensormap, causing tma_load_scale to read stale or
pointer-table data; in the functions that call
cute::tma_descriptor_replace_addr_in_shared_mem /
cute::tma_descriptor_replace_addr_in_global_mem (look for
tensormaps_replace_global_address/static_assert text and the branches using
KernelConversionMode), add the same scale update as in the ConvertAndScale
branch: replace the shared/global smem_tensormap_scale/get<2>(input_tensormaps)
with mainloop_params.ptr_S[next_batch] wherever the ConvertAndScaleWithZero
branch touches A/zero metadata, and also ensure smem_tensormap_scale is
initialized/staged inside tensormaps_init() so per-batch scale updates exist for
grouped GEMM; repeat the same fix in the equivalent locations (the other similar
blocks noted in the review).

Comment on lines +313 to +314
if (tile_config == CutlassTileConfigSM90::CtaShape128x128x128B)
continue;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use braces for the single-statement if body.

At Line 313-314, the if branch should be brace-delimited.

Proposed fix
-                if (tile_config == CutlassTileConfigSM90::CtaShape128x128x128B)
-                    continue;
+                if (tile_config == CutlassTileConfigSM90::CtaShape128x128x128B)
+                {
+                    continue;
+                }

As per coding guidelines, "In C++, if and else should always be followed by brace-delimited statements, even if empty or a single statement".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp` around lines
313 - 314, The single-line if statement checking "if (tile_config ==
CutlassTileConfigSM90::CtaShape128x128x128B) continue;" should be rewritten to
use brace-delimited body per style guidelines: locate the check in
cutlass_heuristic.cpp (the loop handling tile_config values) and replace the
single-statement form with an explicit block using braces around the continue;
so the condition and its action are enclosed in { } to ensure consistent
formatting and avoid style warnings.

Comment on lines +104 to +106
dim3 block(32, 32);
interleave_fp4_weights_for_Hopper_mixed_gemm_kernel<<<1024, block>>>(
fp4_weight, fp4_weight_interleaved, rows, cols);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# First, let's read the file to see the actual code at lines 104-106 and 112-114
cat -n cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu | sed -n '95,125p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 1033


🏁 Script executed:

# Check the function signature that contains these launches
cat -n cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu | sed -n '1,120p' | tail -60

Repository: NVIDIA/TensorRT-LLM

Length of output: 2614


🏁 Script executed:

# Look for the function definition and parameters
rg -A 30 "interleave_fp4_weights_for_Hopper_mixed_gemm\(" cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu | head -40

Repository: NVIDIA/TensorRT-LLM

Length of output: 732


🏁 Script executed:

# Check if there's a cudaGetLastError or stream handling pattern elsewhere in the file
rg "cudaGetLastError|<<<.*stream|cudaStreamSynchronize" cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Check for similar kernel launches in other moe files to see the pattern
find cpp/tensorrt_llm/kernels -name "*.cu" -type f | xargs rg "<<<" | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 3431


🏁 Script executed:

# Check if NVIDIA copyright header is present
head -20 cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu

Repository: NVIDIA/TensorRT-LLM

Length of output: 757


🏁 Script executed:

# Check where these functions are called from
rg "interleave_fp4_weights_for_Hopper_mixed_gemm|interleave_int4_weights_for_Hopper_mixed_gemm" --type cpp --type cu --type h

Repository: NVIDIA/TensorRT-LLM

Length of output: 89


🏁 Script executed:

# Check the header file to see function declarations
cat cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.h

Repository: NVIDIA/TensorRT-LLM

Length of output: 1162


🏁 Script executed:

# Check if there's a Torch binding or wrapper that calls these
rg -l "moe_gemm_mixed_utils" --type cpp --type h --type py

Repository: NVIDIA/TensorRT-LLM

Length of output: 106


🏁 Script executed:

# Check the Torch binding file to understand the context
rg -B 5 -A 10 "moe_gemm_mixed_utils|interleave_fp4_weights_for_Hopper_mixed_gemm|interleave_int4_weights_for_Hopper_mixed_gemm" cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 1224


🏁 Script executed:

# Also check if there's a .h header that might show stream parameter expectation
rg "stream" cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Check how other kernel wrapper functions in TensorRT-LLM handle streams
rg -B 2 "void.*weight.*gemm\(" cpp/tensorrt_llm/kernels/ -A 5 | head -50

Repository: NVIDIA/TensorRT-LLM

Length of output: 5034


🏁 Script executed:

# Check if there's any stream handling in the Torch binding that should be passed
rg -B 10 -A 10 "interleave_fp4_weights_for_Hopper_mixed_gemm|interleave_int4_weights_for_Hopper_mixed_gemm" cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp | head -40

Repository: NVIDIA/TensorRT-LLM

Length of output: 894


🏁 Script executed:

# Check how other Torch-facing kernel wrappers handle streams
rg "at::cuda::getCurrentCUDAStream|cudaStream_t" cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp | head -10

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Check pattern in similar preprocessor functions
rg -B 5 -A 10 "cudaStream_t stream" cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp | head -30

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Check if there's error handling expected for kernel launches in TensorRT-LLM patterns
rg "cudaGetLastError|CHECK_CUDA_ERROR" cpp/tensorrt_llm/kernels/cutlass_kernels/ | head -10

Repository: NVIDIA/TensorRT-LLM

Length of output: 1167


🏁 Script executed:

# Verify the full context of the Torch function that calls these
rg -B 20 "interleave_int4_weights_for_Hopper_mixed_gemm" cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp | head -35

Repository: NVIDIA/TensorRT-LLM

Length of output: 835


🏁 Script executed:

# Check if similar weight preprocessing functions take stream parameters
rg "void.*weight.*Hopper\(" cpp/tensorrt_llm/kernels/ -A 3

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


Pass the caller's CUDA stream and add error checking.

Both wrapper functions at lines 101–107 and 109–115 use <<<1024, block>>>, which binds them to the implicit default stream. When called from the Torch binding in weightOnlyQuantOp.cpp, this creates unsafe semantics that can race with work already queued on the current stream and fails under CUDA graph capture. Add a cudaStream_t stream parameter to both functions, launch with <<<grid, block, 0, stream>>>, and add cudaGetLastError() to catch launch failures—matching the pattern used throughout the rest of the codebase (e.g., unfusedAttentionKernels.cu, cudaCoreGemmNVFP4.cu).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu`
around lines 104 - 106, The two wrapper functions that currently launch
interleave_fp4_weights_for_Hopper_mixed_gemm_kernel with <<<1024, block>>> must
accept a cudaStream_t stream parameter, use a computed grid (e.g., dim3
grid(1024,1,1) or existing grid variable) and launch with <<<grid, block, 0,
stream>>> instead of the implicit default stream, and then call
cudaGetLastError() (or cudaPeekAtLastError()) immediately after the launch to
catch failures; update both function signatures and all call sites (e.g., where
weightOnlyQuantOp.cpp invokes these wrappers) to pass the caller's stream, and
follow the same error-checking pattern used in unfusedAttentionKernels.cu /
cudaCoreGemmNVFP4.cu.

* limitations under the License.
*/

#pragma once
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use the required TRTLLM_*_H guard here.

#pragma once alone doesn't meet the repo header convention for new .h files. Please switch this to a TRTLLM_MOE_GEMM_MIXED_UTILS_H style include guard.

As per coding guidelines, **/*.{h,hpp,cuh}: Use a preprocessor guard in C++ header files with the format TRTLLM_<FILENAME_IN_CAPS>_H (e.g., TRTLLM_FOO_BAR_HELLO_H). Do not use directory names or trailing underscores.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.h` at
line 17, Replace the lone `#pragma once` with the repository-required include
guard: add a preprocessor guard using the macro name
TRTLLM_MOE_GEMM_MIXED_UTILS_H (open with `#ifndef` TRTLLM_MOE_GEMM_MIXED_UTILS_H /
`#define` TRTLLM_MOE_GEMM_MIXED_UTILS_H and close with `#endif`), ensuring the macro
matches the header filename moe_gemm_mixed_utils.h in all caps and without
directory names or trailing underscores; update the top and bottom of the file
(moe_gemm_mixed_utils.h) accordingly so other symbols in this header (e.g.,
declarations used by moe_gemm_mixed_utils) remain unchanged.

Comment on lines +1341 to +1344
if w31_weight_shard.ndim == 2:
w31_weight_shard = w31_weight_shard.cuda()
w31_weight_shard = torch.ops.trtllm.interleave_4bit_weights_for_Hopper_mixed_gemm(
w31_weight_shard, 0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's look at the file structure and the specific lines mentioned
head -30 tensorrt_llm/_torch/modules/fused_moe/quantization.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1228


🏁 Script executed:

# Get the file size to determine how to read it
wc -l tensorrt_llm/_torch/modules/fused_moe/quantization.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 121


🏁 Script executed:

# Look at the problematic lines in context (1327-1344 for w3_w1)
sed -n '1320,1350p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 1936


🏁 Script executed:

# Also check the w2 weight loading (1379-1396)
sed -n '1370,1405p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 2184


🏁 Script executed:

# Let's search for dst_*_weight initialization to understand device setup
rg "dst_.*weight.*device|create_weights" tensorrt_llm/_torch/modules/fused_moe/quantization.py -B 2 -A 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 16767


🏁 Script executed:

# Get full context of load_expert_w3_w1_weight method for WInt4AFP8FusedMoEMethod
sed -n '1300,1350p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 3046


🏁 Script executed:

# Get full context of load_expert_w2_weight method
sed -n '1360,1410p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 2985


🏁 Script executed:

# Check if device variable is used later in the copy operation
sed -n '1340,1360p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 1109


🏁 Script executed:

# Look for WFP4A16FusedMoEMethod to see the FP4 path mentioned
sed -n '1720,1780p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 3209


🏁 Script executed:

# Check what the device variable is actually used for - search around copy_ calls
rg "\.copy_\(" tensorrt_llm/_torch/modules/fused_moe/quantization.py -B 5 -A 2 | head -60

Repository: NVIDIA/TensorRT-LLM

Length of output: 2840


🏁 Script executed:

# Let me check what happens after the .cuda() call and before the copy_
# Look at the exact lines with full context for WInt4AFP8FusedMoEMethod
sed -n '1295,1355p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 3536


🏁 Script executed:

# Check the SM89 path to see how it handles device differently
sed -n '1312,1327p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 957


🏁 Script executed:

# Check if the custom op is called elsewhere and how device is managed there
rg "interleave_4bit_weights_for_Hopper_mixed_gemm" tensorrt_llm/_torch/modules/fused_moe/quantization.py -B 3 -A 1

Repository: NVIDIA/TensorRT-LLM

Length of output: 1051


🏁 Script executed:

# Verify the difference between INT4 (WInt4AFP8FusedMoEMethod) and FP4 (WFP4A16FusedMoEMethod) paths
sed -n '1710,1780p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 3774


🏁 Script executed:

# Let me check if the custom op is CUDA-only by looking for its definition or usage patterns
rg "interleave_4bit_weights_for_Hopper_mixed_gemm" tensorrt_llm/ -A 2 | head -30

Repository: NVIDIA/TensorRT-LLM

Length of output: 1289


🏁 Script executed:

# Check if there's a torch.cuda.current_device() call or similar device management
rg "torch.cuda.current_device|\.cuda\(\)" tensorrt_llm/_torch/modules/fused_moe/quantization.py | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 1154


🏁 Script executed:

# Look at how SM89 path handles device vs SM90 to understand the pattern
sed -n '1318,1332p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 1008


🏁 Script executed:

# Check if dst_w3_w1_weight.device is always CUDA in practice
rg "register_parameter|register_buffer" tensorrt_llm/_torch/modules/fused_moe/quantization.py | grep -E "(w3_w1_weight|w2_weight)" | head -10

Repository: NVIDIA/TensorRT-LLM

Length of output: 603


🏁 Script executed:

# Let me trace the full flow - what is w31_weight_shard.device after line 1343?
# Check if the tensor at line 1343 would be on CPU or already on device
sed -n '1328,1350p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 1427


🏁 Script executed:

# Verify the device where packer/unpacker operations happen
rg "pack_int8_tensor_to_packed_int4|unpack_int4_packed_tensor_to_int8" tensorrt_llm/_torch/modules/fused_moe/quantization.py -B 1 -A 1 | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 390


🏁 Script executed:

# Check what device w31_weight_shard is on at line 1343 in SM90 path
# After line 1338: w31_weight_shard = packer(transposed)
# transposed comes from unpacker on CPU, then .T.contiguous() preserves CPU
# so packer output would be on CPU
sed -n '1328,1348p' tensorrt_llm/_torch/modules/fused_moe/quantization.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1187


🏁 Script executed:

# Check if there are any comments or documentation about why bare .cuda() is used
sed -n '1340,1350p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 630


🏁 Script executed:

# Look for any documentation about the custom op requirements
rg "interleave_4bit_weights_for_Hopper" tensorrt_llm/ --type-list | head -5

Repository: NVIDIA/TensorRT-LLM

Length of output: 452


🏁 Script executed:

# Check if there's any enforcement that dst_w3_w1_weight must be CUDA
sed -n '1295,1310p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 948


🏁 Script executed:

# Final verification - check if the device variable is still available at line 1343
# and confirm SM89 uses it correctly
sed -n '1295,1330p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | grep -E "device\s*=|def load_expert"

Repository: NVIDIA/TensorRT-LLM

Length of output: 250


Use explicit device in Hopper interleave path instead of bare .cuda().

At lines 1343 and 1389, the code uses bare .cuda() which moves tensors to the current CUDA device (via torch.cuda.current_device()). This ignores the device = dst_w3_w1_weight.device captured earlier at the method start. In multi-GPU loads where different ranks own different devices, this causes tensors to be moved to the wrong GPU before the custom operator runs.

The SM89 code path correctly uses .to(w31_weight_shard.device) to preserve device placement. Use the same pattern here, or use .to(device) with the captured device variable.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/fused_moe/quantization.py` around lines 1341 -
1344, The bare .cuda() calls moving w31_weight_shard to the current CUDA device
break multi-GPU placement; replace those .cuda() uses (the ones before calling
torch.ops.trtllm.interleave_4bit_weights_for_Hopper_mixed_gemm) with an explicit
.to(device) (using the previously captured device variable `device =
dst_w3_w1_weight.device`) so the tensor is moved to the intended device before
invoking the Hopper interleave operator.

Comment on lines +1735 to +1737
if w31_weight_shard.ndim == 2:
w31_weight_shard = torch.ops.trtllm.interleave_4bit_weights_for_Hopper_mixed_gemm(
w31_weight_shard, 1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Stage FP4 shards to CUDA before calling the new interleave op.

interleave_4bit_weights_for_Hopper_mixed_gemm is CUDA-only, but these tensors are still built on dst_*_weight.device. Since FusedMoEMethodBase.create_weights() allocates parameters on CPU by default, host-side loading can now fail here at runtime instead of preserving the previous flow.

Also applies to: 1769-1771

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/fused_moe/quantization.py` around lines 1735 -
1737, The interleave op
torch.ops.trtllm.interleave_4bit_weights_for_Hopper_mixed_gemm is CUDA-only, so
before calling it move the FP4 shard tensors to CUDA (e.g., w31_weight_shard =
w31_weight_shard.to(dst_device or torch.device('cuda'), non_blocking=True) or
use .cuda()) and ensure the same for the other shard(s) used in the nearby block
(the similar call around the w13/w31 handling at the later location). Update the
code that calls interleave_4bit_weights_for_Hopper_mixed_gemm to stage the
shard(s) onto CUDA first (preserving dtype/contiguity) so the CUDA-only op runs
on GPU-hosted tensors.

Comment on lines +1 to +3
import pickle
import sys
from typing import Dict, List, Optional
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add the required NVIDIA Apache header.

This new test module is missing the repository-mandated copyright/license block.

As per coding guidelines, "All TensorRT-LLM source files should contain an NVIDIA copyright header with the year of the latest meaningful modification. The header should be an Apache 2.0 license block as specified."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/modules/test_fused_moe_jiangs.py` around lines 1 - 3,
This test module is missing the required NVIDIA Apache 2.0 header; add the
repository-mandated NVIDIA copyright/license block at the very top of the file
(above the import lines) using the current year of the latest meaningful
modification and the standard Apache-2.0 header format used across the repo so
the file (test_fused_moe_jiangs.py) includes the exact NVIDIA copyright and
Apache 2.0 license block.

Comment on lines +229 to +234
q1 = q2 = q3 = q3_q1 = None
if weight_loading_mode == MoEWeightLoadingMode.VANILLA:
q1 = weights[f"{e_idx}.w1.{lut['weight_scale_2']}"].cuda()
q2 = weights[f"{e_idx}.w3.{lut['weight_scale_2']}"].cuda()
q3 = weights[f"{e_idx}.w2.{lut['weight_scale_2']}"].cuda()
q3_q1 = torch.max(q3, q1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Swap q2 and q3 in the reference dequant path.

q2 is passed to fc2 on Line 275, but it is loaded from w3.weight_scale_2 here. That also makes q3_q1 combine w2 with w1 for fc13. The all-ones fixtures hide it today, but any non-trivial weight_scale_2 will make the reference wrong.

🐛 Proposed fix
-                    q2 = weights[f"{e_idx}.w3.{lut['weight_scale_2']}"].cuda()
-                    q3 = weights[f"{e_idx}.w2.{lut['weight_scale_2']}"].cuda()
+                    q2 = weights[f"{e_idx}.w2.{lut['weight_scale_2']}"].cuda()
+                    q3 = weights[f"{e_idx}.w3.{lut['weight_scale_2']}"].cuda()

Also applies to: 263-275

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/modules/test_fused_moe_jiangs.py` around lines 229 -
234, The reference dequant path loads q2 and q3 swapped: when
weight_loading_mode == MoEWeightLoadingMode.VANILLA set q2 and q3 from the
correct weight keys (q2 should come from f"{e_idx}.w2.{lut['weight_scale_2']}"
and q3 from f"{e_idx}.w3.{lut['weight_scale_2']}") so fc2 receives the proper
tensor and q3_q1 combines w3 with w1 for fc13; update both occurrences (the
block around q1/q2/q3/q3_q1 and the similar block at lines ~263-275) to swap the
assignments accordingly while keeping q1 loaded from w1.

Comment on lines +466 to +469
from tensorrt_llm._torch.custom_ops.torch_custom_ops import MoERunner

# Get the C++ FusedMoeRunner to query tactic descriptions
cpp_runner = next(iter(MoERunner.runner_dict.values()))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Don't force a runner lookup for commented diagnostics.

next(iter(MoERunner.runner_dict.values())) can raise StopIteration when this backend does not populate the registry, and nothing below consumes cpp_runner outside the commented debug block.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/modules/test_fused_moe_jiangs.py` around lines 466 -
469, The test currently forces a runner lookup with
next(iter(MoERunner.runner_dict.values())) which can raise StopIteration when
the backend doesn't populate MoERunner.runner_dict; change this to a safe
retrieval (e.g., use next(iter(...), None) or guard with a try/except) and
ensure any use of cpp_runner (the commented diagnostic block) only runs when
cpp_runner is not None so the test won't error when the registry is empty.

@StudyingShao
Copy link
Copy Markdown
Collaborator Author

This file (tests/unittest/_torch/modules/test_fused_moe_jiangs.py) is for debugging and will be deleted later.

samuellees added a commit to flashinfer-ai/flashinfer that referenced this pull request Apr 23, 2026
)

## Summary

Port [TensorRT-LLM PR
#12451](NVIDIA/TensorRT-LLM#12451) to
FlashInfer's `cutlass_fused_moe` SM90 path. Adds an LDSM +
interleaved-LUT weight-load pipeline for 4-bit weights × 16/8-bit
activations, plus the two preprocessing helpers the new kernel layout
requires.

## Changes

### Kernel
- `mixed_input_utils.hpp` /
`sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp` — sync
with TRT-LLM PR #12451 (LDSM path + FP4/INT4 → BF16 LUT converter).
- `moe_gemm_mixed_utils.{cu,h}` (new) — per-row CUDA kernels for
FP4/INT4 byte interleave.
- `cutlass_heuristic.cpp` — for `has_w4afp8`, skip `CtaShape128x128x128B
+ COOPERATIVE` (register overflow on SM90) and pick COOP / PINGPONG per
tile.
- `moe_gemm_tma_ws_mixed_input_launcher.inl` —
`scheduler.max_swizzle_size = 2`, `raster_order = Heuristic`.

### Python
`flashinfer/fused_moe/core.py` exposes two helpers (re-exported by the
package):
- `interleave_moe_weights_for_hopper_mixed_gemm(weight, quant_type)` —
byte-level interleave for `"fp4"` / `"int4"` packed uint8 weights;
delegates to the C++ kernel above.
- `interleave_moe_scales_for_hopper_mixed_gemm(scales, group_size=32)` —
pure PyTorch reshape + permute matching TRT-LLM's
`WFP4A16FusedMoEMethod.load_quant_scales`, factor = `128 // group_size`.

### Tests — inside `tests/moe/test_trtllm_cutlass_fused_moe.py` (18 new)
- `test_moe_bf16_mxfp4_hopper_correctness` (5 shapes, strict
`assert_close` vs a GPU-side dequantized reference that only
materialises active experts to stay under H200 memory at e=256).
- `test_moe_bf16_mxfp4_hopper_coverage` (5 shapes, percent-based ≥
99.9%).
- `test_moe_bf16_mxfp4_hopper_activations` (3 SwiGLU variants).
- `test_moe_w4a8_hopper_correctness` (2 shapes × bf16/fp16) — envelope
matches the upstream CI shape (h = inter = 512, e = 2); larger exceeds
strict tolerance because of FP8 + INT4 accumulation noise, same as the
existing `test_moe_w4a8`.
- `test_moe_w4a8_hopper_autotune` — smoke that `autotune(True)` doesn't
break the W4A8 path.

All 18 green on H200 in 5.2 s cache-hot.

## Performance

H200 (SM90 / HBM3e), `hidden = 4096, intermediate = 2048, experts = 256,
topk = 6`, bf16 output, MXFP4 weights. `cutlass_fused_moe` median over
`bench_gpu_time`. Weight + scale interleave is a one-shot model-load
step and is excluded from timing. `autotune` column runs one pass under
`autotune(True)` to populate the tactic cache before timing.

| batch | main no-autotune | main autotune | **PR no-autotune** | **PR
autotune** | **speedup (autotune)** |

|------:|-----------------:|--------------:|-------------------:|----------------:|-----------------------:|
| 4 | 0.791 ms | 0.513 ms | **0.221 ms** | **0.193 ms** | **2.66×** |
| 16 | 1.598 ms | 1.607 ms | **0.530 ms** | **0.532 ms** | **3.02×** |
| 64 | 3.761 ms | 3.757 ms | **1.200 ms** | **1.207 ms** | **3.11×** |

---------

Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant