Add support for the combinations of allreduce, allgather, and reducescatter#2563
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a mixed-communication subsystem: NVSHMEM/CUDA kernels and TVM FFI entrypoints, Python orchestration and JIT generation, benchmarking CLI and timing aggregation, and distributed tests validating fused and fallback collective operators across TP/DP decompositions. Changes
Sequence Diagram(s)sequenceDiagram
participant App as Application
participant Handler as MixedCommHandler
participant JIT as JIT / Generated Module
participant Kernels as CUDA Kernels
participant NVSHMEM as NVSHMEM
participant NCCL as NCCL / PyTorch Dist
App->>Handler: init(world_rank, local_rank, dtype,...)
Handler->>JIT: build_and_load() / load generated module
Handler->>NVSHMEM: exchange unique-id / init (inter-node)
Handler->>Handler: allocate CUDA VM / map IPC handles (local ranks)
Handler->>NVSHMEM: allocate NVSHMEM workspace (if enabled)
App->>Handler: run_mixed_comm(op, x_in, mode/AUTOTUNE)
Handler->>Handler: select_autotune_mode() / _common_check()
alt fused kernel available
Handler->>Kernels: launch fused kernel (via JIT)
Kernels->>NVSHMEM: intra/inter-node collectives via NVSHMEM
Kernels-->>Handler: results (device buffers)
else fallback
Handler->>NCCL: perform collectives via PyTorch/NCCL
NCCL-->>Handler: results
end
Handler-->>App: x_out
App->>Handler: shutdown()
Handler->>NVSHMEM: finalize (if initialized)
Handler->>Handler: unmap/free VM, destroy groups
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @jinyangyuan-nvidia, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the distributed communication capabilities by introducing optimized primitives for combined allreduce-allgather and reducescatter-allreduce operations. These new primitives are designed to improve efficiency in complex distributed training setups, particularly where both Tensor Parallelism and Data Parallelism are employed. The implementation utilizes fused CUDA kernels and integrates NVSHMEM for efficient inter-node data exchange, alongside new benchmarking and testing infrastructure to ensure correctness and performance. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces significant new functionality by adding support for fused communication kernels (allreduce+allgather and reducescatter+allreduce). These kernels leverage advanced techniques like virtual memory for intra-node communication and nvshmem for inter-node communication, which is a great addition for performance. The PR also includes comprehensive benchmarks and tests for these new features. The overall implementation is complex but appears well-structured. My feedback focuses on improving the clarity and maintainability of the new benchmark script and ensuring proper random data generation in both the benchmarks and tests.
f5d0926 to
6efea96
Compare
…catter Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
6efea96 to
fc634c1
Compare
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (5)
tests/comm/test_mixed_comm.py (1)
14-16: Consider removing sys.path manipulation.This pattern of adding the project root to
sys.pathis unusual for pytest tests. Pytest typically handles module discovery correctly when run from the project root. If this is needed for a specific reason, consider documenting why.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/comm/test_mixed_comm.py` around lines 14 - 16, Remove the manual sys.path manipulation: delete the _project_root = Path(__file__).parent.parent.parent and the conditional sys.path.append(str(_project_root)) lines (they are adding the project root to sys.path) unless there's a documented, test-specific reason; if the import resolution actually requires it, add a brief comment explaining why and prefer using pytest configuration (e.g., PYTHONPATH or pytest.ini) instead of modifying sys.path in tests.csrc/mixed_comm_kernel_inst.jinja (1)
40-41: Minor: Semicolons after namespace closing braces are unconventional.While valid C++, the trailing semicolons after namespace closing braces are not standard style.
♻️ Suggested fix
-}; // namespace mixed_comm -}; // namespace flashinfer +} // namespace mixed_comm +} // namespace flashinfer🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/mixed_comm_kernel_inst.jinja` around lines 40 - 41, The closing namespace braces for namespaces flashinfer and mixed_comm currently include trailing semicolons; remove the unnecessary semicolons after the closing braces for namespace mixed_comm and namespace flashinfer (the lines with "}; // namespace mixed_comm" and "}; // namespace flashinfer") so they read as standard namespace closing comments (i.e., " } // namespace mixed_comm" and " } // namespace flashinfer") without the extra semicolons.flashinfer/jit/comm.py (2)
164-165: Consider simplifying list construction.The list concatenation can be simplified using iterable unpacking for better readability.
♻️ Suggested simplification
nvcc_flags = ["-rdc=true"] - ldflags = [f"-L{str(path_base / 'lib')}"] + ["-lnvshmem_device"] + ldflags = [f"-L{path_base / 'lib'}", "-lnvshmem_device"]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/jit/comm.py` around lines 164 - 165, The ldflags list currently builds via concatenation ([f"-L{str(path_base / 'lib')}"] + ["-lnvshmem_device"]); simplify it by creating a single list using iterable unpacking or by listing both items directly so it's clearer and more readable (refer to the nvcc_flags and ldflags variables and the path_base / 'lib' expression when locating the code).
160-174: Verify nvidia.nvshmem package availability at JIT time.The import of
nvidia.nvshmemhappens at JIT generation time. If the package is not installed, this will raise anImportError. Consider adding a more informative error message.🛡️ Suggested improvement for better error handling
- import nvidia.nvshmem + try: + import nvidia.nvshmem + except ImportError as e: + raise ImportError( + "nvidia-nvshmem package is required for mixed_comm module. " + "Install it with: pip install nvidia-nvshmem" + ) from e path_base = pathlib.Path(nvidia.nvshmem.__path__[0])🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/jit/comm.py` around lines 160 - 174, The code imports nvidia.nvshmem at JIT generation time which will raise ImportError if the package is missing; wrap the import of nvidia.nvshmem in a try/except that catches ImportError and re-raises a clear, actionable error (or RuntimeError) explaining that the nvidia.nvshmem package is not installed and is required to generate the "mixed_comm" JIT spec used by gen_jit_spec, include the original exception message for debugging, and ensure any subsequent uses (path_base, extra_include_paths, extra_ldflags, needs_device_linking) only run after the import succeeds.csrc/nvshmem_binding.cu (1)
134-156: Host synchronization inallreduce_on_stream_with_copymay impact performance.The
cudaStreamSynchronizeat line 155 blocks the host thread. If this function is called frequently in a tight loop, this could become a bottleneck. If the synchronization is intentional for correctness guarantees, consider documenting this.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/nvshmem_binding.cu` around lines 134 - 156, The function allreduce_on_stream_with_copy currently ends with a blocking cudaStreamSynchronize call which can stall the host; remove the unconditional cudaStreamSynchronize to make the operation asynchronous and instead either (a) record and return/use a cudaEvent on the stream for the caller to wait on, or (b) accept a cudaStream_t or a callback so the caller controls synchronization; update the function signature/usage accordingly (refer to allreduce_on_stream_with_copy, the cudaMemcpyAsync calls, and nvshmemx_barrier_on_stream) and add a comment documenting that the function is now non-blocking and callers must synchronize when they need completion guarantees.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/mixed_comm.cu`:
- Around line 96-115: The current memoization uses a single thread_local int
(dummy_smem_size) so get_dummy_smem_size(int device_id) returns the same cached
value for different device_id; change the cache to map device_id->size (e.g.,
thread_local std::unordered_map<int,int> dummy_smem_size_map) and update
get_dummy_smem_size to look up device_id, compute and store the value keyed by
device_id when missing; add `#include` <unordered_map> near the top of the file
and keep the existing helper name get_dummy_smem_size and its behavior
otherwise.
In `@flashinfer/comm/mixed_comm.py`:
- Around line 641-650: The code currently constructs AF_UNIX socket paths via
get_socket_path(pid) => "/tmp/{pid}" and uses that in send_fd to send SCM_RIGHTS
to get_socket_path(local_pid_list[rank]); change this to create and use a secure
per-run temporary directory (mode 0700) and place per-rank socket files there
instead of /tmp; modify get_socket_path to return the path inside that private
temp dir (created once at process startup with os.mkdir/mkdtemp and os.chmod
0o700), update any callers (send_fd and the corresponding receive side
referenced around the other occurrence at the 723-733 region) to use the new
base dir, and ensure old stale paths are cleaned up on exit or recreated
atomically to avoid races.
- Around line 528-537: The call to mixed_comm_module.get_max_block_size inside
max_block_size_dict is passing the original constructor args (local_tp_size,
local_dp_size, inter_tp_size, inter_dp_size) instead of the normalized/resolved
values stored on the ParallelInfo instance; change the call to pass the resolved
attributes (e.g. self.local_tp_size, self.local_dp_size, self.inter_tp_size,
self.inter_dp_size) along with self.dtype, op, and mode so the FFI probe never
receives None.
- Around line 542-552: The loop that is intended to cap entries in
self.max_block_size_dict is a no-op because it only rebinds the local variable
val; update the actual dict entries instead (e.g., for key, val in
self.max_block_size_dict.items(): self.max_block_size_dict[key] = {sub_key:
min(sub_val, max_block_size) for sub_key, sub_val in val.items()} or mutate val
in-place) while keeping the existing assertions on max_block_size, warp_size,
and min_block_size so that the later max(...) computation uses the capped
values.
- Around line 833-840: The current CPU tensor broadcast using
torch.distributed.broadcast on the tensor uid (created via
self.mixed_comm_module.nvshmem_unique_id_size() and filled by
self.mixed_comm_module.nvshmem_get_unique_id(uid) on rank
self.para_info.world_rank == 0) will hit NCCL CPU-backend errors; replace that
call with torch.distributed.broadcast_object_list to broadcast the small CPU
payload safely across backends (create a single-element list containing
uid.numpy().tobytes() or a bytes object on rank 0, call
torch.distributed.broadcast_object_list(list_obj, src=0), and reconstruct the
uid tensor from the received bytes on all ranks) so the NVSHMEM unique id is
distributed correctly before nvshmem_init().
---
Nitpick comments:
In `@csrc/mixed_comm_kernel_inst.jinja`:
- Around line 40-41: The closing namespace braces for namespaces flashinfer and
mixed_comm currently include trailing semicolons; remove the unnecessary
semicolons after the closing braces for namespace mixed_comm and namespace
flashinfer (the lines with "}; // namespace mixed_comm" and "}; // namespace
flashinfer") so they read as standard namespace closing comments (i.e., " } //
namespace mixed_comm" and " } // namespace flashinfer") without the extra
semicolons.
In `@csrc/nvshmem_binding.cu`:
- Around line 134-156: The function allreduce_on_stream_with_copy currently ends
with a blocking cudaStreamSynchronize call which can stall the host; remove the
unconditional cudaStreamSynchronize to make the operation asynchronous and
instead either (a) record and return/use a cudaEvent on the stream for the
caller to wait on, or (b) accept a cudaStream_t or a callback so the caller
controls synchronization; update the function signature/usage accordingly (refer
to allreduce_on_stream_with_copy, the cudaMemcpyAsync calls, and
nvshmemx_barrier_on_stream) and add a comment documenting that the function is
now non-blocking and callers must synchronize when they need completion
guarantees.
In `@flashinfer/jit/comm.py`:
- Around line 164-165: The ldflags list currently builds via concatenation
([f"-L{str(path_base / 'lib')}"] + ["-lnvshmem_device"]); simplify it by
creating a single list using iterable unpacking or by listing both items
directly so it's clearer and more readable (refer to the nvcc_flags and ldflags
variables and the path_base / 'lib' expression when locating the code).
- Around line 160-174: The code imports nvidia.nvshmem at JIT generation time
which will raise ImportError if the package is missing; wrap the import of
nvidia.nvshmem in a try/except that catches ImportError and re-raises a clear,
actionable error (or RuntimeError) explaining that the nvidia.nvshmem package is
not installed and is required to generate the "mixed_comm" JIT spec used by
gen_jit_spec, include the original exception message for debugging, and ensure
any subsequent uses (path_base, extra_include_paths, extra_ldflags,
needs_device_linking) only run after the import succeeds.
In `@tests/comm/test_mixed_comm.py`:
- Around line 14-16: Remove the manual sys.path manipulation: delete the
_project_root = Path(__file__).parent.parent.parent and the conditional
sys.path.append(str(_project_root)) lines (they are adding the project root to
sys.path) unless there's a documented, test-specific reason; if the import
resolution actually requires it, add a brief comment explaining why and prefer
using pytest configuration (e.g., PYTHONPATH or pytest.ini) instead of modifying
sys.path in tests.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: bcebf4b6-8ac1-4cc9-831c-508b612dade7
📒 Files selected for processing (11)
benchmarks/bench_mixed_comm.pycsrc/mixed_comm.cucsrc/mixed_comm_kernel_inst.jinjacsrc/nvshmem_binding.cuflashinfer/comm/mixed_comm.pyflashinfer/jit/comm.pyflashinfer/testing/utils.pyinclude/flashinfer/comm/mixed_comm.cuhinclude/flashinfer/comm/mixed_comm_decl.cuhtests/comm/conftest.pytests/comm/test_mixed_comm.py
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
|
/bot run |
|
hi can you annotate public api using @flashinfer_api |
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
Thanks for the suggestion. The code has been modified accordingly. |
There was a problem hiding this comment.
Actionable comments posted: 5
♻️ Duplicate comments (1)
flashinfer/comm/mixed_comm.py (1)
747-757:⚠️ Potential issue | 🔴 CriticalDon’t assume the default process group can broadcast CPU tensors.
Line 756 broadcasts a CPU
uid, but this class only requires “an activetorch.distributedprocess group” in its public contract. PyTorch documentsbroadcastas unsupported for CPU tensors on an NCCL process group; CPU collectives only route to Gloo when a multi-backend group exists, andbroadcast_object_list()is the portable option for a small payload like this unique ID. That means callers initialized with plainbackend="nccl"can fail here beforenvshmem_init()runs. (docs.pytorch.org)Does torch.distributed.broadcast support CPU tensors when the default process group is initialized with backend="nccl", and what is the recommended way to broadcast a small CPU payload such as an NVSHMEM unique ID?🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/comm/mixed_comm.py` around lines 747 - 757, The code in init_nvshmem uses torch.distributed.broadcast on a CPU tensor which can fail for NCCL-only default process groups; instead convert the NVSHMEM unique id to a small Python object (e.g., bytes via uid.cpu().numpy().tobytes() or a list of ints), use torch.distributed.broadcast_object_list to share it from para_info.world_rank == 0, then reconstruct the uid tensor (same shape/dtype) before calling mixed_comm_module.nvshmem_init; update references to mixed_comm_module.nvshmem_get_unique_id and mixed_comm_module.nvshmem_init accordingly so the send/receive path uses broadcast_object_list and the tensor is restored on CPU.
🧹 Nitpick comments (1)
benchmarks/bench_mixed_comm.py (1)
119-157: Please route this through the unified benchmark harness.This standalone CLI bypasses
benchmarks/flashinfer_benchmark.py, so it won’t inherit the repo’s standard benchmark metadata/reporting path even though the timing loop itself already usesbench_gpu_time()correctly.As per coding guidelines, "Use the unified benchmarking framework in
benchmarks/flashinfer_benchmark.pyfor kernel benchmarking with CUPTI timing support".🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_mixed_comm.py` around lines 119 - 157, The script currently implements a standalone main() that spawns processes and calls _run_worker directly, bypassing the unified harness in benchmarks/flashinfer_benchmark.py; refactor so the CLI routes through that harness instead: import the benchmark entrypoint or runner from flashinfer_benchmark.py and invoke it from main(), passing the existing argument parsing and the worker entrypoint (_run_worker) or encapsulating the spawn logic into a function the harness can call; ensure the timing loop (bench_gpu_time) remains intact but that all runs and metadata/reporting use the flashinfer_benchmark harness so the standard reporting path and CUPTI timing support are used.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/comm/mixed_comm.py`:
- Around line 529-550: The default MixedCommMode.AUTOTUNE is removed from
get_valid_mode_list when use_autotune is False, but run_mixed_comm() still
defaults its mode parameter to MixedCommMode.AUTOTUNE causing _common_check() to
reject calls when MixedCommHandler(..., use_autotune=False); fix by making the
default mode consistent: either always include MixedCommMode.AUTOTUNE in
get_valid_mode_list or change run_mixed_comm()'s default to a valid mode (e.g.,
MixedCommMode.NCCL_ONE or None) and update callers; locate get_valid_mode_list,
run_mixed_comm, and _common_check to implement the consistent default behavior
and ensure MixedCommMode.AUTOTUNE is only accepted when use_autotune is True.
- Around line 1224-1237: The shape checks around x_out in mixed_comm.dispatch
(the block that references x_out.shape and MixedCommOp.REDUCESCATTER /
REDUCESCATTER_ALLREDUCE) must also validate that x_out.dtype matches x_in.dtype
and that x_out.device matches x_in.device (or the expected device used by the
comm handler) before returning/dispatching; add explicit checks that raise a
TypeError with a clear message when dtype or device mismatch is detected so
callers get immediate API-bound errors instead of failures inside NCCL/FFI.
- Around line 408-455: MixedCommHandler currently accepts any torch.dtype but
only supports fp16/bf16 kernels; in the MixedCommHandler.__init__ (after
assigning self.dtype and self.device but before calling get_mixed_comm_module()
or probing block sizes such as via mixed_comm_module.get_max_block_size)
validate that self.dtype is one of torch.float16 or torch.bfloat16
(torch.half/torch.bfloat16 aliases are fine) and raise a clear ValueError if
not; update the constructor to perform this check early so unsupported dtypes
fail fast with an actionable error message referencing MixedCommHandler and
dtype.
In `@flashinfer/jit/comm.py`:
- Around line 17-24: The new JIT module generator function gen_*_module is
missing Python-level caching; import functools.cache (or functools) and add the
`@functools.cache` decorator to gen_*_module so repeated calls reuse the rendered
template and the created JitSpec instead of re-rendering; follow the same
pattern used by the other gen_..._module functions in this file (apply the
decorator to any new gen_*_module entrypoints).
In `@tests/comm/test_mixed_comm.py`:
- Around line 201-208: The test_mixed_comm function currently only checks GPU
count but must early-skip unsupported GPU architectures before spawning workers;
add a compute-capability guard at the top of test_mixed_comm that calls the API
helper (e.g., run_mixed_comm.is_compute_capability_supported(...) or
flashinfer.utils.is_sm90a_supported()/is_sm100a_supported()) and calls
pytest.skip(...) if the capability is not supported, placing this check before
any torch.cuda.device_count() or worker spawn logic so child processes do not
error.
---
Duplicate comments:
In `@flashinfer/comm/mixed_comm.py`:
- Around line 747-757: The code in init_nvshmem uses torch.distributed.broadcast
on a CPU tensor which can fail for NCCL-only default process groups; instead
convert the NVSHMEM unique id to a small Python object (e.g., bytes via
uid.cpu().numpy().tobytes() or a list of ints), use
torch.distributed.broadcast_object_list to share it from para_info.world_rank ==
0, then reconstruct the uid tensor (same shape/dtype) before calling
mixed_comm_module.nvshmem_init; update references to
mixed_comm_module.nvshmem_get_unique_id and mixed_comm_module.nvshmem_init
accordingly so the send/receive path uses broadcast_object_list and the tensor
is restored on CPU.
---
Nitpick comments:
In `@benchmarks/bench_mixed_comm.py`:
- Around line 119-157: The script currently implements a standalone main() that
spawns processes and calls _run_worker directly, bypassing the unified harness
in benchmarks/flashinfer_benchmark.py; refactor so the CLI routes through that
harness instead: import the benchmark entrypoint or runner from
flashinfer_benchmark.py and invoke it from main(), passing the existing argument
parsing and the worker entrypoint (_run_worker) or encapsulating the spawn logic
into a function the harness can call; ensure the timing loop (bench_gpu_time)
remains intact but that all runs and metadata/reporting use the
flashinfer_benchmark harness so the standard reporting path and CUPTI timing
support are used.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: b3b907c3-10d0-4f80-84bd-8029b786ebd3
📒 Files selected for processing (4)
benchmarks/bench_mixed_comm.pyflashinfer/comm/mixed_comm.pyflashinfer/jit/comm.pytests/comm/test_mixed_comm.py
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/comm/test_mixed_comm.py (1)
140-144: Minor: Use generator expression inall().The list comprehension inside
all()creates an unnecessary intermediate list.Suggested change
- assert all([val == local_size for val in local_size_all]), ( + assert all(val == local_size for val in local_size_all), ( "local_size must be the same on all ranks" )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/comm/test_mixed_comm.py` around lines 140 - 144, The assertion creates an unnecessary intermediate list; replace the list comprehension passed to all() with a generator expression by changing the condition that checks equality across gathered sizes (currently using local_size_all and local_size with torch.distributed.all_gather_object) to use a generator expression (e.g., all(val == local_size for val in local_size_all)) so no temporary list is allocated.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/comm/test_mixed_comm.py`:
- Around line 140-144: The assertion creates an unnecessary intermediate list;
replace the list comprehension passed to all() with a generator expression by
changing the condition that checks equality across gathered sizes (currently
using local_size_all and local_size with torch.distributed.all_gather_object) to
use a generator expression (e.g., all(val == local_size for val in
local_size_all)) so no temporary list is allocated.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 1d4cde09-d5b7-48cd-9c52-ca1a93a89642
📥 Commits
Reviewing files that changed from the base of the PR and between 8f6b63f and cd729b5b6c1050c092f72bb5e733980430cc7ede.
📒 Files selected for processing (2)
flashinfer/comm/mixed_comm.pytests/comm/test_mixed_comm.py
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
cd729b5 to
be27834
Compare
There was a problem hiding this comment.
Actionable comments posted: 5
♻️ Duplicate comments (1)
flashinfer/comm/mixed_comm.py (1)
750-758:⚠️ Potential issue | 🔴 CriticalUse
broadcast_object_list()(or a CPU-capable group) for the NVSHMEM unique ID.
uidis allocated on CPU, buttorch.distributed.broadcast(uid, src=0)goes through the default CUDA/NCCL process group here. That path fails on CPU tensors beforenvshmem_init()completes.Does `torch.distributed.broadcast` support CPU tensors when the process group backend is NCCL, and what is the recommended way to broadcast a small CPU payload like an NVSHMEM unique ID?🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/comm/mixed_comm.py` around lines 750 - 758, The code broadcasts a CPU tensor `uid` using torch.distributed.broadcast which may use the NCCL/CUDA group and fail for CPU tensors before nvshmem_init; replace that broadcast with a CPU-safe mechanism (e.g., torch.distributed.broadcast_object_list or a CPU-capable process group) when sending the NVSHMEM unique ID obtained via mixed_comm_module.nvshmem_get_unique_id() and sized by mixed_comm_module.nvshmem_unique_id_size(); ensure the root (para_info.world_rank == 0) fills `uid` and all ranks call the CPU-safe broadcast so nvshmem_init() receives the correct ID.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/bench_mixed_comm.py`:
- Around line 149-153: The current loop that waits for workers (iterating
process_list, calling process.join() then asserting exitcode) can deadlock if
one worker fails and peers are stuck in collectives; change it to poll processes
with a timeout and propagate failures immediately: replace the direct
process.join() loop with a monitoring loop that uses join(timeout=...) or checks
process.exitcode periodically, and on detecting any non-zero exitcode from a
process in process_list (or a process that stops responding within the timeout)
terminate/kill the remaining processes and raise an error; reference the
process_list variable and the code that currently calls process.join() and
checks process.exitcode to implement this immediate teardown and failure
propagation.
- Around line 119-157: The script currently implements its own CLI and
multiprocessing orchestration in main() (spawning _run_worker), which bypasses
the repo's unified benchmark entrypoint; refactor by removing the standalone
CLI/main and instead expose a run_benchmark(config) function that accepts parsed
args or a config object and invokes the existing _run_worker logic inside the
framework, then register this benchmark with the shared benchmark framework in
flashinfer_benchmark.py (use the framework's register API to supply a name, an
entry function that calls run_benchmark, and metadata so CUPTI timing and
standard output/registration are used); keep _run_worker intact, move argument
parsing into the shared framework, and ensure the new registration replaces the
if __name__ == "__main__" block.
In `@flashinfer/comm/mixed_comm.py`:
- Around line 1234-1243: The _common_check() branch incorrectly applies the
allgather rule to all non-reduce-scatter ops; update the conditional to treat
MixedCommOp.ALLREDUCE separately: for MixedCommOp.REDUCESCATTER and
MixedCommOp.REDUCESCATTER_ALLREDUCE keep the existing check (x_out.shape[0] *
handler.para_info.dp_size == x_in.shape[0]), for MixedCommOp.ALLREDUCE require
x_out.shape[0] == x_in.shape[0], and for other allgather-like ops keep
x_out.shape[0] == x_in.shape[0] * handler.para_info.dp_size; reference the
_common_check function, MixedCommOp enum values (REDUCESCATTER,
REDUCESCATTER_ALLREDUCE, ALLREDUCE), and handler.para_info.dp_size to implement
the corrected branching.
- Around line 583-593: The POSIX file descriptors returned by
cuMemExportToShareableHandle and received via SCM_RIGHTS are not being closed,
causing fd leaks; update create_and_allgather_uc_handle,
create_and_send_mc_handle, and recv_and_create_mc_handle so that after calling
cuMemExportToShareableHandle you call send_fd(...) and then immediately close
the exported descriptor (uc_fd_send) and after calling recv_fd(...) and
importing via cuMemImportFromShareableHandle you immediately close the received
descriptor (uc_fd_recv); do the same for mc_fd in create_and_send_mc_handle and
recv_and_create_mc_handle (close mc_fd after send and after import) using
os.close or an equivalent close helper so the allocation lifetime remains
managed by cuMem* APIs while avoiding leaking file descriptors.
In `@tests/comm/test_mixed_comm.py`:
- Around line 227-231: The test currently waits indefinitely on process.join()
and doesn't stop other workers if one fails; update the loop over process_list
to join each process with a short timeout, check for non-zero exitcode from the
worker started by _run_worker, and on first non-zero exitcode immediately
terminate() (or kill if needed) and join the remaining processes to avoid hangs;
after terminating the rest, re-check exit codes and raise/assert with the
failing process's exitcode and PID/index so the test fails fast and CI won't
hang.
---
Duplicate comments:
In `@flashinfer/comm/mixed_comm.py`:
- Around line 750-758: The code broadcasts a CPU tensor `uid` using
torch.distributed.broadcast which may use the NCCL/CUDA group and fail for CPU
tensors before nvshmem_init; replace that broadcast with a CPU-safe mechanism
(e.g., torch.distributed.broadcast_object_list or a CPU-capable process group)
when sending the NVSHMEM unique ID obtained via
mixed_comm_module.nvshmem_get_unique_id() and sized by
mixed_comm_module.nvshmem_unique_id_size(); ensure the root
(para_info.world_rank == 0) fills `uid` and all ranks call the CPU-safe
broadcast so nvshmem_init() receives the correct ID.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 92c31901-a293-407b-b9b8-86adb85f3e9e
📥 Commits
Reviewing files that changed from the base of the PR and between cd729b5b6c1050c092f72bb5e733980430cc7ede and be27834.
📒 Files selected for processing (3)
benchmarks/bench_mixed_comm.pyflashinfer/comm/mixed_comm.pytests/comm/test_mixed_comm.py
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
## Description
Bump version to 0.6.10 for release.
## Related Issues (Gated-by PRs)
https://github.com/flashinfer-ai/flashinfer/issues?q=is%3Aopen+label%3Av0.6.10
## Reviewer Notes
**API changes review**
API changes since v0.6.9
```diff
$ git diff v0.6.9..main -- "*.py" | grep -B5 -A20 "@flashinfer_api"
register_custom_op,
@@ -67,7 +73,7 @@ def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
)
-@flashinfer_api
+@flashinfer_api(trace=silu_and_mul_trace)
def silu_and_mul(
input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None
) -> torch.Tensor:
@@ -112,7 +118,7 @@ def silu_and_mul(
return out
-@flashinfer_api
+@flashinfer_api(trace=gelu_tanh_and_mul_trace)
def gelu_tanh_and_mul(
input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None
) -> torch.Tensor:
@@ -153,7 +159,7 @@ def gelu_tanh_and_mul(
return out
-@flashinfer_api
+@flashinfer_api(trace=gelu_and_mul_trace)
def gelu_and_mul(
input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None
) -> torch.Tensor:
@@ -194,7 +200,7 @@ def gelu_and_mul(
return out
-@flashinfer_api
+@flashinfer_api(trace=silu_and_mul_scaled_nvfp4_experts_quantize_trace)
def silu_and_mul_scaled_nvfp4_experts_quantize(
a,
mask,
diff --git a/flashinfer/aot.py b/flashinfer/aot.py
index dfb05150..d26d5407 100644
--- a/flashinfer/aot.py
+++ b/flashinfer/aot.py
@@ -543,6 +543,7 @@ def gen_all_modules(
if add_comm:
from .jit.comm import (
gen_comm_alltoall_module,
+ gen_dcp_alltoall_module,
gen_moe_alltoall_module,
gen_trtllm_comm_module,
gen_trtllm_mnnvl_comm_module,
@@ -554,6 +555,11 @@ def gen_all_modules(
jit_specs.append(gen_trtllm_comm_module())
jit_specs.append(gen_trtllm_mnnvl_comm_module())
jit_specs.append(gen_moe_alltoall_module())
+ # dcp_alltoall: kernel itself supports SM90+, but ptxas 12.6.0 has
--
-def flashinfer_api(func: Callable = None) -> Callable:
+# ---------------------------------------------------------------------------
+# Trace template registry
+# ---------------------------------------------------------------------------
+# Populated automatically by _attach_fi_trace whenever @flashinfer_api is
+# given a trace= argument. Each entry is (original_func, template, label)
+# where label is the template's name_prefix (or op_type as fallback).
+#
+# For dispatch callables (trace=some_fn), every template listed in
+# some_fn.templates is registered if that attribute exists.
+#
+# Read by tests/trace/test_fi_trace_template_consistency.py to auto-discover
+# all registered templates without requiring manual maintenance.
+_TRACE_REGISTRY: List[Tuple[Callable, Any, str]] = []
+
+
+def _attach_fi_trace(
+ wrapped: Callable,
+ original: Callable,
+ trace_template=None,
+) -> Callable:
+ """Attach a ``fi_trace`` callable to *wrapped*.
+
+ Three resolution strategies, tried in order:
+
--
+
+ warnings.warn(
+ f"[flashinfer] Failed to attach fi_trace to '{_func_name}': "
+ f"{type(_exc).__name__}: {_exc}\n"
+ f"The function will work normally but fi_trace will be unavailable. "
+ f"Fix the TraceTemplate passed to @flashinfer_api(trace=...).",
+ stacklevel=3,
+ )
+ return wrapped
+
+
+def flashinfer_api(func: Callable = None, *, trace=None) -> Callable:
"""
Decorator to FlashInfer's APIs.
@@ -1489,11 +1644,12 @@ def flashinfer_api(func: Callable = None) -> Callable:
- The %i pattern is automatically replaced with the process ID for multi-process environments.
- The logger does not propagate to the root logger to avoid duplicate logs.
"""
- # If logging is disabled, return original function with zero overhead
+ # If logging is disabled, return original function with zero overhead.
+ # We still attach fi_trace so it is always available regardless of log level.
if _API_LOG_LEVEL == 0:
if func is None:
- return lambda f: f
- return func
--
@functools.cache
@@ -135,7 +136,7 @@ class BatchAttention:
causal,
)
- @flashinfer_api
+ @flashinfer_api(trace=batch_attention_run_trace)
def run(
self,
q: torch.Tensor,
@@ -209,6 +210,8 @@ class BatchAttentionWithAttentionSinkWrapper(BatchPrefillWithPagedKVCacheWrapper
a convenient interface for using attention sinks during prefill or decode attention.
"""
+ # No @flashinfer_api here: parent class BatchPrefillWithPagedKVCacheWrapper
+ # already decorates __init__, so decorating again produces double log entries.
def __init__(
self,
float_workspace_buffer: torch.Tensor,
diff --git a/flashinfer/attention/cute_dsl/__init__.py b/flashinfer/attention/cute_dsl/__init__.py
new file mode 100644
index 00000000..3e029627
--- /dev/null
+++ b/flashinfer/attention/cute_dsl/__init__.py
@@ -0,0 +1,37 @@
+# Copyright (c) 2026 by FlashInfer team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
--
@@ -31,7 +37,7 @@ def get_cascade_module():
return gen_cascade_module().build_and_load()
-@flashinfer_api
+@flashinfer_api(trace=merge_state_trace)
@register_custom_op("flashinfer::merge_state", mutates_args=())
def merge_state(
v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
@@ -98,7 +104,7 @@ def _fake_merge_state(
return v, s
-@flashinfer_api
+@flashinfer_api(trace=merge_state_in_place_trace)
@register_custom_op("flashinfer::merge_state_in_place", mutates_args=("v", "s"))
def merge_state_in_place(
v: torch.Tensor,
@@ -159,7 +165,7 @@ def _fake_merge_state_in_place(
pass
-@flashinfer_api
+@flashinfer_api(trace=merge_states_trace)
@register_custom_op("flashinfer::merge_states", mutates_args=())
def merge_states(v: torch.Tensor, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Merge multiple attention states (v, s).
@@ -512,7 +518,7 @@ class MultiLevelCascadeAttentionWrapper:
begin_forward = plan
- @flashinfer_api
+ @flashinfer_api(trace=multi_level_cascade_run_trace)
def run(
self,
q: torch.Tensor,
diff --git a/flashinfer/comm/__init__.py b/flashinfer/comm/__init__.py
index 5f186002..31d23a99 100644
--- a/flashinfer/comm/__init__.py
+++ b/flashinfer/comm/__init__.py
@@ -65,4 +65,15 @@ from .trtllm_moe_alltoall import (
moe_a2a_wrap_payload_tensor_in_workspace as moe_a2a_wrap_payload_tensor_in_workspace,
)
+# DCP A2A (Decode Context Parallel Attention Reduction)
+from .dcp_alltoall import decode_cp_a2a_alltoall as decode_cp_a2a_alltoall
+from .dcp_alltoall import (
+ decode_cp_a2a_allocate_workspace as decode_cp_a2a_allocate_workspace,
+)
+from .dcp_alltoall import decode_cp_a2a_init_workspace as decode_cp_a2a_init_workspace
+from .dcp_alltoall import decode_cp_a2a_workspace_size as decode_cp_a2a_workspace_size
+
# from .mnnvl import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo
--
from .trtllm_ar import trtllm_create_ipc_workspace_for_all_reduce_fusion
@@ -449,7 +450,7 @@ def create_allreduce_fusion_workspace(
# ============================================================================
-@flashinfer_api
+@flashinfer_api(trace=allreduce_fusion_trace)
def allreduce_fusion(
input: torch.Tensor,
workspace: AllReduceFusionWorkspace,
diff --git a/flashinfer/comm/dcp_alltoall.py b/flashinfer/comm/dcp_alltoall.py
new file mode 100644
index 00000000..3047f76c
--- /dev/null
+++ b/flashinfer/comm/dcp_alltoall.py
@@ -0,0 +1,255 @@
+"""
+DCP All-to-All Operations for DCP Attention Reduction
+
+Provides the DCP LL128 FIFO-based all-to-all kernel for context-parallel
+attention reduction. Uses SM90+ features (TMA, mbarrier).
+
+Usage protocol::
+
+ # 1. Query workspace size
+ ws_bytes = decode_cp_a2a_workspace_size(cp_size)
+
--
+
+
+# ─── Public API ───────────────────────────────────────────────────────────
+
+
+@flashinfer_api
+def decode_cp_a2a_workspace_size(cp_size: int) -> int:
+ """Return the workspace size **in bytes** per rank for the given CP group size.
+
+ Args:
+ cp_size: Context-parallel group size (number of ranks).
+
+ Returns:
+ Workspace size in bytes per rank.
+
+ Example::
+
+ >>> decode_cp_a2a_workspace_size(4)
+ 16778240
+ """
+ return get_dcp_alltoall_module().get_workspace_size_per_rank(cp_size)
+
+
+@flashinfer_api
+def decode_cp_a2a_allocate_workspace(
+ cp_size: int,
+ cp_rank: int,
+ *,
+ mapping: Optional[Mapping] = None,
+ mnnvl_config: Optional[MnnvlConfig] = None,
+) -> torch.Tensor:
+ """Allocate a workspace tensor of shape ``[cp_size, ws_elems_per_rank]``.
+
+ After allocation, call :func:`decode_cp_a2a_init_workspace` followed by a
+ cross-rank barrier before the first :func:`decode_cp_a2a_alltoall` call.
+
+ Two allocation modes:
+
+ - **MNNVL** (``mapping`` provided): Cross-rank visible GPU memory via
+ FlashInfer's ``MnnvlMemory``. Required for multi-node or when ranks
+ cannot see each other's device memory directly.
+ - **Plain device memory** (``mapping=None``): Standard ``torch.zeros``
+ allocation. Sufficient for single-node with NVLink P2P.
+
--
+
+ ws_elems_per_rank = (ws_bytes + 7) // 8
+ return torch.zeros(cp_size, ws_elems_per_rank, dtype=torch.int64, device="cuda")
+
+
+@flashinfer_api
+def decode_cp_a2a_init_workspace(
+ workspace: torch.Tensor,
+ cp_rank: int,
+ cp_size: int,
+) -> None:
+ """Initialize the workspace FIFO buffers. Call once before the first alltoall.
+
+ Resets the FIFO buffers in the **local** workspace row
+ (``workspace[cp_rank]``). This function is **synchronous**: when it
+ returns, the GPU memset is guaranteed to have completed.
+
+ .. important::
+ With MNNVL workspaces, **all ranks** must complete
+ ``decode_cp_a2a_init_workspace`` and execute a cross-rank barrier
+ (e.g. ``dist.barrier(group)``) before **any** rank calls
+ :func:`decode_cp_a2a_alltoall`. Without the barrier, a rank may
+ start writing to a peer's FIFO before that peer has finished
+ initializing → deadlock.
+
+ Args:
--
+ # subsequent cross-GPU alltoall can race with the unfinished memset
+ # on MNNVL memory, causing a deadlock.
+ torch.cuda.current_stream().synchronize()
+
+
+@flashinfer_api(trace=decode_cp_a2a_alltoall_trace)
+def decode_cp_a2a_alltoall(
+ partial_o: torch.Tensor,
+ softmax_stats: torch.Tensor,
+ workspace: torch.Tensor,
+ cp_rank: int,
+ cp_size: int,
+ enable_pdl: Optional[bool] = None,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Perform the DCP all-to-all exchange.
+
+ Each rank sends its ``partial_o[..., peer, :]`` slice to the
+ corresponding peer and receives all peers' contributions into the
+ output tensors.
+
+ Args:
+ partial_o: ``[..., cp_size, D]`` — half or bfloat16.
+ ``D * element_size`` must be 16-byte aligned.
+ softmax_stats: ``[..., cp_size, S]`` — float32, ``S >= 2`` and even.
+ Batch dimensions must match ``partial_o``.
+ workspace: ``[cp_size, ws_elems_per_rank]`` int64 tensor from
--
+ MixedCommOp.ALLREDUCE_ALLGATHER: _allreduce_allgather,
+ MixedCommOp.REDUCESCATTER_ALLREDUCE: _reducescatter_allreduce,
+}
+
+
+@flashinfer_api
+@backend_requirement(
+ backend_checks={},
+ common_check=_common_check,
+)
+def run_mixed_comm(
+ op: MixedCommOp,
+ handler: MixedCommHandler,
+ x_in: torch.Tensor,
+ x_out: torch.Tensor | None = None,
+ mode: MixedCommMode | None = None,
+) -> torch.Tensor:
+ """Execute a mixed communication operation.
+
+ This is the main entry point for running communication collectives
+ through the mixed communication handler. It supports fused GPU kernels
+ (using virtual memory intra-node and nvshmem inter-node), NCCL-based
+ fallbacks, and autotuned mode selection.
+
+ Args:
+ op: The communication operation to perform.
--
@functools.cache
@@ -28,7 +29,7 @@ def get_concat_mla_module():
return gen_concat_mla_module().build_and_load()
-@flashinfer_api
+@flashinfer_api(trace=concat_mla_k_trace)
def concat_mla_k(
k: torch.Tensor,
k_nope: torch.Tensor,
diff --git a/flashinfer/cudnn/decode.py b/flashinfer/cudnn/decode.py
index 195ca2d4..9b593095 100644
--- a/flashinfer/cudnn/decode.py
+++ b/flashinfer/cudnn/decode.py
@@ -4,6 +4,7 @@ from typing import Optional
import torch
from ..api_logging import flashinfer_api
+from ..trace.templates.attention import cudnn_batch_decode_trace
from .utils import get_cudnn_fmha_gen_module
try:
@@ -253,7 +254,7 @@ def _batch_decode_with_kv_cache(
return out
-@flashinfer_api
+@flashinfer_api(trace=cudnn_batch_decode_trace)
def cudnn_batch_decode_with_kv_cache(
q: torch.Tensor,
k_cache: torch.Tensor,
diff --git a/flashinfer/cudnn/prefill.py b/flashinfer/cudnn/prefill.py
index fc1bbb5f..b16d6043 100644
--- a/flashinfer/cudnn/prefill.py
+++ b/flashinfer/cudnn/prefill.py
@@ -4,6 +4,7 @@ from typing import Optional
import torch
from ..api_logging import flashinfer_api
+from ..trace.templates.attention import cudnn_batch_prefill_trace
from .utils import get_cudnn_fmha_gen_module
try:
@@ -558,7 +559,7 @@ def _batch_prefill_with_kv_cache(
return out, None
-@flashinfer_api
+@flashinfer_api(trace=cudnn_batch_prefill_trace)
def cudnn_batch_prefill_with_kv_cache(
q: torch.Tensor,
k_cache: torch.Tensor,
diff --git a/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py b/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
index 0b50c22c..f25aa6fd 100644
--- a/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
+++ b/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
@@ -38,6 +38,7 @@ import torch
from cutlass import Float32, Int32, Int64, Uint32, Uint8
from ..api_logging import flashinfer_api
+from ..trace.templates.norm import add_rmsnorm_fp4quant_trace
from ..utils import device_support_pdl
from .fp4_common import (
# Constants
@@ -1042,7 +1043,7 @@ def _get_compiled_kernel(
return tensor_api
-@flashinfer_api
+@flashinfer_api(trace=add_rmsnorm_fp4quant_trace)
def add_rmsnorm_fp4quant(
input: torch.Tensor,
residual: torch.Tensor,
diff --git a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py
index 333697ab..b7aabc36 100644
--- a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py
+++ b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py
@@ -20,6 +20,7 @@ import torch
from cutlass import Float32, Int32
from flashinfer.api_logging import flashinfer_api
+from flashinfer.trace.templates.attention import cute_dsl_batch_mla_run_trace
from flashinfer.utils import device_support_pdl
from flashinfer.cute_dsl.utils import (
get_max_active_clusters,
@@ -519,7 +520,7 @@ class BatchMLADecodeCuteDSLWrapper:
f"out_dtype={self._o_dtype}"
)
- @flashinfer_api
+ @flashinfer_api(trace=cute_dsl_batch_mla_run_trace)
def run(
self,
q: torch.Tensor,
diff --git a/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py b/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py
index 58a24abe..ee0cd5e7 100644
--- a/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py
+++ b/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py
@@ -21,6 +21,7 @@ import cutlass.cute as cute
from cutlass.cute.typing import Int32
from flashinfer.api_logging import flashinfer_api
+from flashinfer.trace.templates.attention import cute_dsl_batch_prefill_run_trace
from ..config import AttentionConfig, AttentionFusion
from ..fusion.mask import MaskType
@@ -371,7 +372,7 @@ class BatchPrefillCuteDSLWrapper:
f"device={self._device}"
)
- @flashinfer_api
+ @flashinfer_api(trace=cute_dsl_batch_prefill_run_trace)
def run(
self,
q: torch.Tensor,
diff --git a/flashinfer/cute_dsl/rmsnorm_fp4quant.py b/flashinfer/cute_dsl/rmsnorm_fp4quant.py
index bc4acffc..97ce68a1 100644
--- a/flashinfer/cute_dsl/rmsnorm_fp4quant.py
+++ b/flashinfer/cute_dsl/rmsnorm_fp4quant.py
@@ -32,6 +32,7 @@ import torch
from cutlass import Float32, Int32, Uint8
from ..api_logging import flashinfer_api
+from ..trace.templates.norm import rmsnorm_fp4quant_trace
from ..utils import device_support_pdl
from .fp4_common import (
# Constants
@@ -771,7 +772,7 @@ def _get_compiled_kernel(
return tensor_api
-@flashinfer_api
+@flashinfer_api(trace=rmsnorm_fp4quant_trace)
def rmsnorm_fp4quant(
input: torch.Tensor,
weight: torch.Tensor,
diff --git a/flashinfer/decode.py b/flashinfer/decode.py
index 822aca40..5e9eb515 100644
--- a/flashinfer/decode.py
+++ b/flashinfer/decode.py
@@ -22,6 +22,12 @@ from typing import Any, List, Literal, Optional, Tuple, Union, overload
import torch
from .api_logging import flashinfer_api
+from .trace.templates.attention import (
+ gqa_paged_decode_trace,
+ single_decode_with_kv_cache_trace,
+ trtllm_batch_decode_trace,
+ xqa_batch_decode_trace,
+)
## NOTE: MLA functions have been moved to mla.py, but we keep the aliases here for backward compatibility.
from .mla import (
@@ -400,7 +406,7 @@ def single_decode_with_kv_cache(
) -> Tuple[torch.Tensor, torch.Tensor]: ...
-@flashinfer_api
+@flashinfer_api(trace=single_decode_with_kv_cache_trace)
def single_decode_with_kv_cache(
q: torch.Tensor,
k: torch.Tensor,
@@ -1215,7 +1221,7 @@ class BatchDecodeWithPagedKVCacheWrapper:
kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ...
- @flashinfer_api
+ @flashinfer_api(trace=gqa_paged_decode_trace)
def run(
self,
q: torch.Tensor,
@@ -1577,6 +1583,8 @@ class CUDAGraphBatchDecodeWithPagedKVCacheWrapper(BatchDecodeWithPagedKVCacheWra
:class:`BatchDecodeWithPagedKVCacheWrapper`
"""
+ # No @flashinfer_api here: parent class BatchDecodeWithPagedKVCacheWrapper
+ # already decorates __init__, so decorating again produces double log entries.
def __init__(
self,
workspace_buffer: torch.Tensor,
@@ -2232,7 +2240,7 @@ def get_trtllm_gen_decode_module(*args):
)
-@flashinfer_api
+@flashinfer_api(trace=trtllm_batch_decode_trace)
def trtllm_batch_decode_with_kv_cache(
query: torch.Tensor,
kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
@@ -2618,7 +2626,7 @@ def trtllm_batch_decode_with_kv_cache(
# xqa uses NHD layout
-@flashinfer_api
+@flashinfer_api(trace=xqa_batch_decode_trace)
def xqa_batch_decode_with_kv_cache(
query: torch.Tensor,
kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
diff --git a/flashinfer/fi_trace.py b/flashinfer/fi_trace.py
new file mode 100644
index 00000000..1104eb6f
--- /dev/null
+++ b/flashinfer/fi_trace.py
@@ -0,0 +1,285 @@
+# Copyright (c) 2025 by FlashInfer team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
--
+
+"""
+fi_trace: Generate `flashinfer-bench <https://github.com/flashinfer-ai/flashinfer-bench>`_
+compatible definition JSON for FlashInfer APIs.
+
+Every ``@flashinfer_api(trace=<template>)``-decorated function supports two
+usage modes:
+
+Auto-dump (recommended)
+-----------------------
+Set environment variables **before** importing flashinfer, then run your
+workload normally. No explicit ``fi_trace`` call is needed.
+
+.. code-block:: bash
+
+ FLASHINFER_TRACE_DUMP=1 \\
+ FLASHINFER_TRACE_DUMP_DIR=./fi_trace_out \\
+ python my_script.py
+
+Every decorated function writes a ``<name>.json`` file on its **first** call
+for each unique set of const-axis values (e.g. head dimensions, vocab size).
+Subsequent calls with the same shape are deduplicated — the file is written
+only once per process. The output directory is created automatically.
+
+Explicit call (for selective or programmatic use)
+-------------------------------------------------
--
+from pathlib import Path
+from typing import Any, Callable, Dict, Optional, Union
+
+# ---------------------------------------------------------------------------
+# Legacy registry — kept for backwards compatibility.
+# New code should use @flashinfer_api(trace=TraceTemplate(...)) instead.
+# ---------------------------------------------------------------------------
+
+_REGISTRY: Dict[str, Any] = {}
+
+
+def register_fi_trace(qualname: str, spec: Any) -> None:
+ """Register a legacy FiTraceSpec for the function with the given qualname.
+
+ .. deprecated::
+ Use ``@flashinfer_api(trace=TraceTemplate(...))`` instead.
+ """
+ _REGISTRY[qualname] = spec
+
+
+def build_fi_trace_fn(spec: Any) -> Callable[..., Dict[str, Any]]:
+ """Build a fi_trace callable from a legacy FiTraceSpec.
+
+ .. deprecated::
+ Use ``TraceTemplate.build_fi_trace_fn`` instead.
+ """
+ # Import the old implementation from the trace package for backwards compat.
+ from .trace.template import ( # noqa: PLC0415,F401
+ Const,
+ Scalar,
+ Tensor,
+ TraceTemplate,
+ Var,
+ )
+ import json # noqa: PLC0415
+ import os # noqa: PLC0415
--
+ """Generate a flashinfer-bench definition JSON for any FlashInfer API call.
+
+ Parameters
+ ----------
+ func_or_method:
+ A ``@flashinfer_api``-decorated function or (bound) method.
+ save_dir:
+ Directory where the JSON definition file should be written.
+ Falls back to ``FLASHINFER_TRACE_DUMP_DIR`` env-var when *None*.
+ **kwargs:
+ The same tensor arguments you would pass to the real API.
+
+ Returns
+ -------
+ dict
+ A flashinfer-bench compatible definition dictionary.
+
+ Examples
+ --------
+ Standalone function::
+
+ defn = fi_trace(flashinfer.norm.rmsnorm, input=hidden, weight=weight)
+
+ Bound method (instance.run)::
+
+ defn = fi_trace(wrapper.run, q=q_tensor, paged_kv_cache=(k, v))
--
+ trace_fn = getattr(actual_func, "fi_trace", None)
+ if trace_fn is None:
+ qualname = getattr(actual_func, "__qualname__", repr(actual_func))
+ raise ValueError(
+ f"No fi_trace spec is registered for '{qualname}'. "
+ "Only @flashinfer_api(trace=...)-decorated functions support fi_trace."
+ )
+ return trace_fn(save_dir=save_dir, **kwargs)
diff --git a/flashinfer/fused_moe/__init__.py b/flashinfer/fused_moe/__init__.py
index df6e1f72..d983f9d4 100644
--- a/flashinfer/fused_moe/__init__.py
+++ b/flashinfer/fused_moe/__init__.py
@@ -17,6 +17,8 @@ limitations under the License.
from .core import (
convert_to_block_layout,
cutlass_fused_moe,
+ interleave_moe_scales_for_sm90_mixed_gemm,
+ interleave_moe_weights_for_sm90_mixed_gemm,
gen_cutlass_fused_moe_sm120_module,
gen_cutlass_fused_moe_sm103_module,
gen_cutlass_fused_moe_sm100_module,
@@ -64,6 +66,8 @@ __all__ = [
"WeightLayout",
"convert_to_block_layout",
"cutlass_fused_moe",
+ "interleave_moe_scales_for_sm90_mixed_gemm",
--
+ ),
)
-# ref: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py#L121
@flashinfer_api
+def interleave_moe_scales_for_sm90_mixed_gemm(
+ scales: torch.Tensor,
+ group_size: int = 32,
+) -> torch.Tensor:
+ """Interleave MXFP4 block scales for the SM90 mixed-input MoE GEMM.
+
+ The kernel expects scales in layout
+ ``(num_experts, K // (group_size * 4), rows * 4)`` rather than the natural
+ ``(num_experts, rows, K // group_size)`` produced by the MXFP4 quantizer.
+ This helper performs the reshape + permute equivalent to TensorRT-LLM's
+ ``WFP4A16FusedMoEMethod.load_quant_scales`` (PR #12451), with the fixed
+ interleave factor of ``128 // group_size`` used for MXFP4.
+
+ Parameters
+ ----------
+ scales:
+ ``[num_experts, rows, K // group_size]`` uint8 tensor of E8M0 block
+ scales.
+ group_size:
+ MXFP4 quantization group size (default 32).
--
+ scales.reshape(e, rows, kgs // factor, factor).permute(0, 2, 1, 3).contiguous()
+ )
+ return tmp.reshape(e, kgs // factor, rows * factor)
+
+
+@flashinfer_api
+def interleave_moe_weights_for_sm90_mixed_gemm(
+ weight: torch.Tensor,
+ quant_type: str = "fp4",
+) -> torch.Tensor:
+ """Interleave 4-bit packed MoE weights for the SM90 mixed-input GEMM.
+
+ The SM90 mixed-dtype MoE GEMM (used by ``cutlass_fused_moe`` with
+ ``use_w4_group_scaling=True``) expects weights in a specific interleaved
+ layout; without preprocessing, the LUT-based FP4→BF16 conversion reads
+ bytes from the wrong positions and the output diverges from a dequantized
+ reference for any K > 128. TensorRT-LLM's W4A16 MoE runs the equivalent
+ preprocessing at weight-load time (see
+ ``interleave_4bit_weights_for_Hopper_mixed_gemm`` in TRT-LLM PR #12451).
+
+ Parameters
+ ----------
+ weight:
+ ``[num_experts, n, k // 2]`` uint8 CUDA tensor (4-bit values packed
+ two-per-byte).
+ quant_type:
--
+ )
+ return out
+
+
+# ref: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py#L121
+@flashinfer_api(trace=cutlass_fused_moe_trace)
def cutlass_fused_moe(
input: torch.Tensor,
token_selected_experts: torch.Tensor,
@@ -1027,8 +1151,8 @@ def get_trtllm_moe_sm100_module():
DynamicTensorSpec(
input_idx,
dim_idx,
- get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens, 1),
- lambda x: min(last_positive_power_of_2(x), tune_max_num_tokens),
+ get_hybrid_num_tokens_buckets(tune_max_num_tokens, 1),
+ lambda x: map_to_hybrid_bucket(x, tune_max_num_tokens),
initializers,
),
),
@@ -2344,7 +2468,7 @@ def _validate_routing_replay_out(
raise ValueError("routing_replay_out must be contiguous (packed row-major)")
-@flashinfer_api
+@flashinfer_api(trace=trtllm_bf16_moe_trace)
def trtllm_bf16_moe(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
@@ -2452,7 +2576,7 @@ def trtllm_bf16_moe(
return result
-@flashinfer_api
+@flashinfer_api(trace=trtllm_bf16_routed_moe_trace)
def trtllm_bf16_routed_moe(
topk_ids: torch.Tensor,
hidden_states: torch.Tensor,
@@ -2557,7 +2681,7 @@ def trtllm_bf16_routed_moe(
return result
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp8_per_tensor_scale_moe_trace)
def trtllm_fp8_per_tensor_scale_moe(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
@@ -2658,7 +2782,7 @@ def trtllm_fp8_per_tensor_scale_moe(
return result
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp8_block_scale_moe_trace_dispatch)
def trtllm_fp8_block_scale_moe(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
@@ -2779,7 +2903,7 @@ def trtllm_fp8_block_scale_moe(
return result
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp8_block_scale_routed_moe_trace)
def trtllm_fp8_block_scale_routed_moe(
topk_ids: torch.Tensor,
routing_bias: Optional[torch.Tensor],
@@ -2893,7 +3017,7 @@ def trtllm_fp8_block_scale_routed_moe(
return result
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp4_block_scale_moe_trace_dispatch)
def trtllm_fp4_block_scale_moe(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
@@ -3030,7 +3154,7 @@ def trtllm_fp4_block_scale_moe(
)
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp4_block_scale_routed_moe_trace)
def trtllm_fp4_block_scale_routed_moe(
topk_ids: torch.Tensor,
routing_bias: Optional[torch.Tensor],
@@ -3165,7 +3289,7 @@ def trtllm_fp4_block_scale_routed_moe(
)
-@flashinfer_api
+@flashinfer_api(trace=trtllm_mxint4_block_scale_moe_trace)
def trtllm_mxint4_block_scale_moe(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
diff --git a/flashinfer/fused_moe/cute_dsl/b12x_moe.py b/flashinfer/fused_moe/cute_dsl/b12x_moe.py
index d2cbc8b0..34916df5 100644
--- a/flashinfer/fused_moe/cute_dsl/b12x_moe.py
+++ b/flashinfer/fused_moe/cute_dsl/b12x_moe.py
@@ -42,11 +42,12 @@ from typing import Optional, Tuple
import torch
from ...api_logging import flashinfer_api
+from ...trace.templates.moe import b12x_fused_moe_trace, b12x_moe_wrapper_run_trace
from ...utils import supported_compute_capability
@supported_compute_capability([120, 121])
-@flashinfer_api
+@flashinfer_api(trace=b12x_fused_moe_trace)
def b12x_fused_moe(
x: torch.Tensor,
w1_weight: torch.Tensor,
@@ -293,7 +294,7 @@ class B12xMoEWrapper:
device=self.device,
)
- @flashinfer_api
+ @flashinfer_api(trace=b12x_moe_wrapper_run_trace)
def run(
self,
x: torch.Tensor,
diff --git a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py
index f6cf1b67..e266cb77 100644
--- a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py
+++ b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py
@@ -89,8 +89,8 @@ from flashinfer.cute_dsl.fp4_common import (
st_global_u64,
scatter_add_bf16x2,
)
-from flashinfer.gemm.kernels.dense_blockscaled_gemm_sm120 import (
- Sm120BlockScaledDenseGemmKernel as DenseGemmKernel,
+from flashinfer.gemm.kernels.dense_blockscaled_gemm_sm120_b12x import (
+ Sm120B12xBlockScaledDenseGemmKernel as DenseGemmKernel,
)
diff --git a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py
index e7fdae92..670b3ad8 100644
--
from .moe_utils import (
@@ -530,7 +534,7 @@ class CuteDslMoEWrapper:
enable_pdl=enable_pdl,
)
- @flashinfer_api
+ @flashinfer_api(trace=cute_dsl_moe_wrapper_run_trace)
def run(
self,
x: torch.Tensor,
@@ -686,7 +690,7 @@ def _cute_dsl_fused_moe_nvfp4_impl(
@supported_compute_capability([100, 103])
-@flashinfer_api
+@flashinfer_api(trace=cute_dsl_fused_moe_nvfp4_trace)
def cute_dsl_fused_moe_nvfp4(
x: torch.Tensor,
x_sf: torch.Tensor,
diff --git a/flashinfer/fused_moe/cute_dsl/tuner.py b/flashinfer/fused_moe/cute_dsl/tuner.py
index 0cc8628e..636043db 100644
--- a/flashinfer/fused_moe/cute_dsl/tuner.py
+++ b/flashinfer/fused_moe/cute_dsl/tuner.py
@@ -42,8 +42,8 @@ from ...autotuner import (
TuningConfig,
)
from ..utils import (
- get_last_power_of_2_num_tokens_buckets,
- last_positive_power_of_2,
+ get_hybrid_num_tokens_buckets,
+ map_to_hybrid_bucket,
)
logger = logging.getLogger(__name__)
@@ -273,10 +273,8 @@ class CuteDslFusedMoENvfp4Runner(TunableRunner):
DynamicTensorSpec(
--
import torch
@@ -137,7 +138,7 @@ def get_dsv3_fused_routing_module():
@backend_requirement({}, common_check=_check_dsv3_fused_routing_supported)
-@flashinfer_api
+@flashinfer_api(trace=fused_topk_deepseek_trace)
def fused_topk_deepseek(
scores: torch.Tensor,
bias: torch.Tensor,
diff --git a/flashinfer/fused_moe/utils.py b/flashinfer/fused_moe/utils.py
index 004271a1..91f37aa5 100644
--- a/flashinfer/fused_moe/utils.py
+++ b/flashinfer/fused_moe/utils.py
@@ -209,29 +209,102 @@ def nearest_in_buckets(x: int, buckets: List[int]) -> int:
return min(max(next_positive_power_of_2(x), buckets[0]), buckets[-1])
-def get_power_of_2_num_tokens_buckets(max_num_tokens) -> Tuple[int]:
- """Return descending power-of-2 buckets from ``next_power_of_2(max_num_tokens)`` down to 1."""
- max_num_tokens = next_positive_power_of_2(max_num_tokens)
- num_token_buckets = []
- m = max_num_tokens
- while m >= 1:
- num_token_buckets.append(m)
- m //= 2
+_PHASE1_END = 256
--
@@ -106,7 +114,7 @@ TILE_V = 8 # pretranspose tile size
# ============================================================================
-@flashinfer_api
+@flashinfer_api(trace=gated_delta_rule_decode_trace)
def gated_delta_rule_decode_pretranspose(
q: torch.Tensor,
k: torch.Tensor,
@@ -394,7 +402,7 @@ def gated_delta_rule_decode_pretranspose(
# ============================================================================
-@flashinfer_api
+@flashinfer_api(trace=gated_delta_rule_decode_trace)
def gated_delta_rule_decode(
q: torch.Tensor,
k: torch.Tensor,
@@ -535,7 +543,7 @@ def gated_delta_rule_decode(
# ============================================================================
-@flashinfer_api
+@flashinfer_api(trace=gdn_mtp_trace)
def gated_delta_rule_mtp(
q: torch.Tensor,
k: torch.Tensor,
diff --git a/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py b/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
index 68398d28..53fe44ce 100644
--- a/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
+++ b/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
@@ -3333,8 +3333,7 @@ class GatedDeltaNetChunkedKernel:
gate_handle = load_gate_consumer.wait_and_advance()
- max_coord = tTR_tCcShared[cute.size(tTR_tCcShared) - 1]
- cumprod_total = sCumprod[max_coord[1], 0, gate_handle.index]
+ cumprod_total = sCumprod[sCumprod.shape[0] - 1, 0, gate_handle.index]
valid_state = not is_first_chunk or self.use_initial_state
if cutlass.const_expr(valid_state):
diff --git a/flashinfer/gdn_kernels/blackwell/gdn_prefill.py b/flashinfer/gdn_kernels/blackwell/gdn_prefill.py
index 82dcc72b..aafcc671 100644
--- a/flashinfer/gdn_kernels/blackwell/gdn_prefill.py
--
register_custom_op,
@@ -95,7 +96,7 @@ def get_gdn_prefill_module():
return SimpleNamespace(gdn_prefill=gdn_prefill)
-@flashinfer_api
+@flashinfer_api(trace=gdn_prefill_trace)
def chunk_gated_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
diff --git a/flashinfer/gemm/__init__.py b/flashinfer/gemm/__init__.py
index a7795beb..def82216 100644
--- a/flashinfer/gemm/__init__.py
+++ b/flashinfer/gemm/__init__.py
@@ -61,11 +61,11 @@ try:
from flashinfer.cute_dsl.utils import is_cute_dsl_available
if is_cute_dsl_available():
- from .kernels.dense_blockscaled_gemm_sm120 import (
- Sm120BlockScaledDenseGemmKernel as Sm120BlockScaledDenseGemmKernel,
+ from .kernels.dense_blockscaled_gemm_sm120_b12x import (
+ Sm120B12xBlockScaledDenseGemmKernel as Sm120B12xBlockScaledDenseGemmKernel,
)
- _cute_dsl_kernels.append("Sm120BlockScaledDenseGemmKernel")
+ _cute_dsl_kernels.append("Sm120B12xBlockScaledDenseGemmKernel")
except ImportError:
--
from ..utils import (
@@ -325,7 +339,7 @@ def _heuristic_func_mm_bf16(
common_check=_check_mm_bf16_problem_size,
heuristic_func=_heuristic_func_mm_bf16,
)
-@flashinfer_api
+@flashinfer_api(trace=mm_bf16_trace)
def mm_bf16(
a: torch.Tensor,
b: torch.Tensor,
@@ -514,7 +528,7 @@ def _heuristic_func_bmm_bf16(
common_check=_check_bmm_bf16_problem_size,
heuristic_func=_heuristic_func_bmm_bf16,
)
-@flashinfer_api
+@flashinfer_api(trace=bmm_bf16_trace)
def bmm_bf16(
A: torch.Tensor,
B: torch.Tensor,
@@ -815,8 +829,8 @@ _FP8_GEMM_SM100_TUNING_CONFIG = TuningConfig(
DynamicTensorSpec(
(0,), # a_tensor_index
(-2,),
- get_last_power_of_2_num_tokens_buckets,
- last_positive_power_of_2,
+ get_hybrid_num_tokens_buckets,
+ map_to_hybrid_bucket_uncapped,
),
),
constraint_specs=(
@@ -871,8 +885,8 @@ _BF16_GEMM_SM100_TUNING_CONFIG = TuningConfig(
DynamicTensorSpec(
(0,), # a_tensor_index
(-2,),
- get_last_power_of_2_num_tokens_buckets,
- last_positive_power_of_2,
--
constraint_specs=(
@@ -1095,7 +1109,7 @@ def get_tgv_gemm_sm10x_module(
)
-@flashinfer_api
+@flashinfer_api(trace=tgv_gemm_sm100_trace)
def tgv_gemm_sm100(
a: torch.Tensor,
b: torch.Tensor,
@@ -1173,8 +1187,8 @@ def tgv_gemm_sm100(
DynamicTensorSpec(
(a_tensor_index,),
(-2,),
- get_last_power_of_2_num_tokens_buckets,
- last_positive_power_of_2,
+ get_hybrid_num_tokens_buckets,
+ map_to_hybrid_bucket_uncapped,
),
),
constraint_specs=(
@@ -1437,6 +1451,7 @@ class SegmentGEMMWrapper:
True
"""
+ @flashinfer_api
def __init__(
self, float_workspace_buffer: torch.Tensor, backend: str = "auto"
) -> None:
@@ -1469,7 +1484,7 @@ class SegmentGEMMWrapper:
self._float_workspace_buffer = float_workspace_buffer
self._int_workspace_buffer = int_workspace_buffer
- @flashinfer_api
+ @flashinfer_api(trace=segment_gemm_run_trace)
def run(
self,
x: torch.Tensor,
@@ -2084,6 +2099,8 @@ def build_cudnn_gemm_fp4_graph_override_shape(
return graph
+# Internal helper called from mm_fp4; the user-facing mm_fp4 is already
+# decorated, so decorating here would double-log the same invocation.
def execute_cudnn_gemm_fp4_graph_override_shape(
graph,
a,
@@ -2319,6 +2336,8 @@ def build_cudnn_gemm_mxfp8_graph_override_shape(
return graph
+# Internal helper called from mm_mxfp8; the user-facing mm_mxfp8 is already
+# decorated, so decorating here would double-log the same invocation.
def execute_cudnn_gemm_mxfp8_graph_override_shape(
graph,
--
):
@@ -3161,7 +3184,7 @@ def _expand_block_scale_tensor_shape(block_scale_tensor, batch_size):
return (tuple(block_scale_shape), tuple(block_scale_stride))
-@flashinfer_api
+@flashinfer_api(trace=mm_fp8_trace)
def mm_fp8(
a: torch.Tensor,
b: torch.Tensor,
@@ -3990,7 +4013,7 @@ def _heuristic_func_mm_mxfp8(
common_check=_check_mm_mxfp8_problem_size,
heuristic_func=_heuristic_func_mm_mxfp8, # result stored in mm_mxfp8.suitable_auto_backends
)
-@flashinfer_api
+@flashinfer_api(trace=mm_mxfp8_trace)
def mm_mxfp8(
a: torch.Tensor,
b: torch.Tensor,
@@ -4858,8 +4881,8 @@ def _b12x_gemm_fp4_runner(
"""
import cutlass
- from .kernels.dense_blockscaled_gemm_sm120 import (
- Sm120BlockScaledDenseGemmKernel,
+ from .kernels.dense_blockscaled_gemm_sm120_b12x import (
+ Sm120B12xBlockScaledDenseGemmKernel,
)
cutlass_dtype_attr = _TORCH_TO_CUTLASS_DTYPE_ATTR.get(out_dtype)
@@ -4905,7 +4928,7 @@ def _b12x_gemm_fp4_runner(
]
swap_ab = False
for mma_tiler_mn in sm120_mma_tiler_candidates:
- if not Sm120BlockScaledDenseGemmKernel.can_implement(
+ if not Sm120B12xBlockScaledDenseGemmKernel.can_implement(
--
constraint_specs=(
@@ -5195,7 +5217,7 @@ _MM_MXFP8_TUNING_CONFIG = TuningConfig(
common_check=_check_mm_fp4_problem_size,
heuristic_func=_heuristic_func_mm_fp4, # result stored in mm_fp4.suitable_auto_backends
)
-@flashinfer_api
+@flashinfer_api(trace=mm_fp4_trace)
def mm_fp4(
a: torch.Tensor,
b: torch.Tensor,
@@ -5449,7 +5471,7 @@ def _heuristic_func_bmm_fp8(
common_check=_check_bmm_fp8_problem_size,
heuristic_func=_heuristic_func_bmm_fp8,
)
-@flashinfer_api
+@flashinfer_api(trace=bmm_fp8_trace)
def bmm_fp8(
A: torch.Tensor,
B: torch.Tensor,
@@ -6862,7 +6884,7 @@ def _check_batch_deepgemm_fp8_nt_groupwise(
{},
common_check=_check_batch_deepgemm_fp8_nt_groupwise,
)
-@flashinfer_api
+@flashinfer_api(trace=batch_deepgemm_fp8_nt_groupwise_trace)
def batch_deepgemm_fp8_nt_groupwise(
a: torch.Tensor, # (batch_size, m, k)
b: torch.Tensor, # (batch_size, n, k)
@@ -7006,7 +7028,7 @@ def get_fp8_blockscale_gemm_runner_sm90():
return module.init()
-@flashinfer_api
+@flashinfer_api(trace=fp8_blockscale_gemm_sm90_trace)
def fp8_blockscale_gemm_sm90(
input: torch.Tensor,
weight: torch.Tensor,
@@ -7588,7 +7610,7 @@ def _heuristic_func_bmm_mxfp8(
common_check=_check_bmm_mxfp8_problem_size,
heuristic_func=_heuristic_func_bmm_mxfp8,
)
-@flashinfer_api
+@flashinfer_api(trace=bmm_mxfp8_trace)
def bmm_mxfp8(
A: torch.Tensor,
B: torch.Tensor,
diff --git a/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py b/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120_b12x.py
similarity index 99%
rename from flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py
rename to flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120_b12x.py
index c49bc815..6eee27a7 100644
--- a/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py
+++ b/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120_b12x.py
@@ -1550,7 +1550,7 @@ class DenseGemmKernel:
# Alias for FlashInfer integration
-Sm120BlockScaledDenseGemmKernel = DenseGemmKernel
+Sm120B12xBlockScaledDenseGemmKernel = DenseGemmKernel
class _DenseGemmLaunch:
diff --git a/flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py b/flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py
--
get_cutlass_dtype,
@@ -2951,7 +2952,7 @@ def get_cute_dsl_compiled_masked_gemm_kernel(
return tensor_api
-@flashinfer_api
+@flashinfer_api(trace=grouped_gemm_nt_masked_trace)
def grouped_gemm_nt_masked(
lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
diff --git a/flashinfer/gemm/routergemm.py b/flashinfer/gemm/routergemm.py
index cfde7d43..f83c8974 100644
--- a/flashinfer/gemm/routergemm.py
+++ b/flashinfer/gemm/routergemm.py
@@ -1,4 +1,8 @@
from ..api_logging import flashinfer_api
+from ..trace.templates.gemm import (
+ mm_M1_16_K7168_N256_trace,
+ tinygemm_bf16_trace,
+)
from flashinfer.jit import gen_dsv3_router_gemm_module, gen_tinygemm2_module
import functools
from types import SimpleNamespace
@@ -176,7 +180,7 @@ def mm_M1_16_K7168_N128(
@backend_requirement({}, common_check=_mm_M1_16_K7168_N256_shape_checks)
-@flashinfer_api
+@flashinfer_api(trace=mm_M1_16_K7168_N256_trace)
def mm_M1_16_K7168_N256(
mat_a: torch.Tensor,
mat_b: torch.Tensor,
@@ -324,7 +328,7 @@ def get_tinygemm2_module():
@backend_requirement({}, common_check=_tinygemm_bf16_shape_checks)
-@flashinfer_api
+@flashinfer_api(trace=tinygemm_bf16_trace)
def tinygemm_bf16(
input: torch.Tensor,
weight: torch.Tensor,
diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py
index 7f36a314..8378e0ab 100644
--- a/flashinfer/jit/__init__.py
+++ b/flashinfer/jit/__init__.py
@@ -82,6 +82,7 @@ from .comm import gen_trtllm_mnnvl_comm_module as gen_trtllm_mnnvl_comm_module
from .comm import gen_trtllm_comm_module as gen_trtllm_comm_module
from .comm import gen_vllm_comm_module as gen_vllm_comm_module
from .comm import gen_moe_alltoall_module as gen_moe_alltoall_module
+from .comm import gen_dcp_alltoall_module as gen_dcp_alltoall_module
from .dsv3_optimizations import (
gen_dsv3_router_gemm_module as gen_dsv3_router_gemm_module,
)
diff --git a/flashinfer/jit/comm.py b/flashinfer/jit/comm.py
index 46768eed..834f77f9 100644
--- a/flashinfer/jit/comm.py
+++ b/flashinfer/jit/comm.py
@@ -15,7 +15,13 @@ limitations under the License.
--
gen_selective_state_update_sm100_module,
@@ -99,7 +100,7 @@ def get_selective_state_update_module(
)
-@flashinfer_api
+@flashinfer_api(trace=selective_state_update_trace)
def selective_state_update(
state: torch.Tensor,
x: torch.Tensor,
diff --git a/flashinfer/mla/_core.py b/flashinfer/mla/_core.py
index 4e8bdd72..e27e3807 100644
--- a/flashinfer/mla/_core.py
+++ b/flashinfer/mla/_core.py
@@ -21,6 +21,11 @@ from typing import List, Literal, Optional, Tuple, Union, overload
import torch
from ..api_logging import flashinfer_api
+from ..trace.templates.attention import (
+ mla_paged_decode_trace,
+ trtllm_batch_decode_mla_trace,
+ xqa_batch_decode_mla_trace,
+)
from ..jit import gen_batch_mla_module, gen_trtllm_gen_fmha_module, setup_cubin_loader
from ..jit.mla import gen_mla_module
from ..utils import (
@@ -469,7 +474,7 @@ class BatchMLAPagedAttentionWrapper:
return_lse_base_on_e: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ...
- @flashinfer_api
+ @flashinfer_api(trace=mla_paged_decode_trace)
def run(
self,
q_nope: torch.Tensor,
@@ -588,7 +593,7 @@ class BatchMLAPagedAttentionWrapper:
return (out, lse) if return_lse else out
-@flashinfer_api
+@flashinfer_api(trace=trtllm_batch_decode_mla_trace)
def trtllm_batch_decode_with_kv_cache_mla(
query: torch.Tensor,
kv_cache: torch.Tensor,
@@ -856,7 +861,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
raise ValueError(f"Backend {backend} not supported")
-@flashinfer_api
+@flashinfer_api(trace=xqa_batch_decode_mla_trace)
def xqa_batch_decode_with_kv_cache_mla(
query: torch.Tensor,
kv_cache: torch.Tensor,
diff --git a/flashinfer/norm/__init__.py b/flashinfer/norm/__init__.py
index 0f9911a6..ba612b28 100644
--- a/flashinfer/norm/__init__.py
+++ b/flashinfer/norm/__init__.py
@@ -32,6 +32,16 @@ from typing import Optional, Union
import torch
from ..api_logging import flashinfer_api
+from ..trace.templates.norm import (
+ fused_add_rmsnorm_quant_trace,
+ fused_add_rmsnorm_trace,
+ fused_rmsnorm_silu_trace,
+ gemma_fused_add_rmsnorm_trace,
+ gemma_rmsnorm_trace,
+ layernorm_trace,
+ rmsnorm_quant_trace,
+ rmsnorm_trace,
--
get_compute_capability,
@@ -94,7 +104,7 @@ def _normalize_scale_tensor(
return scale.contiguous()
-@flashinfer_api
+@flashinfer_api(trace=rmsnorm_trace)
def rmsnorm(
input: torch.Tensor,
weight: torch.Tensor,
@@ -165,7 +175,7 @@ def _rmsnorm_impl_fake(
pass
-@flashinfer_api
+@flashinfer_api(trace=rmsnorm_quant_trace)
@register_custom_op("flashinfer::rmsnorm_quant", mutates_args=("out",))
def rmsnorm_quant(
out: torch.Tensor,
@@ -219,7 +229,7 @@ def _rmsnorm_quant_fake(
pass
-@flashinfer_api
+@flashinfer_api(trace=fused_add_rmsnorm_trace)
@register_custom_op("flashinfer::fused_add_rmsnorm", mutates_args=("input", "residual"))
def fused_add_rmsnorm(
input: torch.Tensor,
@@ -271,7 +281,7 @@ def _fused_add_rmsnorm_fake(
pass
-@flashinfer_api
+@flashinfer_api(trace=fused_add_rmsnorm_quant_trace)
@register_custom_op(
"flashinfer::fused_add_rmsnorm_quant", mutates_args=("out", "residual")
)
@@ -343,7 +353,7 @@ def _fused_add_rmsnorm_quant_fake(
pass
-@flashinfer_api
+@flashinfer_api(trace=gemma_rmsnorm_trace)
def gemma_rmsnorm(
input: torch.Tensor,
weight: torch.Tensor,
@@ -414,7 +424,7 @@ def _gemma_rmsnorm_impl_fake(
pass
-@flashinfer_api
+@flashinfer_api(trace=gemma_fused_add_rmsnorm_trace)
@register_custom_op(
"flashinfer::gemma_fused_add_rmsnorm", mutates_args=("input", "residual")
)
@@ -470,7 +480,7 @@ def _gemma_fused_add_rmsnorm_fake(
pass
-@flashinfer_api
+@flashinfer_api(trace=layernorm_trace)
@register_custom_op("flashinfer::layernorm", mutates_args=())
def layernorm(
input: torch.Tensor,
@@ -590,7 +600,7 @@ def _torch_dtype_to_str(dtype):
)
-@flashinfer_api
+@flashinfer_api(trace=fused_rmsnorm_silu_trace)
def fused_rmsnorm_silu(
input: torch.Tensor,
weight: torch.Tensor,
diff --git a/flashinfer/page.py b/flashinfer/page.py
index 12ea3613..7fb33cf3 100644
--- a/flashinfer/page.py
+++ b/flashinfer/page.py
@@ -20,6 +20,10 @@ from typing import Optional, Tuple, Union
import torch
from .api_logging import flashinfer_api
+from .trace.templates.page import (
+ append_paged_kv_cache_trace,
+ append_paged_mla_kv_cache_trace,
+)
from .jit.page import gen_page_module
from .utils import (
TensorLayout,
@@ -222,7 +226,7 @@ def get_seq_lens(
)
-@flashinfer_api
+@flashinfer_api(trace=append_paged_mla_kv_cache_trace)
def append_paged_mla_kv_cache(
append_ckv: torch.Tensor,
append_kpe: torch.Tensor,
@@ -272,7 +276,7 @@ def append_paged_mla_kv_cache(
)
-@flashinfer_api
+@flashinfer_api(trace=append_paged_kv_cache_trace)
def append_paged_kv_cache(
append_key: torch.Tensor,
append_value: torch.Tensor,
diff --git a/flashinfer/pod.py b/flashinfer/pod.py
index fe2e36c1..4fa2d9bf 100644
--- a/flashinfer/pod.py
+++ b/flashinfer/pod.py
@@ -22,6 +22,10 @@ from typing import Any, List, Optional, Tuple, Union
import torch
from .api_logging import flashinfer_api
+from .trace.templates.attention import (
+ batch_pod_with_paged_kv_cache_run_trace,
+ pod_with_paged_kv_cache_run_trace,
+)
from .jit import gen_pod_module, gen_batch_pod_module
from .page import get_seq_lens
from .prefill import get_batch_prefill_module
@@ -435,7 +439,7 @@ class PODWithPagedKVCacheWrapper:
begin_forward = plan
- @flashinfer_api
+ @flashinfer_api(trace=pod_with_paged_kv_cache_run_trace)
def run(
self,
# Main params (prefill and decode)
@@ -1015,7 +1019,7 @@ class BatchPODWithPagedKVCacheWrapper:
begin_forward = plan
- @flashinfer_api
+ @flashinfer_api(trace=batch_pod_with_paged_kv_cache_run_trace)
def run(
self,
# Main params (prefill and decode)
diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py
index 4ec6a29e..d491dd35 100755
--- a/flashinfer/prefill.py
+++ b/flashinfer/prefill.py
@@ -23,6 +23,17 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload
import torch
from .api_logging import flashinfer_api
+from .trace.templates.attention import (
+ gqa_paged_prefill_trace,
+ gqa_ragged_prefill_trace,
+ single_prefill_with_kv_cache_trace,
+ trtllm_batch_context_trace,
+)
+from .trace.templates.gemm import (
+ fmha_v2_prefill_deepseek_trace,
+ trtllm_ragged_attention_deepseek_trace,
--
gen_customize_batch_prefill_module,
@@ -1099,7 +1110,7 @@ def single_prefill_with_kv_cache(
) -> Tuple[torch.Tensor, torch.Tensor]: ...
-@flashinfer_api
+@flashinfer_api(trace=single_prefill_with_kv_cache_trace)
def single_prefill_with_kv_cache(
q: torch.Tensor,
k: torch.Tensor,
@@ -2132,7 +2143,7 @@ class BatchPrefillWithPagedKVCacheWrapper:
skip_softmax_threshold_scale_factor: Optional[float] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ...
- @flashinfer_api
+ @flashinfer_api(trace=gqa_paged_prefill_trace)
def run(
self,
q: torch.Tensor,
@@ -3186,7 +3197,7 @@ class BatchPrefillWithRaggedKVCacheWrapper:
enable_pdl: Optional[bool] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ...
- @flashinfer_api
+ @flashinfer_api(trace=gqa_ragged_prefill_trace)
def run(
self,
q: torch.Tensor,
@@ -3669,7 +3680,7 @@ def get_trtllm_gen_fmha_module():
return op
-@flashinfer_api
+@flashinfer_api(trace=trtllm_ragged_attention_deepseek_trace)
def trtllm_ragged_attention_deepseek(
query: torch.Tensor,
key: torch.Tensor,
@@ -3692,6 +3703,7 @@ def trtllm_ragged_attention_deepseek(
skip_softmax_threshold_scale_factor: Optional[float] = None,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
+ backend: str = "trtllm-gen",
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Parameters
@@ -3742,6 +3754,12 @@ def trtllm_ragged_attention_deepseek(
output tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1], value.shape[2]]
lse : Optional[torch.Tensor]
lse tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1]]
+ backend : str
+ Attention backend to use. "trtllm-gen" (default) or "cute-dsl".
+ When backend="cute-dsl", query/key/value/out tensors must be
+ front-padded with max_seq_len rows of valid GPU memory before
+ index 0 (see ``cute_dsl_fmha_ragged_prefill`` for details).
--
"lse assumed not None beyond this point when return_lse is True"
@@ -3839,7 +3917,7 @@ def trtllm_ragged_attention_deepseek(
return out
-@flashinfer_api
+@flashinfer_api(trace=trtllm_batch_context_trace)
def trtllm_batch_context_with_kv_cache(
query: torch.Tensor,
kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
@@ -4138,7 +4216,7 @@ def get_trtllm_fmha_v2_sm120_module():
return gen_trtllm_fmha_v2_sm120_module().build_and_load()
-@flashinfer_api
+@flashinfer_api(trace=fmha_v2_prefill_deepseek_trace)
def fmha_v2_prefill_deepseek(
query: torch.Tensor,
key: torch.Tensor,
@@ -4228,7 +4306,7 @@ def get_trtllm_fmha_v2_module(
return gen_fmha_v2_module(input_layout, input_dtype, output_dtype).build_and_load()
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fmha_v2_prefill_trace)
def trtllm_fmha_v2_prefill(
qkv: Union[
torch.Tensor,
diff --git a/flashinfer/quantization/fp4_quantization.py b/flashinfer/quantization/fp4_quantization.py
index 4cd5cd34..84f7ade6 100644
--- a/flashinfer/quantization/fp4_quantization.py
+++ b/flashinfer/quantization/fp4_quantization.py
@@ -21,6 +21,12 @@ from typing import List, Optional, Tuple
import torch
from ..api_logging import flashinfer_api
+from ..trace.templates.quantize import (
+ fp4_quantize_trace,
+ mxfp4_quantize_trace,
+ nvfp4_kv_quantize_trace,
+ nvfp4_quantize_trace,
+)
from ..jit import JitSpec
from ..jit import env as jit_env
from ..jit import (
@@ -648,7 +654,7 @@ def get_fp4_quantization_module(backend: str = "100"):
)
-@flashinfer_api
+@flashinfer_api(trace=fp4_quantize_trace)
def fp4_quantize(
input: torch.Tensor,
global_scale: Optional[torch.Tensor] = None,
@@ -923,7 +929,7 @@ def shuffle_matrix_sf_a(
return block_scale_interleave(w_shuffled)
-@flashinfer_api
+@flashinfer_api(trace=nvfp4_quantize_trace)
def nvfp4_quantize(
a,
a_global_sf,
@@ -1024,7 +1030,7 @@ def nvfp4_quantize(
return a_fp4, a_sf
-@flashinfer_api
+@flashinfer_api(trace=mxfp4_quantize_trace)
def mxfp4_quantize(
a: torch.Tensor,
backend: str = "cuda",
@@ -1441,7 +1447,7 @@ def _nvfp4_kv_quant_check(input, global_scale):
@backend_requirement({}, common_check=_nvfp4_kv_quant_check)
-@flashinfer_api
+@flashinfer_api(trace=nvfp4_kv_quantize_trace)
def nvfp4_kv_quantize(
input: torch.Tensor,
global_scale: torch.Tensor,
diff --git a/flashinfer/quantization/fp8_quantization.py b/flashinfer/quantization/fp8_quantization.py
index f2c9f412..49e13a8b 100644
--- a/flashinfer/quantization/fp8_quantization.py
+++ b/flashinfer/quantization/fp8_quantization.py
@@ -5,6 +5,7 @@ from typing import Literal, Optional, Tuple
import torch
from ..api_logging import flashinfer_api
+from ..trace.templates.quantize import mxfp8_quantize_trace
from ..jit.fp8_quantization import gen_mxfp8_quantization_sm100_module
from ..utils import (
device_support_pdl,
@@ -158,7 +159,7 @@ def get_mxfp8_quantization_sm100_module():
)
-@flashinfer_api
+@flashinfer_api(trace=mxfp8_quantize_trace)
def mxfp8_quantize(
input: torch.Tensor,
is_sf_swizzled_layout: bool = True,
diff --git a/flashinfer/rope.py b/flashinfer/rope.py
index d39d2e07..df5c7d4d 100644
--- a/flashinfer/rope.py
+++ b/flashinfer/rope.py
@@ -20,6 +20,21 @@ from typing import Optional, Tuple
import torch
from .api_logging import flashinfer_api
+from .trace.templates.rope import (
+ apply_llama31_rope_inplace_trace,
+ apply_llama31_rope_pos_ids_inplace_trace,
+ apply_llama31_rope_pos_ids_trace,
+ apply_llama31_rope_trace,
+ apply_rope_inplace_trace,
+ apply_rope_pos_ids_inplace_trace,
+ apply_rope_pos_ids_trace,
+ apply_rope_trace,
--
@@ -414,7 +429,7 @@ def _fake_apply_llama31_rope_pos_ids(
pass
-@flashinfer_api
+@flashinfer_api(trace=apply_rope_inplace_trace)
def apply_rope_inplace(
q: torch.Tensor,
k: torch.Tensor,
@@ -502,7 +517,7 @@ def apply_rope_inplace(
)
-@flashinfer_api
+@flashinfer_api(trace=apply_rope_pos_ids_inplace_trace)
def apply_rope_pos_ids_inplace(
q: torch.Tensor,
k: torch.Tensor,
@@ -561,7 +576,7 @@ def apply_rope_pos_ids_inplace(
)
-@flashinfer_api
+@flashinfer_api(trace=apply_llama31_rope_inplace_trace)
def apply_llama31_rope_inplace(
q: torch.Tensor,
k: torch.Tensor,
... (truncated -- see full diff via the command above)
```
**Summary of API changes:**
- **Decorator semantic addition (backward-compatible):**
`@flashinfer_api` now accepts an optional `trace=<TraceTemplate>`
keyword. Bare `@flashinfer_api` still works. Existing call sites of
decorated functions are unaffected. Most of the diff above is mechanical
rewrites of existing `@flashinfer_api` to `@flashinfer_api(trace=...)`,
plus the new `flashinfer/trace/` package and `fi_trace.py` for
flashinfer-bench JSON dumps.
- **New public APIs (7):**
- `flashinfer.comm.dcp_alltoall.{decode_cp_a2a_workspace_size,
decode_cp_a2a_allocate_workspace, decode_cp_a2a_init_workspace,
decode_cp_a2a_alltoall}` — DCP all-to-all for context-parallel attention
reduction (#2951).
- `flashinfer.fused_moe.{interleave_moe_scales_for_sm90_mixed_gemm,
interleave_moe_weights_for_sm90_mixed_gemm}` — SM90 mixed-input MoE GEMM
helpers (#3084).
- `flashinfer.comm.run_mixed_comm` — combinations of allreduce /
allgather / reducescatter (#2563).
- **New `@flashinfer_api`-decorated wrapper init:**
- `SegmentGEMMWrapper.__init__` is now decorated. Previously the class
itself was undecorated; `run()` already was. No call-site change.
- **Backward-compatible signature additions (defaults preserve old
behavior):**
- `top_k_page_table_transform`: `+dsa_graph_safe: bool = False`,
`+row_starts: Optional[torch.Tensor] = None` (#3133).
- `top_k_ragged_transform`: same two new params (#3133).
- `trtllm_ragged_attention_deepseek`: `+backend: str = "trtllm-gen"`
(cute-dsl backend selection).
- **No breaking signature changes** to any `@flashinfer_api` function.
Net public surface delta: +7 functions, +1 newly-decorated `__init__`, 0
removals.
- **Module reorganization to flag (not `@flashinfer_api`, but in public
re-export):**
- `flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py` →
`dense_blockscaled_gemm_sm120_b12x.py`
- Class renamed: `Sm120BlockScaledDenseGemmKernel` →
`Sm120B12xBlockScaledDenseGemmKernel`
- Re-export in `flashinfer/gemm/__init__.py` updated to the new name
only — direct importers of the old name break. Decision needed: ship as
breaking, or add a deprecation alias.
- **Internal autotuner helper rename (not public, but used by downstream
extensions):**
- `get_last_power_of_2_num_tokens_buckets` →
`get_hybrid_num_tokens_buckets`
- `last_positive_power_of_2` → `map_to_hybrid_bucket` /
`map_to_hybrid_bucket_uncapped`
> Diff truncated above due to GitHub PR body length limit. Run the
command at the top locally to see the full output.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **Patch Release**
* Version updated to 0.6.10
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
This PR adds support for the combinations of allreduce,allgather, and reducescatter:
The last two communication patterns occur when TP and DP are both enabled in the attention part.
Besides combining existing NCCL kernels, this PR also implements fused kernels:
Update
Besides using unit tests to verify correctness, the correctness has also been verified by running GSM8K and GPQA tests using SGLang: https://github.com/jinyangyuan-nvidia/sglang/tree/dev/mixed_comm
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
New Tools
Tests