Skip to content

fix(dcp_alltoall): require MNNVL workspace, drop broken plain-memory path#3210

Merged
aleozlx merged 4 commits intoflashinfer-ai:mainfrom
davidjpyu:fix/dcp-a2a-mnnvl-workspace-required
May 5, 2026
Merged

fix(dcp_alltoall): require MNNVL workspace, drop broken plain-memory path#3210
aleozlx merged 4 commits intoflashinfer-ai:mainfrom
davidjpyu:fix/dcp-a2a-mnnvl-workspace-required

Conversation

@davidjpyu
Copy link
Copy Markdown
Contributor

@davidjpyu davidjpyu commented Apr 30, 2026

Summary

Follow-up to #2951. The merged DCP A2A code shipped with two latent foot-guns that this PR cleans up:

  1. The mapping=None branch in decode_cp_a2a_allocate_workspace returns a per-rank torch.zeros tensor — this deadlocks at runtime on any real multi-GPU setup.
  2. The MNNVL workspace tensor's lifetime is pinned via a private attribute, which is silently lost across ordinary tensor operations and would let the fabric memory be unmapped while the kernel still holds raw pointers into it.

Both are addressed below.

Bug 1: workspace VA mismatch (silent deadlock)

getFifoBasePtr in csrc/nv_internal/tensorrt_llm/kernels/helixAllToAll.cu:177 addresses peer FIFOs via:

auto* mappedMemory = params.workspace + mappedMemoryRank * params.workspaceStrideInU64;

This pointer arithmetic only resolves correctly when the workspace is a single unified VA spanning all CP ranks — i.e. MNNVL fabric memory. With per-rank torch.zeros, rank 0 writing to "rank 1's FIFO" lands in rank 0's own memory; peer 1 never sees it and the FIFO consumer signaling hangs forever.

Reproduced on H200 NVL with 4 GH200s — first decode_cp_a2a_alltoall call hangs, no CUDA error, no assertion. TRT-LLM upstream (tensorrt_llm/_torch/distributed/ops.py:386) only supports MNNVL — there is no plain-memory branch. The fallback added during the FlashInfer port was wrong from the start.

Why CI didn't catch Bug 1

  1. tests/comm/test_dcp_alltoall.py simulates cp_size ranks on one GPU with one shared workspace tensor — pointer arithmetic on the same allocation works, so the bug is invisible.
  2. tests/comm/test_mnnvl_dcp_alltoall.py::TestMnnvlDcpAlltoall exercises real multi-GPU but only on the MNNVL path.
  3. TestMnnvlDcpDeviceMemoryFallback asserts shape only — never actually calls alltoall, so the deadlock never fires.

Bug 2: workspace keep-alive via tensor private attribute

The previous code did:

workspace = mnnvl_mem.as_torch_strided_tensor(torch.int64)
workspace._mnnvl_mem = mnnvl_mem  # prevent GC of MNNVL handle
return workspace

MnnvlMemory.__del__ calls close_mnnvl_memory which unmaps the underlying fabric VA. The workspace._mnnvl_mem private attribute is the only thing keeping the wrapper alive. But torch tensor private attributes are NOT preserved across ordinary operationsworkspace[r], .view(), .contiguous(), .clone(), .to(...), slicing, indexing all return a fresh tensor without the attribute. Any caller that derives a view or slice and drops the original would silently free the workspace while the kernel still holds raw pointers into it. Current direct callers happen to be safe, but this is a buried mine.

Changes

Bug 1 — drop the broken plain-memory path

  • Rename decode_cp_a2a_allocate_workspacedecode_cp_a2a_allocate_mnnvl_workspace, matching trtllm_mnnvl_ar naming style. The MNNVL requirement is now obvious at the call site.
  • Make mapping the only required argument. Drop the mapping=None branch entirely. The redundant cp_size and cp_rank parameters were also removed — mapping already carries that info, and a separate path was a double-source-of-truth footgun.
  • Refactor the single-GPU sim test to use a local _alloc_sim_workspace(cp_size) helper that does torch.zeros(...) directly — that's what the test actually needs and what its docstring claims (it does not need the public allocator).
  • Drop TestMnnvlDcpDeviceMemoryFallback — the path it covered no longer exists.
  • Use the public allocator in the multi-rank test (test_mnnvl_dcp_alltoall.py::_allocate_mnnvl_workspace_once) instead of manually instantiating MnnvlMemory, so the public API is exercised end-to-end.
  • Refresh module docstring to make the MNNVL requirement explicit.

Bug 2 — robust workspace keep-alive

  • Replace workspace._mnnvl_mem = mnnvl_mem with a module-level _workspace_keepalive: Dict[int, MnnvlMemory] keyed by workspace.data_ptr(). The dict pins each MnnvlMemory for the process lifetime, so the workspace stays mapped regardless of what the caller does with the returned tensor. This matches how TRT-LLM holds workspace ownership in its HelixAllToAllNative._cache class-level dict.

Final allocator signature

decode_cp_a2a_allocate_mnnvl_workspace(mapping: Mapping, *, mnnvl_config: Optional[MnnvlConfig] = None) -> torch.Tensor

Test plan

All verified on dlcluster GB200-NVL72 (compute capability 10.0a, CUDA 13):

  • tests/comm/test_dcp_alltoall.py (single-GPU sim, container flashinfer/flashinfer-ci-cu130) — 29/29 PASSED in 5.16s
  • mpirun -launcher fork -np 4 pytest tests/comm/test_mnnvl_dcp_alltoall.py — single-node 4-rank MNNVL — 8/8 PASSED on all 4 ranks in 1.34s
  • srun -N 2 --ntasks-per-node=4 --mpi=pmix pytest tests/comm/test_mnnvl_dcp_alltoall.pymulti-node 8-rank MNNVL (2 nodes × 4 GPUs, cp_size=8, container nvcr.io/nvidia/pytorch:26.02-py3 with HPC-X) — 8/8 PASSED on all 8 ranks in ~17s

The multi-node run exercises real cross-node fabric memory allocation via cuMemCreate with FABRIC handles — same code path Helix production uses on NVL72.

🤖 AI-assisted (Claude Code)

…path

Follow-up to flashinfer-ai#2951. The merged code exposed a non-functional
`mapping=None` branch in `decode_cp_a2a_allocate_workspace` that returns
a per-rank `torch.zeros` tensor. This deadlocks at runtime on any real
multi-GPU setup: the kernel addresses peer FIFOs via
`params.workspace + peer_rank * stride`, which only resolves correctly
when the workspace is a single unified VA spanning all CP ranks (i.e.
MNNVL fabric memory). With per-rank `torch.zeros`, a write to peer N's
FIFO lands in the local rank's own memory and the FIFO consumer signaling
hangs forever.

CI did not catch this:
  - Single-GPU sim test (`test_dcp_alltoall.py`) shares one workspace
    tensor across all simulated ranks, so cross-rank pointer arithmetic
    works on the same allocation.
  - The MNNVL multi-GPU test (`TestMnnvlDcpAlltoall`) only exercises the
    fabric path.
  - `TestMnnvlDcpDeviceMemoryFallback` only asserted shape, never called
    alltoall.

Reproduced on H200 NVL with 4 GH200s — first `decode_cp_a2a_alltoall`
call hangs, no CUDA error. TRT-LLM upstream
(`tensorrt_llm/_torch/distributed/ops.py:386`) only supports MNNVL.

Changes:
  - Rename `decode_cp_a2a_allocate_workspace` ->
    `decode_cp_a2a_allocate_mnnvl_workspace`, matching `trtllm_mnnvl_ar`
    naming style. `mapping` is now a required positional argument.
  - Drop the `mapping=None` branch entirely.
  - Update single-GPU sim test to allocate `torch.zeros` directly via a
    local `_alloc_sim_workspace` helper — that is what the test actually
    needs and what its docstring claims.
  - Drop `TestMnnvlDcpDeviceMemoryFallback` (the path it tested no
    longer exists).
  - Refresh module docstring to make the MNNVL requirement explicit.
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 30, 2026

📝 Walkthrough

Walkthrough

Replaces the generic DCP A2A workspace allocator with an MNNVL-specific allocator that requires a mapping; allocator always returns an MNNVL-backed int64 strided tensor and retains the MNNVL handle in a module-level keepalive. Tests updated to use a single simulated CUDA workspace; device-memory fallback tests removed. (≤50 words)

Changes

Cohort / File(s) Summary
Public API export
flashinfer/comm/__init__.py
Re-export changed: decode_cp_a2a_allocate_workspacedecode_cp_a2a_allocate_mnnvl_workspace.
Allocator implementation
flashinfer/comm/dcp_alltoall.py
Removed generic allocator; added decode_cp_a2a_allocate_mnnvl_workspace(mapping, *, mnnvl_config=None) which always allocates via MnnvlMemory, returns an int64 strided tensor sized from mapping (mapping.cp_size), documents workspace semantics, stores MnnvlMemory in a module-level _workspace_keepalive keyed by workspace.data_ptr(), and updated __all__. Deadlock messaging generalized.
Tests — simulation helper
tests/comm/test_dcp_alltoall.py
Introduces _alloc_sim_workspace(cp_size) to create a shared CUDA torch.int64 zeroed tensor sized via decode_cp_a2a_workspace_size; replaces prior allocator usage with this simulator across single-GPU multi-rank tests; removes allocator shape/dtype unit check.
Tests — MNNVL usage / remove fallback
tests/comm/test_mnnvl_dcp_alltoall.py
Replaces manual MNNVL memory construction with decode_cp_a2a_allocate_mnnvl_workspace(mapping) and removes the TestMnnvlDcpDeviceMemoryFallback class and imports of the removed allocator.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Suggested labels

run-ci

Suggested reviewers

  • yzh119
  • bkryu
  • aleozlx
  • jimmyzho
  • nv-yunzheq

Poem

🐰 I nibble bytes and cuddle heap,
MNNVL cradles workspace deep.
One mapping sings, one tensor stays,
Tests now share the simulated ways.
Hop—allocation made neat and spry.

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: requiring MNNVL workspace and removing the broken plain-memory fallback path for DCP all-to-all operations.
Docstring Coverage ✅ Passed Docstring coverage is 89.47% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed The PR description is comprehensive, addressing two specific bugs with detailed technical context, reproduction steps, and a clear test plan.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request renames the workspace allocator to decode_cp_a2a_allocate_mnnvl_workspace and makes the mapping argument mandatory, enforcing MNNVL fabric memory for multi-rank operations. Single-GPU tests are updated to use a private simulation helper. Feedback suggests simplifying the API by removing redundant cp_size and cp_rank parameters already available in the mapping object, utilizing the public API within tests to ensure proper validation, and removing redundant test cases.

from flashinfer.comm import (
decode_cp_a2a_alltoall,
decode_cp_a2a_allocate_workspace,
decode_cp_a2a_init_workspace,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The new public API decode_cp_a2a_allocate_mnnvl_workspace is not imported or used in this test file. Instead, the module-level helper _allocate_mnnvl_workspace_once (lines 121-144) manually reimplements the allocation logic. It is highly recommended to import and use the public API in the tests to ensure it is properly exercised and to reduce code duplication.

Suggested change
decode_cp_a2a_init_workspace,
decode_cp_a2a_allocate_mnnvl_workspace,

Comment thread flashinfer/comm/dcp_alltoall.py Outdated
# 2. Allocate workspace (MNNVL or plain device memory)
workspace = decode_cp_a2a_allocate_workspace(cp_size, cp_rank, mapping=mapping)
# 2. Allocate MNNVL-backed workspace (Mapping is required)
workspace = decode_cp_a2a_allocate_mnnvl_workspace(cp_size, cp_rank, mapping)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The usage protocol should be updated to reflect the simplified API if the redundant cp_size and cp_rank arguments are removed from the allocator.

Suggested change
workspace = decode_cp_a2a_allocate_mnnvl_workspace(cp_size, cp_rank, mapping)
workspace = decode_cp_a2a_allocate_mnnvl_workspace(mapping)

Comment on lines +127 to 152
def decode_cp_a2a_allocate_mnnvl_workspace(
cp_size: int,
cp_rank: int,
mapping: Mapping,
*,
mapping: Optional[Mapping] = None,
mnnvl_config: Optional[MnnvlConfig] = None,
) -> torch.Tensor:
"""Allocate a workspace tensor of shape ``[cp_size, ws_elems_per_rank]``.
"""Allocate an MNNVL-backed workspace of shape ``[cp_size, ws_elems_per_rank]``.

The DCP A2A kernel requires a single unified VA spanning all CP ranks
(see module docstring), so workspace allocation must go through MNNVL
fabric memory. This function is the only supported allocator.

After allocation, call :func:`decode_cp_a2a_init_workspace` followed by a
cross-rank barrier before the first :func:`decode_cp_a2a_alltoall` call.

Two allocation modes:

- **MNNVL** (``mapping`` provided): Cross-rank visible GPU memory via
FlashInfer's ``MnnvlMemory``. Required for multi-node or when ranks
cannot see each other's device memory directly.
- **Plain device memory** (``mapping=None``): Standard ``torch.zeros``
allocation. Sufficient for single-node with NVLink P2P.

Args:
cp_size: Context-parallel group size.
cp_rank: This rank's position in the CP group.
mapping: Mapping object for MNNVL allocation. If provided, MNNVL is
used. The mapping must have ``cp_size`` set correctly. The
communicator is split using ``mapping.pp_rank``, ``mapping.cp_rank``,
and ``mapping.tp_rank``.
cp_rank: This rank's position in the CP group (used for logging only;
``mapping`` carries the authoritative rank info).
mapping: Mapping object for MNNVL allocation. Must have ``cp_size``
set correctly. The communicator is split using ``mapping.pp_rank``,
``mapping.cp_rank``, and ``mapping.tp_rank``.
mnnvl_config: Configuration for the MNNVL communication backend.
Required when using MNNVL with ``torch.distributed`` (pass
``MnnvlConfig(comm_backend=TorchDistBackend(group))``).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The cp_size and cp_rank arguments are redundant because the mapping object (which is now a required positional argument) already contains this information. As noted in the docstring, mapping carries the authoritative rank info. Removing these redundant parameters simplifies the API and eliminates the risk of passing inconsistent values.

@flashinfer_api
def decode_cp_a2a_allocate_mnnvl_workspace(
    mapping: Mapping,
    *,
    mnnvl_config: Optional[MnnvlConfig] = None,
) -> torch.Tensor:
    """Allocate an MNNVL-backed workspace of shape ``[cp_size, ws_elems_per_rank]``.

    The DCP A2A kernel requires a single unified VA spanning all CP ranks
    (see module docstring), so workspace allocation must go through MNNVL
    fabric memory. This function is the only supported allocator.

    After allocation, call :func:`decode_cp_a2a_init_workspace` followed by a
    cross-rank barrier before the first :func:`decode_cp_a2a_alltoall` call.

    Args:
        mapping: Mapping object for MNNVL allocation. Must have ``cp_size``
            set correctly. The communicator is split using ``mapping.pp_rank``,
            ``mapping.cp_rank``, and ``mapping.tp_rank``.
        mnnvl_config: Configuration for the MNNVL communication backend.
            Required when using MNNVL with ``torch.distributed`` (pass
            ``MnnvlConfig(comm_backend=TorchDistBackend(group))``).

Comment thread flashinfer/comm/dcp_alltoall.py Outdated
Comment on lines +157 to +172
ws_bytes = decode_cp_a2a_workspace_size(cp_size)

if mapping is not None:
MnnvlMemory.initialize()
if mnnvl_config:
MnnvlMemory.set_comm_from_config(mapping, mnnvl_config)

mnnvl_mem = MnnvlMemory(mapping, ws_bytes)
workspace = mnnvl_mem.as_torch_strided_tensor(torch.int64)
workspace._mnnvl_mem = mnnvl_mem # prevent GC of MNNVL handle
logger.info(
"Rank %d: DCP MNNVL workspace allocated — shape=%s, stride=%s",
cp_rank,
list(workspace.shape),
list(workspace.stride()),
)
return workspace

ws_elems_per_rank = (ws_bytes + 7) // 8
return torch.zeros(cp_size, ws_elems_per_rank, dtype=torch.int64, device="cuda")
MnnvlMemory.initialize()
if mnnvl_config:
MnnvlMemory.set_comm_from_config(mapping, mnnvl_config)

mnnvl_mem = MnnvlMemory(mapping, ws_bytes)
workspace = mnnvl_mem.as_torch_strided_tensor(torch.int64)
workspace._mnnvl_mem = mnnvl_mem # prevent GC of MNNVL handle
logger.info(
"Rank %d: DCP MNNVL workspace allocated — shape=%s, stride=%s",
cp_rank,
list(workspace.shape),
list(workspace.stride()),
)
return workspace
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If the redundant cp_size and cp_rank arguments are removed from the function signature, the implementation should be updated to extract these values directly from the mapping object.

Suggested change
ws_bytes = decode_cp_a2a_workspace_size(cp_size)
if mapping is not None:
MnnvlMemory.initialize()
if mnnvl_config:
MnnvlMemory.set_comm_from_config(mapping, mnnvl_config)
mnnvl_mem = MnnvlMemory(mapping, ws_bytes)
workspace = mnnvl_mem.as_torch_strided_tensor(torch.int64)
workspace._mnnvl_mem = mnnvl_mem # prevent GC of MNNVL handle
logger.info(
"Rank %d: DCP MNNVL workspace allocated — shape=%s, stride=%s",
cp_rank,
list(workspace.shape),
list(workspace.stride()),
)
return workspace
ws_elems_per_rank = (ws_bytes + 7) // 8
return torch.zeros(cp_size, ws_elems_per_rank, dtype=torch.int64, device="cuda")
MnnvlMemory.initialize()
if mnnvl_config:
MnnvlMemory.set_comm_from_config(mapping, mnnvl_config)
mnnvl_mem = MnnvlMemory(mapping, ws_bytes)
workspace = mnnvl_mem.as_torch_strided_tensor(torch.int64)
workspace._mnnvl_mem = mnnvl_mem # prevent GC of MNNVL handle
logger.info(
"Rank %d: DCP MNNVL workspace allocated — shape=%s, stride=%s",
cp_rank,
list(workspace.shape),
list(workspace.stride()),
)
return workspace
cp_size = mapping.cp_size
cp_rank = mapping.cp_rank
ws_bytes = decode_cp_a2a_workspace_size(cp_size)
MnnvlMemory.initialize()
if mnnvl_config:
MnnvlMemory.set_comm_from_config(mapping, mnnvl_config)
mnnvl_mem = MnnvlMemory(mapping, ws_bytes)
workspace = mnnvl_mem.as_torch_strided_tensor(torch.int64)
workspace._mnnvl_mem = mnnvl_mem # prevent GC of MNNVL handle
logger.info(
"Rank %d: DCP MNNVL workspace allocated — shape=%s, stride=%s",
cp_rank,
list(workspace.shape),
list(workspace.stride()),
)
return workspace

Comment thread tests/comm/test_dcp_alltoall.py Outdated
for cp_size in [2, 4]:
ws_bytes = decode_cp_a2a_workspace_size(cp_size)
workspace = decode_cp_a2a_allocate_workspace(cp_size, cp_rank=0)
workspace = _alloc_sim_workspace(cp_size)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test case now uses the local helper _alloc_sim_workspace instead of the public API decode_cp_a2a_allocate_mnnvl_workspace. While this allows the test to pass in a single-GPU environment, it means the public allocator is no longer being tested here. Since the workspace shape and dtype are already verified in tests/comm/test_mnnvl_dcp_alltoall.py using real MNNVL allocations, this test case is now redundant and could be removed to keep the test suite focused.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/comm/dcp_alltoall.py (1)

127-172: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Validate the topology inputs before allocating.

cp_size is still used to size the workspace, but this function never checks that it matches mapping.cp_size or that cp_rank matches mapping.cp_rank. A stale caller can now allocate for one rank layout and initialize the MNNVL handle for another, which can reintroduce the deadlock/corruption this PR is trying to remove.

Suggested guard
 def decode_cp_a2a_allocate_mnnvl_workspace(
     cp_size: int,
     cp_rank: int,
     mapping: Mapping,
     *,
     mnnvl_config: Optional[MnnvlConfig] = None,
 ) -> torch.Tensor:
+    if cp_size != mapping.cp_size:
+        raise ValueError(
+            f"cp_size={cp_size} does not match mapping.cp_size={mapping.cp_size}"
+        )
+    if cp_rank != mapping.cp_rank:
+        raise ValueError(
+            f"cp_rank={cp_rank} does not match mapping.cp_rank={mapping.cp_rank}"
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/comm/dcp_alltoall.py` around lines 127 - 172, In
decode_cp_a2a_allocate_mnnvl_workspace validate that the provided cp_size and
cp_rank match the topology carried by mapping before initializing/allocating
MnnvlMemory: check mapping.cp_size == cp_size and mapping.cp_rank == cp_rank
(and raise a ValueError with a clear message including both expected and actual
values if they differ); perform these checks at the top of the function (before
MnnvlMemory.initialize()/MnnvlMemory(...) and before logging) so a stale caller
cannot allocate/initialize using a mismatched layout.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@flashinfer/comm/dcp_alltoall.py`:
- Around line 127-172: In decode_cp_a2a_allocate_mnnvl_workspace validate that
the provided cp_size and cp_rank match the topology carried by mapping before
initializing/allocating MnnvlMemory: check mapping.cp_size == cp_size and
mapping.cp_rank == cp_rank (and raise a ValueError with a clear message
including both expected and actual values if they differ); perform these checks
at the top of the function (before MnnvlMemory.initialize()/MnnvlMemory(...) and
before logging) so a stale caller cannot allocate/initialize using a mismatched
layout.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: e7587b59-f6d0-4604-bac5-4ad8e4030b2d

📥 Commits

Reviewing files that changed from the base of the PR and between ed70283 and 4e11900.

📒 Files selected for processing (4)
  • flashinfer/comm/__init__.py
  • flashinfer/comm/dcp_alltoall.py
  • tests/comm/test_dcp_alltoall.py
  • tests/comm/test_mnnvl_dcp_alltoall.py
💤 Files with no reviewable changes (1)
  • tests/comm/test_mnnvl_dcp_alltoall.py

The previous code attached the MnnvlMemory wrapper as a private attribute
on the returned tensor (`workspace._mnnvl_mem = mnnvl_mem`) to keep it
from being garbage-collected. That works for direct callers, but private
attributes on a torch tensor are NOT preserved across ordinary tensor
operations: `workspace[r]`, `.view()`, `.contiguous()`, `.clone()`,
`.to(...)`, slicing, indexing, etc. all return a fresh tensor without the
attribute. Any caller that derives a view / slice and drops the original
would silently free the underlying fabric memory while the kernel still
holds raw pointers into it.

Move the keep-alive to a module-level `Dict[int, MnnvlMemory]` keyed by
`workspace.data_ptr()`. The dict pins each MnnvlMemory wrapper for the
process lifetime, so the workspace stays mapped regardless of what the
caller does with the returned tensor. This matches the spirit of how
TRT-LLM holds workspace ownership in its `HelixAllToAllNative._cache`
class-level dict.

Same single-GPU and 4-rank MNNVL test results as before.
@aleozlx aleozlx added the v0.6.11 release blocker label for 0.6.11 label Apr 30, 2026
Per Gemini code review on PR flashinfer-ai#3210: `mapping` already carries the
authoritative cp_size and cp_rank, so passing them as separate
parameters creates a double-source-of-truth footgun. Simplify the
public API to take only `mapping` and a keyword-only `mnnvl_config`.

Same change ripples to tests:
  - `tests/comm/test_mnnvl_dcp_alltoall.py::_allocate_mnnvl_workspace_once`
    now calls the public API instead of manually instantiating
    MnnvlMemory, so the multi-rank test exercises the public allocator.
  - `tests/comm/test_dcp_alltoall.py::test_allocate_returns_correct_shape_and_dtype`
    is removed; after the rename the single-GPU test allocates via a
    local `torch.zeros` helper, so the assertion is tautological. The
    real shape/dtype assertion lives in
    `test_mnnvl_dcp_alltoall.py::test_workspace_shape`.

No kernel or behavior changes; pure API simplification.
@davidjpyu
Copy link
Copy Markdown
Contributor Author

Thanks @gemini-code-assist and @coderabbitai for the review.

Addressed all 5 Gemini suggestions in commit b87aa37:

  • HIGH tests/comm/test_mnnvl_dcp_alltoall.py:38_allocate_mnnvl_workspace_once now calls the public decode_cp_a2a_allocate_mnnvl_workspace instead of manually instantiating MnnvlMemory. The multi-rank test now exercises the public API end-to-end.
  • MEDIUM flashinfer/comm/dcp_alltoall.py:161 (and 18, 181) — Dropped redundant cp_size and cp_rank parameters; mapping is now the only required arg. Final signature: decode_cp_a2a_allocate_mnnvl_workspace(mapping, *, mnnvl_config=None).
  • MEDIUM tests/comm/test_dcp_alltoall.py:185 — Removed the now-tautological test_allocate_returns_correct_shape_and_dtype (it asserted properties of a local torch.zeros helper). Real shape/dtype assertion lives in test_mnnvl_dcp_alltoall.py::test_workspace_shape.

CodeRabbit's suggestion to validate cp_size == mapping.cp_size and cp_rank == mapping.cp_rank is moot now that those parameters no longer exist — the inconsistent-input footgun is removed at the API level rather than caught at runtime.

Re-tested on dlcluster GB200-NVL72:

  • single-GPU sim: 29/29 passed (was 30 before removing the redundant test)
  • single-node 4-rank MNNVL via mpirun: 8/8 passed on all ranks
  • multi-node 8-rank MNNVL via srun --mpi=pmix (cp_size=8, 2 nodes × 4 GPUs): 8/8 passed on all ranks (already verified at the previous commit and unchanged by this API-only refactor)

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Thanks for the update, @davidjpyu. It's great to see the MNNVL workspace requirement enforced at the API level and the keep-alive mechanism made more robust. The test suite adjustments also look correct for the new API contract. Since you've verified the changes across single-GPU, single-node, and multi-node environments, this looks solid.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 30, 2026

Tip

For best results, initiate chat on the files or code changes.

@davidjpyu: Thank you for the thorough update! The API-level fix is cleaner than runtime validation — removing cp_size/cp_rank entirely eliminates the inconsistency footgun at the source. The module-level _workspace_keepalive dict keyed by data_ptr() is a solid approach for lifetime management, and the test coverage across single-GPU sim, single-node 4-rank, and multi-node 8-rank scenarios looks comprehensive. Everything looks good! 🎉

(\_/)
(^.^)
(> 🐇)

🧠 Learnings used
Learnt from: depaulmillz
Repo: flashinfer-ai/flashinfer PR: 2738
File: csrc/group_gemm_nvfp4_groupwise_sm120.cu:101-102
Timestamp: 2026-03-11T16:24:19.500Z
Learning: In `flashinfer/gemm/gemm_base.py`, for group GEMM kernel launchers (e.g., `group_gemm_nvfp4_nt_groupwise`, `group_gemm_mxfp4_nt_groupwise`, `group_gemm_fp8_nt_groupwise`), both `int_workspace_buffer` and `float_workspace_buffer` are always allocated via `_get_cache_buf(..., a.device)`, guaranteeing they share the same device as input tensor `a`. In `csrc/group_gemm_nvfp4_groupwise_sm120.cu` and similar launchers, sourcing the `ffi::CUDADeviceGuard` from `float_workspace_buffer.device()` and the stream from `A.device()` is safe and intentional — do not flag this as a device inconsistency.

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 3080
File: flashinfer/fused_moe/cute_dsl/b12x_moe.py:217-240
Timestamp: 2026-04-16T18:28:55.610Z
Learning: In `flashinfer/fused_moe/cute_dsl/b12x_moe.py` (`B12xMoEWrapper._allocate_buffers`) and `flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py` (`launch_sm120_moe`): when `use_cuda_graph=True`, `_allocate_buffers` must mirror the dynamic→static fallback that `launch_sm120_moe` applies on the non-preallocated path. Specifically, `select_sm120_moe_backend` picks by routed-row count alone; `launch_sm120_moe` further forces static when `num_local_experts != num_experts` (lines ~1673-1674). When a workspace is passed in, the backend is inferred from the workspace type (~line 1663) and that fallback is skipped — so preallocating a dynamic workspace for an EP config (num_local_experts < num_experts) causes silent out-of-bounds buffer indexing at runtime. Fix: in `_allocate_buffers`, after calling `select_sm120_moe_backend`, add `if backend == "dynamic" and self.num_local_experts != self.num_experts: backend = "static"` before allocating.

Learnt from: yanqinz2
Repo: flashinfer-ai/flashinfer PR: 2790
File: flashinfer/gemm/gemm_base.py:2121-2179
Timestamp: 2026-03-16T18:23:48.730Z
Learning: In `flashinfer/gemm/gemm_base.py`, cuDNN graphs built with `is_override_shape_enabled=True` (used in `build_cudnn_fp4_gemm_graph_override_shape`, `build_cudnn_mxfp8_gemm_graph_override_shape`, `build_cudnn_gemm_bf16_graph_override_shape`, and `build_cudnn_gemm_with_per_tensor_q_graph_override_shape`) select execution plans that do not require workspace memory. Therefore, `execute_cudnn_fp4_gemm_graph_override_shape`, `execute_cudnn_mxfp8_gemm_graph_override_shape`, and similar override-shape executor functions intentionally omit the `workspace_buffer.numel() >= graph.get_workspace_size()` guard that is present in the non-override-shape counterparts (e.g., `execute_cudnn_gemm_mxfp8_graph`). Do not flag the missing workspace guard in override-shape executors as a bug.

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 3066
File: flashinfer/fused_moe/cute_dsl/fused_moe.py:474-505
Timestamp: 2026-04-14T20:31:05.858Z
Learning: In `flashinfer/fused_moe/cute_dsl/fused_moe.py` and `flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py` (flashinfer-ai/flashinfer PR `#3066`): `CuteDslMoEWrapper._allocate_buffers()` pre-allocates exactly one SM120 workspace (static or dynamic) based on `max_num_tokens`. To prevent backend mismatch when `launch_sm120_moe` is called with a pre-allocated workspace, `launch_sm120_moe` must infer the backend from the workspace type (`isinstance(workspace, Sm120DynamicMoEWorkspace)` vs `Sm120StaticMoEWorkspace`) rather than re-running `select_sm120_moe_backend` on the current `num_tokens`. `select_sm120_moe_backend` is only called when no workspace is provided (non-CUDA-graph path, `_workspace=None`).

Learnt from: blake-snc
Repo: flashinfer-ai/flashinfer PR: 0
File: :0-0
Timestamp: 2026-04-16T15:52:27.219Z
Learning: In `flashinfer/prefill.py` (flashinfer-ai/flashinfer PR `#3016`), `BatchPrefillWithRaggedKVCacheWrapper.run()` reads `self._float_workspace_buffer` at call time (not during plan()), so `reset_workspace_buffer()` correctly updates the reference with no stale snapshot issue. Do not flag workspace buffer reads in run() as a stale-snapshot bug.

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 3007
File: tests/utils/test_norm.py:0-0
Timestamp: 2026-04-07T21:44:40.431Z
Learning: In `tests/utils/test_norm.py` (flashinfer-ai/flashinfer), when writing regression tests for large-stride (> INT32_MAX) tensor paths, using `torch.as_strided(small_buf, (M, H), (_INT64_STRIDE, 1))` on a small buffer is unsafe and will segfault because row 1 is at byte offset `2^31` beyond the allocation. The correct pattern is: allocate a flat buffer of at least `_INT64_STRIDE + H` elements, then create the strided view from it so every row is backed by real memory. For fused_add_rmsnorm tests, only the input tensor `x` needs to be non-contiguous (using the large flat buffer); the residual `r` can remain a normally-allocated contiguous tensor — one non-contiguous tensor is sufficient to trigger `is_contiguous() == False` in the kernel selection path, avoiding a second ~4 GB allocation.

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 3066
File: flashinfer/fused_moe/cute_dsl/fused_moe.py:206-220
Timestamp: 2026-04-14T19:11:19.907Z
Learning: In `flashinfer/fused_moe/cute_dsl/fused_moe.py` (flashinfer-ai/flashinfer PR `#3066`), the SM120/SM121 dispatch paths (`_moe_core_impl`, `CuteDslMoEWrapper.run`, and `cute_dsl_fused_moe_nvfp4`) intentionally do NOT forward `local_expert_offset` to `launch_sm120_moe`. Expert Parallelism (EP) is unsupported on SM120: the dynamic kernel (`MoEDynamicKernel`) lacks `global_to_local_expert` remapping, and EP tests are gated to SM100-only via `sm100_only`. Passing `local_expert_offset` without kernel-side support would silently produce incorrect results. Do not flag the missing `local_expert_offset` propagation in SM120 call sites as a bug.

Learnt from: TomerBN-Nvidia
Repo: flashinfer-ai/flashinfer PR: 3024
File: csrc/fused_moe/noAuxTcKernels.cu:351-369
Timestamp: 2026-04-12T12:18:22.194Z
Learning: In `csrc/fused_moe/noAuxTcKernels.cu` (flashinfer-ai/flashinfer PR `#3024`), the `routing_replay_out` validation in `NoAuxTc` intentionally does NOT check `replay.sizes()[0] >= num_tokens`. This is by design: with CUDA graphs, the buffer is pre-allocated at maximum batch size and reused across steps with varying `num_tokens`; the kernel only writes to indices `[0, num_tokens)` so a larger buffer is always safe. The same policy applies to `csrc/trtllm_fused_moe_kernel_launcher.cu` (documented at line ~1795). Do not flag the missing lower-bound dim0 check as a bug.

Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2865
File: include/flashinfer/mamba/kernel_selective_state_update_mtp_vertical.cuh:343-366
Timestamp: 2026-03-23T21:04:23.630Z
Learning: FlashInfer Mamba SSU MTP kernels: When applying the z gate in epilogues, index z using its own strides (z_stride_batch/z_stride_mtp), not the output strides. Verified in include/flashinfer/mamba/kernel_selective_state_update_mtp_vertical.cuh (role_epilogue) and include/flashinfer/mamba/kernel_selective_state_update_mtp_horizontal.cuh (inline epilogue).

If you found this review helpful, would you consider giving us a shout-out on X?

Thank you for using CodeRabbit!

@aleozlx aleozlx added the run-ci label May 4, 2026
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented May 4, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !625 has been created, and the CI pipeline #50264343 is currently running. I'll report back once the pipeline job completes.

@aleozlx aleozlx merged commit b307819 into flashinfer-ai:main May 5, 2026
28 of 37 checks passed
@aleozlx aleozlx added v0.6.10 release blocker label for 0.6.10 labels May 7, 2026
aleozlx added a commit that referenced this pull request May 7, 2026
…path (#3210)

## Summary

Follow-up to #2951. The merged DCP A2A code shipped with two latent
foot-guns that this PR cleans up:

1. The `mapping=None` branch in `decode_cp_a2a_allocate_workspace`
returns a per-rank `torch.zeros` tensor — this deadlocks at runtime on
any real multi-GPU setup.
2. The MNNVL workspace tensor's lifetime is pinned via a private
attribute, which is silently lost across ordinary tensor operations and
would let the fabric memory be unmapped while the kernel still holds raw
pointers into it.

Both are addressed below.

## Bug 1: workspace VA mismatch (silent deadlock)

`getFifoBasePtr` in
`csrc/nv_internal/tensorrt_llm/kernels/helixAllToAll.cu:177` addresses
peer FIFOs via:

```cpp
auto* mappedMemory = params.workspace + mappedMemoryRank * params.workspaceStrideInU64;
```

This pointer arithmetic only resolves correctly when the workspace is a
single unified VA spanning all CP ranks — i.e. MNNVL fabric memory. With
per-rank `torch.zeros`, rank 0 writing to "rank 1's FIFO" lands in rank
0's own memory; peer 1 never sees it and the FIFO consumer signaling
hangs forever.

Reproduced on H200 NVL with 4 GH200s — first `decode_cp_a2a_alltoall`
call hangs, no CUDA error, no assertion. TRT-LLM upstream
(`tensorrt_llm/_torch/distributed/ops.py:386`) only supports MNNVL —
there is no plain-memory branch. The fallback added during the
FlashInfer port was wrong from the start.

### Why CI didn't catch Bug 1

1. `tests/comm/test_dcp_alltoall.py` simulates `cp_size` ranks on
**one** GPU with **one shared** workspace tensor — pointer arithmetic on
the same allocation works, so the bug is invisible.
2. `tests/comm/test_mnnvl_dcp_alltoall.py::TestMnnvlDcpAlltoall`
exercises real multi-GPU but only on the MNNVL path.
3. `TestMnnvlDcpDeviceMemoryFallback` asserts shape only — never
actually calls `alltoall`, so the deadlock never fires.

## Bug 2: workspace keep-alive via tensor private attribute

The previous code did:

```python
workspace = mnnvl_mem.as_torch_strided_tensor(torch.int64)
workspace._mnnvl_mem = mnnvl_mem  # prevent GC of MNNVL handle
return workspace
```

`MnnvlMemory.__del__` calls `close_mnnvl_memory` which unmaps the
underlying fabric VA. The `workspace._mnnvl_mem` private attribute is
the only thing keeping the wrapper alive. **But torch tensor private
attributes are NOT preserved across ordinary operations** —
`workspace[r]`, `.view()`, `.contiguous()`, `.clone()`, `.to(...)`,
slicing, indexing all return a fresh tensor without the attribute. Any
caller that derives a view or slice and drops the original would
silently free the workspace while the kernel still holds raw pointers
into it. Current direct callers happen to be safe, but this is a buried
mine.

## Changes

### Bug 1 — drop the broken plain-memory path
- **Rename** `decode_cp_a2a_allocate_workspace` →
`decode_cp_a2a_allocate_mnnvl_workspace`, matching `trtllm_mnnvl_ar`
naming style. The MNNVL requirement is now obvious at the call site.
- **Make `mapping` the only required argument.** Drop the `mapping=None`
branch entirely. The redundant `cp_size` and `cp_rank` parameters were
also removed — `mapping` already carries that info, and a separate path
was a double-source-of-truth footgun.
- **Refactor the single-GPU sim test** to use a local
`_alloc_sim_workspace(cp_size)` helper that does `torch.zeros(...)`
directly — that's what the test actually needs and what its docstring
claims (it does *not* need the public allocator).
- **Drop `TestMnnvlDcpDeviceMemoryFallback`** — the path it covered no
longer exists.
- **Use the public allocator in the multi-rank test**
(`test_mnnvl_dcp_alltoall.py::_allocate_mnnvl_workspace_once`) instead
of manually instantiating `MnnvlMemory`, so the public API is exercised
end-to-end.
- Refresh module docstring to make the MNNVL requirement explicit.

### Bug 2 — robust workspace keep-alive
- Replace `workspace._mnnvl_mem = mnnvl_mem` with a module-level
`_workspace_keepalive: Dict[int, MnnvlMemory]` keyed by
`workspace.data_ptr()`. The dict pins each `MnnvlMemory` for the process
lifetime, so the workspace stays mapped regardless of what the caller
does with the returned tensor. This matches how TRT-LLM holds workspace
ownership in its `HelixAllToAllNative._cache` class-level dict.

### Final allocator signature
```python
decode_cp_a2a_allocate_mnnvl_workspace(mapping: Mapping, *, mnnvl_config: Optional[MnnvlConfig] = None) -> torch.Tensor
```

## Test plan

All verified on dlcluster GB200-NVL72 (compute capability 10.0a, CUDA
13):

- [x] `tests/comm/test_dcp_alltoall.py` (single-GPU sim, container
`flashinfer/flashinfer-ci-cu130`) — **29/29 PASSED in 5.16s**
- [x] `mpirun -launcher fork -np 4 pytest
tests/comm/test_mnnvl_dcp_alltoall.py` — single-node 4-rank MNNVL —
**8/8 PASSED on all 4 ranks in 1.34s**
- [x] `srun -N 2 --ntasks-per-node=4 --mpi=pmix pytest
tests/comm/test_mnnvl_dcp_alltoall.py` — **multi-node 8-rank MNNVL** (2
nodes × 4 GPUs, cp_size=8, container `nvcr.io/nvidia/pytorch:26.02-py3`
with HPC-X) — **8/8 PASSED on all 8 ranks in ~17s**

The multi-node run exercises real cross-node fabric memory allocation
via `cuMemCreate` with FABRIC handles — same code path Helix production
uses on NVL72.

🤖 AI-assisted (Claude Code)

---------

Co-authored-by: Alex Yang <aleyang@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

op: comm run-ci v0.6.10 release blocker label for 0.6.10 v0.6.11 release blocker label for 0.6.11

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants