Skip to content

Add support for the combinations of allreduce, allgather, and reducescatter#2563

Merged
aleozlx merged 8 commits intoflashinfer-ai:mainfrom
jinyangyuan-nvidia:dev/mixed_comm
Apr 23, 2026
Merged

Add support for the combinations of allreduce, allgather, and reducescatter#2563
aleozlx merged 8 commits intoflashinfer-ai:mainfrom
jinyangyuan-nvidia:dev/mixed_comm

Conversation

@jinyangyuan-nvidia
Copy link
Copy Markdown
Contributor

@jinyangyuan-nvidia jinyangyuan-nvidia commented Feb 14, 2026

📌 Description

This PR adds support for the combinations of allreduce,allgather, and reducescatter:

  • allreduce
  • allgather
  • reducescatter
  • allreduce + allgather
  • reducescatter + allreduce

The last two communication patterns occur when TP and DP are both enabled in the attention part.

Besides combining existing NCCL kernels, this PR also implements fused kernels:

  • Intra-node communications are implemented using multicast or CUDA IPC
  • Inter-node communications are implemented using NVSHMEM
  • Pipeline is implemented to overlap intra-node and inter-node communications when the message size is large enough
  • To support more use cases, the inputs and outputs are not required to be on symmetric memory even when multicast is used (internal buffers are on symmetric memory)
  • The size of internal buffers is independent of message size because these buffers are reused

Update
Besides using unit tests to verify correctness, the correctness has also been verified by running GSM8K and GPQA tests using SGLang: https://github.com/jinyangyuan-nvidia/sglang/tree/dev/mixed_comm

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Distributed mixed-communication primitives (all-reduce, all-gather, reduce-scatter) with fused GPU kernels, NVSHMEM-enabled intra-node optimizations, topology-aware TP/DP handling, and an autotune path to pick optimal modes.
    • Public APIs to configure/query parallel topology and run mixed-comm workloads.
  • New Tools

    • Executable benchmark for mixed communication across GPUs/ranks and improved multi-rank GPU timing aggregation for accurate cluster measurements.
  • Tests

    • Distributed correctness and cross-mode consistency tests covering multiple dtypes and local sizes.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Feb 14, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a mixed-communication subsystem: NVSHMEM/CUDA kernels and TVM FFI entrypoints, Python orchestration and JIT generation, benchmarking CLI and timing aggregation, and distributed tests validating fused and fallback collective operators across TP/DP decompositions.

Changes

Cohort / File(s) Summary
CUDA Declarations
include/flashinfer/comm/mixed_comm_decl.cuh
New device-side enums, constexpr helpers, BufferInfo and MixedCommArgs structs, and templated kernel declarations for mixed collectives and fused variants.
CUDA Host Implementation
csrc/mixed_comm.cu
New NVSHMEM lifecycle, unique-id handling, NVSHMEM-backed allocator, kernel dispatch macros, input validation, fused/unfused host entrypoints, and max-block-size discovery exported via TVM FFI.
Kernel Instantiation Template
csrc/mixed_comm_kernel_inst.jinja
Jinja template to generate global kernel instantiations across block-size, TP/DP, mode, and dtype combinations.
Python Mixed Comm Runtime
flashinfer/comm/mixed_comm.py
New enums (MixedCommOp/Mode), ParallelInfo topology helpers, MixedCommHandler runtime (CUDA VM, NVSHMEM, handle exchange, autotune), and public run_mixed_comm API with fused/NCCL fallbacks and shutdown.
JIT Generation & Build
flashinfer/jit/comm.py
Added gen_mixed_comm_module() to generate kernel instantiation sources, include NVSHMEM include/links, enable device linking, and integrate generated sources into JIT build.
Benchmark CLI
benchmarks/bench_mixed_comm.py
New executable benchmark that spawns local workers, initializes distributed groups, constructs MixedCommHandler instances, and microbenchmarks ops/modes with aggregated CUDA-graph timing.
Timing Utilities
flashinfer/testing/utils.py
Added aggregate_gpu_time_across_ranks and threaded aggregate_op propagation into bench_gpu_time/* so measurements can be aggregated across distributed ranks.
Tests & Pytest Config
tests/comm/conftest.py, tests/comm/test_mixed_comm.py
Pytest CLI options/fixtures for multi-node runs and a distributed correctness test that spawns per-local-rank workers, compares fused vs fallback outputs, and validates TP/DP consistency.

Sequence Diagram(s)

sequenceDiagram
    participant App as Application
    participant Handler as MixedCommHandler
    participant JIT as JIT / Generated Module
    participant Kernels as CUDA Kernels
    participant NVSHMEM as NVSHMEM
    participant NCCL as NCCL / PyTorch Dist

    App->>Handler: init(world_rank, local_rank, dtype,...)
    Handler->>JIT: build_and_load() / load generated module
    Handler->>NVSHMEM: exchange unique-id / init (inter-node)
    Handler->>Handler: allocate CUDA VM / map IPC handles (local ranks)
    Handler->>NVSHMEM: allocate NVSHMEM workspace (if enabled)
    App->>Handler: run_mixed_comm(op, x_in, mode/AUTOTUNE)
    Handler->>Handler: select_autotune_mode() / _common_check()
    alt fused kernel available
        Handler->>Kernels: launch fused kernel (via JIT)
        Kernels->>NVSHMEM: intra/inter-node collectives via NVSHMEM
        Kernels-->>Handler: results (device buffers)
    else fallback
        Handler->>NCCL: perform collectives via PyTorch/NCCL
        NCCL-->>Handler: results
    end
    Handler-->>App: x_out
    App->>Handler: shutdown()
    Handler->>NVSHMEM: finalize (if initialized)
    Handler->>Handler: unmap/free VM, destroy groups
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Suggested reviewers

  • aleozlx
  • bkryu
  • cyx-6
  • yzh119
  • yongwww
  • sricketts
  • jimmyzho
  • nv-yunzheq
  • saltyminty
  • djmmoss

Poem

🐰
I hopped through kernels, warm and bright,
NVSHMEM lanterns lent their light,
Fused hops, fallbacks, ranks in queue,
Latencies measured, proofs anew.
A rabbit cheers: collective flight!

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 5.38% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Add support for the combinations of allreduce, allgather, and reducescatter' clearly and concisely summarizes the main change—adding support for multiple communication patterns, which is the primary objective of this PR.
Description check ✅ Passed The PR description includes all major required sections: a clear description of what the PR does with detailed context, related issues, and the PR checklist with pre-commit and testing confirmations.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @jinyangyuan-nvidia, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the distributed communication capabilities by introducing optimized primitives for combined allreduce-allgather and reducescatter-allreduce operations. These new primitives are designed to improve efficiency in complex distributed training setups, particularly where both Tensor Parallelism and Data Parallelism are employed. The implementation utilizes fused CUDA kernels and integrates NVSHMEM for efficient inter-node data exchange, alongside new benchmarking and testing infrastructure to ensure correctness and performance.

Highlights

  • New Communication Patterns: Added support for combined communication patterns: allreduce + allgather and reducescatter + allreduce, which are crucial for scenarios involving both Tensor Parallelism (TP) and Data Parallelism (DP) in attention mechanisms.
  • Fused Kernel Implementations: Implemented fused CUDA kernels for these new communication patterns, leveraging virtual memory for intra-node communication and NVSHMEM for inter-node communication to optimize performance.
  • Benchmarking and Testing: Introduced a new benchmarking script (bench_mixed_comm.py) and comprehensive unit tests (test_mixed_comm.py) to validate the precision, consistency, and performance of the mixed communication operations.
  • NVSHMEM Integration: Enhanced NVSHMEM bindings to support local PE information and memory allocation with initialization, facilitating multi-node communication within the fused kernels.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • benchmarks/bench_mixed_comm.py
    • Added a new benchmarking script for mixed communication operations.
  • csrc/mixed_comm.cu
    • Implemented fused CUDA kernels for combined allreduce-allgather and reducescatter-allreduce operations, including single-node and multi-node variants.
  • csrc/nvshmem_binding.cu
    • Extended NVSHMEM bindings to include local PE information and a malloc_tensor_with_init function.
  • flashinfer/comm/mixed_comm.py
    • Introduced a new Python module MixedComm to manage and execute mixed communication patterns, including virtual memory and NVSHMEM initialization.
  • flashinfer/comm/mnnvl.py
    • Added a method to TorchDistBackend to retrieve group ranks and adjusted _init_ipc_socket to use this for root determination.
  • flashinfer/jit/comm.py
    • Added a new JIT specification for compiling the mixed communication CUDA kernels.
  • flashinfer/testing/utils.py
    • Enhanced GPU time benchmarking utilities with an aggregate_gpu_time_across_ranks function.
    • Added aggregate_op parameter to existing GPU benchmarking functions.
  • include/flashinfer/comm/mixed_comm.cuh
    • Defined CUDA kernel implementations and helper functions for fused mixed communication operations.
  • tests/comm/conftest.py
    • Added pytest fixtures for distributed testing parameters (num_nodes, node_id, dist_init_method).
  • tests/comm/test_mixed_comm.py
    • Added new unit tests for the MixedComm class, verifying precision and consistency of the new communication patterns.
Activity
  • The author noted two TODO items in the PR description: fixing potential hanging issues in the TRTLLM_LOCAL_INTER benchmark mode and supporting autotuning.
  • Pre-commit checks and tests are marked as incomplete in the PR checklist, indicating further work or verification is needed.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant new functionality by adding support for fused communication kernels (allreduce+allgather and reducescatter+allreduce). These kernels leverage advanced techniques like virtual memory for intra-node communication and nvshmem for inter-node communication, which is a great addition for performance. The PR also includes comprehensive benchmarks and tests for these new features. The overall implementation is complex but appears well-structured. My feedback focuses on improving the clarity and maintainability of the new benchmark script and ensuring proper random data generation in both the benchmarks and tests.

Comment thread benchmarks/bench_mixed_comm.py Outdated
Comment thread benchmarks/bench_mixed_comm.py Outdated
…catter

Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
@jinyangyuan-nvidia jinyangyuan-nvidia changed the title Add support for the combinations of allreduce + allgather and reducescatter + allreduce Add support for the combinations of allreduce, allgather, and reducescatter Apr 12, 2026
@jinyangyuan-nvidia jinyangyuan-nvidia marked this pull request as ready for review April 12, 2026 16:17
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (5)
tests/comm/test_mixed_comm.py (1)

14-16: Consider removing sys.path manipulation.

This pattern of adding the project root to sys.path is unusual for pytest tests. Pytest typically handles module discovery correctly when run from the project root. If this is needed for a specific reason, consider documenting why.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/comm/test_mixed_comm.py` around lines 14 - 16, Remove the manual
sys.path manipulation: delete the _project_root =
Path(__file__).parent.parent.parent and the conditional
sys.path.append(str(_project_root)) lines (they are adding the project root to
sys.path) unless there's a documented, test-specific reason; if the import
resolution actually requires it, add a brief comment explaining why and prefer
using pytest configuration (e.g., PYTHONPATH or pytest.ini) instead of modifying
sys.path in tests.
csrc/mixed_comm_kernel_inst.jinja (1)

40-41: Minor: Semicolons after namespace closing braces are unconventional.

While valid C++, the trailing semicolons after namespace closing braces are not standard style.

♻️ Suggested fix
-};  // namespace mixed_comm
-};  // namespace flashinfer
+}  // namespace mixed_comm
+}  // namespace flashinfer
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/mixed_comm_kernel_inst.jinja` around lines 40 - 41, The closing
namespace braces for namespaces flashinfer and mixed_comm currently include
trailing semicolons; remove the unnecessary semicolons after the closing braces
for namespace mixed_comm and namespace flashinfer (the lines with "};  //
namespace mixed_comm" and "};  // namespace flashinfer") so they read as
standard namespace closing comments (i.e., " }  // namespace mixed_comm" and " }
// namespace flashinfer") without the extra semicolons.
flashinfer/jit/comm.py (2)

164-165: Consider simplifying list construction.

The list concatenation can be simplified using iterable unpacking for better readability.

♻️ Suggested simplification
     nvcc_flags = ["-rdc=true"]
-    ldflags = [f"-L{str(path_base / 'lib')}"] + ["-lnvshmem_device"]
+    ldflags = [f"-L{path_base / 'lib'}", "-lnvshmem_device"]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/jit/comm.py` around lines 164 - 165, The ldflags list currently
builds via concatenation ([f"-L{str(path_base / 'lib')}"] +
["-lnvshmem_device"]); simplify it by creating a single list using iterable
unpacking or by listing both items directly so it's clearer and more readable
(refer to the nvcc_flags and ldflags variables and the path_base / 'lib'
expression when locating the code).

160-174: Verify nvidia.nvshmem package availability at JIT time.

The import of nvidia.nvshmem happens at JIT generation time. If the package is not installed, this will raise an ImportError. Consider adding a more informative error message.

🛡️ Suggested improvement for better error handling
-    import nvidia.nvshmem
+    try:
+        import nvidia.nvshmem
+    except ImportError as e:
+        raise ImportError(
+            "nvidia-nvshmem package is required for mixed_comm module. "
+            "Install it with: pip install nvidia-nvshmem"
+        ) from e

     path_base = pathlib.Path(nvidia.nvshmem.__path__[0])
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/jit/comm.py` around lines 160 - 174, The code imports
nvidia.nvshmem at JIT generation time which will raise ImportError if the
package is missing; wrap the import of nvidia.nvshmem in a try/except that
catches ImportError and re-raises a clear, actionable error (or RuntimeError)
explaining that the nvidia.nvshmem package is not installed and is required to
generate the "mixed_comm" JIT spec used by gen_jit_spec, include the original
exception message for debugging, and ensure any subsequent uses (path_base,
extra_include_paths, extra_ldflags, needs_device_linking) only run after the
import succeeds.
csrc/nvshmem_binding.cu (1)

134-156: Host synchronization in allreduce_on_stream_with_copy may impact performance.

The cudaStreamSynchronize at line 155 blocks the host thread. If this function is called frequently in a tight loop, this could become a bottleneck. If the synchronization is intentional for correctness guarantees, consider documenting this.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/nvshmem_binding.cu` around lines 134 - 156, The function
allreduce_on_stream_with_copy currently ends with a blocking
cudaStreamSynchronize call which can stall the host; remove the unconditional
cudaStreamSynchronize to make the operation asynchronous and instead either (a)
record and return/use a cudaEvent on the stream for the caller to wait on, or
(b) accept a cudaStream_t or a callback so the caller controls synchronization;
update the function signature/usage accordingly (refer to
allreduce_on_stream_with_copy, the cudaMemcpyAsync calls, and
nvshmemx_barrier_on_stream) and add a comment documenting that the function is
now non-blocking and callers must synchronize when they need completion
guarantees.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/mixed_comm.cu`:
- Around line 96-115: The current memoization uses a single thread_local int
(dummy_smem_size) so get_dummy_smem_size(int device_id) returns the same cached
value for different device_id; change the cache to map device_id->size (e.g.,
thread_local std::unordered_map<int,int> dummy_smem_size_map) and update
get_dummy_smem_size to look up device_id, compute and store the value keyed by
device_id when missing; add `#include` <unordered_map> near the top of the file
and keep the existing helper name get_dummy_smem_size and its behavior
otherwise.

In `@flashinfer/comm/mixed_comm.py`:
- Around line 641-650: The code currently constructs AF_UNIX socket paths via
get_socket_path(pid) => "/tmp/{pid}" and uses that in send_fd to send SCM_RIGHTS
to get_socket_path(local_pid_list[rank]); change this to create and use a secure
per-run temporary directory (mode 0700) and place per-rank socket files there
instead of /tmp; modify get_socket_path to return the path inside that private
temp dir (created once at process startup with os.mkdir/mkdtemp and os.chmod
0o700), update any callers (send_fd and the corresponding receive side
referenced around the other occurrence at the 723-733 region) to use the new
base dir, and ensure old stale paths are cleaned up on exit or recreated
atomically to avoid races.
- Around line 528-537: The call to mixed_comm_module.get_max_block_size inside
max_block_size_dict is passing the original constructor args (local_tp_size,
local_dp_size, inter_tp_size, inter_dp_size) instead of the normalized/resolved
values stored on the ParallelInfo instance; change the call to pass the resolved
attributes (e.g. self.local_tp_size, self.local_dp_size, self.inter_tp_size,
self.inter_dp_size) along with self.dtype, op, and mode so the FFI probe never
receives None.
- Around line 542-552: The loop that is intended to cap entries in
self.max_block_size_dict is a no-op because it only rebinds the local variable
val; update the actual dict entries instead (e.g., for key, val in
self.max_block_size_dict.items(): self.max_block_size_dict[key] = {sub_key:
min(sub_val, max_block_size) for sub_key, sub_val in val.items()} or mutate val
in-place) while keeping the existing assertions on max_block_size, warp_size,
and min_block_size so that the later max(...) computation uses the capped
values.
- Around line 833-840: The current CPU tensor broadcast using
torch.distributed.broadcast on the tensor uid (created via
self.mixed_comm_module.nvshmem_unique_id_size() and filled by
self.mixed_comm_module.nvshmem_get_unique_id(uid) on rank
self.para_info.world_rank == 0) will hit NCCL CPU-backend errors; replace that
call with torch.distributed.broadcast_object_list to broadcast the small CPU
payload safely across backends (create a single-element list containing
uid.numpy().tobytes() or a bytes object on rank 0, call
torch.distributed.broadcast_object_list(list_obj, src=0), and reconstruct the
uid tensor from the received bytes on all ranks) so the NVSHMEM unique id is
distributed correctly before nvshmem_init().

---

Nitpick comments:
In `@csrc/mixed_comm_kernel_inst.jinja`:
- Around line 40-41: The closing namespace braces for namespaces flashinfer and
mixed_comm currently include trailing semicolons; remove the unnecessary
semicolons after the closing braces for namespace mixed_comm and namespace
flashinfer (the lines with "};  // namespace mixed_comm" and "};  // namespace
flashinfer") so they read as standard namespace closing comments (i.e., " }  //
namespace mixed_comm" and " }  // namespace flashinfer") without the extra
semicolons.

In `@csrc/nvshmem_binding.cu`:
- Around line 134-156: The function allreduce_on_stream_with_copy currently ends
with a blocking cudaStreamSynchronize call which can stall the host; remove the
unconditional cudaStreamSynchronize to make the operation asynchronous and
instead either (a) record and return/use a cudaEvent on the stream for the
caller to wait on, or (b) accept a cudaStream_t or a callback so the caller
controls synchronization; update the function signature/usage accordingly (refer
to allreduce_on_stream_with_copy, the cudaMemcpyAsync calls, and
nvshmemx_barrier_on_stream) and add a comment documenting that the function is
now non-blocking and callers must synchronize when they need completion
guarantees.

In `@flashinfer/jit/comm.py`:
- Around line 164-165: The ldflags list currently builds via concatenation
([f"-L{str(path_base / 'lib')}"] + ["-lnvshmem_device"]); simplify it by
creating a single list using iterable unpacking or by listing both items
directly so it's clearer and more readable (refer to the nvcc_flags and ldflags
variables and the path_base / 'lib' expression when locating the code).
- Around line 160-174: The code imports nvidia.nvshmem at JIT generation time
which will raise ImportError if the package is missing; wrap the import of
nvidia.nvshmem in a try/except that catches ImportError and re-raises a clear,
actionable error (or RuntimeError) explaining that the nvidia.nvshmem package is
not installed and is required to generate the "mixed_comm" JIT spec used by
gen_jit_spec, include the original exception message for debugging, and ensure
any subsequent uses (path_base, extra_include_paths, extra_ldflags,
needs_device_linking) only run after the import succeeds.

In `@tests/comm/test_mixed_comm.py`:
- Around line 14-16: Remove the manual sys.path manipulation: delete the
_project_root = Path(__file__).parent.parent.parent and the conditional
sys.path.append(str(_project_root)) lines (they are adding the project root to
sys.path) unless there's a documented, test-specific reason; if the import
resolution actually requires it, add a brief comment explaining why and prefer
using pytest configuration (e.g., PYTHONPATH or pytest.ini) instead of modifying
sys.path in tests.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: bcebf4b6-8ac1-4cc9-831c-508b612dade7

📥 Commits

Reviewing files that changed from the base of the PR and between a1166dc and fc634c1.

📒 Files selected for processing (11)
  • benchmarks/bench_mixed_comm.py
  • csrc/mixed_comm.cu
  • csrc/mixed_comm_kernel_inst.jinja
  • csrc/nvshmem_binding.cu
  • flashinfer/comm/mixed_comm.py
  • flashinfer/jit/comm.py
  • flashinfer/testing/utils.py
  • include/flashinfer/comm/mixed_comm.cuh
  • include/flashinfer/comm/mixed_comm_decl.cuh
  • tests/comm/conftest.py
  • tests/comm/test_mixed_comm.py

Comment thread csrc/mixed_comm.cu Outdated
Comment thread flashinfer/comm/mixed_comm.py
Comment thread flashinfer/comm/mixed_comm.py
Comment thread flashinfer/comm/mixed_comm.py Outdated
Comment thread flashinfer/comm/mixed_comm.py
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
@aleozlx aleozlx added the run-ci label Apr 14, 2026
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 14, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !549 has been created, and the CI pipeline #48536652 is currently running. I'll report back once the pipeline job completes.

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 14, 2026

hi can you annotate public api using @flashinfer_api

Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
@jinyangyuan-nvidia
Copy link
Copy Markdown
Contributor Author

jinyangyuan-nvidia commented Apr 15, 2026

hi can you annotate public api using @flashinfer_api

Thanks for the suggestion. The code has been modified accordingly.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

♻️ Duplicate comments (1)
flashinfer/comm/mixed_comm.py (1)

747-757: ⚠️ Potential issue | 🔴 Critical

Don’t assume the default process group can broadcast CPU tensors.

Line 756 broadcasts a CPU uid, but this class only requires “an active torch.distributed process group” in its public contract. PyTorch documents broadcast as unsupported for CPU tensors on an NCCL process group; CPU collectives only route to Gloo when a multi-backend group exists, and broadcast_object_list() is the portable option for a small payload like this unique ID. That means callers initialized with plain backend="nccl" can fail here before nvshmem_init() runs. (docs.pytorch.org)

Does torch.distributed.broadcast support CPU tensors when the default process group is initialized with backend="nccl", and what is the recommended way to broadcast a small CPU payload such as an NVSHMEM unique ID?
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/comm/mixed_comm.py` around lines 747 - 757, The code in
init_nvshmem uses torch.distributed.broadcast on a CPU tensor which can fail for
NCCL-only default process groups; instead convert the NVSHMEM unique id to a
small Python object (e.g., bytes via uid.cpu().numpy().tobytes() or a list of
ints), use torch.distributed.broadcast_object_list to share it from
para_info.world_rank == 0, then reconstruct the uid tensor (same shape/dtype)
before calling mixed_comm_module.nvshmem_init; update references to
mixed_comm_module.nvshmem_get_unique_id and mixed_comm_module.nvshmem_init
accordingly so the send/receive path uses broadcast_object_list and the tensor
is restored on CPU.
🧹 Nitpick comments (1)
benchmarks/bench_mixed_comm.py (1)

119-157: Please route this through the unified benchmark harness.

This standalone CLI bypasses benchmarks/flashinfer_benchmark.py, so it won’t inherit the repo’s standard benchmark metadata/reporting path even though the timing loop itself already uses bench_gpu_time() correctly.

As per coding guidelines, "Use the unified benchmarking framework in benchmarks/flashinfer_benchmark.py for kernel benchmarking with CUPTI timing support".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_mixed_comm.py` around lines 119 - 157, The script currently
implements a standalone main() that spawns processes and calls _run_worker
directly, bypassing the unified harness in benchmarks/flashinfer_benchmark.py;
refactor so the CLI routes through that harness instead: import the benchmark
entrypoint or runner from flashinfer_benchmark.py and invoke it from main(),
passing the existing argument parsing and the worker entrypoint (_run_worker) or
encapsulating the spawn logic into a function the harness can call; ensure the
timing loop (bench_gpu_time) remains intact but that all runs and
metadata/reporting use the flashinfer_benchmark harness so the standard
reporting path and CUPTI timing support are used.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/comm/mixed_comm.py`:
- Around line 529-550: The default MixedCommMode.AUTOTUNE is removed from
get_valid_mode_list when use_autotune is False, but run_mixed_comm() still
defaults its mode parameter to MixedCommMode.AUTOTUNE causing _common_check() to
reject calls when MixedCommHandler(..., use_autotune=False); fix by making the
default mode consistent: either always include MixedCommMode.AUTOTUNE in
get_valid_mode_list or change run_mixed_comm()'s default to a valid mode (e.g.,
MixedCommMode.NCCL_ONE or None) and update callers; locate get_valid_mode_list,
run_mixed_comm, and _common_check to implement the consistent default behavior
and ensure MixedCommMode.AUTOTUNE is only accepted when use_autotune is True.
- Around line 1224-1237: The shape checks around x_out in mixed_comm.dispatch
(the block that references x_out.shape and MixedCommOp.REDUCESCATTER /
REDUCESCATTER_ALLREDUCE) must also validate that x_out.dtype matches x_in.dtype
and that x_out.device matches x_in.device (or the expected device used by the
comm handler) before returning/dispatching; add explicit checks that raise a
TypeError with a clear message when dtype or device mismatch is detected so
callers get immediate API-bound errors instead of failures inside NCCL/FFI.
- Around line 408-455: MixedCommHandler currently accepts any torch.dtype but
only supports fp16/bf16 kernels; in the MixedCommHandler.__init__ (after
assigning self.dtype and self.device but before calling get_mixed_comm_module()
or probing block sizes such as via mixed_comm_module.get_max_block_size)
validate that self.dtype is one of torch.float16 or torch.bfloat16
(torch.half/torch.bfloat16 aliases are fine) and raise a clear ValueError if
not; update the constructor to perform this check early so unsupported dtypes
fail fast with an actionable error message referencing MixedCommHandler and
dtype.

In `@flashinfer/jit/comm.py`:
- Around line 17-24: The new JIT module generator function gen_*_module is
missing Python-level caching; import functools.cache (or functools) and add the
`@functools.cache` decorator to gen_*_module so repeated calls reuse the rendered
template and the created JitSpec instead of re-rendering; follow the same
pattern used by the other gen_..._module functions in this file (apply the
decorator to any new gen_*_module entrypoints).

In `@tests/comm/test_mixed_comm.py`:
- Around line 201-208: The test_mixed_comm function currently only checks GPU
count but must early-skip unsupported GPU architectures before spawning workers;
add a compute-capability guard at the top of test_mixed_comm that calls the API
helper (e.g., run_mixed_comm.is_compute_capability_supported(...) or
flashinfer.utils.is_sm90a_supported()/is_sm100a_supported()) and calls
pytest.skip(...) if the capability is not supported, placing this check before
any torch.cuda.device_count() or worker spawn logic so child processes do not
error.

---

Duplicate comments:
In `@flashinfer/comm/mixed_comm.py`:
- Around line 747-757: The code in init_nvshmem uses torch.distributed.broadcast
on a CPU tensor which can fail for NCCL-only default process groups; instead
convert the NVSHMEM unique id to a small Python object (e.g., bytes via
uid.cpu().numpy().tobytes() or a list of ints), use
torch.distributed.broadcast_object_list to share it from para_info.world_rank ==
0, then reconstruct the uid tensor (same shape/dtype) before calling
mixed_comm_module.nvshmem_init; update references to
mixed_comm_module.nvshmem_get_unique_id and mixed_comm_module.nvshmem_init
accordingly so the send/receive path uses broadcast_object_list and the tensor
is restored on CPU.

---

Nitpick comments:
In `@benchmarks/bench_mixed_comm.py`:
- Around line 119-157: The script currently implements a standalone main() that
spawns processes and calls _run_worker directly, bypassing the unified harness
in benchmarks/flashinfer_benchmark.py; refactor so the CLI routes through that
harness instead: import the benchmark entrypoint or runner from
flashinfer_benchmark.py and invoke it from main(), passing the existing argument
parsing and the worker entrypoint (_run_worker) or encapsulating the spawn logic
into a function the harness can call; ensure the timing loop (bench_gpu_time)
remains intact but that all runs and metadata/reporting use the
flashinfer_benchmark harness so the standard reporting path and CUPTI timing
support are used.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: b3b907c3-10d0-4f80-84bd-8029b786ebd3

📥 Commits

Reviewing files that changed from the base of the PR and between b54f7fb and 8f6b63f.

📒 Files selected for processing (4)
  • benchmarks/bench_mixed_comm.py
  • flashinfer/comm/mixed_comm.py
  • flashinfer/jit/comm.py
  • tests/comm/test_mixed_comm.py

Comment thread flashinfer/comm/mixed_comm.py
Comment thread flashinfer/comm/mixed_comm.py
Comment thread flashinfer/comm/mixed_comm.py
Comment thread flashinfer/jit/comm.py
Comment thread tests/comm/test_mixed_comm.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/comm/test_mixed_comm.py (1)

140-144: Minor: Use generator expression in all().

The list comprehension inside all() creates an unnecessary intermediate list.

Suggested change
-    assert all([val == local_size for val in local_size_all]), (
+    assert all(val == local_size for val in local_size_all), (
         "local_size must be the same on all ranks"
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/comm/test_mixed_comm.py` around lines 140 - 144, The assertion creates
an unnecessary intermediate list; replace the list comprehension passed to all()
with a generator expression by changing the condition that checks equality
across gathered sizes (currently using local_size_all and local_size with
torch.distributed.all_gather_object) to use a generator expression (e.g.,
all(val == local_size for val in local_size_all)) so no temporary list is
allocated.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/comm/test_mixed_comm.py`:
- Around line 140-144: The assertion creates an unnecessary intermediate list;
replace the list comprehension passed to all() with a generator expression by
changing the condition that checks equality across gathered sizes (currently
using local_size_all and local_size with torch.distributed.all_gather_object) to
use a generator expression (e.g., all(val == local_size for val in
local_size_all)) so no temporary list is allocated.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 1d4cde09-d5b7-48cd-9c52-ca1a93a89642

📥 Commits

Reviewing files that changed from the base of the PR and between 8f6b63f and cd729b5b6c1050c092f72bb5e733980430cc7ede.

📒 Files selected for processing (2)
  • flashinfer/comm/mixed_comm.py
  • tests/comm/test_mixed_comm.py

Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

♻️ Duplicate comments (1)
flashinfer/comm/mixed_comm.py (1)

750-758: ⚠️ Potential issue | 🔴 Critical

Use broadcast_object_list() (or a CPU-capable group) for the NVSHMEM unique ID.

uid is allocated on CPU, but torch.distributed.broadcast(uid, src=0) goes through the default CUDA/NCCL process group here. That path fails on CPU tensors before nvshmem_init() completes.

Does `torch.distributed.broadcast` support CPU tensors when the process group backend is NCCL, and what is the recommended way to broadcast a small CPU payload like an NVSHMEM unique ID?
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/comm/mixed_comm.py` around lines 750 - 758, The code broadcasts a
CPU tensor `uid` using torch.distributed.broadcast which may use the NCCL/CUDA
group and fail for CPU tensors before nvshmem_init; replace that broadcast with
a CPU-safe mechanism (e.g., torch.distributed.broadcast_object_list or a
CPU-capable process group) when sending the NVSHMEM unique ID obtained via
mixed_comm_module.nvshmem_get_unique_id() and sized by
mixed_comm_module.nvshmem_unique_id_size(); ensure the root
(para_info.world_rank == 0) fills `uid` and all ranks call the CPU-safe
broadcast so nvshmem_init() receives the correct ID.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/bench_mixed_comm.py`:
- Around line 149-153: The current loop that waits for workers (iterating
process_list, calling process.join() then asserting exitcode) can deadlock if
one worker fails and peers are stuck in collectives; change it to poll processes
with a timeout and propagate failures immediately: replace the direct
process.join() loop with a monitoring loop that uses join(timeout=...) or checks
process.exitcode periodically, and on detecting any non-zero exitcode from a
process in process_list (or a process that stops responding within the timeout)
terminate/kill the remaining processes and raise an error; reference the
process_list variable and the code that currently calls process.join() and
checks process.exitcode to implement this immediate teardown and failure
propagation.
- Around line 119-157: The script currently implements its own CLI and
multiprocessing orchestration in main() (spawning _run_worker), which bypasses
the repo's unified benchmark entrypoint; refactor by removing the standalone
CLI/main and instead expose a run_benchmark(config) function that accepts parsed
args or a config object and invokes the existing _run_worker logic inside the
framework, then register this benchmark with the shared benchmark framework in
flashinfer_benchmark.py (use the framework's register API to supply a name, an
entry function that calls run_benchmark, and metadata so CUPTI timing and
standard output/registration are used); keep _run_worker intact, move argument
parsing into the shared framework, and ensure the new registration replaces the
if __name__ == "__main__" block.

In `@flashinfer/comm/mixed_comm.py`:
- Around line 1234-1243: The _common_check() branch incorrectly applies the
allgather rule to all non-reduce-scatter ops; update the conditional to treat
MixedCommOp.ALLREDUCE separately: for MixedCommOp.REDUCESCATTER and
MixedCommOp.REDUCESCATTER_ALLREDUCE keep the existing check (x_out.shape[0] *
handler.para_info.dp_size == x_in.shape[0]), for MixedCommOp.ALLREDUCE require
x_out.shape[0] == x_in.shape[0], and for other allgather-like ops keep
x_out.shape[0] == x_in.shape[0] * handler.para_info.dp_size; reference the
_common_check function, MixedCommOp enum values (REDUCESCATTER,
REDUCESCATTER_ALLREDUCE, ALLREDUCE), and handler.para_info.dp_size to implement
the corrected branching.
- Around line 583-593: The POSIX file descriptors returned by
cuMemExportToShareableHandle and received via SCM_RIGHTS are not being closed,
causing fd leaks; update create_and_allgather_uc_handle,
create_and_send_mc_handle, and recv_and_create_mc_handle so that after calling
cuMemExportToShareableHandle you call send_fd(...) and then immediately close
the exported descriptor (uc_fd_send) and after calling recv_fd(...) and
importing via cuMemImportFromShareableHandle you immediately close the received
descriptor (uc_fd_recv); do the same for mc_fd in create_and_send_mc_handle and
recv_and_create_mc_handle (close mc_fd after send and after import) using
os.close or an equivalent close helper so the allocation lifetime remains
managed by cuMem* APIs while avoiding leaking file descriptors.

In `@tests/comm/test_mixed_comm.py`:
- Around line 227-231: The test currently waits indefinitely on process.join()
and doesn't stop other workers if one fails; update the loop over process_list
to join each process with a short timeout, check for non-zero exitcode from the
worker started by _run_worker, and on first non-zero exitcode immediately
terminate() (or kill if needed) and join the remaining processes to avoid hangs;
after terminating the rest, re-check exit codes and raise/assert with the
failing process's exitcode and PID/index so the test fails fast and CI won't
hang.

---

Duplicate comments:
In `@flashinfer/comm/mixed_comm.py`:
- Around line 750-758: The code broadcasts a CPU tensor `uid` using
torch.distributed.broadcast which may use the NCCL/CUDA group and fail for CPU
tensors before nvshmem_init; replace that broadcast with a CPU-safe mechanism
(e.g., torch.distributed.broadcast_object_list or a CPU-capable process group)
when sending the NVSHMEM unique ID obtained via
mixed_comm_module.nvshmem_get_unique_id() and sized by
mixed_comm_module.nvshmem_unique_id_size(); ensure the root
(para_info.world_rank == 0) fills `uid` and all ranks call the CPU-safe
broadcast so nvshmem_init() receives the correct ID.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 92c31901-a293-407b-b9b8-86adb85f3e9e

📥 Commits

Reviewing files that changed from the base of the PR and between cd729b5b6c1050c092f72bb5e733980430cc7ede and be27834.

📒 Files selected for processing (3)
  • benchmarks/bench_mixed_comm.py
  • flashinfer/comm/mixed_comm.py
  • tests/comm/test_mixed_comm.py

Comment thread benchmarks/bench_mixed_comm.py Outdated
Comment thread benchmarks/bench_mixed_comm.py Outdated
Comment thread flashinfer/comm/mixed_comm.py
Comment thread flashinfer/comm/mixed_comm.py
Comment thread tests/comm/test_mixed_comm.py Outdated
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
@aleozlx aleozlx self-assigned this Apr 23, 2026
@aleozlx aleozlx merged commit c9eb3cd into flashinfer-ai:main Apr 23, 2026
30 of 31 checks passed
@aleozlx aleozlx mentioned this pull request Apr 25, 2026
aleozlx added a commit that referenced this pull request May 5, 2026
## Description

Bump version to 0.6.10 for release.

## Related Issues (Gated-by PRs)


https://github.com/flashinfer-ai/flashinfer/issues?q=is%3Aopen+label%3Av0.6.10

## Reviewer Notes

**API changes review**

API changes since v0.6.9

```diff
$ git diff v0.6.9..main -- "*.py" | grep -B5 -A20 "@flashinfer_api"
     register_custom_op,
@@ -67,7 +73,7 @@ def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=silu_and_mul_trace)
 def silu_and_mul(
     input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None
 ) -> torch.Tensor:
@@ -112,7 +118,7 @@ def silu_and_mul(
     return out
 
 
-@flashinfer_api
+@flashinfer_api(trace=gelu_tanh_and_mul_trace)
 def gelu_tanh_and_mul(
     input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None
 ) -> torch.Tensor:
@@ -153,7 +159,7 @@ def gelu_tanh_and_mul(
     return out
 
 
-@flashinfer_api
+@flashinfer_api(trace=gelu_and_mul_trace)
 def gelu_and_mul(
     input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None
 ) -> torch.Tensor:
@@ -194,7 +200,7 @@ def gelu_and_mul(
     return out
 
 
-@flashinfer_api
+@flashinfer_api(trace=silu_and_mul_scaled_nvfp4_experts_quantize_trace)
 def silu_and_mul_scaled_nvfp4_experts_quantize(
     a,
     mask,
diff --git a/flashinfer/aot.py b/flashinfer/aot.py
index dfb05150..d26d5407 100644
--- a/flashinfer/aot.py
+++ b/flashinfer/aot.py
@@ -543,6 +543,7 @@ def gen_all_modules(
     if add_comm:
         from .jit.comm import (
             gen_comm_alltoall_module,
+            gen_dcp_alltoall_module,
             gen_moe_alltoall_module,
             gen_trtllm_comm_module,
             gen_trtllm_mnnvl_comm_module,
@@ -554,6 +555,11 @@ def gen_all_modules(
             jit_specs.append(gen_trtllm_comm_module())
             jit_specs.append(gen_trtllm_mnnvl_comm_module())
             jit_specs.append(gen_moe_alltoall_module())
+            # dcp_alltoall: kernel itself supports SM90+, but ptxas 12.6.0 has
--
 
-def flashinfer_api(func: Callable = None) -> Callable:
+# ---------------------------------------------------------------------------
+# Trace template registry
+# ---------------------------------------------------------------------------
+# Populated automatically by _attach_fi_trace whenever @flashinfer_api is
+# given a trace= argument.  Each entry is (original_func, template, label)
+# where label is the template's name_prefix (or op_type as fallback).
+#
+# For dispatch callables (trace=some_fn), every template listed in
+# some_fn.templates is registered if that attribute exists.
+#
+# Read by tests/trace/test_fi_trace_template_consistency.py to auto-discover
+# all registered templates without requiring manual maintenance.
+_TRACE_REGISTRY: List[Tuple[Callable, Any, str]] = []
+
+
+def _attach_fi_trace(
+    wrapped: Callable,
+    original: Callable,
+    trace_template=None,
+) -> Callable:
+    """Attach a ``fi_trace`` callable to *wrapped*.
+
+    Three resolution strategies, tried in order:
+
--
+
+        warnings.warn(
+            f"[flashinfer] Failed to attach fi_trace to '{_func_name}': "
+            f"{type(_exc).__name__}: {_exc}\n"
+            f"The function will work normally but fi_trace will be unavailable. "
+            f"Fix the TraceTemplate passed to @flashinfer_api(trace=...).",
+            stacklevel=3,
+        )
+    return wrapped
+
+
+def flashinfer_api(func: Callable = None, *, trace=None) -> Callable:
     """
     Decorator to FlashInfer's APIs.
 
@@ -1489,11 +1644,12 @@ def flashinfer_api(func: Callable = None) -> Callable:
     - The %i pattern is automatically replaced with the process ID for multi-process environments.
     - The logger does not propagate to the root logger to avoid duplicate logs.
     """
-    # If logging is disabled, return original function with zero overhead
+    # If logging is disabled, return original function with zero overhead.
+    # We still attach fi_trace so it is always available regardless of log level.
     if _API_LOG_LEVEL == 0:
         if func is None:
-            return lambda f: f
-        return func
--
 @functools.cache
@@ -135,7 +136,7 @@ class BatchAttention:
             causal,
         )
 
-    @flashinfer_api
+    @flashinfer_api(trace=batch_attention_run_trace)
     def run(
         self,
         q: torch.Tensor,
@@ -209,6 +210,8 @@ class BatchAttentionWithAttentionSinkWrapper(BatchPrefillWithPagedKVCacheWrapper
     a convenient interface for using attention sinks during prefill or decode attention.
     """
 
+    # No @flashinfer_api here: parent class BatchPrefillWithPagedKVCacheWrapper
+    # already decorates __init__, so decorating again produces double log entries.
     def __init__(
         self,
         float_workspace_buffer: torch.Tensor,
diff --git a/flashinfer/attention/cute_dsl/__init__.py b/flashinfer/attention/cute_dsl/__init__.py
new file mode 100644
index 00000000..3e029627
--- /dev/null
+++ b/flashinfer/attention/cute_dsl/__init__.py
@@ -0,0 +1,37 @@
+# Copyright (c) 2026 by FlashInfer team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
--
 
@@ -31,7 +37,7 @@ def get_cascade_module():
     return gen_cascade_module().build_and_load()
 
 
-@flashinfer_api
+@flashinfer_api(trace=merge_state_trace)
 @register_custom_op("flashinfer::merge_state", mutates_args=())
 def merge_state(
     v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
@@ -98,7 +104,7 @@ def _fake_merge_state(
     return v, s
 
 
-@flashinfer_api
+@flashinfer_api(trace=merge_state_in_place_trace)
 @register_custom_op("flashinfer::merge_state_in_place", mutates_args=("v", "s"))
 def merge_state_in_place(
     v: torch.Tensor,
@@ -159,7 +165,7 @@ def _fake_merge_state_in_place(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=merge_states_trace)
 @register_custom_op("flashinfer::merge_states", mutates_args=())
 def merge_states(v: torch.Tensor, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
     r"""Merge multiple attention states (v, s).
@@ -512,7 +518,7 @@ class MultiLevelCascadeAttentionWrapper:
 
     begin_forward = plan
 
-    @flashinfer_api
+    @flashinfer_api(trace=multi_level_cascade_run_trace)
     def run(
         self,
         q: torch.Tensor,
diff --git a/flashinfer/comm/__init__.py b/flashinfer/comm/__init__.py
index 5f186002..31d23a99 100644
--- a/flashinfer/comm/__init__.py
+++ b/flashinfer/comm/__init__.py
@@ -65,4 +65,15 @@ from .trtllm_moe_alltoall import (
     moe_a2a_wrap_payload_tensor_in_workspace as moe_a2a_wrap_payload_tensor_in_workspace,
 )
 
+# DCP A2A (Decode Context Parallel Attention Reduction)
+from .dcp_alltoall import decode_cp_a2a_alltoall as decode_cp_a2a_alltoall
+from .dcp_alltoall import (
+    decode_cp_a2a_allocate_workspace as decode_cp_a2a_allocate_workspace,
+)
+from .dcp_alltoall import decode_cp_a2a_init_workspace as decode_cp_a2a_init_workspace
+from .dcp_alltoall import decode_cp_a2a_workspace_size as decode_cp_a2a_workspace_size
+
 # from .mnnvl import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo
--
 from .trtllm_ar import trtllm_create_ipc_workspace_for_all_reduce_fusion
@@ -449,7 +450,7 @@ def create_allreduce_fusion_workspace(
 # ============================================================================
 
 
-@flashinfer_api
+@flashinfer_api(trace=allreduce_fusion_trace)
 def allreduce_fusion(
     input: torch.Tensor,
     workspace: AllReduceFusionWorkspace,
diff --git a/flashinfer/comm/dcp_alltoall.py b/flashinfer/comm/dcp_alltoall.py
new file mode 100644
index 00000000..3047f76c
--- /dev/null
+++ b/flashinfer/comm/dcp_alltoall.py
@@ -0,0 +1,255 @@
+"""
+DCP All-to-All Operations for DCP Attention Reduction
+
+Provides the DCP LL128 FIFO-based all-to-all kernel for context-parallel
+attention reduction. Uses SM90+ features (TMA, mbarrier).
+
+Usage protocol::
+
+    # 1. Query workspace size
+    ws_bytes = decode_cp_a2a_workspace_size(cp_size)
+
--
+
+
+# ─── Public API ───────────────────────────────────────────────────────────
+
+
+@flashinfer_api
+def decode_cp_a2a_workspace_size(cp_size: int) -> int:
+    """Return the workspace size **in bytes** per rank for the given CP group size.
+
+    Args:
+        cp_size: Context-parallel group size (number of ranks).
+
+    Returns:
+        Workspace size in bytes per rank.
+
+    Example::
+
+        >>> decode_cp_a2a_workspace_size(4)
+        16778240
+    """
+    return get_dcp_alltoall_module().get_workspace_size_per_rank(cp_size)
+
+
+@flashinfer_api
+def decode_cp_a2a_allocate_workspace(
+    cp_size: int,
+    cp_rank: int,
+    *,
+    mapping: Optional[Mapping] = None,
+    mnnvl_config: Optional[MnnvlConfig] = None,
+) -> torch.Tensor:
+    """Allocate a workspace tensor of shape ``[cp_size, ws_elems_per_rank]``.
+
+    After allocation, call :func:`decode_cp_a2a_init_workspace` followed by a
+    cross-rank barrier before the first :func:`decode_cp_a2a_alltoall` call.
+
+    Two allocation modes:
+
+    - **MNNVL** (``mapping`` provided): Cross-rank visible GPU memory via
+      FlashInfer's ``MnnvlMemory``. Required for multi-node or when ranks
+      cannot see each other's device memory directly.
+    - **Plain device memory** (``mapping=None``): Standard ``torch.zeros``
+      allocation. Sufficient for single-node with NVLink P2P.
+
--
+
+    ws_elems_per_rank = (ws_bytes + 7) // 8
+    return torch.zeros(cp_size, ws_elems_per_rank, dtype=torch.int64, device="cuda")
+
+
+@flashinfer_api
+def decode_cp_a2a_init_workspace(
+    workspace: torch.Tensor,
+    cp_rank: int,
+    cp_size: int,
+) -> None:
+    """Initialize the workspace FIFO buffers. Call once before the first alltoall.
+
+    Resets the FIFO buffers in the **local** workspace row
+    (``workspace[cp_rank]``). This function is **synchronous**: when it
+    returns, the GPU memset is guaranteed to have completed.
+
+    .. important::
+        With MNNVL workspaces, **all ranks** must complete
+        ``decode_cp_a2a_init_workspace`` and execute a cross-rank barrier
+        (e.g. ``dist.barrier(group)``) before **any** rank calls
+        :func:`decode_cp_a2a_alltoall`. Without the barrier, a rank may
+        start writing to a peer's FIFO before that peer has finished
+        initializing → deadlock.
+
+    Args:
--
+    # subsequent cross-GPU alltoall can race with the unfinished memset
+    # on MNNVL memory, causing a deadlock.
+    torch.cuda.current_stream().synchronize()
+
+
+@flashinfer_api(trace=decode_cp_a2a_alltoall_trace)
+def decode_cp_a2a_alltoall(
+    partial_o: torch.Tensor,
+    softmax_stats: torch.Tensor,
+    workspace: torch.Tensor,
+    cp_rank: int,
+    cp_size: int,
+    enable_pdl: Optional[bool] = None,
+) -> tuple[torch.Tensor, torch.Tensor]:
+    """Perform the DCP all-to-all exchange.
+
+    Each rank sends its ``partial_o[..., peer, :]`` slice to the
+    corresponding peer and receives all peers' contributions into the
+    output tensors.
+
+    Args:
+        partial_o: ``[..., cp_size, D]`` — half or bfloat16.
+            ``D * element_size`` must be 16-byte aligned.
+        softmax_stats: ``[..., cp_size, S]`` — float32, ``S >= 2`` and even.
+            Batch dimensions must match ``partial_o``.
+        workspace: ``[cp_size, ws_elems_per_rank]`` int64 tensor from
--
+    MixedCommOp.ALLREDUCE_ALLGATHER: _allreduce_allgather,
+    MixedCommOp.REDUCESCATTER_ALLREDUCE: _reducescatter_allreduce,
+}
+
+
+@flashinfer_api
+@backend_requirement(
+    backend_checks={},
+    common_check=_common_check,
+)
+def run_mixed_comm(
+    op: MixedCommOp,
+    handler: MixedCommHandler,
+    x_in: torch.Tensor,
+    x_out: torch.Tensor | None = None,
+    mode: MixedCommMode | None = None,
+) -> torch.Tensor:
+    """Execute a mixed communication operation.
+
+    This is the main entry point for running communication collectives
+    through the mixed communication handler. It supports fused GPU kernels
+    (using virtual memory intra-node and nvshmem inter-node), NCCL-based
+    fallbacks, and autotuned mode selection.
+
+    Args:
+        op: The communication operation to perform.
--
 @functools.cache
@@ -28,7 +29,7 @@ def get_concat_mla_module():
     return gen_concat_mla_module().build_and_load()
 
 
-@flashinfer_api
+@flashinfer_api(trace=concat_mla_k_trace)
 def concat_mla_k(
     k: torch.Tensor,
     k_nope: torch.Tensor,
diff --git a/flashinfer/cudnn/decode.py b/flashinfer/cudnn/decode.py
index 195ca2d4..9b593095 100644
--- a/flashinfer/cudnn/decode.py
+++ b/flashinfer/cudnn/decode.py
@@ -4,6 +4,7 @@ from typing import Optional
 import torch
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.attention import cudnn_batch_decode_trace
 from .utils import get_cudnn_fmha_gen_module
 
 try:
@@ -253,7 +254,7 @@ def _batch_decode_with_kv_cache(
     return out
 
 
-@flashinfer_api
+@flashinfer_api(trace=cudnn_batch_decode_trace)
 def cudnn_batch_decode_with_kv_cache(
     q: torch.Tensor,
     k_cache: torch.Tensor,
diff --git a/flashinfer/cudnn/prefill.py b/flashinfer/cudnn/prefill.py
index fc1bbb5f..b16d6043 100644
--- a/flashinfer/cudnn/prefill.py
+++ b/flashinfer/cudnn/prefill.py
@@ -4,6 +4,7 @@ from typing import Optional
 import torch
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.attention import cudnn_batch_prefill_trace
 from .utils import get_cudnn_fmha_gen_module
 
 try:
@@ -558,7 +559,7 @@ def _batch_prefill_with_kv_cache(
         return out, None
 
 
-@flashinfer_api
+@flashinfer_api(trace=cudnn_batch_prefill_trace)
 def cudnn_batch_prefill_with_kv_cache(
     q: torch.Tensor,
     k_cache: torch.Tensor,
diff --git a/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py b/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
index 0b50c22c..f25aa6fd 100644
--- a/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
+++ b/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
@@ -38,6 +38,7 @@ import torch
 from cutlass import Float32, Int32, Int64, Uint32, Uint8
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.norm import add_rmsnorm_fp4quant_trace
 from ..utils import device_support_pdl
 from .fp4_common import (
     # Constants
@@ -1042,7 +1043,7 @@ def _get_compiled_kernel(
     return tensor_api
 
 
-@flashinfer_api
+@flashinfer_api(trace=add_rmsnorm_fp4quant_trace)
 def add_rmsnorm_fp4quant(
     input: torch.Tensor,
     residual: torch.Tensor,
diff --git a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py
index 333697ab..b7aabc36 100644
--- a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py
+++ b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py
@@ -20,6 +20,7 @@ import torch
 from cutlass import Float32, Int32
 
 from flashinfer.api_logging import flashinfer_api
+from flashinfer.trace.templates.attention import cute_dsl_batch_mla_run_trace
 from flashinfer.utils import device_support_pdl
 from flashinfer.cute_dsl.utils import (
     get_max_active_clusters,
@@ -519,7 +520,7 @@ class BatchMLADecodeCuteDSLWrapper:
                 f"out_dtype={self._o_dtype}"
             )
 
-    @flashinfer_api
+    @flashinfer_api(trace=cute_dsl_batch_mla_run_trace)
     def run(
         self,
         q: torch.Tensor,
diff --git a/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py b/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py
index 58a24abe..ee0cd5e7 100644
--- a/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py
+++ b/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py
@@ -21,6 +21,7 @@ import cutlass.cute as cute
 from cutlass.cute.typing import Int32
 
 from flashinfer.api_logging import flashinfer_api
+from flashinfer.trace.templates.attention import cute_dsl_batch_prefill_run_trace
 
 from ..config import AttentionConfig, AttentionFusion
 from ..fusion.mask import MaskType
@@ -371,7 +372,7 @@ class BatchPrefillCuteDSLWrapper:
                     f"device={self._device}"
                 )
 
-    @flashinfer_api
+    @flashinfer_api(trace=cute_dsl_batch_prefill_run_trace)
     def run(
         self,
         q: torch.Tensor,
diff --git a/flashinfer/cute_dsl/rmsnorm_fp4quant.py b/flashinfer/cute_dsl/rmsnorm_fp4quant.py
index bc4acffc..97ce68a1 100644
--- a/flashinfer/cute_dsl/rmsnorm_fp4quant.py
+++ b/flashinfer/cute_dsl/rmsnorm_fp4quant.py
@@ -32,6 +32,7 @@ import torch
 from cutlass import Float32, Int32, Uint8
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.norm import rmsnorm_fp4quant_trace
 from ..utils import device_support_pdl
 from .fp4_common import (
     # Constants
@@ -771,7 +772,7 @@ def _get_compiled_kernel(
     return tensor_api
 
 
-@flashinfer_api
+@flashinfer_api(trace=rmsnorm_fp4quant_trace)
 def rmsnorm_fp4quant(
     input: torch.Tensor,
     weight: torch.Tensor,
diff --git a/flashinfer/decode.py b/flashinfer/decode.py
index 822aca40..5e9eb515 100644
--- a/flashinfer/decode.py
+++ b/flashinfer/decode.py
@@ -22,6 +22,12 @@ from typing import Any, List, Literal, Optional, Tuple, Union, overload
 import torch
 
 from .api_logging import flashinfer_api
+from .trace.templates.attention import (
+    gqa_paged_decode_trace,
+    single_decode_with_kv_cache_trace,
+    trtllm_batch_decode_trace,
+    xqa_batch_decode_trace,
+)
 
 ## NOTE: MLA functions have been moved to mla.py, but we keep the aliases here for backward compatibility.
 from .mla import (
@@ -400,7 +406,7 @@ def single_decode_with_kv_cache(
 ) -> Tuple[torch.Tensor, torch.Tensor]: ...
 
 
-@flashinfer_api
+@flashinfer_api(trace=single_decode_with_kv_cache_trace)
 def single_decode_with_kv_cache(
     q: torch.Tensor,
     k: torch.Tensor,
@@ -1215,7 +1221,7 @@ class BatchDecodeWithPagedKVCacheWrapper:
         kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]: ...
 
-    @flashinfer_api
+    @flashinfer_api(trace=gqa_paged_decode_trace)
     def run(
         self,
         q: torch.Tensor,
@@ -1577,6 +1583,8 @@ class CUDAGraphBatchDecodeWithPagedKVCacheWrapper(BatchDecodeWithPagedKVCacheWra
     :class:`BatchDecodeWithPagedKVCacheWrapper`
     """
 
+    # No @flashinfer_api here: parent class BatchDecodeWithPagedKVCacheWrapper
+    # already decorates __init__, so decorating again produces double log entries.
     def __init__(
         self,
         workspace_buffer: torch.Tensor,
@@ -2232,7 +2240,7 @@ def get_trtllm_gen_decode_module(*args):
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_batch_decode_trace)
 def trtllm_batch_decode_with_kv_cache(
     query: torch.Tensor,
     kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
@@ -2618,7 +2626,7 @@ def trtllm_batch_decode_with_kv_cache(
 
 
 # xqa uses NHD layout
-@flashinfer_api
+@flashinfer_api(trace=xqa_batch_decode_trace)
 def xqa_batch_decode_with_kv_cache(
     query: torch.Tensor,
     kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
diff --git a/flashinfer/fi_trace.py b/flashinfer/fi_trace.py
new file mode 100644
index 00000000..1104eb6f
--- /dev/null
+++ b/flashinfer/fi_trace.py
@@ -0,0 +1,285 @@
+# Copyright (c) 2025 by FlashInfer team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
--
+
+"""
+fi_trace: Generate `flashinfer-bench <https://github.com/flashinfer-ai/flashinfer-bench>`_
+compatible definition JSON for FlashInfer APIs.
+
+Every ``@flashinfer_api(trace=<template>)``-decorated function supports two
+usage modes:
+
+Auto-dump (recommended)
+-----------------------
+Set environment variables **before** importing flashinfer, then run your
+workload normally.  No explicit ``fi_trace`` call is needed.
+
+.. code-block:: bash
+
+    FLASHINFER_TRACE_DUMP=1 \\
+    FLASHINFER_TRACE_DUMP_DIR=./fi_trace_out \\
+    python my_script.py
+
+Every decorated function writes a ``<name>.json`` file on its **first** call
+for each unique set of const-axis values (e.g. head dimensions, vocab size).
+Subsequent calls with the same shape are deduplicated — the file is written
+only once per process.  The output directory is created automatically.
+
+Explicit call (for selective or programmatic use)
+-------------------------------------------------
--
+from pathlib import Path
+from typing import Any, Callable, Dict, Optional, Union
+
+# ---------------------------------------------------------------------------
+# Legacy registry — kept for backwards compatibility.
+# New code should use @flashinfer_api(trace=TraceTemplate(...)) instead.
+# ---------------------------------------------------------------------------
+
+_REGISTRY: Dict[str, Any] = {}
+
+
+def register_fi_trace(qualname: str, spec: Any) -> None:
+    """Register a legacy FiTraceSpec for the function with the given qualname.
+
+    .. deprecated::
+        Use ``@flashinfer_api(trace=TraceTemplate(...))`` instead.
+    """
+    _REGISTRY[qualname] = spec
+
+
+def build_fi_trace_fn(spec: Any) -> Callable[..., Dict[str, Any]]:
+    """Build a fi_trace callable from a legacy FiTraceSpec.
+
+    .. deprecated::
+        Use ``TraceTemplate.build_fi_trace_fn`` instead.
+    """
+    # Import the old implementation from the trace package for backwards compat.
+    from .trace.template import (  # noqa: PLC0415,F401
+        Const,
+        Scalar,
+        Tensor,
+        TraceTemplate,
+        Var,
+    )
+    import json  # noqa: PLC0415
+    import os  # noqa: PLC0415
--
+    """Generate a flashinfer-bench definition JSON for any FlashInfer API call.
+
+    Parameters
+    ----------
+    func_or_method:
+        A ``@flashinfer_api``-decorated function or (bound) method.
+    save_dir:
+        Directory where the JSON definition file should be written.
+        Falls back to ``FLASHINFER_TRACE_DUMP_DIR`` env-var when *None*.
+    **kwargs:
+        The same tensor arguments you would pass to the real API.
+
+    Returns
+    -------
+    dict
+        A flashinfer-bench compatible definition dictionary.
+
+    Examples
+    --------
+    Standalone function::
+
+        defn = fi_trace(flashinfer.norm.rmsnorm, input=hidden, weight=weight)
+
+    Bound method (instance.run)::
+
+        defn = fi_trace(wrapper.run, q=q_tensor, paged_kv_cache=(k, v))
--
+    trace_fn = getattr(actual_func, "fi_trace", None)
+    if trace_fn is None:
+        qualname = getattr(actual_func, "__qualname__", repr(actual_func))
+        raise ValueError(
+            f"No fi_trace spec is registered for '{qualname}'. "
+            "Only @flashinfer_api(trace=...)-decorated functions support fi_trace."
+        )
+    return trace_fn(save_dir=save_dir, **kwargs)
diff --git a/flashinfer/fused_moe/__init__.py b/flashinfer/fused_moe/__init__.py
index df6e1f72..d983f9d4 100644
--- a/flashinfer/fused_moe/__init__.py
+++ b/flashinfer/fused_moe/__init__.py
@@ -17,6 +17,8 @@ limitations under the License.
 from .core import (
     convert_to_block_layout,
     cutlass_fused_moe,
+    interleave_moe_scales_for_sm90_mixed_gemm,
+    interleave_moe_weights_for_sm90_mixed_gemm,
     gen_cutlass_fused_moe_sm120_module,
     gen_cutlass_fused_moe_sm103_module,
     gen_cutlass_fused_moe_sm100_module,
@@ -64,6 +66,8 @@ __all__ = [
     "WeightLayout",
     "convert_to_block_layout",
     "cutlass_fused_moe",
+    "interleave_moe_scales_for_sm90_mixed_gemm",
--
+        ),
     )
 
 
-# ref: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py#L121
 @flashinfer_api
+def interleave_moe_scales_for_sm90_mixed_gemm(
+    scales: torch.Tensor,
+    group_size: int = 32,
+) -> torch.Tensor:
+    """Interleave MXFP4 block scales for the SM90 mixed-input MoE GEMM.
+
+    The kernel expects scales in layout
+    ``(num_experts, K // (group_size * 4), rows * 4)`` rather than the natural
+    ``(num_experts, rows, K // group_size)`` produced by the MXFP4 quantizer.
+    This helper performs the reshape + permute equivalent to TensorRT-LLM's
+    ``WFP4A16FusedMoEMethod.load_quant_scales`` (PR #12451), with the fixed
+    interleave factor of ``128 // group_size`` used for MXFP4.
+
+    Parameters
+    ----------
+    scales:
+        ``[num_experts, rows, K // group_size]`` uint8 tensor of E8M0 block
+        scales.
+    group_size:
+        MXFP4 quantization group size (default 32).
--
+        scales.reshape(e, rows, kgs // factor, factor).permute(0, 2, 1, 3).contiguous()
+    )
+    return tmp.reshape(e, kgs // factor, rows * factor)
+
+
+@flashinfer_api
+def interleave_moe_weights_for_sm90_mixed_gemm(
+    weight: torch.Tensor,
+    quant_type: str = "fp4",
+) -> torch.Tensor:
+    """Interleave 4-bit packed MoE weights for the SM90 mixed-input GEMM.
+
+    The SM90 mixed-dtype MoE GEMM (used by ``cutlass_fused_moe`` with
+    ``use_w4_group_scaling=True``) expects weights in a specific interleaved
+    layout; without preprocessing, the LUT-based FP4→BF16 conversion reads
+    bytes from the wrong positions and the output diverges from a dequantized
+    reference for any K > 128. TensorRT-LLM's W4A16 MoE runs the equivalent
+    preprocessing at weight-load time (see
+    ``interleave_4bit_weights_for_Hopper_mixed_gemm`` in TRT-LLM PR #12451).
+
+    Parameters
+    ----------
+    weight:
+        ``[num_experts, n, k // 2]`` uint8 CUDA tensor (4-bit values packed
+        two-per-byte).
+    quant_type:
--
+    )
+    return out
+
+
+# ref: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py#L121
+@flashinfer_api(trace=cutlass_fused_moe_trace)
 def cutlass_fused_moe(
     input: torch.Tensor,
     token_selected_experts: torch.Tensor,
@@ -1027,8 +1151,8 @@ def get_trtllm_moe_sm100_module():
                     DynamicTensorSpec(
                         input_idx,
                         dim_idx,
-                        get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens, 1),
-                        lambda x: min(last_positive_power_of_2(x), tune_max_num_tokens),
+                        get_hybrid_num_tokens_buckets(tune_max_num_tokens, 1),
+                        lambda x: map_to_hybrid_bucket(x, tune_max_num_tokens),
                         initializers,
                     ),
                 ),
@@ -2344,7 +2468,7 @@ def _validate_routing_replay_out(
         raise ValueError("routing_replay_out must be contiguous (packed row-major)")
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_bf16_moe_trace)
 def trtllm_bf16_moe(
     routing_logits: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
@@ -2452,7 +2576,7 @@ def trtllm_bf16_moe(
         return result
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_bf16_routed_moe_trace)
 def trtllm_bf16_routed_moe(
     topk_ids: torch.Tensor,
     hidden_states: torch.Tensor,
@@ -2557,7 +2681,7 @@ def trtllm_bf16_routed_moe(
         return result
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp8_per_tensor_scale_moe_trace)
 def trtllm_fp8_per_tensor_scale_moe(
     routing_logits: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
@@ -2658,7 +2782,7 @@ def trtllm_fp8_per_tensor_scale_moe(
         return result
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp8_block_scale_moe_trace_dispatch)
 def trtllm_fp8_block_scale_moe(
     routing_logits: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
@@ -2779,7 +2903,7 @@ def trtllm_fp8_block_scale_moe(
         return result
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp8_block_scale_routed_moe_trace)
 def trtllm_fp8_block_scale_routed_moe(
     topk_ids: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
@@ -2893,7 +3017,7 @@ def trtllm_fp8_block_scale_routed_moe(
         return result
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp4_block_scale_moe_trace_dispatch)
 def trtllm_fp4_block_scale_moe(
     routing_logits: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
@@ -3030,7 +3154,7 @@ def trtllm_fp4_block_scale_moe(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp4_block_scale_routed_moe_trace)
 def trtllm_fp4_block_scale_routed_moe(
     topk_ids: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
@@ -3165,7 +3289,7 @@ def trtllm_fp4_block_scale_routed_moe(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_mxint4_block_scale_moe_trace)
 def trtllm_mxint4_block_scale_moe(
     routing_logits: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
diff --git a/flashinfer/fused_moe/cute_dsl/b12x_moe.py b/flashinfer/fused_moe/cute_dsl/b12x_moe.py
index d2cbc8b0..34916df5 100644
--- a/flashinfer/fused_moe/cute_dsl/b12x_moe.py
+++ b/flashinfer/fused_moe/cute_dsl/b12x_moe.py
@@ -42,11 +42,12 @@ from typing import Optional, Tuple
 import torch
 
 from ...api_logging import flashinfer_api
+from ...trace.templates.moe import b12x_fused_moe_trace, b12x_moe_wrapper_run_trace
 from ...utils import supported_compute_capability
 
 
 @supported_compute_capability([120, 121])
-@flashinfer_api
+@flashinfer_api(trace=b12x_fused_moe_trace)
 def b12x_fused_moe(
     x: torch.Tensor,
     w1_weight: torch.Tensor,
@@ -293,7 +294,7 @@ class B12xMoEWrapper:
             device=self.device,
         )
 
-    @flashinfer_api
+    @flashinfer_api(trace=b12x_moe_wrapper_run_trace)
     def run(
         self,
         x: torch.Tensor,
diff --git a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py
index f6cf1b67..e266cb77 100644
--- a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py
+++ b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py
@@ -89,8 +89,8 @@ from flashinfer.cute_dsl.fp4_common import (
     st_global_u64,
     scatter_add_bf16x2,
 )
-from flashinfer.gemm.kernels.dense_blockscaled_gemm_sm120 import (
-    Sm120BlockScaledDenseGemmKernel as DenseGemmKernel,
+from flashinfer.gemm.kernels.dense_blockscaled_gemm_sm120_b12x import (
+    Sm120B12xBlockScaledDenseGemmKernel as DenseGemmKernel,
 )
 
 
diff --git a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py
index e7fdae92..670b3ad8 100644
--
 from .moe_utils import (
@@ -530,7 +534,7 @@ class CuteDslMoEWrapper:
             enable_pdl=enable_pdl,
         )
 
-    @flashinfer_api
+    @flashinfer_api(trace=cute_dsl_moe_wrapper_run_trace)
     def run(
         self,
         x: torch.Tensor,
@@ -686,7 +690,7 @@ def _cute_dsl_fused_moe_nvfp4_impl(
 
 
 @supported_compute_capability([100, 103])
-@flashinfer_api
+@flashinfer_api(trace=cute_dsl_fused_moe_nvfp4_trace)
 def cute_dsl_fused_moe_nvfp4(
     x: torch.Tensor,
     x_sf: torch.Tensor,
diff --git a/flashinfer/fused_moe/cute_dsl/tuner.py b/flashinfer/fused_moe/cute_dsl/tuner.py
index 0cc8628e..636043db 100644
--- a/flashinfer/fused_moe/cute_dsl/tuner.py
+++ b/flashinfer/fused_moe/cute_dsl/tuner.py
@@ -42,8 +42,8 @@ from ...autotuner import (
     TuningConfig,
 )
 from ..utils import (
-    get_last_power_of_2_num_tokens_buckets,
-    last_positive_power_of_2,
+    get_hybrid_num_tokens_buckets,
+    map_to_hybrid_bucket,
 )
 
 logger = logging.getLogger(__name__)
@@ -273,10 +273,8 @@ class CuteDslFusedMoENvfp4Runner(TunableRunner):
                 DynamicTensorSpec(
--
 import torch
@@ -137,7 +138,7 @@ def get_dsv3_fused_routing_module():
 
 
 @backend_requirement({}, common_check=_check_dsv3_fused_routing_supported)
-@flashinfer_api
+@flashinfer_api(trace=fused_topk_deepseek_trace)
 def fused_topk_deepseek(
     scores: torch.Tensor,
     bias: torch.Tensor,
diff --git a/flashinfer/fused_moe/utils.py b/flashinfer/fused_moe/utils.py
index 004271a1..91f37aa5 100644
--- a/flashinfer/fused_moe/utils.py
+++ b/flashinfer/fused_moe/utils.py
@@ -209,29 +209,102 @@ def nearest_in_buckets(x: int, buckets: List[int]) -> int:
     return min(max(next_positive_power_of_2(x), buckets[0]), buckets[-1])
 
 
-def get_power_of_2_num_tokens_buckets(max_num_tokens) -> Tuple[int]:
-    """Return descending power-of-2 buckets from ``next_power_of_2(max_num_tokens)`` down to 1."""
-    max_num_tokens = next_positive_power_of_2(max_num_tokens)
-    num_token_buckets = []
-    m = max_num_tokens
-    while m >= 1:
-        num_token_buckets.append(m)
-        m //= 2
+_PHASE1_END = 256
--
 
@@ -106,7 +114,7 @@ TILE_V = 8  # pretranspose tile size
 # ============================================================================
 
 
-@flashinfer_api
+@flashinfer_api(trace=gated_delta_rule_decode_trace)
 def gated_delta_rule_decode_pretranspose(
     q: torch.Tensor,
     k: torch.Tensor,
@@ -394,7 +402,7 @@ def gated_delta_rule_decode_pretranspose(
 # ============================================================================
 
 
-@flashinfer_api
+@flashinfer_api(trace=gated_delta_rule_decode_trace)
 def gated_delta_rule_decode(
     q: torch.Tensor,
     k: torch.Tensor,
@@ -535,7 +543,7 @@ def gated_delta_rule_decode(
 # ============================================================================
 
 
-@flashinfer_api
+@flashinfer_api(trace=gdn_mtp_trace)
 def gated_delta_rule_mtp(
     q: torch.Tensor,
     k: torch.Tensor,
diff --git a/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py b/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
index 68398d28..53fe44ce 100644
--- a/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
+++ b/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
@@ -3333,8 +3333,7 @@ class GatedDeltaNetChunkedKernel:
 
         gate_handle = load_gate_consumer.wait_and_advance()
 
-        max_coord = tTR_tCcShared[cute.size(tTR_tCcShared) - 1]
-        cumprod_total = sCumprod[max_coord[1], 0, gate_handle.index]
+        cumprod_total = sCumprod[sCumprod.shape[0] - 1, 0, gate_handle.index]
 
         valid_state = not is_first_chunk or self.use_initial_state
         if cutlass.const_expr(valid_state):
diff --git a/flashinfer/gdn_kernels/blackwell/gdn_prefill.py b/flashinfer/gdn_kernels/blackwell/gdn_prefill.py
index 82dcc72b..aafcc671 100644
--- a/flashinfer/gdn_kernels/blackwell/gdn_prefill.py
--
     register_custom_op,
@@ -95,7 +96,7 @@ def get_gdn_prefill_module():
     return SimpleNamespace(gdn_prefill=gdn_prefill)
 
 
-@flashinfer_api
+@flashinfer_api(trace=gdn_prefill_trace)
 def chunk_gated_delta_rule(
     q: torch.Tensor,
     k: torch.Tensor,
diff --git a/flashinfer/gemm/__init__.py b/flashinfer/gemm/__init__.py
index a7795beb..def82216 100644
--- a/flashinfer/gemm/__init__.py
+++ b/flashinfer/gemm/__init__.py
@@ -61,11 +61,11 @@ try:
     from flashinfer.cute_dsl.utils import is_cute_dsl_available
 
     if is_cute_dsl_available():
-        from .kernels.dense_blockscaled_gemm_sm120 import (
-            Sm120BlockScaledDenseGemmKernel as Sm120BlockScaledDenseGemmKernel,
+        from .kernels.dense_blockscaled_gemm_sm120_b12x import (
+            Sm120B12xBlockScaledDenseGemmKernel as Sm120B12xBlockScaledDenseGemmKernel,
         )
 
-        _cute_dsl_kernels.append("Sm120BlockScaledDenseGemmKernel")
+        _cute_dsl_kernels.append("Sm120B12xBlockScaledDenseGemmKernel")
 except ImportError:
--
 from ..utils import (
@@ -325,7 +339,7 @@ def _heuristic_func_mm_bf16(
     common_check=_check_mm_bf16_problem_size,
     heuristic_func=_heuristic_func_mm_bf16,
 )
-@flashinfer_api
+@flashinfer_api(trace=mm_bf16_trace)
 def mm_bf16(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -514,7 +528,7 @@ def _heuristic_func_bmm_bf16(
     common_check=_check_bmm_bf16_problem_size,
     heuristic_func=_heuristic_func_bmm_bf16,
 )
-@flashinfer_api
+@flashinfer_api(trace=bmm_bf16_trace)
 def bmm_bf16(
     A: torch.Tensor,
     B: torch.Tensor,
@@ -815,8 +829,8 @@ _FP8_GEMM_SM100_TUNING_CONFIG = TuningConfig(
         DynamicTensorSpec(
             (0,),  # a_tensor_index
             (-2,),
-            get_last_power_of_2_num_tokens_buckets,
-            last_positive_power_of_2,
+            get_hybrid_num_tokens_buckets,
+            map_to_hybrid_bucket_uncapped,
         ),
     ),
     constraint_specs=(
@@ -871,8 +885,8 @@ _BF16_GEMM_SM100_TUNING_CONFIG = TuningConfig(
         DynamicTensorSpec(
             (0,),  # a_tensor_index
             (-2,),
-            get_last_power_of_2_num_tokens_buckets,
-            last_positive_power_of_2,
--
     constraint_specs=(
@@ -1095,7 +1109,7 @@ def get_tgv_gemm_sm10x_module(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=tgv_gemm_sm100_trace)
 def tgv_gemm_sm100(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -1173,8 +1187,8 @@ def tgv_gemm_sm100(
             DynamicTensorSpec(
                 (a_tensor_index,),
                 (-2,),
-                get_last_power_of_2_num_tokens_buckets,
-                last_positive_power_of_2,
+                get_hybrid_num_tokens_buckets,
+                map_to_hybrid_bucket_uncapped,
             ),
         ),
         constraint_specs=(
@@ -1437,6 +1451,7 @@ class SegmentGEMMWrapper:
     True
     """
 
+    @flashinfer_api
     def __init__(
         self, float_workspace_buffer: torch.Tensor, backend: str = "auto"
     ) -> None:
@@ -1469,7 +1484,7 @@ class SegmentGEMMWrapper:
         self._float_workspace_buffer = float_workspace_buffer
         self._int_workspace_buffer = int_workspace_buffer
 
-    @flashinfer_api
+    @flashinfer_api(trace=segment_gemm_run_trace)
     def run(
         self,
         x: torch.Tensor,
@@ -2084,6 +2099,8 @@ def build_cudnn_gemm_fp4_graph_override_shape(
     return graph
 
 
+# Internal helper called from mm_fp4; the user-facing mm_fp4 is already
+# decorated, so decorating here would double-log the same invocation.
 def execute_cudnn_gemm_fp4_graph_override_shape(
     graph,
     a,
@@ -2319,6 +2336,8 @@ def build_cudnn_gemm_mxfp8_graph_override_shape(
     return graph
 
 
+# Internal helper called from mm_mxfp8; the user-facing mm_mxfp8 is already
+# decorated, so decorating here would double-log the same invocation.
 def execute_cudnn_gemm_mxfp8_graph_override_shape(
     graph,
--
 ):
@@ -3161,7 +3184,7 @@ def _expand_block_scale_tensor_shape(block_scale_tensor, batch_size):
     return (tuple(block_scale_shape), tuple(block_scale_stride))
 
 
-@flashinfer_api
+@flashinfer_api(trace=mm_fp8_trace)
 def mm_fp8(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -3990,7 +4013,7 @@ def _heuristic_func_mm_mxfp8(
     common_check=_check_mm_mxfp8_problem_size,
     heuristic_func=_heuristic_func_mm_mxfp8,  # result stored in mm_mxfp8.suitable_auto_backends
 )
-@flashinfer_api
+@flashinfer_api(trace=mm_mxfp8_trace)
 def mm_mxfp8(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -4858,8 +4881,8 @@ def _b12x_gemm_fp4_runner(
     """
     import cutlass
 
-    from .kernels.dense_blockscaled_gemm_sm120 import (
-        Sm120BlockScaledDenseGemmKernel,
+    from .kernels.dense_blockscaled_gemm_sm120_b12x import (
+        Sm120B12xBlockScaledDenseGemmKernel,
     )
 
     cutlass_dtype_attr = _TORCH_TO_CUTLASS_DTYPE_ATTR.get(out_dtype)
@@ -4905,7 +4928,7 @@ def _b12x_gemm_fp4_runner(
             ]
             swap_ab = False
             for mma_tiler_mn in sm120_mma_tiler_candidates:
-                if not Sm120BlockScaledDenseGemmKernel.can_implement(
+                if not Sm120B12xBlockScaledDenseGemmKernel.can_implement(
--
     constraint_specs=(
@@ -5195,7 +5217,7 @@ _MM_MXFP8_TUNING_CONFIG = TuningConfig(
     common_check=_check_mm_fp4_problem_size,
     heuristic_func=_heuristic_func_mm_fp4,  # result stored in mm_fp4.suitable_auto_backends
 )
-@flashinfer_api
+@flashinfer_api(trace=mm_fp4_trace)
 def mm_fp4(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -5449,7 +5471,7 @@ def _heuristic_func_bmm_fp8(
     common_check=_check_bmm_fp8_problem_size,
     heuristic_func=_heuristic_func_bmm_fp8,
 )
-@flashinfer_api
+@flashinfer_api(trace=bmm_fp8_trace)
 def bmm_fp8(
     A: torch.Tensor,
     B: torch.Tensor,
@@ -6862,7 +6884,7 @@ def _check_batch_deepgemm_fp8_nt_groupwise(
     {},
     common_check=_check_batch_deepgemm_fp8_nt_groupwise,
 )
-@flashinfer_api
+@flashinfer_api(trace=batch_deepgemm_fp8_nt_groupwise_trace)
 def batch_deepgemm_fp8_nt_groupwise(
     a: torch.Tensor,  # (batch_size, m, k)
     b: torch.Tensor,  # (batch_size, n, k)
@@ -7006,7 +7028,7 @@ def get_fp8_blockscale_gemm_runner_sm90():
     return module.init()
 
 
-@flashinfer_api
+@flashinfer_api(trace=fp8_blockscale_gemm_sm90_trace)
 def fp8_blockscale_gemm_sm90(
     input: torch.Tensor,
     weight: torch.Tensor,
@@ -7588,7 +7610,7 @@ def _heuristic_func_bmm_mxfp8(
     common_check=_check_bmm_mxfp8_problem_size,
     heuristic_func=_heuristic_func_bmm_mxfp8,
 )
-@flashinfer_api
+@flashinfer_api(trace=bmm_mxfp8_trace)
 def bmm_mxfp8(
     A: torch.Tensor,
     B: torch.Tensor,
diff --git a/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py b/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120_b12x.py
similarity index 99%
rename from flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py
rename to flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120_b12x.py
index c49bc815..6eee27a7 100644
--- a/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py
+++ b/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120_b12x.py
@@ -1550,7 +1550,7 @@ class DenseGemmKernel:
 
 
 # Alias for FlashInfer integration
-Sm120BlockScaledDenseGemmKernel = DenseGemmKernel
+Sm120B12xBlockScaledDenseGemmKernel = DenseGemmKernel
 
 
 class _DenseGemmLaunch:
diff --git a/flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py b/flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py
--
     get_cutlass_dtype,
@@ -2951,7 +2952,7 @@ def get_cute_dsl_compiled_masked_gemm_kernel(
     return tensor_api
 
 
-@flashinfer_api
+@flashinfer_api(trace=grouped_gemm_nt_masked_trace)
 def grouped_gemm_nt_masked(
     lhs: Tuple[torch.Tensor, torch.Tensor],
     rhs: Tuple[torch.Tensor, torch.Tensor],
diff --git a/flashinfer/gemm/routergemm.py b/flashinfer/gemm/routergemm.py
index cfde7d43..f83c8974 100644
--- a/flashinfer/gemm/routergemm.py
+++ b/flashinfer/gemm/routergemm.py
@@ -1,4 +1,8 @@
 from ..api_logging import flashinfer_api
+from ..trace.templates.gemm import (
+    mm_M1_16_K7168_N256_trace,
+    tinygemm_bf16_trace,
+)
 from flashinfer.jit import gen_dsv3_router_gemm_module, gen_tinygemm2_module
 import functools
 from types import SimpleNamespace
@@ -176,7 +180,7 @@ def mm_M1_16_K7168_N128(
 
 
 @backend_requirement({}, common_check=_mm_M1_16_K7168_N256_shape_checks)
-@flashinfer_api
+@flashinfer_api(trace=mm_M1_16_K7168_N256_trace)
 def mm_M1_16_K7168_N256(
     mat_a: torch.Tensor,
     mat_b: torch.Tensor,
@@ -324,7 +328,7 @@ def get_tinygemm2_module():
 
 
 @backend_requirement({}, common_check=_tinygemm_bf16_shape_checks)
-@flashinfer_api
+@flashinfer_api(trace=tinygemm_bf16_trace)
 def tinygemm_bf16(
     input: torch.Tensor,
     weight: torch.Tensor,
diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py
index 7f36a314..8378e0ab 100644
--- a/flashinfer/jit/__init__.py
+++ b/flashinfer/jit/__init__.py
@@ -82,6 +82,7 @@ from .comm import gen_trtllm_mnnvl_comm_module as gen_trtllm_mnnvl_comm_module
 from .comm import gen_trtllm_comm_module as gen_trtllm_comm_module
 from .comm import gen_vllm_comm_module as gen_vllm_comm_module
 from .comm import gen_moe_alltoall_module as gen_moe_alltoall_module
+from .comm import gen_dcp_alltoall_module as gen_dcp_alltoall_module
 from .dsv3_optimizations import (
     gen_dsv3_router_gemm_module as gen_dsv3_router_gemm_module,
 )
diff --git a/flashinfer/jit/comm.py b/flashinfer/jit/comm.py
index 46768eed..834f77f9 100644
--- a/flashinfer/jit/comm.py
+++ b/flashinfer/jit/comm.py
@@ -15,7 +15,13 @@ limitations under the License.
--
     gen_selective_state_update_sm100_module,
@@ -99,7 +100,7 @@ def get_selective_state_update_module(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=selective_state_update_trace)
 def selective_state_update(
     state: torch.Tensor,
     x: torch.Tensor,
diff --git a/flashinfer/mla/_core.py b/flashinfer/mla/_core.py
index 4e8bdd72..e27e3807 100644
--- a/flashinfer/mla/_core.py
+++ b/flashinfer/mla/_core.py
@@ -21,6 +21,11 @@ from typing import List, Literal, Optional, Tuple, Union, overload
 import torch
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.attention import (
+    mla_paged_decode_trace,
+    trtllm_batch_decode_mla_trace,
+    xqa_batch_decode_mla_trace,
+)
 from ..jit import gen_batch_mla_module, gen_trtllm_gen_fmha_module, setup_cubin_loader
 from ..jit.mla import gen_mla_module
 from ..utils import (
@@ -469,7 +474,7 @@ class BatchMLAPagedAttentionWrapper:
         return_lse_base_on_e: bool = False,
     ) -> Tuple[torch.Tensor, torch.Tensor]: ...
 
-    @flashinfer_api
+    @flashinfer_api(trace=mla_paged_decode_trace)
     def run(
         self,
         q_nope: torch.Tensor,
@@ -588,7 +593,7 @@ class BatchMLAPagedAttentionWrapper:
         return (out, lse) if return_lse else out
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_batch_decode_mla_trace)
 def trtllm_batch_decode_with_kv_cache_mla(
     query: torch.Tensor,
     kv_cache: torch.Tensor,
@@ -856,7 +861,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
         raise ValueError(f"Backend {backend} not supported")
 
 
-@flashinfer_api
+@flashinfer_api(trace=xqa_batch_decode_mla_trace)
 def xqa_batch_decode_with_kv_cache_mla(
     query: torch.Tensor,
     kv_cache: torch.Tensor,
diff --git a/flashinfer/norm/__init__.py b/flashinfer/norm/__init__.py
index 0f9911a6..ba612b28 100644
--- a/flashinfer/norm/__init__.py
+++ b/flashinfer/norm/__init__.py
@@ -32,6 +32,16 @@ from typing import Optional, Union
 import torch
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.norm import (
+    fused_add_rmsnorm_quant_trace,
+    fused_add_rmsnorm_trace,
+    fused_rmsnorm_silu_trace,
+    gemma_fused_add_rmsnorm_trace,
+    gemma_rmsnorm_trace,
+    layernorm_trace,
+    rmsnorm_quant_trace,
+    rmsnorm_trace,
--
     get_compute_capability,
@@ -94,7 +104,7 @@ def _normalize_scale_tensor(
     return scale.contiguous()
 
 
-@flashinfer_api
+@flashinfer_api(trace=rmsnorm_trace)
 def rmsnorm(
     input: torch.Tensor,
     weight: torch.Tensor,
@@ -165,7 +175,7 @@ def _rmsnorm_impl_fake(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=rmsnorm_quant_trace)
 @register_custom_op("flashinfer::rmsnorm_quant", mutates_args=("out",))
 def rmsnorm_quant(
     out: torch.Tensor,
@@ -219,7 +229,7 @@ def _rmsnorm_quant_fake(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=fused_add_rmsnorm_trace)
 @register_custom_op("flashinfer::fused_add_rmsnorm", mutates_args=("input", "residual"))
 def fused_add_rmsnorm(
     input: torch.Tensor,
@@ -271,7 +281,7 @@ def _fused_add_rmsnorm_fake(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=fused_add_rmsnorm_quant_trace)
 @register_custom_op(
     "flashinfer::fused_add_rmsnorm_quant", mutates_args=("out", "residual")
 )
@@ -343,7 +353,7 @@ def _fused_add_rmsnorm_quant_fake(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=gemma_rmsnorm_trace)
 def gemma_rmsnorm(
     input: torch.Tensor,
     weight: torch.Tensor,
@@ -414,7 +424,7 @@ def _gemma_rmsnorm_impl_fake(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=gemma_fused_add_rmsnorm_trace)
 @register_custom_op(
     "flashinfer::gemma_fused_add_rmsnorm", mutates_args=("input", "residual")
 )
@@ -470,7 +480,7 @@ def _gemma_fused_add_rmsnorm_fake(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=layernorm_trace)
 @register_custom_op("flashinfer::layernorm", mutates_args=())
 def layernorm(
     input: torch.Tensor,
@@ -590,7 +600,7 @@ def _torch_dtype_to_str(dtype):
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=fused_rmsnorm_silu_trace)
 def fused_rmsnorm_silu(
     input: torch.Tensor,
     weight: torch.Tensor,
diff --git a/flashinfer/page.py b/flashinfer/page.py
index 12ea3613..7fb33cf3 100644
--- a/flashinfer/page.py
+++ b/flashinfer/page.py
@@ -20,6 +20,10 @@ from typing import Optional, Tuple, Union
 import torch
 
 from .api_logging import flashinfer_api
+from .trace.templates.page import (
+    append_paged_kv_cache_trace,
+    append_paged_mla_kv_cache_trace,
+)
 from .jit.page import gen_page_module
 from .utils import (
     TensorLayout,
@@ -222,7 +226,7 @@ def get_seq_lens(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=append_paged_mla_kv_cache_trace)
 def append_paged_mla_kv_cache(
     append_ckv: torch.Tensor,
     append_kpe: torch.Tensor,
@@ -272,7 +276,7 @@ def append_paged_mla_kv_cache(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=append_paged_kv_cache_trace)
 def append_paged_kv_cache(
     append_key: torch.Tensor,
     append_value: torch.Tensor,
diff --git a/flashinfer/pod.py b/flashinfer/pod.py
index fe2e36c1..4fa2d9bf 100644
--- a/flashinfer/pod.py
+++ b/flashinfer/pod.py
@@ -22,6 +22,10 @@ from typing import Any, List, Optional, Tuple, Union
 import torch
 
 from .api_logging import flashinfer_api
+from .trace.templates.attention import (
+    batch_pod_with_paged_kv_cache_run_trace,
+    pod_with_paged_kv_cache_run_trace,
+)
 from .jit import gen_pod_module, gen_batch_pod_module
 from .page import get_seq_lens
 from .prefill import get_batch_prefill_module
@@ -435,7 +439,7 @@ class PODWithPagedKVCacheWrapper:
 
     begin_forward = plan
 
-    @flashinfer_api
+    @flashinfer_api(trace=pod_with_paged_kv_cache_run_trace)
     def run(
         self,
         # Main params (prefill and decode)
@@ -1015,7 +1019,7 @@ class BatchPODWithPagedKVCacheWrapper:
 
     begin_forward = plan
 
-    @flashinfer_api
+    @flashinfer_api(trace=batch_pod_with_paged_kv_cache_run_trace)
     def run(
         self,
         # Main params (prefill and decode)
diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py
index 4ec6a29e..d491dd35 100755
--- a/flashinfer/prefill.py
+++ b/flashinfer/prefill.py
@@ -23,6 +23,17 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload
 import torch
 
 from .api_logging import flashinfer_api
+from .trace.templates.attention import (
+    gqa_paged_prefill_trace,
+    gqa_ragged_prefill_trace,
+    single_prefill_with_kv_cache_trace,
+    trtllm_batch_context_trace,
+)
+from .trace.templates.gemm import (
+    fmha_v2_prefill_deepseek_trace,
+    trtllm_ragged_attention_deepseek_trace,
--
     gen_customize_batch_prefill_module,
@@ -1099,7 +1110,7 @@ def single_prefill_with_kv_cache(
 ) -> Tuple[torch.Tensor, torch.Tensor]: ...
 
 
-@flashinfer_api
+@flashinfer_api(trace=single_prefill_with_kv_cache_trace)
 def single_prefill_with_kv_cache(
     q: torch.Tensor,
     k: torch.Tensor,
@@ -2132,7 +2143,7 @@ class BatchPrefillWithPagedKVCacheWrapper:
         skip_softmax_threshold_scale_factor: Optional[float] = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]: ...
 
-    @flashinfer_api
+    @flashinfer_api(trace=gqa_paged_prefill_trace)
     def run(
         self,
         q: torch.Tensor,
@@ -3186,7 +3197,7 @@ class BatchPrefillWithRaggedKVCacheWrapper:
         enable_pdl: Optional[bool] = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]: ...
 
-    @flashinfer_api
+    @flashinfer_api(trace=gqa_ragged_prefill_trace)
     def run(
         self,
         q: torch.Tensor,
@@ -3669,7 +3680,7 @@ def get_trtllm_gen_fmha_module():
     return op
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_ragged_attention_deepseek_trace)
 def trtllm_ragged_attention_deepseek(
     query: torch.Tensor,
     key: torch.Tensor,
@@ -3692,6 +3703,7 @@ def trtllm_ragged_attention_deepseek(
     skip_softmax_threshold_scale_factor: Optional[float] = None,
     out: Optional[torch.Tensor] = None,
     lse: Optional[torch.Tensor] = None,
+    backend: str = "trtllm-gen",
 ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
     """
     Parameters
@@ -3742,6 +3754,12 @@ def trtllm_ragged_attention_deepseek(
         output tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1], value.shape[2]]
     lse : Optional[torch.Tensor]
         lse tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1]]
+    backend : str
+        Attention backend to use. "trtllm-gen" (default) or "cute-dsl".
+        When backend="cute-dsl", query/key/value/out tensors must be
+        front-padded with max_seq_len rows of valid GPU memory before
+        index 0 (see ``cute_dsl_fmha_ragged_prefill`` for details).
--
             "lse assumed not None beyond this point when return_lse is True"
@@ -3839,7 +3917,7 @@ def trtllm_ragged_attention_deepseek(
         return out
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_batch_context_trace)
 def trtllm_batch_context_with_kv_cache(
     query: torch.Tensor,
     kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
@@ -4138,7 +4216,7 @@ def get_trtllm_fmha_v2_sm120_module():
     return gen_trtllm_fmha_v2_sm120_module().build_and_load()
 
 
-@flashinfer_api
+@flashinfer_api(trace=fmha_v2_prefill_deepseek_trace)
 def fmha_v2_prefill_deepseek(
     query: torch.Tensor,
     key: torch.Tensor,
@@ -4228,7 +4306,7 @@ def get_trtllm_fmha_v2_module(
     return gen_fmha_v2_module(input_layout, input_dtype, output_dtype).build_and_load()
 
 
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fmha_v2_prefill_trace)
 def trtllm_fmha_v2_prefill(
     qkv: Union[
         torch.Tensor,
diff --git a/flashinfer/quantization/fp4_quantization.py b/flashinfer/quantization/fp4_quantization.py
index 4cd5cd34..84f7ade6 100644
--- a/flashinfer/quantization/fp4_quantization.py
+++ b/flashinfer/quantization/fp4_quantization.py
@@ -21,6 +21,12 @@ from typing import List, Optional, Tuple
 import torch
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.quantize import (
+    fp4_quantize_trace,
+    mxfp4_quantize_trace,
+    nvfp4_kv_quantize_trace,
+    nvfp4_quantize_trace,
+)
 from ..jit import JitSpec
 from ..jit import env as jit_env
 from ..jit import (
@@ -648,7 +654,7 @@ def get_fp4_quantization_module(backend: str = "100"):
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=fp4_quantize_trace)
 def fp4_quantize(
     input: torch.Tensor,
     global_scale: Optional[torch.Tensor] = None,
@@ -923,7 +929,7 @@ def shuffle_matrix_sf_a(
     return block_scale_interleave(w_shuffled)
 
 
-@flashinfer_api
+@flashinfer_api(trace=nvfp4_quantize_trace)
 def nvfp4_quantize(
     a,
     a_global_sf,
@@ -1024,7 +1030,7 @@ def nvfp4_quantize(
     return a_fp4, a_sf
 
 
-@flashinfer_api
+@flashinfer_api(trace=mxfp4_quantize_trace)
 def mxfp4_quantize(
     a: torch.Tensor,
     backend: str = "cuda",
@@ -1441,7 +1447,7 @@ def _nvfp4_kv_quant_check(input, global_scale):
 
 
 @backend_requirement({}, common_check=_nvfp4_kv_quant_check)
-@flashinfer_api
+@flashinfer_api(trace=nvfp4_kv_quantize_trace)
 def nvfp4_kv_quantize(
     input: torch.Tensor,
     global_scale: torch.Tensor,
diff --git a/flashinfer/quantization/fp8_quantization.py b/flashinfer/quantization/fp8_quantization.py
index f2c9f412..49e13a8b 100644
--- a/flashinfer/quantization/fp8_quantization.py
+++ b/flashinfer/quantization/fp8_quantization.py
@@ -5,6 +5,7 @@ from typing import Literal, Optional, Tuple
 import torch
 
 from ..api_logging import flashinfer_api
+from ..trace.templates.quantize import mxfp8_quantize_trace
 from ..jit.fp8_quantization import gen_mxfp8_quantization_sm100_module
 from ..utils import (
     device_support_pdl,
@@ -158,7 +159,7 @@ def get_mxfp8_quantization_sm100_module():
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=mxfp8_quantize_trace)
 def mxfp8_quantize(
     input: torch.Tensor,
     is_sf_swizzled_layout: bool = True,
diff --git a/flashinfer/rope.py b/flashinfer/rope.py
index d39d2e07..df5c7d4d 100644
--- a/flashinfer/rope.py
+++ b/flashinfer/rope.py
@@ -20,6 +20,21 @@ from typing import Optional, Tuple
 import torch
 
 from .api_logging import flashinfer_api
+from .trace.templates.rope import (
+    apply_llama31_rope_inplace_trace,
+    apply_llama31_rope_pos_ids_inplace_trace,
+    apply_llama31_rope_pos_ids_trace,
+    apply_llama31_rope_trace,
+    apply_rope_inplace_trace,
+    apply_rope_pos_ids_inplace_trace,
+    apply_rope_pos_ids_trace,
+    apply_rope_trace,
--
 
@@ -414,7 +429,7 @@ def _fake_apply_llama31_rope_pos_ids(
     pass
 
 
-@flashinfer_api
+@flashinfer_api(trace=apply_rope_inplace_trace)
 def apply_rope_inplace(
     q: torch.Tensor,
     k: torch.Tensor,
@@ -502,7 +517,7 @@ def apply_rope_inplace(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=apply_rope_pos_ids_inplace_trace)
 def apply_rope_pos_ids_inplace(
     q: torch.Tensor,
     k: torch.Tensor,
@@ -561,7 +576,7 @@ def apply_rope_pos_ids_inplace(
     )
 
 
-@flashinfer_api
+@flashinfer_api(trace=apply_llama31_rope_inplace_trace)
 def apply_llama31_rope_inplace(
     q: torch.Tensor,
     k: torch.Tensor,
... (truncated -- see full diff via the command above)
```

**Summary of API changes:**

- **Decorator semantic addition (backward-compatible):**
`@flashinfer_api` now accepts an optional `trace=<TraceTemplate>`
keyword. Bare `@flashinfer_api` still works. Existing call sites of
decorated functions are unaffected. Most of the diff above is mechanical
rewrites of existing `@flashinfer_api` to `@flashinfer_api(trace=...)`,
plus the new `flashinfer/trace/` package and `fi_trace.py` for
flashinfer-bench JSON dumps.

- **New public APIs (7):**
- `flashinfer.comm.dcp_alltoall.{decode_cp_a2a_workspace_size,
decode_cp_a2a_allocate_workspace, decode_cp_a2a_init_workspace,
decode_cp_a2a_alltoall}` — DCP all-to-all for context-parallel attention
reduction (#2951).
- `flashinfer.fused_moe.{interleave_moe_scales_for_sm90_mixed_gemm,
interleave_moe_weights_for_sm90_mixed_gemm}` — SM90 mixed-input MoE GEMM
helpers (#3084).
- `flashinfer.comm.run_mixed_comm` — combinations of allreduce /
allgather / reducescatter (#2563).

- **New `@flashinfer_api`-decorated wrapper init:**
- `SegmentGEMMWrapper.__init__` is now decorated. Previously the class
itself was undecorated; `run()` already was. No call-site change.

- **Backward-compatible signature additions (defaults preserve old
behavior):**
- `top_k_page_table_transform`: `+dsa_graph_safe: bool = False`,
`+row_starts: Optional[torch.Tensor] = None` (#3133).
  - `top_k_ragged_transform`: same two new params (#3133).
- `trtllm_ragged_attention_deepseek`: `+backend: str = "trtllm-gen"`
(cute-dsl backend selection).

- **No breaking signature changes** to any `@flashinfer_api` function.
Net public surface delta: +7 functions, +1 newly-decorated `__init__`, 0
removals.

- **Module reorganization to flag (not `@flashinfer_api`, but in public
re-export):**
- `flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py` →
`dense_blockscaled_gemm_sm120_b12x.py`
- Class renamed: `Sm120BlockScaledDenseGemmKernel` →
`Sm120B12xBlockScaledDenseGemmKernel`
- Re-export in `flashinfer/gemm/__init__.py` updated to the new name
only — direct importers of the old name break. Decision needed: ship as
breaking, or add a deprecation alias.

- **Internal autotuner helper rename (not public, but used by downstream
extensions):**
- `get_last_power_of_2_num_tokens_buckets` →
`get_hybrid_num_tokens_buckets`
- `last_positive_power_of_2` → `map_to_hybrid_bucket` /
`map_to_hybrid_bucket_uncapped`

> Diff truncated above due to GitHub PR body length limit. Run the
command at the top locally to see the full output.


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Patch Release**
  * Version updated to 0.6.10

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants