Skip to content

feat: Add DCP All-to-All kernel for context-parallel attention reduction#2951

Merged
aleozlx merged 12 commits intoflashinfer-ai:mainfrom
davidjpyu:pr-dcp-a2a
Apr 23, 2026
Merged

feat: Add DCP All-to-All kernel for context-parallel attention reduction#2951
aleozlx merged 12 commits intoflashinfer-ai:mainfrom
davidjpyu:pr-dcp-a2a

Conversation

@davidjpyu
Copy link
Copy Markdown
Contributor

@davidjpyu davidjpyu commented Apr 2, 2026

What is DCP?

DCP (Decode Context Parallelism) splits the KV cache across multiple GPUs for long-context inference. During the decode phase, each GPU computes partial attention over its local KV shard, producing two tensors:

  • partial_o (bf16/fp16): partial attention output, shape [B, cp_size, head_dim]
  • softmax_stats (fp32): softmax statistics for online correction, shape [B, cp_size, stats_dim]

These must be exchanged across all ranks (all-to-all) before the final attention reduction.

Why this kernel?

The current approach uses 2× NCCL all_to_all_single — one call per tensor since they have different dtypes. For the small message sizes in DCP (KB-range), NCCL protocol overhead dominates latency.

This PR adds a fused LL128 FIFO kernel (ported from TensorRT-LLM) that exchanges both tensors in a single kernel launch via MNNVL cross-GPU direct memory writes, eliminating NCCL protocol overhead.

Kernel-level microbenchmark

GB200 NVL72, 4 GPUs, cp_size=4. Measures raw kernel time only (not end-to-end pipeline).

batch NCCL 2×a2a p50 Native p50 Speedup
1 0.176ms 0.035ms 5.0x
16 0.145ms 0.034ms 4.3x
64 0.140ms 0.033ms 4.3x
128 0.138ms 0.032ms 4.3x

Speedup is latency-dominated — these tensors are small, so NCCL protocol overhead is the bottleneck.

Components

  • Kernel: csrc/nv_internal/tensorrt_llm/kernels/helixAllToAll.{cu,h} — LL128 FIFO protocol, SM90+
  • Launcher + TVM-FFI binding: csrc/trtllm_dcp_alltoall.cu
  • JIT generator: flashinfer/jit/comm.py (gen_dcp_alltoall_module)
  • Python API: flashinfer/comm/dcp_alltoall.py
    • decode_cp_a2a_workspace_size — query workspace bytes per rank
    • decode_cp_a2a_allocate_workspace — allocate workspace (MNNVL or device memory)
    • decode_cp_a2a_init_workspace — initialize FIFO buffers
    • decode_cp_a2a_alltoall — run all-to-all exchange
  • Tests:
    • tests/comm/test_dcp_alltoall.py — 30 single-GPU tests (correctness, edge cases, input validation)
    • tests/comm/test_mnnvl_dcp_alltoall.py — 9 multi-GPU MNNVL tests (workspace shape, cross-rank visibility, alltoall correctness, FIFO reuse)
  • Benchmark: benchmarks/bench_dcp_alltoall.py — native LL128 vs NCCL baseline
  • AOT registration in flashinfer/aot.py

Requirements

  • SM90+ GPU (Hopper/Blackwell)
  • MNNVL support for multi-GPU operation (GB200 NVL72 or similar)
  • mpi4py for multi-GPU tests and benchmark

Test plan

  • 30 single-GPU tests pass (pytest tests/comm/test_dcp_alltoall.py)
  • 9 multi-GPU MNNVL tests pass on 4× GB200 (mpirun -np 4 pytest tests/comm/test_mnnvl_dcp_alltoall.py)
  • Benchmark shows 4.3-5.0× speedup vs NCCL on GB200 NVL72
  • CI passes (single-GPU tests; MNNVL tests auto-skip without fabric memory)

Summary by CodeRabbit

  • New Features

    • DCP decode-context-parallel all-to-all: GPU-native workspace allocation, init, and all-to-all invocation for faster multi-GPU attention reductions.
    • New Helix-based GPU all-to-all kernel plus low-level async-copy/barrier primitives to improve multi-GPU transfer performance.
    • Env-driven toggle to enable a new GPU launch mode on supported architectures.
  • Tests

    • Added GPU microbenchmark and extensive GPU pytest suites validating correctness, workspace lifecycle, reuse, and MNNVL-backed multi-node integration.
  • Chores

    • Communication modules and JIT build specs updated to target SM90+.

Add the DCP (Disaggregated Context Parallelism) LL128 FIFO-based
all-to-all kernel, ported from TensorRT-LLM. This kernel fuses the
exchange of partial attention outputs (bf16/fp16) and softmax statistics
(fp32) into a single kernel launch using MNNVL cross-GPU memory,
replacing 2x NCCL all_to_all_single calls.

Requires SM90+ (Hopper/Blackwell). MNNVL workspace required for
multi-GPU operation.
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 2, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a DCP (Decode Context Parallel) Helix all-to-all implementation: device cp.async/mbarrier primitives and LL128 protocol, Helix CUDA kernel with workspace sizing/initialization/launch, TVM/JIT bindings and Python APIs (optional MNNVL-backed workspace), AOT/JIT integration, benchmark, and single-/multi‑GPU tests.

Changes

Cohort / File(s) Summary
Benchmark
benchmarks/bench_dcp_alltoall.py
New MPI/NCCL microbenchmark measuring single-kernel latency for DCP all-to-all with optional MNNVL workspace, warmup/iters, per-rank timing aggregation and p50/p95/mean stats.
Python API & JIT/AOT
flashinfer/comm/dcp_alltoall.py, flashinfer/comm/__init__.py, flashinfer/jit/comm.py, flashinfer/jit/__init__.py, flashinfer/aot.py
Adds Python APIs for workspace sizing/allocation/init/alltoall (MNNVL or device memory), registers JIT spec generator for dcp_alltoall, exposes JIT generator, and includes the comm module in AOT for SM90/SM100.
FFI / TVM Entrypoints
csrc/trtllm_dcp_alltoall.cu
Registers TVM FFI functions: per-rank workspace sizing, workspace initializer, and native all-to-all wrapper with input validation and Helix params construction.
Helix Kernel Implementation
csrc/nv_internal/tensorrt_llm/kernels/helixAllToAll.h, csrc/nv_internal/tensorrt_llm/kernels/helixAllToAll.cu
New Helix all-to-all CUDA kernel and host-side helpers: FIFO metadata, LL128 pack/unpack integration, workspace sizing, initialization, and kernel launch APIs.
Device Primitives & Protocol
csrc/nv_internal/tensorrt_llm/kernels/cudaAsyncOps.cuh, csrc/nv_internal/tensorrt_llm/kernels/ll128Proto.cuh, csrc/nv_internal/tensorrt_llm/kernels/moeCommKernelsCommon.h
Adds device-side cp.async and mbarrier helpers, shared-memory barrier utilities, LL128 proto pack/unpack and small compile-time constants/utilities.
Env Utilities
csrc/nv_internal/cpp/common/envUtils.cpp, csrc/nv_internal/tensorrt_llm/common/envUtils.h
Adds cached env accessor getEnvEnablePDL() with SM-version gating and single-character env parsing.
JIT Build Inputs
flashinfer/jit/comm.py
JIT spec for dcp_alltoall now includes new CUDA/C++ sources (trtllm_dcp_alltoall.cu, helixAllToAll.cu, envUtils.cpp, etc.) and NVCC flags targeting SM90/SM100.
Tests
tests/comm/test_dcp_alltoall.py, tests/comm/test_mnnvl_dcp_alltoall.py
New GPU tests: single-GPU correctness/workspace tests and MPI multi-GPU tests validating MNNVL-backed workspace allocation/visibility, repeated runs, synchronization, and input validation.

Sequence Diagram

sequenceDiagram
    participant User
    participant PyAPI as Python API
    participant JIT as JIT/FFI Module
    participant Kernel as Helix Kernel
    participant Workspace as FIFO Workspace
    participant Device as CUDA Device
    participant MPI as MPI/NCCL

    User->>PyAPI: decode_cp_a2a_allocate_workspace(cp_size, cp_rank)
    PyAPI->>Device: allocate tensor (MNNVL or torch.cuda)
    Device-->>PyAPI: workspace tensor
    PyAPI-->>User: workspace

    User->>PyAPI: decode_cp_a2a_init_workspace(workspace, cp_rank, cp_size)
    PyAPI->>JIT: initialize_dcp_workspace(workspace, cp_rank, cp_size)
    JIT->>Kernel: initializeHelixWorkspace(workspace_ptr)
    Kernel->>Workspace: memset FIFO metadata
    Kernel->>Device: stream sync
    JIT-->>PyAPI: init complete

    User->>PyAPI: decode_cp_a2a_alltoall(partial_o, softmax_stats, workspace, cp_rank, cp_size)
    PyAPI->>JIT: alltoall_dcp_native(inputs, workspace, cp_rank, cp_size)
    JIT->>Kernel: launchHelixAllToAll(params)
    Kernel->>Workspace: cp.async / LL128 pack writes
    Kernel->>Kernel: mbarrier / FIFO handshake
    Kernel->>Workspace: LL128 unpack reads
    Kernel-->>JIT: outputs tensors
    JIT-->>PyAPI: return outputs
    PyAPI-->>User: outputs
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested reviewers

  • aleozlx
  • cyx-6
  • yzh119
  • yongwww
  • bkryu
  • kahyunnam
  • nv-yunzheq

Poem

🐰 I tunneled bytes through sync and stream,
I packed each proto like a dream,
Barriers blink, cp.async sings,
Peers trade crumbs on silent wings,
Helix hops — and data gleams.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 48.86% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat: Add DCP All-to-All kernel for context-parallel attention reduction' accurately and concisely describes the main change: a new DCP all-to-all kernel for context-parallel operations.
Description check ✅ Passed The PR description is comprehensive and well-structured, covering the DCP context, kernel rationale, performance benchmarks, detailed component list, requirements, and test plan. It aligns with the template structure and provides substantial technical context.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the DCP (Decode Context Parallel) All-to-All communication operation, featuring a fused LL128 FIFO kernel (Helix) optimized for SM90+ architectures. The implementation includes core CUDA kernels, LL128 protocol handling, TVM FFI wrappers, and a Python API with JIT support. Review feedback identifies several critical issues: Programmatic Dependent Launch (PDL) is incorrectly enabled by default on unsupported pre-SM90 architectures, the kernel launch configuration cache lacks thread-safety for concurrent environments, and static caching of workspace parameters fails to account for varying input arguments, potentially leading to incorrect memory allocations.

Comment on lines +361 to +374
static bool enablePDL = true;

std::call_once(flag, [&]() {
if (getSMVersion() >= 90) {
char const* env = std::getenv("TRTLLM_ENABLE_PDL");
if (env) {
if (env[0] == '1' && env[1] == '\0') {
enablePDL = true;
} else if (env[0] == '0' && env[1] == '\0') {
enablePDL = false;
}
};
}
});
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The enablePDL variable is initialized to true and only modified if getSMVersion() >= 90. This means Programmatic Dependent Launch (PDL) will be incorrectly reported as enabled on older GPU architectures (SM < 90), which do not support this feature. This could lead to kernel launch failures when cudaLaunchKernelEx is called with PDL attributes. It should default to false and only be enabled for SM90+.

  static bool enablePDL = false;

  std::call_once(flag, [&]() {
    if (getSMVersion() >= 90) {
      enablePDL = true;
      char const* env = std::getenv("TRTLLM_ENABLE_PDL");
      if (env && env[0] == '0' && env[1] == '\0') {
        enablePDL = false;
      }
    }
  });

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think lock is needed as host-side launch is single-threaded in practice.

