Skip to content

feat: Add row_starts and dsa_graph_safe to topk#3133

Merged
kahyunnam merged 4 commits intoflashinfer-ai:mainfrom
zianglih:dsa-graph-safe
Apr 24, 2026
Merged

feat: Add row_starts and dsa_graph_safe to topk#3133
kahyunnam merged 4 commits intoflashinfer-ai:mainfrom
zianglih:dsa-graph-safe

Conversation

@zianglih
Copy link
Copy Markdown
Contributor

@zianglih zianglih commented Apr 21, 2026

📌 Description

@HumansAnd
Parent PR: #3095
SGLang PR: sgl-project/sglang#22851

Add row_starts and dsa_graph_safe for SGLang DSA integration.

🔍 Related Issues

sgl-project/sglang#22851 (comment)

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added dsa_graph_safe flag to top-k APIs to opt into DSA-graph safe execution.
    • Added optional row_starts parameter to page-table and ragged top-k transforms to support per-row score offsets.
  • Behavior

    • When dsa_graph_safe=True the optimized clusters fast-path is disabled to ensure safe execution.
  • Tests

    • Added tests covering row_starts behavior for page-table and ragged transforms.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 21, 2026

📝 Walkthrough

Walkthrough

Threads a new boolean flag dsa_graph_safe and an optional row_starts tensor through the Top-K stack: Python API, FFI bindings, C++ dispatch, CUDA headers/kernels, and tests. Dispatch and kernel signatures, per-row indexing, vector-size selection, and filtered-topk control flow were updated accordingly.

Changes

Cohort / File(s) Summary
Bindings / FFI
csrc/flashinfer_topk_binding.cu
Exported Top-K entrypoints signatures extended to accept bool dsa_graph_safe; page-table and ragged transforms also accept Optional<TensorView> maybe_row_starts.
C++ implementation
csrc/topk.cu
Threaded dsa_graph_safe through radix_topk* functions and dispatch calls; added validation for optional row_starts and passed a row_starts_ptr into dispatch.
Python API
flashinfer/topk.py
Added dsa_graph_safe: bool = False to top-k APIs; page-table and ragged transforms gain row_starts: Optional[torch.Tensor] = None; clusters fast-path disabled when dsa_graph_safe=True or row_starts provided; args forwarded to module.
CUDA headers & kernels
include/flashinfer/topk.cuh
Added const IdType* row_starts kernel parameter across unified Radix/FilteredTopK paths; page-table lookups offset by row_start; ComputeFilteredTopKVecSize and dispatchs accept dsa_graph_safe to force vec_size=1 when set; dispatch wrappers updated to accept row_starts/dsa_graph_safe.
Tests
tests/utils/test_topk.py
Reference page-table/ragged transforms accept optional row_starts and adjust slicing/indexing; new parametrized test test_top_k_transform_with_row_starts added to validate behavior.

Sequence Diagram(s)

sequenceDiagram
  participant Py as Python API
  participant FFI as C++ FFI Binding
  participant Dispatch as TopK Dispatch
  participant Kernel as CUDA Kernel
  Py->>FFI: call radix_topk(..., dsa_graph_safe, maybe_row_starts)
  FFI->>Dispatch: forward tensors, dsa_graph_safe, row_starts_ptr
  Dispatch->>Dispatch: choose path (FilteredTopK vs Radix) using dsa_graph_safe, tie_break
  Dispatch->>Kernel: launch kernel with row_starts, dsa_graph_safe
  Kernel-->>Dispatch: return results/status
  Dispatch-->>FFI: propagate results
  FFI-->>Py: return tensors
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • aleozlx
  • yzh119
  • sricketts
  • cyx-6
  • bkryu
  • nv-yunzheq
  • jiahanc

Poem

🐰
Row starts mark where scores begin,
A safe graph flag keeps vecs to one,
From Python call to CUDA run,
Offsets thread and kernels hum,
Hopping through layers—Top-K done!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 36.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely summarizes the main changes: adding two new parameters (row_starts and dsa_graph_safe) to the topk functionality.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed The PR description includes the required template structure with all main sections populated: Description, Related Issues, and completed Pre-commit and Tests checklists.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@zianglih zianglih changed the title feat: Add a dsa_graph_safe flag to topk feat: Add a dsa_graph_safe flag to topk Apr 21, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces deterministic tie-breaking support for top-k operations, enabling users to specify whether to prefer smaller or larger indices for equal values at the selection boundary. The changes include the addition of a TopKTieBreak enum, updates to the CUDA kernels and Python API, and the implementation of a DeterministicContiguousCollect helper for contiguous index traversal. Benchmarking and testing utilities have also been expanded to cover these new modes. Review feedback highlights opportunities to improve performance by ensuring coalesced memory reads in the collection helper and suggests reusing shared memory buffers to stay within hardware limits.

Comment thread include/flashinfer/topk.cuh
Comment thread include/flashinfer/topk.cuh
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
include/flashinfer/topk.cuh (1)

3218-3226: ⚠️ Potential issue | 🟠 Major

Let tie-break requests override the benchmark algorithm override.

With FLASHINFER_TOPK_ALGO=multi_cta, Line 3221 returns false before the tie-break check, so tie_break=Small/Large silently falls back to radix even though the comment says tie-break is only supported by FilteredTopK.

Proposed fix
-  // Check for algorithm override
-  const TopKAlgoOverride algo_override = GetTopKAlgoOverride();
-  if (algo_override == TopKAlgoOverride::FILTERED) return true;
-  if (algo_override == TopKAlgoOverride::MULTI_CTA) return false;
-
   // Tie-break modes are only supported by FilteredTopK
   if (tie_break != TopKTieBreak::None) {
     return true;
   }
+
+  // Check for algorithm override
+  const TopKAlgoOverride algo_override = GetTopKAlgoOverride();
+  if (algo_override == TopKAlgoOverride::FILTERED) return true;
+  if (algo_override == TopKAlgoOverride::MULTI_CTA) return false;
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/topk.cuh` around lines 3218 - 3226, The current logic
checks GetTopKAlgoOverride() before considering tie_break, which lets
TopKAlgoOverride::MULTI_CTA override requested tie-breaks; change the branch
order so tie-break requests take precedence: first check if tie_break !=
TopKTieBreak::None and return true (support FilteredTopK), then query
GetTopKAlgoOverride() and handle TopKAlgoOverride::FILTERED / MULTI_CTA; update
the function containing these checks (refer to GetTopKAlgoOverride,
TopKAlgoOverride, and TopKTieBreak) so tie-break modes always force the
FilteredTopK path.
benchmarks/bench_topk.py (1)

883-889: ⚠️ Potential issue | 🟡 Minor

The "sglang_error" key is never populated — this branch is dead and inconsistent with other sections.

Line 888 checks "sglang_error" in result, but sglang_error is not set anywhere in the codebase. The sglang block (lines 208–212) only writes sglang_us, and failures surface as RuntimeError exceptions caught at lines 891–899. This makes the elif at line 888 unreachable.

Additionally, the analogous fallback branches in page_table (line 1126) and ragged (line 1231) still use the original k == 2048 check. This inconsistency suggests incomplete refactoring—either restore the k == 2048 check in the top_k section or populate result["sglang_error"] and mirror the change in page_table and ragged sections.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_topk.py` around lines 883 - 889, The branch checking
"sglang_error" is dead because result["sglang_error"] is never set; fix by
either (A) when catching the RuntimeError in the top_k benchmark code path (the
block that currently writes result["sglang_us"]) set result["sglang_error"]=True
(or an error message) so the existing display branch can detect failures, and
update the analogous page_table and ragged sections to populate the same key for
consistency; or (B) revert the refactor and restore the original k == 2048
fallback checks in the top_k, page_table and ragged reporting code so the
fallback branches behave the same across all sections—choose one approach and
apply it consistently to result handling for sglang.
🧹 Nitpick comments (4)
include/flashinfer/topk.cuh (1)

234-236: Document the hot-path tradeoffs.

ITEMS_PER_THREAD = 4 and forcing vec_size = 1 for dsa_graph_safe are special performance-sensitive choices. Please add a short rationale and note the alternative considered, especially because Line 234 already leaves this as a TODO.

As per coding guidelines, "For performance-critical hot paths, leave comments with justification for special algorithmic choices and mention alternative approaches considered."

Also applies to: 2843-2846

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/topk.cuh` around lines 234 - 236, Add a concise comment in
the hot path explaining why ITEMS_PER_THREAD is set to 4 and CHUNK_ITEMS derived
from it (e.g., memory/register pressure vs. occupancy tradeoff,
cache/vectorization limits) and document the decision to force vec_size = 1 for
dsa_graph_safe (e.g., alignment/unaligned memory access, divergent control flow,
or correctness constraints) along with the primary alternative(s) considered
(e.g., ITEMS_PER_THREAD=8 or using vectorized loads) and why they were rejected
(impact on shared memory, register usage, or branch divergence). Place this
justification adjacent to the ITEMS_PER_THREAD/CHUNK_ITEMS definitions and
mirror a similar explanatory note where vec_size is set for dsa_graph_safe so
future maintainers can understand the performance tradeoffs and tuning
rationale.
tests/utils/test_topk.py (1)

1931-2050: Add coverage for dsa_graph_safe=True.

These new tests cover tie-break behavior, but the PR’s graph-safe flag can regress independently through routing and VEC_SIZE=1 dispatch. Please add at least one top_k and one transform API case with dsa_graph_safe=True; ideally include a CUDA graph capture/replay smoke test.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/utils/test_topk.py` around lines 1931 - 2050, Add tests that exercise
the dsa_graph_safe=True path: in test_top_k_tie_break_modes add a case that
calls flashinfer.top_k(logits, k, tie_break=1/2, dsa_graph_safe=True) (use the
same seed/generator and skip logic with can_implement_filtered_topk() and
set_topk_algo), and in test_top_k_tie_break_modes_transform_apis add calls to
flashinfer.top_k_page_table_transform(..., tie_break=1/2, dsa_graph_safe=True)
and flashinfer.top_k_ragged_transform(..., tie_break=1/2, dsa_graph_safe=True)
validating expected indices/values as done for the non-graph-safe variants;
optionally wrap one of these calls in a simple CUDA graph capture/replay smoke
test to ensure graph capture works.
flashinfer/topk.py (1)

499-540: Optional: annotate tie_break with the enum type.

Since TopKTieBreak is now a first-class public enum and the default is a TopKTieBreak member, consider typing the parameter as TopKTieBreak (or Union[TopKTieBreak, int]) across all three public APIs (top_k, top_k_page_table_transform, top_k_ragged_transform). IntEnum values still satisfy the FFI int conversion, so runtime behavior is unchanged, but callers get enum-level type checking and IDE completion instead of a bare int.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/topk.py` around lines 499 - 540, Update the tie_break parameter
annotations to use the TopKTieBreak enum (or Union[TopKTieBreak, int]) in the
public APIs so callers get enum-level typing and IDE completion: change the type
on function signatures for top_k, top_k_page_table_transform, and
top_k_ragged_transform to TopKTieBreak (or Union[TopKTieBreak, int]) while
leaving default values and runtime behavior unchanged; ensure imports/typing
references for TopKTieBreak are added where needed and run typechecks to confirm
no FFI/int conversion assumptions are broken.
benchmarks/bench_topk.py (1)

89-103: Nit: bind tie_break via a default argument to silence B023 and harden against future refactors.

Ruff flags B023 on line 95. Today this is a false positive — bench_median_ms consumes the lambda synchronously before the loop advances, so the late-binding hazard does not actually trigger. It’s still cheap to make the capture explicit in case the lambda is ever deferred (e.g., scheduled, stored, or passed to an async benchmarker):

Proposed defensive fix
-    for suffix, tie_break in TIE_BREAK_VARIANTS:
-        try:
-            tie_ms = bench_median_ms(lambda: run_flashinfer_with_tie_break(tie_break))
+    for suffix, tie_break in TIE_BREAK_VARIANTS:
+        try:
+            tie_ms = bench_median_ms(
+                lambda tb=tie_break: run_flashinfer_with_tie_break(tb)
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_topk.py` around lines 89 - 103, The loop in
bench_tie_break_variants closes over tie_break causing a potential late-binding
issue flagged by Ruff B023; change the lambda passed to bench_median_ms to
capture tie_break as a default argument (e.g., lambda tb=tie_break:
run_flashinfer_with_tie_break(tb)) so the current tie_break value is bound
immediately; update the invocation around bench_median_ms(...) and leave the
rest of the logic (metrics keys using suffix, error
handling/classify_benchmark_runtime_error) unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@benchmarks/bench_topk.py`:
- Around line 883-889: The branch checking "sglang_error" is dead because
result["sglang_error"] is never set; fix by either (A) when catching the
RuntimeError in the top_k benchmark code path (the block that currently writes
result["sglang_us"]) set result["sglang_error"]=True (or an error message) so
the existing display branch can detect failures, and update the analogous
page_table and ragged sections to populate the same key for consistency; or (B)
revert the refactor and restore the original k == 2048 fallback checks in the
top_k, page_table and ragged reporting code so the fallback branches behave the
same across all sections—choose one approach and apply it consistently to result
handling for sglang.

In `@include/flashinfer/topk.cuh`:
- Around line 3218-3226: The current logic checks GetTopKAlgoOverride() before
considering tie_break, which lets TopKAlgoOverride::MULTI_CTA override requested
tie-breaks; change the branch order so tie-break requests take precedence: first
check if tie_break != TopKTieBreak::None and return true (support FilteredTopK),
then query GetTopKAlgoOverride() and handle TopKAlgoOverride::FILTERED /
MULTI_CTA; update the function containing these checks (refer to
GetTopKAlgoOverride, TopKAlgoOverride, and TopKTieBreak) so tie-break modes
always force the FilteredTopK path.

---

Nitpick comments:
In `@benchmarks/bench_topk.py`:
- Around line 89-103: The loop in bench_tie_break_variants closes over tie_break
causing a potential late-binding issue flagged by Ruff B023; change the lambda
passed to bench_median_ms to capture tie_break as a default argument (e.g.,
lambda tb=tie_break: run_flashinfer_with_tie_break(tb)) so the current tie_break
value is bound immediately; update the invocation around bench_median_ms(...)
and leave the rest of the logic (metrics keys using suffix, error
handling/classify_benchmark_runtime_error) unchanged.

In `@flashinfer/topk.py`:
- Around line 499-540: Update the tie_break parameter annotations to use the
TopKTieBreak enum (or Union[TopKTieBreak, int]) in the public APIs so callers
get enum-level typing and IDE completion: change the type on function signatures
for top_k, top_k_page_table_transform, and top_k_ragged_transform to
TopKTieBreak (or Union[TopKTieBreak, int]) while leaving default values and
runtime behavior unchanged; ensure imports/typing references for TopKTieBreak
are added where needed and run typechecks to confirm no FFI/int conversion
assumptions are broken.

In `@include/flashinfer/topk.cuh`:
- Around line 234-236: Add a concise comment in the hot path explaining why
ITEMS_PER_THREAD is set to 4 and CHUNK_ITEMS derived from it (e.g.,
memory/register pressure vs. occupancy tradeoff, cache/vectorization limits) and
document the decision to force vec_size = 1 for dsa_graph_safe (e.g.,
alignment/unaligned memory access, divergent control flow, or correctness
constraints) along with the primary alternative(s) considered (e.g.,
ITEMS_PER_THREAD=8 or using vectorized loads) and why they were rejected (impact
on shared memory, register usage, or branch divergence). Place this
justification adjacent to the ITEMS_PER_THREAD/CHUNK_ITEMS definitions and
mirror a similar explanatory note where vec_size is set for dsa_graph_safe so
future maintainers can understand the performance tradeoffs and tuning
rationale.

In `@tests/utils/test_topk.py`:
- Around line 1931-2050: Add tests that exercise the dsa_graph_safe=True path:
in test_top_k_tie_break_modes add a case that calls flashinfer.top_k(logits, k,
tie_break=1/2, dsa_graph_safe=True) (use the same seed/generator and skip logic
with can_implement_filtered_topk() and set_topk_algo), and in
test_top_k_tie_break_modes_transform_apis add calls to
flashinfer.top_k_page_table_transform(..., tie_break=1/2, dsa_graph_safe=True)
and flashinfer.top_k_ragged_transform(..., tie_break=1/2, dsa_graph_safe=True)
validating expected indices/values as done for the non-graph-safe variants;
optionally wrap one of these calls in a simple CUDA graph capture/replay smoke
test to ensure graph capture works.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 88e1832b-cadc-43e2-b4c6-4c84155aaf21

📥 Commits

Reviewing files that changed from the base of the PR and between 9e3d8b9 and 6bbd1dac40692808d999fbffca4a00ec189a1b6d.

📒 Files selected for processing (7)
  • benchmarks/bench_topk.py
  • csrc/flashinfer_topk_binding.cu
  • csrc/topk.cu
  • flashinfer/__init__.py
  • flashinfer/topk.py
  • include/flashinfer/topk.cuh
  • tests/utils/test_topk.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
include/flashinfer/topk.cuh (1)

3070-3104: ⚠️ Potential issue | 🟡 Minor

Make tie_break imply deterministic mode inside the filtered launcher.

LaunchFilteredTopKUnified exposes tie_break, but direct callers passing tie_break != None with deterministic=false still launch the non-deterministic TopKTieBreak::None specialization. The higher-level dispatchers normalize this today, but this wrapper should enforce its own API contract.

Suggested fix
 cudaError_t LaunchFilteredTopKUnified(DType* input, IdType* output, DType* aux_output,
                                       const IdType* aux_input, int64_t aux_stride,
                                       const IdType* row_to_batch, const IdType* lengths,
                                       uint32_t num_rows, uint32_t top_k_val, uint32_t max_len,
                                       bool deterministic = false,
                                       TopKTieBreak tie_break = TopKTieBreak::None,
                                       cudaStream_t stream = 0, bool dsa_graph_safe = false) {
   constexpr size_t smem_size = FILTERED_TOPK_SMEM_DYNAMIC;
   constexpr int MAX_VEC = 16 / sizeof(DType);
+  const bool effective_deterministic = deterministic || tie_break != TopKTieBreak::None;
@@
 `#define` DISPATCH_VEC_SIZE(VS)                                  \
   if (vec_size == VS) {                                        \
-    if (!deterministic) {                                      \
+    if (!effective_deterministic) {                            \
       LAUNCH_FILTERED_KERNEL(VS, false, TopKTieBreak::None);   \
     } else {                                                   \
       if (tie_break == TopKTieBreak::Small) {                  \
         LAUNCH_FILTERED_KERNEL(VS, true, TopKTieBreak::Small); \
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/topk.cuh` around lines 3070 - 3104, The launcher currently
ignores a non-None tie_break when deterministic==false; change the dispatch to
compute an effective deterministic flag (e.g., bool effective_det =
deterministic || (tie_break != TopKTieBreak::None)) and use effective_det in
DISPATCH_VEC_SIZE/launch logic so that any tie_break != TopKTieBreak::None
forces the deterministic specialization via LAUNCH_FILTERED_KERNEL(..., true,
tie_break) while preserving the existing non-deterministic path only when
effective_det is false; update references to deterministic in the
DISPATCH_VEC_SIZE block to use this effective_det and select the correct
TopKTieBreak template parameter accordingly.
🧹 Nitpick comments (1)
include/flashinfer/topk.cuh (1)

234-236: Document the fixed chunking choice or remove the TODO.

ITEMS_PER_THREAD = 4 is now part of a performance-sensitive deterministic tie-break path. Please either justify why 4 is the intended trade-off here or link this TODO to a tracked tuning task so the algorithmic choice is explicit. As per coding guidelines, For performance-critical hot paths, leave comments with justification for special algorithmic choices and mention alternative approaches considered.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/topk.cuh` around lines 234 - 236, Replace the TODO by
either (a) adding a short justification comment next to ITEMS_PER_THREAD = 4
explaining why 4 was chosen (trade-offs tested, microbenchmarks summary,
sensitivity in the deterministic tie-break hot path, interaction with
BLOCK_THREADS and CHUNK_ITEMS, and why vectorization wasn't chosen), or (b) if
the number is provisional, remove the TODO and add a one-line reference to a
tracked tuning task/issue ID that contains the benchmarking results and
alternative values tested; ensure the comment mentions the symbols
ITEMS_PER_THREAD, CHUNK_ITEMS and BLOCK_THREADS and that this choice affects the
deterministic tie-break/performance-critical path.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/bench_topk.py`:
- Around line 89-103: In bench_tie_break_variants, the lambda passed to
bench_median_ms captures the loop variable tie_break by reference and the
metrics dict is annotated too narrowly as dict[str, float]; fix by binding the
loop variable in the lambda (e.g., make it a default arg so you call
run_flashinfer_with_tie_break(tie_break=tie_break) inside the lambda) and widen
the return type to allow string error labels (e.g., change the annotation from
dict[str, float] to dict[str, float | str] or dict[str, Any]); keep references
to TIE_BREAK_VARIANTS, run_flashinfer_with_tie_break,
classify_benchmark_runtime_error, and metrics when making the edits.

---

Outside diff comments:
In `@include/flashinfer/topk.cuh`:
- Around line 3070-3104: The launcher currently ignores a non-None tie_break
when deterministic==false; change the dispatch to compute an effective
deterministic flag (e.g., bool effective_det = deterministic || (tie_break !=
TopKTieBreak::None)) and use effective_det in DISPATCH_VEC_SIZE/launch logic so
that any tie_break != TopKTieBreak::None forces the deterministic specialization
via LAUNCH_FILTERED_KERNEL(..., true, tie_break) while preserving the existing
non-deterministic path only when effective_det is false; update references to
deterministic in the DISPATCH_VEC_SIZE block to use this effective_det and
select the correct TopKTieBreak template parameter accordingly.

---

Nitpick comments:
In `@include/flashinfer/topk.cuh`:
- Around line 234-236: Replace the TODO by either (a) adding a short
justification comment next to ITEMS_PER_THREAD = 4 explaining why 4 was chosen
(trade-offs tested, microbenchmarks summary, sensitivity in the deterministic
tie-break hot path, interaction with BLOCK_THREADS and CHUNK_ITEMS, and why
vectorization wasn't chosen), or (b) if the number is provisional, remove the
TODO and add a one-line reference to a tracked tuning task/issue ID that
contains the benchmarking results and alternative values tested; ensure the
comment mentions the symbols ITEMS_PER_THREAD, CHUNK_ITEMS and BLOCK_THREADS and
that this choice affects the deterministic tie-break/performance-critical path.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 51145eb8-8c31-4857-a35e-958b56de4bbb

📥 Commits

Reviewing files that changed from the base of the PR and between 6bbd1dac40692808d999fbffca4a00ec189a1b6d and 30d7210cbbc0f685cc00026100c20aca9a4db702.

📒 Files selected for processing (5)
  • benchmarks/bench_topk.py
  • csrc/flashinfer_topk_binding.cu
  • csrc/topk.cu
  • flashinfer/topk.py
  • include/flashinfer/topk.cuh

Comment thread benchmarks/bench_topk.py
@zianglih zianglih marked this pull request as draft April 21, 2026 08:18
@ziang-and ziang-and force-pushed the dsa-graph-safe branch 2 times, most recently from 5432f6d to e5f4eb0 Compare April 22, 2026 00:40
@zianglih zianglih changed the title feat: Add a dsa_graph_safe flag to topk feat: Add row_starts and dsa_graph_safe to topk Apr 22, 2026
@zianglih zianglih marked this pull request as ready for review April 22, 2026 04:14
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/topk.py (2)

789-799: ⚠️ Potential issue | 🟠 Major

Preserve positional compatibility for deterministic.

row_starts is inserted before the existing deterministic parameter, so existing calls like top_k_ragged_transform(scores, offsets, lengths, k, True) now pass True as row_starts.

Proposed fix
 def top_k_ragged_transform(
     input: torch.Tensor,
     offsets: torch.Tensor,
     lengths: torch.Tensor,
     k: int,
-    row_starts: Optional[torch.Tensor] = None,
     deterministic: bool = False,
     tie_break: int = TopKTieBreak.NONE,
+    row_starts: Optional[torch.Tensor] = None,
     dsa_graph_safe: bool = False,
 ) -> torch.Tensor:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/topk.py` around lines 789 - 799, The signature change to
top_k_ragged_transform moved row_starts before deterministic, breaking
positional callers; restore positional compatibility by ensuring deterministic
remains the positional parameter before row_starts (i.e., place deterministic as
the parameter immediately after k and make row_starts either follow
deterministic or be keyword-only), update the function signature accordingly and
adjust any internal usage of row_starts/deterministic inside
top_k_ragged_transform to match the restored parameter order.

658-669: ⚠️ Potential issue | 🟠 Major

Preserve positional compatibility for row_to_batch.

row_starts is inserted before the existing row_to_batch parameter, so existing calls like top_k_page_table_transform(scores, table, lengths, k, row_to_batch) now bind that tensor as row_starts and silently compute the wrong mapping.

Proposed fix
 def top_k_page_table_transform(
     input: torch.Tensor,
     src_page_table: torch.Tensor,
     lengths: torch.Tensor,
     k: int,
-    row_starts: Optional[torch.Tensor] = None,
     row_to_batch: Optional[torch.Tensor] = None,
     deterministic: bool = False,
     tie_break: int = TopKTieBreak.NONE,
+    row_starts: Optional[torch.Tensor] = None,
     dsa_graph_safe: bool = False,
 ) -> torch.Tensor:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/topk.py` around lines 658 - 669, The function signature change in
top_k_page_table_transform broke positional compatibility by inserting
row_starts before the existing row_to_batch parameter; restore compatibility by
reordering the parameters so row_to_batch appears before row_starts (i.e., keep
the original positional order: ..., k, row_to_batch:
Optional[torch.Tensor]=None, row_starts: Optional[torch.Tensor]=None,
deterministic=..., tie_break=..., dsa_graph_safe=...), update any internal
references to use the renamed parameters accordingly, and run tests that call
top_k_page_table_transform(positionally) to confirm behavior is unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@flashinfer/topk.py`:
- Around line 789-799: The signature change to top_k_ragged_transform moved
row_starts before deterministic, breaking positional callers; restore positional
compatibility by ensuring deterministic remains the positional parameter before
row_starts (i.e., place deterministic as the parameter immediately after k and
make row_starts either follow deterministic or be keyword-only), update the
function signature accordingly and adjust any internal usage of
row_starts/deterministic inside top_k_ragged_transform to match the restored
parameter order.
- Around line 658-669: The function signature change in
top_k_page_table_transform broke positional compatibility by inserting
row_starts before the existing row_to_batch parameter; restore compatibility by
reordering the parameters so row_to_batch appears before row_starts (i.e., keep
the original positional order: ..., k, row_to_batch:
Optional[torch.Tensor]=None, row_starts: Optional[torch.Tensor]=None,
deterministic=..., tie_break=..., dsa_graph_safe=...), update any internal
references to use the renamed parameters accordingly, and run tests that call
top_k_page_table_transform(positionally) to confirm behavior is unchanged.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: be26a571-3378-41eb-9773-499092e2e2f0

📥 Commits

Reviewing files that changed from the base of the PR and between 30d7210cbbc0f685cc00026100c20aca9a4db702 and 20061c2.

📒 Files selected for processing (5)
  • csrc/flashinfer_topk_binding.cu
  • csrc/topk.cu
  • flashinfer/topk.py
  • include/flashinfer/topk.cuh
  • tests/utils/test_topk.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • include/flashinfer/topk.cuh

@jiahanc jiahanc added the run-ci label Apr 23, 2026
@jiahanc
Copy link
Copy Markdown
Collaborator

jiahanc commented Apr 23, 2026

/bot run

Copy link
Copy Markdown
Collaborator

@jiahanc jiahanc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for contribution!

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !588 has been created, and the CI pipeline #49277161 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Member

@kahyunnam kahyunnam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall except for one comment! Will approve after this is updated + /bot run CICD passes

Comment thread flashinfer/topk.py Outdated
src_page_table: torch.Tensor,
lengths: torch.Tensor,
k: int,
row_starts: Optional[torch.Tensor] = None,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this new optional arguments to the end? This breaks positional ordering for existing callers; this may break backwards compatibility for API definition

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done by d9638b1

@zianglih
Copy link
Copy Markdown
Contributor Author

Hi @kahyunnam I have made the requested changes. All python APIs and top-level C++ bindings have both args at the end. Internal implementation still use previous ordering for better readability. Thank you!

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (4)
flashinfer/topk.py (3)

849-854: Minor doc gap: clarify row_starts interaction with the trivial ragged path.

For top_k_ragged_transform, the "If lengths[i] <= k" note still reads as if row_starts has no trivial-case role, but callers may reasonably assume symmetry with top_k_page_table_transform (which now documents the row-shifted slice). A short clarification avoids ambiguity, e.g.:

📝 Suggested wording
-    - If lengths[i] <= k, the output contains [offsets[i], offsets[i]+1, ..., offsets[i]+lengths[i]-1]
-      with remaining positions set to -1.
+    - If lengths[i] <= k, the output contains [offsets[i], offsets[i]+1, ..., offsets[i]+lengths[i]-1]
+      with remaining positions set to -1. ``row_starts`` only shifts the score window used for
+      top-k selection; it does not shift these local indices in the trivial case.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/topk.py` around lines 849 - 854, The docstring for
top_k_ragged_transform is ambiguous about how row_starts shifts indices in the
trivial ragged path (lengths[i] <= k); update the Note to explicitly state that
when lengths[i] <= k the returned indices are the sequence
[row_starts[i]+offsets[i], row_starts[i]+offsets[i]+1, ...,
row_starts[i]+offsets[i]+lengths[i]-1] with remaining positions set to -1,
mirroring the documented behavior/symmetry of top_k_page_table_transform;
reference top_k_ragged_transform, row_starts, offsets, lengths, and
top_k_page_table_transform in the docstring so callers aren’t confused about
whether indices are row-shifted.

495-500: Default dsa_graph_safe to preserve backward compatibility.

can_use_clusters_topk is a module-level (non-underscore) helper. Adding a required third positional parameter is technically a breaking change for any external caller. Given the PR's explicit "Keep API backward compatibility" intent (and the prior review feedback about positional ordering), consider defaulting it:

♻️ Suggested default
-def can_use_clusters_topk(device, deterministic, dsa_graph_safe):
+def can_use_clusters_topk(device, deterministic, dsa_graph_safe=False):
     if dsa_graph_safe:
         return False
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/topk.py` around lines 495 - 500, can_use_clusters_topk currently
requires a third positional parameter dsa_graph_safe which is a breaking API
change; make dsa_graph_safe optional with a default value (e.g., False) so
existing callers keep current behavior, update the function signature for
can_use_clusters_topk to set dsa_graph_safe=False and ensure the function body
still uses the parameter as before, and scan for external uses of
can_use_clusters_topk to confirm none rely on a mandatory third argument.

728-731: Nit: tighten the trivial-case wording for readability.

The inline parenthetical splits an RST code reference across lines and reads awkwardly. A small rewrite keeps the code literal intact:

📝 Suggested wording
-    - If lengths[i] <= k, the output simply contains
-      ``src_page_table[batch_idx, row_starts[i]:row_starts[i] + lengths[i]]`` (or start 0 when
-      ``row_starts`` is None)
-      with remaining positions set to -1.
+    - If lengths[i] <= k, the output simply contains
+      ``src_page_table[batch_idx, s:s + lengths[i]]`` where ``s = row_starts[i]`` (or 0 when
+      ``row_starts`` is None), with remaining positions set to -1.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/topk.py` around lines 728 - 731, The docstring sentence describing
the trivial case splits the RST code reference across lines and reads awkwardly;
update the sentence in topk.py that starts "If lengths[i] <= k" to keep the code
literal intact by making it a single clear clause referencing the symbols
lengths, k, src_page_table, row_starts and batch_idx — e.g. state that when
lengths[i] <= k the output contains the entries of src_page_table for batch_idx
from row_starts[i] to row_starts[i] + lengths[i], with row_starts treated as
starting at 0 when row_starts is None, and any remaining positions set to -1.
tests/utils/test_topk.py (1)

459-476: Consider adding trivial-length coverage for row_starts.

In the ragged reference, row_start is read but intentionally unused in the trivial branch (length <= k), matching the documented semantics (output is local_topk + offsets[i]). The new test_top_k_transform_with_row_starts forces lengths >= k+1, so the kernel's trivial-length behavior under non-zero row_starts is not validated against this reference. A small additional case (e.g., one row with lengths[i] <= k and row_starts[i] > 0) would close that gap for both transforms.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/utils/test_topk.py` around lines 459 - 476, Add a trivial-length test
case to exercise the branch where length <= k while row_starts is non-zero: in
tests/utils/test_topk.py (the test_top_k_transform_with_row_starts setup),
append or insert one row whose lengths[i] <= k and row_starts[i] > 0 (ensure
offsets[i] is set) and verify output[i, :length] equals torch.arange(offset,
offset+length) so the reference path that ignores row_start is validated; update
any generated scores/slices accordingly so that this single-row case triggers
the trivial branch alongside the existing longer rows.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@flashinfer/topk.py`:
- Around line 849-854: The docstring for top_k_ragged_transform is ambiguous
about how row_starts shifts indices in the trivial ragged path (lengths[i] <=
k); update the Note to explicitly state that when lengths[i] <= k the returned
indices are the sequence [row_starts[i]+offsets[i], row_starts[i]+offsets[i]+1,
..., row_starts[i]+offsets[i]+lengths[i]-1] with remaining positions set to -1,
mirroring the documented behavior/symmetry of top_k_page_table_transform;
reference top_k_ragged_transform, row_starts, offsets, lengths, and
top_k_page_table_transform in the docstring so callers aren’t confused about
whether indices are row-shifted.
- Around line 495-500: can_use_clusters_topk currently requires a third
positional parameter dsa_graph_safe which is a breaking API change; make
dsa_graph_safe optional with a default value (e.g., False) so existing callers
keep current behavior, update the function signature for can_use_clusters_topk
to set dsa_graph_safe=False and ensure the function body still uses the
parameter as before, and scan for external uses of can_use_clusters_topk to
confirm none rely on a mandatory third argument.
- Around line 728-731: The docstring sentence describing the trivial case splits
the RST code reference across lines and reads awkwardly; update the sentence in
topk.py that starts "If lengths[i] <= k" to keep the code literal intact by
making it a single clear clause referencing the symbols lengths, k,
src_page_table, row_starts and batch_idx — e.g. state that when lengths[i] <= k
the output contains the entries of src_page_table for batch_idx from
row_starts[i] to row_starts[i] + lengths[i], with row_starts treated as starting
at 0 when row_starts is None, and any remaining positions set to -1.

In `@tests/utils/test_topk.py`:
- Around line 459-476: Add a trivial-length test case to exercise the branch
where length <= k while row_starts is non-zero: in tests/utils/test_topk.py (the
test_top_k_transform_with_row_starts setup), append or insert one row whose
lengths[i] <= k and row_starts[i] > 0 (ensure offsets[i] is set) and verify
output[i, :length] equals torch.arange(offset, offset+length) so the reference
path that ignores row_start is validated; update any generated scores/slices
accordingly so that this single-row case triggers the trivial branch alongside
the existing longer rows.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f2e2f667-ff32-401b-a377-dac5ae258868

📥 Commits

Reviewing files that changed from the base of the PR and between 20061c2 and d9638b1.

📒 Files selected for processing (4)
  • csrc/flashinfer_topk_binding.cu
  • csrc/topk.cu
  • flashinfer/topk.py
  • tests/utils/test_topk.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • csrc/topk.cu

@zianglih zianglih requested a review from kahyunnam April 23, 2026 19:22
@kahyunnam kahyunnam enabled auto-merge (squash) April 23, 2026 20:57
@kahyunnam kahyunnam merged commit ef46793 into flashinfer-ai:main Apr 24, 2026
32 of 38 checks passed
@zianglih zianglih deleted the dsa-graph-safe branch April 24, 2026 17:37
@aleozlx aleozlx mentioned this pull request Apr 25, 2026
aleozlx added a commit that referenced this pull request May 5, 2026
## Description

Bump version to 0.6.10 for release.

## Related Issues (Gated-by PRs)


https://github.com/flashinfer-ai/flashinfer/issues?q=is%3Aopen+label%3Av0.6.10

## Reviewer Notes

**API changes review**

API changes since v0.6.9

```diff
$ git diff v0.6.9..main -- "*.py" | grep -B5 -A20 "@flashinfer_api"
     register_custom_op,
@@ -67,7 +73,7 @@ def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=silu_and_mul_trace)
 def silu_and_mul(
     input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None
 ) -> torch.Tensor:
@@ -112,7 +118,7 @@ def silu_and_mul(
     return out
 
 
-@flashinfer_api
+@flashinfer_api(trace=gelu_tanh_and_mul_trace)
 def gelu_tanh_and_mul(
     input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None
 ) -> torch.Tensor:
@@ -153,7 +159,7 @@ def gelu_tanh_and_mul(
     return out
 
 
-@flashinfer_api
+@flashinfer_api(trace=gelu_and_mul_trace)
 def gelu_and_mul(
     input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None
 ) -> torch.Tensor:
@@ -194,7 +200,7 @@ def gelu_and_mul(
     return out
 
 
-@flashinfer_api
+@flashinfer_api(trace=silu_and_mul_scaled_nvfp4_experts_quantize_trace)
 def silu_and_mul_scaled_nvfp4_experts_quantize(
     a,
     mask,
diff --git a/flashinfer/aot.py b/flashinfer/aot.py
index dfb05150..d26d5407 100644
--- a/flashinfer/aot.py
+++ b/flashinfer/aot.py
@@ -543,6 +543,7 @@ def gen_all_modules(
     if add_comm:
         from .jit.comm import (
             gen_comm_alltoall_module,
+            gen_dcp_alltoall_module,
             gen_moe_alltoall_module,
             gen_trtllm_comm_module,
             gen_trtllm_mnnvl_comm_module,
@@ -554,6 +555,11 @@ def gen_all_modules(
             jit_specs.append(gen_trtllm_comm_module())
             jit_specs.append(gen_trtllm_mnnvl_comm_module())
             jit_specs.append(gen_moe_alltoall_module())
+            # dcp_alltoall: kernel itself supports SM90+, but ptxas 12.6.0 has
--
 
-def flashinfer_api(func: Callable = None) -> Callable:
+# ---------------------------------------------------------------------------
+# Trace template registry
+# ---------------------------------------------------------------------------
+# Populated automatically by _attach_fi_trace whenever @flashinfer_api is
+# given a trace= argument.  Each entry is (original_func, template, label)
+# where label is the template's name_prefix (or op_type as fallback).
+#
+# For dispatch callables (trace=some_fn), every template listed in
+# some_fn.templates is registered if that attribute exists.
+#
+# Read by tests/trace/test_fi_trace_template_consistency.py to auto-discover
+# all registered templates without requiring manual maintenance.
+_TRACE_REGISTRY: List[Tuple[Callable, Any, str]] = []
+
+
+def _attach_fi_trace(
+    wrapped: Callable,
+    original: Callable,
+    trace_template=None,
+) -> Callable:
+    """Attach a ``fi_trace`` callable to *wrapped*.
+
+    Three resolution strategies, tried in order:
+
--
+
+        warnings.warn(
+            f"[flashinfer] Failed to attach fi_trace to '{_func_name}': "
+            f"{type(_exc).__name__}: {_exc}\n"
+            f"The function will work normally but fi_trace will be unavailable. "
+            f"Fix the TraceTemplate passed to @flashinfer_api(trace=...).",
+            stacklevel=3,
+        )
+    return wrapped
+
+
+def flashinfer_api(func: Callable = None, *, trace=None) -> Callable:
     """
     Decorator to FlashInfer's APIs.
 
@@ -1489,11 +1644,12 @@ def flashinfer_api(func: Callable = None) -> Callable:
     - The %i pattern is automatically replaced with the process ID for multi-process environments.
     - The logger does not propagate to the root logger to avoid duplicate logs.
     """
-    # If logging is disabled, return original function with zero overhead
+    # If logging is disabled, return original function with zero overhead.
+    # We still attach fi_trace so it is always available regardless of log level.
     if _API_LOG_LEVEL == 0:
         if func is None:
-            return lambda f: f
-        return func
--
 @functools.cache
@@ -135,7 +136,7 @@ class BatchAttention:
             causal,
         )
 
-    @flashinfer_api
+    @flashinfer_api(trace=batch_attention_run_trace)
     def run(
         self,
         q: torch.Tensor,
@@ -209,6 +210,8 @@ class BatchAttentionWithAttentionSinkWrapper(BatchPrefillWithPagedKVCacheWrapper
     a convenient interface for using attention sinks during prefill or decode attention.
     """
 
+    # No @flashinfer_api here: parent class BatchPrefillWithPagedKVCacheWrapper
+    # already decorates __init__, so decorating again produces double log entries.
     def __init__(
         self,
         float_workspace_buffer: torch.Tensor,
diff --git a/flashinfer/attention/cute_dsl/__init__.py b/flashinfer/attention/cute_dsl/__init__.py
new file mode 100644
index 00000000..3e029627
--- /dev/null
+++ b/flashinfer/attention/cute_dsl/__init__.py
@@ -0,0 +1,37 @@
+# Copyright (c) 2026 by FlashInfer team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
--
 
@@ -31,7 +37,7 @@ def get_cascade_module():
     return gen_cascade_module().build_and_load()
 
 
-@flashinfer_api
+@flashinfer_api(trace=merge_state_trace)
 @register_custom_op("flashinfer::merge_state", mutates_args=())
 def merge_state(
     v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
@@ -98,7 +104,7 @@ def _fake_merge_state(
     return v, s
 
 
-@flashinfer_api
+@flashinfer_api(trace=merge_state_in_place_trace)
 @register_custom_op("flashinfer::merge_state_in_place", mutates_args=("v", "s"))
 def merge_state_in_place(
     v: torch.Tensor,
@@ -159,7 +165,7 @@ def _fake_merge_state_in_place(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=merge_states_trace)
 @register_custom_op("flashinfer::merge_states", mutates_args=())
 def merge_states(v: torch.Tensor, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
     r"""Merge multiple attention states (v, s).
@@ -512,7 +518,7 @@ class MultiLevelCascadeAttentionWrapper:
 
     begin_forward = plan
 
-    @flashinfer_api
+    @flashinfer_api(trace=multi_level_cascade_run_trace)
     def run(
         self,
         q: torch.Tensor,
diff --git a/flashinfer/comm/__init__.py b/flashinfer/comm/__init__.py
index 5f186002..31d23a99 100644
--- a/flashinfer/comm/__init__.py
+++ b/flashinfer/comm/__init__.py
@@ -65,4 +65,15 @@ from .trtllm_moe_alltoall import (
     moe_a2a_wrap_payload_tensor_in_workspace as moe_a2a_wrap_payload_tensor_in_workspace,
 )
 
+# DCP A2A (Decode Context Parallel Attention Reduction)
+from .dcp_alltoall import decode_cp_a2a_alltoall as decode_cp_a2a_alltoall
+from .dcp_alltoall import (
+    decode_cp_a2a_allocate_workspace as decode_cp_a2a_allocate_workspace,
+)
+from .dcp_alltoall import decode_cp_a2a_init_workspace as decode_cp_a2a_init_workspace
+from .dcp_alltoall import decode_cp_a2a_workspace_size as decode_cp_a2a_workspace_size
+
 # from .mnnvl import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo
--
 from .trtllm_ar import trtllm_create_ipc_workspace_for_all_reduce_fusion
@@ -449,7 +450,7 @@ def create_allreduce_fusion_workspace(
 # ============================================================================
 
 
-@flashinfer_api
+@flashinfer_api(trace=allreduce_fusion_trace)
 def allreduce_fusion(
     input: torch.Tensor,
     workspace: AllReduceFusionWorkspace,
diff --git a/flashinfer/comm/dcp_alltoall.py b/flashinfer/comm/dcp_alltoall.py
new file mode 100644
index 00000000..3047f76c
--- /dev/null
+++ b/flashinfer/comm/dcp_alltoall.py
@@ -0,0 +1,255 @@
+"""
+DCP All-to-All Operations for DCP Attention Reduction
+
+Provides the DCP LL128 FIFO-based all-to-all kernel for context-parallel
+attention reduction. Uses SM90+ features (TMA, mbarrier).
+
+Usage protocol::
+
+    # 1. Query workspace size
+    ws_bytes = decode_cp_a2a_workspace_size(cp_size)
+
--
+
+
+# ─── Public API ───────────────────────────────────────────────────────────
+
+
+@flashinfer_api
+def decode_cp_a2a_workspace_size(cp_size: int) -> int:
+    """Return the workspace size **in bytes** per rank for the given CP group size.
+
+    Args:
+        cp_size: Context-parallel group size (number of ranks).
+
+    Returns:
+        Workspace size in bytes per rank.
+
+    Example::
+
+        >>> decode_cp_a2a_workspace_size(4)
+        16778240
+    """
+    return get_dcp_alltoall_module().get_workspace_size_per_rank(cp_size)
+
+
+@flashinfer_api
+def decode_cp_a2a_allocate_workspace(
+    cp_size: int,
+    cp_rank: int,
+    *,
+    mapping: Optional[Mapping] = None,
+    mnnvl_config: Optional[MnnvlConfig] = None,
+) -> torch.Tensor:
+    """Allocate a workspace tensor of shape ``[cp_size, ws_elems_per_rank]``.
+
+    After allocation, call :func:`decode_cp_a2a_init_workspace` followed by a
+    cross-rank barrier before the first :func:`decode_cp_a2a_alltoall` call.
+
+    Two allocation modes:
+
+    - **MNNVL** (``mapping`` provided): Cross-rank visible GPU memory via
+      FlashInfer's ``MnnvlMemory``. Required for multi-node or when ranks
+      cannot see each other's device memory directly.
+    - **Plain device memory** (``mapping=None``): Standard ``torch.zeros``
+      allocation. Sufficient for single-node with NVLink P2P.
+
--
+
+    ws_elems_per_rank = (ws_bytes + 7) // 8
+    return torch.zeros(cp_size, ws_elems_per_rank, dtype=torch.int64, device="cuda")
+
+
+@flashinfer_api
+def decode_cp_a2a_init_workspace(
+    workspace: torch.Tensor,
+    cp_rank: int,
+    cp_size: int,
+) -> None:
+    """Initialize the workspace FIFO buffers. Call once before the first alltoall.
+
+    Resets the FIFO buffers in the **local** workspace row
+    (``workspace[cp_rank]``). This function is **synchronous**: when it
+    returns, the GPU memset is guaranteed to have completed.
+
+    .. important::
+        With MNNVL workspaces, **all ranks** must complete
+        ``decode_cp_a2a_init_workspace`` and execute a cross-rank barrier
+        (e.g. ``dist.barrier(group)``) before **any** rank calls
+        :func:`decode_cp_a2a_alltoall`. Without the barrier, a rank may
+        start writing to a peer's FIFO before that peer has finished
+        initializing → deadlock.
+
+    Args:
--
+    # subsequent cross-GPU alltoall can race with the unfinished memset
+    # on MNNVL memory, causing a deadlock.
+    torch.cuda.current_stream().synchronize()
+
+
+@flashinfer_api(trace=decode_cp_a2a_alltoall_trace)
+def decode_cp_a2a_alltoall(
+    partial_o: torch.Tensor,
+    softmax_stats: torch.Tensor,
+    workspace: torch.Tensor,
+    cp_rank: int,
+    cp_size: int,
+    enable_pdl: Optional[bool] = None,
+) -> tuple[torch.Tensor, torch.Tensor]:
+    """Perform the DCP all-to-all exchange.
+
+    Each rank sends its ``partial_o[..., peer, :]`` slice to the
+    corresponding peer and receives all peers' contributions into the
+    output tensors.
+
+    Args:
+        partial_o: ``[..., cp_size, D]`` — half or bfloat16.
+            ``D * element_size`` must be 16-byte aligned.
+        softmax_stats: ``[..., cp_size, S]`` — float32, ``S >= 2`` and even.
+            Batch dimensions must match ``partial_o``.
+        workspace: ``[cp_size, ws_elems_per_rank]`` int64 tensor from
--
+    MixedCommOp.ALLREDUCE_ALLGATHER: _allreduce_allgather,
+    MixedCommOp.REDUCESCATTER_ALLREDUCE: _reducescatter_allreduce,
+}
+
+
+@flashinfer_api
+@backend_requirement(
+    backend_checks={},
+    common_check=_common_check,
+)
+def run_mixed_comm(
+    op: MixedCommOp,
+    handler: MixedCommHandler,
+    x_in: torch.Tensor,
+    x_out: torch.Tensor | None = None,
+    mode: MixedCommMode | None = None,
+) -> torch.Tensor:
+    """Execute a mixed communication operation.
+
+    This is the main entry point for running communication collectives
+    through the mixed communication handler. It supports fused GPU kernels
+    (using virtual memory intra-node and nvshmem inter-node), NCCL-based
+    fallbacks, and autotuned mode selection.
+
+    Args:
+        op: The communication operation to perform.
--
 @functools.cache
@@ -28,7 +29,7 @@ def get_concat_mla_module():
     return gen_concat_mla_module().build_and_load()
 
 
-@flashinfer_api
+@flashinfer_api(trace=concat_mla_k_trace)
 def concat_mla_k(
     k: torch.Tensor,
     k_nope: torch.Tensor,
diff --git a/flashinfer/cudnn/decode.py b/flashinfer/cudnn/decode.py
index 195ca2d4..9b593095 100644
--- a/flashinfer/cudnn/decode.py
+++ b/flashinfer/cudnn/decode.py
@@ -4,6 +4,7 @@ from typing import Optional
 import torch
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.attention import cudnn_batch_decode_trace
 from .utils import get_cudnn_fmha_gen_module
 
 try:
@@ -253,7 +254,7 @@ def _batch_decode_with_kv_cache(
     return out
 
 
-@flashinfer_api
+@flashinfer_api(trace=cudnn_batch_decode_trace)
 def cudnn_batch_decode_with_kv_cache(
     q: torch.Tensor,
     k_cache: torch.Tensor,
diff --git a/flashinfer/cudnn/prefill.py b/flashinfer/cudnn/prefill.py
index fc1bbb5f..b16d6043 100644
--- a/flashinfer/cudnn/prefill.py
+++ b/flashinfer/cudnn/prefill.py
@@ -4,6 +4,7 @@ from typing import Optional
 import torch
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.attention import cudnn_batch_prefill_trace
 from .utils import get_cudnn_fmha_gen_module
 
 try:
@@ -558,7 +559,7 @@ def _batch_prefill_with_kv_cache(
         return out, None
 
 
-@flashinfer_api
+@flashinfer_api(trace=cudnn_batch_prefill_trace)
 def cudnn_batch_prefill_with_kv_cache(
     q: torch.Tensor,
     k_cache: torch.Tensor,
diff --git a/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py b/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
index 0b50c22c..f25aa6fd 100644
--- a/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
+++ b/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
@@ -38,6 +38,7 @@ import torch
 from cutlass import Float32, Int32, Int64, Uint32, Uint8
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.norm import add_rmsnorm_fp4quant_trace
 from ..utils import device_support_pdl
 from .fp4_common import (
     # Constants
@@ -1042,7 +1043,7 @@ def _get_compiled_kernel(
     return tensor_api
 
 
-@flashinfer_api
+@flashinfer_api(trace=add_rmsnorm_fp4quant_trace)
 def add_rmsnorm_fp4quant(
     input: torch.Tensor,
     residual: torch.Tensor,
diff --git a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py
index 333697ab..b7aabc36 100644
--- a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py
+++ b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py
@@ -20,6 +20,7 @@ import torch
 from cutlass import Float32, Int32
 
 from flashinfer.api_logging import flashinfer_api
+from flashinfer.trace.templates.attention import cute_dsl_batch_mla_run_trace
 from flashinfer.utils import device_support_pdl
 from flashinfer.cute_dsl.utils import (
     get_max_active_clusters,
@@ -519,7 +520,7 @@ class BatchMLADecodeCuteDSLWrapper:
                 f"out_dtype={self._o_dtype}"
             )
 
-    @flashinfer_api
+    @flashinfer_api(trace=cute_dsl_batch_mla_run_trace)
     def run(
         self,
         q: torch.Tensor,
diff --git a/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py b/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py
index 58a24abe..ee0cd5e7 100644
--- a/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py
+++ b/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py
@@ -21,6 +21,7 @@ import cutlass.cute as cute
 from cutlass.cute.typing import Int32
 
 from flashinfer.api_logging import flashinfer_api
+from flashinfer.trace.templates.attention import cute_dsl_batch_prefill_run_trace
 
 from ..config import AttentionConfig, AttentionFusion
 from ..fusion.mask import MaskType
@@ -371,7 +372,7 @@ class BatchPrefillCuteDSLWrapper:
                     f"device={self._device}"
                 )
 
-    @flashinfer_api
+    @flashinfer_api(trace=cute_dsl_batch_prefill_run_trace)
     def run(
         self,
         q: torch.Tensor,
diff --git a/flashinfer/cute_dsl/rmsnorm_fp4quant.py b/flashinfer/cute_dsl/rmsnorm_fp4quant.py
index bc4acffc..97ce68a1 100644
--- a/flashinfer/cute_dsl/rmsnorm_fp4quant.py
+++ b/flashinfer/cute_dsl/rmsnorm_fp4quant.py
@@ -32,6 +32,7 @@ import torch
 from cutlass import Float32, Int32, Uint8
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.norm import rmsnorm_fp4quant_trace
 from ..utils import device_support_pdl
 from .fp4_common import (
     # Constants
@@ -771,7 +772,7 @@ def _get_compiled_kernel(
     return tensor_api
 
 
-@flashinfer_api
+@flashinfer_api(trace=rmsnorm_fp4quant_trace)
 def rmsnorm_fp4quant(
     input: torch.Tensor,
     weight: torch.Tensor,
diff --git a/flashinfer/decode.py b/flashinfer/decode.py
index 822aca40..5e9eb515 100644
--- a/flashinfer/decode.py
+++ b/flashinfer/decode.py
@@ -22,6 +22,12 @@ from typing import Any, List, Literal, Optional, Tuple, Union, overload
 import torch
 
 from .api_logging import flashinfer_api
+from .trace.templates.attention import (
+    gqa_paged_decode_trace,
+    single_decode_with_kv_cache_trace,
+    trtllm_batch_decode_trace,
+    xqa_batch_decode_trace,
+)
 
 ## NOTE: MLA functions have been moved to mla.py, but we keep the aliases here for backward compatibility.
 from .mla import (
@@ -400,7 +406,7 @@ def single_decode_with_kv_cache(
 ) -> Tuple[torch.Tensor, torch.Tensor]: ...
 
 
-@flashinfer_api
+@flashinfer_api(trace=single_decode_with_kv_cache_trace)
 def single_decode_with_kv_cache(
     q: torch.Tensor,
     k: torch.Tensor,
@@ -1215,7 +1221,7 @@ class BatchDecodeWithPagedKVCacheWrapper:
         kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]: ...
 
-    @flashinfer_api
+    @flashinfer_api(trace=gqa_paged_decode_trace)
     def run(
         self,
         q: torch.Tensor,
@@ -1577,6 +1583,8 @@ class CUDAGraphBatchDecodeWithPagedKVCacheWrapper(BatchDecodeWithPagedKVCacheWra
     :class:`BatchDecodeWithPagedKVCacheWrapper`
     """
 
+    # No @flashinfer_api here: parent class BatchDecodeWithPagedKVCacheWrapper
+    # already decorates __init__, so decorating again produces double log entries.
     def __init__(
         self,
         workspace_buffer: torch.Tensor,
@@ -2232,7 +2240,7 @@ def get_trtllm_gen_decode_module(*args):
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_batch_decode_trace)
 def trtllm_batch_decode_with_kv_cache(
     query: torch.Tensor,
     kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
@@ -2618,7 +2626,7 @@ def trtllm_batch_decode_with_kv_cache(
 
 
 # xqa uses NHD layout
-@flashinfer_api
+@flashinfer_api(trace=xqa_batch_decode_trace)
 def xqa_batch_decode_with_kv_cache(
     query: torch.Tensor,
     kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
diff --git a/flashinfer/fi_trace.py b/flashinfer/fi_trace.py
new file mode 100644
index 00000000..1104eb6f
--- /dev/null
+++ b/flashinfer/fi_trace.py
@@ -0,0 +1,285 @@
+# Copyright (c) 2025 by FlashInfer team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
--
+
+"""
+fi_trace: Generate `flashinfer-bench <https://github.com/flashinfer-ai/flashinfer-bench>`_
+compatible definition JSON for FlashInfer APIs.
+
+Every ``@flashinfer_api(trace=<template>)``-decorated function supports two
+usage modes:
+
+Auto-dump (recommended)
+-----------------------
+Set environment variables **before** importing flashinfer, then run your
+workload normally.  No explicit ``fi_trace`` call is needed.
+
+.. code-block:: bash
+
+    FLASHINFER_TRACE_DUMP=1 \\
+    FLASHINFER_TRACE_DUMP_DIR=./fi_trace_out \\
+    python my_script.py
+
+Every decorated function writes a ``<name>.json`` file on its **first** call
+for each unique set of const-axis values (e.g. head dimensions, vocab size).
+Subsequent calls with the same shape are deduplicated — the file is written
+only once per process.  The output directory is created automatically.
+
+Explicit call (for selective or programmatic use)
+-------------------------------------------------
--
+from pathlib import Path
+from typing import Any, Callable, Dict, Optional, Union
+
+# ---------------------------------------------------------------------------
+# Legacy registry — kept for backwards compatibility.
+# New code should use @flashinfer_api(trace=TraceTemplate(...)) instead.
+# ---------------------------------------------------------------------------
+
+_REGISTRY: Dict[str, Any] = {}
+
+
+def register_fi_trace(qualname: str, spec: Any) -> None:
+    """Register a legacy FiTraceSpec for the function with the given qualname.
+
+    .. deprecated::
+        Use ``@flashinfer_api(trace=TraceTemplate(...))`` instead.
+    """
+    _REGISTRY[qualname] = spec
+
+
+def build_fi_trace_fn(spec: Any) -> Callable[..., Dict[str, Any]]:
+    """Build a fi_trace callable from a legacy FiTraceSpec.
+
+    .. deprecated::
+        Use ``TraceTemplate.build_fi_trace_fn`` instead.
+    """
+    # Import the old implementation from the trace package for backwards compat.
+    from .trace.template import (  # noqa: PLC0415,F401
+        Const,
+        Scalar,
+        Tensor,
+        TraceTemplate,
+        Var,
+    )
+    import json  # noqa: PLC0415
+    import os  # noqa: PLC0415
--
+    """Generate a flashinfer-bench definition JSON for any FlashInfer API call.
+
+    Parameters
+    ----------
+    func_or_method:
+        A ``@flashinfer_api``-decorated function or (bound) method.
+    save_dir:
+        Directory where the JSON definition file should be written.
+        Falls back to ``FLASHINFER_TRACE_DUMP_DIR`` env-var when *None*.
+    **kwargs:
+        The same tensor arguments you would pass to the real API.
+
+    Returns
+    -------
+    dict
+        A flashinfer-bench compatible definition dictionary.
+
+    Examples
+    --------
+    Standalone function::
+
+        defn = fi_trace(flashinfer.norm.rmsnorm, input=hidden, weight=weight)
+
+    Bound method (instance.run)::
+
+        defn = fi_trace(wrapper.run, q=q_tensor, paged_kv_cache=(k, v))
--
+    trace_fn = getattr(actual_func, "fi_trace", None)
+    if trace_fn is None:
+        qualname = getattr(actual_func, "__qualname__", repr(actual_func))
+        raise ValueError(
+            f"No fi_trace spec is registered for '{qualname}'. "
+            "Only @flashinfer_api(trace=...)-decorated functions support fi_trace."
+        )
+    return trace_fn(save_dir=save_dir, **kwargs)
diff --git a/flashinfer/fused_moe/__init__.py b/flashinfer/fused_moe/__init__.py
index df6e1f72..d983f9d4 100644
--- a/flashinfer/fused_moe/__init__.py
+++ b/flashinfer/fused_moe/__init__.py
@@ -17,6 +17,8 @@ limitations under the License.
 from .core import (
     convert_to_block_layout,
     cutlass_fused_moe,
+    interleave_moe_scales_for_sm90_mixed_gemm,
+    interleave_moe_weights_for_sm90_mixed_gemm,
     gen_cutlass_fused_moe_sm120_module,
     gen_cutlass_fused_moe_sm103_module,
     gen_cutlass_fused_moe_sm100_module,
@@ -64,6 +66,8 @@ __all__ = [
     "WeightLayout",
     "convert_to_block_layout",
     "cutlass_fused_moe",
+    "interleave_moe_scales_for_sm90_mixed_gemm",
--
+        ),
     )
 
 
-# ref: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py#L121
 @flashinfer_api
+def interleave_moe_scales_for_sm90_mixed_gemm(
+    scales: torch.Tensor,
+    group_size: int = 32,
+) -> torch.Tensor:
+    """Interleave MXFP4 block scales for the SM90 mixed-input MoE GEMM.
+
+    The kernel expects scales in layout
+    ``(num_experts, K // (group_size * 4), rows * 4)`` rather than the natural
+    ``(num_experts, rows, K // group_size)`` produced by the MXFP4 quantizer.
+    This helper performs the reshape + permute equivalent to TensorRT-LLM's
+    ``WFP4A16FusedMoEMethod.load_quant_scales`` (PR #12451), with the fixed
+    interleave factor of ``128 // group_size`` used for MXFP4.
+
+    Parameters
+    ----------
+    scales:
+        ``[num_experts, rows, K // group_size]`` uint8 tensor of E8M0 block
+        scales.
+    group_size:
+        MXFP4 quantization group size (default 32).
--
+        scales.reshape(e, rows, kgs // factor, factor).permute(0, 2, 1, 3).contiguous()
+    )
+    return tmp.reshape(e, kgs // factor, rows * factor)
+
+
+@flashinfer_api
+def interleave_moe_weights_for_sm90_mixed_gemm(
+    weight: torch.Tensor,
+    quant_type: str = "fp4",
+) -> torch.Tensor:
+    """Interleave 4-bit packed MoE weights for the SM90 mixed-input GEMM.
+
+    The SM90 mixed-dtype MoE GEMM (used by ``cutlass_fused_moe`` with
+    ``use_w4_group_scaling=True``) expects weights in a specific interleaved
+    layout; without preprocessing, the LUT-based FP4→BF16 conversion reads
+    bytes from the wrong positions and the output diverges from a dequantized
+    reference for any K > 128. TensorRT-LLM's W4A16 MoE runs the equivalent
+    preprocessing at weight-load time (see
+    ``interleave_4bit_weights_for_Hopper_mixed_gemm`` in TRT-LLM PR #12451).
+
+    Parameters
+    ----------
+    weight:
+        ``[num_experts, n, k // 2]`` uint8 CUDA tensor (4-bit values packed
+        two-per-byte).
+    quant_type:
--
+    )
+    return out
+
+
+# ref: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py#L121
+@flashinfer_api(trace=cutlass_fused_moe_trace)
 def cutlass_fused_moe(
     input: torch.Tensor,
     token_selected_experts: torch.Tensor,
@@ -1027,8 +1151,8 @@ def get_trtllm_moe_sm100_module():
                     DynamicTensorSpec(
                         input_idx,
                         dim_idx,
-                        get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens, 1),
-                        lambda x: min(last_positive_power_of_2(x), tune_max_num_tokens),
+                        get_hybrid_num_tokens_buckets(tune_max_num_tokens, 1),
+                        lambda x: map_to_hybrid_bucket(x, tune_max_num_tokens),
                         initializers,
                     ),
                 ),
@@ -2344,7 +2468,7 @@ def _validate_routing_replay_out(
         raise ValueError("routing_replay_out must be contiguous (packed row-major)")
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_bf16_moe_trace)
 def trtllm_bf16_moe(
     routing_logits: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
@@ -2452,7 +2576,7 @@ def trtllm_bf16_moe(
         return result
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_bf16_routed_moe_trace)
 def trtllm_bf16_routed_moe(
     topk_ids: torch.Tensor,
     hidden_states: torch.Tensor,
@@ -2557,7 +2681,7 @@ def trtllm_bf16_routed_moe(
         return result
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp8_per_tensor_scale_moe_trace)
 def trtllm_fp8_per_tensor_scale_moe(
     routing_logits: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
@@ -2658,7 +2782,7 @@ def trtllm_fp8_per_tensor_scale_moe(
         return result
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp8_block_scale_moe_trace_dispatch)
 def trtllm_fp8_block_scale_moe(
     routing_logits: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
@@ -2779,7 +2903,7 @@ def trtllm_fp8_block_scale_moe(
         return result
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp8_block_scale_routed_moe_trace)
 def trtllm_fp8_block_scale_routed_moe(
     topk_ids: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
@@ -2893,7 +3017,7 @@ def trtllm_fp8_block_scale_routed_moe(
         return result
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp4_block_scale_moe_trace_dispatch)
 def trtllm_fp4_block_scale_moe(
     routing_logits: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
@@ -3030,7 +3154,7 @@ def trtllm_fp4_block_scale_moe(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp4_block_scale_routed_moe_trace)
 def trtllm_fp4_block_scale_routed_moe(
     topk_ids: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
@@ -3165,7 +3289,7 @@ def trtllm_fp4_block_scale_routed_moe(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_mxint4_block_scale_moe_trace)
 def trtllm_mxint4_block_scale_moe(
     routing_logits: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
diff --git a/flashinfer/fused_moe/cute_dsl/b12x_moe.py b/flashinfer/fused_moe/cute_dsl/b12x_moe.py
index d2cbc8b0..34916df5 100644
--- a/flashinfer/fused_moe/cute_dsl/b12x_moe.py
+++ b/flashinfer/fused_moe/cute_dsl/b12x_moe.py
@@ -42,11 +42,12 @@ from typing import Optional, Tuple
 import torch
 
 from ...api_logging import flashinfer_api
+from ...trace.templates.moe import b12x_fused_moe_trace, b12x_moe_wrapper_run_trace
 from ...utils import supported_compute_capability
 
 
 @supported_compute_capability([120, 121])
-@flashinfer_api
+@flashinfer_api(trace=b12x_fused_moe_trace)
 def b12x_fused_moe(
     x: torch.Tensor,
     w1_weight: torch.Tensor,
@@ -293,7 +294,7 @@ class B12xMoEWrapper:
             device=self.device,
         )
 
-    @flashinfer_api
+    @flashinfer_api(trace=b12x_moe_wrapper_run_trace)
     def run(
         self,
         x: torch.Tensor,
diff --git a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py
index f6cf1b67..e266cb77 100644
--- a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py
+++ b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py
@@ -89,8 +89,8 @@ from flashinfer.cute_dsl.fp4_common import (
     st_global_u64,
     scatter_add_bf16x2,
 )
-from flashinfer.gemm.kernels.dense_blockscaled_gemm_sm120 import (
-    Sm120BlockScaledDenseGemmKernel as DenseGemmKernel,
+from flashinfer.gemm.kernels.dense_blockscaled_gemm_sm120_b12x import (
+    Sm120B12xBlockScaledDenseGemmKernel as DenseGemmKernel,
 )
 
 
diff --git a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py
index e7fdae92..670b3ad8 100644
--
 from .moe_utils import (
@@ -530,7 +534,7 @@ class CuteDslMoEWrapper:
             enable_pdl=enable_pdl,
         )
 
-    @flashinfer_api
+    @flashinfer_api(trace=cute_dsl_moe_wrapper_run_trace)
     def run(
         self,
         x: torch.Tensor,
@@ -686,7 +690,7 @@ def _cute_dsl_fused_moe_nvfp4_impl(
 
 
 @supported_compute_capability([100, 103])
-@flashinfer_api
+@flashinfer_api(trace=cute_dsl_fused_moe_nvfp4_trace)
 def cute_dsl_fused_moe_nvfp4(
     x: torch.Tensor,
     x_sf: torch.Tensor,
diff --git a/flashinfer/fused_moe/cute_dsl/tuner.py b/flashinfer/fused_moe/cute_dsl/tuner.py
index 0cc8628e..636043db 100644
--- a/flashinfer/fused_moe/cute_dsl/tuner.py
+++ b/flashinfer/fused_moe/cute_dsl/tuner.py
@@ -42,8 +42,8 @@ from ...autotuner import (
     TuningConfig,
 )
 from ..utils import (
-    get_last_power_of_2_num_tokens_buckets,
-    last_positive_power_of_2,
+    get_hybrid_num_tokens_buckets,
+    map_to_hybrid_bucket,
 )
 
 logger = logging.getLogger(__name__)
@@ -273,10 +273,8 @@ class CuteDslFusedMoENvfp4Runner(TunableRunner):
                 DynamicTensorSpec(
--
 import torch
@@ -137,7 +138,7 @@ def get_dsv3_fused_routing_module():
 
 
 @backend_requirement({}, common_check=_check_dsv3_fused_routing_supported)
-@flashinfer_api
+@flashinfer_api(trace=fused_topk_deepseek_trace)
 def fused_topk_deepseek(
     scores: torch.Tensor,
     bias: torch.Tensor,
diff --git a/flashinfer/fused_moe/utils.py b/flashinfer/fused_moe/utils.py
index 004271a1..91f37aa5 100644
--- a/flashinfer/fused_moe/utils.py
+++ b/flashinfer/fused_moe/utils.py
@@ -209,29 +209,102 @@ def nearest_in_buckets(x: int, buckets: List[int]) -> int:
     return min(max(next_positive_power_of_2(x), buckets[0]), buckets[-1])
 
 
-def get_power_of_2_num_tokens_buckets(max_num_tokens) -> Tuple[int]:
-    """Return descending power-of-2 buckets from ``next_power_of_2(max_num_tokens)`` down to 1."""
-    max_num_tokens = next_positive_power_of_2(max_num_tokens)
-    num_token_buckets = []
-    m = max_num_tokens
-    while m >= 1:
-        num_token_buckets.append(m)
-        m //= 2
+_PHASE1_END = 256
--
 
@@ -106,7 +114,7 @@ TILE_V = 8  # pretranspose tile size
 # ============================================================================
 
 
-@flashinfer_api
+@flashinfer_api(trace=gated_delta_rule_decode_trace)
 def gated_delta_rule_decode_pretranspose(
     q: torch.Tensor,
     k: torch.Tensor,
@@ -394,7 +402,7 @@ def gated_delta_rule_decode_pretranspose(
 # ============================================================================
 
 
-@flashinfer_api
+@flashinfer_api(trace=gated_delta_rule_decode_trace)
 def gated_delta_rule_decode(
     q: torch.Tensor,
     k: torch.Tensor,
@@ -535,7 +543,7 @@ def gated_delta_rule_decode(
 # ============================================================================
 
 
-@flashinfer_api
+@flashinfer_api(trace=gdn_mtp_trace)
 def gated_delta_rule_mtp(
     q: torch.Tensor,
     k: torch.Tensor,
diff --git a/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py b/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
index 68398d28..53fe44ce 100644
--- a/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
+++ b/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
@@ -3333,8 +3333,7 @@ class GatedDeltaNetChunkedKernel:
 
         gate_handle = load_gate_consumer.wait_and_advance()
 
-        max_coord = tTR_tCcShared[cute.size(tTR_tCcShared) - 1]
-        cumprod_total = sCumprod[max_coord[1], 0, gate_handle.index]
+        cumprod_total = sCumprod[sCumprod.shape[0] - 1, 0, gate_handle.index]
 
         valid_state = not is_first_chunk or self.use_initial_state
         if cutlass.const_expr(valid_state):
diff --git a/flashinfer/gdn_kernels/blackwell/gdn_prefill.py b/flashinfer/gdn_kernels/blackwell/gdn_prefill.py
index 82dcc72b..aafcc671 100644
--- a/flashinfer/gdn_kernels/blackwell/gdn_prefill.py
--
     register_custom_op,
@@ -95,7 +96,7 @@ def get_gdn_prefill_module():
     return SimpleNamespace(gdn_prefill=gdn_prefill)
 
 
-@flashinfer_api
+@flashinfer_api(trace=gdn_prefill_trace)
 def chunk_gated_delta_rule(
     q: torch.Tensor,
     k: torch.Tensor,
diff --git a/flashinfer/gemm/__init__.py b/flashinfer/gemm/__init__.py
index a7795beb..def82216 100644
--- a/flashinfer/gemm/__init__.py
+++ b/flashinfer/gemm/__init__.py
@@ -61,11 +61,11 @@ try:
     from flashinfer.cute_dsl.utils import is_cute_dsl_available
 
     if is_cute_dsl_available():
-        from .kernels.dense_blockscaled_gemm_sm120 import (
-            Sm120BlockScaledDenseGemmKernel as Sm120BlockScaledDenseGemmKernel,
+        from .kernels.dense_blockscaled_gemm_sm120_b12x import (
+            Sm120B12xBlockScaledDenseGemmKernel as Sm120B12xBlockScaledDenseGemmKernel,
         )
 
-        _cute_dsl_kernels.append("Sm120BlockScaledDenseGemmKernel")
+        _cute_dsl_kernels.append("Sm120B12xBlockScaledDenseGemmKernel")
 except ImportError:
--
 from ..utils import (
@@ -325,7 +339,7 @@ def _heuristic_func_mm_bf16(
     common_check=_check_mm_bf16_problem_size,
     heuristic_func=_heuristic_func_mm_bf16,
 )
-@flashinfer_api
+@flashinfer_api(trace=mm_bf16_trace)
 def mm_bf16(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -514,7 +528,7 @@ def _heuristic_func_bmm_bf16(
     common_check=_check_bmm_bf16_problem_size,
     heuristic_func=_heuristic_func_bmm_bf16,
 )
-@flashinfer_api
+@flashinfer_api(trace=bmm_bf16_trace)
 def bmm_bf16(
     A: torch.Tensor,
     B: torch.Tensor,
@@ -815,8 +829,8 @@ _FP8_GEMM_SM100_TUNING_CONFIG = TuningConfig(
         DynamicTensorSpec(
             (0,),  # a_tensor_index
             (-2,),
-            get_last_power_of_2_num_tokens_buckets,
-            last_positive_power_of_2,
+            get_hybrid_num_tokens_buckets,
+            map_to_hybrid_bucket_uncapped,
         ),
     ),
     constraint_specs=(
@@ -871,8 +885,8 @@ _BF16_GEMM_SM100_TUNING_CONFIG = TuningConfig(
         DynamicTensorSpec(
             (0,),  # a_tensor_index
             (-2,),
-            get_last_power_of_2_num_tokens_buckets,
-            last_positive_power_of_2,
--
     constraint_specs=(
@@ -1095,7 +1109,7 @@ def get_tgv_gemm_sm10x_module(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=tgv_gemm_sm100_trace)
 def tgv_gemm_sm100(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -1173,8 +1187,8 @@ def tgv_gemm_sm100(
             DynamicTensorSpec(
                 (a_tensor_index,),
                 (-2,),
-                get_last_power_of_2_num_tokens_buckets,
-                last_positive_power_of_2,
+                get_hybrid_num_tokens_buckets,
+                map_to_hybrid_bucket_uncapped,
             ),
         ),
         constraint_specs=(
@@ -1437,6 +1451,7 @@ class SegmentGEMMWrapper:
     True
     """
 
+    @flashinfer_api
     def __init__(
         self, float_workspace_buffer: torch.Tensor, backend: str = "auto"
     ) -> None:
@@ -1469,7 +1484,7 @@ class SegmentGEMMWrapper:
         self._float_workspace_buffer = float_workspace_buffer
         self._int_workspace_buffer = int_workspace_buffer
 
-    @flashinfer_api
+    @flashinfer_api(trace=segment_gemm_run_trace)
     def run(
         self,
         x: torch.Tensor,
@@ -2084,6 +2099,8 @@ def build_cudnn_gemm_fp4_graph_override_shape(
     return graph
 
 
+# Internal helper called from mm_fp4; the user-facing mm_fp4 is already
+# decorated, so decorating here would double-log the same invocation.
 def execute_cudnn_gemm_fp4_graph_override_shape(
     graph,
     a,
@@ -2319,6 +2336,8 @@ def build_cudnn_gemm_mxfp8_graph_override_shape(
     return graph
 
 
+# Internal helper called from mm_mxfp8; the user-facing mm_mxfp8 is already
+# decorated, so decorating here would double-log the same invocation.
 def execute_cudnn_gemm_mxfp8_graph_override_shape(
     graph,
--
 ):
@@ -3161,7 +3184,7 @@ def _expand_block_scale_tensor_shape(block_scale_tensor, batch_size):
     return (tuple(block_scale_shape), tuple(block_scale_stride))
 
 
-@flashinfer_api
+@flashinfer_api(trace=mm_fp8_trace)
 def mm_fp8(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -3990,7 +4013,7 @@ def _heuristic_func_mm_mxfp8(
     common_check=_check_mm_mxfp8_problem_size,
     heuristic_func=_heuristic_func_mm_mxfp8,  # result stored in mm_mxfp8.suitable_auto_backends
 )
-@flashinfer_api
+@flashinfer_api(trace=mm_mxfp8_trace)
 def mm_mxfp8(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -4858,8 +4881,8 @@ def _b12x_gemm_fp4_runner(
     """
     import cutlass
 
-    from .kernels.dense_blockscaled_gemm_sm120 import (
-        Sm120BlockScaledDenseGemmKernel,
+    from .kernels.dense_blockscaled_gemm_sm120_b12x import (
+        Sm120B12xBlockScaledDenseGemmKernel,
     )
 
     cutlass_dtype_attr = _TORCH_TO_CUTLASS_DTYPE_ATTR.get(out_dtype)
@@ -4905,7 +4928,7 @@ def _b12x_gemm_fp4_runner(
             ]
             swap_ab = False
             for mma_tiler_mn in sm120_mma_tiler_candidates:
-                if not Sm120BlockScaledDenseGemmKernel.can_implement(
+                if not Sm120B12xBlockScaledDenseGemmKernel.can_implement(
--
     constraint_specs=(
@@ -5195,7 +5217,7 @@ _MM_MXFP8_TUNING_CONFIG = TuningConfig(
     common_check=_check_mm_fp4_problem_size,
     heuristic_func=_heuristic_func_mm_fp4,  # result stored in mm_fp4.suitable_auto_backends
 )
-@flashinfer_api
+@flashinfer_api(trace=mm_fp4_trace)
 def mm_fp4(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -5449,7 +5471,7 @@ def _heuristic_func_bmm_fp8(
     common_check=_check_bmm_fp8_problem_size,
     heuristic_func=_heuristic_func_bmm_fp8,
 )
-@flashinfer_api
+@flashinfer_api(trace=bmm_fp8_trace)
 def bmm_fp8(
     A: torch.Tensor,
     B: torch.Tensor,
@@ -6862,7 +6884,7 @@ def _check_batch_deepgemm_fp8_nt_groupwise(
     {},
     common_check=_check_batch_deepgemm_fp8_nt_groupwise,
 )
-@flashinfer_api
+@flashinfer_api(trace=batch_deepgemm_fp8_nt_groupwise_trace)
 def batch_deepgemm_fp8_nt_groupwise(
     a: torch.Tensor,  # (batch_size, m, k)
     b: torch.Tensor,  # (batch_size, n, k)
@@ -7006,7 +7028,7 @@ def get_fp8_blockscale_gemm_runner_sm90():
     return module.init()
 
 
-@flashinfer_api
+@flashinfer_api(trace=fp8_blockscale_gemm_sm90_trace)
 def fp8_blockscale_gemm_sm90(
     input: torch.Tensor,
     weight: torch.Tensor,
@@ -7588,7 +7610,7 @@ def _heuristic_func_bmm_mxfp8(
     common_check=_check_bmm_mxfp8_problem_size,
     heuristic_func=_heuristic_func_bmm_mxfp8,
 )
-@flashinfer_api
+@flashinfer_api(trace=bmm_mxfp8_trace)
 def bmm_mxfp8(
     A: torch.Tensor,
     B: torch.Tensor,
diff --git a/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py b/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120_b12x.py
similarity index 99%
rename from flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py
rename to flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120_b12x.py
index c49bc815..6eee27a7 100644
--- a/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py
+++ b/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120_b12x.py
@@ -1550,7 +1550,7 @@ class DenseGemmKernel:
 
 
 # Alias for FlashInfer integration
-Sm120BlockScaledDenseGemmKernel = DenseGemmKernel
+Sm120B12xBlockScaledDenseGemmKernel = DenseGemmKernel
 
 
 class _DenseGemmLaunch:
diff --git a/flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py b/flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py
--
     get_cutlass_dtype,
@@ -2951,7 +2952,7 @@ def get_cute_dsl_compiled_masked_gemm_kernel(
     return tensor_api
 
 
-@flashinfer_api
+@flashinfer_api(trace=grouped_gemm_nt_masked_trace)
 def grouped_gemm_nt_masked(
     lhs: Tuple[torch.Tensor, torch.Tensor],
     rhs: Tuple[torch.Tensor, torch.Tensor],
diff --git a/flashinfer/gemm/routergemm.py b/flashinfer/gemm/routergemm.py
index cfde7d43..f83c8974 100644
--- a/flashinfer/gemm/routergemm.py
+++ b/flashinfer/gemm/routergemm.py
@@ -1,4 +1,8 @@
 from ..api_logging import flashinfer_api
+from ..trace.templates.gemm import (
+    mm_M1_16_K7168_N256_trace,
+    tinygemm_bf16_trace,
+)
 from flashinfer.jit import gen_dsv3_router_gemm_module, gen_tinygemm2_module
 import functools
 from types import SimpleNamespace
@@ -176,7 +180,7 @@ def mm_M1_16_K7168_N128(
 
 
 @backend_requirement({}, common_check=_mm_M1_16_K7168_N256_shape_checks)
-@flashinfer_api
+@flashinfer_api(trace=mm_M1_16_K7168_N256_trace)
 def mm_M1_16_K7168_N256(
     mat_a: torch.Tensor,
     mat_b: torch.Tensor,
@@ -324,7 +328,7 @@ def get_tinygemm2_module():
 
 
 @backend_requirement({}, common_check=_tinygemm_bf16_shape_checks)
-@flashinfer_api
+@flashinfer_api(trace=tinygemm_bf16_trace)
 def tinygemm_bf16(
     input: torch.Tensor,
     weight: torch.Tensor,
diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py
index 7f36a314..8378e0ab 100644
--- a/flashinfer/jit/__init__.py
+++ b/flashinfer/jit/__init__.py
@@ -82,6 +82,7 @@ from .comm import gen_trtllm_mnnvl_comm_module as gen_trtllm_mnnvl_comm_module
 from .comm import gen_trtllm_comm_module as gen_trtllm_comm_module
 from .comm import gen_vllm_comm_module as gen_vllm_comm_module
 from .comm import gen_moe_alltoall_module as gen_moe_alltoall_module
+from .comm import gen_dcp_alltoall_module as gen_dcp_alltoall_module
 from .dsv3_optimizations import (
     gen_dsv3_router_gemm_module as gen_dsv3_router_gemm_module,
 )
diff --git a/flashinfer/jit/comm.py b/flashinfer/jit/comm.py
index 46768eed..834f77f9 100644
--- a/flashinfer/jit/comm.py
+++ b/flashinfer/jit/comm.py
@@ -15,7 +15,13 @@ limitations under the License.
--
     gen_selective_state_update_sm100_module,
@@ -99,7 +100,7 @@ def get_selective_state_update_module(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=selective_state_update_trace)
 def selective_state_update(
     state: torch.Tensor,
     x: torch.Tensor,
diff --git a/flashinfer/mla/_core.py b/flashinfer/mla/_core.py
index 4e8bdd72..e27e3807 100644
--- a/flashinfer/mla/_core.py
+++ b/flashinfer/mla/_core.py
@@ -21,6 +21,11 @@ from typing import List, Literal, Optional, Tuple, Union, overload
 import torch
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.attention import (
+    mla_paged_decode_trace,
+    trtllm_batch_decode_mla_trace,
+    xqa_batch_decode_mla_trace,
+)
 from ..jit import gen_batch_mla_module, gen_trtllm_gen_fmha_module, setup_cubin_loader
 from ..jit.mla import gen_mla_module
 from ..utils import (
@@ -469,7 +474,7 @@ class BatchMLAPagedAttentionWrapper:
         return_lse_base_on_e: bool = False,
     ) -> Tuple[torch.Tensor, torch.Tensor]: ...
 
-    @flashinfer_api
+    @flashinfer_api(trace=mla_paged_decode_trace)
     def run(
         self,
         q_nope: torch.Tensor,
@@ -588,7 +593,7 @@ class BatchMLAPagedAttentionWrapper:
         return (out, lse) if return_lse else out
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_batch_decode_mla_trace)
 def trtllm_batch_decode_with_kv_cache_mla(
     query: torch.Tensor,
     kv_cache: torch.Tensor,
@@ -856,7 +861,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
         raise ValueError(f"Backend {backend} not supported")
 
 
-@flashinfer_api
+@flashinfer_api(trace=xqa_batch_decode_mla_trace)
 def xqa_batch_decode_with_kv_cache_mla(
     query: torch.Tensor,
     kv_cache: torch.Tensor,
diff --git a/flashinfer/norm/__init__.py b/flashinfer/norm/__init__.py
index 0f9911a6..ba612b28 100644
--- a/flashinfer/norm/__init__.py
+++ b/flashinfer/norm/__init__.py
@@ -32,6 +32,16 @@ from typing import Optional, Union
 import torch
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.norm import (
+    fused_add_rmsnorm_quant_trace,
+    fused_add_rmsnorm_trace,
+    fused_rmsnorm_silu_trace,
+    gemma_fused_add_rmsnorm_trace,
+    gemma_rmsnorm_trace,
+    layernorm_trace,
+    rmsnorm_quant_trace,
+    rmsnorm_trace,
--
     get_compute_capability,
@@ -94,7 +104,7 @@ def _normalize_scale_tensor(
     return scale.contiguous()
 
 
-@flashinfer_api
+@flashinfer_api(trace=rmsnorm_trace)
 def rmsnorm(
     input: torch.Tensor,
     weight: torch.Tensor,
@@ -165,7 +175,7 @@ def _rmsnorm_impl_fake(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=rmsnorm_quant_trace)
 @register_custom_op("flashinfer::rmsnorm_quant", mutates_args=("out",))
 def rmsnorm_quant(
     out: torch.Tensor,
@@ -219,7 +229,7 @@ def _rmsnorm_quant_fake(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=fused_add_rmsnorm_trace)
 @register_custom_op("flashinfer::fused_add_rmsnorm", mutates_args=("input", "residual"))
 def fused_add_rmsnorm(
     input: torch.Tensor,
@@ -271,7 +281,7 @@ def _fused_add_rmsnorm_fake(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=fused_add_rmsnorm_quant_trace)
 @register_custom_op(
     "flashinfer::fused_add_rmsnorm_quant", mutates_args=("out", "residual")
 )
@@ -343,7 +353,7 @@ def _fused_add_rmsnorm_quant_fake(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=gemma_rmsnorm_trace)
 def gemma_rmsnorm(
     input: torch.Tensor,
     weight: torch.Tensor,
@@ -414,7 +424,7 @@ def _gemma_rmsnorm_impl_fake(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=gemma_fused_add_rmsnorm_trace)
 @register_custom_op(
     "flashinfer::gemma_fused_add_rmsnorm", mutates_args=("input", "residual")
 )
@@ -470,7 +480,7 @@ def _gemma_fused_add_rmsnorm_fake(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=layernorm_trace)
 @register_custom_op("flashinfer::layernorm", mutates_args=())
 def layernorm(
     input: torch.Tensor,
@@ -590,7 +600,7 @@ def _torch_dtype_to_str(dtype):
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=fused_rmsnorm_silu_trace)
 def fused_rmsnorm_silu(
     input: torch.Tensor,
     weight: torch.Tensor,
diff --git a/flashinfer/page.py b/flashinfer/page.py
index 12ea3613..7fb33cf3 100644
--- a/flashinfer/page.py
+++ b/flashinfer/page.py
@@ -20,6 +20,10 @@ from typing import Optional, Tuple, Union
 import torch
 
 from .api_logging import flashinfer_api
+from .trace.templates.page import (
+    append_paged_kv_cache_trace,
+    append_paged_mla_kv_cache_trace,
+)
 from .jit.page import gen_page_module
 from .utils import (
     TensorLayout,
@@ -222,7 +226,7 @@ def get_seq_lens(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=append_paged_mla_kv_cache_trace)
 def append_paged_mla_kv_cache(
     append_ckv: torch.Tensor,
     append_kpe: torch.Tensor,
@@ -272,7 +276,7 @@ def append_paged_mla_kv_cache(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=append_paged_kv_cache_trace)
 def append_paged_kv_cache(
     append_key: torch.Tensor,
     append_value: torch.Tensor,
diff --git a/flashinfer/pod.py b/flashinfer/pod.py
index fe2e36c1..4fa2d9bf 100644
--- a/flashinfer/pod.py
+++ b/flashinfer/pod.py
@@ -22,6 +22,10 @@ from typing import Any, List, Optional, Tuple, Union
 import torch
 
 from .api_logging import flashinfer_api
+from .trace.templates.attention import (
+    batch_pod_with_paged_kv_cache_run_trace,
+    pod_with_paged_kv_cache_run_trace,
+)
 from .jit import gen_pod_module, gen_batch_pod_module
 from .page import get_seq_lens
 from .prefill import get_batch_prefill_module
@@ -435,7 +439,7 @@ class PODWithPagedKVCacheWrapper:
 
     begin_forward = plan
 
-    @flashinfer_api
+    @flashinfer_api(trace=pod_with_paged_kv_cache_run_trace)
     def run(
         self,
         # Main params (prefill and decode)
@@ -1015,7 +1019,7 @@ class BatchPODWithPagedKVCacheWrapper:
 
     begin_forward = plan
 
-    @flashinfer_api
+    @flashinfer_api(trace=batch_pod_with_paged_kv_cache_run_trace)
     def run(
         self,
         # Main params (prefill and decode)
diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py
index 4ec6a29e..d491dd35 100755
--- a/flashinfer/prefill.py
+++ b/flashinfer/prefill.py
@@ -23,6 +23,17 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload
 import torch
 
 from .api_logging import flashinfer_api
+from .trace.templates.attention import (
+    gqa_paged_prefill_trace,
+    gqa_ragged_prefill_trace,
+    single_prefill_with_kv_cache_trace,
+    trtllm_batch_context_trace,
+)
+from .trace.templates.gemm import (
+    fmha_v2_prefill_deepseek_trace,
+    trtllm_ragged_attention_deepseek_trace,
--
     gen_customize_batch_prefill_module,
@@ -1099,7 +1110,7 @@ def single_prefill_with_kv_cache(
 ) -> Tuple[torch.Tensor, torch.Tensor]: ...
 
 
-@flashinfer_api
+@flashinfer_api(trace=single_prefill_with_kv_cache_trace)
 def single_prefill_with_kv_cache(
     q: torch.Tensor,
     k: torch.Tensor,
@@ -2132,7 +2143,7 @@ class BatchPrefillWithPagedKVCacheWrapper:
         skip_softmax_threshold_scale_factor: Optional[float] = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]: ...
 
-    @flashinfer_api
+    @flashinfer_api(trace=gqa_paged_prefill_trace)
     def run(
         self,
         q: torch.Tensor,
@@ -3186,7 +3197,7 @@ class BatchPrefillWithRaggedKVCacheWrapper:
         enable_pdl: Optional[bool] = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]: ...
 
-    @flashinfer_api
+    @flashinfer_api(trace=gqa_ragged_prefill_trace)
     def run(
         self,
         q: torch.Tensor,
@@ -3669,7 +3680,7 @@ def get_trtllm_gen_fmha_module():
     return op
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_ragged_attention_deepseek_trace)
 def trtllm_ragged_attention_deepseek(
     query: torch.Tensor,
     key: torch.Tensor,
@@ -3692,6 +3703,7 @@ def trtllm_ragged_attention_deepseek(
     skip_softmax_threshold_scale_factor: Optional[float] = None,
     out: Optional[torch.Tensor] = None,
     lse: Optional[torch.Tensor] = None,
+    backend: str = "trtllm-gen",
 ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
     """
     Parameters
@@ -3742,6 +3754,12 @@ def trtllm_ragged_attention_deepseek(
         output tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1], value.shape[2]]
     lse : Optional[torch.Tensor]
         lse tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1]]
+    backend : str
+        Attention backend to use. "trtllm-gen" (default) or "cute-dsl".
+        When backend="cute-dsl", query/key/value/out tensors must be
+        front-padded with max_seq_len rows of valid GPU memory before
+        index 0 (see ``cute_dsl_fmha_ragged_prefill`` for details).
--
             "lse assumed not None beyond this point when return_lse is True"
@@ -3839,7 +3917,7 @@ def trtllm_ragged_attention_deepseek(
         return out
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_batch_context_trace)
 def trtllm_batch_context_with_kv_cache(
     query: torch.Tensor,
     kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
@@ -4138,7 +4216,7 @@ def get_trtllm_fmha_v2_sm120_module():
     return gen_trtllm_fmha_v2_sm120_module().build_and_load()
 
 
-@flashinfer_api
+@flashinfer_api(trace=fmha_v2_prefill_deepseek_trace)
 def fmha_v2_prefill_deepseek(
     query: torch.Tensor,
     key: torch.Tensor,
@@ -4228,7 +4306,7 @@ def get_trtllm_fmha_v2_module(
     return gen_fmha_v2_module(input_layout, input_dtype, output_dtype).build_and_load()
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fmha_v2_prefill_trace)
 def trtllm_fmha_v2_prefill(
     qkv: Union[
         torch.Tensor,
diff --git a/flashinfer/quantization/fp4_quantization.py b/flashinfer/quantization/fp4_quantization.py
index 4cd5cd34..84f7ade6 100644
--- a/flashinfer/quantization/fp4_quantization.py
+++ b/flashinfer/quantization/fp4_quantization.py
@@ -21,6 +21,12 @@ from typing import List, Optional, Tuple
 import torch
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.quantize import (
+    fp4_quantize_trace,
+    mxfp4_quantize_trace,
+    nvfp4_kv_quantize_trace,
+    nvfp4_quantize_trace,
+)
 from ..jit import JitSpec
 from ..jit import env as jit_env
 from ..jit import (
@@ -648,7 +654,7 @@ def get_fp4_quantization_module(backend: str = "100"):
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=fp4_quantize_trace)
 def fp4_quantize(
     input: torch.Tensor,
     global_scale: Optional[torch.Tensor] = None,
@@ -923,7 +929,7 @@ def shuffle_matrix_sf_a(
     return block_scale_interleave(w_shuffled)
 
 
-@flashinfer_api
+@flashinfer_api(trace=nvfp4_quantize_trace)
 def nvfp4_quantize(
     a,
     a_global_sf,
@@ -1024,7 +1030,7 @@ def nvfp4_quantize(
     return a_fp4, a_sf
 
 
-@flashinfer_api
+@flashinfer_api(trace=mxfp4_quantize_trace)
 def mxfp4_quantize(
     a: torch.Tensor,
     backend: str = "cuda",
@@ -1441,7 +1447,7 @@ def _nvfp4_kv_quant_check(input, global_scale):
 
 
 @backend_requirement({}, common_check=_nvfp4_kv_quant_check)
-@flashinfer_api
+@flashinfer_api(trace=nvfp4_kv_quantize_trace)
 def nvfp4_kv_quantize(
     input: torch.Tensor,
     global_scale: torch.Tensor,
diff --git a/flashinfer/quantization/fp8_quantization.py b/flashinfer/quantization/fp8_quantization.py
index f2c9f412..49e13a8b 100644
--- a/flashinfer/quantization/fp8_quantization.py
+++ b/flashinfer/quantization/fp8_quantization.py
@@ -5,6 +5,7 @@ from typing import Literal, Optional, Tuple
 import torch
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.quantize import mxfp8_quantize_trace
 from ..jit.fp8_quantization import gen_mxfp8_quantization_sm100_module
 from ..utils import (
     device_support_pdl,
@@ -158,7 +159,7 @@ def get_mxfp8_quantization_sm100_module():
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=mxfp8_quantize_trace)
 def mxfp8_quantize(
     input: torch.Tensor,
     is_sf_swizzled_layout: bool = True,
diff --git a/flashinfer/rope.py b/flashinfer/rope.py
index d39d2e07..df5c7d4d 100644
--- a/flashinfer/rope.py
+++ b/flashinfer/rope.py
@@ -20,6 +20,21 @@ from typing import Optional, Tuple
 import torch
 
 from .api_logging import flashinfer_api
+from .trace.templates.rope import (
+    apply_llama31_rope_inplace_trace,
+    apply_llama31_rope_pos_ids_inplace_trace,
+    apply_llama31_rope_pos_ids_trace,
+    apply_llama31_rope_trace,
+    apply_rope_inplace_trace,
+    apply_rope_pos_ids_inplace_trace,
+    apply_rope_pos_ids_trace,
+    apply_rope_trace,
--
 
@@ -414,7 +429,7 @@ def _fake_apply_llama31_rope_pos_ids(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=apply_rope_inplace_trace)
 def apply_rope_inplace(
     q: torch.Tensor,
     k: torch.Tensor,
@@ -502,7 +517,7 @@ def apply_rope_inplace(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=apply_rope_pos_ids_inplace_trace)
 def apply_rope_pos_ids_inplace(
     q: torch.Tensor,
     k: torch.Tensor,
@@ -561,7 +576,7 @@ def apply_rope_pos_ids_inplace(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=apply_llama31_rope_inplace_trace)
 def apply_llama31_rope_inplace(
     q: torch.Tensor,
     k: torch.Tensor,
... (truncated -- see full diff via the command above)
```

**Summary of API changes:**

- **Decorator semantic addition (backward-compatible):**
`@flashinfer_api` now accepts an optional `trace=<TraceTemplate>`
keyword. Bare `@flashinfer_api` still works. Existing call sites of
decorated functions are unaffected. Most of the diff above is mechanical
rewrites of existing `@flashinfer_api` to `@flashinfer_api(trace=...)`,
plus the new `flashinfer/trace/` package and `fi_trace.py` for
flashinfer-bench JSON dumps.

- **New public APIs (7):**
- `flashinfer.comm.dcp_alltoall.{decode_cp_a2a_workspace_size,
decode_cp_a2a_allocate_workspace, decode_cp_a2a_init_workspace,
decode_cp_a2a_alltoall}` — DCP all-to-all for context-parallel attention
reduction (#2951).
- `flashinfer.fused_moe.{interleave_moe_scales_for_sm90_mixed_gemm,
interleave_moe_weights_for_sm90_mixed_gemm}` — SM90 mixed-input MoE GEMM
helpers (#3084).
- `flashinfer.comm.run_mixed_comm` — combinations of allreduce /
allgather / reducescatter (#2563).

- **New `@flashinfer_api`-decorated wrapper init:**
- `SegmentGEMMWrapper.__init__` is now decorated. Previously the class
itself was undecorated; `run()` already was. No call-site change.

- **Backward-compatible signature additions (defaults preserve old
behavior):**
- `top_k_page_table_transform`: `+dsa_graph_safe: bool = False`,
`+row_starts: Optional[torch.Tensor] = None` (#3133).
  - `top_k_ragged_transform`: same two new params (#3133).
- `trtllm_ragged_attention_deepseek`: `+backend: str = "trtllm-gen"`
(cute-dsl backend selection).

- **No breaking signature changes** to any `@flashinfer_api` function.
Net public surface delta: +7 functions, +1 newly-decorated `__init__`, 0
removals.

- **Module reorganization to flag (not `@flashinfer_api`, but in public
re-export):**
- `flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py` →
`dense_blockscaled_gemm_sm120_b12x.py`
- Class renamed: `Sm120BlockScaledDenseGemmKernel` →
`Sm120B12xBlockScaledDenseGemmKernel`
- Re-export in `flashinfer/gemm/__init__.py` updated to the new name
only — direct importers of the old name break. Decision needed: ship as
breaking, or add a deprecation alias.

- **Internal autotuner helper rename (not public, but used by downstream
extensions):**
- `get_last_power_of_2_num_tokens_buckets` →
`get_hybrid_num_tokens_buckets`
- `last_positive_power_of_2` → `map_to_hybrid_bucket` /
`map_to_hybrid_bucket_uncapped`

> Diff truncated above due to GitHub PR body length limit. Run the
command at the top locally to see the full output.


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Patch Release**
  * Version updated to 0.6.10

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants