Skip to content

perf: optimize MXFP4xBF16 & INT4xFP8 CUTLASS MoE backend for SM90#3084

Merged
samuellees merged 33 commits intoflashinfer-ai:mainfrom
samuellees:feat/w4a16-moe-kernel
Apr 23, 2026
Merged

perf: optimize MXFP4xBF16 & INT4xFP8 CUTLASS MoE backend for SM90#3084
samuellees merged 33 commits intoflashinfer-ai:mainfrom
samuellees:feat/w4a16-moe-kernel

Conversation

@samuellees
Copy link
Copy Markdown
Collaborator

@samuellees samuellees commented Apr 15, 2026

Summary

Port TensorRT-LLM PR #12451 to FlashInfer's cutlass_fused_moe SM90 path. Adds an LDSM + interleaved-LUT weight-load pipeline for 4-bit weights × 16/8-bit activations, plus the two preprocessing helpers the new kernel layout requires.

Changes

Kernel

  • mixed_input_utils.hpp / sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp — sync with TRT-LLM PR #12451 (LDSM path + FP4/INT4 → BF16 LUT converter).
  • moe_gemm_mixed_utils.{cu,h} (new) — per-row CUDA kernels for FP4/INT4 byte interleave.
  • cutlass_heuristic.cpp — for has_w4afp8, skip CtaShape128x128x128B + COOPERATIVE (register overflow on SM90) and pick COOP / PINGPONG per tile.
  • moe_gemm_tma_ws_mixed_input_launcher.inlscheduler.max_swizzle_size = 2, raster_order = Heuristic.

Python

flashinfer/fused_moe/core.py exposes two helpers (re-exported by the package):

  • interleave_moe_weights_for_hopper_mixed_gemm(weight, quant_type) — byte-level interleave for "fp4" / "int4" packed uint8 weights; delegates to the C++ kernel above.
  • interleave_moe_scales_for_hopper_mixed_gemm(scales, group_size=32) — pure PyTorch reshape + permute matching TRT-LLM's WFP4A16FusedMoEMethod.load_quant_scales, factor = 128 // group_size.

Tests — inside tests/moe/test_trtllm_cutlass_fused_moe.py (18 new)

  • test_moe_bf16_mxfp4_hopper_correctness (5 shapes, strict assert_close vs a GPU-side dequantized reference that only materialises active experts to stay under H200 memory at e=256).
  • test_moe_bf16_mxfp4_hopper_coverage (5 shapes, percent-based ≥ 99.9%).
  • test_moe_bf16_mxfp4_hopper_activations (3 SwiGLU variants).
  • test_moe_w4a8_hopper_correctness (2 shapes × bf16/fp16) — envelope matches the upstream CI shape (h = inter = 512, e = 2); larger exceeds strict tolerance because of FP8 + INT4 accumulation noise, same as the existing test_moe_w4a8.
  • test_moe_w4a8_hopper_autotune — smoke that autotune(True) doesn't break the W4A8 path.

All 18 green on H200 in 5.2 s cache-hot.

Performance

H200 (SM90 / HBM3e), hidden = 4096, intermediate = 2048, experts = 256, topk = 6, bf16 output, MXFP4 weights. cutlass_fused_moe median over bench_gpu_time. Weight + scale interleave is a one-shot model-load step and is excluded from timing. autotune column runs one pass under autotune(True) to populate the tactic cache before timing.

batch main no-autotune main autotune PR no-autotune PR autotune speedup (autotune)
4 0.791 ms 0.513 ms 0.221 ms 0.193 ms 2.66×
16 1.598 ms 1.607 ms 0.530 ms 0.532 ms 3.02×
64 3.761 ms 3.757 ms 1.200 ms 1.207 ms 3.11×

Integrate key optimizations from TensorRT-LLM PR #12451 for the
mixed-dtype MoE GEMM path (MXFP4 weights + BF16 activations):

- cutlass_heuristic.cpp: Skip 128x128x128B + COOPERATIVE scheduler
  combo for W4A16 grouped GEMM to avoid register overflow on SM90.
  Fall back to PINGPONG for this tile, keeping COOPERATIVE for others.

- moe_gemm_tma_ws_mixed_input_launcher.inl: Add max_swizzle_size=2
  scheduler hint for better L2 cache locality.

- test_w4a16_moe.py: Add 30 parametrized test cases covering:
  batch_size=[1..512], hidden_size=[2048..7168], num_experts=[8..256],
  top_k=[1..8], intermediate_size=[1024..4096], activation variants.
  Core target config: experts=256, topk=6, hidden=4096, inter=2048.

AI-assisted

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 15, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Restricts a CUTLASS mainloop schedule for one SM90 tile when W4A16+AFP8 is enabled, sets a grouped-GEMM scheduler swizzle-size and raster-order override, adds Hopper-specific weight-interleaving CUDA helpers, new SM90-only parametrized tests for W4A16 MoE, and extends mixed-input conversion utilities (int4→fp8, fp4→bf16 LUTs) and SM100 helpers.

Changes

Cohort / File(s) Summary
CUTLASS heuristics
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp
When building SM90 candidate configs for W4A16+AFP8, exclude MainloopScheduleType::COOPERATIVE for tile CtaShape128x128x128B; keep PINGPONG.
Grouped GEMM launcher
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl
Add include for scheduler params and set arguments.scheduler.max_swizzle_size = 2 and arguments.scheduler.raster_order = PersistentTileSchedulerSm90Params::RasterOrderOptions::Heuristic before group-size assertion and workspace logic.
MoE mixed GEMM utils (new)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu, csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.h
Add Hopper-specific CUDA kernels and host wrappers to interleave FP4 and INT4 packed weights for mixed-precision GEMM; new exported helper declarations.
Mixed-input conversion & SM100 helpers
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp
Add int4→fp8 LUT conversion, rework fp4→bf16 LUT path to lane-dependent constants, add MixedInputUtilsSM100 with transform-based dequantization and adjust conversion dispatch and elements_per_smem_zero logic.
Tests (new)
tests/moe/test_w4a16_moe.py
Add SM90-only pytest module exercising fused MoE mixed-dtype (BF16 activations + W4A16 packed weights) with parametrized coverage, activation-parameter variants, and core-config test; skip when SM90 not supported.

Sequence Diagram(s)

sequenceDiagram
    participant Test
    participant Host as CPU Memory
    participant Launcher as GEMM Launcher
    participant Device as GPU Kernel
    Note over Test,Launcher: SM90-only test prepares inputs, packed weights, and quant scales
    Test->>Host: allocate activations, packed weights, scales, router logits
    Test->>Launcher: invoke cutlass_fused_moe(..., use_w4_group_scaling=True)
    Launcher->>Host: configure grouped-GEMM arguments (scheduler.max_swizzle_size=2, raster_order=Heuristic)
    Launcher->>Device: launch interleave_fp4/int4 kernels on stream
    Device-->>Host: write interleaved weight buffers
    Launcher->>Device: launch CUTLASS fused MoE GEMM kernel(s)
    Device-->>Host: write output tensor
    Test->>Host: validate output (finiteness, shape)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

run-ci, op: gemm

Suggested reviewers

  • aleozlx
  • jiahanc
  • IwakuraRein
  • yongwww
  • cyx-6
  • djns99

Poem

🐰
I hop through tiles and swap two little nibbles,
Tweak schedulers, pack weights, avoid sleepy dribbles.
Tests bound in fields of BF16 light,
Kernels hum softly into the night.
Small hops, bright kernels — all set just right.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 22.50% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'perf: optimize MXFP4xBF16 & INT4xFP8 CUTLASS MoE backend for SM90' accurately summarizes the main optimization work across multiple kernel components and configurations for SM90 Hopper hardware.
Description check ✅ Passed The PR description is comprehensive, detailed, and directly addresses the changes in the code.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds a test suite for W4A16 MoE kernels on SM90 and optimizes kernel heuristics by disabling the cooperative scheduler for specific tile configurations and adjusting swizzle sizes for better L2 locality. Feedback suggests simplifying conditional logic in the heuristics and addressing unused functions and parameters in the test code.

