feat: Add DCP All-to-All kernel for context-parallel attention reduction#2951
feat: Add DCP All-to-All kernel for context-parallel attention reduction#2951aleozlx merged 12 commits intoflashinfer-ai:mainfrom
Conversation
Add the DCP (Disaggregated Context Parallelism) LL128 FIFO-based all-to-all kernel, ported from TensorRT-LLM. This kernel fuses the exchange of partial attention outputs (bf16/fp16) and softmax statistics (fp32) into a single kernel launch using MNNVL cross-GPU memory, replacing 2x NCCL all_to_all_single calls. Requires SM90+ (Hopper/Blackwell). MNNVL workspace required for multi-GPU operation.
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a DCP (Decode Context Parallel) Helix all-to-all implementation: device cp.async/mbarrier primitives and LL128 protocol, Helix CUDA kernel with workspace sizing/initialization/launch, TVM/JIT bindings and Python APIs (optional MNNVL-backed workspace), AOT/JIT integration, benchmark, and single-/multi‑GPU tests. Changes
Sequence DiagramsequenceDiagram
participant User
participant PyAPI as Python API
participant JIT as JIT/FFI Module
participant Kernel as Helix Kernel
participant Workspace as FIFO Workspace
participant Device as CUDA Device
participant MPI as MPI/NCCL
User->>PyAPI: decode_cp_a2a_allocate_workspace(cp_size, cp_rank)
PyAPI->>Device: allocate tensor (MNNVL or torch.cuda)
Device-->>PyAPI: workspace tensor
PyAPI-->>User: workspace
User->>PyAPI: decode_cp_a2a_init_workspace(workspace, cp_rank, cp_size)
PyAPI->>JIT: initialize_dcp_workspace(workspace, cp_rank, cp_size)
JIT->>Kernel: initializeHelixWorkspace(workspace_ptr)
Kernel->>Workspace: memset FIFO metadata
Kernel->>Device: stream sync
JIT-->>PyAPI: init complete
User->>PyAPI: decode_cp_a2a_alltoall(partial_o, softmax_stats, workspace, cp_rank, cp_size)
PyAPI->>JIT: alltoall_dcp_native(inputs, workspace, cp_rank, cp_size)
JIT->>Kernel: launchHelixAllToAll(params)
Kernel->>Workspace: cp.async / LL128 pack writes
Kernel->>Kernel: mbarrier / FIFO handshake
Kernel->>Workspace: LL128 unpack reads
Kernel-->>JIT: outputs tensors
JIT-->>PyAPI: return outputs
PyAPI-->>User: outputs
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces the DCP (Decode Context Parallel) All-to-All communication operation, featuring a fused LL128 FIFO kernel (Helix) optimized for SM90+ architectures. The implementation includes core CUDA kernels, LL128 protocol handling, TVM FFI wrappers, and a Python API with JIT support. Review feedback identifies several critical issues: Programmatic Dependent Launch (PDL) is incorrectly enabled by default on unsupported pre-SM90 architectures, the kernel launch configuration cache lacks thread-safety for concurrent environments, and static caching of workspace parameters fails to account for varying input arguments, potentially leading to incorrect memory allocations.
| static bool enablePDL = true; | ||
|
|
||
| std::call_once(flag, [&]() { | ||
| if (getSMVersion() >= 90) { | ||
| char const* env = std::getenv("TRTLLM_ENABLE_PDL"); | ||
| if (env) { | ||
| if (env[0] == '1' && env[1] == '\0') { | ||
| enablePDL = true; | ||
| } else if (env[0] == '0' && env[1] == '\0') { | ||
| enablePDL = false; | ||
| } | ||
| }; | ||
| } | ||
| }); |
There was a problem hiding this comment.
The enablePDL variable is initialized to true and only modified if getSMVersion() >= 90. This means Programmatic Dependent Launch (PDL) will be incorrectly reported as enabled on older GPU architectures (SM < 90), which do not support this feature. This could lead to kernel launch failures when cudaLaunchKernelEx is called with PDL attributes. It should default to false and only be enabled for SM90+.
static bool enablePDL = false;
std::call_once(flag, [&]() {
if (getSMVersion() >= 90) {
enablePDL = true;
char const* env = std::getenv("TRTLLM_ENABLE_PDL");
if (env && env[0] == '0' && env[1] == '\0') {
enablePDL = false;
}
}
});There was a problem hiding this comment.
Don't think lock is needed as host-side launch is single-threaded in practice.
| static std::unordered_map<std::tuple<int, int, int>, std::tuple<int, int, int>, hash_cache_key> | ||
| cache; | ||
| int deviceId = 0; | ||
| TLLM_CUDA_CHECK(cudaGetDevice(&deviceId)); | ||
| int singleShmSize = std::max(computeTotalUnpackedSize(fields), computeProtoTransferSize(fields)); | ||
| auto key = std::make_tuple(deviceId, cpSize, singleShmSize); | ||
| auto it = cache.find(key); | ||
| if (it != cache.end()) { | ||
| return it->second; | ||
| } |
There was a problem hiding this comment.
The cache variable is a static std::unordered_map accessed without any synchronization. In a multi-threaded environment (e.g., a multi-stream inference server), concurrent calls to launchHelixAllToAll could lead to race conditions, data corruption, or crashes when accessing or modifying this map. Additionally, the cudaFuncSetAttribute call on line 515 is not thread-safe and modifies global state for the kernel function. Consider using a mutex to protect access to the cache and the attribute setting logic.
| static int maxChannelCount = 0; | ||
| if (maxChannelCount == 0) { | ||
| maxChannelCount = computeHelixMaxChannelCount(cpSize); | ||
| } |
There was a problem hiding this comment.
The maxChannelCount is cached as a static variable, but its value depends on the cpSize argument. If computeHelixWorkspaceSizePerRank is called with different cpSize values in the same process (e.g., in a multi-model or multi-tenant setup), it will return an incorrect workspace size for all calls after the first one. The caching should be removed or made dependent on cpSize.
int maxChannelCount = computeHelixMaxChannelCount(cpSize);
There was a problem hiding this comment.
Despite that cpSize will not change during a run, it is light weighted division so I fixed it now: maxChannelCount is no longer static and refreshed every time.
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/aot.py (1)
546-550:⚠️ Potential issue | 🟠 MajorDCP module only generated for SM100+ builds despite supporting SM90.
The
gen_dcp_alltoall_module()is placed inside theif has_sm100:block, but itssupported_major_versions=[9, 10]declaration indicates support for SM90 as well. This means SM90-only AOT builds (e.g., H100 without Blackwell) won't include the DCP all-to-all module, despite the kernel supporting it. The kernel itself contains no SM100-specific code.Move the call outside the SM100-only conditional:
Suggested fix
jit_specs.append(gen_comm_alltoall_module()) + if has_sm90 or has_sm100: + jit_specs.append(gen_dcp_alltoall_module()) if has_sm100: jit_specs.append(gen_trtllm_comm_module()) jit_specs.append(gen_trtllm_mnnvl_comm_module()) jit_specs.append(gen_moe_alltoall_module()) - jit_specs.append(gen_dcp_alltoall_module())🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/aot.py` around lines 546 - 550, The DCP all-to-all module is incorrectly only appended when has_sm100 is true; move the call to gen_dcp_alltoall_module() out of the has_sm100 conditional so it is appended to jit_specs for SM90 and SM100 builds (keep gen_trtllm_comm_module(), gen_trtllm_mnnvl_comm_module(), gen_moe_alltoall_module() inside the if block and only relocate the gen_dcp_alltoall_module() invocation so that jit_specs.append(gen_dcp_alltoall_module()) runs regardless of has_sm100).
🧹 Nitpick comments (8)
tests/comm/test_dcp_alltoall.py (2)
54-58: Session-scoped fixture may not seed all tests as expected.The
setup_test_environmentfixture setstorch.manual_seed(0xA2A)once per session, but pytest may run tests in different orders, and torch's RNG state is mutated by each test. Consider making thisfunction-scoped or removingautouseand calling it explicitly where determinism is needed.-@pytest.fixture(autouse=True, scope="session") +@pytest.fixture(autouse=True, scope="function") def setup_test_environment(): """Set torch seed for deterministic tests.""" torch.manual_seed(0xA2A) yieldThis ensures each test starts with the same RNG state.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/comm/test_dcp_alltoall.py` around lines 54 - 58, The session-scoped autouse fixture setup_test_environment only seeds torch once per session, so tests mutate RNG state and won't start from the same seed; change the fixture scope to "function" (or remove autouse and call the fixture explicitly in tests that need determinism) so torch.manual_seed(0xA2A) runs before each test, ensuring each test begins with the same RNG state; locate and update the setup_test_environment fixture definition accordingly.
38-51: Consider usingflashinfer.utils.get_compute_capability()per coding guidelines.The custom
_sm90_available()function reimplements SM capability detection. Per coding guidelines, tests should useflashinfer.utilsfunctions for architecture checks.-def _sm90_available() -> bool: - try: - if not torch.cuda.is_available(): - return False - major, _ = torch.cuda.get_device_capability(0) - return major >= 9 - except Exception: - return False +from flashinfer.utils import get_compute_capability + +def _sm90_available() -> bool: + try: + major, _ = get_compute_capability(0) + return major >= 9 + except Exception: + return FalseThis also addresses the static analysis hint about catching blind exceptions, as the exception handling is limited to a fallback in a skip-check context.
As per coding guidelines: "Use flashinfer.utils functions (
get_compute_capability(),is_sm90a_supported(),is_sm100a_supported()) to skip tests on unsupported GPU architectures"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/comm/test_dcp_alltoall.py` around lines 38 - 51, Replace the custom _sm90_available() and pytestmark check with the utility functions from flashinfer.utils: call flashinfer.utils.get_compute_capability() or use flashinfer.utils.is_sm90a_supported()/is_sm100a_supported() to determine support and base the pytest.mark.skipif on that result; update the symbol references in the module (remove _sm90_available and change pytestmark to call the flashinfer.utils helper) so the test uses the canonical capability-checking helpers rather than reimplementing device queries and broad exception handling.flashinfer/comm/dcp_alltoall.py (2)
101-116: Consider adding@backend_requirementdecorator for SM90+ compute capability check.Per coding guidelines, APIs with compute capability requirements should use the
@backend_requirementdecorator. The DCP all-to-all requires SM90+ (Hopper/Blackwell).+from ..compilation_context import backend_requirement + +@backend_requirement(sm_major=9) `@flashinfer_api` def decode_cp_a2a_workspace_size(cp_size: int) -> int:This would provide
is_compute_capability_supported(cc)andis_backend_supported()methods for runtime checks.As per coding guidelines: "Use
@backend_requirementdecorator on APIs that have compute capability requirements"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/comm/dcp_alltoall.py` around lines 101 - 116, The API decode_cp_a2a_workspace_size lacks the required compute-capability guard; wrap it with the `@backend_requirement` decorator configured for SM90+ so runtime checks are available. Update the function definition to use `@backend_requirement`(...) before `@flashinfer_api` (or alongside as per project convention) and ensure calls to get_dcp_alltoall_module()/get_workspace_size_per_rank are protected by the decorator’s provided helpers (is_compute_capability_supported(cc) and is_backend_supported()) so the API only runs on supported SM90+ backends.
243-248: Minor:__all__is not sorted alphabetically.Static analysis (RUF022) flags unsorted
__all__. Consider sorting for consistency:__all__ = [ - "decode_cp_a2a_workspace_size", "decode_cp_a2a_allocate_workspace", + "decode_cp_a2a_alltoall", "decode_cp_a2a_init_workspace", - "decode_cp_a2a_alltoall", + "decode_cp_a2a_workspace_size", ]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/comm/dcp_alltoall.py` around lines 243 - 248, The __all__ list in dcp_alltoall.py is not alphabetically sorted; reorder the export names so they are lexically ascending (e.g. decode_cp_a2a_alltoall, decode_cp_a2a_allocate_workspace, decode_cp_a2a_init_workspace, decode_cp_a2a_workspace_size) to satisfy static analysis (RUF022) and keep consistency when referencing exports like decode_cp_a2a_workspace_size, decode_cp_a2a_allocate_workspace, decode_cp_a2a_init_workspace, and decode_cp_a2a_alltoall.csrc/nv_internal/cpp/common/envUtils.cpp (1)
359-376: Consider settingenablePDL = falsefor SM < 90.The function returns
trueby default (line 361), and only processes the environment variable whengetSMVersion() >= 90(line 364). This means on pre-SM90 hardware, it always returnstrue.While the CUDA runtime likely ignores the
programmaticStreamSerializationAllowedattribute on older architectures, semantically it would be cleaner to default tofalsewhen the feature isn't supported:bool getEnvEnablePDL() { static std::once_flag flag; - static bool enablePDL = true; + static bool enablePDL = false; std::call_once(flag, [&]() { if (getSMVersion() >= 90) { + enablePDL = true; // default to enabled on supported hardware char const* env = std::getenv("TRTLLM_ENABLE_PDL");This makes the default state consistent with hardware capability.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/nv_internal/cpp/common/envUtils.cpp` around lines 359 - 376, The function getEnvEnablePDL currently defaults enablePDL to true and only reads TRTLLM_ENABLE_PDL when getSMVersion() >= 90; change the default to false so that for SM < 90 getEnvEnablePDL() returns false (feature unsupported), and keep the existing std::call_once logic to override enablePDL only when getSMVersion() >= 90 by reading the TRTLLM_ENABLE_PDL env var; update the static bool enablePDL initialization and ensure getEnvEnablePDL, getSMVersion, and the TRTLLM_ENABLE_PDL handling are the only touched symbols.benchmarks/bench_dcp_alltoall.py (2)
141-147: Prefix unused return values with underscore to suppress linter warnings.The
recv_oandrecv_svariables are intentionally unused during warmup iterations. Prefixing them with underscore makes this intent explicit and silences the Ruff RUF059 warning.♻️ Suggested fix
# Warmup for _ in range(warmup): - recv_o, recv_s = decode_cp_a2a_alltoall( + _recv_o, _recv_s = decode_cp_a2a_alltoall( partial_o, softmax_stats, workspace, rank, cp_size )Similarly for the timed loop at lines 158-160:
- recv_o, recv_s = decode_cp_a2a_alltoall( + _recv_o, _recv_s = decode_cp_a2a_alltoall( partial_o, softmax_stats, workspace, rank, cp_size )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_dcp_alltoall.py` around lines 141 - 147, The warmup loop (and the timed loop that also calls decode_cp_a2a_alltoall) currently assigns returned values to recv_o and recv_s but never uses them; rename those local variables to _recv_o and _recv_s when calling decode_cp_a2a_alltoall to indicate they are intentionally unused and silence the RUF059 linter warning, e.g., replace recv_o, recv_s = decode_cp_a2a_alltoall(...) with _recv_o, _recv_s = decode_cp_a2a_alltoall(...) in both the warmup and the timed invocation sites.
243-244: Prefix unusedlocal_rankwith underscore.The
local_rankis computed insidesetup_mpi()and used there to set the CUDA device, but the returned value is never used at the call site.♻️ Suggested fix
- rank, world_size, mpi_comm, local_rank = setup_mpi() + rank, world_size, mpi_comm, _local_rank = setup_mpi()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_dcp_alltoall.py` around lines 243 - 244, The unpacking of setup_mpi() returns a local_rank value that isn’t used; change the tuple unpack to ignore it by prefixing with an underscore (e.g., replace local_rank with _local_rank or use a single underscore _) so the call becomes rank, world_size, mpi_comm, _local_rank = setup_mpi() which makes intent explicit while leaving setup_mpi(), rank, world_size and mpi_comm unchanged.tests/comm/test_mnnvl_dcp_alltoall.py (1)
53-60: Consider catching a more specific exception or using flashinfer utility.The blind
Exceptioncatch (Ruff BLE001) could mask unexpected errors. However, this is a safety guard for capability checking in an MPI context where various errors could occur.♻️ Alternative using specific exceptions
def _sm90_available() -> bool: try: if not torch.cuda.is_available(): return False major, _ = torch.cuda.get_device_capability(0) return major >= 9 - except Exception: + except (RuntimeError, AssertionError): return FalseAlternatively, consider using
flashinfer.utils.get_compute_capability()if it's available in the MPI-launched test context. As per coding guidelines, tests should useflashinfer.utilsfunctions for architecture checks.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/comm/test_mnnvl_dcp_alltoall.py` around lines 53 - 60, Replace the broad Exception catch in _sm90_available with a targeted approach: prefer calling flashinfer.utils.get_compute_capability() (and return major >= 9) if that utility is available in test runtime; otherwise restrict the except clause to concrete errors that can occur during CUDA queries (e.g., RuntimeError, OSError, IndexError) instead of catching Exception so unexpected bugs aren't swallowed; update the function _sm90_available to reference the chosen approach and adjust imports accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/nv_internal/tensorrt_llm/kernels/helixAllToAll.cu`:
- Around line 486-534: computeChannelAndGroupCount uses a static
std::unordered_map named cache that is mutated without synchronization; protect
concurrent access by adding a static mutex (e.g., cache_mutex) and using a
std::lock_guard<std::mutex> (or similar) around lookups/insertions of cache
keyed by key so find/insert are atomic; for better performance do a
double-checked pattern: do an unsynchronized find, and if not found take the
lock, re-check cache.find(key), and only then compute/insert the value (refer to
computeChannelAndGroupCount, cache, key, and the returned value tuple).
- Around line 595-609: computeHelixWorkspaceSizePerRank currently uses a static
maxChannelCount cached across calls, but computeHelixMaxChannelCount(cpSize)
depends on cpSize so the cached value can be wrong for different cpSize; fix by
removing the static single-value cache and compute maxChannelCount =
computeHelixMaxChannelCount(cpSize) each call (or replace the static with a
small cache keyed by cpSize, e.g., a map from cpSize to channel count) so
computeHelixWorkspaceSizePerRank and its fifo/sender/receiver size calculations
always use the correct channel count.
In `@csrc/nv_internal/tensorrt_llm/kernels/moeCommKernelsCommon.h`:
- Around line 74-84: Replace the current overflow-prone implementation of
ceil_div and add type constraints: add a static_assert(std::is_integral_v<T>) to
both ceil_div and align_up to enforce integral types, reimplement ceil_div as "a
/ b + (a % b != 0)" (which avoids a+b overflow and is constexpr-friendly) and
ensure align_up continues to call ceil_div(value, alignment) * alignment; also
consider validating alignment > 0 (e.g., via assert or precondition) to avoid
division by zero.
---
Outside diff comments:
In `@flashinfer/aot.py`:
- Around line 546-550: The DCP all-to-all module is incorrectly only appended
when has_sm100 is true; move the call to gen_dcp_alltoall_module() out of the
has_sm100 conditional so it is appended to jit_specs for SM90 and SM100 builds
(keep gen_trtllm_comm_module(), gen_trtllm_mnnvl_comm_module(),
gen_moe_alltoall_module() inside the if block and only relocate the
gen_dcp_alltoall_module() invocation so that
jit_specs.append(gen_dcp_alltoall_module()) runs regardless of has_sm100).
---
Nitpick comments:
In `@benchmarks/bench_dcp_alltoall.py`:
- Around line 141-147: The warmup loop (and the timed loop that also calls
decode_cp_a2a_alltoall) currently assigns returned values to recv_o and recv_s
but never uses them; rename those local variables to _recv_o and _recv_s when
calling decode_cp_a2a_alltoall to indicate they are intentionally unused and
silence the RUF059 linter warning, e.g., replace recv_o, recv_s =
decode_cp_a2a_alltoall(...) with _recv_o, _recv_s = decode_cp_a2a_alltoall(...)
in both the warmup and the timed invocation sites.
- Around line 243-244: The unpacking of setup_mpi() returns a local_rank value
that isn’t used; change the tuple unpack to ignore it by prefixing with an
underscore (e.g., replace local_rank with _local_rank or use a single underscore
_) so the call becomes rank, world_size, mpi_comm, _local_rank = setup_mpi()
which makes intent explicit while leaving setup_mpi(), rank, world_size and
mpi_comm unchanged.
In `@csrc/nv_internal/cpp/common/envUtils.cpp`:
- Around line 359-376: The function getEnvEnablePDL currently defaults enablePDL
to true and only reads TRTLLM_ENABLE_PDL when getSMVersion() >= 90; change the
default to false so that for SM < 90 getEnvEnablePDL() returns false (feature
unsupported), and keep the existing std::call_once logic to override enablePDL
only when getSMVersion() >= 90 by reading the TRTLLM_ENABLE_PDL env var; update
the static bool enablePDL initialization and ensure getEnvEnablePDL,
getSMVersion, and the TRTLLM_ENABLE_PDL handling are the only touched symbols.
In `@flashinfer/comm/dcp_alltoall.py`:
- Around line 101-116: The API decode_cp_a2a_workspace_size lacks the required
compute-capability guard; wrap it with the `@backend_requirement` decorator
configured for SM90+ so runtime checks are available. Update the function
definition to use `@backend_requirement`(...) before `@flashinfer_api` (or alongside
as per project convention) and ensure calls to
get_dcp_alltoall_module()/get_workspace_size_per_rank are protected by the
decorator’s provided helpers (is_compute_capability_supported(cc) and
is_backend_supported()) so the API only runs on supported SM90+ backends.
- Around line 243-248: The __all__ list in dcp_alltoall.py is not alphabetically
sorted; reorder the export names so they are lexically ascending (e.g.
decode_cp_a2a_alltoall, decode_cp_a2a_allocate_workspace,
decode_cp_a2a_init_workspace, decode_cp_a2a_workspace_size) to satisfy static
analysis (RUF022) and keep consistency when referencing exports like
decode_cp_a2a_workspace_size, decode_cp_a2a_allocate_workspace,
decode_cp_a2a_init_workspace, and decode_cp_a2a_alltoall.
In `@tests/comm/test_dcp_alltoall.py`:
- Around line 54-58: The session-scoped autouse fixture setup_test_environment
only seeds torch once per session, so tests mutate RNG state and won't start
from the same seed; change the fixture scope to "function" (or remove autouse
and call the fixture explicitly in tests that need determinism) so
torch.manual_seed(0xA2A) runs before each test, ensuring each test begins with
the same RNG state; locate and update the setup_test_environment fixture
definition accordingly.
- Around line 38-51: Replace the custom _sm90_available() and pytestmark check
with the utility functions from flashinfer.utils: call
flashinfer.utils.get_compute_capability() or use
flashinfer.utils.is_sm90a_supported()/is_sm100a_supported() to determine support
and base the pytest.mark.skipif on that result; update the symbol references in
the module (remove _sm90_available and change pytestmark to call the
flashinfer.utils helper) so the test uses the canonical capability-checking
helpers rather than reimplementing device queries and broad exception handling.
In `@tests/comm/test_mnnvl_dcp_alltoall.py`:
- Around line 53-60: Replace the broad Exception catch in _sm90_available with a
targeted approach: prefer calling flashinfer.utils.get_compute_capability() (and
return major >= 9) if that utility is available in test runtime; otherwise
restrict the except clause to concrete errors that can occur during CUDA queries
(e.g., RuntimeError, OSError, IndexError) instead of catching Exception so
unexpected bugs aren't swallowed; update the function _sm90_available to
reference the chosen approach and adjust imports accordingly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 5c6f78ce-54b4-4218-a142-da477e8bf9a5
📒 Files selected for processing (16)
benchmarks/bench_dcp_alltoall.pycsrc/nv_internal/cpp/common/envUtils.cppcsrc/nv_internal/tensorrt_llm/common/envUtils.hcsrc/nv_internal/tensorrt_llm/kernels/cudaAsyncOps.cuhcsrc/nv_internal/tensorrt_llm/kernels/helixAllToAll.cucsrc/nv_internal/tensorrt_llm/kernels/helixAllToAll.hcsrc/nv_internal/tensorrt_llm/kernels/ll128Proto.cuhcsrc/nv_internal/tensorrt_llm/kernels/moeCommKernelsCommon.hcsrc/trtllm_dcp_alltoall.cuflashinfer/aot.pyflashinfer/comm/__init__.pyflashinfer/comm/dcp_alltoall.pyflashinfer/jit/__init__.pyflashinfer/jit/comm.pytests/comm/test_dcp_alltoall.pytests/comm/test_mnnvl_dcp_alltoall.py
| std::tuple<int, int, int> computeChannelAndGroupCount(int cpSize, HelixFieldInfo const* fields) { | ||
| static std::unordered_map<std::tuple<int, int, int>, std::tuple<int, int, int>, hash_cache_key> | ||
| cache; | ||
| int deviceId = 0; | ||
| TLLM_CUDA_CHECK(cudaGetDevice(&deviceId)); | ||
| int singleShmSize = std::max(computeTotalUnpackedSize(fields), computeProtoTransferSize(fields)); | ||
| auto key = std::make_tuple(deviceId, cpSize, singleShmSize); | ||
| auto it = cache.find(key); | ||
| if (it != cache.end()) { | ||
| return it->second; | ||
| } | ||
|
|
||
| int maxGroupCountPerCta = std::min(cpSize, MAX_GROUP_COUNT_PER_BLOCK); | ||
| int groupCountPerCta = maxGroupCountPerCta; // Start with max | ||
| int totalDynamicShmemSize = singleShmSize * groupCountPerCta; | ||
| int maxDynamicShmSize = 0; | ||
| TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&maxDynamicShmSize, | ||
| cudaDevAttrMaxSharedMemoryPerBlockOptin, deviceId)); | ||
|
|
||
| while (totalDynamicShmemSize > maxDynamicShmSize) { | ||
| groupCountPerCta--; | ||
| totalDynamicShmemSize = singleShmSize * groupCountPerCta; | ||
| } | ||
|
|
||
| TLLM_CHECK_WITH_INFO(totalDynamicShmemSize <= maxDynamicShmSize, | ||
| "Single packed size %d exceeds limit %d", singleShmSize, maxDynamicShmSize); | ||
|
|
||
| // Set shared memory attribute if needed | ||
| if (totalDynamicShmemSize > 48 * 1024) { | ||
| TLLM_CUDA_CHECK(cudaFuncSetAttribute(helixAllToAllKernel<ALLOW_VARIABLE_FIELD1>, | ||
| cudaFuncAttributeMaxDynamicSharedMemorySize, | ||
| totalDynamicShmemSize)); | ||
| } | ||
|
|
||
| int blockCountPerChannel = ceil_div(cpSize, groupCountPerCta); | ||
| blockCountPerChannel *= 2; // for send and recv | ||
|
|
||
| int smCount = 0; | ||
| TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&smCount, cudaDevAttrMultiProcessorCount, deviceId)); | ||
| // TODO: we might only want to use half the SMs to overlap with other kernels. | ||
| // note that overlap with FMHA is almost impossible because it must use | ||
| // all SMs and probably uses >50% shmem per SM. | ||
| // overlap with the subsequent BMM / out proj GEMMs might be possible, | ||
| // so we need experiments to see whether it makes sense. | ||
| int channelCount = std::max(smCount / blockCountPerChannel, 1); | ||
| auto value = std::make_tuple(channelCount, groupCountPerCta, totalDynamicShmemSize); | ||
| cache[key] = value; | ||
| return value; | ||
| } |
There was a problem hiding this comment.
Static cache in computeChannelAndGroupCount is not thread-safe for concurrent writes.
The static cache map is written to without synchronization. If multiple threads call this function concurrently with different keys, there's a potential data race during map insertion.
For single-threaded host code paths (typical for kernel launch configuration), this may not be an issue in practice. However, if this code could be called from multiple host threads, consider adding a mutex.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@csrc/nv_internal/tensorrt_llm/kernels/helixAllToAll.cu` around lines 486 -
534, computeChannelAndGroupCount uses a static std::unordered_map named cache
that is mutated without synchronization; protect concurrent access by adding a
static mutex (e.g., cache_mutex) and using a std::lock_guard<std::mutex> (or
similar) around lookups/insertions of cache keyed by key so find/insert are
atomic; for better performance do a double-checked pattern: do an unsynchronized
find, and if not found take the lock, re-check cache.find(key), and only then
compute/insert the value (refer to computeChannelAndGroupCount, cache, key, and
the returned value tuple).
| inline constexpr T ceil_div(T a, T b) { | ||
| return (a + b - 1) / b; | ||
| } | ||
|
|
||
| /** | ||
| * Align value up to nearest multiple of alignment | ||
| */ | ||
| template <typename T> | ||
| inline constexpr T align_up(T value, T alignment) { | ||
| return ceil_div(value, alignment) * alignment; | ||
| } |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify call-site safety for ceil_div/align_up usage patterns.
# Expectation: divisors/alignments are always non-zero and arguments are non-negative integral expressions.
rg -n -C3 '\bceil_div\s*\('
rg -n -C3 '\balign_up\s*\('Repository: flashinfer-ai/flashinfer
Length of output: 50381
🏁 Script executed:
fd -type f -name "moeCommKernelsCommon.h" | head -5Repository: flashinfer-ai/flashinfer
Length of output: 238
🏁 Script executed:
find . -name "moeCommKernelsCommon.h" 2>/dev/nullRepository: flashinfer-ai/flashinfer
Length of output: 130
🏁 Script executed:
sed -n '70,90p' ./csrc/nv_internal/tensorrt_llm/kernels/moeCommKernelsCommon.hRepository: flashinfer-ai/flashinfer
Length of output: 622
🏁 Script executed:
rg -n '\bceil_div\s*\(|\balign_up\s*\(' ./csrc/nv_internal/tensorrt_llm/kernels/moeCommKernelsCommon.hRepository: flashinfer-ai/flashinfer
Length of output: 217
🏁 Script executed:
# Check for any calls to these functions in dependent headers/files
rg -l 'ceil_div|align_up' ./csrc/nv_internal/tensorrt_llm/kernels/ | head -5Repository: flashinfer-ai/flashinfer
Length of output: 282
🏁 Script executed:
# Check usages in helixAllToAll.cu
rg -n -B2 -A2 'ceil_div\s*\(|align_up\s*\(' ./csrc/nv_internal/tensorrt_llm/kernels/helixAllToAll.cu | head -40Repository: flashinfer-ai/flashinfer
Length of output: 1027
🏁 Script executed:
# Check the other file with usage
rg -n -B2 -A2 'ceil_div\s*\(|align_up\s*\(' ./csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.hRepository: flashinfer-ai/flashinfer
Length of output: 471
🏁 Script executed:
# Check types and values of variables used in calls
rg -n 'groupCountPerCta|MAX_GROUP_COUNT_PER_BLOCK' ./csrc/nv_internal/tensorrt_llm/kernels/helixAllToAll.cu | head -20Repository: flashinfer-ai/flashinfer
Length of output: 879
Improve robustness of ceil_div and align_up for signed types.
The current implementation (a + b - 1) / b can overflow for signed types when a + b approaches the type's maximum. Consider using the alternative formula a / b + (a % b != 0), which avoids overflow and works in constexpr contexts. Additionally, add static_assert(std::is_integral_v<T>) to enforce integral types.
Suggested change
+#include <type_traits>
+
template <typename T>
inline constexpr T ceil_div(T a, T b) {
+ static_assert(std::is_integral_v<T>, "ceil_div requires integral type");
- return (a + b - 1) / b;
+ return a / b + (a % b != 0);
}
template <typename T>
inline constexpr T align_up(T value, T alignment) {
+ static_assert(std::is_integral_v<T>, "align_up requires integral type");
return ceil_div(value, alignment) * alignment;
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| inline constexpr T ceil_div(T a, T b) { | |
| return (a + b - 1) / b; | |
| } | |
| /** | |
| * Align value up to nearest multiple of alignment | |
| */ | |
| template <typename T> | |
| inline constexpr T align_up(T value, T alignment) { | |
| return ceil_div(value, alignment) * alignment; | |
| } | |
| `#include` <type_traits> | |
| template <typename T> | |
| inline constexpr T ceil_div(T a, T b) { | |
| static_assert(std::is_integral_v<T>, "ceil_div requires integral type"); | |
| return a / b + (a % b != 0); | |
| } | |
| /** | |
| * Align value up to nearest multiple of alignment | |
| */ | |
| template <typename T> | |
| inline constexpr T align_up(T value, T alignment) { | |
| static_assert(std::is_integral_v<T>, "align_up requires integral type"); | |
| return ceil_div(value, alignment) * alignment; | |
| } |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@csrc/nv_internal/tensorrt_llm/kernels/moeCommKernelsCommon.h` around lines 74
- 84, Replace the current overflow-prone implementation of ceil_div and add type
constraints: add a static_assert(std::is_integral_v<T>) to both ceil_div and
align_up to enforce integral types, reimplement ceil_div as "a / b + (a % b !=
0)" (which avoids a+b overflow and is constexpr-friendly) and ensure align_up
continues to call ceil_div(value, alignment) * alignment; also consider
validating alignment > 0 (e.g., via assert or precondition) to avoid division by
zero.
gen_trtllm_comm_module, gen_moe_alltoall_module, and gen_dcp_alltoall_module all declare supported_major_versions=[9, 10] but were only included in AOT builds when has_sm100 was true. SM90 (Hopper) users would miss these modules in pre-compiled packages.
|
Re: aot.py — DCP module only generated for SM100+ (CodeRabbit) Fixed in 9f34778 — changed This also correctly covers Verified all 30 DCP unit tests pass on H200 (SM90). |
|
/bot run |
|
[FAILED] Pipeline #47588815: 5/20 passed |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/comm/test_mnnvl_dcp_alltoall.py`:
- Line 47: pynvml.nvmlInit() is executed at import time and can cause collection
to fail; move it behind the existing skip protection or guard it so it only runs
when GPUs/driver are present: either relocate the call into the conditional that
checks _sm90_available() and mnnvl_available() (the same block around the
guarded workspace allocation at line 107) or wrap the call in a try/except that
catches initialization errors and sets a flag so tests will be skipped;
reference the symbols pynvml.nvmlInit(), _sm90_available(), and
mnnvl_available() when making the change.
- Around line 53-60: The helper _sm90_available currently duplicates CUDA
capability detection and swallows errors; replace its probe to use
flashinfer.utils.get_compute_capability instead of
torch.cuda.get_device_capability and remove the blanket except so CUDA init
errors surface. Specifically, import get_compute_capability and implement
_sm90_available to first check torch.cuda.is_available(), then call
get_compute_capability(torch.device("cuda")) to obtain (major, minor) and return
major >= 9, leaving exceptions to propagate.
- Around line 332-333: The test runner currently ignores pytest.main()'s return
code in the if __name__ == "__main__" block; capture the return value from
pytest.main([__file__, "-v", "-s"]) and propagate it by calling
sys.exit(return_code) so the process exits nonzero on test failures, and add an
import for sys if not present; update the main block around pytest.main to use
the captured result and sys.exit.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 2dad6958-3076-40c6-ba77-d69c94122820
📥 Commits
Reviewing files that changed from the base of the PR and between 9f34778 and 30a52894d9b6e3de0d8e572797d59cce3b187783.
📒 Files selected for processing (1)
tests/comm/test_mnnvl_dcp_alltoall.py
| if __name__ == "__main__": | ||
| pytest.main([__file__, "-v", "-s"]) |
There was a problem hiding this comment.
Propagate the pytest exit code here.
pytest.main() returns the process status, but this block ignores it, so python tests/comm/test_mnnvl_dcp_alltoall.py exits 0 even when tests fail.
♻️ Proposed fix
if __name__ == "__main__":
- pytest.main([__file__, "-v", "-s"])
+ raise SystemExit(pytest.main([__file__, "-v", "-s"]))📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v", "-s"]) | |
| if __name__ == "__main__": | |
| raise SystemExit(pytest.main([__file__, "-v", "-s"])) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/comm/test_mnnvl_dcp_alltoall.py` around lines 332 - 333, The test
runner currently ignores pytest.main()'s return code in the if __name__ ==
"__main__" block; capture the return value from pytest.main([__file__, "-v",
"-s"]) and propagate it by calling sys.exit(return_code) so the process exits
nonzero on test failures, and add an import for sys if not present; update the
main block around pytest.main to use the captured result and sys.exit.
The module-level MNNVL workspace allocation runs during pytest collection, before pytestmark skipif conditions take effect. In CI environments without SYS_PTRACE capability, this causes a collection error instead of a graceful skip. Guard the allocation behind the same mnnvl_available() check used by pytestmark.
- Wrap all MNNVL test bodies in try/finally to ensure _comm.Barrier() runs even if assertions fail (prevents peer rank hangs) - Remove static cache in computeHelixWorkspaceSizePerRank so workspace size is correct if cpSize varies across calls
|
possible AOT test regression on main from another PR |
|
/bot run |
…as_sm90+ The previous gate change (has_sm100 -> has_sm90 or has_sm100) caused mnnvl_moe_alltoall and trtllm_comm/mnnvl_comm to be compiled on CUDA 12.6 (SM90), but those modules use SM100-only PTX instructions (fence .release/.acquire, cp.async.bulk with fabric state space). Split the gate so only gen_dcp_alltoall_module() is under has_sm90, while the existing SM100-only modules stay behind has_sm100.
Head branch was pushed to by a user without write access
helixAllToAll.cu uses cp.async.bulk instructions that trigger a known ptxas state-space inference bug in CUDA 12.6.0 on arm64. Gate dcp_alltoall under has_sm100 (requires CUDA >= 12.8) alongside the other comm modules. SM90 users still get dcp_alltoall via JIT fallback at runtime.
|
/bot run |
The kernel is compiled with supported_major_versions=[9, 10] in gen_dcp_alltoall_module, so running on SM11x/SM12x GPUs triggers 'No supported CUDA architectures found for major versions [9, 10].' at JIT time. Use is_sm90a_supported/is_sm100a_supported (which also verify the CUDA toolkit version) so the skip matches the kernel's actual compilation gate.
The helix A2A kernel only requires TMA + mbarrier + PDL (`__CUDA_ARCH__ >= 900` baseline) — no SM100-exclusive PTX. TRT-LLM upstream compiles it for SM80/86/89/90/100/103/120 with an internal `#if __CUDA_ARCH__ >= 900` guard, so FlashInfer's previous `supported_major_versions=[9, 10]` gate was unnecessarily narrow and broke JIT on SM11x/SM12x runners. - jit/comm.py: supported_major_versions -> [9, 10, 11, 12] - aot.py: gen_dcp_alltoall_module now runs under any SM90+ family, not just has_sm100 - tests: skipif now accepts major in (9, 10, 11, 12)
|
/bot run |
…0 bug ptxas 12.6.0 has a known bug with cp.async.bulk state-space inference that aborts compilation of helixAllToAll.cu (fixed in 12.6.3). Widening the AOT gate to `has_sm90 or ...` in the previous commit caused the cu126 x64 AOT build to hit this bug (it targets only sm_90a, which is exactly the path that trips ptxas). Keep the JIT arch list widened to [9, 10, 11, 12] — that was the fix for SM12x JIT runners, and is independent of AOT. has_sm100 implies CUDA >= 12.8, so the AOT build naturally avoids the buggy compiler. SM90/SM12x wheel users still get dcp_alltoall via JIT on their own machines.
|
/bot run |
The FlashInfer decode_cp_a2a_alltoall kernel addresses peer FIFOs via a single workspace base pointer + per-rank stride, which only works when the workspace is unified VA (MNNVL fabric memory). The torch.zeros fallback we previously selected for single-node CP groups deadlocks the kernel. Match TRT-LLM upstream and require MNNVL. Refs: flashinfer-ai/flashinfer#2951 follow-up
## Description
Bump version to 0.6.10 for release.
## Related Issues (Gated-by PRs)
https://github.com/flashinfer-ai/flashinfer/issues?q=is%3Aopen+label%3Av0.6.10
## Reviewer Notes
**API changes review**
API changes since v0.6.9
```diff
$ git diff v0.6.9..main -- "*.py" | grep -B5 -A20 "@flashinfer_api"
register_custom_op,
@@ -67,7 +73,7 @@ def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
)
-@flashinfer_api
+@flashinfer_api(trace=silu_and_mul_trace)
def silu_and_mul(
input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None
) -> torch.Tensor:
@@ -112,7 +118,7 @@ def silu_and_mul(
return out
-@flashinfer_api
+@flashinfer_api(trace=gelu_tanh_and_mul_trace)
def gelu_tanh_and_mul(
input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None
) -> torch.Tensor:
@@ -153,7 +159,7 @@ def gelu_tanh_and_mul(
return out
-@flashinfer_api
+@flashinfer_api(trace=gelu_and_mul_trace)
def gelu_and_mul(
input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None
) -> torch.Tensor:
@@ -194,7 +200,7 @@ def gelu_and_mul(
return out
-@flashinfer_api
+@flashinfer_api(trace=silu_and_mul_scaled_nvfp4_experts_quantize_trace)
def silu_and_mul_scaled_nvfp4_experts_quantize(
a,
mask,
diff --git a/flashinfer/aot.py b/flashinfer/aot.py
index dfb05150..d26d5407 100644
--- a/flashinfer/aot.py
+++ b/flashinfer/aot.py
@@ -543,6 +543,7 @@ def gen_all_modules(
if add_comm:
from .jit.comm import (
gen_comm_alltoall_module,
+ gen_dcp_alltoall_module,
gen_moe_alltoall_module,
gen_trtllm_comm_module,
gen_trtllm_mnnvl_comm_module,
@@ -554,6 +555,11 @@ def gen_all_modules(
jit_specs.append(gen_trtllm_comm_module())
jit_specs.append(gen_trtllm_mnnvl_comm_module())
jit_specs.append(gen_moe_alltoall_module())
+ # dcp_alltoall: kernel itself supports SM90+, but ptxas 12.6.0 has
--
-def flashinfer_api(func: Callable = None) -> Callable:
+# ---------------------------------------------------------------------------
+# Trace template registry
+# ---------------------------------------------------------------------------
+# Populated automatically by _attach_fi_trace whenever @flashinfer_api is
+# given a trace= argument. Each entry is (original_func, template, label)
+# where label is the template's name_prefix (or op_type as fallback).
+#
+# For dispatch callables (trace=some_fn), every template listed in
+# some_fn.templates is registered if that attribute exists.
+#
+# Read by tests/trace/test_fi_trace_template_consistency.py to auto-discover
+# all registered templates without requiring manual maintenance.
+_TRACE_REGISTRY: List[Tuple[Callable, Any, str]] = []
+
+
+def _attach_fi_trace(
+ wrapped: Callable,
+ original: Callable,
+ trace_template=None,
+) -> Callable:
+ """Attach a ``fi_trace`` callable to *wrapped*.
+
+ Three resolution strategies, tried in order:
+
--
+
+ warnings.warn(
+ f"[flashinfer] Failed to attach fi_trace to '{_func_name}': "
+ f"{type(_exc).__name__}: {_exc}\n"
+ f"The function will work normally but fi_trace will be unavailable. "
+ f"Fix the TraceTemplate passed to @flashinfer_api(trace=...).",
+ stacklevel=3,
+ )
+ return wrapped
+
+
+def flashinfer_api(func: Callable = None, *, trace=None) -> Callable:
"""
Decorator to FlashInfer's APIs.
@@ -1489,11 +1644,12 @@ def flashinfer_api(func: Callable = None) -> Callable:
- The %i pattern is automatically replaced with the process ID for multi-process environments.
- The logger does not propagate to the root logger to avoid duplicate logs.
"""
- # If logging is disabled, return original function with zero overhead
+ # If logging is disabled, return original function with zero overhead.
+ # We still attach fi_trace so it is always available regardless of log level.
if _API_LOG_LEVEL == 0:
if func is None:
- return lambda f: f
- return func
--
@functools.cache
@@ -135,7 +136,7 @@ class BatchAttention:
causal,
)
- @flashinfer_api
+ @flashinfer_api(trace=batch_attention_run_trace)
def run(
self,
q: torch.Tensor,
@@ -209,6 +210,8 @@ class BatchAttentionWithAttentionSinkWrapper(BatchPrefillWithPagedKVCacheWrapper
a convenient interface for using attention sinks during prefill or decode attention.
"""
+ # No @flashinfer_api here: parent class BatchPrefillWithPagedKVCacheWrapper
+ # already decorates __init__, so decorating again produces double log entries.
def __init__(
self,
float_workspace_buffer: torch.Tensor,
diff --git a/flashinfer/attention/cute_dsl/__init__.py b/flashinfer/attention/cute_dsl/__init__.py
new file mode 100644
index 00000000..3e029627
--- /dev/null
+++ b/flashinfer/attention/cute_dsl/__init__.py
@@ -0,0 +1,37 @@
+# Copyright (c) 2026 by FlashInfer team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
--
@@ -31,7 +37,7 @@ def get_cascade_module():
return gen_cascade_module().build_and_load()
-@flashinfer_api
+@flashinfer_api(trace=merge_state_trace)
@register_custom_op("flashinfer::merge_state", mutates_args=())
def merge_state(
v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
@@ -98,7 +104,7 @@ def _fake_merge_state(
return v, s
-@flashinfer_api
+@flashinfer_api(trace=merge_state_in_place_trace)
@register_custom_op("flashinfer::merge_state_in_place", mutates_args=("v", "s"))
def merge_state_in_place(
v: torch.Tensor,
@@ -159,7 +165,7 @@ def _fake_merge_state_in_place(
pass
-@flashinfer_api
+@flashinfer_api(trace=merge_states_trace)
@register_custom_op("flashinfer::merge_states", mutates_args=())
def merge_states(v: torch.Tensor, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Merge multiple attention states (v, s).
@@ -512,7 +518,7 @@ class MultiLevelCascadeAttentionWrapper:
begin_forward = plan
- @flashinfer_api
+ @flashinfer_api(trace=multi_level_cascade_run_trace)
def run(
self,
q: torch.Tensor,
diff --git a/flashinfer/comm/__init__.py b/flashinfer/comm/__init__.py
index 5f186002..31d23a99 100644
--- a/flashinfer/comm/__init__.py
+++ b/flashinfer/comm/__init__.py
@@ -65,4 +65,15 @@ from .trtllm_moe_alltoall import (
moe_a2a_wrap_payload_tensor_in_workspace as moe_a2a_wrap_payload_tensor_in_workspace,
)
+# DCP A2A (Decode Context Parallel Attention Reduction)
+from .dcp_alltoall import decode_cp_a2a_alltoall as decode_cp_a2a_alltoall
+from .dcp_alltoall import (
+ decode_cp_a2a_allocate_workspace as decode_cp_a2a_allocate_workspace,
+)
+from .dcp_alltoall import decode_cp_a2a_init_workspace as decode_cp_a2a_init_workspace
+from .dcp_alltoall import decode_cp_a2a_workspace_size as decode_cp_a2a_workspace_size
+
# from .mnnvl import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo
--
from .trtllm_ar import trtllm_create_ipc_workspace_for_all_reduce_fusion
@@ -449,7 +450,7 @@ def create_allreduce_fusion_workspace(
# ============================================================================
-@flashinfer_api
+@flashinfer_api(trace=allreduce_fusion_trace)
def allreduce_fusion(
input: torch.Tensor,
workspace: AllReduceFusionWorkspace,
diff --git a/flashinfer/comm/dcp_alltoall.py b/flashinfer/comm/dcp_alltoall.py
new file mode 100644
index 00000000..3047f76c
--- /dev/null
+++ b/flashinfer/comm/dcp_alltoall.py
@@ -0,0 +1,255 @@
+"""
+DCP All-to-All Operations for DCP Attention Reduction
+
+Provides the DCP LL128 FIFO-based all-to-all kernel for context-parallel
+attention reduction. Uses SM90+ features (TMA, mbarrier).
+
+Usage protocol::
+
+ # 1. Query workspace size
+ ws_bytes = decode_cp_a2a_workspace_size(cp_size)
+
--
+
+
+# ─── Public API ───────────────────────────────────────────────────────────
+
+
+@flashinfer_api
+def decode_cp_a2a_workspace_size(cp_size: int) -> int:
+ """Return the workspace size **in bytes** per rank for the given CP group size.
+
+ Args:
+ cp_size: Context-parallel group size (number of ranks).
+
+ Returns:
+ Workspace size in bytes per rank.
+
+ Example::
+
+ >>> decode_cp_a2a_workspace_size(4)
+ 16778240
+ """
+ return get_dcp_alltoall_module().get_workspace_size_per_rank(cp_size)
+
+
+@flashinfer_api
+def decode_cp_a2a_allocate_workspace(
+ cp_size: int,
+ cp_rank: int,
+ *,
+ mapping: Optional[Mapping] = None,
+ mnnvl_config: Optional[MnnvlConfig] = None,
+) -> torch.Tensor:
+ """Allocate a workspace tensor of shape ``[cp_size, ws_elems_per_rank]``.
+
+ After allocation, call :func:`decode_cp_a2a_init_workspace` followed by a
+ cross-rank barrier before the first :func:`decode_cp_a2a_alltoall` call.
+
+ Two allocation modes:
+
+ - **MNNVL** (``mapping`` provided): Cross-rank visible GPU memory via
+ FlashInfer's ``MnnvlMemory``. Required for multi-node or when ranks
+ cannot see each other's device memory directly.
+ - **Plain device memory** (``mapping=None``): Standard ``torch.zeros``
+ allocation. Sufficient for single-node with NVLink P2P.
+
--
+
+ ws_elems_per_rank = (ws_bytes + 7) // 8
+ return torch.zeros(cp_size, ws_elems_per_rank, dtype=torch.int64, device="cuda")
+
+
+@flashinfer_api
+def decode_cp_a2a_init_workspace(
+ workspace: torch.Tensor,
+ cp_rank: int,
+ cp_size: int,
+) -> None:
+ """Initialize the workspace FIFO buffers. Call once before the first alltoall.
+
+ Resets the FIFO buffers in the **local** workspace row
+ (``workspace[cp_rank]``). This function is **synchronous**: when it
+ returns, the GPU memset is guaranteed to have completed.
+
+ .. important::
+ With MNNVL workspaces, **all ranks** must complete
+ ``decode_cp_a2a_init_workspace`` and execute a cross-rank barrier
+ (e.g. ``dist.barrier(group)``) before **any** rank calls
+ :func:`decode_cp_a2a_alltoall`. Without the barrier, a rank may
+ start writing to a peer's FIFO before that peer has finished
+ initializing → deadlock.
+
+ Args:
--
+ # subsequent cross-GPU alltoall can race with the unfinished memset
+ # on MNNVL memory, causing a deadlock.
+ torch.cuda.current_stream().synchronize()
+
+
+@flashinfer_api(trace=decode_cp_a2a_alltoall_trace)
+def decode_cp_a2a_alltoall(
+ partial_o: torch.Tensor,
+ softmax_stats: torch.Tensor,
+ workspace: torch.Tensor,
+ cp_rank: int,
+ cp_size: int,
+ enable_pdl: Optional[bool] = None,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Perform the DCP all-to-all exchange.
+
+ Each rank sends its ``partial_o[..., peer, :]`` slice to the
+ corresponding peer and receives all peers' contributions into the
+ output tensors.
+
+ Args:
+ partial_o: ``[..., cp_size, D]`` — half or bfloat16.
+ ``D * element_size`` must be 16-byte aligned.
+ softmax_stats: ``[..., cp_size, S]`` — float32, ``S >= 2`` and even.
+ Batch dimensions must match ``partial_o``.
+ workspace: ``[cp_size, ws_elems_per_rank]`` int64 tensor from
--
+ MixedCommOp.ALLREDUCE_ALLGATHER: _allreduce_allgather,
+ MixedCommOp.REDUCESCATTER_ALLREDUCE: _reducescatter_allreduce,
+}
+
+
+@flashinfer_api
+@backend_requirement(
+ backend_checks={},
+ common_check=_common_check,
+)
+def run_mixed_comm(
+ op: MixedCommOp,
+ handler: MixedCommHandler,
+ x_in: torch.Tensor,
+ x_out: torch.Tensor | None = None,
+ mode: MixedCommMode | None = None,
+) -> torch.Tensor:
+ """Execute a mixed communication operation.
+
+ This is the main entry point for running communication collectives
+ through the mixed communication handler. It supports fused GPU kernels
+ (using virtual memory intra-node and nvshmem inter-node), NCCL-based
+ fallbacks, and autotuned mode selection.
+
+ Args:
+ op: The communication operation to perform.
--
@functools.cache
@@ -28,7 +29,7 @@ def get_concat_mla_module():
return gen_concat_mla_module().build_and_load()
-@flashinfer_api
+@flashinfer_api(trace=concat_mla_k_trace)
def concat_mla_k(
k: torch.Tensor,
k_nope: torch.Tensor,
diff --git a/flashinfer/cudnn/decode.py b/flashinfer/cudnn/decode.py
index 195ca2d4..9b593095 100644
--- a/flashinfer/cudnn/decode.py
+++ b/flashinfer/cudnn/decode.py
@@ -4,6 +4,7 @@ from typing import Optional
import torch
from ..api_logging import flashinfer_api
+from ..trace.templates.attention import cudnn_batch_decode_trace
from .utils import get_cudnn_fmha_gen_module
try:
@@ -253,7 +254,7 @@ def _batch_decode_with_kv_cache(
return out
-@flashinfer_api
+@flashinfer_api(trace=cudnn_batch_decode_trace)
def cudnn_batch_decode_with_kv_cache(
q: torch.Tensor,
k_cache: torch.Tensor,
diff --git a/flashinfer/cudnn/prefill.py b/flashinfer/cudnn/prefill.py
index fc1bbb5f..b16d6043 100644
--- a/flashinfer/cudnn/prefill.py
+++ b/flashinfer/cudnn/prefill.py
@@ -4,6 +4,7 @@ from typing import Optional
import torch
from ..api_logging import flashinfer_api
+from ..trace.templates.attention import cudnn_batch_prefill_trace
from .utils import get_cudnn_fmha_gen_module
try:
@@ -558,7 +559,7 @@ def _batch_prefill_with_kv_cache(
return out, None
-@flashinfer_api
+@flashinfer_api(trace=cudnn_batch_prefill_trace)
def cudnn_batch_prefill_with_kv_cache(
q: torch.Tensor,
k_cache: torch.Tensor,
diff --git a/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py b/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
index 0b50c22c..f25aa6fd 100644
--- a/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
+++ b/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
@@ -38,6 +38,7 @@ import torch
from cutlass import Float32, Int32, Int64, Uint32, Uint8
from ..api_logging import flashinfer_api
+from ..trace.templates.norm import add_rmsnorm_fp4quant_trace
from ..utils import device_support_pdl
from .fp4_common import (
# Constants
@@ -1042,7 +1043,7 @@ def _get_compiled_kernel(
return tensor_api
-@flashinfer_api
+@flashinfer_api(trace=add_rmsnorm_fp4quant_trace)
def add_rmsnorm_fp4quant(
input: torch.Tensor,
residual: torch.Tensor,
diff --git a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py
index 333697ab..b7aabc36 100644
--- a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py
+++ b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py
@@ -20,6 +20,7 @@ import torch
from cutlass import Float32, Int32
from flashinfer.api_logging import flashinfer_api
+from flashinfer.trace.templates.attention import cute_dsl_batch_mla_run_trace
from flashinfer.utils import device_support_pdl
from flashinfer.cute_dsl.utils import (
get_max_active_clusters,
@@ -519,7 +520,7 @@ class BatchMLADecodeCuteDSLWrapper:
f"out_dtype={self._o_dtype}"
)
- @flashinfer_api
+ @flashinfer_api(trace=cute_dsl_batch_mla_run_trace)
def run(
self,
q: torch.Tensor,
diff --git a/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py b/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py
index 58a24abe..ee0cd5e7 100644
--- a/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py
+++ b/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py
@@ -21,6 +21,7 @@ import cutlass.cute as cute
from cutlass.cute.typing import Int32
from flashinfer.api_logging import flashinfer_api
+from flashinfer.trace.templates.attention import cute_dsl_batch_prefill_run_trace
from ..config import AttentionConfig, AttentionFusion
from ..fusion.mask import MaskType
@@ -371,7 +372,7 @@ class BatchPrefillCuteDSLWrapper:
f"device={self._device}"
)
- @flashinfer_api
+ @flashinfer_api(trace=cute_dsl_batch_prefill_run_trace)
def run(
self,
q: torch.Tensor,
diff --git a/flashinfer/cute_dsl/rmsnorm_fp4quant.py b/flashinfer/cute_dsl/rmsnorm_fp4quant.py
index bc4acffc..97ce68a1 100644
--- a/flashinfer/cute_dsl/rmsnorm_fp4quant.py
+++ b/flashinfer/cute_dsl/rmsnorm_fp4quant.py
@@ -32,6 +32,7 @@ import torch
from cutlass import Float32, Int32, Uint8
from ..api_logging import flashinfer_api
+from ..trace.templates.norm import rmsnorm_fp4quant_trace
from ..utils import device_support_pdl
from .fp4_common import (
# Constants
@@ -771,7 +772,7 @@ def _get_compiled_kernel(
return tensor_api
-@flashinfer_api
+@flashinfer_api(trace=rmsnorm_fp4quant_trace)
def rmsnorm_fp4quant(
input: torch.Tensor,
weight: torch.Tensor,
diff --git a/flashinfer/decode.py b/flashinfer/decode.py
index 822aca40..5e9eb515 100644
--- a/flashinfer/decode.py
+++ b/flashinfer/decode.py
@@ -22,6 +22,12 @@ from typing import Any, List, Literal, Optional, Tuple, Union, overload
import torch
from .api_logging import flashinfer_api
+from .trace.templates.attention import (
+ gqa_paged_decode_trace,
+ single_decode_with_kv_cache_trace,
+ trtllm_batch_decode_trace,
+ xqa_batch_decode_trace,
+)
## NOTE: MLA functions have been moved to mla.py, but we keep the aliases here for backward compatibility.
from .mla import (
@@ -400,7 +406,7 @@ def single_decode_with_kv_cache(
) -> Tuple[torch.Tensor, torch.Tensor]: ...
-@flashinfer_api
+@flashinfer_api(trace=single_decode_with_kv_cache_trace)
def single_decode_with_kv_cache(
q: torch.Tensor,
k: torch.Tensor,
@@ -1215,7 +1221,7 @@ class BatchDecodeWithPagedKVCacheWrapper:
kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ...
- @flashinfer_api
+ @flashinfer_api(trace=gqa_paged_decode_trace)
def run(
self,
q: torch.Tensor,
@@ -1577,6 +1583,8 @@ class CUDAGraphBatchDecodeWithPagedKVCacheWrapper(BatchDecodeWithPagedKVCacheWra
:class:`BatchDecodeWithPagedKVCacheWrapper`
"""
+ # No @flashinfer_api here: parent class BatchDecodeWithPagedKVCacheWrapper
+ # already decorates __init__, so decorating again produces double log entries.
def __init__(
self,
workspace_buffer: torch.Tensor,
@@ -2232,7 +2240,7 @@ def get_trtllm_gen_decode_module(*args):
)
-@flashinfer_api
+@flashinfer_api(trace=trtllm_batch_decode_trace)
def trtllm_batch_decode_with_kv_cache(
query: torch.Tensor,
kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
@@ -2618,7 +2626,7 @@ def trtllm_batch_decode_with_kv_cache(
# xqa uses NHD layout
-@flashinfer_api
+@flashinfer_api(trace=xqa_batch_decode_trace)
def xqa_batch_decode_with_kv_cache(
query: torch.Tensor,
kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
diff --git a/flashinfer/fi_trace.py b/flashinfer/fi_trace.py
new file mode 100644
index 00000000..1104eb6f
--- /dev/null
+++ b/flashinfer/fi_trace.py
@@ -0,0 +1,285 @@
+# Copyright (c) 2025 by FlashInfer team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
--
+
+"""
+fi_trace: Generate `flashinfer-bench <https://github.com/flashinfer-ai/flashinfer-bench>`_
+compatible definition JSON for FlashInfer APIs.
+
+Every ``@flashinfer_api(trace=<template>)``-decorated function supports two
+usage modes:
+
+Auto-dump (recommended)
+-----------------------
+Set environment variables **before** importing flashinfer, then run your
+workload normally. No explicit ``fi_trace`` call is needed.
+
+.. code-block:: bash
+
+ FLASHINFER_TRACE_DUMP=1 \\
+ FLASHINFER_TRACE_DUMP_DIR=./fi_trace_out \\
+ python my_script.py
+
+Every decorated function writes a ``<name>.json`` file on its **first** call
+for each unique set of const-axis values (e.g. head dimensions, vocab size).
+Subsequent calls with the same shape are deduplicated — the file is written
+only once per process. The output directory is created automatically.
+
+Explicit call (for selective or programmatic use)
+-------------------------------------------------
--
+from pathlib import Path
+from typing import Any, Callable, Dict, Optional, Union
+
+# ---------------------------------------------------------------------------
+# Legacy registry — kept for backwards compatibility.
+# New code should use @flashinfer_api(trace=TraceTemplate(...)) instead.
+# ---------------------------------------------------------------------------
+
+_REGISTRY: Dict[str, Any] = {}
+
+
+def register_fi_trace(qualname: str, spec: Any) -> None:
+ """Register a legacy FiTraceSpec for the function with the given qualname.
+
+ .. deprecated::
+ Use ``@flashinfer_api(trace=TraceTemplate(...))`` instead.
+ """
+ _REGISTRY[qualname] = spec
+
+
+def build_fi_trace_fn(spec: Any) -> Callable[..., Dict[str, Any]]:
+ """Build a fi_trace callable from a legacy FiTraceSpec.
+
+ .. deprecated::
+ Use ``TraceTemplate.build_fi_trace_fn`` instead.
+ """
+ # Import the old implementation from the trace package for backwards compat.
+ from .trace.template import ( # noqa: PLC0415,F401
+ Const,
+ Scalar,
+ Tensor,
+ TraceTemplate,
+ Var,
+ )
+ import json # noqa: PLC0415
+ import os # noqa: PLC0415
--
+ """Generate a flashinfer-bench definition JSON for any FlashInfer API call.
+
+ Parameters
+ ----------
+ func_or_method:
+ A ``@flashinfer_api``-decorated function or (bound) method.
+ save_dir:
+ Directory where the JSON definition file should be written.
+ Falls back to ``FLASHINFER_TRACE_DUMP_DIR`` env-var when *None*.
+ **kwargs:
+ The same tensor arguments you would pass to the real API.
+
+ Returns
+ -------
+ dict
+ A flashinfer-bench compatible definition dictionary.
+
+ Examples
+ --------
+ Standalone function::
+
+ defn = fi_trace(flashinfer.norm.rmsnorm, input=hidden, weight=weight)
+
+ Bound method (instance.run)::
+
+ defn = fi_trace(wrapper.run, q=q_tensor, paged_kv_cache=(k, v))
--
+ trace_fn = getattr(actual_func, "fi_trace", None)
+ if trace_fn is None:
+ qualname = getattr(actual_func, "__qualname__", repr(actual_func))
+ raise ValueError(
+ f"No fi_trace spec is registered for '{qualname}'. "
+ "Only @flashinfer_api(trace=...)-decorated functions support fi_trace."
+ )
+ return trace_fn(save_dir=save_dir, **kwargs)
diff --git a/flashinfer/fused_moe/__init__.py b/flashinfer/fused_moe/__init__.py
index df6e1f72..d983f9d4 100644
--- a/flashinfer/fused_moe/__init__.py
+++ b/flashinfer/fused_moe/__init__.py
@@ -17,6 +17,8 @@ limitations under the License.
from .core import (
convert_to_block_layout,
cutlass_fused_moe,
+ interleave_moe_scales_for_sm90_mixed_gemm,
+ interleave_moe_weights_for_sm90_mixed_gemm,
gen_cutlass_fused_moe_sm120_module,
gen_cutlass_fused_moe_sm103_module,
gen_cutlass_fused_moe_sm100_module,
@@ -64,6 +66,8 @@ __all__ = [
"WeightLayout",
"convert_to_block_layout",
"cutlass_fused_moe",
+ "interleave_moe_scales_for_sm90_mixed_gemm",
--
+ ),
)
-# ref: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py#L121
@flashinfer_api
+def interleave_moe_scales_for_sm90_mixed_gemm(
+ scales: torch.Tensor,
+ group_size: int = 32,
+) -> torch.Tensor:
+ """Interleave MXFP4 block scales for the SM90 mixed-input MoE GEMM.
+
+ The kernel expects scales in layout
+ ``(num_experts, K // (group_size * 4), rows * 4)`` rather than the natural
+ ``(num_experts, rows, K // group_size)`` produced by the MXFP4 quantizer.
+ This helper performs the reshape + permute equivalent to TensorRT-LLM's
+ ``WFP4A16FusedMoEMethod.load_quant_scales`` (PR #12451), with the fixed
+ interleave factor of ``128 // group_size`` used for MXFP4.
+
+ Parameters
+ ----------
+ scales:
+ ``[num_experts, rows, K // group_size]`` uint8 tensor of E8M0 block
+ scales.
+ group_size:
+ MXFP4 quantization group size (default 32).
--
+ scales.reshape(e, rows, kgs // factor, factor).permute(0, 2, 1, 3).contiguous()
+ )
+ return tmp.reshape(e, kgs // factor, rows * factor)
+
+
+@flashinfer_api
+def interleave_moe_weights_for_sm90_mixed_gemm(
+ weight: torch.Tensor,
+ quant_type: str = "fp4",
+) -> torch.Tensor:
+ """Interleave 4-bit packed MoE weights for the SM90 mixed-input GEMM.
+
+ The SM90 mixed-dtype MoE GEMM (used by ``cutlass_fused_moe`` with
+ ``use_w4_group_scaling=True``) expects weights in a specific interleaved
+ layout; without preprocessing, the LUT-based FP4→BF16 conversion reads
+ bytes from the wrong positions and the output diverges from a dequantized
+ reference for any K > 128. TensorRT-LLM's W4A16 MoE runs the equivalent
+ preprocessing at weight-load time (see
+ ``interleave_4bit_weights_for_Hopper_mixed_gemm`` in TRT-LLM PR #12451).
+
+ Parameters
+ ----------
+ weight:
+ ``[num_experts, n, k // 2]`` uint8 CUDA tensor (4-bit values packed
+ two-per-byte).
+ quant_type:
--
+ )
+ return out
+
+
+# ref: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py#L121
+@flashinfer_api(trace=cutlass_fused_moe_trace)
def cutlass_fused_moe(
input: torch.Tensor,
token_selected_experts: torch.Tensor,
@@ -1027,8 +1151,8 @@ def get_trtllm_moe_sm100_module():
DynamicTensorSpec(
input_idx,
dim_idx,
- get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens, 1),
- lambda x: min(last_positive_power_of_2(x), tune_max_num_tokens),
+ get_hybrid_num_tokens_buckets(tune_max_num_tokens, 1),
+ lambda x: map_to_hybrid_bucket(x, tune_max_num_tokens),
initializers,
),
),
@@ -2344,7 +2468,7 @@ def _validate_routing_replay_out(
raise ValueError("routing_replay_out must be contiguous (packed row-major)")
-@flashinfer_api
+@flashinfer_api(trace=trtllm_bf16_moe_trace)
def trtllm_bf16_moe(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
@@ -2452,7 +2576,7 @@ def trtllm_bf16_moe(
return result
-@flashinfer_api
+@flashinfer_api(trace=trtllm_bf16_routed_moe_trace)
def trtllm_bf16_routed_moe(
topk_ids: torch.Tensor,
hidden_states: torch.Tensor,
@@ -2557,7 +2681,7 @@ def trtllm_bf16_routed_moe(
return result
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp8_per_tensor_scale_moe_trace)
def trtllm_fp8_per_tensor_scale_moe(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
@@ -2658,7 +2782,7 @@ def trtllm_fp8_per_tensor_scale_moe(
return result
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp8_block_scale_moe_trace_dispatch)
def trtllm_fp8_block_scale_moe(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
@@ -2779,7 +2903,7 @@ def trtllm_fp8_block_scale_moe(
return result
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp8_block_scale_routed_moe_trace)
def trtllm_fp8_block_scale_routed_moe(
topk_ids: torch.Tensor,
routing_bias: Optional[torch.Tensor],
@@ -2893,7 +3017,7 @@ def trtllm_fp8_block_scale_routed_moe(
return result
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp4_block_scale_moe_trace_dispatch)
def trtllm_fp4_block_scale_moe(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
@@ -3030,7 +3154,7 @@ def trtllm_fp4_block_scale_moe(
)
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fp4_block_scale_routed_moe_trace)
def trtllm_fp4_block_scale_routed_moe(
topk_ids: torch.Tensor,
routing_bias: Optional[torch.Tensor],
@@ -3165,7 +3289,7 @@ def trtllm_fp4_block_scale_routed_moe(
)
-@flashinfer_api
+@flashinfer_api(trace=trtllm_mxint4_block_scale_moe_trace)
def trtllm_mxint4_block_scale_moe(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
diff --git a/flashinfer/fused_moe/cute_dsl/b12x_moe.py b/flashinfer/fused_moe/cute_dsl/b12x_moe.py
index d2cbc8b0..34916df5 100644
--- a/flashinfer/fused_moe/cute_dsl/b12x_moe.py
+++ b/flashinfer/fused_moe/cute_dsl/b12x_moe.py
@@ -42,11 +42,12 @@ from typing import Optional, Tuple
import torch
from ...api_logging import flashinfer_api
+from ...trace.templates.moe import b12x_fused_moe_trace, b12x_moe_wrapper_run_trace
from ...utils import supported_compute_capability
@supported_compute_capability([120, 121])
-@flashinfer_api
+@flashinfer_api(trace=b12x_fused_moe_trace)
def b12x_fused_moe(
x: torch.Tensor,
w1_weight: torch.Tensor,
@@ -293,7 +294,7 @@ class B12xMoEWrapper:
device=self.device,
)
- @flashinfer_api
+ @flashinfer_api(trace=b12x_moe_wrapper_run_trace)
def run(
self,
x: torch.Tensor,
diff --git a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py
index f6cf1b67..e266cb77 100644
--- a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py
+++ b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py
@@ -89,8 +89,8 @@ from flashinfer.cute_dsl.fp4_common import (
st_global_u64,
scatter_add_bf16x2,
)
-from flashinfer.gemm.kernels.dense_blockscaled_gemm_sm120 import (
- Sm120BlockScaledDenseGemmKernel as DenseGemmKernel,
+from flashinfer.gemm.kernels.dense_blockscaled_gemm_sm120_b12x import (
+ Sm120B12xBlockScaledDenseGemmKernel as DenseGemmKernel,
)
diff --git a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py
index e7fdae92..670b3ad8 100644
--
from .moe_utils import (
@@ -530,7 +534,7 @@ class CuteDslMoEWrapper:
enable_pdl=enable_pdl,
)
- @flashinfer_api
+ @flashinfer_api(trace=cute_dsl_moe_wrapper_run_trace)
def run(
self,
x: torch.Tensor,
@@ -686,7 +690,7 @@ def _cute_dsl_fused_moe_nvfp4_impl(
@supported_compute_capability([100, 103])
-@flashinfer_api
+@flashinfer_api(trace=cute_dsl_fused_moe_nvfp4_trace)
def cute_dsl_fused_moe_nvfp4(
x: torch.Tensor,
x_sf: torch.Tensor,
diff --git a/flashinfer/fused_moe/cute_dsl/tuner.py b/flashinfer/fused_moe/cute_dsl/tuner.py
index 0cc8628e..636043db 100644
--- a/flashinfer/fused_moe/cute_dsl/tuner.py
+++ b/flashinfer/fused_moe/cute_dsl/tuner.py
@@ -42,8 +42,8 @@ from ...autotuner import (
TuningConfig,
)
from ..utils import (
- get_last_power_of_2_num_tokens_buckets,
- last_positive_power_of_2,
+ get_hybrid_num_tokens_buckets,
+ map_to_hybrid_bucket,
)
logger = logging.getLogger(__name__)
@@ -273,10 +273,8 @@ class CuteDslFusedMoENvfp4Runner(TunableRunner):
DynamicTensorSpec(
--
import torch
@@ -137,7 +138,7 @@ def get_dsv3_fused_routing_module():
@backend_requirement({}, common_check=_check_dsv3_fused_routing_supported)
-@flashinfer_api
+@flashinfer_api(trace=fused_topk_deepseek_trace)
def fused_topk_deepseek(
scores: torch.Tensor,
bias: torch.Tensor,
diff --git a/flashinfer/fused_moe/utils.py b/flashinfer/fused_moe/utils.py
index 004271a1..91f37aa5 100644
--- a/flashinfer/fused_moe/utils.py
+++ b/flashinfer/fused_moe/utils.py
@@ -209,29 +209,102 @@ def nearest_in_buckets(x: int, buckets: List[int]) -> int:
return min(max(next_positive_power_of_2(x), buckets[0]), buckets[-1])
-def get_power_of_2_num_tokens_buckets(max_num_tokens) -> Tuple[int]:
- """Return descending power-of-2 buckets from ``next_power_of_2(max_num_tokens)`` down to 1."""
- max_num_tokens = next_positive_power_of_2(max_num_tokens)
- num_token_buckets = []
- m = max_num_tokens
- while m >= 1:
- num_token_buckets.append(m)
- m //= 2
+_PHASE1_END = 256
--
@@ -106,7 +114,7 @@ TILE_V = 8 # pretranspose tile size
# ============================================================================
-@flashinfer_api
+@flashinfer_api(trace=gated_delta_rule_decode_trace)
def gated_delta_rule_decode_pretranspose(
q: torch.Tensor,
k: torch.Tensor,
@@ -394,7 +402,7 @@ def gated_delta_rule_decode_pretranspose(
# ============================================================================
-@flashinfer_api
+@flashinfer_api(trace=gated_delta_rule_decode_trace)
def gated_delta_rule_decode(
q: torch.Tensor,
k: torch.Tensor,
@@ -535,7 +543,7 @@ def gated_delta_rule_decode(
# ============================================================================
-@flashinfer_api
+@flashinfer_api(trace=gdn_mtp_trace)
def gated_delta_rule_mtp(
q: torch.Tensor,
k: torch.Tensor,
diff --git a/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py b/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
index 68398d28..53fe44ce 100644
--- a/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
+++ b/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
@@ -3333,8 +3333,7 @@ class GatedDeltaNetChunkedKernel:
gate_handle = load_gate_consumer.wait_and_advance()
- max_coord = tTR_tCcShared[cute.size(tTR_tCcShared) - 1]
- cumprod_total = sCumprod[max_coord[1], 0, gate_handle.index]
+ cumprod_total = sCumprod[sCumprod.shape[0] - 1, 0, gate_handle.index]
valid_state = not is_first_chunk or self.use_initial_state
if cutlass.const_expr(valid_state):
diff --git a/flashinfer/gdn_kernels/blackwell/gdn_prefill.py b/flashinfer/gdn_kernels/blackwell/gdn_prefill.py
index 82dcc72b..aafcc671 100644
--- a/flashinfer/gdn_kernels/blackwell/gdn_prefill.py
--
register_custom_op,
@@ -95,7 +96,7 @@ def get_gdn_prefill_module():
return SimpleNamespace(gdn_prefill=gdn_prefill)
-@flashinfer_api
+@flashinfer_api(trace=gdn_prefill_trace)
def chunk_gated_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
diff --git a/flashinfer/gemm/__init__.py b/flashinfer/gemm/__init__.py
index a7795beb..def82216 100644
--- a/flashinfer/gemm/__init__.py
+++ b/flashinfer/gemm/__init__.py
@@ -61,11 +61,11 @@ try:
from flashinfer.cute_dsl.utils import is_cute_dsl_available
if is_cute_dsl_available():
- from .kernels.dense_blockscaled_gemm_sm120 import (
- Sm120BlockScaledDenseGemmKernel as Sm120BlockScaledDenseGemmKernel,
+ from .kernels.dense_blockscaled_gemm_sm120_b12x import (
+ Sm120B12xBlockScaledDenseGemmKernel as Sm120B12xBlockScaledDenseGemmKernel,
)
- _cute_dsl_kernels.append("Sm120BlockScaledDenseGemmKernel")
+ _cute_dsl_kernels.append("Sm120B12xBlockScaledDenseGemmKernel")
except ImportError:
--
from ..utils import (
@@ -325,7 +339,7 @@ def _heuristic_func_mm_bf16(
common_check=_check_mm_bf16_problem_size,
heuristic_func=_heuristic_func_mm_bf16,
)
-@flashinfer_api
+@flashinfer_api(trace=mm_bf16_trace)
def mm_bf16(
a: torch.Tensor,
b: torch.Tensor,
@@ -514,7 +528,7 @@ def _heuristic_func_bmm_bf16(
common_check=_check_bmm_bf16_problem_size,
heuristic_func=_heuristic_func_bmm_bf16,
)
-@flashinfer_api
+@flashinfer_api(trace=bmm_bf16_trace)
def bmm_bf16(
A: torch.Tensor,
B: torch.Tensor,
@@ -815,8 +829,8 @@ _FP8_GEMM_SM100_TUNING_CONFIG = TuningConfig(
DynamicTensorSpec(
(0,), # a_tensor_index
(-2,),
- get_last_power_of_2_num_tokens_buckets,
- last_positive_power_of_2,
+ get_hybrid_num_tokens_buckets,
+ map_to_hybrid_bucket_uncapped,
),
),
constraint_specs=(
@@ -871,8 +885,8 @@ _BF16_GEMM_SM100_TUNING_CONFIG = TuningConfig(
DynamicTensorSpec(
(0,), # a_tensor_index
(-2,),
- get_last_power_of_2_num_tokens_buckets,
- last_positive_power_of_2,
--
constraint_specs=(
@@ -1095,7 +1109,7 @@ def get_tgv_gemm_sm10x_module(
)
-@flashinfer_api
+@flashinfer_api(trace=tgv_gemm_sm100_trace)
def tgv_gemm_sm100(
a: torch.Tensor,
b: torch.Tensor,
@@ -1173,8 +1187,8 @@ def tgv_gemm_sm100(
DynamicTensorSpec(
(a_tensor_index,),
(-2,),
- get_last_power_of_2_num_tokens_buckets,
- last_positive_power_of_2,
+ get_hybrid_num_tokens_buckets,
+ map_to_hybrid_bucket_uncapped,
),
),
constraint_specs=(
@@ -1437,6 +1451,7 @@ class SegmentGEMMWrapper:
True
"""
+ @flashinfer_api
def __init__(
self, float_workspace_buffer: torch.Tensor, backend: str = "auto"
) -> None:
@@ -1469,7 +1484,7 @@ class SegmentGEMMWrapper:
self._float_workspace_buffer = float_workspace_buffer
self._int_workspace_buffer = int_workspace_buffer
- @flashinfer_api
+ @flashinfer_api(trace=segment_gemm_run_trace)
def run(
self,
x: torch.Tensor,
@@ -2084,6 +2099,8 @@ def build_cudnn_gemm_fp4_graph_override_shape(
return graph
+# Internal helper called from mm_fp4; the user-facing mm_fp4 is already
+# decorated, so decorating here would double-log the same invocation.
def execute_cudnn_gemm_fp4_graph_override_shape(
graph,
a,
@@ -2319,6 +2336,8 @@ def build_cudnn_gemm_mxfp8_graph_override_shape(
return graph
+# Internal helper called from mm_mxfp8; the user-facing mm_mxfp8 is already
+# decorated, so decorating here would double-log the same invocation.
def execute_cudnn_gemm_mxfp8_graph_override_shape(
graph,
--
):
@@ -3161,7 +3184,7 @@ def _expand_block_scale_tensor_shape(block_scale_tensor, batch_size):
return (tuple(block_scale_shape), tuple(block_scale_stride))
-@flashinfer_api
+@flashinfer_api(trace=mm_fp8_trace)
def mm_fp8(
a: torch.Tensor,
b: torch.Tensor,
@@ -3990,7 +4013,7 @@ def _heuristic_func_mm_mxfp8(
common_check=_check_mm_mxfp8_problem_size,
heuristic_func=_heuristic_func_mm_mxfp8, # result stored in mm_mxfp8.suitable_auto_backends
)
-@flashinfer_api
+@flashinfer_api(trace=mm_mxfp8_trace)
def mm_mxfp8(
a: torch.Tensor,
b: torch.Tensor,
@@ -4858,8 +4881,8 @@ def _b12x_gemm_fp4_runner(
"""
import cutlass
- from .kernels.dense_blockscaled_gemm_sm120 import (
- Sm120BlockScaledDenseGemmKernel,
+ from .kernels.dense_blockscaled_gemm_sm120_b12x import (
+ Sm120B12xBlockScaledDenseGemmKernel,
)
cutlass_dtype_attr = _TORCH_TO_CUTLASS_DTYPE_ATTR.get(out_dtype)
@@ -4905,7 +4928,7 @@ def _b12x_gemm_fp4_runner(
]
swap_ab = False
for mma_tiler_mn in sm120_mma_tiler_candidates:
- if not Sm120BlockScaledDenseGemmKernel.can_implement(
+ if not Sm120B12xBlockScaledDenseGemmKernel.can_implement(
--
constraint_specs=(
@@ -5195,7 +5217,7 @@ _MM_MXFP8_TUNING_CONFIG = TuningConfig(
common_check=_check_mm_fp4_problem_size,
heuristic_func=_heuristic_func_mm_fp4, # result stored in mm_fp4.suitable_auto_backends
)
-@flashinfer_api
+@flashinfer_api(trace=mm_fp4_trace)
def mm_fp4(
a: torch.Tensor,
b: torch.Tensor,
@@ -5449,7 +5471,7 @@ def _heuristic_func_bmm_fp8(
common_check=_check_bmm_fp8_problem_size,
heuristic_func=_heuristic_func_bmm_fp8,
)
-@flashinfer_api
+@flashinfer_api(trace=bmm_fp8_trace)
def bmm_fp8(
A: torch.Tensor,
B: torch.Tensor,
@@ -6862,7 +6884,7 @@ def _check_batch_deepgemm_fp8_nt_groupwise(
{},
common_check=_check_batch_deepgemm_fp8_nt_groupwise,
)
-@flashinfer_api
+@flashinfer_api(trace=batch_deepgemm_fp8_nt_groupwise_trace)
def batch_deepgemm_fp8_nt_groupwise(
a: torch.Tensor, # (batch_size, m, k)
b: torch.Tensor, # (batch_size, n, k)
@@ -7006,7 +7028,7 @@ def get_fp8_blockscale_gemm_runner_sm90():
return module.init()
-@flashinfer_api
+@flashinfer_api(trace=fp8_blockscale_gemm_sm90_trace)
def fp8_blockscale_gemm_sm90(
input: torch.Tensor,
weight: torch.Tensor,
@@ -7588,7 +7610,7 @@ def _heuristic_func_bmm_mxfp8(
common_check=_check_bmm_mxfp8_problem_size,
heuristic_func=_heuristic_func_bmm_mxfp8,
)
-@flashinfer_api
+@flashinfer_api(trace=bmm_mxfp8_trace)
def bmm_mxfp8(
A: torch.Tensor,
B: torch.Tensor,
diff --git a/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py b/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120_b12x.py
similarity index 99%
rename from flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py
rename to flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120_b12x.py
index c49bc815..6eee27a7 100644
--- a/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py
+++ b/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120_b12x.py
@@ -1550,7 +1550,7 @@ class DenseGemmKernel:
# Alias for FlashInfer integration
-Sm120BlockScaledDenseGemmKernel = DenseGemmKernel
+Sm120B12xBlockScaledDenseGemmKernel = DenseGemmKernel
class _DenseGemmLaunch:
diff --git a/flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py b/flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py
--
get_cutlass_dtype,
@@ -2951,7 +2952,7 @@ def get_cute_dsl_compiled_masked_gemm_kernel(
return tensor_api
-@flashinfer_api
+@flashinfer_api(trace=grouped_gemm_nt_masked_trace)
def grouped_gemm_nt_masked(
lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
diff --git a/flashinfer/gemm/routergemm.py b/flashinfer/gemm/routergemm.py
index cfde7d43..f83c8974 100644
--- a/flashinfer/gemm/routergemm.py
+++ b/flashinfer/gemm/routergemm.py
@@ -1,4 +1,8 @@
from ..api_logging import flashinfer_api
+from ..trace.templates.gemm import (
+ mm_M1_16_K7168_N256_trace,
+ tinygemm_bf16_trace,
+)
from flashinfer.jit import gen_dsv3_router_gemm_module, gen_tinygemm2_module
import functools
from types import SimpleNamespace
@@ -176,7 +180,7 @@ def mm_M1_16_K7168_N128(
@backend_requirement({}, common_check=_mm_M1_16_K7168_N256_shape_checks)
-@flashinfer_api
+@flashinfer_api(trace=mm_M1_16_K7168_N256_trace)
def mm_M1_16_K7168_N256(
mat_a: torch.Tensor,
mat_b: torch.Tensor,
@@ -324,7 +328,7 @@ def get_tinygemm2_module():
@backend_requirement({}, common_check=_tinygemm_bf16_shape_checks)
-@flashinfer_api
+@flashinfer_api(trace=tinygemm_bf16_trace)
def tinygemm_bf16(
input: torch.Tensor,
weight: torch.Tensor,
diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py
index 7f36a314..8378e0ab 100644
--- a/flashinfer/jit/__init__.py
+++ b/flashinfer/jit/__init__.py
@@ -82,6 +82,7 @@ from .comm import gen_trtllm_mnnvl_comm_module as gen_trtllm_mnnvl_comm_module
from .comm import gen_trtllm_comm_module as gen_trtllm_comm_module
from .comm import gen_vllm_comm_module as gen_vllm_comm_module
from .comm import gen_moe_alltoall_module as gen_moe_alltoall_module
+from .comm import gen_dcp_alltoall_module as gen_dcp_alltoall_module
from .dsv3_optimizations import (
gen_dsv3_router_gemm_module as gen_dsv3_router_gemm_module,
)
diff --git a/flashinfer/jit/comm.py b/flashinfer/jit/comm.py
index 46768eed..834f77f9 100644
--- a/flashinfer/jit/comm.py
+++ b/flashinfer/jit/comm.py
@@ -15,7 +15,13 @@ limitations under the License.
--
gen_selective_state_update_sm100_module,
@@ -99,7 +100,7 @@ def get_selective_state_update_module(
)
-@flashinfer_api
+@flashinfer_api(trace=selective_state_update_trace)
def selective_state_update(
state: torch.Tensor,
x: torch.Tensor,
diff --git a/flashinfer/mla/_core.py b/flashinfer/mla/_core.py
index 4e8bdd72..e27e3807 100644
--- a/flashinfer/mla/_core.py
+++ b/flashinfer/mla/_core.py
@@ -21,6 +21,11 @@ from typing import List, Literal, Optional, Tuple, Union, overload
import torch
from ..api_logging import flashinfer_api
+from ..trace.templates.attention import (
+ mla_paged_decode_trace,
+ trtllm_batch_decode_mla_trace,
+ xqa_batch_decode_mla_trace,
+)
from ..jit import gen_batch_mla_module, gen_trtllm_gen_fmha_module, setup_cubin_loader
from ..jit.mla import gen_mla_module
from ..utils import (
@@ -469,7 +474,7 @@ class BatchMLAPagedAttentionWrapper:
return_lse_base_on_e: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ...
- @flashinfer_api
+ @flashinfer_api(trace=mla_paged_decode_trace)
def run(
self,
q_nope: torch.Tensor,
@@ -588,7 +593,7 @@ class BatchMLAPagedAttentionWrapper:
return (out, lse) if return_lse else out
-@flashinfer_api
+@flashinfer_api(trace=trtllm_batch_decode_mla_trace)
def trtllm_batch_decode_with_kv_cache_mla(
query: torch.Tensor,
kv_cache: torch.Tensor,
@@ -856,7 +861,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
raise ValueError(f"Backend {backend} not supported")
-@flashinfer_api
+@flashinfer_api(trace=xqa_batch_decode_mla_trace)
def xqa_batch_decode_with_kv_cache_mla(
query: torch.Tensor,
kv_cache: torch.Tensor,
diff --git a/flashinfer/norm/__init__.py b/flashinfer/norm/__init__.py
index 0f9911a6..ba612b28 100644
--- a/flashinfer/norm/__init__.py
+++ b/flashinfer/norm/__init__.py
@@ -32,6 +32,16 @@ from typing import Optional, Union
import torch
from ..api_logging import flashinfer_api
+from ..trace.templates.norm import (
+ fused_add_rmsnorm_quant_trace,
+ fused_add_rmsnorm_trace,
+ fused_rmsnorm_silu_trace,
+ gemma_fused_add_rmsnorm_trace,
+ gemma_rmsnorm_trace,
+ layernorm_trace,
+ rmsnorm_quant_trace,
+ rmsnorm_trace,
--
get_compute_capability,
@@ -94,7 +104,7 @@ def _normalize_scale_tensor(
return scale.contiguous()
-@flashinfer_api
+@flashinfer_api(trace=rmsnorm_trace)
def rmsnorm(
input: torch.Tensor,
weight: torch.Tensor,
@@ -165,7 +175,7 @@ def _rmsnorm_impl_fake(
pass
-@flashinfer_api
+@flashinfer_api(trace=rmsnorm_quant_trace)
@register_custom_op("flashinfer::rmsnorm_quant", mutates_args=("out",))
def rmsnorm_quant(
out: torch.Tensor,
@@ -219,7 +229,7 @@ def _rmsnorm_quant_fake(
pass
-@flashinfer_api
+@flashinfer_api(trace=fused_add_rmsnorm_trace)
@register_custom_op("flashinfer::fused_add_rmsnorm", mutates_args=("input", "residual"))
def fused_add_rmsnorm(
input: torch.Tensor,
@@ -271,7 +281,7 @@ def _fused_add_rmsnorm_fake(
pass
-@flashinfer_api
+@flashinfer_api(trace=fused_add_rmsnorm_quant_trace)
@register_custom_op(
"flashinfer::fused_add_rmsnorm_quant", mutates_args=("out", "residual")
)
@@ -343,7 +353,7 @@ def _fused_add_rmsnorm_quant_fake(
pass
-@flashinfer_api
+@flashinfer_api(trace=gemma_rmsnorm_trace)
def gemma_rmsnorm(
input: torch.Tensor,
weight: torch.Tensor,
@@ -414,7 +424,7 @@ def _gemma_rmsnorm_impl_fake(
pass
-@flashinfer_api
+@flashinfer_api(trace=gemma_fused_add_rmsnorm_trace)
@register_custom_op(
"flashinfer::gemma_fused_add_rmsnorm", mutates_args=("input", "residual")
)
@@ -470,7 +480,7 @@ def _gemma_fused_add_rmsnorm_fake(
pass
-@flashinfer_api
+@flashinfer_api(trace=layernorm_trace)
@register_custom_op("flashinfer::layernorm", mutates_args=())
def layernorm(
input: torch.Tensor,
@@ -590,7 +600,7 @@ def _torch_dtype_to_str(dtype):
)
-@flashinfer_api
+@flashinfer_api(trace=fused_rmsnorm_silu_trace)
def fused_rmsnorm_silu(
input: torch.Tensor,
weight: torch.Tensor,
diff --git a/flashinfer/page.py b/flashinfer/page.py
index 12ea3613..7fb33cf3 100644
--- a/flashinfer/page.py
+++ b/flashinfer/page.py
@@ -20,6 +20,10 @@ from typing import Optional, Tuple, Union
import torch
from .api_logging import flashinfer_api
+from .trace.templates.page import (
+ append_paged_kv_cache_trace,
+ append_paged_mla_kv_cache_trace,
+)
from .jit.page import gen_page_module
from .utils import (
TensorLayout,
@@ -222,7 +226,7 @@ def get_seq_lens(
)
-@flashinfer_api
+@flashinfer_api(trace=append_paged_mla_kv_cache_trace)
def append_paged_mla_kv_cache(
append_ckv: torch.Tensor,
append_kpe: torch.Tensor,
@@ -272,7 +276,7 @@ def append_paged_mla_kv_cache(
)
-@flashinfer_api
+@flashinfer_api(trace=append_paged_kv_cache_trace)
def append_paged_kv_cache(
append_key: torch.Tensor,
append_value: torch.Tensor,
diff --git a/flashinfer/pod.py b/flashinfer/pod.py
index fe2e36c1..4fa2d9bf 100644
--- a/flashinfer/pod.py
+++ b/flashinfer/pod.py
@@ -22,6 +22,10 @@ from typing import Any, List, Optional, Tuple, Union
import torch
from .api_logging import flashinfer_api
+from .trace.templates.attention import (
+ batch_pod_with_paged_kv_cache_run_trace,
+ pod_with_paged_kv_cache_run_trace,
+)
from .jit import gen_pod_module, gen_batch_pod_module
from .page import get_seq_lens
from .prefill import get_batch_prefill_module
@@ -435,7 +439,7 @@ class PODWithPagedKVCacheWrapper:
begin_forward = plan
- @flashinfer_api
+ @flashinfer_api(trace=pod_with_paged_kv_cache_run_trace)
def run(
self,
# Main params (prefill and decode)
@@ -1015,7 +1019,7 @@ class BatchPODWithPagedKVCacheWrapper:
begin_forward = plan
- @flashinfer_api
+ @flashinfer_api(trace=batch_pod_with_paged_kv_cache_run_trace)
def run(
self,
# Main params (prefill and decode)
diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py
index 4ec6a29e..d491dd35 100755
--- a/flashinfer/prefill.py
+++ b/flashinfer/prefill.py
@@ -23,6 +23,17 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload
import torch
from .api_logging import flashinfer_api
+from .trace.templates.attention import (
+ gqa_paged_prefill_trace,
+ gqa_ragged_prefill_trace,
+ single_prefill_with_kv_cache_trace,
+ trtllm_batch_context_trace,
+)
+from .trace.templates.gemm import (
+ fmha_v2_prefill_deepseek_trace,
+ trtllm_ragged_attention_deepseek_trace,
--
gen_customize_batch_prefill_module,
@@ -1099,7 +1110,7 @@ def single_prefill_with_kv_cache(
) -> Tuple[torch.Tensor, torch.Tensor]: ...
-@flashinfer_api
+@flashinfer_api(trace=single_prefill_with_kv_cache_trace)
def single_prefill_with_kv_cache(
q: torch.Tensor,
k: torch.Tensor,
@@ -2132,7 +2143,7 @@ class BatchPrefillWithPagedKVCacheWrapper:
skip_softmax_threshold_scale_factor: Optional[float] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ...
- @flashinfer_api
+ @flashinfer_api(trace=gqa_paged_prefill_trace)
def run(
self,
q: torch.Tensor,
@@ -3186,7 +3197,7 @@ class BatchPrefillWithRaggedKVCacheWrapper:
enable_pdl: Optional[bool] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ...
- @flashinfer_api
+ @flashinfer_api(trace=gqa_ragged_prefill_trace)
def run(
self,
q: torch.Tensor,
@@ -3669,7 +3680,7 @@ def get_trtllm_gen_fmha_module():
return op
-@flashinfer_api
+@flashinfer_api(trace=trtllm_ragged_attention_deepseek_trace)
def trtllm_ragged_attention_deepseek(
query: torch.Tensor,
key: torch.Tensor,
@@ -3692,6 +3703,7 @@ def trtllm_ragged_attention_deepseek(
skip_softmax_threshold_scale_factor: Optional[float] = None,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
+ backend: str = "trtllm-gen",
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Parameters
@@ -3742,6 +3754,12 @@ def trtllm_ragged_attention_deepseek(
output tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1], value.shape[2]]
lse : Optional[torch.Tensor]
lse tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1]]
+ backend : str
+ Attention backend to use. "trtllm-gen" (default) or "cute-dsl".
+ When backend="cute-dsl", query/key/value/out tensors must be
+ front-padded with max_seq_len rows of valid GPU memory before
+ index 0 (see ``cute_dsl_fmha_ragged_prefill`` for details).
--
"lse assumed not None beyond this point when return_lse is True"
@@ -3839,7 +3917,7 @@ def trtllm_ragged_attention_deepseek(
return out
-@flashinfer_api
+@flashinfer_api(trace=trtllm_batch_context_trace)
def trtllm_batch_context_with_kv_cache(
query: torch.Tensor,
kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
@@ -4138,7 +4216,7 @@ def get_trtllm_fmha_v2_sm120_module():
return gen_trtllm_fmha_v2_sm120_module().build_and_load()
-@flashinfer_api
+@flashinfer_api(trace=fmha_v2_prefill_deepseek_trace)
def fmha_v2_prefill_deepseek(
query: torch.Tensor,
key: torch.Tensor,
@@ -4228,7 +4306,7 @@ def get_trtllm_fmha_v2_module(
return gen_fmha_v2_module(input_layout, input_dtype, output_dtype).build_and_load()
-@flashinfer_api
+@flashinfer_api(trace=trtllm_fmha_v2_prefill_trace)
def trtllm_fmha_v2_prefill(
qkv: Union[
torch.Tensor,
diff --git a/flashinfer/quantization/fp4_quantization.py b/flashinfer/quantization/fp4_quantization.py
index 4cd5cd34..84f7ade6 100644
--- a/flashinfer/quantization/fp4_quantization.py
+++ b/flashinfer/quantization/fp4_quantization.py
@@ -21,6 +21,12 @@ from typing import List, Optional, Tuple
import torch
from ..api_logging import flashinfer_api
+from ..trace.templates.quantize import (
+ fp4_quantize_trace,
+ mxfp4_quantize_trace,
+ nvfp4_kv_quantize_trace,
+ nvfp4_quantize_trace,
+)
from ..jit import JitSpec
from ..jit import env as jit_env
from ..jit import (
@@ -648,7 +654,7 @@ def get_fp4_quantization_module(backend: str = "100"):
)
-@flashinfer_api
+@flashinfer_api(trace=fp4_quantize_trace)
def fp4_quantize(
input: torch.Tensor,
global_scale: Optional[torch.Tensor] = None,
@@ -923,7 +929,7 @@ def shuffle_matrix_sf_a(
return block_scale_interleave(w_shuffled)
-@flashinfer_api
+@flashinfer_api(trace=nvfp4_quantize_trace)
def nvfp4_quantize(
a,
a_global_sf,
@@ -1024,7 +1030,7 @@ def nvfp4_quantize(
return a_fp4, a_sf
-@flashinfer_api
+@flashinfer_api(trace=mxfp4_quantize_trace)
def mxfp4_quantize(
a: torch.Tensor,
backend: str = "cuda",
@@ -1441,7 +1447,7 @@ def _nvfp4_kv_quant_check(input, global_scale):
@backend_requirement({}, common_check=_nvfp4_kv_quant_check)
-@flashinfer_api
+@flashinfer_api(trace=nvfp4_kv_quantize_trace)
def nvfp4_kv_quantize(
input: torch.Tensor,
global_scale: torch.Tensor,
diff --git a/flashinfer/quantization/fp8_quantization.py b/flashinfer/quantization/fp8_quantization.py
index f2c9f412..49e13a8b 100644
--- a/flashinfer/quantization/fp8_quantization.py
+++ b/flashinfer/quantization/fp8_quantization.py
@@ -5,6 +5,7 @@ from typing import Literal, Optional, Tuple
import torch
from ..api_logging import flashinfer_api
+from ..trace.templates.quantize import mxfp8_quantize_trace
from ..jit.fp8_quantization import gen_mxfp8_quantization_sm100_module
from ..utils import (
device_support_pdl,
@@ -158,7 +159,7 @@ def get_mxfp8_quantization_sm100_module():
)
-@flashinfer_api
+@flashinfer_api(trace=mxfp8_quantize_trace)
def mxfp8_quantize(
input: torch.Tensor,
is_sf_swizzled_layout: bool = True,
diff --git a/flashinfer/rope.py b/flashinfer/rope.py
index d39d2e07..df5c7d4d 100644
--- a/flashinfer/rope.py
+++ b/flashinfer/rope.py
@@ -20,6 +20,21 @@ from typing import Optional, Tuple
import torch
from .api_logging import flashinfer_api
+from .trace.templates.rope import (
+ apply_llama31_rope_inplace_trace,
+ apply_llama31_rope_pos_ids_inplace_trace,
+ apply_llama31_rope_pos_ids_trace,
+ apply_llama31_rope_trace,
+ apply_rope_inplace_trace,
+ apply_rope_pos_ids_inplace_trace,
+ apply_rope_pos_ids_trace,
+ apply_rope_trace,
--
@@ -414,7 +429,7 @@ def _fake_apply_llama31_rope_pos_ids(
pass
-@flashinfer_api
+@flashinfer_api(trace=apply_rope_inplace_trace)
def apply_rope_inplace(
q: torch.Tensor,
k: torch.Tensor,
@@ -502,7 +517,7 @@ def apply_rope_inplace(
)
-@flashinfer_api
+@flashinfer_api(trace=apply_rope_pos_ids_inplace_trace)
def apply_rope_pos_ids_inplace(
q: torch.Tensor,
k: torch.Tensor,
@@ -561,7 +576,7 @@ def apply_rope_pos_ids_inplace(
)
-@flashinfer_api
+@flashinfer_api(trace=apply_llama31_rope_inplace_trace)
def apply_llama31_rope_inplace(
q: torch.Tensor,
k: torch.Tensor,
... (truncated -- see full diff via the command above)
```
**Summary of API changes:**
- **Decorator semantic addition (backward-compatible):**
`@flashinfer_api` now accepts an optional `trace=<TraceTemplate>`
keyword. Bare `@flashinfer_api` still works. Existing call sites of
decorated functions are unaffected. Most of the diff above is mechanical
rewrites of existing `@flashinfer_api` to `@flashinfer_api(trace=...)`,
plus the new `flashinfer/trace/` package and `fi_trace.py` for
flashinfer-bench JSON dumps.
- **New public APIs (7):**
- `flashinfer.comm.dcp_alltoall.{decode_cp_a2a_workspace_size,
decode_cp_a2a_allocate_workspace, decode_cp_a2a_init_workspace,
decode_cp_a2a_alltoall}` — DCP all-to-all for context-parallel attention
reduction (#2951).
- `flashinfer.fused_moe.{interleave_moe_scales_for_sm90_mixed_gemm,
interleave_moe_weights_for_sm90_mixed_gemm}` — SM90 mixed-input MoE GEMM
helpers (#3084).
- `flashinfer.comm.run_mixed_comm` — combinations of allreduce /
allgather / reducescatter (#2563).
- **New `@flashinfer_api`-decorated wrapper init:**
- `SegmentGEMMWrapper.__init__` is now decorated. Previously the class
itself was undecorated; `run()` already was. No call-site change.
- **Backward-compatible signature additions (defaults preserve old
behavior):**
- `top_k_page_table_transform`: `+dsa_graph_safe: bool = False`,
`+row_starts: Optional[torch.Tensor] = None` (#3133).
- `top_k_ragged_transform`: same two new params (#3133).
- `trtllm_ragged_attention_deepseek`: `+backend: str = "trtllm-gen"`
(cute-dsl backend selection).
- **No breaking signature changes** to any `@flashinfer_api` function.
Net public surface delta: +7 functions, +1 newly-decorated `__init__`, 0
removals.
- **Module reorganization to flag (not `@flashinfer_api`, but in public
re-export):**
- `flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py` →
`dense_blockscaled_gemm_sm120_b12x.py`
- Class renamed: `Sm120BlockScaledDenseGemmKernel` →
`Sm120B12xBlockScaledDenseGemmKernel`
- Re-export in `flashinfer/gemm/__init__.py` updated to the new name
only — direct importers of the old name break. Decision needed: ship as
breaking, or add a deprecation alias.
- **Internal autotuner helper rename (not public, but used by downstream
extensions):**
- `get_last_power_of_2_num_tokens_buckets` →
`get_hybrid_num_tokens_buckets`
- `last_positive_power_of_2` → `map_to_hybrid_bucket` /
`map_to_hybrid_bucket_uncapped`
> Diff truncated above due to GitHub PR body length limit. Run the
command at the top locally to see the full output.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **Patch Release**
* Version updated to 0.6.10
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
…path (#3210) ## Summary Follow-up to #2951. The merged DCP A2A code shipped with two latent foot-guns that this PR cleans up: 1. The `mapping=None` branch in `decode_cp_a2a_allocate_workspace` returns a per-rank `torch.zeros` tensor — this deadlocks at runtime on any real multi-GPU setup. 2. The MNNVL workspace tensor's lifetime is pinned via a private attribute, which is silently lost across ordinary tensor operations and would let the fabric memory be unmapped while the kernel still holds raw pointers into it. Both are addressed below. ## Bug 1: workspace VA mismatch (silent deadlock) `getFifoBasePtr` in `csrc/nv_internal/tensorrt_llm/kernels/helixAllToAll.cu:177` addresses peer FIFOs via: ```cpp auto* mappedMemory = params.workspace + mappedMemoryRank * params.workspaceStrideInU64; ``` This pointer arithmetic only resolves correctly when the workspace is a single unified VA spanning all CP ranks — i.e. MNNVL fabric memory. With per-rank `torch.zeros`, rank 0 writing to "rank 1's FIFO" lands in rank 0's own memory; peer 1 never sees it and the FIFO consumer signaling hangs forever. Reproduced on H200 NVL with 4 GH200s — first `decode_cp_a2a_alltoall` call hangs, no CUDA error, no assertion. TRT-LLM upstream (`tensorrt_llm/_torch/distributed/ops.py:386`) only supports MNNVL — there is no plain-memory branch. The fallback added during the FlashInfer port was wrong from the start. ### Why CI didn't catch Bug 1 1. `tests/comm/test_dcp_alltoall.py` simulates `cp_size` ranks on **one** GPU with **one shared** workspace tensor — pointer arithmetic on the same allocation works, so the bug is invisible. 2. `tests/comm/test_mnnvl_dcp_alltoall.py::TestMnnvlDcpAlltoall` exercises real multi-GPU but only on the MNNVL path. 3. `TestMnnvlDcpDeviceMemoryFallback` asserts shape only — never actually calls `alltoall`, so the deadlock never fires. ## Bug 2: workspace keep-alive via tensor private attribute The previous code did: ```python workspace = mnnvl_mem.as_torch_strided_tensor(torch.int64) workspace._mnnvl_mem = mnnvl_mem # prevent GC of MNNVL handle return workspace ``` `MnnvlMemory.__del__` calls `close_mnnvl_memory` which unmaps the underlying fabric VA. The `workspace._mnnvl_mem` private attribute is the only thing keeping the wrapper alive. **But torch tensor private attributes are NOT preserved across ordinary operations** — `workspace[r]`, `.view()`, `.contiguous()`, `.clone()`, `.to(...)`, slicing, indexing all return a fresh tensor without the attribute. Any caller that derives a view or slice and drops the original would silently free the workspace while the kernel still holds raw pointers into it. Current direct callers happen to be safe, but this is a buried mine. ## Changes ### Bug 1 — drop the broken plain-memory path - **Rename** `decode_cp_a2a_allocate_workspace` → `decode_cp_a2a_allocate_mnnvl_workspace`, matching `trtllm_mnnvl_ar` naming style. The MNNVL requirement is now obvious at the call site. - **Make `mapping` the only required argument.** Drop the `mapping=None` branch entirely. The redundant `cp_size` and `cp_rank` parameters were also removed — `mapping` already carries that info, and a separate path was a double-source-of-truth footgun. - **Refactor the single-GPU sim test** to use a local `_alloc_sim_workspace(cp_size)` helper that does `torch.zeros(...)` directly — that's what the test actually needs and what its docstring claims (it does *not* need the public allocator). - **Drop `TestMnnvlDcpDeviceMemoryFallback`** — the path it covered no longer exists. - **Use the public allocator in the multi-rank test** (`test_mnnvl_dcp_alltoall.py::_allocate_mnnvl_workspace_once`) instead of manually instantiating `MnnvlMemory`, so the public API is exercised end-to-end. - Refresh module docstring to make the MNNVL requirement explicit. ### Bug 2 — robust workspace keep-alive - Replace `workspace._mnnvl_mem = mnnvl_mem` with a module-level `_workspace_keepalive: Dict[int, MnnvlMemory]` keyed by `workspace.data_ptr()`. The dict pins each `MnnvlMemory` for the process lifetime, so the workspace stays mapped regardless of what the caller does with the returned tensor. This matches how TRT-LLM holds workspace ownership in its `HelixAllToAllNative._cache` class-level dict. ### Final allocator signature ```python decode_cp_a2a_allocate_mnnvl_workspace(mapping: Mapping, *, mnnvl_config: Optional[MnnvlConfig] = None) -> torch.Tensor ``` ## Test plan All verified on dlcluster GB200-NVL72 (compute capability 10.0a, CUDA 13): - [x] `tests/comm/test_dcp_alltoall.py` (single-GPU sim, container `flashinfer/flashinfer-ci-cu130`) — **29/29 PASSED in 5.16s** - [x] `mpirun -launcher fork -np 4 pytest tests/comm/test_mnnvl_dcp_alltoall.py` — single-node 4-rank MNNVL — **8/8 PASSED on all 4 ranks in 1.34s** - [x] `srun -N 2 --ntasks-per-node=4 --mpi=pmix pytest tests/comm/test_mnnvl_dcp_alltoall.py` — **multi-node 8-rank MNNVL** (2 nodes × 4 GPUs, cp_size=8, container `nvcr.io/nvidia/pytorch:26.02-py3` with HPC-X) — **8/8 PASSED on all 8 ranks in ~17s** The multi-node run exercises real cross-node fabric memory allocation via `cuMemCreate` with FABRIC handles — same code path Helix production uses on NVL72. 🤖 AI-assisted (Claude Code) --------- Co-authored-by: Alex Yang <aleyang@nvidia.com>
…path (#3210) ## Summary Follow-up to #2951. The merged DCP A2A code shipped with two latent foot-guns that this PR cleans up: 1. The `mapping=None` branch in `decode_cp_a2a_allocate_workspace` returns a per-rank `torch.zeros` tensor — this deadlocks at runtime on any real multi-GPU setup. 2. The MNNVL workspace tensor's lifetime is pinned via a private attribute, which is silently lost across ordinary tensor operations and would let the fabric memory be unmapped while the kernel still holds raw pointers into it. Both are addressed below. ## Bug 1: workspace VA mismatch (silent deadlock) `getFifoBasePtr` in `csrc/nv_internal/tensorrt_llm/kernels/helixAllToAll.cu:177` addresses peer FIFOs via: ```cpp auto* mappedMemory = params.workspace + mappedMemoryRank * params.workspaceStrideInU64; ``` This pointer arithmetic only resolves correctly when the workspace is a single unified VA spanning all CP ranks — i.e. MNNVL fabric memory. With per-rank `torch.zeros`, rank 0 writing to "rank 1's FIFO" lands in rank 0's own memory; peer 1 never sees it and the FIFO consumer signaling hangs forever. Reproduced on H200 NVL with 4 GH200s — first `decode_cp_a2a_alltoall` call hangs, no CUDA error, no assertion. TRT-LLM upstream (`tensorrt_llm/_torch/distributed/ops.py:386`) only supports MNNVL — there is no plain-memory branch. The fallback added during the FlashInfer port was wrong from the start. ### Why CI didn't catch Bug 1 1. `tests/comm/test_dcp_alltoall.py` simulates `cp_size` ranks on **one** GPU with **one shared** workspace tensor — pointer arithmetic on the same allocation works, so the bug is invisible. 2. `tests/comm/test_mnnvl_dcp_alltoall.py::TestMnnvlDcpAlltoall` exercises real multi-GPU but only on the MNNVL path. 3. `TestMnnvlDcpDeviceMemoryFallback` asserts shape only — never actually calls `alltoall`, so the deadlock never fires. ## Bug 2: workspace keep-alive via tensor private attribute The previous code did: ```python workspace = mnnvl_mem.as_torch_strided_tensor(torch.int64) workspace._mnnvl_mem = mnnvl_mem # prevent GC of MNNVL handle return workspace ``` `MnnvlMemory.__del__` calls `close_mnnvl_memory` which unmaps the underlying fabric VA. The `workspace._mnnvl_mem` private attribute is the only thing keeping the wrapper alive. **But torch tensor private attributes are NOT preserved across ordinary operations** — `workspace[r]`, `.view()`, `.contiguous()`, `.clone()`, `.to(...)`, slicing, indexing all return a fresh tensor without the attribute. Any caller that derives a view or slice and drops the original would silently free the workspace while the kernel still holds raw pointers into it. Current direct callers happen to be safe, but this is a buried mine. ## Changes ### Bug 1 — drop the broken plain-memory path - **Rename** `decode_cp_a2a_allocate_workspace` → `decode_cp_a2a_allocate_mnnvl_workspace`, matching `trtllm_mnnvl_ar` naming style. The MNNVL requirement is now obvious at the call site. - **Make `mapping` the only required argument.** Drop the `mapping=None` branch entirely. The redundant `cp_size` and `cp_rank` parameters were also removed — `mapping` already carries that info, and a separate path was a double-source-of-truth footgun. - **Refactor the single-GPU sim test** to use a local `_alloc_sim_workspace(cp_size)` helper that does `torch.zeros(...)` directly — that's what the test actually needs and what its docstring claims (it does *not* need the public allocator). - **Drop `TestMnnvlDcpDeviceMemoryFallback`** — the path it covered no longer exists. - **Use the public allocator in the multi-rank test** (`test_mnnvl_dcp_alltoall.py::_allocate_mnnvl_workspace_once`) instead of manually instantiating `MnnvlMemory`, so the public API is exercised end-to-end. - Refresh module docstring to make the MNNVL requirement explicit. ### Bug 2 — robust workspace keep-alive - Replace `workspace._mnnvl_mem = mnnvl_mem` with a module-level `_workspace_keepalive: Dict[int, MnnvlMemory]` keyed by `workspace.data_ptr()`. The dict pins each `MnnvlMemory` for the process lifetime, so the workspace stays mapped regardless of what the caller does with the returned tensor. This matches how TRT-LLM holds workspace ownership in its `HelixAllToAllNative._cache` class-level dict. ### Final allocator signature ```python decode_cp_a2a_allocate_mnnvl_workspace(mapping: Mapping, *, mnnvl_config: Optional[MnnvlConfig] = None) -> torch.Tensor ``` ## Test plan All verified on dlcluster GB200-NVL72 (compute capability 10.0a, CUDA 13): - [x] `tests/comm/test_dcp_alltoall.py` (single-GPU sim, container `flashinfer/flashinfer-ci-cu130`) — **29/29 PASSED in 5.16s** - [x] `mpirun -launcher fork -np 4 pytest tests/comm/test_mnnvl_dcp_alltoall.py` — single-node 4-rank MNNVL — **8/8 PASSED on all 4 ranks in 1.34s** - [x] `srun -N 2 --ntasks-per-node=4 --mpi=pmix pytest tests/comm/test_mnnvl_dcp_alltoall.py` — **multi-node 8-rank MNNVL** (2 nodes × 4 GPUs, cp_size=8, container `nvcr.io/nvidia/pytorch:26.02-py3` with HPC-X) — **8/8 PASSED on all 8 ranks in ~17s** The multi-node run exercises real cross-node fabric memory allocation via `cuMemCreate` with FABRIC handles — same code path Helix production uses on NVL72. 🤖 AI-assisted (Claude Code) --------- Co-authored-by: Alex Yang <aleyang@nvidia.com>
What is DCP?
DCP (Decode Context Parallelism) splits the KV cache across multiple GPUs for long-context inference. During the decode phase, each GPU computes partial attention over its local KV shard, producing two tensors:
partial_o(bf16/fp16): partial attention output, shape[B, cp_size, head_dim]softmax_stats(fp32): softmax statistics for online correction, shape[B, cp_size, stats_dim]These must be exchanged across all ranks (all-to-all) before the final attention reduction.
Why this kernel?
The current approach uses 2× NCCL
all_to_all_single— one call per tensor since they have different dtypes. For the small message sizes in DCP (KB-range), NCCL protocol overhead dominates latency.This PR adds a fused LL128 FIFO kernel (ported from TensorRT-LLM) that exchanges both tensors in a single kernel launch via MNNVL cross-GPU direct memory writes, eliminating NCCL protocol overhead.
Kernel-level microbenchmark
GB200 NVL72, 4 GPUs, cp_size=4. Measures raw kernel time only (not end-to-end pipeline).
Speedup is latency-dominated — these tensors are small, so NCCL protocol overhead is the bottleneck.
Components
csrc/nv_internal/tensorrt_llm/kernels/helixAllToAll.{cu,h}— LL128 FIFO protocol, SM90+csrc/trtllm_dcp_alltoall.cuflashinfer/jit/comm.py(gen_dcp_alltoall_module)flashinfer/comm/dcp_alltoall.pydecode_cp_a2a_workspace_size— query workspace bytes per rankdecode_cp_a2a_allocate_workspace— allocate workspace (MNNVL or device memory)decode_cp_a2a_init_workspace— initialize FIFO buffersdecode_cp_a2a_alltoall— run all-to-all exchangetests/comm/test_dcp_alltoall.py— 30 single-GPU tests (correctness, edge cases, input validation)tests/comm/test_mnnvl_dcp_alltoall.py— 9 multi-GPU MNNVL tests (workspace shape, cross-rank visibility, alltoall correctness, FIFO reuse)benchmarks/bench_dcp_alltoall.py— native LL128 vs NCCL baselineflashinfer/aot.pyRequirements
mpi4pyfor multi-GPU tests and benchmarkTest plan
pytest tests/comm/test_dcp_alltoall.py)mpirun -np 4 pytest tests/comm/test_mnnvl_dcp_alltoall.py)Summary by CodeRabbit
New Features
Tests
Chores