Comment on lines +487 to +496
static std::unordered_map<std::tuple<int, int, int>, std::tuple<int, int, int>, hash_cache_key>
cache;
int deviceId = 0;
TLLM_CUDA_CHECK(cudaGetDevice(&deviceId));
int singleShmSize = std::max(computeTotalUnpackedSize(fields), computeProtoTransferSize(fields));
auto key = std::make_tuple(deviceId, cpSize, singleShmSize);
auto it = cache.find(key);
if (it != cache.end()) {
return it->second;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The cache variable is a static std::unordered_map accessed without any synchronization. In a multi-threaded environment (e.g., a multi-stream inference server), concurrent calls to launchHelixAllToAll could lead to race conditions, data corruption, or crashes when accessing or modifying this map. Additionally, the cudaFuncSetAttribute call on line 515 is not thread-safe and modifies global state for the kernel function. Consider using a mutex to protect access to the cache and the attribute setting logic.

Copy link
Copy Markdown
Contributor Author

@davidjpyu davidjpyu Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Comment on lines +596 to +599
static int maxChannelCount = 0;
if (maxChannelCount == 0) {
maxChannelCount = computeHelixMaxChannelCount(cpSize);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The maxChannelCount is cached as a static variable, but its value depends on the cpSize argument. If computeHelixWorkspaceSizePerRank is called with different cpSize values in the same process (e.g., in a multi-model or multi-tenant setup), it will return an incorrect workspace size for all calls after the first one. The caching should be removed or made dependent on cpSize.

  int maxChannelCount = computeHelixMaxChannelCount(cpSize);

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Despite that cpSize will not change during a run, it is light weighted division so I fixed it now: maxChannelCount is no longer static and refreshed every time.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/aot.py (1)

546-550: ⚠️ Potential issue | 🟠 Major

DCP module only generated for SM100+ builds despite supporting SM90.

The gen_dcp_alltoall_module() is placed inside the if has_sm100: block, but its supported_major_versions=[9, 10] declaration indicates support for SM90 as well. This means SM90-only AOT builds (e.g., H100 without Blackwell) won't include the DCP all-to-all module, despite the kernel supporting it. The kernel itself contains no SM100-specific code.

Move the call outside the SM100-only conditional:

Suggested fix
         jit_specs.append(gen_comm_alltoall_module())
+        if has_sm90 or has_sm100:
+            jit_specs.append(gen_dcp_alltoall_module())
         if has_sm100:
             jit_specs.append(gen_trtllm_comm_module())
             jit_specs.append(gen_trtllm_mnnvl_comm_module())
             jit_specs.append(gen_moe_alltoall_module())
-            jit_specs.append(gen_dcp_alltoall_module())
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/aot.py` around lines 546 - 550, The DCP all-to-all module is
incorrectly only appended when has_sm100 is true; move the call to
gen_dcp_alltoall_module() out of the has_sm100 conditional so it is appended to
jit_specs for SM90 and SM100 builds (keep gen_trtllm_comm_module(),
gen_trtllm_mnnvl_comm_module(), gen_moe_alltoall_module() inside the if block
and only relocate the gen_dcp_alltoall_module() invocation so that
jit_specs.append(gen_dcp_alltoall_module()) runs regardless of has_sm100).
🧹 Nitpick comments (8)
tests/comm/test_dcp_alltoall.py (2)

54-58: Session-scoped fixture may not seed all tests as expected.

The setup_test_environment fixture sets torch.manual_seed(0xA2A) once per session, but pytest may run tests in different orders, and torch's RNG state is mutated by each test. Consider making this function-scoped or removing autouse and calling it explicitly where determinism is needed.

-@pytest.fixture(autouse=True, scope="session")
+@pytest.fixture(autouse=True, scope="function")
 def setup_test_environment():
     """Set torch seed for deterministic tests."""
     torch.manual_seed(0xA2A)
     yield

This ensures each test starts with the same RNG state.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/comm/test_dcp_alltoall.py` around lines 54 - 58, The session-scoped
autouse fixture setup_test_environment only seeds torch once per session, so
tests mutate RNG state and won't start from the same seed; change the fixture
scope to "function" (or remove autouse and call the fixture explicitly in tests
that need determinism) so torch.manual_seed(0xA2A) runs before each test,
ensuring each test begins with the same RNG state; locate and update the
setup_test_environment fixture definition accordingly.

38-51: Consider using flashinfer.utils.get_compute_capability() per coding guidelines.

The custom _sm90_available() function reimplements SM capability detection. Per coding guidelines, tests should use flashinfer.utils functions for architecture checks.

-def _sm90_available() -> bool:
-    try:
-        if not torch.cuda.is_available():
-            return False
-        major, _ = torch.cuda.get_device_capability(0)
-        return major >= 9
-    except Exception:
-        return False
+from flashinfer.utils import get_compute_capability
+
+def _sm90_available() -> bool:
+    try:
+        major, _ = get_compute_capability(0)
+        return major >= 9
+    except Exception:
+        return False

This also addresses the static analysis hint about catching blind exceptions, as the exception handling is limited to a fallback in a skip-check context.

As per coding guidelines: "Use flashinfer.utils functions (get_compute_capability(), is_sm90a_supported(), is_sm100a_supported()) to skip tests on unsupported GPU architectures"

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/comm/test_dcp_alltoall.py` around lines 38 - 51, Replace the custom
_sm90_available() and pytestmark check with the utility functions from
flashinfer.utils: call flashinfer.utils.get_compute_capability() or use
flashinfer.utils.is_sm90a_supported()/is_sm100a_supported() to determine support
and base the pytest.mark.skipif on that result; update the symbol references in
the module (remove _sm90_available and change pytestmark to call the
flashinfer.utils helper) so the test uses the canonical capability-checking
helpers rather than reimplementing device queries and broad exception handling.
flashinfer/comm/dcp_alltoall.py (2)

101-116: Consider adding @backend_requirement decorator for SM90+ compute capability check.

Per coding guidelines, APIs with compute capability requirements should use the @backend_requirement decorator. The DCP all-to-all requires SM90+ (Hopper/Blackwell).

+from ..compilation_context import backend_requirement
+
+@backend_requirement(sm_major=9)
 `@flashinfer_api`
 def decode_cp_a2a_workspace_size(cp_size: int) -> int:

This would provide is_compute_capability_supported(cc) and is_backend_supported() methods for runtime checks.

As per coding guidelines: "Use @backend_requirement decorator on APIs that have compute capability requirements"

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/comm/dcp_alltoall.py` around lines 101 - 116, The API
decode_cp_a2a_workspace_size lacks the required compute-capability guard; wrap
it with the `@backend_requirement` decorator configured for SM90+ so runtime
checks are available. Update the function definition to use
`@backend_requirement`(...) before `@flashinfer_api` (or alongside as per project
convention) and ensure calls to
get_dcp_alltoall_module()/get_workspace_size_per_rank are protected by the
decorator’s provided helpers (is_compute_capability_supported(cc) and
is_backend_supported()) so the API only runs on supported SM90+ backends.

243-248: Minor: __all__ is not sorted alphabetically.

Static analysis (RUF022) flags unsorted __all__. Consider sorting for consistency:

 __all__ = [
-    "decode_cp_a2a_workspace_size",
     "decode_cp_a2a_allocate_workspace",
+    "decode_cp_a2a_alltoall",
     "decode_cp_a2a_init_workspace",
-    "decode_cp_a2a_alltoall",
+    "decode_cp_a2a_workspace_size",
 ]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/comm/dcp_alltoall.py` around lines 243 - 248, The __all__ list in
dcp_alltoall.py is not alphabetically sorted; reorder the export names so they
are lexically ascending (e.g. decode_cp_a2a_alltoall,
decode_cp_a2a_allocate_workspace, decode_cp_a2a_init_workspace,
decode_cp_a2a_workspace_size) to satisfy static analysis (RUF022) and keep
consistency when referencing exports like decode_cp_a2a_workspace_size,
decode_cp_a2a_allocate_workspace, decode_cp_a2a_init_workspace, and
decode_cp_a2a_alltoall.
csrc/nv_internal/cpp/common/envUtils.cpp (1)

359-376: Consider setting enablePDL = false for SM < 90.

The function returns true by default (line 361), and only processes the environment variable when getSMVersion() >= 90 (line 364). This means on pre-SM90 hardware, it always returns true.

While the CUDA runtime likely ignores the programmaticStreamSerializationAllowed attribute on older architectures, semantically it would be cleaner to default to false when the feature isn't supported:

 bool getEnvEnablePDL() {
   static std::once_flag flag;
-  static bool enablePDL = true;
+  static bool enablePDL = false;
 
   std::call_once(flag, [&]() {
     if (getSMVersion() >= 90) {
+      enablePDL = true;  // default to enabled on supported hardware
       char const* env = std::getenv("TRTLLM_ENABLE_PDL");

This makes the default state consistent with hardware capability.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/nv_internal/cpp/common/envUtils.cpp` around lines 359 - 376, The
function getEnvEnablePDL currently defaults enablePDL to true and only reads
TRTLLM_ENABLE_PDL when getSMVersion() >= 90; change the default to false so that
for SM < 90 getEnvEnablePDL() returns false (feature unsupported), and keep the
existing std::call_once logic to override enablePDL only when getSMVersion() >=
90 by reading the TRTLLM_ENABLE_PDL env var; update the static bool enablePDL
initialization and ensure getEnvEnablePDL, getSMVersion, and the
TRTLLM_ENABLE_PDL handling are the only touched symbols.
benchmarks/bench_dcp_alltoall.py (2)

141-147: Prefix unused return values with underscore to suppress linter warnings.

The recv_o and recv_s variables are intentionally unused during warmup iterations. Prefixing them with underscore makes this intent explicit and silences the Ruff RUF059 warning.

♻️ Suggested fix
     # Warmup
     for _ in range(warmup):
-        recv_o, recv_s = decode_cp_a2a_alltoall(
+        _recv_o, _recv_s = decode_cp_a2a_alltoall(
             partial_o, softmax_stats, workspace, rank, cp_size
         )

Similarly for the timed loop at lines 158-160:

-        recv_o, recv_s = decode_cp_a2a_alltoall(
+        _recv_o, _recv_s = decode_cp_a2a_alltoall(
             partial_o, softmax_stats, workspace, rank, cp_size
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_dcp_alltoall.py` around lines 141 - 147, The warmup loop
(and the timed loop that also calls decode_cp_a2a_alltoall) currently assigns
returned values to recv_o and recv_s but never uses them; rename those local
variables to _recv_o and _recv_s when calling decode_cp_a2a_alltoall to indicate
they are intentionally unused and silence the RUF059 linter warning, e.g.,
replace recv_o, recv_s = decode_cp_a2a_alltoall(...) with _recv_o, _recv_s =
decode_cp_a2a_alltoall(...) in both the warmup and the timed invocation sites.

243-244: Prefix unused local_rank with underscore.

The local_rank is computed inside setup_mpi() and used there to set the CUDA device, but the returned value is never used at the call site.

♻️ Suggested fix
-    rank, world_size, mpi_comm, local_rank = setup_mpi()
+    rank, world_size, mpi_comm, _local_rank = setup_mpi()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_dcp_alltoall.py` around lines 243 - 244, The unpacking of
setup_mpi() returns a local_rank value that isn’t used; change the tuple unpack
to ignore it by prefixing with an underscore (e.g., replace local_rank with
_local_rank or use a single underscore _) so the call becomes rank, world_size,
mpi_comm, _local_rank = setup_mpi() which makes intent explicit while leaving
setup_mpi(), rank, world_size and mpi_comm unchanged.
tests/comm/test_mnnvl_dcp_alltoall.py (1)

53-60: Consider catching a more specific exception or using flashinfer utility.

The blind Exception catch (Ruff BLE001) could mask unexpected errors. However, this is a safety guard for capability checking in an MPI context where various errors could occur.

♻️ Alternative using specific exceptions
 def _sm90_available() -> bool:
     try:
         if not torch.cuda.is_available():
             return False
         major, _ = torch.cuda.get_device_capability(0)
         return major >= 9
-    except Exception:
+    except (RuntimeError, AssertionError):
         return False

Alternatively, consider using flashinfer.utils.get_compute_capability() if it's available in the MPI-launched test context. As per coding guidelines, tests should use flashinfer.utils functions for architecture checks.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/comm/test_mnnvl_dcp_alltoall.py` around lines 53 - 60, Replace the
broad Exception catch in _sm90_available with a targeted approach: prefer
calling flashinfer.utils.get_compute_capability() (and return major >= 9) if
that utility is available in test runtime; otherwise restrict the except clause
to concrete errors that can occur during CUDA queries (e.g., RuntimeError,
OSError, IndexError) instead of catching Exception so unexpected bugs aren't
swallowed; update the function _sm90_available to reference the chosen approach
and adjust imports accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/nv_internal/tensorrt_llm/kernels/helixAllToAll.cu`:
- Around line 486-534: computeChannelAndGroupCount uses a static
std::unordered_map named cache that is mutated without synchronization; protect
concurrent access by adding a static mutex (e.g., cache_mutex) and using a
std::lock_guard<std::mutex> (or similar) around lookups/insertions of cache
keyed by key so find/insert are atomic; for better performance do a
double-checked pattern: do an unsynchronized find, and if not found take the
lock, re-check cache.find(key), and only then compute/insert the value (refer to
computeChannelAndGroupCount, cache, key, and the returned value tuple).
- Around line 595-609: computeHelixWorkspaceSizePerRank currently uses a static
maxChannelCount cached across calls, but computeHelixMaxChannelCount(cpSize)
depends on cpSize so the cached value can be wrong for different cpSize; fix by
removing the static single-value cache and compute maxChannelCount =
computeHelixMaxChannelCount(cpSize) each call (or replace the static with a
small cache keyed by cpSize, e.g., a map from cpSize to channel count) so
computeHelixWorkspaceSizePerRank and its fifo/sender/receiver size calculations
always use the correct channel count.

In `@csrc/nv_internal/tensorrt_llm/kernels/moeCommKernelsCommon.h`:
- Around line 74-84: Replace the current overflow-prone implementation of
ceil_div and add type constraints: add a static_assert(std::is_integral_v<T>) to
both ceil_div and align_up to enforce integral types, reimplement ceil_div as "a
/ b + (a % b != 0)" (which avoids a+b overflow and is constexpr-friendly) and
ensure align_up continues to call ceil_div(value, alignment) * alignment; also
consider validating alignment > 0 (e.g., via assert or precondition) to avoid
division by zero.

---

Outside diff comments:
In `@flashinfer/aot.py`:
- Around line 546-550: The DCP all-to-all module is incorrectly only appended
when has_sm100 is true; move the call to gen_dcp_alltoall_module() out of the
has_sm100 conditional so it is appended to jit_specs for SM90 and SM100 builds
(keep gen_trtllm_comm_module(), gen_trtllm_mnnvl_comm_module(),
gen_moe_alltoall_module() inside the if block and only relocate the
gen_dcp_alltoall_module() invocation so that
jit_specs.append(gen_dcp_alltoall_module()) runs regardless of has_sm100).

---

Nitpick comments:
In `@benchmarks/bench_dcp_alltoall.py`:
- Around line 141-147: The warmup loop (and the timed loop that also calls
decode_cp_a2a_alltoall) currently assigns returned values to recv_o and recv_s
but never uses them; rename those local variables to _recv_o and _recv_s when
calling decode_cp_a2a_alltoall to indicate they are intentionally unused and
silence the RUF059 linter warning, e.g., replace recv_o, recv_s =
decode_cp_a2a_alltoall(...) with _recv_o, _recv_s = decode_cp_a2a_alltoall(...)
in both the warmup and the timed invocation sites.
- Around line 243-244: The unpacking of setup_mpi() returns a local_rank value
that isn’t used; change the tuple unpack to ignore it by prefixing with an
underscore (e.g., replace local_rank with _local_rank or use a single underscore
_) so the call becomes rank, world_size, mpi_comm, _local_rank = setup_mpi()
which makes intent explicit while leaving setup_mpi(), rank, world_size and
mpi_comm unchanged.

In `@csrc/nv_internal/cpp/common/envUtils.cpp`:
- Around line 359-376: The function getEnvEnablePDL currently defaults enablePDL
to true and only reads TRTLLM_ENABLE_PDL when getSMVersion() >= 90; change the
default to false so that for SM < 90 getEnvEnablePDL() returns false (feature
unsupported), and keep the existing std::call_once logic to override enablePDL
only when getSMVersion() >= 90 by reading the TRTLLM_ENABLE_PDL env var; update
the static bool enablePDL initialization and ensure getEnvEnablePDL,
getSMVersion, and the TRTLLM_ENABLE_PDL handling are the only touched symbols.

In `@flashinfer/comm/dcp_alltoall.py`:
- Around line 101-116: The API decode_cp_a2a_workspace_size lacks the required
compute-capability guard; wrap it with the `@backend_requirement` decorator
configured for SM90+ so runtime checks are available. Update the function
definition to use `@backend_requirement`(...) before `@flashinfer_api` (or alongside
as per project convention) and ensure calls to
get_dcp_alltoall_module()/get_workspace_size_per_rank are protected by the
decorator’s provided helpers (is_compute_capability_supported(cc) and
is_backend_supported()) so the API only runs on supported SM90+ backends.
- Around line 243-248: The __all__ list in dcp_alltoall.py is not alphabetically
sorted; reorder the export names so they are lexically ascending (e.g.
decode_cp_a2a_alltoall, decode_cp_a2a_allocate_workspace,
decode_cp_a2a_init_workspace, decode_cp_a2a_workspace_size) to satisfy static
analysis (RUF022) and keep consistency when referencing exports like
decode_cp_a2a_workspace_size, decode_cp_a2a_allocate_workspace,
decode_cp_a2a_init_workspace, and decode_cp_a2a_alltoall.

In `@tests/comm/test_dcp_alltoall.py`:
- Around line 54-58: The session-scoped autouse fixture setup_test_environment
only seeds torch once per session, so tests mutate RNG state and won't start
from the same seed; change the fixture scope to "function" (or remove autouse
and call the fixture explicitly in tests that need determinism) so
torch.manual_seed(0xA2A) runs before each test, ensuring each test begins with
the same RNG state; locate and update the setup_test_environment fixture
definition accordingly.
- Around line 38-51: Replace the custom _sm90_available() and pytestmark check
with the utility functions from flashinfer.utils: call
flashinfer.utils.get_compute_capability() or use
flashinfer.utils.is_sm90a_supported()/is_sm100a_supported() to determine support
and base the pytest.mark.skipif on that result; update the symbol references in
the module (remove _sm90_available and change pytestmark to call the
flashinfer.utils helper) so the test uses the canonical capability-checking
helpers rather than reimplementing device queries and broad exception handling.

In `@tests/comm/test_mnnvl_dcp_alltoall.py`:
- Around line 53-60: Replace the broad Exception catch in _sm90_available with a
targeted approach: prefer calling flashinfer.utils.get_compute_capability() (and
return major >= 9) if that utility is available in test runtime; otherwise
restrict the except clause to concrete errors that can occur during CUDA queries
(e.g., RuntimeError, OSError, IndexError) instead of catching Exception so
unexpected bugs aren't swallowed; update the function _sm90_available to
reference the chosen approach and adjust imports accordingly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 5c6f78ce-54b4-4218-a142-da477e8bf9a5

📥 Commits

Reviewing files that changed from the base of the PR and between 637209a and 82e8c81.

📒 Files selected for processing (16)
  • benchmarks/bench_dcp_alltoall.py
  • csrc/nv_internal/cpp/common/envUtils.cpp
  • csrc/nv_internal/tensorrt_llm/common/envUtils.h
  • csrc/nv_internal/tensorrt_llm/kernels/cudaAsyncOps.cuh
  • csrc/nv_internal/tensorrt_llm/kernels/helixAllToAll.cu
  • csrc/nv_internal/tensorrt_llm/kernels/helixAllToAll.h
  • csrc/nv_internal/tensorrt_llm/kernels/ll128Proto.cuh
  • csrc/nv_internal/tensorrt_llm/kernels/moeCommKernelsCommon.h
  • csrc/trtllm_dcp_alltoall.cu
  • flashinfer/aot.py
  • flashinfer/comm/__init__.py
  • flashinfer/comm/dcp_alltoall.py
  • flashinfer/jit/__init__.py
  • flashinfer/jit/comm.py
  • tests/comm/test_dcp_alltoall.py
  • tests/comm/test_mnnvl_dcp_alltoall.py

Comment on lines +486 to +534
std::tuple<int, int, int> computeChannelAndGroupCount(int cpSize, HelixFieldInfo const* fields) {
static std::unordered_map<std::tuple<int, int, int>, std::tuple<int, int, int>, hash_cache_key>
cache;
int deviceId = 0;
TLLM_CUDA_CHECK(cudaGetDevice(&deviceId));
int singleShmSize = std::max(computeTotalUnpackedSize(fields), computeProtoTransferSize(fields));
auto key = std::make_tuple(deviceId, cpSize, singleShmSize);
auto it = cache.find(key);
if (it != cache.end()) {
return it->second;
}

int maxGroupCountPerCta = std::min(cpSize, MAX_GROUP_COUNT_PER_BLOCK);
int groupCountPerCta = maxGroupCountPerCta; // Start with max
int totalDynamicShmemSize = singleShmSize * groupCountPerCta;
int maxDynamicShmSize = 0;
TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&maxDynamicShmSize,
cudaDevAttrMaxSharedMemoryPerBlockOptin, deviceId));

while (totalDynamicShmemSize > maxDynamicShmSize) {
groupCountPerCta--;
totalDynamicShmemSize = singleShmSize * groupCountPerCta;
}

TLLM_CHECK_WITH_INFO(totalDynamicShmemSize <= maxDynamicShmSize,
"Single packed size %d exceeds limit %d", singleShmSize, maxDynamicShmSize);

// Set shared memory attribute if needed
if (totalDynamicShmemSize > 48 * 1024) {
TLLM_CUDA_CHECK(cudaFuncSetAttribute(helixAllToAllKernel<ALLOW_VARIABLE_FIELD1>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
totalDynamicShmemSize));
}

int blockCountPerChannel = ceil_div(cpSize, groupCountPerCta);
blockCountPerChannel *= 2; // for send and recv

int smCount = 0;
TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&smCount, cudaDevAttrMultiProcessorCount, deviceId));
// TODO: we might only want to use half the SMs to overlap with other kernels.
// note that overlap with FMHA is almost impossible because it must use
// all SMs and probably uses >50% shmem per SM.
// overlap with the subsequent BMM / out proj GEMMs might be possible,
// so we need experiments to see whether it makes sense.
int channelCount = std::max(smCount / blockCountPerChannel, 1);
auto value = std::make_tuple(channelCount, groupCountPerCta, totalDynamicShmemSize);
cache[key] = value;
return value;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Static cache in computeChannelAndGroupCount is not thread-safe for concurrent writes.

The static cache map is written to without synchronization. If multiple threads call this function concurrently with different keys, there's a potential data race during map insertion.

For single-threaded host code paths (typical for kernel launch configuration), this may not be an issue in practice. However, if this code could be called from multiple host threads, consider adding a mutex.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/nv_internal/tensorrt_llm/kernels/helixAllToAll.cu` around lines 486 -
534, computeChannelAndGroupCount uses a static std::unordered_map named cache
that is mutated without synchronization; protect concurrent access by adding a
static mutex (e.g., cache_mutex) and using a std::lock_guard<std::mutex> (or
similar) around lookups/insertions of cache keyed by key so find/insert are
atomic; for better performance do a double-checked pattern: do an unsynchronized
find, and if not found take the lock, re-check cache.find(key), and only then
compute/insert the value (refer to computeChannelAndGroupCount, cache, key, and
the returned value tuple).

Comment thread csrc/nv_internal/tensorrt_llm/kernels/helixAllToAll.cu
Comment on lines +74 to +84
inline constexpr T ceil_div(T a, T b) {
return (a + b - 1) / b;
}

/**
* Align value up to nearest multiple of alignment
*/
template <typename T>
inline constexpr T align_up(T value, T alignment) {
return ceil_div(value, alignment) * alignment;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify call-site safety for ceil_div/align_up usage patterns.
# Expectation: divisors/alignments are always non-zero and arguments are non-negative integral expressions.
rg -n -C3 '\bceil_div\s*\('
rg -n -C3 '\balign_up\s*\('

Repository: flashinfer-ai/flashinfer

Length of output: 50381


🏁 Script executed:

fd -type f -name "moeCommKernelsCommon.h" | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 238


🏁 Script executed:

find . -name "moeCommKernelsCommon.h" 2>/dev/null

Repository: flashinfer-ai/flashinfer

Length of output: 130


🏁 Script executed:

sed -n '70,90p' ./csrc/nv_internal/tensorrt_llm/kernels/moeCommKernelsCommon.h

Repository: flashinfer-ai/flashinfer

Length of output: 622


🏁 Script executed:

rg -n '\bceil_div\s*\(|\balign_up\s*\(' ./csrc/nv_internal/tensorrt_llm/kernels/moeCommKernelsCommon.h

Repository: flashinfer-ai/flashinfer

Length of output: 217


🏁 Script executed:

# Check for any calls to these functions in dependent headers/files
rg -l 'ceil_div|align_up' ./csrc/nv_internal/tensorrt_llm/kernels/ | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 282


🏁 Script executed:

# Check usages in helixAllToAll.cu
rg -n -B2 -A2 'ceil_div\s*\(|align_up\s*\(' ./csrc/nv_internal/tensorrt_llm/kernels/helixAllToAll.cu | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 1027


🏁 Script executed:

# Check the other file with usage
rg -n -B2 -A2 'ceil_div\s*\(|align_up\s*\(' ./csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h

Repository: flashinfer-ai/flashinfer

Length of output: 471


🏁 Script executed:

# Check types and values of variables used in calls
rg -n 'groupCountPerCta|MAX_GROUP_COUNT_PER_BLOCK' ./csrc/nv_internal/tensorrt_llm/kernels/helixAllToAll.cu | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 879


Improve robustness of ceil_div and align_up for signed types.

The current implementation (a + b - 1) / b can overflow for signed types when a + b approaches the type's maximum. Consider using the alternative formula a / b + (a % b != 0), which avoids overflow and works in constexpr contexts. Additionally, add static_assert(std::is_integral_v<T>) to enforce integral types.

Suggested change
+#include <type_traits>
+
 template <typename T>
 inline constexpr T ceil_div(T a, T b) {
+  static_assert(std::is_integral_v<T>, "ceil_div requires integral type");
-  return (a + b - 1) / b;
+  return a / b + (a % b != 0);
 }
 
 template <typename T>
 inline constexpr T align_up(T value, T alignment) {
+  static_assert(std::is_integral_v<T>, "align_up requires integral type");
   return ceil_div(value, alignment) * alignment;
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
inline constexpr T ceil_div(T a, T b) {
return (a + b - 1) / b;
}
/**
* Align value up to nearest multiple of alignment
*/
template <typename T>
inline constexpr T align_up(T value, T alignment) {
return ceil_div(value, alignment) * alignment;
}
`#include` <type_traits>
template <typename T>
inline constexpr T ceil_div(T a, T b) {
static_assert(std::is_integral_v<T>, "ceil_div requires integral type");
return a / b + (a % b != 0);
}
/**
* Align value up to nearest multiple of alignment
*/
template <typename T>
inline constexpr T align_up(T value, T alignment) {
static_assert(std::is_integral_v<T>, "align_up requires integral type");
return ceil_div(value, alignment) * alignment;
}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/nv_internal/tensorrt_llm/kernels/moeCommKernelsCommon.h` around lines 74
- 84, Replace the current overflow-prone implementation of ceil_div and add type
constraints: add a static_assert(std::is_integral_v<T>) to both ceil_div and
align_up to enforce integral types, reimplement ceil_div as "a / b + (a % b !=
0)" (which avoids a+b overflow and is constexpr-friendly) and ensure align_up
continues to call ceil_div(value, alignment) * alignment; also consider
validating alignment > 0 (e.g., via assert or precondition) to avoid division by
zero.

gen_trtllm_comm_module, gen_moe_alltoall_module, and gen_dcp_alltoall_module
all declare supported_major_versions=[9, 10] but were only included in AOT
builds when has_sm100 was true. SM90 (Hopper) users would miss these modules
in pre-compiled packages.
@davidjpyu
Copy link
Copy Markdown
Contributor Author

Re: aot.py — DCP module only generated for SM100+ (CodeRabbit)

Fixed in 9f34778 — changed if has_sm100: to if has_sm90 or has_sm100:.

This also correctly covers gen_trtllm_comm_module(), gen_trtllm_mnnvl_comm_module(), and gen_moe_alltoall_module() in the same block, which all declare supported_major_versions=[9, 10].

Verified all 30 DCP unit tests pass on H200 (SM90).

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 3, 2026

/bot run

@aleozlx aleozlx self-assigned this Apr 3, 2026
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !500 has been created, and the CI pipeline #47588815 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #47588815: 5/20 passed

@aleozlx aleozlx added the run-ci label Apr 3, 2026
@aleozlx aleozlx enabled auto-merge (squash) April 3, 2026 22:38
@aleozlx aleozlx disabled auto-merge April 3, 2026 22:38
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/comm/test_mnnvl_dcp_alltoall.py`:
- Line 47: pynvml.nvmlInit() is executed at import time and can cause collection
to fail; move it behind the existing skip protection or guard it so it only runs
when GPUs/driver are present: either relocate the call into the conditional that
checks _sm90_available() and mnnvl_available() (the same block around the
guarded workspace allocation at line 107) or wrap the call in a try/except that
catches initialization errors and sets a flag so tests will be skipped;
reference the symbols pynvml.nvmlInit(), _sm90_available(), and
mnnvl_available() when making the change.
- Around line 53-60: The helper _sm90_available currently duplicates CUDA
capability detection and swallows errors; replace its probe to use
flashinfer.utils.get_compute_capability instead of
torch.cuda.get_device_capability and remove the blanket except so CUDA init
errors surface. Specifically, import get_compute_capability and implement
_sm90_available to first check torch.cuda.is_available(), then call
get_compute_capability(torch.device("cuda")) to obtain (major, minor) and return
major >= 9, leaving exceptions to propagate.
- Around line 332-333: The test runner currently ignores pytest.main()'s return
code in the if __name__ == "__main__" block; capture the return value from
pytest.main([__file__, "-v", "-s"]) and propagate it by calling
sys.exit(return_code) so the process exits nonzero on test failures, and add an
import for sys if not present; update the main block around pytest.main to use
the captured result and sys.exit.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2dad6958-3076-40c6-ba77-d69c94122820

📥 Commits

Reviewing files that changed from the base of the PR and between 9f34778 and 30a52894d9b6e3de0d8e572797d59cce3b187783.

📒 Files selected for processing (1)
  • tests/comm/test_mnnvl_dcp_alltoall.py

Comment thread tests/comm/test_mnnvl_dcp_alltoall.py
Comment thread tests/comm/test_mnnvl_dcp_alltoall.py Outdated
Comment on lines +332 to +333
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Propagate the pytest exit code here.

pytest.main() returns the process status, but this block ignores it, so python tests/comm/test_mnnvl_dcp_alltoall.py exits 0 even when tests fail.

♻️ Proposed fix
 if __name__ == "__main__":
-    pytest.main([__file__, "-v", "-s"])
+    raise SystemExit(pytest.main([__file__, "-v", "-s"]))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])
if __name__ == "__main__":
raise SystemExit(pytest.main([__file__, "-v", "-s"]))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/comm/test_mnnvl_dcp_alltoall.py` around lines 332 - 333, The test
runner currently ignores pytest.main()'s return code in the if __name__ ==
"__main__" block; capture the return value from pytest.main([__file__, "-v",
"-s"]) and propagate it by calling sys.exit(return_code) so the process exits
nonzero on test failures, and add an import for sys if not present; update the
main block around pytest.main to use the captured result and sys.exit.

The module-level MNNVL workspace allocation runs during pytest
collection, before pytestmark skipif conditions take effect. In CI
environments without SYS_PTRACE capability, this causes a collection
error instead of a graceful skip. Guard the allocation behind the
same mnnvl_available() check used by pytestmark.
- Wrap all MNNVL test bodies in try/finally to ensure _comm.Barrier()
  runs even if assertions fail (prevents peer rank hangs)
- Remove static cache in computeHelixWorkspaceSizePerRank so workspace
  size is correct if cpSize varies across calls
@aleozlx aleozlx enabled auto-merge (squash) April 13, 2026 04:34
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 14, 2026

possible AOT test regression on main from another PR
the fix:
#3056

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 14, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !500 has been updated with latest changes, and the CI pipeline #48533956 is currently running. I'll report back once the pipeline job completes.

…as_sm90+

The previous gate change (has_sm100 -> has_sm90 or has_sm100) caused
mnnvl_moe_alltoall and trtllm_comm/mnnvl_comm to be compiled on
CUDA 12.6 (SM90), but those modules use SM100-only PTX instructions
(fence .release/.acquire, cp.async.bulk with fabric state space).

Split the gate so only gen_dcp_alltoall_module() is under has_sm90,
while the existing SM100-only modules stay behind has_sm100.
auto-merge was automatically disabled April 15, 2026 19:31

Head branch was pushed to by a user without write access

helixAllToAll.cu uses cp.async.bulk instructions that trigger a
known ptxas state-space inference bug in CUDA 12.6.0 on arm64.
Gate dcp_alltoall under has_sm100 (requires CUDA >= 12.8) alongside
the other comm modules. SM90 users still get dcp_alltoall via JIT
fallback at runtime.
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 16, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !500 has been updated with latest changes, and the CI pipeline #48737703 is currently running. I'll report back once the pipeline job completes.

The kernel is compiled with supported_major_versions=[9, 10] in
gen_dcp_alltoall_module, so running on SM11x/SM12x GPUs triggers
'No supported CUDA architectures found for major versions [9, 10].'
at JIT time. Use is_sm90a_supported/is_sm100a_supported (which also
verify the CUDA toolkit version) so the skip matches the kernel's
actual compilation gate.
The helix A2A kernel only requires TMA + mbarrier + PDL (`__CUDA_ARCH__ >= 900`
baseline) — no SM100-exclusive PTX. TRT-LLM upstream compiles it for
SM80/86/89/90/100/103/120 with an internal `#if __CUDA_ARCH__ >= 900` guard,
so FlashInfer's previous `supported_major_versions=[9, 10]` gate was unnecessarily
narrow and broke JIT on SM11x/SM12x runners.

- jit/comm.py: supported_major_versions -> [9, 10, 11, 12]
- aot.py: gen_dcp_alltoall_module now runs under any SM90+ family, not just has_sm100
- tests: skipif now accepts major in (9, 10, 11, 12)
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 17, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !500 has been updated with latest changes, and the CI pipeline #48816154 is currently running. I'll report back once the pipeline job completes.

…0 bug

ptxas 12.6.0 has a known bug with cp.async.bulk state-space inference that
aborts compilation of helixAllToAll.cu (fixed in 12.6.3). Widening the AOT
gate to `has_sm90 or ...` in the previous commit caused the cu126 x64 AOT
build to hit this bug (it targets only sm_90a, which is exactly the path
that trips ptxas).

Keep the JIT arch list widened to [9, 10, 11, 12] — that was the fix for
SM12x JIT runners, and is independent of AOT. has_sm100 implies CUDA >= 12.8,
so the AOT build naturally avoids the buggy compiler. SM90/SM12x wheel users
still get dcp_alltoall via JIT on their own machines.
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 18, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !500 has been updated with latest changes, and the CI pipeline #48833983 is currently running. I'll report back once the pipeline job completes.

@aleozlx aleozlx enabled auto-merge (squash) April 23, 2026 22:10
@aleozlx aleozlx merged commit 498e837 into flashinfer-ai:main Apr 23, 2026
29 of 30 checks passed
@aleozlx aleozlx mentioned this pull request Apr 25, 2026
davidjpyu added a commit to davidjpyu/vllm that referenced this pull request Apr 30, 2026
The FlashInfer decode_cp_a2a_alltoall kernel addresses peer FIFOs via a
single workspace base pointer + per-rank stride, which only works when
the workspace is unified VA (MNNVL fabric memory). The torch.zeros
fallback we previously selected for single-node CP groups deadlocks the
kernel. Match TRT-LLM upstream and require MNNVL.

Refs: flashinfer-ai/flashinfer#2951 follow-up
aleozlx added a commit that referenced this pull request May 5, 2026
## Description

Bump version to 0.6.10 for release.

## Related Issues (Gated-by PRs)


https://github.com/flashinfer-ai/flashinfer/issues?q=is%3Aopen+label%3Av0.6.10

## Reviewer Notes

**API changes review**

API changes since v0.6.9

```diff
$ git diff v0.6.9..main -- "*.py" | grep -B5 -A20 "@flashinfer_api"
     register_custom_op,
@@ -67,7 +73,7 @@ def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=silu_and_mul_trace)
 def silu_and_mul(
     input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None
 ) -> torch.Tensor:
@@ -112,7 +118,7 @@ def silu_and_mul(
     return out
 
 
-@flashinfer_api
+@flashinfer_api(trace=gelu_tanh_and_mul_trace)
 def gelu_tanh_and_mul(
     input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None
 ) -> torch.Tensor:
@@ -153,7 +159,7 @@ def gelu_tanh_and_mul(
     return out
 
 
-@flashinfer_api
+@flashinfer_api(trace=gelu_and_mul_trace)
 def gelu_and_mul(
     input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None
 ) -> torch.Tensor:
@@ -194,7 +200,7 @@ def gelu_and_mul(
     return out
 
 
-@flashinfer_api
+@flashinfer_api(trace=silu_and_mul_scaled_nvfp4_experts_quantize_trace)
 def silu_and_mul_scaled_nvfp4_experts_quantize(
     a,
     mask,
diff --git a/flashinfer/aot.py b/flashinfer/aot.py
index dfb05150..d26d5407 100644
--- a/flashinfer/aot.py
+++ b/flashinfer/aot.py
@@ -543,6 +543,7 @@ def gen_all_modules(
     if add_comm:
         from .jit.comm import (
             gen_comm_alltoall_module,
+            gen_dcp_alltoall_module,
             gen_moe_alltoall_module,
             gen_trtllm_comm_module,
             gen_trtllm_mnnvl_comm_module,
@@ -554,6 +555,11 @@ def gen_all_modules(
             jit_specs.append(gen_trtllm_comm_module())
             jit_specs.append(gen_trtllm_mnnvl_comm_module())
             jit_specs.append(gen_moe_alltoall_module())
+            # dcp_alltoall: kernel itself supports SM90+, but ptxas 12.6.0 has
--
 
-def flashinfer_api(func: Callable = None) -> Callable:
+# ---------------------------------------------------------------------------
+# Trace template registry
+# ---------------------------------------------------------------------------
+# Populated automatically by _attach_fi_trace whenever @flashinfer_api is
+# given a trace= argument.  Each entry is (original_func, template, label)
+# where label is the template's name_prefix (or op_type as fallback).
+#
+# For dispatch callables (trace=some_fn), every template listed in
+# some_fn.templates is registered if that attribute exists.
+#
+# Read by tests/trace/test_fi_trace_template_consistency.py to auto-discover
+# all registered templates without requiring manual maintenance.
+_TRACE_REGISTRY: List[Tuple[Callable, Any, str]] = []
+
+
+def _attach_fi_trace(
+    wrapped: Callable,
+    original: Callable,
+    trace_template=None,
+) -> Callable:
+    """Attach a ``fi_trace`` callable to *wrapped*.
+
+    Three resolution strategies, tried in order:
+
--
+
+        warnings.warn(
+            f"[flashinfer] Failed to attach fi_trace to '{_func_name}': "
+            f"{type(_exc).__name__}: {_exc}\n"
+            f"The function will work normally but fi_trace will be unavailable. "
+            f"Fix the TraceTemplate passed to @flashinfer_api(trace=...).",
+            stacklevel=3,
+        )
+    return wrapped
+
+
+def flashinfer_api(func: Callable = None, *, trace=None) -> Callable:
     """
     Decorator to FlashInfer's APIs.
 
@@ -1489,11 +1644,12 @@ def flashinfer_api(func: Callable = None) -> Callable:
     - The %i pattern is automatically replaced with the process ID for multi-process environments.
     - The logger does not propagate to the root logger to avoid duplicate logs.
     """
-    # If logging is disabled, return original function with zero overhead
+    # If logging is disabled, return original function with zero overhead.
+    # We still attach fi_trace so it is always available regardless of log level.
     if _API_LOG_LEVEL == 0:
         if func is None:
-            return lambda f: f
-        return func
--
 @functools.cache
@@ -135,7 +136,7 @@ class BatchAttention:
             causal,
         )
 
-    @flashinfer_api
+    @flashinfer_api(trace=batch_attention_run_trace)
     def run(
         self,
         q: torch.Tensor,
@@ -209,6 +210,8 @@ class BatchAttentionWithAttentionSinkWrapper(BatchPrefillWithPagedKVCacheWrapper
     a convenient interface for using attention sinks during prefill or decode attention.
     """
 
+    # No @flashinfer_api here: parent class BatchPrefillWithPagedKVCacheWrapper
+    # already decorates __init__, so decorating again produces double log entries.
     def __init__(
         self,
         float_workspace_buffer: torch.Tensor,
diff --git a/flashinfer/attention/cute_dsl/__init__.py b/flashinfer/attention/cute_dsl/__init__.py
new file mode 100644
index 00000000..3e029627
--- /dev/null
+++ b/flashinfer/attention/cute_dsl/__init__.py
@@ -0,0 +1,37 @@
+# Copyright (c) 2026 by FlashInfer team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
--
 
@@ -31,7 +37,7 @@ def get_cascade_module():
     return gen_cascade_module().build_and_load()
 
 
-@flashinfer_api
+@flashinfer_api(trace=merge_state_trace)
 @register_custom_op("flashinfer::merge_state", mutates_args=())
 def merge_state(
     v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
@@ -98,7 +104,7 @@ def _fake_merge_state(
     return v, s
 
 
-@flashinfer_api
+@flashinfer_api(trace=merge_state_in_place_trace)
 @register_custom_op("flashinfer::merge_state_in_place", mutates_args=("v", "s"))
 def merge_state_in_place(
     v: torch.Tensor,
@@ -159,7 +165,7 @@ def _fake_merge_state_in_place(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=merge_states_trace)
 @register_custom_op("flashinfer::merge_states", mutates_args=())
 def merge_states(v: torch.Tensor, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
     r"""Merge multiple attention states (v, s).
@@ -512,7 +518,7 @@ class MultiLevelCascadeAttentionWrapper:
 
     begin_forward = plan
 
-    @flashinfer_api
+    @flashinfer_api(trace=multi_level_cascade_run_trace)
     def run(
         self,
         q: torch.Tensor,
diff --git a/flashinfer/comm/__init__.py b/flashinfer/comm/__init__.py
index 5f186002..31d23a99 100644
--- a/flashinfer/comm/__init__.py
+++ b/flashinfer/comm/__init__.py
@@ -65,4 +65,15 @@ from .trtllm_moe_alltoall import (
     moe_a2a_wrap_payload_tensor_in_workspace as moe_a2a_wrap_payload_tensor_in_workspace,
 )
 
+# DCP A2A (Decode Context Parallel Attention Reduction)
+from .dcp_alltoall import decode_cp_a2a_alltoall as decode_cp_a2a_alltoall
+from .dcp_alltoall import (
+    decode_cp_a2a_allocate_workspace as decode_cp_a2a_allocate_workspace,
+)
+from .dcp_alltoall import decode_cp_a2a_init_workspace as decode_cp_a2a_init_workspace
+from .dcp_alltoall import decode_cp_a2a_workspace_size as decode_cp_a2a_workspace_size
+
 # from .mnnvl import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo
--
 from .trtllm_ar import trtllm_create_ipc_workspace_for_all_reduce_fusion
@@ -449,7 +450,7 @@ def create_allreduce_fusion_workspace(
 # ============================================================================
 
 
-@flashinfer_api
+@flashinfer_api(trace=allreduce_fusion_trace)
 def allreduce_fusion(
     input: torch.Tensor,
     workspace: AllReduceFusionWorkspace,
diff --git a/flashinfer/comm/dcp_alltoall.py b/flashinfer/comm/dcp_alltoall.py
new file mode 100644
index 00000000..3047f76c
--- /dev/null
+++ b/flashinfer/comm/dcp_alltoall.py
@@ -0,0 +1,255 @@
+"""
+DCP All-to-All Operations for DCP Attention Reduction
+
+Provides the DCP LL128 FIFO-based all-to-all kernel for context-parallel
+attention reduction. Uses SM90+ features (TMA, mbarrier).
+
+Usage protocol::
+
+    # 1. Query workspace size
+    ws_bytes = decode_cp_a2a_workspace_size(cp_size)
+
--
+
+
+# ─── Public API ───────────────────────────────────────────────────────────
+
+
+@flashinfer_api
+def decode_cp_a2a_workspace_size(cp_size: int) -> int:
+    """Return the workspace size **in bytes** per rank for the given CP group size.
+
+    Args:
+        cp_size: Context-parallel group size (number of ranks).
+
+    Returns:
+        Workspace size in bytes per rank.
+
+    Example::
+
+        >>> decode_cp_a2a_workspace_size(4)
+        16778240
+    """
+    return get_dcp_alltoall_module().get_workspace_size_per_rank(cp_size)
+
+
+@flashinfer_api
+def decode_cp_a2a_allocate_workspace(
+    cp_size: int,
+    cp_rank: int,
+    *,
+    mapping: Optional[Mapping] = None,
+    mnnvl_config: Optional[MnnvlConfig] = None,
+) -> torch.Tensor:
+    """Allocate a workspace tensor of shape ``[cp_size, ws_elems_per_rank]``.
+
+    After allocation, call :func:`decode_cp_a2a_init_workspace` followed by a
+    cross-rank barrier before the first :func:`decode_cp_a2a_alltoall` call.
+
+    Two allocation modes:
+
+    - **MNNVL** (``mapping`` provided): Cross-rank visible GPU memory via
+      FlashInfer's ``MnnvlMemory``. Required for multi-node or when ranks
+      cannot see each other's device memory directly.
+    - **Plain device memory** (``mapping=None``): Standard ``torch.zeros``
+      allocation. Sufficient for single-node with NVLink P2P.
+
--
+
+    ws_elems_per_rank = (ws_bytes + 7) // 8
+    return torch.zeros(cp_size, ws_elems_per_rank, dtype=torch.int64, device="cuda")
+
+
+@flashinfer_api
+def decode_cp_a2a_init_workspace(
+    workspace: torch.Tensor,
+    cp_rank: int,
+    cp_size: int,
+) -> None:
+    """Initialize the workspace FIFO buffers. Call once before the first alltoall.
+
+    Resets the FIFO buffers in the **local** workspace row
+    (``workspace[cp_rank]``). This function is **synchronous**: when it
+    returns, the GPU memset is guaranteed to have completed.
+
+    .. important::
+        With MNNVL workspaces, **all ranks** must complete
+        ``decode_cp_a2a_init_workspace`` and execute a cross-rank barrier
+        (e.g. ``dist.barrier(group)``) before **any** rank calls
+        :func:`decode_cp_a2a_alltoall`. Without the barrier, a rank may
+        start writing to a peer's FIFO before that peer has finished
+        initializing → deadlock.
+
+    Args:
--
+    # subsequent cross-GPU alltoall can race with the unfinished memset
+    # on MNNVL memory, causing a deadlock.
+    torch.cuda.current_stream().synchronize()
+
+
+@flashinfer_api(trace=decode_cp_a2a_alltoall_trace)
+def decode_cp_a2a_alltoall(
+    partial_o: torch.Tensor,
+    softmax_stats: torch.Tensor,
+    workspace: torch.Tensor,
+    cp_rank: int,
+    cp_size: int,
+    enable_pdl: Optional[bool] = None,
+) -> tuple[torch.Tensor, torch.Tensor]:
+    """Perform the DCP all-to-all exchange.
+
+    Each rank sends its ``partial_o[..., peer, :]`` slice to the
+    corresponding peer and receives all peers' contributions into the
+    output tensors.
+
+    Args:
+        partial_o: ``[..., cp_size, D]`` — half or bfloat16.
+            ``D * element_size`` must be 16-byte aligned.
+        softmax_stats: ``[..., cp_size, S]`` — float32, ``S >= 2`` and even.
+            Batch dimensions must match ``partial_o``.
+        workspace: ``[cp_size, ws_elems_per_rank]`` int64 tensor from
--
+    MixedCommOp.ALLREDUCE_ALLGATHER: _allreduce_allgather,
+    MixedCommOp.REDUCESCATTER_ALLREDUCE: _reducescatter_allreduce,
+}
+
+
+@flashinfer_api
+@backend_requirement(
+    backend_checks={},
+    common_check=_common_check,
+)
+def run_mixed_comm(
+    op: MixedCommOp,
+    handler: MixedCommHandler,
+    x_in: torch.Tensor,
+    x_out: torch.Tensor | None = None,
+    mode: MixedCommMode | None = None,
+) -> torch.Tensor:
+    """Execute a mixed communication operation.
+
+    This is the main entry point for running communication collectives
+    through the mixed communication handler. It supports fused GPU kernels
+    (using virtual memory intra-node and nvshmem inter-node), NCCL-based
+    fallbacks, and autotuned mode selection.
+
+    Args:
+        op: The communication operation to perform.
--
 @functools.cache
@@ -28,7 +29,7 @@ def get_concat_mla_module():
     return gen_concat_mla_module().build_and_load()
 
 
-@flashinfer_api
+@flashinfer_api(trace=concat_mla_k_trace)
 def concat_mla_k(
     k: torch.Tensor,
     k_nope: torch.Tensor,
diff --git a/flashinfer/cudnn/decode.py b/flashinfer/cudnn/decode.py
index 195ca2d4..9b593095 100644
--- a/flashinfer/cudnn/decode.py
+++ b/flashinfer/cudnn/decode.py
@@ -4,6 +4,7 @@ from typing import Optional
 import torch
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.attention import cudnn_batch_decode_trace
 from .utils import get_cudnn_fmha_gen_module
 
 try:
@@ -253,7 +254,7 @@ def _batch_decode_with_kv_cache(
     return out
 
 
-@flashinfer_api
+@flashinfer_api(trace=cudnn_batch_decode_trace)
 def cudnn_batch_decode_with_kv_cache(
     q: torch.Tensor,
     k_cache: torch.Tensor,
diff --git a/flashinfer/cudnn/prefill.py b/flashinfer/cudnn/prefill.py
index fc1bbb5f..b16d6043 100644
--- a/flashinfer/cudnn/prefill.py
+++ b/flashinfer/cudnn/prefill.py
@@ -4,6 +4,7 @@ from typing import Optional
 import torch
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.attention import cudnn_batch_prefill_trace
 from .utils import get_cudnn_fmha_gen_module
 
 try:
@@ -558,7 +559,7 @@ def _batch_prefill_with_kv_cache(
         return out, None
 
 
-@flashinfer_api
+@flashinfer_api(trace=cudnn_batch_prefill_trace)
 def cudnn_batch_prefill_with_kv_cache(
     q: torch.Tensor,
     k_cache: torch.Tensor,
diff --git a/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py b/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
index 0b50c22c..f25aa6fd 100644
--- a/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
+++ b/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
@@ -38,6 +38,7 @@ import torch
 from cutlass import Float32, Int32, Int64, Uint32, Uint8
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.norm import add_rmsnorm_fp4quant_trace
 from ..utils import device_support_pdl
 from .fp4_common import (
     # Constants
@@ -1042,7 +1043,7 @@ def _get_compiled_kernel(
     return tensor_api
 
 
-@flashinfer_api
+@flashinfer_api(trace=add_rmsnorm_fp4quant_trace)
 def add_rmsnorm_fp4quant(
     input: torch.Tensor,
     residual: torch.Tensor,
diff --git a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py
index 333697ab..b7aabc36 100644
--- a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py
+++ b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py
@@ -20,6 +20,7 @@ import torch
 from cutlass import Float32, Int32
 
 from flashinfer.api_logging import flashinfer_api
+from flashinfer.trace.templates.attention import cute_dsl_batch_mla_run_trace
 from flashinfer.utils import device_support_pdl
 from flashinfer.cute_dsl.utils import (
     get_max_active_clusters,
@@ -519,7 +520,7 @@ class BatchMLADecodeCuteDSLWrapper:
                 f"out_dtype={self._o_dtype}"
             )
 
-    @flashinfer_api
+    @flashinfer_api(trace=cute_dsl_batch_mla_run_trace)
     def run(
         self,
         q: torch.Tensor,
diff --git a/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py b/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py
index 58a24abe..ee0cd5e7 100644
--- a/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py
+++ b/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py
@@ -21,6 +21,7 @@ import cutlass.cute as cute
 from cutlass.cute.typing import Int32
 
 from flashinfer.api_logging import flashinfer_api
+from flashinfer.trace.templates.attention import cute_dsl_batch_prefill_run_trace
 
 from ..config import AttentionConfig, AttentionFusion
 from ..fusion.mask import MaskType
@@ -371,7 +372,7 @@ class BatchPrefillCuteDSLWrapper:
                     f"device={self._device}"
                 )
 
-    @flashinfer_api
+    @flashinfer_api(trace=cute_dsl_batch_prefill_run_trace)
     def run(
         self,
         q: torch.Tensor,
diff --git a/flashinfer/cute_dsl/rmsnorm_fp4quant.py b/flashinfer/cute_dsl/rmsnorm_fp4quant.py
index bc4acffc..97ce68a1 100644
--- a/flashinfer/cute_dsl/rmsnorm_fp4quant.py
+++ b/flashinfer/cute_dsl/rmsnorm_fp4quant.py
@@ -32,6 +32,7 @@ import torch
 from cutlass import Float32, Int32, Uint8
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.norm import rmsnorm_fp4quant_trace
 from ..utils import device_support_pdl
 from .fp4_common import (
     # Constants
@@ -771,7 +772,7 @@ def _get_compiled_kernel(
     return tensor_api
 
 
-@flashinfer_api
+@flashinfer_api(trace=rmsnorm_fp4quant_trace)
 def rmsnorm_fp4quant(
     input: torch.Tensor,
     weight: torch.Tensor,
diff --git a/flashinfer/decode.py b/flashinfer/decode.py
index 822aca40..5e9eb515 100644
--- a/flashinfer/decode.py
+++ b/flashinfer/decode.py
@@ -22,6 +22,12 @@ from typing import Any, List, Literal, Optional, Tuple, Union, overload
 import torch
 
 from .api_logging import flashinfer_api
+from .trace.templates.attention import (
+    gqa_paged_decode_trace,
+    single_decode_with_kv_cache_trace,
+    trtllm_batch_decode_trace,
+    xqa_batch_decode_trace,
+)
 
 ## NOTE: MLA functions have been moved to mla.py, but we keep the aliases here for backward compatibility.
 from .mla import (
@@ -400,7 +406,7 @@ def single_decode_with_kv_cache(
 ) -> Tuple[torch.Tensor, torch.Tensor]: ...
 
 
-@flashinfer_api
+@flashinfer_api(trace=single_decode_with_kv_cache_trace)
 def single_decode_with_kv_cache(
     q: torch.Tensor,
     k: torch.Tensor,
@@ -1215,7 +1221,7 @@ class BatchDecodeWithPagedKVCacheWrapper:
         kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]: ...
 
-    @flashinfer_api
+    @flashinfer_api(trace=gqa_paged_decode_trace)
     def run(
         self,
         q: torch.Tensor,
@@ -1577,6 +1583,8 @@ class CUDAGraphBatchDecodeWithPagedKVCacheWrapper(BatchDecodeWithPagedKVCacheWra
     :class:`BatchDecodeWithPagedKVCacheWrapper`
     """
 
+    # No @flashinfer_api here: parent class BatchDecodeWithPagedKVCacheWrapper
+    # already decorates __init__, so decorating again produces double log entries.
     def __init__(
         self,
         workspace_buffer: torch.Tensor,
@@ -2232,7 +2240,7 @@ def get_trtllm_gen_decode_module(*args):
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_batch_decode_trace)
 def trtllm_batch_decode_with_kv_cache(
     query: torch.Tensor,
     kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
@@ -2618,7 +2626,7 @@ def trtllm_batch_decode_with_kv_cache(
 
 
 # xqa uses NHD layout
-@flashinfer_api
+@flashinfer_api(trace=xqa_batch_decode_trace)
 def xqa_batch_decode_with_kv_cache(
     query: torch.Tensor,
     kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
diff --git a/flashinfer/fi_trace.py b/flashinfer/fi_trace.py
new file mode 100644
index 00000000..1104eb6f
--- /dev/null
+++ b/flashinfer/fi_trace.py
@@ -0,0 +1,285 @@
+# Copyright (c) 2025 by FlashInfer team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
--
+
+"""
+fi_trace: Generate `flashinfer-bench <https://github.com/flashinfer-ai/flashinfer-bench>`_
+compatible definition JSON for FlashInfer APIs.
+
+Every ``@flashinfer_api(trace=<template>)``-decorated function supports two
+usage modes:
+
+Auto-dump (recommended)
+-----------------------
+Set environment variables **before** importing flashinfer, then run your
+workload normally.  No explicit ``fi_trace`` call is needed.
+
+.. code-block:: bash
+
+    FLASHINFER_TRACE_DUMP=1 \\
+    FLASHINFER_TRACE_DUMP_DIR=./fi_trace_out \\
+    python my_script.py
+
+Every decorated function writes a ``<name>.json`` file on its **first** call
+for each unique set of const-axis values (e.g. head dimensions, vocab size).
+Subsequent calls with the same shape are deduplicated — the file is written
+only once per process.  The output directory is created automatically.
+
+Explicit call (for selective or programmatic use)
+-------------------------------------------------
--
+from pathlib import Path
+from typing import Any, Callable, Dict, Optional, Union
+
+# ---------------------------------------------------------------------------
+# Legacy registry — kept for backwards compatibility.
+# New code should use @flashinfer_api(trace=TraceTemplate(...)) instead.
+# ---------------------------------------------------------------------------
+
+_REGISTRY: Dict[str, Any] = {}
+
+
+def register_fi_trace(qualname: str, spec: Any) -> None:
+    """Register a legacy FiTraceSpec for the function with the given qualname.
+
+    .. deprecated::
+        Use ``@flashinfer_api(trace=TraceTemplate(...))`` instead.
+    """
+    _REGISTRY[qualname] = spec
+
+
+def build_fi_trace_fn(spec: Any) -> Callable[..., Dict[str, Any]]:
+    """Build a fi_trace callable from a legacy FiTraceSpec.
+
+    .. deprecated::
+        Use ``TraceTemplate.build_fi_trace_fn`` instead.
+    """
+    # Import the old implementation from the trace package for backwards compat.
+    from .trace.template import (  # noqa: PLC0415,F401
+        Const,
+        Scalar,
+        Tensor,
+        TraceTemplate,
+        Var,
+    )
+    import json  # noqa: PLC0415
+    import os  # noqa: PLC0415
--
+    """Generate a flashinfer-bench definition JSON for any FlashInfer API call.
+
+    Parameters
+    ----------
+    func_or_method:
+        A ``@flashinfer_api``-decorated function or (bound) method.
+    save_dir:
+        Directory where the JSON definition file should be written.
+        Falls back to ``FLASHINFER_TRACE_DUMP_DIR`` env-var when *None*.
+    **kwargs:
+        The same tensor arguments you would pass to the real API.
+
+    Returns
+    -------
+    dict
+        A flashinfer-bench compatible definition dictionary.
+
+    Examples
+    --------
+    Standalone function::
+
+        defn = fi_trace(flashinfer.norm.rmsnorm, input=hidden, weight=weight)
+
+    Bound method (instance.run)::
+
+        defn = fi_trace(wrapper.run, q=q_tensor, paged_kv_cache=(k, v))
--
+    trace_fn = getattr(actual_func, "fi_trace", None)
+    if trace_fn is None:
+        qualname = getattr(actual_func, "__qualname__", repr(actual_func))
+        raise ValueError(
+            f"No fi_trace spec is registered for '{qualname}'. "
+            "Only @flashinfer_api(trace=...)-decorated functions support fi_trace."
+        )
+    return trace_fn(save_dir=save_dir, **kwargs)
diff --git a/flashinfer/fused_moe/__init__.py b/flashinfer/fused_moe/__init__.py
index df6e1f72..d983f9d4 100644
--- a/flashinfer/fused_moe/__init__.py
+++ b/flashinfer/fused_moe/__init__.py
@@ -17,6 +17,8 @@ limitations under the License.
 from .core import (
     convert_to_block_layout,
     cutlass_fused_moe,
+    interleave_moe_scales_for_sm90_mixed_gemm,
+    interleave_moe_weights_for_sm90_mixed_gemm,
     gen_cutlass_fused_moe_sm120_module,
     gen_cutlass_fused_moe_sm103_module,
     gen_cutlass_fused_moe_sm100_module,
@@ -64,6 +66,8 @@ __all__ = [
     "WeightLayout",
     "convert_to_block_layout",
     "cutlass_fused_moe",
+    "interleave_moe_scales_for_sm90_mixed_gemm",
--
+        ),
     )
 
 
-# ref: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py#L121
 @flashinfer_api
+def interleave_moe_scales_for_sm90_mixed_gemm(
+    scales: torch.Tensor,
+    group_size: int = 32,
+) -> torch.Tensor:
+    """Interleave MXFP4 block scales for the SM90 mixed-input MoE GEMM.
+
+    The kernel expects scales in layout
+    ``(num_experts, K // (group_size * 4), rows * 4)`` rather than the natural
+    ``(num_experts, rows, K // group_size)`` produced by the MXFP4 quantizer.
+    This helper performs the reshape + permute equivalent to TensorRT-LLM's
+    ``WFP4A16FusedMoEMethod.load_quant_scales`` (PR #12451), with the fixed
+    interleave factor of ``128 // group_size`` used for MXFP4.
+
+    Parameters
+    ----------
+    scales:
+        ``[num_experts, rows, K // group_size]`` uint8 tensor of E8M0 block
+        scales.
+    group_size:
+        MXFP4 quantization group size (default 32).
--
+        scales.reshape(e, rows, kgs // factor, factor).permute(0, 2, 1, 3).contiguous()
+    )
+    return tmp.reshape(e, kgs // factor, rows * factor)
+
+
+@flashinfer_api
+def interleave_moe_weights_for_sm90_mixed_gemm(
+    weight: torch.Tensor,
+    quant_type: str = "fp4",
+) -> torch.Tensor:
+    """Interleave 4-bit packed MoE weights for the SM90 mixed-input GEMM.
+
+    The SM90 mixed-dtype MoE GEMM (used by ``cutlass_fused_moe`` with
+    ``use_w4_group_scaling=True``) expects weights in a specific interleaved
+    layout; without preprocessing, the LUT-based FP4→BF16 conversion reads
+    bytes from the wrong positions and the output diverges from a dequantized
+    reference for any K > 128. TensorRT-LLM's W4A16 MoE runs the equivalent
+    preprocessing at weight-load time (see
+    ``interleave_4bit_weights_for_Hopper_mixed_gemm`` in TRT-LLM PR #12451).
+
+    Parameters
+    ----------
+    weight:
+        ``[num_experts, n, k // 2]`` uint8 CUDA tensor (4-bit values packed
+        two-per-byte).
+    quant_type:
--
+    )
+    return out
+
+
+# ref: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py#L121
+@flashinfer_api(trace=cutlass_fused_moe_trace)
 def cutlass_fused_moe(
     input: torch.Tensor,
     token_selected_experts: torch.Tensor,
@@ -1027,8 +1151,8 @@ def get_trtllm_moe_sm100_module():
                     DynamicTensorSpec(
                         input_idx,
                         dim_idx,
-                        get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens, 1),
-                        lambda x: min(last_positive_power_of_2(x), tune_max_num_tokens),
+                        get_hybrid_num_tokens_buckets(tune_max_num_tokens, 1),
+                        lambda x: map_to_hybrid_bucket(x, tune_max_num_tokens),
                         initializers,
                     ),
                 ),
@@ -2344,7 +2468,7 @@ def _validate_routing_replay_out(
         raise ValueError("routing_replay_out must be contiguous (packed row-major)")
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_bf16_moe_trace)
 def trtllm_bf16_moe(
     routing_logits: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
@@ -2452,7 +2576,7 @@ def trtllm_bf16_moe(
         return result
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_bf16_routed_moe_trace)
 def trtllm_bf16_routed_moe(
     topk_ids: torch.Tensor,
     hidden_states: torch.Tensor,
@@ -2557,7 +2681,7 @@ def trtllm_bf16_routed_moe(
         return result
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp8_per_tensor_scale_moe_trace)
 def trtllm_fp8_per_tensor_scale_moe(
     routing_logits: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
@@ -2658,7 +2782,7 @@ def trtllm_fp8_per_tensor_scale_moe(
         return result
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp8_block_scale_moe_trace_dispatch)
 def trtllm_fp8_block_scale_moe(
     routing_logits: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
@@ -2779,7 +2903,7 @@ def trtllm_fp8_block_scale_moe(
         return result
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp8_block_scale_routed_moe_trace)
 def trtllm_fp8_block_scale_routed_moe(
     topk_ids: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
@@ -2893,7 +3017,7 @@ def trtllm_fp8_block_scale_routed_moe(
         return result
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp4_block_scale_moe_trace_dispatch)
 def trtllm_fp4_block_scale_moe(
     routing_logits: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
@@ -3030,7 +3154,7 @@ def trtllm_fp4_block_scale_moe(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp4_block_scale_routed_moe_trace)
 def trtllm_fp4_block_scale_routed_moe(
     topk_ids: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
@@ -3165,7 +3289,7 @@ def trtllm_fp4_block_scale_routed_moe(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_mxint4_block_scale_moe_trace)
 def trtllm_mxint4_block_scale_moe(
     routing_logits: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
diff --git a/flashinfer/fused_moe/cute_dsl/b12x_moe.py b/flashinfer/fused_moe/cute_dsl/b12x_moe.py
index d2cbc8b0..34916df5 100644
--- a/flashinfer/fused_moe/cute_dsl/b12x_moe.py
+++ b/flashinfer/fused_moe/cute_dsl/b12x_moe.py
@@ -42,11 +42,12 @@ from typing import Optional, Tuple
 import torch
 
 from ...api_logging import flashinfer_api
+from ...trace.templates.moe import b12x_fused_moe_trace, b12x_moe_wrapper_run_trace
 from ...utils import supported_compute_capability
 
 
 @supported_compute_capability([120, 121])
-@flashinfer_api
+@flashinfer_api(trace=b12x_fused_moe_trace)
 def b12x_fused_moe(
     x: torch.Tensor,
     w1_weight: torch.Tensor,
@@ -293,7 +294,7 @@ class B12xMoEWrapper:
             device=self.device,
         )
 
-    @flashinfer_api
+    @flashinfer_api(trace=b12x_moe_wrapper_run_trace)
     def run(
         self,
         x: torch.Tensor,
diff --git a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py
index f6cf1b67..e266cb77 100644
--- a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py
+++ b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py
@@ -89,8 +89,8 @@ from flashinfer.cute_dsl.fp4_common import (
     st_global_u64,
     scatter_add_bf16x2,
 )
-from flashinfer.gemm.kernels.dense_blockscaled_gemm_sm120 import (
-    Sm120BlockScaledDenseGemmKernel as DenseGemmKernel,
+from flashinfer.gemm.kernels.dense_blockscaled_gemm_sm120_b12x import (
+    Sm120B12xBlockScaledDenseGemmKernel as DenseGemmKernel,
 )
 
 
diff --git a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py
index e7fdae92..670b3ad8 100644
--
 from .moe_utils import (
@@ -530,7 +534,7 @@ class CuteDslMoEWrapper:
             enable_pdl=enable_pdl,
         )
 
-    @flashinfer_api
+    @flashinfer_api(trace=cute_dsl_moe_wrapper_run_trace)
     def run(
         self,
         x: torch.Tensor,
@@ -686,7 +690,7 @@ def _cute_dsl_fused_moe_nvfp4_impl(
 
 
 @supported_compute_capability([100, 103])
-@flashinfer_api
+@flashinfer_api(trace=cute_dsl_fused_moe_nvfp4_trace)
 def cute_dsl_fused_moe_nvfp4(
     x: torch.Tensor,
     x_sf: torch.Tensor,
diff --git a/flashinfer/fused_moe/cute_dsl/tuner.py b/flashinfer/fused_moe/cute_dsl/tuner.py
index 0cc8628e..636043db 100644
--- a/flashinfer/fused_moe/cute_dsl/tuner.py
+++ b/flashinfer/fused_moe/cute_dsl/tuner.py
@@ -42,8 +42,8 @@ from ...autotuner import (
     TuningConfig,
 )
 from ..utils import (
-    get_last_power_of_2_num_tokens_buckets,
-    last_positive_power_of_2,
+    get_hybrid_num_tokens_buckets,
+    map_to_hybrid_bucket,
 )
 
 logger = logging.getLogger(__name__)
@@ -273,10 +273,8 @@ class CuteDslFusedMoENvfp4Runner(TunableRunner):
                 DynamicTensorSpec(
--
 import torch
@@ -137,7 +138,7 @@ def get_dsv3_fused_routing_module():
 
 
 @backend_requirement({}, common_check=_check_dsv3_fused_routing_supported)
-@flashinfer_api
+@flashinfer_api(trace=fused_topk_deepseek_trace)
 def fused_topk_deepseek(
     scores: torch.Tensor,
     bias: torch.Tensor,
diff --git a/flashinfer/fused_moe/utils.py b/flashinfer/fused_moe/utils.py
index 004271a1..91f37aa5 100644
--- a/flashinfer/fused_moe/utils.py
+++ b/flashinfer/fused_moe/utils.py
@@ -209,29 +209,102 @@ def nearest_in_buckets(x: int, buckets: List[int]) -> int:
     return min(max(next_positive_power_of_2(x), buckets[0]), buckets[-1])
 
 
-def get_power_of_2_num_tokens_buckets(max_num_tokens) -> Tuple[int]:
-    """Return descending power-of-2 buckets from ``next_power_of_2(max_num_tokens)`` down to 1."""
-    max_num_tokens = next_positive_power_of_2(max_num_tokens)
-    num_token_buckets = []
-    m = max_num_tokens
-    while m >= 1:
-        num_token_buckets.append(m)
-        m //= 2
+_PHASE1_END = 256
--
 
@@ -106,7 +114,7 @@ TILE_V = 8  # pretranspose tile size
 # ============================================================================
 
 
-@flashinfer_api
+@flashinfer_api(trace=gated_delta_rule_decode_trace)
 def gated_delta_rule_decode_pretranspose(
     q: torch.Tensor,
     k: torch.Tensor,
@@ -394,7 +402,7 @@ def gated_delta_rule_decode_pretranspose(
 # ============================================================================
 
 
-@flashinfer_api
+@flashinfer_api(trace=gated_delta_rule_decode_trace)
 def gated_delta_rule_decode(
     q: torch.Tensor,
     k: torch.Tensor,
@@ -535,7 +543,7 @@ def gated_delta_rule_decode(
 # ============================================================================
 
 
-@flashinfer_api
+@flashinfer_api(trace=gdn_mtp_trace)
 def gated_delta_rule_mtp(
     q: torch.Tensor,
     k: torch.Tensor,
diff --git a/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py b/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
index 68398d28..53fe44ce 100644
--- a/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
+++ b/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
@@ -3333,8 +3333,7 @@ class GatedDeltaNetChunkedKernel:
 
         gate_handle = load_gate_consumer.wait_and_advance()
 
-        max_coord = tTR_tCcShared[cute.size(tTR_tCcShared) - 1]
-        cumprod_total = sCumprod[max_coord[1], 0, gate_handle.index]
+        cumprod_total = sCumprod[sCumprod.shape[0] - 1, 0, gate_handle.index]
 
         valid_state = not is_first_chunk or self.use_initial_state
         if cutlass.const_expr(valid_state):
diff --git a/flashinfer/gdn_kernels/blackwell/gdn_prefill.py b/flashinfer/gdn_kernels/blackwell/gdn_prefill.py
index 82dcc72b..aafcc671 100644
--- a/flashinfer/gdn_kernels/blackwell/gdn_prefill.py
--
     register_custom_op,
@@ -95,7 +96,7 @@ def get_gdn_prefill_module():
     return SimpleNamespace(gdn_prefill=gdn_prefill)
 
 
-@flashinfer_api
+@flashinfer_api(trace=gdn_prefill_trace)
 def chunk_gated_delta_rule(
     q: torch.Tensor,
     k: torch.Tensor,
diff --git a/flashinfer/gemm/__init__.py b/flashinfer/gemm/__init__.py
index a7795beb..def82216 100644
--- a/flashinfer/gemm/__init__.py
+++ b/flashinfer/gemm/__init__.py
@@ -61,11 +61,11 @@ try:
     from flashinfer.cute_dsl.utils import is_cute_dsl_available
 
     if is_cute_dsl_available():
-        from .kernels.dense_blockscaled_gemm_sm120 import (
-            Sm120BlockScaledDenseGemmKernel as Sm120BlockScaledDenseGemmKernel,
+        from .kernels.dense_blockscaled_gemm_sm120_b12x import (
+            Sm120B12xBlockScaledDenseGemmKernel as Sm120B12xBlockScaledDenseGemmKernel,
         )
 
-        _cute_dsl_kernels.append("Sm120BlockScaledDenseGemmKernel")
+        _cute_dsl_kernels.append("Sm120B12xBlockScaledDenseGemmKernel")
 except ImportError:
--
 from ..utils import (
@@ -325,7 +339,7 @@ def _heuristic_func_mm_bf16(
     common_check=_check_mm_bf16_problem_size,
     heuristic_func=_heuristic_func_mm_bf16,
 )
-@flashinfer_api
+@flashinfer_api(trace=mm_bf16_trace)
 def mm_bf16(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -514,7 +528,7 @@ def _heuristic_func_bmm_bf16(
     common_check=_check_bmm_bf16_problem_size,
     heuristic_func=_heuristic_func_bmm_bf16,
 )
-@flashinfer_api
+@flashinfer_api(trace=bmm_bf16_trace)
 def bmm_bf16(
     A: torch.Tensor,
     B: torch.Tensor,
@@ -815,8 +829,8 @@ _FP8_GEMM_SM100_TUNING_CONFIG = TuningConfig(
         DynamicTensorSpec(
             (0,),  # a_tensor_index
             (-2,),
-            get_last_power_of_2_num_tokens_buckets,
-            last_positive_power_of_2,
+            get_hybrid_num_tokens_buckets,
+            map_to_hybrid_bucket_uncapped,
         ),
     ),
     constraint_specs=(
@@ -871,8 +885,8 @@ _BF16_GEMM_SM100_TUNING_CONFIG = TuningConfig(
         DynamicTensorSpec(
             (0,),  # a_tensor_index
             (-2,),
-            get_last_power_of_2_num_tokens_buckets,
-            last_positive_power_of_2,
--
     constraint_specs=(
@@ -1095,7 +1109,7 @@ def get_tgv_gemm_sm10x_module(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=tgv_gemm_sm100_trace)
 def tgv_gemm_sm100(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -1173,8 +1187,8 @@ def tgv_gemm_sm100(
             DynamicTensorSpec(
                 (a_tensor_index,),
                 (-2,),
-                get_last_power_of_2_num_tokens_buckets,
-                last_positive_power_of_2,
+                get_hybrid_num_tokens_buckets,
+                map_to_hybrid_bucket_uncapped,
             ),
         ),
         constraint_specs=(
@@ -1437,6 +1451,7 @@ class SegmentGEMMWrapper:
     True
     """
 
+    @flashinfer_api
     def __init__(
         self, float_workspace_buffer: torch.Tensor, backend: str = "auto"
     ) -> None:
@@ -1469,7 +1484,7 @@ class SegmentGEMMWrapper:
         self._float_workspace_buffer = float_workspace_buffer
         self._int_workspace_buffer = int_workspace_buffer
 
-    @flashinfer_api
+    @flashinfer_api(trace=segment_gemm_run_trace)
     def run(
         self,
         x: torch.Tensor,
@@ -2084,6 +2099,8 @@ def build_cudnn_gemm_fp4_graph_override_shape(
     return graph
 
 
+# Internal helper called from mm_fp4; the user-facing mm_fp4 is already
+# decorated, so decorating here would double-log the same invocation.
 def execute_cudnn_gemm_fp4_graph_override_shape(
     graph,
     a,
@@ -2319,6 +2336,8 @@ def build_cudnn_gemm_mxfp8_graph_override_shape(
     return graph
 
 
+# Internal helper called from mm_mxfp8; the user-facing mm_mxfp8 is already
+# decorated, so decorating here would double-log the same invocation.
 def execute_cudnn_gemm_mxfp8_graph_override_shape(
     graph,
--
 ):
@@ -3161,7 +3184,7 @@ def _expand_block_scale_tensor_shape(block_scale_tensor, batch_size):
     return (tuple(block_scale_shape), tuple(block_scale_stride))
 
 
-@flashinfer_api
+@flashinfer_api(trace=mm_fp8_trace)
 def mm_fp8(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -3990,7 +4013,7 @@ def _heuristic_func_mm_mxfp8(
     common_check=_check_mm_mxfp8_problem_size,
     heuristic_func=_heuristic_func_mm_mxfp8,  # result stored in mm_mxfp8.suitable_auto_backends
 )
-@flashinfer_api
+@flashinfer_api(trace=mm_mxfp8_trace)
 def mm_mxfp8(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -4858,8 +4881,8 @@ def _b12x_gemm_fp4_runner(
     """
     import cutlass
 
-    from .kernels.dense_blockscaled_gemm_sm120 import (
-        Sm120BlockScaledDenseGemmKernel,
+    from .kernels.dense_blockscaled_gemm_sm120_b12x import (
+        Sm120B12xBlockScaledDenseGemmKernel,
     )
 
     cutlass_dtype_attr = _TORCH_TO_CUTLASS_DTYPE_ATTR.get(out_dtype)
@@ -4905,7 +4928,7 @@ def _b12x_gemm_fp4_runner(
             ]
             swap_ab = False
             for mma_tiler_mn in sm120_mma_tiler_candidates:
-                if not Sm120BlockScaledDenseGemmKernel.can_implement(
+                if not Sm120B12xBlockScaledDenseGemmKernel.can_implement(
--
     constraint_specs=(
@@ -5195,7 +5217,7 @@ _MM_MXFP8_TUNING_CONFIG = TuningConfig(
     common_check=_check_mm_fp4_problem_size,
     heuristic_func=_heuristic_func_mm_fp4,  # result stored in mm_fp4.suitable_auto_backends
 )
-@flashinfer_api
+@flashinfer_api(trace=mm_fp4_trace)
 def mm_fp4(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -5449,7 +5471,7 @@ def _heuristic_func_bmm_fp8(
     common_check=_check_bmm_fp8_problem_size,
     heuristic_func=_heuristic_func_bmm_fp8,
 )
-@flashinfer_api
+@flashinfer_api(trace=bmm_fp8_trace)
 def bmm_fp8(
     A: torch.Tensor,
     B: torch.Tensor,
@@ -6862,7 +6884,7 @@ def _check_batch_deepgemm_fp8_nt_groupwise(
     {},
     common_check=_check_batch_deepgemm_fp8_nt_groupwise,
 )
-@flashinfer_api
+@flashinfer_api(trace=batch_deepgemm_fp8_nt_groupwise_trace)
 def batch_deepgemm_fp8_nt_groupwise(
     a: torch.Tensor,  # (batch_size, m, k)
     b: torch.Tensor,  # (batch_size, n, k)
@@ -7006,7 +7028,7 @@ def get_fp8_blockscale_gemm_runner_sm90():
     return module.init()
 
 
-@flashinfer_api
+@flashinfer_api(trace=fp8_blockscale_gemm_sm90_trace)
 def fp8_blockscale_gemm_sm90(
     input: torch.Tensor,
     weight: torch.Tensor,
@@ -7588,7 +7610,7 @@ def _heuristic_func_bmm_mxfp8(
     common_check=_check_bmm_mxfp8_problem_size,
     heuristic_func=_heuristic_func_bmm_mxfp8,
 )
-@flashinfer_api
+@flashinfer_api(trace=bmm_mxfp8_trace)
 def bmm_mxfp8(
     A: torch.Tensor,
     B: torch.Tensor,
diff --git a/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py b/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120_b12x.py
similarity index 99%
rename from flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py
rename to flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120_b12x.py
index c49bc815..6eee27a7 100644
--- a/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py
+++ b/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120_b12x.py
@@ -1550,7 +1550,7 @@ class DenseGemmKernel:
 
 
 # Alias for FlashInfer integration
-Sm120BlockScaledDenseGemmKernel = DenseGemmKernel
+Sm120B12xBlockScaledDenseGemmKernel = DenseGemmKernel
 
 
 class _DenseGemmLaunch:
diff --git a/flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py b/flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py
--
     get_cutlass_dtype,
@@ -2951,7 +2952,7 @@ def get_cute_dsl_compiled_masked_gemm_kernel(
     return tensor_api
 
 
-@flashinfer_api
+@flashinfer_api(trace=grouped_gemm_nt_masked_trace)
 def grouped_gemm_nt_masked(
     lhs: Tuple[torch.Tensor, torch.Tensor],
     rhs: Tuple[torch.Tensor, torch.Tensor],
diff --git a/flashinfer/gemm/routergemm.py b/flashinfer/gemm/routergemm.py
index cfde7d43..f83c8974 100644
--- a/flashinfer/gemm/routergemm.py
+++ b/flashinfer/gemm/routergemm.py
@@ -1,4 +1,8 @@
 from ..api_logging import flashinfer_api
+from ..trace.templates.gemm import (
+    mm_M1_16_K7168_N256_trace,
+    tinygemm_bf16_trace,
+)
 from flashinfer.jit import gen_dsv3_router_gemm_module, gen_tinygemm2_module
 import functools
 from types import SimpleNamespace
@@ -176,7 +180,7 @@ def mm_M1_16_K7168_N128(
 
 
 @backend_requirement({}, common_check=_mm_M1_16_K7168_N256_shape_checks)
-@flashinfer_api
+@flashinfer_api(trace=mm_M1_16_K7168_N256_trace)
 def mm_M1_16_K7168_N256(
     mat_a: torch.Tensor,
     mat_b: torch.Tensor,
@@ -324,7 +328,7 @@ def get_tinygemm2_module():
 
 
 @backend_requirement({}, common_check=_tinygemm_bf16_shape_checks)
-@flashinfer_api
+@flashinfer_api(trace=tinygemm_bf16_trace)
 def tinygemm_bf16(
     input: torch.Tensor,
     weight: torch.Tensor,
diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py
index 7f36a314..8378e0ab 100644
--- a/flashinfer/jit/__init__.py
+++ b/flashinfer/jit/__init__.py
@@ -82,6 +82,7 @@ from .comm import gen_trtllm_mnnvl_comm_module as gen_trtllm_mnnvl_comm_module
 from .comm import gen_trtllm_comm_module as gen_trtllm_comm_module
 from .comm import gen_vllm_comm_module as gen_vllm_comm_module
 from .comm import gen_moe_alltoall_module as gen_moe_alltoall_module
+from .comm import gen_dcp_alltoall_module as gen_dcp_alltoall_module
 from .dsv3_optimizations import (
     gen_dsv3_router_gemm_module as gen_dsv3_router_gemm_module,
 )
diff --git a/flashinfer/jit/comm.py b/flashinfer/jit/comm.py
index 46768eed..834f77f9 100644
--- a/flashinfer/jit/comm.py
+++ b/flashinfer/jit/comm.py
@@ -15,7 +15,13 @@ limitations under the License.
--
     gen_selective_state_update_sm100_module,
@@ -99,7 +100,7 @@ def get_selective_state_update_module(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=selective_state_update_trace)
 def selective_state_update(
     state: torch.Tensor,
     x: torch.Tensor,
diff --git a/flashinfer/mla/_core.py b/flashinfer/mla/_core.py
index 4e8bdd72..e27e3807 100644
--- a/flashinfer/mla/_core.py
+++ b/flashinfer/mla/_core.py
@@ -21,6 +21,11 @@ from typing import List, Literal, Optional, Tuple, Union, overload
 import torch
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.attention import (
+    mla_paged_decode_trace,
+    trtllm_batch_decode_mla_trace,
+    xqa_batch_decode_mla_trace,
+)
 from ..jit import gen_batch_mla_module, gen_trtllm_gen_fmha_module, setup_cubin_loader
 from ..jit.mla import gen_mla_module
 from ..utils import (
@@ -469,7 +474,7 @@ class BatchMLAPagedAttentionWrapper:
         return_lse_base_on_e: bool = False,
     ) -> Tuple[torch.Tensor, torch.Tensor]: ...
 
-    @flashinfer_api
+    @flashinfer_api(trace=mla_paged_decode_trace)
     def run(
         self,
         q_nope: torch.Tensor,
@@ -588,7 +593,7 @@ class BatchMLAPagedAttentionWrapper:
         return (out, lse) if return_lse else out
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_batch_decode_mla_trace)
 def trtllm_batch_decode_with_kv_cache_mla(
     query: torch.Tensor,
     kv_cache: torch.Tensor,
@@ -856,7 +861,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
         raise ValueError(f"Backend {backend} not supported")
 
 
-@flashinfer_api
+@flashinfer_api(trace=xqa_batch_decode_mla_trace)
 def xqa_batch_decode_with_kv_cache_mla(
     query: torch.Tensor,
     kv_cache: torch.Tensor,
diff --git a/flashinfer/norm/__init__.py b/flashinfer/norm/__init__.py
index 0f9911a6..ba612b28 100644
--- a/flashinfer/norm/__init__.py
+++ b/flashinfer/norm/__init__.py
@@ -32,6 +32,16 @@ from typing import Optional, Union
 import torch
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.norm import (
+    fused_add_rmsnorm_quant_trace,
+    fused_add_rmsnorm_trace,
+    fused_rmsnorm_silu_trace,
+    gemma_fused_add_rmsnorm_trace,
+    gemma_rmsnorm_trace,
+    layernorm_trace,
+    rmsnorm_quant_trace,
+    rmsnorm_trace,
--
     get_compute_capability,
@@ -94,7 +104,7 @@ def _normalize_scale_tensor(
     return scale.contiguous()
 
 
-@flashinfer_api
+@flashinfer_api(trace=rmsnorm_trace)
 def rmsnorm(
     input: torch.Tensor,
     weight: torch.Tensor,
@@ -165,7 +175,7 @@ def _rmsnorm_impl_fake(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=rmsnorm_quant_trace)
 @register_custom_op("flashinfer::rmsnorm_quant", mutates_args=("out",))
 def rmsnorm_quant(
     out: torch.Tensor,
@@ -219,7 +229,7 @@ def _rmsnorm_quant_fake(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=fused_add_rmsnorm_trace)
 @register_custom_op("flashinfer::fused_add_rmsnorm", mutates_args=("input", "residual"))
 def fused_add_rmsnorm(
     input: torch.Tensor,
@@ -271,7 +281,7 @@ def _fused_add_rmsnorm_fake(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=fused_add_rmsnorm_quant_trace)
 @register_custom_op(
     "flashinfer::fused_add_rmsnorm_quant", mutates_args=("out", "residual")
 )
@@ -343,7 +353,7 @@ def _fused_add_rmsnorm_quant_fake(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=gemma_rmsnorm_trace)
 def gemma_rmsnorm(
     input: torch.Tensor,
     weight: torch.Tensor,
@@ -414,7 +424,7 @@ def _gemma_rmsnorm_impl_fake(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=gemma_fused_add_rmsnorm_trace)
 @register_custom_op(
     "flashinfer::gemma_fused_add_rmsnorm", mutates_args=("input", "residual")
 )
@@ -470,7 +480,7 @@ def _gemma_fused_add_rmsnorm_fake(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=layernorm_trace)
 @register_custom_op("flashinfer::layernorm", mutates_args=())
 def layernorm(
     input: torch.Tensor,
@@ -590,7 +600,7 @@ def _torch_dtype_to_str(dtype):
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=fused_rmsnorm_silu_trace)
 def fused_rmsnorm_silu(
     input: torch.Tensor,
     weight: torch.Tensor,
diff --git a/flashinfer/page.py b/flashinfer/page.py
index 12ea3613..7fb33cf3 100644
--- a/flashinfer/page.py
+++ b/flashinfer/page.py
@@ -20,6 +20,10 @@ from typing import Optional, Tuple, Union
 import torch
 
 from .api_logging import flashinfer_api
+from .trace.templates.page import (
+    append_paged_kv_cache_trace,
+    append_paged_mla_kv_cache_trace,
+)
 from .jit.page import gen_page_module
 from .utils import (
     TensorLayout,
@@ -222,7 +226,7 @@ def get_seq_lens(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=append_paged_mla_kv_cache_trace)
 def append_paged_mla_kv_cache(
     append_ckv: torch.Tensor,
     append_kpe: torch.Tensor,
@@ -272,7 +276,7 @@ def append_paged_mla_kv_cache(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=append_paged_kv_cache_trace)
 def append_paged_kv_cache(
     append_key: torch.Tensor,
     append_value: torch.Tensor,
diff --git a/flashinfer/pod.py b/flashinfer/pod.py
index fe2e36c1..4fa2d9bf 100644
--- a/flashinfer/pod.py
+++ b/flashinfer/pod.py
@@ -22,6 +22,10 @@ from typing import Any, List, Optional, Tuple, Union
 import torch
 
 from .api_logging import flashinfer_api
+from .trace.templates.attention import (
+    batch_pod_with_paged_kv_cache_run_trace,
+    pod_with_paged_kv_cache_run_trace,
+)
 from .jit import gen_pod_module, gen_batch_pod_module
 from .page import get_seq_lens
 from .prefill import get_batch_prefill_module
@@ -435,7 +439,7 @@ class PODWithPagedKVCacheWrapper:
 
     begin_forward = plan
 
-    @flashinfer_api
+    @flashinfer_api(trace=pod_with_paged_kv_cache_run_trace)
     def run(
         self,
         # Main params (prefill and decode)
@@ -1015,7 +1019,7 @@ class BatchPODWithPagedKVCacheWrapper:
 
     begin_forward = plan
 
-    @flashinfer_api
+    @flashinfer_api(trace=batch_pod_with_paged_kv_cache_run_trace)
     def run(
         self,
         # Main params (prefill and decode)
diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py
index 4ec6a29e..d491dd35 100755
--- a/flashinfer/prefill.py
+++ b/flashinfer/prefill.py
@@ -23,6 +23,17 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload
 import torch
 
 from .api_logging import flashinfer_api
+from .trace.templates.attention import (
+    gqa_paged_prefill_trace,
+    gqa_ragged_prefill_trace,
+    single_prefill_with_kv_cache_trace,
+    trtllm_batch_context_trace,
+)
+from .trace.templates.gemm import (
+    fmha_v2_prefill_deepseek_trace,
+    trtllm_ragged_attention_deepseek_trace,
--
     gen_customize_batch_prefill_module,
@@ -1099,7 +1110,7 @@ def single_prefill_with_kv_cache(
 ) -> Tuple[torch.Tensor, torch.Tensor]: ...
 
 
-@flashinfer_api
+@flashinfer_api(trace=single_prefill_with_kv_cache_trace)
 def single_prefill_with_kv_cache(
     q: torch.Tensor,
     k: torch.Tensor,
@@ -2132,7 +2143,7 @@ class BatchPrefillWithPagedKVCacheWrapper:
         skip_softmax_threshold_scale_factor: Optional[float] = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]: ...
 
-    @flashinfer_api
+    @flashinfer_api(trace=gqa_paged_prefill_trace)
     def run(
         self,
         q: torch.Tensor,
@@ -3186,7 +3197,7 @@ class BatchPrefillWithRaggedKVCacheWrapper:
         enable_pdl: Optional[bool] = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]: ...
 
-    @flashinfer_api
+    @flashinfer_api(trace=gqa_ragged_prefill_trace)
     def run(
         self,
         q: torch.Tensor,
@@ -3669,7 +3680,7 @@ def get_trtllm_gen_fmha_module():
     return op
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_ragged_attention_deepseek_trace)
 def trtllm_ragged_attention_deepseek(
     query: torch.Tensor,
     key: torch.Tensor,
@@ -3692,6 +3703,7 @@ def trtllm_ragged_attention_deepseek(
     skip_softmax_threshold_scale_factor: Optional[float] = None,
     out: Optional[torch.Tensor] = None,
     lse: Optional[torch.Tensor] = None,
+    backend: str = "trtllm-gen",
 ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
     """
     Parameters
@@ -3742,6 +3754,12 @@ def trtllm_ragged_attention_deepseek(
         output tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1], value.shape[2]]
     lse : Optional[torch.Tensor]
         lse tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1]]
+    backend : str
+        Attention backend to use. "trtllm-gen" (default) or "cute-dsl".
+        When backend="cute-dsl", query/key/value/out tensors must be
+        front-padded with max_seq_len rows of valid GPU memory before
+        index 0 (see ``cute_dsl_fmha_ragged_prefill`` for details).
--
             "lse assumed not None beyond this point when return_lse is True"
@@ -3839,7 +3917,7 @@ def trtllm_ragged_attention_deepseek(
         return out
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_batch_context_trace)
 def trtllm_batch_context_with_kv_cache(
     query: torch.Tensor,
     kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
@@ -4138,7 +4216,7 @@ def get_trtllm_fmha_v2_sm120_module():
     return gen_trtllm_fmha_v2_sm120_module().build_and_load()
 
 
-@flashinfer_api
+@flashinfer_api(trace=fmha_v2_prefill_deepseek_trace)
 def fmha_v2_prefill_deepseek(
     query: torch.Tensor,
     key: torch.Tensor,
@@ -4228,7 +4306,7 @@ def get_trtllm_fmha_v2_module(
     return gen_fmha_v2_module(input_layout, input_dtype, output_dtype).build_and_load()
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fmha_v2_prefill_trace)
 def trtllm_fmha_v2_prefill(
     qkv: Union[
         torch.Tensor,
diff --git a/flashinfer/quantization/fp4_quantization.py b/flashinfer/quantization/fp4_quantization.py
index 4cd5cd34..84f7ade6 100644
--- a/flashinfer/quantization/fp4_quantization.py
+++ b/flashinfer/quantization/fp4_quantization.py
@@ -21,6 +21,12 @@ from typing import List, Optional, Tuple
 import torch
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.quantize import (
+    fp4_quantize_trace,
+    mxfp4_quantize_trace,
+    nvfp4_kv_quantize_trace,
+    nvfp4_quantize_trace,
+)
 from ..jit import JitSpec
 from ..jit import env as jit_env
 from ..jit import (
@@ -648,7 +654,7 @@ def get_fp4_quantization_module(backend: str = "100"):
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=fp4_quantize_trace)
 def fp4_quantize(
     input: torch.Tensor,
     global_scale: Optional[torch.Tensor] = None,
@@ -923,7 +929,7 @@ def shuffle_matrix_sf_a(
     return block_scale_interleave(w_shuffled)
 
 
-@flashinfer_api
+@flashinfer_api(trace=nvfp4_quantize_trace)
 def nvfp4_quantize(
     a,
     a_global_sf,
@@ -1024,7 +1030,7 @@ def nvfp4_quantize(
     return a_fp4, a_sf
 
 
-@flashinfer_api
+@flashinfer_api(trace=mxfp4_quantize_trace)
 def mxfp4_quantize(
     a: torch.Tensor,
     backend: str = "cuda",
@@ -1441,7 +1447,7 @@ def _nvfp4_kv_quant_check(input, global_scale):
 
 
 @backend_requirement({}, common_check=_nvfp4_kv_quant_check)
-@flashinfer_api
+@flashinfer_api(trace=nvfp4_kv_quantize_trace)
 def nvfp4_kv_quantize(
     input: torch.Tensor,
     global_scale: torch.Tensor,
diff --git a/flashinfer/quantization/fp8_quantization.py b/flashinfer/quantization/fp8_quantization.py
index f2c9f412..49e13a8b 100644
--- a/flashinfer/quantization/fp8_quantization.py
+++ b/flashinfer/quantization/fp8_quantization.py
@@ -5,6 +5,7 @@ from typing import Literal, Optional, Tuple
 import torch
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.quantize import mxfp8_quantize_trace
 from ..jit.fp8_quantization import gen_mxfp8_quantization_sm100_module
 from ..utils import (
     device_support_pdl,
@@ -158,7 +159,7 @@ def get_mxfp8_quantization_sm100_module():
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=mxfp8_quantize_trace)
 def mxfp8_quantize(
     input: torch.Tensor,
     is_sf_swizzled_layout: bool = True,
diff --git a/flashinfer/rope.py b/flashinfer/rope.py
index d39d2e07..df5c7d4d 100644
--- a/flashinfer/rope.py
+++ b/flashinfer/rope.py
@@ -20,6 +20,21 @@ from typing import Optional, Tuple
 import torch
 
 from .api_logging import flashinfer_api
+from .trace.templates.rope import (
+    apply_llama31_rope_inplace_trace,
+    apply_llama31_rope_pos_ids_inplace_trace,
+    apply_llama31_rope_pos_ids_trace,
+    apply_llama31_rope_trace,
+    apply_rope_inplace_trace,
+    apply_rope_pos_ids_inplace_trace,
+    apply_rope_pos_ids_trace,
+    apply_rope_trace,
--
 
@@ -414,7 +429,7 @@ def _fake_apply_llama31_rope_pos_ids(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=apply_rope_inplace_trace)
 def apply_rope_inplace(
     q: torch.Tensor,
     k: torch.Tensor,
@@ -502,7 +517,7 @@ def apply_rope_inplace(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=apply_rope_pos_ids_inplace_trace)
 def apply_rope_pos_ids_inplace(
     q: torch.Tensor,
     k: torch.Tensor,
@@ -561,7 +576,7 @@ def apply_rope_pos_ids_inplace(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=apply_llama31_rope_inplace_trace)
 def apply_llama31_rope_inplace(
     q: torch.Tensor,
     k: torch.Tensor,
... (truncated -- see full diff via the command above)
```

**Summary of API changes:**

- **Decorator semantic addition (backward-compatible):**
`@flashinfer_api` now accepts an optional `trace=<TraceTemplate>`
keyword. Bare `@flashinfer_api` still works. Existing call sites of
decorated functions are unaffected. Most of the diff above is mechanical
rewrites of existing `@flashinfer_api` to `@flashinfer_api(trace=...)`,
plus the new `flashinfer/trace/` package and `fi_trace.py` for
flashinfer-bench JSON dumps.

- **New public APIs (7):**
- `flashinfer.comm.dcp_alltoall.{decode_cp_a2a_workspace_size,
decode_cp_a2a_allocate_workspace, decode_cp_a2a_init_workspace,
decode_cp_a2a_alltoall}` — DCP all-to-all for context-parallel attention
reduction (#2951).
- `flashinfer.fused_moe.{interleave_moe_scales_for_sm90_mixed_gemm,
interleave_moe_weights_for_sm90_mixed_gemm}` — SM90 mixed-input MoE GEMM
helpers (#3084).
- `flashinfer.comm.run_mixed_comm` — combinations of allreduce /
allgather / reducescatter (#2563).

- **New `@flashinfer_api`-decorated wrapper init:**
- `SegmentGEMMWrapper.__init__` is now decorated. Previously the class
itself was undecorated; `run()` already was. No call-site change.

- **Backward-compatible signature additions (defaults preserve old
behavior):**
- `top_k_page_table_transform`: `+dsa_graph_safe: bool = False`,
`+row_starts: Optional[torch.Tensor] = None` (#3133).
  - `top_k_ragged_transform`: same two new params (#3133).
- `trtllm_ragged_attention_deepseek`: `+backend: str = "trtllm-gen"`
(cute-dsl backend selection).

- **No breaking signature changes** to any `@flashinfer_api` function.
Net public surface delta: +7 functions, +1 newly-decorated `__init__`, 0
removals.

- **Module reorganization to flag (not `@flashinfer_api`, but in public
re-export):**
- `flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py` →
`dense_blockscaled_gemm_sm120_b12x.py`
- Class renamed: `Sm120BlockScaledDenseGemmKernel` →
`Sm120B12xBlockScaledDenseGemmKernel`
- Re-export in `flashinfer/gemm/__init__.py` updated to the new name
only — direct importers of the old name break. Decision needed: ship as
breaking, or add a deprecation alias.

- **Internal autotuner helper rename (not public, but used by downstream
extensions):**
- `get_last_power_of_2_num_tokens_buckets` →
`get_hybrid_num_tokens_buckets`
- `last_positive_power_of_2` → `map_to_hybrid_bucket` /
`map_to_hybrid_bucket_uncapped`

> Diff truncated above due to GitHub PR body length limit. Run the
command at the top locally to see the full output.


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Patch Release**
  * Version updated to 0.6.10

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
aleozlx added a commit that referenced this pull request May 5, 2026
…path (#3210)

## Summary

Follow-up to #2951. The merged DCP A2A code shipped with two latent
foot-guns that this PR cleans up:

1. The `mapping=None` branch in `decode_cp_a2a_allocate_workspace`
returns a per-rank `torch.zeros` tensor — this deadlocks at runtime on
any real multi-GPU setup.
2. The MNNVL workspace tensor's lifetime is pinned via a private
attribute, which is silently lost across ordinary tensor operations and
would let the fabric memory be unmapped while the kernel still holds raw
pointers into it.

Both are addressed below.

## Bug 1: workspace VA mismatch (silent deadlock)

`getFifoBasePtr` in
`csrc/nv_internal/tensorrt_llm/kernels/helixAllToAll.cu:177` addresses
peer FIFOs via:

```cpp
auto* mappedMemory = params.workspace + mappedMemoryRank * params.workspaceStrideInU64;
```

This pointer arithmetic only resolves correctly when the workspace is a
single unified VA spanning all CP ranks — i.e. MNNVL fabric memory. With
per-rank `torch.zeros`, rank 0 writing to "rank 1's FIFO" lands in rank
0's own memory; peer 1 never sees it and the FIFO consumer signaling
hangs forever.

Reproduced on H200 NVL with 4 GH200s — first `decode_cp_a2a_alltoall`
call hangs, no CUDA error, no assertion. TRT-LLM upstream
(`tensorrt_llm/_torch/distributed/ops.py:386`) only supports MNNVL —
there is no plain-memory branch. The fallback added during the
FlashInfer port was wrong from the start.

### Why CI didn't catch Bug 1

1. `tests/comm/test_dcp_alltoall.py` simulates `cp_size` ranks on
**one** GPU with **one shared** workspace tensor — pointer arithmetic on
the same allocation works, so the bug is invisible.
2. `tests/comm/test_mnnvl_dcp_alltoall.py::TestMnnvlDcpAlltoall`
exercises real multi-GPU but only on the MNNVL path.
3. `TestMnnvlDcpDeviceMemoryFallback` asserts shape only — never
actually calls `alltoall`, so the deadlock never fires.

## Bug 2: workspace keep-alive via tensor private attribute

The previous code did:

```python
workspace = mnnvl_mem.as_torch_strided_tensor(torch.int64)
workspace._mnnvl_mem = mnnvl_mem  # prevent GC of MNNVL handle
return workspace
```

`MnnvlMemory.__del__` calls `close_mnnvl_memory` which unmaps the
underlying fabric VA. The `workspace._mnnvl_mem` private attribute is
the only thing keeping the wrapper alive. **But torch tensor private
attributes are NOT preserved across ordinary operations** —
`workspace[r]`, `.view()`, `.contiguous()`, `.clone()`, `.to(...)`,
slicing, indexing all return a fresh tensor without the attribute. Any
caller that derives a view or slice and drops the original would
silently free the workspace while the kernel still holds raw pointers
into it. Current direct callers happen to be safe, but this is a buried
mine.

## Changes

### Bug 1 — drop the broken plain-memory path
- **Rename** `decode_cp_a2a_allocate_workspace` →
`decode_cp_a2a_allocate_mnnvl_workspace`, matching `trtllm_mnnvl_ar`
naming style. The MNNVL requirement is now obvious at the call site.
- **Make `mapping` the only required argument.** Drop the `mapping=None`
branch entirely. The redundant `cp_size` and `cp_rank` parameters were
also removed — `mapping` already carries that info, and a separate path
was a double-source-of-truth footgun.
- **Refactor the single-GPU sim test** to use a local
`_alloc_sim_workspace(cp_size)` helper that does `torch.zeros(...)`
directly — that's what the test actually needs and what its docstring
claims (it does *not* need the public allocator).
- **Drop `TestMnnvlDcpDeviceMemoryFallback`** — the path it covered no
longer exists.
- **Use the public allocator in the multi-rank test**
(`test_mnnvl_dcp_alltoall.py::_allocate_mnnvl_workspace_once`) instead
of manually instantiating `MnnvlMemory`, so the public API is exercised
end-to-end.
- Refresh module docstring to make the MNNVL requirement explicit.

### Bug 2 — robust workspace keep-alive
- Replace `workspace._mnnvl_mem = mnnvl_mem` with a module-level
`_workspace_keepalive: Dict[int, MnnvlMemory]` keyed by
`workspace.data_ptr()`. The dict pins each `MnnvlMemory` for the process
lifetime, so the workspace stays mapped regardless of what the caller
does with the returned tensor. This matches how TRT-LLM holds workspace
ownership in its `HelixAllToAllNative._cache` class-level dict.

### Final allocator signature
```python
decode_cp_a2a_allocate_mnnvl_workspace(mapping: Mapping, *, mnnvl_config: Optional[MnnvlConfig] = None) -> torch.Tensor
```

## Test plan

All verified on dlcluster GB200-NVL72 (compute capability 10.0a, CUDA
13):

- [x] `tests/comm/test_dcp_alltoall.py` (single-GPU sim, container
`flashinfer/flashinfer-ci-cu130`) — **29/29 PASSED in 5.16s**
- [x] `mpirun -launcher fork -np 4 pytest
tests/comm/test_mnnvl_dcp_alltoall.py` — single-node 4-rank MNNVL —
**8/8 PASSED on all 4 ranks in 1.34s**
- [x] `srun -N 2 --ntasks-per-node=4 --mpi=pmix pytest
tests/comm/test_mnnvl_dcp_alltoall.py` — **multi-node 8-rank MNNVL** (2
nodes × 4 GPUs, cp_size=8, container `nvcr.io/nvidia/pytorch:26.02-py3`
with HPC-X) — **8/8 PASSED on all 8 ranks in ~17s**

The multi-node run exercises real cross-node fabric memory allocation
via `cuMemCreate` with FABRIC handles — same code path Helix production
uses on NVL72.

🤖 AI-assisted (Claude Code)

---------

Co-authored-by: Alex Yang <aleyang@nvidia.com>
aleozlx added a commit that referenced this pull request May 7, 2026
…path (#3210)

## Summary

Follow-up to #2951. The merged DCP A2A code shipped with two latent
foot-guns that this PR cleans up:

1. The `mapping=None` branch in `decode_cp_a2a_allocate_workspace`
returns a per-rank `torch.zeros` tensor — this deadlocks at runtime on
any real multi-GPU setup.
2. The MNNVL workspace tensor's lifetime is pinned via a private
attribute, which is silently lost across ordinary tensor operations and
would let the fabric memory be unmapped while the kernel still holds raw
pointers into it.

Both are addressed below.

## Bug 1: workspace VA mismatch (silent deadlock)

`getFifoBasePtr` in
`csrc/nv_internal/tensorrt_llm/kernels/helixAllToAll.cu:177` addresses
peer FIFOs via:

```cpp
auto* mappedMemory = params.workspace + mappedMemoryRank * params.workspaceStrideInU64;
```

This pointer arithmetic only resolves correctly when the workspace is a
single unified VA spanning all CP ranks — i.e. MNNVL fabric memory. With
per-rank `torch.zeros`, rank 0 writing to "rank 1's FIFO" lands in rank
0's own memory; peer 1 never sees it and the FIFO consumer signaling
hangs forever.

Reproduced on H200 NVL with 4 GH200s — first `decode_cp_a2a_alltoall`
call hangs, no CUDA error, no assertion. TRT-LLM upstream
(`tensorrt_llm/_torch/distributed/ops.py:386`) only supports MNNVL —
there is no plain-memory branch. The fallback added during the
FlashInfer port was wrong from the start.

### Why CI didn't catch Bug 1

1. `tests/comm/test_dcp_alltoall.py` simulates `cp_size` ranks on
**one** GPU with **one shared** workspace tensor — pointer arithmetic on
the same allocation works, so the bug is invisible.
2. `tests/comm/test_mnnvl_dcp_alltoall.py::TestMnnvlDcpAlltoall`
exercises real multi-GPU but only on the MNNVL path.
3. `TestMnnvlDcpDeviceMemoryFallback` asserts shape only — never
actually calls `alltoall`, so the deadlock never fires.

## Bug 2: workspace keep-alive via tensor private attribute

The previous code did:

```python
workspace = mnnvl_mem.as_torch_strided_tensor(torch.int64)
workspace._mnnvl_mem = mnnvl_mem  # prevent GC of MNNVL handle
return workspace
```

`MnnvlMemory.__del__` calls `close_mnnvl_memory` which unmaps the
underlying fabric VA. The `workspace._mnnvl_mem` private attribute is
the only thing keeping the wrapper alive. **But torch tensor private
attributes are NOT preserved across ordinary operations** —
`workspace[r]`, `.view()`, `.contiguous()`, `.clone()`, `.to(...)`,
slicing, indexing all return a fresh tensor without the attribute. Any
caller that derives a view or slice and drops the original would
silently free the workspace while the kernel still holds raw pointers
into it. Current direct callers happen to be safe, but this is a buried
mine.

## Changes

### Bug 1 — drop the broken plain-memory path
- **Rename** `decode_cp_a2a_allocate_workspace` →
`decode_cp_a2a_allocate_mnnvl_workspace`, matching `trtllm_mnnvl_ar`
naming style. The MNNVL requirement is now obvious at the call site.
- **Make `mapping` the only required argument.** Drop the `mapping=None`
branch entirely. The redundant `cp_size` and `cp_rank` parameters were
also removed — `mapping` already carries that info, and a separate path
was a double-source-of-truth footgun.
- **Refactor the single-GPU sim test** to use a local
`_alloc_sim_workspace(cp_size)` helper that does `torch.zeros(...)`
directly — that's what the test actually needs and what its docstring
claims (it does *not* need the public allocator).
- **Drop `TestMnnvlDcpDeviceMemoryFallback`** — the path it covered no
longer exists.
- **Use the public allocator in the multi-rank test**
(`test_mnnvl_dcp_alltoall.py::_allocate_mnnvl_workspace_once`) instead
of manually instantiating `MnnvlMemory`, so the public API is exercised
end-to-end.
- Refresh module docstring to make the MNNVL requirement explicit.

### Bug 2 — robust workspace keep-alive
- Replace `workspace._mnnvl_mem = mnnvl_mem` with a module-level
`_workspace_keepalive: Dict[int, MnnvlMemory]` keyed by
`workspace.data_ptr()`. The dict pins each `MnnvlMemory` for the process
lifetime, so the workspace stays mapped regardless of what the caller
does with the returned tensor. This matches how TRT-LLM holds workspace
ownership in its `HelixAllToAllNative._cache` class-level dict.

### Final allocator signature
```python
decode_cp_a2a_allocate_mnnvl_workspace(mapping: Mapping, *, mnnvl_config: Optional[MnnvlConfig] = None) -> torch.Tensor
```

## Test plan

All verified on dlcluster GB200-NVL72 (compute capability 10.0a, CUDA
13):

- [x] `tests/comm/test_dcp_alltoall.py` (single-GPU sim, container
`flashinfer/flashinfer-ci-cu130`) — **29/29 PASSED in 5.16s**
- [x] `mpirun -launcher fork -np 4 pytest
tests/comm/test_mnnvl_dcp_alltoall.py` — single-node 4-rank MNNVL —
**8/8 PASSED on all 4 ranks in 1.34s**
- [x] `srun -N 2 --ntasks-per-node=4 --mpi=pmix pytest
tests/comm/test_mnnvl_dcp_alltoall.py` — **multi-node 8-rank MNNVL** (2
nodes × 4 GPUs, cp_size=8, container `nvcr.io/nvidia/pytorch:26.02-py3`
with HPC-X) — **8/8 PASSED on all 8 ranks in ~17s**

The multi-node run exercises real cross-node fabric memory allocation
via `cuMemCreate` with FABRIC handles — same code path Helix production
uses on NVL72.

🤖 AI-assisted (Claude Code)

---------

Co-authored-by: Alex Yang <aleyang@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants