MXFP4 x BF16 CUTLASS MoE backend perf and profiling improvement on Hopper#12451
MXFP4 x BF16 CUTLASS MoE backend perf and profiling improvement on Hopper#12451StudyingShao wants to merge 6 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
📝 WalkthroughWalkthroughThis PR introduces Int4→FP8 lookup-table conversion capabilities for SM90 Hopper mixed-precision GEMMs, including new tensormap management logic, A operand loading pipeline refactoring, weight interleaving utilities, and comprehensive test coverage for fused MoE configurations. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Suggested reviewers
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Important Merge conflicts detected (Beta)
✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 12
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu (1)
2-2:⚠️ Potential issue | 🟠 MajorUpdate copyright year for a modified source file.
This file was modified in 2026, but the header still ends at 2025.
Proposed fix
- * Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020-2026, NVIDIA CORPORATION. All rights reserved.As per coding guidelines, "Add NVIDIA copyright header on ALL new files, and update year on modified files".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu` at line 2, Update the copyright header in the top-of-file comment in moe_kernels.cu to include 2026 (e.g., change "2020-2025" to "2020-2026") so the modified source file reflects the current year; locate the header comment near the top of cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu (the file-level copyright comment) and update the year range accordingly.cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp (1)
2-2:⚠️ Potential issue | 🟠 MajorUpdate copyright year for this modified file.
The header still ends at 2023.
Proposed fix
- * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020-2026, NVIDIA CORPORATION. All rights reserved.As per coding guidelines, "Add NVIDIA copyright header on ALL new files, and update year on modified files".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp` at line 2, Update the file header in cutlass_heuristic.cpp to reflect the current modified-year range and include the NVIDIA copyright header; specifically change the existing "2020-2023" year span to include the current year (e.g., "2020-2026") and ensure the full NVIDIA copyright header text used across the repo is present and matches formatting of other modified files.cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl (1)
2-2:⚠️ Potential issue | 🟠 MajorUpdate copyright year for this modified file.
The header currently ends at 2023.
Proposed fix
- * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020-2026, NVIDIA CORPORATION. All rights reserved.As per coding guidelines, "Add NVIDIA copyright header on ALL new files, and update year on modified files".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl` at line 2, Update the copyright header in the file moe_gemm_tma_ws_mixed_input_launcher.inl to reflect the current modification year (change the existing "2020-2023" to include the current year), i.e., update the top-of-file NVIDIA copyright comment block in cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl so the range covers through the current year; keep the existing header text and formatting intact except for the year range.
🧹 Nitpick comments (4)
cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu (2)
4616-4616: Remove the unuseduse_w4_groupwiselocal.At Line 4616, the value is computed and never read.
Proposed fix
- [[maybe_unused]] bool use_w4_groupwise = use_w4afp8 || use_wfp4a16;As per coding guidelines, "Avoid dead code in C++".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu` at line 4616, Remove the dead local variable declaration [[maybe_unused]] bool use_w4_groupwise = use_w4afp8 || use_wfp4a16; in moe_kernels.cu (the computed value is never read); simply delete this unused variable (or if the intention was to use it later, replace unused declaration with the appropriate logic that actually references use_w4_groupwise in the relevant control flow) so there is no unused local symbol left.
2157-2158: Useconstfor immutable locals introduced in this change.These locals are not reassigned after initialization.
Proposed fix
- [[maybe_unused]] float global_scale_val = fc2_act_global_scale ? fc2_act_global_scale[act_scale_idx] : 1.0f; - [[maybe_unused]] int64_t num_tokens_before_expert = (IsNVFP4 || IsMXFP8) ? expert_first_token_offset[expert] : 0; + [[maybe_unused]] float const global_scale_val + = fc2_act_global_scale ? fc2_act_global_scale[act_scale_idx] : 1.0f; + [[maybe_unused]] int64_t const num_tokens_before_expert + = (IsNVFP4 || IsMXFP8) ? expert_first_token_offset[expert] : 0; ... - [[maybe_unused]] auto NVFP4 = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType, + [[maybe_unused]] auto const NVFP4 = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType, TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4>{}; - [[maybe_unused]] auto MXFPX = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType, + [[maybe_unused]] auto const MXFPX = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType, TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX>{}; - [[maybe_unused]] auto NONE = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType, + [[maybe_unused]] auto const NONE = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType, TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE>{};As per coding guidelines, "A variable that is not modified after its initialization should be declared as
const".Also applies to: 2402-2407
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu` around lines 2157 - 2158, Declare the local variables that are not modified after initialization as const: change the declarations of global_scale_val and num_tokens_before_expert to use const (e.g., const float global_scale_val = ...; const int64_t num_tokens_before_expert = ...;). Apply the same change to the other immutable locals noted in the diff region around 2402-2407 so any variables that are initialized once and never reassigned become const (use the exact symbol names from those lines when updating).cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl (1)
262-263: Replace hard-coded scheduler swizzle value with a named constant.At Line 262,
2should be expressed as a named constant.Proposed fix
+ int constexpr maxSwizzleSize = 2; - arguments.scheduler.max_swizzle_size = 2; + arguments.scheduler.max_swizzle_size = maxSwizzleSize; arguments.scheduler.raster_order = RasterOrderOptions::Heuristic;As per coding guidelines, "Except
0(only used in comparison),nullptr,true,false, all other literals should only be used for variable initialization in C++. Other occurrences should use named constants".🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl` around lines 262 - 263, Replace the magic literal 2 used for arguments.scheduler.max_swizzle_size with a named constant (e.g., kDefaultMaxSwizzleSize or MAX_SWIZZLE_SIZE) and use that constant in the assignment to arguments.scheduler.max_swizzle_size; define the constant as a constexpr integer with an appropriate scope (file-scope in this implementation file or in the module header alongside other scheduler constants) so calls like arguments.scheduler.max_swizzle_size = kDefaultMaxSwizzleSize; replace the hard-coded 2 and keep the adjacent use of RasterOrderOptions::Heuristic unchanged.tests/unittest/_torch/modules/test_fused_moe_jiangs.py (1)
667-680: Avoid a sharedModelConfigdefault.Line 674 constructs one
ModelConfigat import time. If any test mutates it, laterRefGatedMLPFusedMoEinstances inherit that state.♻️ Proposed fix
def __init__( self, num_experts: int, routing_method: BaseMoeRoutingMethod, hidden_size: int, intermediate_size: int, dtype: Optional[torch.dtype] = None, - model_config: ModelConfig = ModelConfig(), + model_config: Optional[ModelConfig] = None, use_cute_dsl_blockscaling_mm: bool = False, bias=False, swiglu_alpha: Optional[float] = None, swiglu_beta: Optional[float] = None, swiglu_limit: Optional[float] = None, ): super().__init__() + model_config = ModelConfig() if model_config is None else model_config self.num_experts = num_experts self.routing_method = routing_method self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.bias = bias🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/modules/test_fused_moe_jiangs.py` around lines 667 - 680, The constructor currently uses a shared ModelConfig default (model_config: ModelConfig = ModelConfig()), which can be mutated across tests; change the signature to accept model_config: Optional[ModelConfig] = None and inside __init__ set self.model_config = model_config if model_config is not None else ModelConfig() so each instance gets a fresh ModelConfig; update any references to model_config in the constructor/body accordingly (look for the __init__ method and the symbol model_config in this class).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In
`@cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp`:
- Around line 1493-1527: The zero-mode branch for
ConversionMode::ConvertAndScaleWithZero updates only the zero tensormap but
omits updating the scale tensormap, causing tma_load_scale to read stale or
pointer-table data; in the functions that call
cute::tma_descriptor_replace_addr_in_shared_mem /
cute::tma_descriptor_replace_addr_in_global_mem (look for
tensormaps_replace_global_address/static_assert text and the branches using
KernelConversionMode), add the same scale update as in the ConvertAndScale
branch: replace the shared/global smem_tensormap_scale/get<2>(input_tensormaps)
with mainloop_params.ptr_S[next_batch] wherever the ConvertAndScaleWithZero
branch touches A/zero metadata, and also ensure smem_tensormap_scale is
initialized/staged inside tensormaps_init() so per-batch scale updates exist for
grouped GEMM; repeat the same fix in the equivalent locations (the other similar
blocks noted in the review).
- Around line 1015-1016: The LDSM thread-slice is using the global thread_idx
instead of the warpgroup-local index, causing incorrect lane mapping when
multiple consumer warpgroups participate; change the call to
smem_tiled_copy_A_LDSM.get_thread_slice(thread_idx) to use
get_thread_slice(warp_group_thread_idx) (consistent with the earlier A-copy path
and other SM90 collectives) so SmemCopyAtomA_LDSM / make_tiled_copy_A with
tiled_mma uses the warpgroup-local index for the slice.
In `@cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp`:
- Around line 313-314: The single-line if statement checking "if (tile_config ==
CutlassTileConfigSM90::CtaShape128x128x128B) continue;" should be rewritten to
use brace-delimited body per style guidelines: locate the check in
cutlass_heuristic.cpp (the loop handling tile_config values) and replace the
single-statement form with an explicit block using braces around the continue;
so the condition and its action are enclosed in { } to ensure consistent
formatting and avoid style warnings.
In `@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu`:
- Around line 104-106: The two wrapper functions that currently launch
interleave_fp4_weights_for_Hopper_mixed_gemm_kernel with <<<1024, block>>> must
accept a cudaStream_t stream parameter, use a computed grid (e.g., dim3
grid(1024,1,1) or existing grid variable) and launch with <<<grid, block, 0,
stream>>> instead of the implicit default stream, and then call
cudaGetLastError() (or cudaPeekAtLastError()) immediately after the launch to
catch failures; update both function signatures and all call sites (e.g., where
weightOnlyQuantOp.cpp invokes these wrappers) to pass the caller's stream, and
follow the same error-checking pattern used in unfusedAttentionKernels.cu /
cudaCoreGemmNVFP4.cu.
In `@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.h`:
- Line 17: Replace the lone `#pragma once` with the repository-required include
guard: add a preprocessor guard using the macro name
TRTLLM_MOE_GEMM_MIXED_UTILS_H (open with `#ifndef` TRTLLM_MOE_GEMM_MIXED_UTILS_H /
`#define` TRTLLM_MOE_GEMM_MIXED_UTILS_H and close with `#endif`), ensuring the macro
matches the header filename moe_gemm_mixed_utils.h in all caps and without
directory names or trailing underscores; update the top and bottom of the file
(moe_gemm_mixed_utils.h) accordingly so other symbols in this header (e.g.,
declarations used by moe_gemm_mixed_utils) remain unchanged.
In `@cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp`:
- Around line 412-435: Add explicit shape validation before computing n/k and
before dispatching to the Hopper kernels: ensure weight.dim() == 2,
weight.size(0) % 16 == 0, and weight.size(1) % 32 == 0, and emit TORCH_CHECK
errors if any condition fails. Place these checks at the start of the
function/block handling the public entrypoint (before the lines that compute n =
weight.size(0) and k = weight.size(1) * 2) so that inputs violating the kernel
tile assumptions are rejected prior to calling
interleave_int4_weights_for_Hopper_mixed_gemm or
interleave_fp4_weights_for_Hopper_mixed_gemm.
- Around line 416-427: The output tensor is allocated on the current CUDA device
instead of the input tensor's device, causing mismatched allocations on
multi-GPU; add `#include` <c10/cuda/CUDAGuard.h>, create a c10::cuda::CUDAGuard
guard{weight.device()}; (or at::cuda::CUDAGuard guard(weight.device().index());)
immediately before allocating weight_interleaved so the torch::empty call uses
the same device as weight, and ensure subsequent get_ptr<uint8_t>(weight) /
get_ptr<uint8_t>(weight_interleaved) dereference the correct device pointers.
In `@tensorrt_llm/_torch/modules/fused_moe/quantization.py`:
- Around line 1735-1737: The interleave op
torch.ops.trtllm.interleave_4bit_weights_for_Hopper_mixed_gemm is CUDA-only, so
before calling it move the FP4 shard tensors to CUDA (e.g., w31_weight_shard =
w31_weight_shard.to(dst_device or torch.device('cuda'), non_blocking=True) or
use .cuda()) and ensure the same for the other shard(s) used in the nearby block
(the similar call around the w13/w31 handling at the later location). Update the
code that calls interleave_4bit_weights_for_Hopper_mixed_gemm to stage the
shard(s) onto CUDA first (preserving dtype/contiguity) so the CUDA-only op runs
on GPU-hosted tensors.
- Around line 1341-1344: The bare .cuda() calls moving w31_weight_shard to the
current CUDA device break multi-GPU placement; replace those .cuda() uses (the
ones before calling
torch.ops.trtllm.interleave_4bit_weights_for_Hopper_mixed_gemm) with an explicit
.to(device) (using the previously captured device variable `device =
dst_w3_w1_weight.device`) so the tensor is moved to the intended device before
invoking the Hopper interleave operator.
In `@tests/unittest/_torch/modules/test_fused_moe_jiangs.py`:
- Around line 229-234: The reference dequant path loads q2 and q3 swapped: when
weight_loading_mode == MoEWeightLoadingMode.VANILLA set q2 and q3 from the
correct weight keys (q2 should come from f"{e_idx}.w2.{lut['weight_scale_2']}"
and q3 from f"{e_idx}.w3.{lut['weight_scale_2']}") so fc2 receives the proper
tensor and q3_q1 combines w3 with w1 for fc13; update both occurrences (the
block around q1/q2/q3/q3_q1 and the similar block at lines ~263-275) to swap the
assignments accordingly while keeping q1 loaded from w1.
- Around line 466-469: The test currently forces a runner lookup with
next(iter(MoERunner.runner_dict.values())) which can raise StopIteration when
the backend doesn't populate MoERunner.runner_dict; change this to a safe
retrieval (e.g., use next(iter(...), None) or guard with a try/except) and
ensure any use of cpp_runner (the commented diagnostic block) only runs when
cpp_runner is not None so the test won't error when the registry is empty.
- Around line 1-3: This test module is missing the required NVIDIA Apache 2.0
header; add the repository-mandated NVIDIA copyright/license block at the very
top of the file (above the import lines) using the current year of the latest
meaningful modification and the standard Apache-2.0 header format used across
the repo so the file (test_fused_moe_jiangs.py) includes the exact NVIDIA
copyright and Apache 2.0 license block.
---
Outside diff comments:
In `@cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp`:
- Line 2: Update the file header in cutlass_heuristic.cpp to reflect the current
modified-year range and include the NVIDIA copyright header; specifically change
the existing "2020-2023" year span to include the current year (e.g.,
"2020-2026") and ensure the full NVIDIA copyright header text used across the
repo is present and matches formatting of other modified files.
In
`@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl`:
- Line 2: Update the copyright header in the file
moe_gemm_tma_ws_mixed_input_launcher.inl to reflect the current modification
year (change the existing "2020-2023" to include the current year), i.e., update
the top-of-file NVIDIA copyright comment block in
cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl
so the range covers through the current year; keep the existing header text and
formatting intact except for the year range.
In `@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu`:
- Line 2: Update the copyright header in the top-of-file comment in
moe_kernels.cu to include 2026 (e.g., change "2020-2025" to "2020-2026") so the
modified source file reflects the current year; locate the header comment near
the top of cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu (the
file-level copyright comment) and update the year range accordingly.
---
Nitpick comments:
In
`@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl`:
- Around line 262-263: Replace the magic literal 2 used for
arguments.scheduler.max_swizzle_size with a named constant (e.g.,
kDefaultMaxSwizzleSize or MAX_SWIZZLE_SIZE) and use that constant in the
assignment to arguments.scheduler.max_swizzle_size; define the constant as a
constexpr integer with an appropriate scope (file-scope in this implementation
file or in the module header alongside other scheduler constants) so calls like
arguments.scheduler.max_swizzle_size = kDefaultMaxSwizzleSize; replace the
hard-coded 2 and keep the adjacent use of RasterOrderOptions::Heuristic
unchanged.
In `@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu`:
- Line 4616: Remove the dead local variable declaration [[maybe_unused]] bool
use_w4_groupwise = use_w4afp8 || use_wfp4a16; in moe_kernels.cu (the computed
value is never read); simply delete this unused variable (or if the intention
was to use it later, replace unused declaration with the appropriate logic that
actually references use_w4_groupwise in the relevant control flow) so there is
no unused local symbol left.
- Around line 2157-2158: Declare the local variables that are not modified after
initialization as const: change the declarations of global_scale_val and
num_tokens_before_expert to use const (e.g., const float global_scale_val = ...;
const int64_t num_tokens_before_expert = ...;). Apply the same change to the
other immutable locals noted in the diff region around 2402-2407 so any
variables that are initialized once and never reassigned become const (use the
exact symbol names from those lines when updating).
In `@tests/unittest/_torch/modules/test_fused_moe_jiangs.py`:
- Around line 667-680: The constructor currently uses a shared ModelConfig
default (model_config: ModelConfig = ModelConfig()), which can be mutated across
tests; change the signature to accept model_config: Optional[ModelConfig] = None
and inside __init__ set self.model_config = model_config if model_config is not
None else ModelConfig() so each instance gets a fresh ModelConfig; update any
references to model_config in the constructor/body accordingly (look for the
__init__ method and the symbol model_config in this class).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 8cd5f7bd-7d25-496d-86b6-147cca74e2fa
📒 Files selected for processing (11)
cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hppcpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hppcpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cppcpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inlcpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cucpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.hcpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cucpp/tensorrt_llm/thop/moeOp.cppcpp/tensorrt_llm/thop/weightOnlyQuantOp.cpptensorrt_llm/_torch/modules/fused_moe/quantization.pytests/unittest/_torch/modules/test_fused_moe_jiangs.py
| auto smem_tiled_copy_A_LDSM = make_tiled_copy_A(SmemCopyAtomA_LDSM{}, tiled_mma); | ||
| auto smem_thr_copy_A_LDSM = smem_tiled_copy_A_LDSM.get_thread_slice(thread_idx); |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "Current collective:"
rg -nP 'make_tiled_copy_A\(|get_thread_slice\((thread_idx|warp_group_thread_idx)\)' \
cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp \
-A3 -B3
echo
echo "Nearby SM90 collectives:"
rg -nP 'get_thread_slice\((thread_idx|warp_group_thread_idx)\)' \
cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective \
-A2 -B2Repository: NVIDIA/TensorRT-LLM
Length of output: 8930
Use the warpgroup-local thread index for the LDSM slice.
The existing A-copy path on line 993 uses get_thread_slice(warp_group_thread_idx), but the new LDSM path on line 1016 incorrectly uses thread_idx. When multiple consumer warpgroups participate, later warpgroups will select the wrong copy lane mapping and read incorrect A fragments. All similar patterns in nearby SM90 collectives use warp_group_thread_idx.
Fix
- auto smem_thr_copy_A_LDSM = smem_tiled_copy_A_LDSM.get_thread_slice(thread_idx);
+ auto smem_thr_copy_A_LDSM = smem_tiled_copy_A_LDSM.get_thread_slice(warp_group_thread_idx);📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| auto smem_tiled_copy_A_LDSM = make_tiled_copy_A(SmemCopyAtomA_LDSM{}, tiled_mma); | |
| auto smem_thr_copy_A_LDSM = smem_tiled_copy_A_LDSM.get_thread_slice(thread_idx); | |
| auto smem_tiled_copy_A_LDSM = make_tiled_copy_A(SmemCopyAtomA_LDSM{}, tiled_mma); | |
| auto smem_thr_copy_A_LDSM = smem_tiled_copy_A_LDSM.get_thread_slice(warp_group_thread_idx); |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp`
around lines 1015 - 1016, The LDSM thread-slice is using the global thread_idx
instead of the warpgroup-local index, causing incorrect lane mapping when
multiple consumer warpgroups participate; change the call to
smem_tiled_copy_A_LDSM.get_thread_slice(thread_idx) to use
get_thread_slice(warp_group_thread_idx) (consistent with the earlier A-copy path
and other SM90 collectives) so SmemCopyAtomA_LDSM / make_tiled_copy_A with
tiled_mma uses the warpgroup-local index for the slice.
| if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) | ||
| { | ||
| cute::tma_descriptor_replace_addr_in_shared_mem( | ||
| shared_tensormaps.smem_tensormap_scale, mainloop_params.ptr_S[next_batch]); | ||
| } | ||
| else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) | ||
| { | ||
| cute::tma_descriptor_replace_addr_in_shared_mem( | ||
| shared_tensormaps.smem_tensormap_zero, mainloop_params.ptr_Z[next_batch]); | ||
| } | ||
| else if constexpr (KernelConversionMode != ConversionMode::DirectConvert) | ||
| { | ||
| static_assert(cutlass::detail::dependent_false<KernelSchedule>, | ||
| "Conversion mode not handled in tensormaps_replace_global_address."); | ||
| } | ||
| } | ||
| else if constexpr (KernelConversionMode != ConversionMode::DirectConvert) | ||
| else | ||
| { | ||
| static_assert(cutlass::detail::dependent_false<KernelSchedule>, | ||
| "Conversion mode not handled in tensormaps_replace_global_address."); | ||
| cute::tma_descriptor_replace_addr_in_global_mem( | ||
| get<0>(input_tensormaps), mainloop_params.ptr_A[next_batch]); | ||
| if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) | ||
| { | ||
| cute::tma_descriptor_replace_addr_in_global_mem( | ||
| get<2>(input_tensormaps), mainloop_params.ptr_S[next_batch]); | ||
| } | ||
| else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) | ||
| { | ||
| cute::tma_descriptor_replace_addr_in_global_mem( | ||
| get<3>(input_tensormaps), mainloop_params.ptr_Z[next_batch]); | ||
| } | ||
| else if constexpr (KernelConversionMode != ConversionMode::DirectConvert) | ||
| { | ||
| static_assert(cutlass::detail::dependent_false<KernelSchedule>, | ||
| "Conversion mode not handled in tensormaps_replace_global_address."); | ||
| } |
There was a problem hiding this comment.
ConvertAndScaleWithZero still drops the scale tensormap.
These branches only patch/release the zero descriptor. tma_load_scale is still consumed in zero-point mode, and for grouped GEMM its placeholder descriptor is built from the pointer table, so skipping the per-batch scale update means this path can read the pointer array or a stale expert’s scales. Mirror the scale handling anywhere the zero-mode branch currently touches A/zero metadata, and stage smem_tensormap_scale in tensormaps_init() too.
🐛 Minimum fix
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero)
{
+ cute::tma_descriptor_replace_addr_in_shared_mem(
+ shared_tensormaps.smem_tensormap_scale, mainloop_params.ptr_S[next_batch]);
cute::tma_descriptor_replace_addr_in_shared_mem(
shared_tensormaps.smem_tensormap_zero, mainloop_params.ptr_Z[next_batch]);
}
...
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero)
{
+ tma_descriptor_cp_fence_release(get<2>(input_tensormaps), shared_tensormaps.smem_tensormap_scale);
tma_descriptor_cp_fence_release(get<3>(input_tensormaps), shared_tensormaps.smem_tensormap_zero);
}Also applies to: 1571-1621, 1669-1676
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp`
around lines 1493 - 1527, The zero-mode branch for
ConversionMode::ConvertAndScaleWithZero updates only the zero tensormap but
omits updating the scale tensormap, causing tma_load_scale to read stale or
pointer-table data; in the functions that call
cute::tma_descriptor_replace_addr_in_shared_mem /
cute::tma_descriptor_replace_addr_in_global_mem (look for
tensormaps_replace_global_address/static_assert text and the branches using
KernelConversionMode), add the same scale update as in the ConvertAndScale
branch: replace the shared/global smem_tensormap_scale/get<2>(input_tensormaps)
with mainloop_params.ptr_S[next_batch] wherever the ConvertAndScaleWithZero
branch touches A/zero metadata, and also ensure smem_tensormap_scale is
initialized/staged inside tensormaps_init() so per-batch scale updates exist for
grouped GEMM; repeat the same fix in the equivalent locations (the other similar
blocks noted in the review).
| if (tile_config == CutlassTileConfigSM90::CtaShape128x128x128B) | ||
| continue; |
There was a problem hiding this comment.
Use braces for the single-statement if body.
At Line 313-314, the if branch should be brace-delimited.
Proposed fix
- if (tile_config == CutlassTileConfigSM90::CtaShape128x128x128B)
- continue;
+ if (tile_config == CutlassTileConfigSM90::CtaShape128x128x128B)
+ {
+ continue;
+ }As per coding guidelines, "In C++, if and else should always be followed by brace-delimited statements, even if empty or a single statement".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp` around lines
313 - 314, The single-line if statement checking "if (tile_config ==
CutlassTileConfigSM90::CtaShape128x128x128B) continue;" should be rewritten to
use brace-delimited body per style guidelines: locate the check in
cutlass_heuristic.cpp (the loop handling tile_config values) and replace the
single-statement form with an explicit block using braces around the continue;
so the condition and its action are enclosed in { } to ensure consistent
formatting and avoid style warnings.
| dim3 block(32, 32); | ||
| interleave_fp4_weights_for_Hopper_mixed_gemm_kernel<<<1024, block>>>( | ||
| fp4_weight, fp4_weight_interleaved, rows, cols); |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's read the file to see the actual code at lines 104-106 and 112-114
cat -n cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu | sed -n '95,125p'Repository: NVIDIA/TensorRT-LLM
Length of output: 1033
🏁 Script executed:
# Check the function signature that contains these launches
cat -n cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu | sed -n '1,120p' | tail -60Repository: NVIDIA/TensorRT-LLM
Length of output: 2614
🏁 Script executed:
# Look for the function definition and parameters
rg -A 30 "interleave_fp4_weights_for_Hopper_mixed_gemm\(" cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu | head -40Repository: NVIDIA/TensorRT-LLM
Length of output: 732
🏁 Script executed:
# Check if there's a cudaGetLastError or stream handling pattern elsewhere in the file
rg "cudaGetLastError|<<<.*stream|cudaStreamSynchronize" cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cuRepository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Check for similar kernel launches in other moe files to see the pattern
find cpp/tensorrt_llm/kernels -name "*.cu" -type f | xargs rg "<<<" | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 3431
🏁 Script executed:
# Check if NVIDIA copyright header is present
head -20 cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cuRepository: NVIDIA/TensorRT-LLM
Length of output: 757
🏁 Script executed:
# Check where these functions are called from
rg "interleave_fp4_weights_for_Hopper_mixed_gemm|interleave_int4_weights_for_Hopper_mixed_gemm" --type cpp --type cu --type hRepository: NVIDIA/TensorRT-LLM
Length of output: 89
🏁 Script executed:
# Check the header file to see function declarations
cat cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.hRepository: NVIDIA/TensorRT-LLM
Length of output: 1162
🏁 Script executed:
# Check if there's a Torch binding or wrapper that calls these
rg -l "moe_gemm_mixed_utils" --type cpp --type h --type pyRepository: NVIDIA/TensorRT-LLM
Length of output: 106
🏁 Script executed:
# Check the Torch binding file to understand the context
rg -B 5 -A 10 "moe_gemm_mixed_utils|interleave_fp4_weights_for_Hopper_mixed_gemm|interleave_int4_weights_for_Hopper_mixed_gemm" cpp/tensorrt_llm/thop/weightOnlyQuantOp.cppRepository: NVIDIA/TensorRT-LLM
Length of output: 1224
🏁 Script executed:
# Also check if there's a .h header that might show stream parameter expectation
rg "stream" cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Check how other kernel wrapper functions in TensorRT-LLM handle streams
rg -B 2 "void.*weight.*gemm\(" cpp/tensorrt_llm/kernels/ -A 5 | head -50Repository: NVIDIA/TensorRT-LLM
Length of output: 5034
🏁 Script executed:
# Check if there's any stream handling in the Torch binding that should be passed
rg -B 10 -A 10 "interleave_fp4_weights_for_Hopper_mixed_gemm|interleave_int4_weights_for_Hopper_mixed_gemm" cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp | head -40Repository: NVIDIA/TensorRT-LLM
Length of output: 894
🏁 Script executed:
# Check how other Torch-facing kernel wrappers handle streams
rg "at::cuda::getCurrentCUDAStream|cudaStream_t" cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp | head -10Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Check pattern in similar preprocessor functions
rg -B 5 -A 10 "cudaStream_t stream" cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp | head -30Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Check if there's error handling expected for kernel launches in TensorRT-LLM patterns
rg "cudaGetLastError|CHECK_CUDA_ERROR" cpp/tensorrt_llm/kernels/cutlass_kernels/ | head -10Repository: NVIDIA/TensorRT-LLM
Length of output: 1167
🏁 Script executed:
# Verify the full context of the Torch function that calls these
rg -B 20 "interleave_int4_weights_for_Hopper_mixed_gemm" cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp | head -35Repository: NVIDIA/TensorRT-LLM
Length of output: 835
🏁 Script executed:
# Check if similar weight preprocessing functions take stream parameters
rg "void.*weight.*Hopper\(" cpp/tensorrt_llm/kernels/ -A 3Repository: NVIDIA/TensorRT-LLM
Length of output: 45
Pass the caller's CUDA stream and add error checking.
Both wrapper functions at lines 101–107 and 109–115 use <<<1024, block>>>, which binds them to the implicit default stream. When called from the Torch binding in weightOnlyQuantOp.cpp, this creates unsafe semantics that can race with work already queued on the current stream and fails under CUDA graph capture. Add a cudaStream_t stream parameter to both functions, launch with <<<grid, block, 0, stream>>>, and add cudaGetLastError() to catch launch failures—matching the pattern used throughout the rest of the codebase (e.g., unfusedAttentionKernels.cu, cudaCoreGemmNVFP4.cu).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu`
around lines 104 - 106, The two wrapper functions that currently launch
interleave_fp4_weights_for_Hopper_mixed_gemm_kernel with <<<1024, block>>> must
accept a cudaStream_t stream parameter, use a computed grid (e.g., dim3
grid(1024,1,1) or existing grid variable) and launch with <<<grid, block, 0,
stream>>> instead of the implicit default stream, and then call
cudaGetLastError() (or cudaPeekAtLastError()) immediately after the launch to
catch failures; update both function signatures and all call sites (e.g., where
weightOnlyQuantOp.cpp invokes these wrappers) to pass the caller's stream, and
follow the same error-checking pattern used in unfusedAttentionKernels.cu /
cudaCoreGemmNVFP4.cu.
| * limitations under the License. | ||
| */ | ||
|
|
||
| #pragma once |
There was a problem hiding this comment.
Use the required TRTLLM_*_H guard here.
#pragma once alone doesn't meet the repo header convention for new .h files. Please switch this to a TRTLLM_MOE_GEMM_MIXED_UTILS_H style include guard.
As per coding guidelines, **/*.{h,hpp,cuh}: Use a preprocessor guard in C++ header files with the format TRTLLM_<FILENAME_IN_CAPS>_H (e.g., TRTLLM_FOO_BAR_HELLO_H). Do not use directory names or trailing underscores.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.h` at
line 17, Replace the lone `#pragma once` with the repository-required include
guard: add a preprocessor guard using the macro name
TRTLLM_MOE_GEMM_MIXED_UTILS_H (open with `#ifndef` TRTLLM_MOE_GEMM_MIXED_UTILS_H /
`#define` TRTLLM_MOE_GEMM_MIXED_UTILS_H and close with `#endif`), ensuring the macro
matches the header filename moe_gemm_mixed_utils.h in all caps and without
directory names or trailing underscores; update the top and bottom of the file
(moe_gemm_mixed_utils.h) accordingly so other symbols in this header (e.g.,
declarations used by moe_gemm_mixed_utils) remain unchanged.
| if w31_weight_shard.ndim == 2: | ||
| w31_weight_shard = w31_weight_shard.cuda() | ||
| w31_weight_shard = torch.ops.trtllm.interleave_4bit_weights_for_Hopper_mixed_gemm( | ||
| w31_weight_shard, 0) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's look at the file structure and the specific lines mentioned
head -30 tensorrt_llm/_torch/modules/fused_moe/quantization.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1228
🏁 Script executed:
# Get the file size to determine how to read it
wc -l tensorrt_llm/_torch/modules/fused_moe/quantization.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 121
🏁 Script executed:
# Look at the problematic lines in context (1327-1344 for w3_w1)
sed -n '1320,1350p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -nRepository: NVIDIA/TensorRT-LLM
Length of output: 1936
🏁 Script executed:
# Also check the w2 weight loading (1379-1396)
sed -n '1370,1405p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -nRepository: NVIDIA/TensorRT-LLM
Length of output: 2184
🏁 Script executed:
# Let's search for dst_*_weight initialization to understand device setup
rg "dst_.*weight.*device|create_weights" tensorrt_llm/_torch/modules/fused_moe/quantization.py -B 2 -A 2Repository: NVIDIA/TensorRT-LLM
Length of output: 16767
🏁 Script executed:
# Get full context of load_expert_w3_w1_weight method for WInt4AFP8FusedMoEMethod
sed -n '1300,1350p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -nRepository: NVIDIA/TensorRT-LLM
Length of output: 3046
🏁 Script executed:
# Get full context of load_expert_w2_weight method
sed -n '1360,1410p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -nRepository: NVIDIA/TensorRT-LLM
Length of output: 2985
🏁 Script executed:
# Check if device variable is used later in the copy operation
sed -n '1340,1360p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -nRepository: NVIDIA/TensorRT-LLM
Length of output: 1109
🏁 Script executed:
# Look for WFP4A16FusedMoEMethod to see the FP4 path mentioned
sed -n '1720,1780p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -nRepository: NVIDIA/TensorRT-LLM
Length of output: 3209
🏁 Script executed:
# Check what the device variable is actually used for - search around copy_ calls
rg "\.copy_\(" tensorrt_llm/_torch/modules/fused_moe/quantization.py -B 5 -A 2 | head -60Repository: NVIDIA/TensorRT-LLM
Length of output: 2840
🏁 Script executed:
# Let me check what happens after the .cuda() call and before the copy_
# Look at the exact lines with full context for WInt4AFP8FusedMoEMethod
sed -n '1295,1355p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -nRepository: NVIDIA/TensorRT-LLM
Length of output: 3536
🏁 Script executed:
# Check the SM89 path to see how it handles device differently
sed -n '1312,1327p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -nRepository: NVIDIA/TensorRT-LLM
Length of output: 957
🏁 Script executed:
# Check if the custom op is called elsewhere and how device is managed there
rg "interleave_4bit_weights_for_Hopper_mixed_gemm" tensorrt_llm/_torch/modules/fused_moe/quantization.py -B 3 -A 1Repository: NVIDIA/TensorRT-LLM
Length of output: 1051
🏁 Script executed:
# Verify the difference between INT4 (WInt4AFP8FusedMoEMethod) and FP4 (WFP4A16FusedMoEMethod) paths
sed -n '1710,1780p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -nRepository: NVIDIA/TensorRT-LLM
Length of output: 3774
🏁 Script executed:
# Let me check if the custom op is CUDA-only by looking for its definition or usage patterns
rg "interleave_4bit_weights_for_Hopper_mixed_gemm" tensorrt_llm/ -A 2 | head -30Repository: NVIDIA/TensorRT-LLM
Length of output: 1289
🏁 Script executed:
# Check if there's a torch.cuda.current_device() call or similar device management
rg "torch.cuda.current_device|\.cuda\(\)" tensorrt_llm/_torch/modules/fused_moe/quantization.py | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 1154
🏁 Script executed:
# Look at how SM89 path handles device vs SM90 to understand the pattern
sed -n '1318,1332p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -nRepository: NVIDIA/TensorRT-LLM
Length of output: 1008
🏁 Script executed:
# Check if dst_w3_w1_weight.device is always CUDA in practice
rg "register_parameter|register_buffer" tensorrt_llm/_torch/modules/fused_moe/quantization.py | grep -E "(w3_w1_weight|w2_weight)" | head -10Repository: NVIDIA/TensorRT-LLM
Length of output: 603
🏁 Script executed:
# Let me trace the full flow - what is w31_weight_shard.device after line 1343?
# Check if the tensor at line 1343 would be on CPU or already on device
sed -n '1328,1350p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -nRepository: NVIDIA/TensorRT-LLM
Length of output: 1427
🏁 Script executed:
# Verify the device where packer/unpacker operations happen
rg "pack_int8_tensor_to_packed_int4|unpack_int4_packed_tensor_to_int8" tensorrt_llm/_torch/modules/fused_moe/quantization.py -B 1 -A 1 | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 390
🏁 Script executed:
# Check what device w31_weight_shard is on at line 1343 in SM90 path
# After line 1338: w31_weight_shard = packer(transposed)
# transposed comes from unpacker on CPU, then .T.contiguous() preserves CPU
# so packer output would be on CPU
sed -n '1328,1348p' tensorrt_llm/_torch/modules/fused_moe/quantization.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1187
🏁 Script executed:
# Check if there are any comments or documentation about why bare .cuda() is used
sed -n '1340,1350p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -nRepository: NVIDIA/TensorRT-LLM
Length of output: 630
🏁 Script executed:
# Look for any documentation about the custom op requirements
rg "interleave_4bit_weights_for_Hopper" tensorrt_llm/ --type-list | head -5Repository: NVIDIA/TensorRT-LLM
Length of output: 452
🏁 Script executed:
# Check if there's any enforcement that dst_w3_w1_weight must be CUDA
sed -n '1295,1310p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -nRepository: NVIDIA/TensorRT-LLM
Length of output: 948
🏁 Script executed:
# Final verification - check if the device variable is still available at line 1343
# and confirm SM89 uses it correctly
sed -n '1295,1330p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | grep -E "device\s*=|def load_expert"Repository: NVIDIA/TensorRT-LLM
Length of output: 250
Use explicit device in Hopper interleave path instead of bare .cuda().
At lines 1343 and 1389, the code uses bare .cuda() which moves tensors to the current CUDA device (via torch.cuda.current_device()). This ignores the device = dst_w3_w1_weight.device captured earlier at the method start. In multi-GPU loads where different ranks own different devices, this causes tensors to be moved to the wrong GPU before the custom operator runs.
The SM89 code path correctly uses .to(w31_weight_shard.device) to preserve device placement. Use the same pattern here, or use .to(device) with the captured device variable.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/modules/fused_moe/quantization.py` around lines 1341 -
1344, The bare .cuda() calls moving w31_weight_shard to the current CUDA device
break multi-GPU placement; replace those .cuda() uses (the ones before calling
torch.ops.trtllm.interleave_4bit_weights_for_Hopper_mixed_gemm) with an explicit
.to(device) (using the previously captured device variable `device =
dst_w3_w1_weight.device`) so the tensor is moved to the intended device before
invoking the Hopper interleave operator.
| if w31_weight_shard.ndim == 2: | ||
| w31_weight_shard = torch.ops.trtllm.interleave_4bit_weights_for_Hopper_mixed_gemm( | ||
| w31_weight_shard, 1) |
There was a problem hiding this comment.
Stage FP4 shards to CUDA before calling the new interleave op.
interleave_4bit_weights_for_Hopper_mixed_gemm is CUDA-only, but these tensors are still built on dst_*_weight.device. Since FusedMoEMethodBase.create_weights() allocates parameters on CPU by default, host-side loading can now fail here at runtime instead of preserving the previous flow.
Also applies to: 1769-1771
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/modules/fused_moe/quantization.py` around lines 1735 -
1737, The interleave op
torch.ops.trtllm.interleave_4bit_weights_for_Hopper_mixed_gemm is CUDA-only, so
before calling it move the FP4 shard tensors to CUDA (e.g., w31_weight_shard =
w31_weight_shard.to(dst_device or torch.device('cuda'), non_blocking=True) or
use .cuda()) and ensure the same for the other shard(s) used in the nearby block
(the similar call around the w13/w31 handling at the later location). Update the
code that calls interleave_4bit_weights_for_Hopper_mixed_gemm to stage the
shard(s) onto CUDA first (preserving dtype/contiguity) so the CUDA-only op runs
on GPU-hosted tensors.
| import pickle | ||
| import sys | ||
| from typing import Dict, List, Optional |
There was a problem hiding this comment.
Add the required NVIDIA Apache header.
This new test module is missing the repository-mandated copyright/license block.
As per coding guidelines, "All TensorRT-LLM source files should contain an NVIDIA copyright header with the year of the latest meaningful modification. The header should be an Apache 2.0 license block as specified."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/unittest/_torch/modules/test_fused_moe_jiangs.py` around lines 1 - 3,
This test module is missing the required NVIDIA Apache 2.0 header; add the
repository-mandated NVIDIA copyright/license block at the very top of the file
(above the import lines) using the current year of the latest meaningful
modification and the standard Apache-2.0 header format used across the repo so
the file (test_fused_moe_jiangs.py) includes the exact NVIDIA copyright and
Apache 2.0 license block.
| q1 = q2 = q3 = q3_q1 = None | ||
| if weight_loading_mode == MoEWeightLoadingMode.VANILLA: | ||
| q1 = weights[f"{e_idx}.w1.{lut['weight_scale_2']}"].cuda() | ||
| q2 = weights[f"{e_idx}.w3.{lut['weight_scale_2']}"].cuda() | ||
| q3 = weights[f"{e_idx}.w2.{lut['weight_scale_2']}"].cuda() | ||
| q3_q1 = torch.max(q3, q1) |
There was a problem hiding this comment.
Swap q2 and q3 in the reference dequant path.
q2 is passed to fc2 on Line 275, but it is loaded from w3.weight_scale_2 here. That also makes q3_q1 combine w2 with w1 for fc13. The all-ones fixtures hide it today, but any non-trivial weight_scale_2 will make the reference wrong.
🐛 Proposed fix
- q2 = weights[f"{e_idx}.w3.{lut['weight_scale_2']}"].cuda()
- q3 = weights[f"{e_idx}.w2.{lut['weight_scale_2']}"].cuda()
+ q2 = weights[f"{e_idx}.w2.{lut['weight_scale_2']}"].cuda()
+ q3 = weights[f"{e_idx}.w3.{lut['weight_scale_2']}"].cuda()Also applies to: 263-275
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/unittest/_torch/modules/test_fused_moe_jiangs.py` around lines 229 -
234, The reference dequant path loads q2 and q3 swapped: when
weight_loading_mode == MoEWeightLoadingMode.VANILLA set q2 and q3 from the
correct weight keys (q2 should come from f"{e_idx}.w2.{lut['weight_scale_2']}"
and q3 from f"{e_idx}.w3.{lut['weight_scale_2']}") so fc2 receives the proper
tensor and q3_q1 combines w3 with w1 for fc13; update both occurrences (the
block around q1/q2/q3/q3_q1 and the similar block at lines ~263-275) to swap the
assignments accordingly while keeping q1 loaded from w1.
| from tensorrt_llm._torch.custom_ops.torch_custom_ops import MoERunner | ||
|
|
||
| # Get the C++ FusedMoeRunner to query tactic descriptions | ||
| cpp_runner = next(iter(MoERunner.runner_dict.values())) |
There was a problem hiding this comment.
Don't force a runner lookup for commented diagnostics.
next(iter(MoERunner.runner_dict.values())) can raise StopIteration when this backend does not populate the registry, and nothing below consumes cpp_runner outside the commented debug block.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/unittest/_torch/modules/test_fused_moe_jiangs.py` around lines 466 -
469, The test currently forces a runner lookup with
next(iter(MoERunner.runner_dict.values())) which can raise StopIteration when
the backend doesn't populate MoERunner.runner_dict; change this to a safe
retrieval (e.g., use next(iter(...), None) or guard with a try/except) and
ensure any use of cpp_runner (the commented diagnostic block) only runs when
cpp_runner is not None so the test won't error when the registry is empty.
|
This file (tests/unittest/_torch/modules/test_fused_moe_jiangs.py) is for debugging and will be deleted later. |
) ## Summary Port [TensorRT-LLM PR #12451](NVIDIA/TensorRT-LLM#12451) to FlashInfer's `cutlass_fused_moe` SM90 path. Adds an LDSM + interleaved-LUT weight-load pipeline for 4-bit weights × 16/8-bit activations, plus the two preprocessing helpers the new kernel layout requires. ## Changes ### Kernel - `mixed_input_utils.hpp` / `sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp` — sync with TRT-LLM PR #12451 (LDSM path + FP4/INT4 → BF16 LUT converter). - `moe_gemm_mixed_utils.{cu,h}` (new) — per-row CUDA kernels for FP4/INT4 byte interleave. - `cutlass_heuristic.cpp` — for `has_w4afp8`, skip `CtaShape128x128x128B + COOPERATIVE` (register overflow on SM90) and pick COOP / PINGPONG per tile. - `moe_gemm_tma_ws_mixed_input_launcher.inl` — `scheduler.max_swizzle_size = 2`, `raster_order = Heuristic`. ### Python `flashinfer/fused_moe/core.py` exposes two helpers (re-exported by the package): - `interleave_moe_weights_for_hopper_mixed_gemm(weight, quant_type)` — byte-level interleave for `"fp4"` / `"int4"` packed uint8 weights; delegates to the C++ kernel above. - `interleave_moe_scales_for_hopper_mixed_gemm(scales, group_size=32)` — pure PyTorch reshape + permute matching TRT-LLM's `WFP4A16FusedMoEMethod.load_quant_scales`, factor = `128 // group_size`. ### Tests — inside `tests/moe/test_trtllm_cutlass_fused_moe.py` (18 new) - `test_moe_bf16_mxfp4_hopper_correctness` (5 shapes, strict `assert_close` vs a GPU-side dequantized reference that only materialises active experts to stay under H200 memory at e=256). - `test_moe_bf16_mxfp4_hopper_coverage` (5 shapes, percent-based ≥ 99.9%). - `test_moe_bf16_mxfp4_hopper_activations` (3 SwiGLU variants). - `test_moe_w4a8_hopper_correctness` (2 shapes × bf16/fp16) — envelope matches the upstream CI shape (h = inter = 512, e = 2); larger exceeds strict tolerance because of FP8 + INT4 accumulation noise, same as the existing `test_moe_w4a8`. - `test_moe_w4a8_hopper_autotune` — smoke that `autotune(True)` doesn't break the W4A8 path. All 18 green on H200 in 5.2 s cache-hot. ## Performance H200 (SM90 / HBM3e), `hidden = 4096, intermediate = 2048, experts = 256, topk = 6`, bf16 output, MXFP4 weights. `cutlass_fused_moe` median over `bench_gpu_time`. Weight + scale interleave is a one-shot model-load step and is excluded from timing. `autotune` column runs one pass under `autotune(True)` to populate the tactic cache before timing. | batch | main no-autotune | main autotune | **PR no-autotune** | **PR autotune** | **speedup (autotune)** | |------:|-----------------:|--------------:|-------------------:|----------------:|-----------------------:| | 4 | 0.791 ms | 0.513 ms | **0.221 ms** | **0.193 ms** | **2.66×** | | 16 | 1.598 ms | 1.607 ms | **0.530 ms** | **0.532 ms** | **3.02×** | | 64 | 3.761 ms | 3.757 ms | **1.200 ms** | **1.207 ms** | **3.11×** | --------- Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
Summary by CodeRabbit
Release Notes
New Features
Bug Fixes
Tests
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.