Comment thread csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp Outdated
Comment thread tests/moe/test_w4a16_moe.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/moe/test_w4a16_moe.py`:
- Around line 124-125: The test currently sets check_correctness=True but never
compares numeric outputs to a reference; update the test so that when
check_correctness is True it runs only the smallest, quick case (e.g., the
minimal batch/num_experts config used for fast CI), computes a deterministic
baseline (seed RNG) or loads a stored reference output, and compares outputs
(e.g., logits/tensors) elementwise or via assert_allclose to the reference;
modify the validation block that currently only checks finiteness/shape (the
code around the finiteness/shape asserts) to perform this numeric comparison
guarded by check_correctness and only for the quick case to avoid long runtimes.
Ensure you reference the check_correctness flag and the small-case configuration
when adding the assert_allclose comparison.
- Around line 16-20: Remove the custom _is_sm90() helper and instead import and
call is_sm90a_supported() from flashinfer.utils; update any conditional test
skips in tests/moe/test_w4a16_moe.py to use is_sm90a_supported() (and add the
import), and delete the _is_sm90() definition so the test follows the repo-wide
standardized architecture check helper pattern.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 0216a0c2-9850-420e-abe7-1b08b3746c09

📥 Commits

Reviewing files that changed from the base of the PR and between 25b324d and 4c2372b.

📒 Files selected for processing (3)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl
  • tests/moe/test_w4a16_moe.py

Comment thread tests/moe/test_w4a16_moe.py Outdated
Comment thread tests/moe/test_w4a16_moe.py Outdated
samuellees and others added 3 commits April 15, 2026 18:30
- Invert condition to eliminate empty if-block in heuristic (readability)
- Use is_sm90a_supported() instead of custom _is_sm90() (codebase convention)
- Remove unused _dequant_mxfp4_host, _compute_reference, check_correctness

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
… #12451

- Add moe_gemm_mixed_utils.cu/.h: CUDA kernels for FP4 and INT4 weight
  interleaving on Hopper, reorganizing weight layout for optimal TMA
  scheduling in mixed-precision grouped GEMM.

- Add RasterOrderOptions::Heuristic to mixed-input launcher scheduler
  config, complementing the existing max_swizzle_size=2 setting.

These are direct ports from TRTLLM PR #12451 commits cd541ba and 79315f6.

Note: The CUTLASS extension changes (mixed_input_utils.hpp LDSM pipeline,
Int4->FP8 LUT, scale_convertor in sm90_mma_array...hpp) are NOT included
in this commit. Those changes involve deep CUTLASS template restructuring
and the upstream PR has open review issues (scale tensormap not updated,
lane mapping errors). They should be migrated after the upstream PR is
finalized and merged.

AI-assisted

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu (1)

35-39: Document the swizzle and launch geometry choices.

The lane remaps here and the fixed 1024-block launches are hard to reason about without context. A short note on why Hopper mixed GEMM needs this layout, and why a simpler linear repack or shape-derived grid was not used, would make future tuning much safer.

As per coding guidelines, "For performance-critical hot paths, leave comments with justification for special algorithmic choices and mention alternative approaches considered."

Also applies to: 73-77, 95-108

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu`
around lines 35 - 39, Add concise in-code documentation explaining the swizzle
and launch geometry choices: annotate the computation of interleaved_lane_id
(using lane_id and partition_id) and derived indices col_id and dst_col_id, plus
the fixed 1024-block launch size, with why this particular remapping is required
for Hopper mixed GEMM (e.g., tensor core/hardware lane packing, bank conflicts,
warp-per-thread mapping) and why simpler alternatives (linear repack or
shape-derived grid) were rejected; place these comments next to the interleaving
code (interleaved_lane_id, col_id, dst_col_id) and repeat/expand the
justification near the other identical sections mentioned (the code around lines
referenced as 73-77 and 95-108) so future tuners can understand the trade-offs
and assumptions.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In
`@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu`:
- Around line 23-109: This file was reformatted by clang-format and the
pre-commit check fails; run clang-format (or your project's formatting tool) on
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu
to apply the expected style. Ensure you reformat the entire file and stage the
changes before committing; focus on the functions
interleave_fp4_weights_for_Hopper_mixed_gemm_kernel,
interleave_int4_weights_for_Hopper_mixed_gemm_kernel and their callers
interleave_fp4_weights_for_Hopper_mixed_gemm /
interleave_int4_weights_for_Hopper_mixed_gemm so the diff matches the project's
clang-format rules.
- Around line 27-28: The kernels in moe_gemm_mixed_utils.cu assume rows % 16 ==
0 and cols % 64 == 0, but the current launch loops (using block_id from
blockIdx.x and partition_id from threadIdx.y) can read past bounds (e.g.,
accesses like row_id + 8) and silently drop column remainders; add a host-side
validation before launching these kernels that checks the input dimensions (rows
and cols) and either (a) returns/throws an error for unsupported shapes or (b)
pads/rounds up the buffers to multiples of 16 (rows) and 64 (cols) and documents
that fallback behavior; ensure this check is performed wherever these kernels
are invoked so the loops governed by block_id/partition_id never encounter tails
that would corrupt the interleaved buffer.

In
`@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.h`:
- Around line 17-32: Run clang-format on this header to match the repo style:
reorder and format includes and function declarations so they follow the
project's include order/spacing rules and line-wrapping conventions.
Specifically, fix the include order and spacing around the header guard/pragma
and reflow the declarations for interleave_fp4_weights_for_Hopper_mixed_gemm and
interleave_int4_weights_for_Hopper_mixed_gemm so they match the repo's
wrapped-declaration style (use the project's preferred line breaks, parameter
alignment, and trailing semicolons). Save the file after clang-format so
pre-commit no longer fails.

---

Nitpick comments:
In
`@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu`:
- Around line 35-39: Add concise in-code documentation explaining the swizzle
and launch geometry choices: annotate the computation of interleaved_lane_id
(using lane_id and partition_id) and derived indices col_id and dst_col_id, plus
the fixed 1024-block launch size, with why this particular remapping is required
for Hopper mixed GEMM (e.g., tensor core/hardware lane packing, bank conflicts,
warp-per-thread mapping) and why simpler alternatives (linear repack or
shape-derived grid) were rejected; place these comments next to the interleaving
code (interleaved_lane_id, col_id, dst_col_id) and repeat/expand the
justification near the other identical sections mentioned (the code around lines
referenced as 73-77 and 95-108) so future tuners can understand the trade-offs
and assumptions.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: d1944a26-f0ae-49a7-b1ae-2bc06aa19d0a

📥 Commits

Reviewing files that changed from the base of the PR and between 931b87a and 9c4d204.

📒 Files selected for processing (3)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.h
🚧 Files skipped from review as they are similar to previous changes (1)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl

Comment on lines +27 to +28
for (int block_id = blockIdx.x; block_id < rows / 2; block_id += gridDim.x) {
for (int partition_id = threadIdx.y; partition_id < cols / 64; partition_id += blockDim.y) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fail fast on unsupported matrix shapes.

These kernels only work when rows is a multiple of 16 and cols is a multiple of 64. On Line 42 and Line 80, row_id + 8 can read past the last tile when rows has a tail, and the partition_id < cols / 64 loops silently drop remainder columns. Please add a host-side check or fallback before launch so unsupported shapes do not corrupt the interleaved buffer.

Also applies to: 41-56, 66-67, 79-88, 95-108

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu`
around lines 27 - 28, The kernels in moe_gemm_mixed_utils.cu assume rows % 16 ==
0 and cols % 64 == 0, but the current launch loops (using block_id from
blockIdx.x and partition_id from threadIdx.y) can read past bounds (e.g.,
accesses like row_id + 8) and silently drop column remainders; add a host-side
validation before launching these kernels that checks the input dimensions (rows
and cols) and either (a) returns/throws an error for unsupported shapes or (b)
pads/rounds up the buffers to multiples of 16 (rows) and 64 (cols) and documents
that fallback behavior; ensure this check is performed wherever these kernels
are invoked so the loops governed by block_id/partition_id never encounter tails
that would corrupt the interleaved buffer.

… #12451

Port core CUTLASS extension changes from TensorRT-LLM PR #12451 (commit
cd541ba) for Hopper mixed-precision MoE GEMM performance:

mixed_input_utils.hpp:
- Add Int4->FP8 E4M3 lookup table conversion (psx_cvt_lut_prmt_int4x8_to_fp8x8)
  with lane-dependent LUT constants for parallel thread processing
- Rename cvt_lut_bf16 -> cvt_lut_fp4_to_bf16 with thread-aware LUT indexing
- Rename psx_cvt_lut_prmt_fp4x8_to_bf16x8 -> _interleaved with optimized
  bit manipulation for interleaved weight layout
- Add copy_tensors_A/copy_tensors_SFA for separated tensor copy paths
- Add int4tofp8_lookup_table_convert template method
- Add UseInt4ToFP8LookupTable constraint flag

sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp:
- Restructure A operand loading with LDSM-based copy/retiling using
  SM75_U32x4_LDSM_N for improved memory access patterns
- Add scale_convertor template for FP2M1 scale type conversion
- Add TensormapUpdateShapesStridesForAandScale flag for conditional
  tensormap shape/stride updates
- Separate scale copying into dedicated copy_tensors_SFA calls
- Enhance tensormaps_replace_global_address/properties/cp_fence_release
  with conditional update logic

Data format note: The interleaved weight layout requires weights to be
preprocessed with interleave_fp4_weights_for_Hopper_mixed_gemm (added in
prior commit) before being passed to the GEMM kernel.

AI-assisted

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp (1)

652-669: Misleading comment should be updated.

The comment on line 652 states KernelConversionMode == ConversionMode::DirectConvert, but this function now dispatches to UseFP4ToBF16LookupTable and UseInt4ToFP8LookupTable paths, which are only enabled when KernelConversionMode == ConversionMode::ConvertAndScale (per the definitions in the collective header). Consider updating or removing the comment to reflect the actual conversion modes handled.

Suggested comment update
-        // KernelConversionMode == ConversionMode::DirectConvert
+        // Type conversion dispatch: LUT paths for FP4→BF16 and INT4→FP8, generic converter otherwise
         CUTLASS_PRAGMA_UNROLL
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp`
around lines 652 - 669, The comment "KernelConversionMode ==
ConversionMode::DirectConvert" is now misleading because this block also
dispatches to UseFP4ToBF16LookupTable and UseInt4ToFP8LookupTable conversion
paths; update or remove the comment to accurately reflect supported modes (e.g.,
mention both DirectConvert and ConvertAndScale or remove mode-specific wording).
Modify the comment near the loop that references
KernelConversionMode/ConversionMode::DirectConvert and ensure it calls out the
conditional branches UseFP4ToBF16LookupTable, UseInt4ToFP8LookupTable, and the
fallback LayoutAwareConvert (and the called helpers
fp4tobf16_lookup_table_convert and int4tofp8_lookup_table_convert) so the
comment matches actual behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In
`@csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp`:
- Around line 652-669: The comment "KernelConversionMode ==
ConversionMode::DirectConvert" is now misleading because this block also
dispatches to UseFP4ToBF16LookupTable and UseInt4ToFP8LookupTable conversion
paths; update or remove the comment to accurately reflect supported modes (e.g.,
mention both DirectConvert and ConvertAndScale or remove mode-specific wording).
Modify the comment near the loop that references
KernelConversionMode/ConversionMode::DirectConvert and ensure it calls out the
conditional branches UseFP4ToBF16LookupTable, UseInt4ToFP8LookupTable, and the
fallback LayoutAwareConvert (and the called helpers
fp4tobf16_lookup_table_convert and int4tofp8_lookup_table_convert) so the
comment matches actual behavior.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 4bacfaeb-b7f8-4a6f-b27f-76978cce257c

📥 Commits

Reviewing files that changed from the base of the PR and between 9c4d204 and 31f80e2.

📒 Files selected for processing (2)
  • csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp
  • csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp

samuellees and others added 3 commits April 17, 2026 03:30
Benchmark wMXFP4 x BF16 MoE GEMM on H20 (SM90) with configurable
batch size, TP/EP splits, and expert/hidden/intermediate dimensions.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ctness checks

Revert mixed_input_utils.hpp and sm90_mma_array...hpp to main branch versions.
The TRTLLM PR #12451 versions require interleaved weight layout preprocessing
which is incompatible with FlashInfer's current API — all 18 tests failed with
95.8% element mismatch and max abs error 37.25 (expected < 0.1).

The heuristic, launcher scheduler, and weight interleave utility changes from
TRTLLM PR #12451 are retained as they don't change the GEMM kernel data path.

Also update test_w4a16_moe.py to include dequant-based reference correctness
verification (matching test_trtllm_cutlass_fused_moe.py, rtol=1e-1, atol=1e-1).

AI-assisted

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@samuellees samuellees marked this pull request as draft April 18, 2026 13:18
…ty only

The dequant-based reference comparison (rtol=1e-1, atol=1e-1) only works
reliably at small problem sizes (hidden=128, experts=2-8), matching the
original test_trtllm_cutlass_fused_moe.py methodology. At large sizes
(hidden=4096, experts=256), cumulative FP4 quantization error exceeds
the tolerance threshold.

- CORRECTNESS_CONFIGS: small sizes with strict assert_close
- COVERAGE_CONFIGS: large sizes with finite+shape sanity checks
- ACTIVATION_CONFIGS: small sizes with strict assert_close

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
samuellees and others added 5 commits April 20, 2026 19:33
…ntouched

Moves all PR-specific test coverage out of
tests/moe/test_trtllm_cutlass_fused_moe.py and into two new dedicated files,
so the existing upstream test matrix stays at its original (bs=1, h=128)
scale and the PR does not bundle matrix bumps that regressed unrelated
SM100/SM120 MXFP8 paths.

Files
-----
- tests/moe/test_trtllm_cutlass_fused_moe.py:
  reverted to origin/main (drops the BATCH_SIZES/HIDDEN_SIZES/INTERMEDIATE_SIZES
  bump from (1, 128, 128) to (16, 512, 512), drops the interleave calls added
  inside test_moe_bf16_mxfp4 and test_moe_w4a8, drops the act-scale dtype /
  reference input_scale tweaks inside test_moe_w4a8). All of those logical
  fixes are re-applied in the new files below, but only for the PR's own
  SM90 mixed-input paths.

- tests/moe/test_w4a16_moe.py (restored from earlier history):
  MXFP4 x BF16 (W4A16) SM90 tests. Covers 5 CORRECTNESS_CONFIGS
  (h=128..4096) under strict assert_close, 8 COVERAGE_CONFIGS (primary target
  e=256/topk=6/h=4096/n=2048 stressed along batch / shape / expert-count /
  top_k axes) under a 99.9%-elements-within-tolerance check, 3 SwiGLU
  ACTIVATION_CONFIGS, and a single primary-target smoke. All inputs go
  through interleave_moe_weights_for_hopper_mixed_gemm("fp4") +
  interleave_moe_scales_for_hopper_mixed_gemm before cutlass_fused_moe.
  Reference dequantizes MXFP4 on-device via a small FP4 LUT and only covers
  the top-k active experts (avoids OOM on e=256).

- tests/moe/test_w4a8_moe.py (new):
  INT4 x FP8_e4m3 (W4A8) SM90 tests, dedicated to the W4A8 AWQ-style path
  that uses per-group weight scales + per-channel pre-quant act scales.
  Mirrors the structure of test_w4a16_moe.py (correctness / coverage /
  autotune / core_config) and parametrizes both bf16 and fp16 output dtypes
  on the correctness sweep. Weights go through
  interleave_moe_weights_for_hopper_mixed_gemm("int4"); weight scales use
  TRTLLM's factor-4/2/1 reshape+permute and the SM90 bf16-bitpattern trick;
  activation scales stay in the native dtype (consumed by
  expandInputRows / applyPrequantScale as OutputType).
  fc1_input_scale for the reference is a broadcast max over experts — the
  kernel folds per-expert input scales into a single divisor, so per-expert
  scales in the reference would double-correct and diverge.

AI-assisted
The SM90 W4A8 kernel carries enough FP8 + INT4 accumulation noise that the
assert_close(rtol=1e-2, atol=1e-1) tolerance vs a float32 PyTorch reference
is only achievable at h == intermediate_size == 512 with num_experts == 2.
This matches the envelope the upstream CI test already uses; going beyond
it fails in the upstream test too (independently verified on H200:
e=2/h=2048 and e=8/h=512 both fail in upstream).

Shrinks the file to two strict configs plus one autotune smoke, drops the
coverage / core_config sweeps that exceeded the envelope.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Move the dedicated test_w4a16_moe.py / test_w4a8_moe.py coverage into
test_trtllm_cutlass_fused_moe.py as six new test functions:

  test_moe_bf16_mxfp4_hopper_correctness   (5 configs)
  test_moe_bf16_mxfp4_hopper_coverage      (8 configs, percent >= 99.9%)
  test_moe_bf16_mxfp4_hopper_activations   (3 swiglu variants)
  test_moe_w4a8_hopper_correctness         (2 configs x 2 dtypes)
  test_moe_w4a8_hopper_autotune            (smoke)

Reuses existing module-level helpers (compute_routing, dequantize_int4_to_dtype,
torch_moe_w4a8). Adds a GPU-side MXFP4 dequant + active-expert reference
compute to keep e=256 / h=4096 / n=2048 coverage from OOMing vs
dequant_mxfp4_batches_host.

21/21 green on H200 (bf16 + fp16, strict assert_close for correctness, 99.9%
percent-based for coverage).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Previously batch was only covered at m=4 / m=16; add a single-token entry
so the W4A8 strict envelope matches the W4A16 side (which already covers
m=1 in both correctness and coverage). Verified on H200.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
User-requested trim: remove the m=16 batch across W4A16 coverage, W4A8
correctness, and W4A8 autotune. Replace the six W4A16 coverage entries
that relied on m=16 (to stress non-batch axes) with m=4; drop three
configs that OOM on H200 under parametrize fragmentation
(m=512 / h=7168-e256 / n=4096-e256).

Final envelope: 18 tests, H200 5.21s.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Comment thread csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu Outdated
Comment thread csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu Outdated
Comment thread csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu Outdated
Comment thread tests/moe/test_trtllm_cutlass_fused_moe.py Outdated
Comment thread tests/moe/test_trtllm_cutlass_fused_moe.py
Addresses review feedback (samuellees on PR flashinfer-ai#3084): function names should
use the SM arch identifier (sm90) rather than the marketing name
(Hopper/hopper), for consistency with the rest of the fused_moe API.

Renamed across:
  C++ kernel impl (moe_gemm_mixed_utils.{cu,h})
  C++ binding + TVM_FFI export (flashinfer_cutlass_fused_moe_binding.cu)
  Python helpers + docstrings (flashinfer/fused_moe/core.py, __init__.py)
  Tests (tests/moe/test_trtllm_cutlass_fused_moe.py)

External references to TRT-LLM's upstream
interleave_4bit_weights_for_Hopper_mixed_gemm retain the original name
(it's still called that in TRT-LLM). Also tightened the block comment
introducing the new tests to no longer depend on cross-references into
the upstream test_moe_w4a8.

Tests: 18/18 green on H200 (cache-cold build + test run 13 min).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@samuellees samuellees force-pushed the feat/w4a16-moe-kernel branch from 9388f37 to cb90611 Compare April 21, 2026 11:57
@samuellees
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !569 has been updated with latest changes, and the CI pipeline #49088402 is currently running. I'll report back once the pipeline job completes.

samuellees and others added 2 commits April 21, 2026 05:42
- clang-format / ruff-format reflow triggered by the earlier
  Hopper → sm90 rename (shorter names let the formatters pack more on
  fewer lines).
- Add the two new public helpers to docs/api/fused_moe.rst so Sphinx
  autodoc picks them up:
    * interleave_moe_weights_for_sm90_mixed_gemm
    * interleave_moe_scales_for_sm90_mixed_gemm

mypy errors reported by pre-commit (`ActivationType` etc. "not defined")
are pre-existing — same 74 errors on plain origin/main — and come from
the wildcard `from ..tllm_enums import *` in flashinfer/fused_moe/core.py.
Not touched by this PR.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@samuellees
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !569 has been updated with latest changes, and the CI pipeline #49090861 is currently running. I'll report back once the pipeline job completes.

@samuellees
Copy link
Copy Markdown
Collaborator Author

@flashinfer-bot rerun failed

The new SM90 mixed-input MoE kernel (ported from TRT-LLM PR #12451)
expects weights and MXFP4 block scales in an interleaved byte layout.
The upstream tests passed raw weights / raw scales, which produced stale
output:

  test_moe_bf16_mxfp4[*-128-2-2-128-1]: 18-54% mismatched on H200/H100
  test_moe_w4a8: passes at h=128 by tolerance, but is wrong layout-wise

Both tests now apply the public preprocessing helpers added by this PR:

  test_moe_bf16_mxfp4:
    + interleave_moe_weights_for_sm90_mixed_gemm(w, "fp4")
    + interleave_moe_scales_for_sm90_mixed_gemm(w_scale)

  test_moe_w4a8:
    + interleave_moe_weights_for_sm90_mixed_gemm(w, "int4")
    (scale interleave was already done via the local interleave_weights
     helper.)

7 H200 tests under upstream test_moe_{bf16_mxfp4,w4a8} green; together
with the 18 new sm90-hopper tests this PR added: 25/25 in 5.3s.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@samuellees
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !569 has been updated with latest changes, and the CI pipeline #49156313 is currently running. I'll report back once the pipeline job completes.

@samuellees
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !569 has been updated with latest changes, and the CI pipeline #49156422 is currently running. I'll report back once the pipeline job completes.

@samuellees samuellees enabled auto-merge (squash) April 22, 2026 14:41
Copy link
Copy Markdown
Collaborator

@aleozlx aleozlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@samuellees samuellees merged commit d454492 into flashinfer-ai:main Apr 23, 2026
76 of 78 checks passed
@samuellees samuellees deleted the feat/w4a16-moe-kernel branch April 24, 2026 01:57
@aleozlx aleozlx mentioned this pull request Apr 25, 2026
aleozlx added a commit that referenced this pull request May 5, 2026
## Description

Bump version to 0.6.10 for release.

## Related Issues (Gated-by PRs)


https://github.com/flashinfer-ai/flashinfer/issues?q=is%3Aopen+label%3Av0.6.10

## Reviewer Notes

**API changes review**

API changes since v0.6.9

```diff
$ git diff v0.6.9..main -- "*.py" | grep -B5 -A20 "@flashinfer_api"
     register_custom_op,
@@ -67,7 +73,7 @@ def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=silu_and_mul_trace)
 def silu_and_mul(
     input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None
 ) -> torch.Tensor:
@@ -112,7 +118,7 @@ def silu_and_mul(
     return out
 
 
-@flashinfer_api
+@flashinfer_api(trace=gelu_tanh_and_mul_trace)
 def gelu_tanh_and_mul(
     input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None
 ) -> torch.Tensor:
@@ -153,7 +159,7 @@ def gelu_tanh_and_mul(
     return out
 
 
-@flashinfer_api
+@flashinfer_api(trace=gelu_and_mul_trace)
 def gelu_and_mul(
     input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None
 ) -> torch.Tensor:
@@ -194,7 +200,7 @@ def gelu_and_mul(
     return out
 
 
-@flashinfer_api
+@flashinfer_api(trace=silu_and_mul_scaled_nvfp4_experts_quantize_trace)
 def silu_and_mul_scaled_nvfp4_experts_quantize(
     a,
     mask,
diff --git a/flashinfer/aot.py b/flashinfer/aot.py
index dfb05150..d26d5407 100644
--- a/flashinfer/aot.py
+++ b/flashinfer/aot.py
@@ -543,6 +543,7 @@ def gen_all_modules(
     if add_comm:
         from .jit.comm import (
             gen_comm_alltoall_module,
+            gen_dcp_alltoall_module,
             gen_moe_alltoall_module,
             gen_trtllm_comm_module,
             gen_trtllm_mnnvl_comm_module,
@@ -554,6 +555,11 @@ def gen_all_modules(
             jit_specs.append(gen_trtllm_comm_module())
             jit_specs.append(gen_trtllm_mnnvl_comm_module())
             jit_specs.append(gen_moe_alltoall_module())
+            # dcp_alltoall: kernel itself supports SM90+, but ptxas 12.6.0 has
--
 
-def flashinfer_api(func: Callable = None) -> Callable:
+# ---------------------------------------------------------------------------
+# Trace template registry
+# ---------------------------------------------------------------------------
+# Populated automatically by _attach_fi_trace whenever @flashinfer_api is
+# given a trace= argument.  Each entry is (original_func, template, label)
+# where label is the template's name_prefix (or op_type as fallback).
+#
+# For dispatch callables (trace=some_fn), every template listed in
+# some_fn.templates is registered if that attribute exists.
+#
+# Read by tests/trace/test_fi_trace_template_consistency.py to auto-discover
+# all registered templates without requiring manual maintenance.
+_TRACE_REGISTRY: List[Tuple[Callable, Any, str]] = []
+
+
+def _attach_fi_trace(
+    wrapped: Callable,
+    original: Callable,
+    trace_template=None,
+) -> Callable:
+    """Attach a ``fi_trace`` callable to *wrapped*.
+
+    Three resolution strategies, tried in order:
+
--
+
+        warnings.warn(
+            f"[flashinfer] Failed to attach fi_trace to '{_func_name}': "
+            f"{type(_exc).__name__}: {_exc}\n"
+            f"The function will work normally but fi_trace will be unavailable. "
+            f"Fix the TraceTemplate passed to @flashinfer_api(trace=...).",
+            stacklevel=3,
+        )
+    return wrapped
+
+
+def flashinfer_api(func: Callable = None, *, trace=None) -> Callable:
     """
     Decorator to FlashInfer's APIs.
 
@@ -1489,11 +1644,12 @@ def flashinfer_api(func: Callable = None) -> Callable:
     - The %i pattern is automatically replaced with the process ID for multi-process environments.
     - The logger does not propagate to the root logger to avoid duplicate logs.
     """
-    # If logging is disabled, return original function with zero overhead
+    # If logging is disabled, return original function with zero overhead.
+    # We still attach fi_trace so it is always available regardless of log level.
     if _API_LOG_LEVEL == 0:
         if func is None:
-            return lambda f: f
-        return func
--
 @functools.cache
@@ -135,7 +136,7 @@ class BatchAttention:
             causal,
         )
 
-    @flashinfer_api
+    @flashinfer_api(trace=batch_attention_run_trace)
     def run(
         self,
         q: torch.Tensor,
@@ -209,6 +210,8 @@ class BatchAttentionWithAttentionSinkWrapper(BatchPrefillWithPagedKVCacheWrapper
     a convenient interface for using attention sinks during prefill or decode attention.
     """
 
+    # No @flashinfer_api here: parent class BatchPrefillWithPagedKVCacheWrapper
+    # already decorates __init__, so decorating again produces double log entries.
     def __init__(
         self,
         float_workspace_buffer: torch.Tensor,
diff --git a/flashinfer/attention/cute_dsl/__init__.py b/flashinfer/attention/cute_dsl/__init__.py
new file mode 100644
index 00000000..3e029627
--- /dev/null
+++ b/flashinfer/attention/cute_dsl/__init__.py
@@ -0,0 +1,37 @@
+# Copyright (c) 2026 by FlashInfer team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
--
 
@@ -31,7 +37,7 @@ def get_cascade_module():
     return gen_cascade_module().build_and_load()
 
 
-@flashinfer_api
+@flashinfer_api(trace=merge_state_trace)
 @register_custom_op("flashinfer::merge_state", mutates_args=())
 def merge_state(
     v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
@@ -98,7 +104,7 @@ def _fake_merge_state(
     return v, s
 
 
-@flashinfer_api
+@flashinfer_api(trace=merge_state_in_place_trace)
 @register_custom_op("flashinfer::merge_state_in_place", mutates_args=("v", "s"))
 def merge_state_in_place(
     v: torch.Tensor,
@@ -159,7 +165,7 @@ def _fake_merge_state_in_place(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=merge_states_trace)
 @register_custom_op("flashinfer::merge_states", mutates_args=())
 def merge_states(v: torch.Tensor, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
     r"""Merge multiple attention states (v, s).
@@ -512,7 +518,7 @@ class MultiLevelCascadeAttentionWrapper:
 
     begin_forward = plan
 
-    @flashinfer_api
+    @flashinfer_api(trace=multi_level_cascade_run_trace)
     def run(
         self,
         q: torch.Tensor,
diff --git a/flashinfer/comm/__init__.py b/flashinfer/comm/__init__.py
index 5f186002..31d23a99 100644
--- a/flashinfer/comm/__init__.py
+++ b/flashinfer/comm/__init__.py
@@ -65,4 +65,15 @@ from .trtllm_moe_alltoall import (
     moe_a2a_wrap_payload_tensor_in_workspace as moe_a2a_wrap_payload_tensor_in_workspace,
 )
 
+# DCP A2A (Decode Context Parallel Attention Reduction)
+from .dcp_alltoall import decode_cp_a2a_alltoall as decode_cp_a2a_alltoall
+from .dcp_alltoall import (
+    decode_cp_a2a_allocate_workspace as decode_cp_a2a_allocate_workspace,
+)
+from .dcp_alltoall import decode_cp_a2a_init_workspace as decode_cp_a2a_init_workspace
+from .dcp_alltoall import decode_cp_a2a_workspace_size as decode_cp_a2a_workspace_size
+
 # from .mnnvl import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo
--
 from .trtllm_ar import trtllm_create_ipc_workspace_for_all_reduce_fusion
@@ -449,7 +450,7 @@ def create_allreduce_fusion_workspace(
 # ============================================================================
 
 
-@flashinfer_api
+@flashinfer_api(trace=allreduce_fusion_trace)
 def allreduce_fusion(
     input: torch.Tensor,
     workspace: AllReduceFusionWorkspace,
diff --git a/flashinfer/comm/dcp_alltoall.py b/flashinfer/comm/dcp_alltoall.py
new file mode 100644
index 00000000..3047f76c
--- /dev/null
+++ b/flashinfer/comm/dcp_alltoall.py
@@ -0,0 +1,255 @@
+"""
+DCP All-to-All Operations for DCP Attention Reduction
+
+Provides the DCP LL128 FIFO-based all-to-all kernel for context-parallel
+attention reduction. Uses SM90+ features (TMA, mbarrier).
+
+Usage protocol::
+
+    # 1. Query workspace size
+    ws_bytes = decode_cp_a2a_workspace_size(cp_size)
+
--
+
+
+# ─── Public API ───────────────────────────────────────────────────────────
+
+
+@flashinfer_api
+def decode_cp_a2a_workspace_size(cp_size: int) -> int:
+    """Return the workspace size **in bytes** per rank for the given CP group size.
+
+    Args:
+        cp_size: Context-parallel group size (number of ranks).
+
+    Returns:
+        Workspace size in bytes per rank.
+
+    Example::
+
+        >>> decode_cp_a2a_workspace_size(4)
+        16778240
+    """
+    return get_dcp_alltoall_module().get_workspace_size_per_rank(cp_size)
+
+
+@flashinfer_api
+def decode_cp_a2a_allocate_workspace(
+    cp_size: int,
+    cp_rank: int,
+    *,
+    mapping: Optional[Mapping] = None,
+    mnnvl_config: Optional[MnnvlConfig] = None,
+) -> torch.Tensor:
+    """Allocate a workspace tensor of shape ``[cp_size, ws_elems_per_rank]``.
+
+    After allocation, call :func:`decode_cp_a2a_init_workspace` followed by a
+    cross-rank barrier before the first :func:`decode_cp_a2a_alltoall` call.
+
+    Two allocation modes:
+
+    - **MNNVL** (``mapping`` provided): Cross-rank visible GPU memory via
+      FlashInfer's ``MnnvlMemory``. Required for multi-node or when ranks
+      cannot see each other's device memory directly.
+    - **Plain device memory** (``mapping=None``): Standard ``torch.zeros``
+      allocation. Sufficient for single-node with NVLink P2P.
+
--
+
+    ws_elems_per_rank = (ws_bytes + 7) // 8
+    return torch.zeros(cp_size, ws_elems_per_rank, dtype=torch.int64, device="cuda")
+
+
+@flashinfer_api
+def decode_cp_a2a_init_workspace(
+    workspace: torch.Tensor,
+    cp_rank: int,
+    cp_size: int,
+) -> None:
+    """Initialize the workspace FIFO buffers. Call once before the first alltoall.
+
+    Resets the FIFO buffers in the **local** workspace row
+    (``workspace[cp_rank]``). This function is **synchronous**: when it
+    returns, the GPU memset is guaranteed to have completed.
+
+    .. important::
+        With MNNVL workspaces, **all ranks** must complete
+        ``decode_cp_a2a_init_workspace`` and execute a cross-rank barrier
+        (e.g. ``dist.barrier(group)``) before **any** rank calls
+        :func:`decode_cp_a2a_alltoall`. Without the barrier, a rank may
+        start writing to a peer's FIFO before that peer has finished
+        initializing → deadlock.
+
+    Args:
--
+    # subsequent cross-GPU alltoall can race with the unfinished memset
+    # on MNNVL memory, causing a deadlock.
+    torch.cuda.current_stream().synchronize()
+
+
+@flashinfer_api(trace=decode_cp_a2a_alltoall_trace)
+def decode_cp_a2a_alltoall(
+    partial_o: torch.Tensor,
+    softmax_stats: torch.Tensor,
+    workspace: torch.Tensor,
+    cp_rank: int,
+    cp_size: int,
+    enable_pdl: Optional[bool] = None,
+) -> tuple[torch.Tensor, torch.Tensor]:
+    """Perform the DCP all-to-all exchange.
+
+    Each rank sends its ``partial_o[..., peer, :]`` slice to the
+    corresponding peer and receives all peers' contributions into the
+    output tensors.
+
+    Args:
+        partial_o: ``[..., cp_size, D]`` — half or bfloat16.
+            ``D * element_size`` must be 16-byte aligned.
+        softmax_stats: ``[..., cp_size, S]`` — float32, ``S >= 2`` and even.
+            Batch dimensions must match ``partial_o``.
+        workspace: ``[cp_size, ws_elems_per_rank]`` int64 tensor from
--
+    MixedCommOp.ALLREDUCE_ALLGATHER: _allreduce_allgather,
+    MixedCommOp.REDUCESCATTER_ALLREDUCE: _reducescatter_allreduce,
+}
+
+
+@flashinfer_api
+@backend_requirement(
+    backend_checks={},
+    common_check=_common_check,
+)
+def run_mixed_comm(
+    op: MixedCommOp,
+    handler: MixedCommHandler,
+    x_in: torch.Tensor,
+    x_out: torch.Tensor | None = None,
+    mode: MixedCommMode | None = None,
+) -> torch.Tensor:
+    """Execute a mixed communication operation.
+
+    This is the main entry point for running communication collectives
+    through the mixed communication handler. It supports fused GPU kernels
+    (using virtual memory intra-node and nvshmem inter-node), NCCL-based
+    fallbacks, and autotuned mode selection.
+
+    Args:
+        op: The communication operation to perform.
--
 @functools.cache
@@ -28,7 +29,7 @@ def get_concat_mla_module():
     return gen_concat_mla_module().build_and_load()
 
 
-@flashinfer_api
+@flashinfer_api(trace=concat_mla_k_trace)
 def concat_mla_k(
     k: torch.Tensor,
     k_nope: torch.Tensor,
diff --git a/flashinfer/cudnn/decode.py b/flashinfer/cudnn/decode.py
index 195ca2d4..9b593095 100644
--- a/flashinfer/cudnn/decode.py
+++ b/flashinfer/cudnn/decode.py
@@ -4,6 +4,7 @@ from typing import Optional
 import torch
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.attention import cudnn_batch_decode_trace
 from .utils import get_cudnn_fmha_gen_module
 
 try:
@@ -253,7 +254,7 @@ def _batch_decode_with_kv_cache(
     return out
 
 
-@flashinfer_api
+@flashinfer_api(trace=cudnn_batch_decode_trace)
 def cudnn_batch_decode_with_kv_cache(
     q: torch.Tensor,
     k_cache: torch.Tensor,
diff --git a/flashinfer/cudnn/prefill.py b/flashinfer/cudnn/prefill.py
index fc1bbb5f..b16d6043 100644
--- a/flashinfer/cudnn/prefill.py
+++ b/flashinfer/cudnn/prefill.py
@@ -4,6 +4,7 @@ from typing import Optional
 import torch
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.attention import cudnn_batch_prefill_trace
 from .utils import get_cudnn_fmha_gen_module
 
 try:
@@ -558,7 +559,7 @@ def _batch_prefill_with_kv_cache(
         return out, None
 
 
-@flashinfer_api
+@flashinfer_api(trace=cudnn_batch_prefill_trace)
 def cudnn_batch_prefill_with_kv_cache(
     q: torch.Tensor,
     k_cache: torch.Tensor,
diff --git a/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py b/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
index 0b50c22c..f25aa6fd 100644
--- a/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
+++ b/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
@@ -38,6 +38,7 @@ import torch
 from cutlass import Float32, Int32, Int64, Uint32, Uint8
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.norm import add_rmsnorm_fp4quant_trace
 from ..utils import device_support_pdl
 from .fp4_common import (
     # Constants
@@ -1042,7 +1043,7 @@ def _get_compiled_kernel(
     return tensor_api
 
 
-@flashinfer_api
+@flashinfer_api(trace=add_rmsnorm_fp4quant_trace)
 def add_rmsnorm_fp4quant(
     input: torch.Tensor,
     residual: torch.Tensor,
diff --git a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py
index 333697ab..b7aabc36 100644
--- a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py
+++ b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py
@@ -20,6 +20,7 @@ import torch
 from cutlass import Float32, Int32
 
 from flashinfer.api_logging import flashinfer_api
+from flashinfer.trace.templates.attention import cute_dsl_batch_mla_run_trace
 from flashinfer.utils import device_support_pdl
 from flashinfer.cute_dsl.utils import (
     get_max_active_clusters,
@@ -519,7 +520,7 @@ class BatchMLADecodeCuteDSLWrapper:
                 f"out_dtype={self._o_dtype}"
             )
 
-    @flashinfer_api
+    @flashinfer_api(trace=cute_dsl_batch_mla_run_trace)
     def run(
         self,
         q: torch.Tensor,
diff --git a/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py b/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py
index 58a24abe..ee0cd5e7 100644
--- a/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py
+++ b/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py
@@ -21,6 +21,7 @@ import cutlass.cute as cute
 from cutlass.cute.typing import Int32
 
 from flashinfer.api_logging import flashinfer_api
+from flashinfer.trace.templates.attention import cute_dsl_batch_prefill_run_trace
 
 from ..config import AttentionConfig, AttentionFusion
 from ..fusion.mask import MaskType
@@ -371,7 +372,7 @@ class BatchPrefillCuteDSLWrapper:
                     f"device={self._device}"
                 )
 
-    @flashinfer_api
+    @flashinfer_api(trace=cute_dsl_batch_prefill_run_trace)
     def run(
         self,
         q: torch.Tensor,
diff --git a/flashinfer/cute_dsl/rmsnorm_fp4quant.py b/flashinfer/cute_dsl/rmsnorm_fp4quant.py
index bc4acffc..97ce68a1 100644
--- a/flashinfer/cute_dsl/rmsnorm_fp4quant.py
+++ b/flashinfer/cute_dsl/rmsnorm_fp4quant.py
@@ -32,6 +32,7 @@ import torch
 from cutlass import Float32, Int32, Uint8
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.norm import rmsnorm_fp4quant_trace
 from ..utils import device_support_pdl
 from .fp4_common import (
     # Constants
@@ -771,7 +772,7 @@ def _get_compiled_kernel(
     return tensor_api
 
 
-@flashinfer_api
+@flashinfer_api(trace=rmsnorm_fp4quant_trace)
 def rmsnorm_fp4quant(
     input: torch.Tensor,
     weight: torch.Tensor,
diff --git a/flashinfer/decode.py b/flashinfer/decode.py
index 822aca40..5e9eb515 100644
--- a/flashinfer/decode.py
+++ b/flashinfer/decode.py
@@ -22,6 +22,12 @@ from typing import Any, List, Literal, Optional, Tuple, Union, overload
 import torch
 
 from .api_logging import flashinfer_api
+from .trace.templates.attention import (
+    gqa_paged_decode_trace,
+    single_decode_with_kv_cache_trace,
+    trtllm_batch_decode_trace,
+    xqa_batch_decode_trace,
+)
 
 ## NOTE: MLA functions have been moved to mla.py, but we keep the aliases here for backward compatibility.
 from .mla import (
@@ -400,7 +406,7 @@ def single_decode_with_kv_cache(
 ) -> Tuple[torch.Tensor, torch.Tensor]: ...
 
 
-@flashinfer_api
+@flashinfer_api(trace=single_decode_with_kv_cache_trace)
 def single_decode_with_kv_cache(
     q: torch.Tensor,
     k: torch.Tensor,
@@ -1215,7 +1221,7 @@ class BatchDecodeWithPagedKVCacheWrapper:
         kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]: ...
 
-    @flashinfer_api
+    @flashinfer_api(trace=gqa_paged_decode_trace)
     def run(
         self,
         q: torch.Tensor,
@@ -1577,6 +1583,8 @@ class CUDAGraphBatchDecodeWithPagedKVCacheWrapper(BatchDecodeWithPagedKVCacheWra
     :class:`BatchDecodeWithPagedKVCacheWrapper`
     """
 
+    # No @flashinfer_api here: parent class BatchDecodeWithPagedKVCacheWrapper
+    # already decorates __init__, so decorating again produces double log entries.
     def __init__(
         self,
         workspace_buffer: torch.Tensor,
@@ -2232,7 +2240,7 @@ def get_trtllm_gen_decode_module(*args):
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_batch_decode_trace)
 def trtllm_batch_decode_with_kv_cache(
     query: torch.Tensor,
     kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
@@ -2618,7 +2626,7 @@ def trtllm_batch_decode_with_kv_cache(
 
 
 # xqa uses NHD layout
-@flashinfer_api
+@flashinfer_api(trace=xqa_batch_decode_trace)
 def xqa_batch_decode_with_kv_cache(
     query: torch.Tensor,
     kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
diff --git a/flashinfer/fi_trace.py b/flashinfer/fi_trace.py
new file mode 100644
index 00000000..1104eb6f
--- /dev/null
+++ b/flashinfer/fi_trace.py
@@ -0,0 +1,285 @@
+# Copyright (c) 2025 by FlashInfer team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
--
+
+"""
+fi_trace: Generate `flashinfer-bench <https://github.com/flashinfer-ai/flashinfer-bench>`_
+compatible definition JSON for FlashInfer APIs.
+
+Every ``@flashinfer_api(trace=<template>)``-decorated function supports two
+usage modes:
+
+Auto-dump (recommended)
+-----------------------
+Set environment variables **before** importing flashinfer, then run your
+workload normally.  No explicit ``fi_trace`` call is needed.
+
+.. code-block:: bash
+
+    FLASHINFER_TRACE_DUMP=1 \\
+    FLASHINFER_TRACE_DUMP_DIR=./fi_trace_out \\
+    python my_script.py
+
+Every decorated function writes a ``<name>.json`` file on its **first** call
+for each unique set of const-axis values (e.g. head dimensions, vocab size).
+Subsequent calls with the same shape are deduplicated — the file is written
+only once per process.  The output directory is created automatically.
+
+Explicit call (for selective or programmatic use)
+-------------------------------------------------
--
+from pathlib import Path
+from typing import Any, Callable, Dict, Optional, Union
+
+# ---------------------------------------------------------------------------
+# Legacy registry — kept for backwards compatibility.
+# New code should use @flashinfer_api(trace=TraceTemplate(...)) instead.
+# ---------------------------------------------------------------------------
+
+_REGISTRY: Dict[str, Any] = {}
+
+
+def register_fi_trace(qualname: str, spec: Any) -> None:
+    """Register a legacy FiTraceSpec for the function with the given qualname.
+
+    .. deprecated::
+        Use ``@flashinfer_api(trace=TraceTemplate(...))`` instead.
+    """
+    _REGISTRY[qualname] = spec
+
+
+def build_fi_trace_fn(spec: Any) -> Callable[..., Dict[str, Any]]:
+    """Build a fi_trace callable from a legacy FiTraceSpec.
+
+    .. deprecated::
+        Use ``TraceTemplate.build_fi_trace_fn`` instead.
+    """
+    # Import the old implementation from the trace package for backwards compat.
+    from .trace.template import (  # noqa: PLC0415,F401
+        Const,
+        Scalar,
+        Tensor,
+        TraceTemplate,
+        Var,
+    )
+    import json  # noqa: PLC0415
+    import os  # noqa: PLC0415
--
+    """Generate a flashinfer-bench definition JSON for any FlashInfer API call.
+
+    Parameters
+    ----------
+    func_or_method:
+        A ``@flashinfer_api``-decorated function or (bound) method.
+    save_dir:
+        Directory where the JSON definition file should be written.
+        Falls back to ``FLASHINFER_TRACE_DUMP_DIR`` env-var when *None*.
+    **kwargs:
+        The same tensor arguments you would pass to the real API.
+
+    Returns
+    -------
+    dict
+        A flashinfer-bench compatible definition dictionary.
+
+    Examples
+    --------
+    Standalone function::
+
+        defn = fi_trace(flashinfer.norm.rmsnorm, input=hidden, weight=weight)
+
+    Bound method (instance.run)::
+
+        defn = fi_trace(wrapper.run, q=q_tensor, paged_kv_cache=(k, v))
--
+    trace_fn = getattr(actual_func, "fi_trace", None)
+    if trace_fn is None:
+        qualname = getattr(actual_func, "__qualname__", repr(actual_func))
+        raise ValueError(
+            f"No fi_trace spec is registered for '{qualname}'. "
+            "Only @flashinfer_api(trace=...)-decorated functions support fi_trace."
+        )
+    return trace_fn(save_dir=save_dir, **kwargs)
diff --git a/flashinfer/fused_moe/__init__.py b/flashinfer/fused_moe/__init__.py
index df6e1f72..d983f9d4 100644
--- a/flashinfer/fused_moe/__init__.py
+++ b/flashinfer/fused_moe/__init__.py
@@ -17,6 +17,8 @@ limitations under the License.
 from .core import (
     convert_to_block_layout,
     cutlass_fused_moe,
+    interleave_moe_scales_for_sm90_mixed_gemm,
+    interleave_moe_weights_for_sm90_mixed_gemm,
     gen_cutlass_fused_moe_sm120_module,
     gen_cutlass_fused_moe_sm103_module,
     gen_cutlass_fused_moe_sm100_module,
@@ -64,6 +66,8 @@ __all__ = [
     "WeightLayout",
     "convert_to_block_layout",
     "cutlass_fused_moe",
+    "interleave_moe_scales_for_sm90_mixed_gemm",
--
+        ),
     )
 
 
-# ref: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py#L121
 @flashinfer_api
+def interleave_moe_scales_for_sm90_mixed_gemm(
+    scales: torch.Tensor,
+    group_size: int = 32,
+) -> torch.Tensor:
+    """Interleave MXFP4 block scales for the SM90 mixed-input MoE GEMM.
+
+    The kernel expects scales in layout
+    ``(num_experts, K // (group_size * 4), rows * 4)`` rather than the natural
+    ``(num_experts, rows, K // group_size)`` produced by the MXFP4 quantizer.
+    This helper performs the reshape + permute equivalent to TensorRT-LLM's
+    ``WFP4A16FusedMoEMethod.load_quant_scales`` (PR #12451), with the fixed
+    interleave factor of ``128 // group_size`` used for MXFP4.
+
+    Parameters
+    ----------
+    scales:
+        ``[num_experts, rows, K // group_size]`` uint8 tensor of E8M0 block
+        scales.
+    group_size:
+        MXFP4 quantization group size (default 32).
--
+        scales.reshape(e, rows, kgs // factor, factor).permute(0, 2, 1, 3).contiguous()
+    )
+    return tmp.reshape(e, kgs // factor, rows * factor)
+
+
+@flashinfer_api
+def interleave_moe_weights_for_sm90_mixed_gemm(
+    weight: torch.Tensor,
+    quant_type: str = "fp4",
+) -> torch.Tensor:
+    """Interleave 4-bit packed MoE weights for the SM90 mixed-input GEMM.
+
+    The SM90 mixed-dtype MoE GEMM (used by ``cutlass_fused_moe`` with
+    ``use_w4_group_scaling=True``) expects weights in a specific interleaved
+    layout; without preprocessing, the LUT-based FP4→BF16 conversion reads
+    bytes from the wrong positions and the output diverges from a dequantized
+    reference for any K > 128. TensorRT-LLM's W4A16 MoE runs the equivalent
+    preprocessing at weight-load time (see
+    ``interleave_4bit_weights_for_Hopper_mixed_gemm`` in TRT-LLM PR #12451).
+
+    Parameters
+    ----------
+    weight:
+        ``[num_experts, n, k // 2]`` uint8 CUDA tensor (4-bit values packed
+        two-per-byte).
+    quant_type:
--
+    )
+    return out
+
+
+# ref: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py#L121
+@flashinfer_api(trace=cutlass_fused_moe_trace)
 def cutlass_fused_moe(
     input: torch.Tensor,
     token_selected_experts: torch.Tensor,
@@ -1027,8 +1151,8 @@ def get_trtllm_moe_sm100_module():
                     DynamicTensorSpec(
                         input_idx,
                         dim_idx,
-                        get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens, 1),
-                        lambda x: min(last_positive_power_of_2(x), tune_max_num_tokens),
+                        get_hybrid_num_tokens_buckets(tune_max_num_tokens, 1),
+                        lambda x: map_to_hybrid_bucket(x, tune_max_num_tokens),
                         initializers,
                     ),
                 ),
@@ -2344,7 +2468,7 @@ def _validate_routing_replay_out(
         raise ValueError("routing_replay_out must be contiguous (packed row-major)")
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_bf16_moe_trace)
 def trtllm_bf16_moe(
     routing_logits: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
@@ -2452,7 +2576,7 @@ def trtllm_bf16_moe(
         return result
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_bf16_routed_moe_trace)
 def trtllm_bf16_routed_moe(
     topk_ids: torch.Tensor,
     hidden_states: torch.Tensor,
@@ -2557,7 +2681,7 @@ def trtllm_bf16_routed_moe(
         return result
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp8_per_tensor_scale_moe_trace)
 def trtllm_fp8_per_tensor_scale_moe(
     routing_logits: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
@@ -2658,7 +2782,7 @@ def trtllm_fp8_per_tensor_scale_moe(
         return result
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp8_block_scale_moe_trace_dispatch)
 def trtllm_fp8_block_scale_moe(
     routing_logits: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
@@ -2779,7 +2903,7 @@ def trtllm_fp8_block_scale_moe(
         return result
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp8_block_scale_routed_moe_trace)
 def trtllm_fp8_block_scale_routed_moe(
     topk_ids: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
@@ -2893,7 +3017,7 @@ def trtllm_fp8_block_scale_routed_moe(
         return result
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp4_block_scale_moe_trace_dispatch)
 def trtllm_fp4_block_scale_moe(
     routing_logits: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
@@ -3030,7 +3154,7 @@ def trtllm_fp4_block_scale_moe(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp4_block_scale_routed_moe_trace)
 def trtllm_fp4_block_scale_routed_moe(
     topk_ids: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
@@ -3165,7 +3289,7 @@ def trtllm_fp4_block_scale_routed_moe(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_mxint4_block_scale_moe_trace)
 def trtllm_mxint4_block_scale_moe(
     routing_logits: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
diff --git a/flashinfer/fused_moe/cute_dsl/b12x_moe.py b/flashinfer/fused_moe/cute_dsl/b12x_moe.py
index d2cbc8b0..34916df5 100644
--- a/flashinfer/fused_moe/cute_dsl/b12x_moe.py
+++ b/flashinfer/fused_moe/cute_dsl/b12x_moe.py
@@ -42,11 +42,12 @@ from typing import Optional, Tuple
 import torch
 
 from ...api_logging import flashinfer_api
+from ...trace.templates.moe import b12x_fused_moe_trace, b12x_moe_wrapper_run_trace
 from ...utils import supported_compute_capability
 
 
 @supported_compute_capability([120, 121])
-@flashinfer_api
+@flashinfer_api(trace=b12x_fused_moe_trace)
 def b12x_fused_moe(
     x: torch.Tensor,
     w1_weight: torch.Tensor,
@@ -293,7 +294,7 @@ class B12xMoEWrapper:
             device=self.device,
         )
 
-    @flashinfer_api
+    @flashinfer_api(trace=b12x_moe_wrapper_run_trace)
     def run(
         self,
         x: torch.Tensor,
diff --git a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py
index f6cf1b67..e266cb77 100644
--- a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py
+++ b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py
@@ -89,8 +89,8 @@ from flashinfer.cute_dsl.fp4_common import (
     st_global_u64,
     scatter_add_bf16x2,
 )
-from flashinfer.gemm.kernels.dense_blockscaled_gemm_sm120 import (
-    Sm120BlockScaledDenseGemmKernel as DenseGemmKernel,
+from flashinfer.gemm.kernels.dense_blockscaled_gemm_sm120_b12x import (
+    Sm120B12xBlockScaledDenseGemmKernel as DenseGemmKernel,
 )
 
 
diff --git a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py
index e7fdae92..670b3ad8 100644
--
 from .moe_utils import (
@@ -530,7 +534,7 @@ class CuteDslMoEWrapper:
             enable_pdl=enable_pdl,
         )
 
-    @flashinfer_api
+    @flashinfer_api(trace=cute_dsl_moe_wrapper_run_trace)
     def run(
         self,
         x: torch.Tensor,
@@ -686,7 +690,7 @@ def _cute_dsl_fused_moe_nvfp4_impl(
 
 
 @supported_compute_capability([100, 103])
-@flashinfer_api
+@flashinfer_api(trace=cute_dsl_fused_moe_nvfp4_trace)
 def cute_dsl_fused_moe_nvfp4(
     x: torch.Tensor,
     x_sf: torch.Tensor,
diff --git a/flashinfer/fused_moe/cute_dsl/tuner.py b/flashinfer/fused_moe/cute_dsl/tuner.py
index 0cc8628e..636043db 100644
--- a/flashinfer/fused_moe/cute_dsl/tuner.py
+++ b/flashinfer/fused_moe/cute_dsl/tuner.py
@@ -42,8 +42,8 @@ from ...autotuner import (
     TuningConfig,
 )
 from ..utils import (
-    get_last_power_of_2_num_tokens_buckets,
-    last_positive_power_of_2,
+    get_hybrid_num_tokens_buckets,
+    map_to_hybrid_bucket,
 )
 
 logger = logging.getLogger(__name__)
@@ -273,10 +273,8 @@ class CuteDslFusedMoENvfp4Runner(TunableRunner):
                 DynamicTensorSpec(
--
 import torch
@@ -137,7 +138,7 @@ def get_dsv3_fused_routing_module():
 
 
 @backend_requirement({}, common_check=_check_dsv3_fused_routing_supported)
-@flashinfer_api
+@flashinfer_api(trace=fused_topk_deepseek_trace)
 def fused_topk_deepseek(
     scores: torch.Tensor,
     bias: torch.Tensor,
diff --git a/flashinfer/fused_moe/utils.py b/flashinfer/fused_moe/utils.py
index 004271a1..91f37aa5 100644
--- a/flashinfer/fused_moe/utils.py
+++ b/flashinfer/fused_moe/utils.py
@@ -209,29 +209,102 @@ def nearest_in_buckets(x: int, buckets: List[int]) -> int:
     return min(max(next_positive_power_of_2(x), buckets[0]), buckets[-1])
 
 
-def get_power_of_2_num_tokens_buckets(max_num_tokens) -> Tuple[int]:
-    """Return descending power-of-2 buckets from ``next_power_of_2(max_num_tokens)`` down to 1."""
-    max_num_tokens = next_positive_power_of_2(max_num_tokens)
-    num_token_buckets = []
-    m = max_num_tokens
-    while m >= 1:
-        num_token_buckets.append(m)
-        m //= 2
+_PHASE1_END = 256
--
 
@@ -106,7 +114,7 @@ TILE_V = 8  # pretranspose tile size
 # ============================================================================
 
 
-@flashinfer_api
+@flashinfer_api(trace=gated_delta_rule_decode_trace)
 def gated_delta_rule_decode_pretranspose(
     q: torch.Tensor,
     k: torch.Tensor,
@@ -394,7 +402,7 @@ def gated_delta_rule_decode_pretranspose(
 # ============================================================================
 
 
-@flashinfer_api
+@flashinfer_api(trace=gated_delta_rule_decode_trace)
 def gated_delta_rule_decode(
     q: torch.Tensor,
     k: torch.Tensor,
@@ -535,7 +543,7 @@ def gated_delta_rule_decode(
 # ============================================================================
 
 
-@flashinfer_api
+@flashinfer_api(trace=gdn_mtp_trace)
 def gated_delta_rule_mtp(
     q: torch.Tensor,
     k: torch.Tensor,
diff --git a/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py b/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
index 68398d28..53fe44ce 100644
--- a/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
+++ b/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
@@ -3333,8 +3333,7 @@ class GatedDeltaNetChunkedKernel:
 
         gate_handle = load_gate_consumer.wait_and_advance()
 
-        max_coord = tTR_tCcShared[cute.size(tTR_tCcShared) - 1]
-        cumprod_total = sCumprod[max_coord[1], 0, gate_handle.index]
+        cumprod_total = sCumprod[sCumprod.shape[0] - 1, 0, gate_handle.index]
 
         valid_state = not is_first_chunk or self.use_initial_state
         if cutlass.const_expr(valid_state):
diff --git a/flashinfer/gdn_kernels/blackwell/gdn_prefill.py b/flashinfer/gdn_kernels/blackwell/gdn_prefill.py
index 82dcc72b..aafcc671 100644
--- a/flashinfer/gdn_kernels/blackwell/gdn_prefill.py
--
     register_custom_op,
@@ -95,7 +96,7 @@ def get_gdn_prefill_module():
     return SimpleNamespace(gdn_prefill=gdn_prefill)
 
 
-@flashinfer_api
+@flashinfer_api(trace=gdn_prefill_trace)
 def chunk_gated_delta_rule(
     q: torch.Tensor,
     k: torch.Tensor,
diff --git a/flashinfer/gemm/__init__.py b/flashinfer/gemm/__init__.py
index a7795beb..def82216 100644
--- a/flashinfer/gemm/__init__.py
+++ b/flashinfer/gemm/__init__.py
@@ -61,11 +61,11 @@ try:
     from flashinfer.cute_dsl.utils import is_cute_dsl_available
 
     if is_cute_dsl_available():
-        from .kernels.dense_blockscaled_gemm_sm120 import (
-            Sm120BlockScaledDenseGemmKernel as Sm120BlockScaledDenseGemmKernel,
+        from .kernels.dense_blockscaled_gemm_sm120_b12x import (
+            Sm120B12xBlockScaledDenseGemmKernel as Sm120B12xBlockScaledDenseGemmKernel,
         )
 
-        _cute_dsl_kernels.append("Sm120BlockScaledDenseGemmKernel")
+        _cute_dsl_kernels.append("Sm120B12xBlockScaledDenseGemmKernel")
 except ImportError:
--
 from ..utils import (
@@ -325,7 +339,7 @@ def _heuristic_func_mm_bf16(
     common_check=_check_mm_bf16_problem_size,
     heuristic_func=_heuristic_func_mm_bf16,
 )
-@flashinfer_api
+@flashinfer_api(trace=mm_bf16_trace)
 def mm_bf16(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -514,7 +528,7 @@ def _heuristic_func_bmm_bf16(
     common_check=_check_bmm_bf16_problem_size,
     heuristic_func=_heuristic_func_bmm_bf16,
 )
-@flashinfer_api
+@flashinfer_api(trace=bmm_bf16_trace)
 def bmm_bf16(
     A: torch.Tensor,
     B: torch.Tensor,
@@ -815,8 +829,8 @@ _FP8_GEMM_SM100_TUNING_CONFIG = TuningConfig(
         DynamicTensorSpec(
             (0,),  # a_tensor_index
             (-2,),
-            get_last_power_of_2_num_tokens_buckets,
-            last_positive_power_of_2,
+            get_hybrid_num_tokens_buckets,
+            map_to_hybrid_bucket_uncapped,
         ),
     ),
     constraint_specs=(
@@ -871,8 +885,8 @@ _BF16_GEMM_SM100_TUNING_CONFIG = TuningConfig(
         DynamicTensorSpec(
             (0,),  # a_tensor_index
             (-2,),
-            get_last_power_of_2_num_tokens_buckets,
-            last_positive_power_of_2,
--
     constraint_specs=(
@@ -1095,7 +1109,7 @@ def get_tgv_gemm_sm10x_module(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=tgv_gemm_sm100_trace)
 def tgv_gemm_sm100(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -1173,8 +1187,8 @@ def tgv_gemm_sm100(
             DynamicTensorSpec(
                 (a_tensor_index,),
                 (-2,),
-                get_last_power_of_2_num_tokens_buckets,
-                last_positive_power_of_2,
+                get_hybrid_num_tokens_buckets,
+                map_to_hybrid_bucket_uncapped,
             ),
         ),
         constraint_specs=(
@@ -1437,6 +1451,7 @@ class SegmentGEMMWrapper:
     True
     """
 
+    @flashinfer_api
     def __init__(
         self, float_workspace_buffer: torch.Tensor, backend: str = "auto"
     ) -> None:
@@ -1469,7 +1484,7 @@ class SegmentGEMMWrapper:
         self._float_workspace_buffer = float_workspace_buffer
         self._int_workspace_buffer = int_workspace_buffer
 
-    @flashinfer_api
+    @flashinfer_api(trace=segment_gemm_run_trace)
     def run(
         self,
         x: torch.Tensor,
@@ -2084,6 +2099,8 @@ def build_cudnn_gemm_fp4_graph_override_shape(
     return graph
 
 
+# Internal helper called from mm_fp4; the user-facing mm_fp4 is already
+# decorated, so decorating here would double-log the same invocation.
 def execute_cudnn_gemm_fp4_graph_override_shape(
     graph,
     a,
@@ -2319,6 +2336,8 @@ def build_cudnn_gemm_mxfp8_graph_override_shape(
     return graph
 
 
+# Internal helper called from mm_mxfp8; the user-facing mm_mxfp8 is already
+# decorated, so decorating here would double-log the same invocation.
 def execute_cudnn_gemm_mxfp8_graph_override_shape(
     graph,
--
 ):
@@ -3161,7 +3184,7 @@ def _expand_block_scale_tensor_shape(block_scale_tensor, batch_size):
     return (tuple(block_scale_shape), tuple(block_scale_stride))
 
 
-@flashinfer_api
+@flashinfer_api(trace=mm_fp8_trace)
 def mm_fp8(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -3990,7 +4013,7 @@ def _heuristic_func_mm_mxfp8(
     common_check=_check_mm_mxfp8_problem_size,
     heuristic_func=_heuristic_func_mm_mxfp8,  # result stored in mm_mxfp8.suitable_auto_backends
 )
-@flashinfer_api
+@flashinfer_api(trace=mm_mxfp8_trace)
 def mm_mxfp8(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -4858,8 +4881,8 @@ def _b12x_gemm_fp4_runner(
     """
     import cutlass
 
-    from .kernels.dense_blockscaled_gemm_sm120 import (
-        Sm120BlockScaledDenseGemmKernel,
+    from .kernels.dense_blockscaled_gemm_sm120_b12x import (
+        Sm120B12xBlockScaledDenseGemmKernel,
     )
 
     cutlass_dtype_attr = _TORCH_TO_CUTLASS_DTYPE_ATTR.get(out_dtype)
@@ -4905,7 +4928,7 @@ def _b12x_gemm_fp4_runner(
             ]
             swap_ab = False
             for mma_tiler_mn in sm120_mma_tiler_candidates:
-                if not Sm120BlockScaledDenseGemmKernel.can_implement(
+                if not Sm120B12xBlockScaledDenseGemmKernel.can_implement(
--
     constraint_specs=(
@@ -5195,7 +5217,7 @@ _MM_MXFP8_TUNING_CONFIG = TuningConfig(
     common_check=_check_mm_fp4_problem_size,
     heuristic_func=_heuristic_func_mm_fp4,  # result stored in mm_fp4.suitable_auto_backends
 )
-@flashinfer_api
+@flashinfer_api(trace=mm_fp4_trace)
 def mm_fp4(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -5449,7 +5471,7 @@ def _heuristic_func_bmm_fp8(
     common_check=_check_bmm_fp8_problem_size,
     heuristic_func=_heuristic_func_bmm_fp8,
 )
-@flashinfer_api
+@flashinfer_api(trace=bmm_fp8_trace)
 def bmm_fp8(
     A: torch.Tensor,
     B: torch.Tensor,
@@ -6862,7 +6884,7 @@ def _check_batch_deepgemm_fp8_nt_groupwise(
     {},
     common_check=_check_batch_deepgemm_fp8_nt_groupwise,
 )
-@flashinfer_api
+@flashinfer_api(trace=batch_deepgemm_fp8_nt_groupwise_trace)
 def batch_deepgemm_fp8_nt_groupwise(
     a: torch.Tensor,  # (batch_size, m, k)
     b: torch.Tensor,  # (batch_size, n, k)
@@ -7006,7 +7028,7 @@ def get_fp8_blockscale_gemm_runner_sm90():
     return module.init()
 
 
-@flashinfer_api
+@flashinfer_api(trace=fp8_blockscale_gemm_sm90_trace)
 def fp8_blockscale_gemm_sm90(
     input: torch.Tensor,
     weight: torch.Tensor,
@@ -7588,7 +7610,7 @@ def _heuristic_func_bmm_mxfp8(
     common_check=_check_bmm_mxfp8_problem_size,
     heuristic_func=_heuristic_func_bmm_mxfp8,
 )
-@flashinfer_api
+@flashinfer_api(trace=bmm_mxfp8_trace)
 def bmm_mxfp8(
     A: torch.Tensor,
     B: torch.Tensor,
diff --git a/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py b/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120_b12x.py
similarity index 99%
rename from flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py
rename to flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120_b12x.py
index c49bc815..6eee27a7 100644
--- a/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py
+++ b/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120_b12x.py
@@ -1550,7 +1550,7 @@ class DenseGemmKernel:
 
 
 # Alias for FlashInfer integration
-Sm120BlockScaledDenseGemmKernel = DenseGemmKernel
+Sm120B12xBlockScaledDenseGemmKernel = DenseGemmKernel
 
 
 class _DenseGemmLaunch:
diff --git a/flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py b/flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py
--
     get_cutlass_dtype,
@@ -2951,7 +2952,7 @@ def get_cute_dsl_compiled_masked_gemm_kernel(
     return tensor_api
 
 
-@flashinfer_api
+@flashinfer_api(trace=grouped_gemm_nt_masked_trace)
 def grouped_gemm_nt_masked(
     lhs: Tuple[torch.Tensor, torch.Tensor],
     rhs: Tuple[torch.Tensor, torch.Tensor],
diff --git a/flashinfer/gemm/routergemm.py b/flashinfer/gemm/routergemm.py
index cfde7d43..f83c8974 100644
--- a/flashinfer/gemm/routergemm.py
+++ b/flashinfer/gemm/routergemm.py
@@ -1,4 +1,8 @@
 from ..api_logging import flashinfer_api
+from ..trace.templates.gemm import (
+    mm_M1_16_K7168_N256_trace,
+    tinygemm_bf16_trace,
+)
 from flashinfer.jit import gen_dsv3_router_gemm_module, gen_tinygemm2_module
 import functools
 from types import SimpleNamespace
@@ -176,7 +180,7 @@ def mm_M1_16_K7168_N128(
 
 
 @backend_requirement({}, common_check=_mm_M1_16_K7168_N256_shape_checks)
-@flashinfer_api
+@flashinfer_api(trace=mm_M1_16_K7168_N256_trace)
 def mm_M1_16_K7168_N256(
     mat_a: torch.Tensor,
     mat_b: torch.Tensor,
@@ -324,7 +328,7 @@ def get_tinygemm2_module():
 
 
 @backend_requirement({}, common_check=_tinygemm_bf16_shape_checks)
-@flashinfer_api
+@flashinfer_api(trace=tinygemm_bf16_trace)
 def tinygemm_bf16(
     input: torch.Tensor,
     weight: torch.Tensor,
diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py
index 7f36a314..8378e0ab 100644
--- a/flashinfer/jit/__init__.py
+++ b/flashinfer/jit/__init__.py
@@ -82,6 +82,7 @@ from .comm import gen_trtllm_mnnvl_comm_module as gen_trtllm_mnnvl_comm_module
 from .comm import gen_trtllm_comm_module as gen_trtllm_comm_module
 from .comm import gen_vllm_comm_module as gen_vllm_comm_module
 from .comm import gen_moe_alltoall_module as gen_moe_alltoall_module
+from .comm import gen_dcp_alltoall_module as gen_dcp_alltoall_module
 from .dsv3_optimizations import (
     gen_dsv3_router_gemm_module as gen_dsv3_router_gemm_module,
 )
diff --git a/flashinfer/jit/comm.py b/flashinfer/jit/comm.py
index 46768eed..834f77f9 100644
--- a/flashinfer/jit/comm.py
+++ b/flashinfer/jit/comm.py
@@ -15,7 +15,13 @@ limitations under the License.
--
     gen_selective_state_update_sm100_module,
@@ -99,7 +100,7 @@ def get_selective_state_update_module(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=selective_state_update_trace)
 def selective_state_update(
     state: torch.Tensor,
     x: torch.Tensor,
diff --git a/flashinfer/mla/_core.py b/flashinfer/mla/_core.py
index 4e8bdd72..e27e3807 100644
--- a/flashinfer/mla/_core.py
+++ b/flashinfer/mla/_core.py
@@ -21,6 +21,11 @@ from typing import List, Literal, Optional, Tuple, Union, overload
 import torch
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.attention import (
+    mla_paged_decode_trace,
+    trtllm_batch_decode_mla_trace,
+    xqa_batch_decode_mla_trace,
+)
 from ..jit import gen_batch_mla_module, gen_trtllm_gen_fmha_module, setup_cubin_loader
 from ..jit.mla import gen_mla_module
 from ..utils import (
@@ -469,7 +474,7 @@ class BatchMLAPagedAttentionWrapper:
         return_lse_base_on_e: bool = False,
     ) -> Tuple[torch.Tensor, torch.Tensor]: ...
 
-    @flashinfer_api
+    @flashinfer_api(trace=mla_paged_decode_trace)
     def run(
         self,
         q_nope: torch.Tensor,
@@ -588,7 +593,7 @@ class BatchMLAPagedAttentionWrapper:
         return (out, lse) if return_lse else out
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_batch_decode_mla_trace)
 def trtllm_batch_decode_with_kv_cache_mla(
     query: torch.Tensor,
     kv_cache: torch.Tensor,
@@ -856,7 +861,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
         raise ValueError(f"Backend {backend} not supported")
 
 
-@flashinfer_api
+@flashinfer_api(trace=xqa_batch_decode_mla_trace)
 def xqa_batch_decode_with_kv_cache_mla(
     query: torch.Tensor,
     kv_cache: torch.Tensor,
diff --git a/flashinfer/norm/__init__.py b/flashinfer/norm/__init__.py
index 0f9911a6..ba612b28 100644
--- a/flashinfer/norm/__init__.py
+++ b/flashinfer/norm/__init__.py
@@ -32,6 +32,16 @@ from typing import Optional, Union
 import torch
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.norm import (
+    fused_add_rmsnorm_quant_trace,
+    fused_add_rmsnorm_trace,
+    fused_rmsnorm_silu_trace,
+    gemma_fused_add_rmsnorm_trace,
+    gemma_rmsnorm_trace,
+    layernorm_trace,
+    rmsnorm_quant_trace,
+    rmsnorm_trace,
--
     get_compute_capability,
@@ -94,7 +104,7 @@ def _normalize_scale_tensor(
     return scale.contiguous()
 
 
-@flashinfer_api
+@flashinfer_api(trace=rmsnorm_trace)
 def rmsnorm(
     input: torch.Tensor,
     weight: torch.Tensor,
@@ -165,7 +175,7 @@ def _rmsnorm_impl_fake(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=rmsnorm_quant_trace)
 @register_custom_op("flashinfer::rmsnorm_quant", mutates_args=("out",))
 def rmsnorm_quant(
     out: torch.Tensor,
@@ -219,7 +229,7 @@ def _rmsnorm_quant_fake(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=fused_add_rmsnorm_trace)
 @register_custom_op("flashinfer::fused_add_rmsnorm", mutates_args=("input", "residual"))
 def fused_add_rmsnorm(
     input: torch.Tensor,
@@ -271,7 +281,7 @@ def _fused_add_rmsnorm_fake(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=fused_add_rmsnorm_quant_trace)
 @register_custom_op(
     "flashinfer::fused_add_rmsnorm_quant", mutates_args=("out", "residual")
 )
@@ -343,7 +353,7 @@ def _fused_add_rmsnorm_quant_fake(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=gemma_rmsnorm_trace)
 def gemma_rmsnorm(
     input: torch.Tensor,
     weight: torch.Tensor,
@@ -414,7 +424,7 @@ def _gemma_rmsnorm_impl_fake(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=gemma_fused_add_rmsnorm_trace)
 @register_custom_op(
     "flashinfer::gemma_fused_add_rmsnorm", mutates_args=("input", "residual")
 )
@@ -470,7 +480,7 @@ def _gemma_fused_add_rmsnorm_fake(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=layernorm_trace)
 @register_custom_op("flashinfer::layernorm", mutates_args=())
 def layernorm(
     input: torch.Tensor,
@@ -590,7 +600,7 @@ def _torch_dtype_to_str(dtype):
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=fused_rmsnorm_silu_trace)
 def fused_rmsnorm_silu(
     input: torch.Tensor,
     weight: torch.Tensor,
diff --git a/flashinfer/page.py b/flashinfer/page.py
index 12ea3613..7fb33cf3 100644
--- a/flashinfer/page.py
+++ b/flashinfer/page.py
@@ -20,6 +20,10 @@ from typing import Optional, Tuple, Union
 import torch
 
 from .api_logging import flashinfer_api
+from .trace.templates.page import (
+    append_paged_kv_cache_trace,
+    append_paged_mla_kv_cache_trace,
+)
 from .jit.page import gen_page_module
 from .utils import (
     TensorLayout,
@@ -222,7 +226,7 @@ def get_seq_lens(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=append_paged_mla_kv_cache_trace)
 def append_paged_mla_kv_cache(
     append_ckv: torch.Tensor,
     append_kpe: torch.Tensor,
@@ -272,7 +276,7 @@ def append_paged_mla_kv_cache(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=append_paged_kv_cache_trace)
 def append_paged_kv_cache(
     append_key: torch.Tensor,
     append_value: torch.Tensor,
diff --git a/flashinfer/pod.py b/flashinfer/pod.py
index fe2e36c1..4fa2d9bf 100644
--- a/flashinfer/pod.py
+++ b/flashinfer/pod.py
@@ -22,6 +22,10 @@ from typing import Any, List, Optional, Tuple, Union
 import torch
 
 from .api_logging import flashinfer_api
+from .trace.templates.attention import (
+    batch_pod_with_paged_kv_cache_run_trace,
+    pod_with_paged_kv_cache_run_trace,
+)
 from .jit import gen_pod_module, gen_batch_pod_module
 from .page import get_seq_lens
 from .prefill import get_batch_prefill_module
@@ -435,7 +439,7 @@ class PODWithPagedKVCacheWrapper:
 
     begin_forward = plan
 
-    @flashinfer_api
+    @flashinfer_api(trace=pod_with_paged_kv_cache_run_trace)
     def run(
         self,
         # Main params (prefill and decode)
@@ -1015,7 +1019,7 @@ class BatchPODWithPagedKVCacheWrapper:
 
     begin_forward = plan
 
-    @flashinfer_api
+    @flashinfer_api(trace=batch_pod_with_paged_kv_cache_run_trace)
     def run(
         self,
         # Main params (prefill and decode)
diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py
index 4ec6a29e..d491dd35 100755
--- a/flashinfer/prefill.py
+++ b/flashinfer/prefill.py
@@ -23,6 +23,17 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload
 import torch
 
 from .api_logging import flashinfer_api
+from .trace.templates.attention import (
+    gqa_paged_prefill_trace,
+    gqa_ragged_prefill_trace,
+    single_prefill_with_kv_cache_trace,
+    trtllm_batch_context_trace,
+)
+from .trace.templates.gemm import (
+    fmha_v2_prefill_deepseek_trace,
+    trtllm_ragged_attention_deepseek_trace,
--
     gen_customize_batch_prefill_module,
@@ -1099,7 +1110,7 @@ def single_prefill_with_kv_cache(
 ) -> Tuple[torch.Tensor, torch.Tensor]: ...
 
 
-@flashinfer_api
+@flashinfer_api(trace=single_prefill_with_kv_cache_trace)
 def single_prefill_with_kv_cache(
     q: torch.Tensor,
     k: torch.Tensor,
@@ -2132,7 +2143,7 @@ class BatchPrefillWithPagedKVCacheWrapper:
         skip_softmax_threshold_scale_factor: Optional[float] = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]: ...
 
-    @flashinfer_api
+    @flashinfer_api(trace=gqa_paged_prefill_trace)
     def run(
         self,
         q: torch.Tensor,
@@ -3186,7 +3197,7 @@ class BatchPrefillWithRaggedKVCacheWrapper:
         enable_pdl: Optional[bool] = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]: ...
 
-    @flashinfer_api
+    @flashinfer_api(trace=gqa_ragged_prefill_trace)
     def run(
         self,
         q: torch.Tensor,
@@ -3669,7 +3680,7 @@ def get_trtllm_gen_fmha_module():
     return op
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_ragged_attention_deepseek_trace)
 def trtllm_ragged_attention_deepseek(
     query: torch.Tensor,
     key: torch.Tensor,
@@ -3692,6 +3703,7 @@ def trtllm_ragged_attention_deepseek(
     skip_softmax_threshold_scale_factor: Optional[float] = None,
     out: Optional[torch.Tensor] = None,
     lse: Optional[torch.Tensor] = None,
+    backend: str = "trtllm-gen",
 ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
     """
     Parameters
@@ -3742,6 +3754,12 @@ def trtllm_ragged_attention_deepseek(
         output tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1], value.shape[2]]
     lse : Optional[torch.Tensor]
         lse tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1]]
+    backend : str
+        Attention backend to use. "trtllm-gen" (default) or "cute-dsl".
+        When backend="cute-dsl", query/key/value/out tensors must be
+        front-padded with max_seq_len rows of valid GPU memory before
+        index 0 (see ``cute_dsl_fmha_ragged_prefill`` for details).
--
             "lse assumed not None beyond this point when return_lse is True"
@@ -3839,7 +3917,7 @@ def trtllm_ragged_attention_deepseek(
         return out
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_batch_context_trace)
 def trtllm_batch_context_with_kv_cache(
     query: torch.Tensor,
     kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
@@ -4138,7 +4216,7 @@ def get_trtllm_fmha_v2_sm120_module():
     return gen_trtllm_fmha_v2_sm120_module().build_and_load()
 
 
-@flashinfer_api
+@flashinfer_api(trace=fmha_v2_prefill_deepseek_trace)
 def fmha_v2_prefill_deepseek(
     query: torch.Tensor,
     key: torch.Tensor,
@@ -4228,7 +4306,7 @@ def get_trtllm_fmha_v2_module(
     return gen_fmha_v2_module(input_layout, input_dtype, output_dtype).build_and_load()
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fmha_v2_prefill_trace)
 def trtllm_fmha_v2_prefill(
     qkv: Union[
         torch.Tensor,
diff --git a/flashinfer/quantization/fp4_quantization.py b/flashinfer/quantization/fp4_quantization.py
index 4cd5cd34..84f7ade6 100644
--- a/flashinfer/quantization/fp4_quantization.py
+++ b/flashinfer/quantization/fp4_quantization.py
@@ -21,6 +21,12 @@ from typing import List, Optional, Tuple
 import torch
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.quantize import (
+    fp4_quantize_trace,
+    mxfp4_quantize_trace,
+    nvfp4_kv_quantize_trace,
+    nvfp4_quantize_trace,
+)
 from ..jit import JitSpec
 from ..jit import env as jit_env
 from ..jit import (
@@ -648,7 +654,7 @@ def get_fp4_quantization_module(backend: str = "100"):
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=fp4_quantize_trace)
 def fp4_quantize(
     input: torch.Tensor,
     global_scale: Optional[torch.Tensor] = None,
@@ -923,7 +929,7 @@ def shuffle_matrix_sf_a(
     return block_scale_interleave(w_shuffled)
 
 
-@flashinfer_api
+@flashinfer_api(trace=nvfp4_quantize_trace)
 def nvfp4_quantize(
     a,
     a_global_sf,
@@ -1024,7 +1030,7 @@ def nvfp4_quantize(
     return a_fp4, a_sf
 
 
-@flashinfer_api
+@flashinfer_api(trace=mxfp4_quantize_trace)
 def mxfp4_quantize(
     a: torch.Tensor,
     backend: str = "cuda",
@@ -1441,7 +1447,7 @@ def _nvfp4_kv_quant_check(input, global_scale):
 
 
 @backend_requirement({}, common_check=_nvfp4_kv_quant_check)
-@flashinfer_api
+@flashinfer_api(trace=nvfp4_kv_quantize_trace)
 def nvfp4_kv_quantize(
     input: torch.Tensor,
     global_scale: torch.Tensor,
diff --git a/flashinfer/quantization/fp8_quantization.py b/flashinfer/quantization/fp8_quantization.py
index f2c9f412..49e13a8b 100644
--- a/flashinfer/quantization/fp8_quantization.py
+++ b/flashinfer/quantization/fp8_quantization.py
@@ -5,6 +5,7 @@ from typing import Literal, Optional, Tuple
 import torch
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.quantize import mxfp8_quantize_trace
 from ..jit.fp8_quantization import gen_mxfp8_quantization_sm100_module
 from ..utils import (
     device_support_pdl,
@@ -158,7 +159,7 @@ def get_mxfp8_quantization_sm100_module():
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=mxfp8_quantize_trace)
 def mxfp8_quantize(
     input: torch.Tensor,
     is_sf_swizzled_layout: bool = True,
diff --git a/flashinfer/rope.py b/flashinfer/rope.py
index d39d2e07..df5c7d4d 100644
--- a/flashinfer/rope.py
+++ b/flashinfer/rope.py
@@ -20,6 +20,21 @@ from typing import Optional, Tuple
 import torch
 
 from .api_logging import flashinfer_api
+from .trace.templates.rope import (
+    apply_llama31_rope_inplace_trace,
+    apply_llama31_rope_pos_ids_inplace_trace,
+    apply_llama31_rope_pos_ids_trace,
+    apply_llama31_rope_trace,
+    apply_rope_inplace_trace,
+    apply_rope_pos_ids_inplace_trace,
+    apply_rope_pos_ids_trace,
+    apply_rope_trace,
--
 
@@ -414,7 +429,7 @@ def _fake_apply_llama31_rope_pos_ids(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=apply_rope_inplace_trace)
 def apply_rope_inplace(
     q: torch.Tensor,
     k: torch.Tensor,
@@ -502,7 +517,7 @@ def apply_rope_inplace(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=apply_rope_pos_ids_inplace_trace)
 def apply_rope_pos_ids_inplace(
     q: torch.Tensor,
     k: torch.Tensor,
@@ -561,7 +576,7 @@ def apply_rope_pos_ids_inplace(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=apply_llama31_rope_inplace_trace)
 def apply_llama31_rope_inplace(
     q: torch.Tensor,
     k: torch.Tensor,
... (truncated -- see full diff via the command above)
```

**Summary of API changes:**

- **Decorator semantic addition (backward-compatible):**
`@flashinfer_api` now accepts an optional `trace=<TraceTemplate>`
keyword. Bare `@flashinfer_api` still works. Existing call sites of
decorated functions are unaffected. Most of the diff above is mechanical
rewrites of existing `@flashinfer_api` to `@flashinfer_api(trace=...)`,
plus the new `flashinfer/trace/` package and `fi_trace.py` for
flashinfer-bench JSON dumps.

- **New public APIs (7):**
- `flashinfer.comm.dcp_alltoall.{decode_cp_a2a_workspace_size,
decode_cp_a2a_allocate_workspace, decode_cp_a2a_init_workspace,
decode_cp_a2a_alltoall}` — DCP all-to-all for context-parallel attention
reduction (#2951).
- `flashinfer.fused_moe.{interleave_moe_scales_for_sm90_mixed_gemm,
interleave_moe_weights_for_sm90_mixed_gemm}` — SM90 mixed-input MoE GEMM
helpers (#3084).
- `flashinfer.comm.run_mixed_comm` — combinations of allreduce /
allgather / reducescatter (#2563).

- **New `@flashinfer_api`-decorated wrapper init:**
- `SegmentGEMMWrapper.__init__` is now decorated. Previously the class
itself was undecorated; `run()` already was. No call-site change.

- **Backward-compatible signature additions (defaults preserve old
behavior):**
- `top_k_page_table_transform`: `+dsa_graph_safe: bool = False`,
`+row_starts: Optional[torch.Tensor] = None` (#3133).
  - `top_k_ragged_transform`: same two new params (#3133).
- `trtllm_ragged_attention_deepseek`: `+backend: str = "trtllm-gen"`
(cute-dsl backend selection).

- **No breaking signature changes** to any `@flashinfer_api` function.
Net public surface delta: +7 functions, +1 newly-decorated `__init__`, 0
removals.

- **Module reorganization to flag (not `@flashinfer_api`, but in public
re-export):**
- `flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py` →
`dense_blockscaled_gemm_sm120_b12x.py`
- Class renamed: `Sm120BlockScaledDenseGemmKernel` →
`Sm120B12xBlockScaledDenseGemmKernel`
- Re-export in `flashinfer/gemm/__init__.py` updated to the new name
only — direct importers of the old name break. Decision needed: ship as
breaking, or add a deprecation alias.

- **Internal autotuner helper rename (not public, but used by downstream
extensions):**
- `get_last_power_of_2_num_tokens_buckets` →
`get_hybrid_num_tokens_buckets`
- `last_positive_power_of_2` → `map_to_hybrid_bucket` /
`map_to_hybrid_bucket_uncapped`

> Diff truncated above due to GitHub PR body length limit. Run the
command at the top locally to see the full output.


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Patch Release**
  * Version updated to 0.6.10

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants