fix(dcp_alltoall): require MNNVL workspace, drop broken plain-memory path#3210
Conversation
…path Follow-up to flashinfer-ai#2951. The merged code exposed a non-functional `mapping=None` branch in `decode_cp_a2a_allocate_workspace` that returns a per-rank `torch.zeros` tensor. This deadlocks at runtime on any real multi-GPU setup: the kernel addresses peer FIFOs via `params.workspace + peer_rank * stride`, which only resolves correctly when the workspace is a single unified VA spanning all CP ranks (i.e. MNNVL fabric memory). With per-rank `torch.zeros`, a write to peer N's FIFO lands in the local rank's own memory and the FIFO consumer signaling hangs forever. CI did not catch this: - Single-GPU sim test (`test_dcp_alltoall.py`) shares one workspace tensor across all simulated ranks, so cross-rank pointer arithmetic works on the same allocation. - The MNNVL multi-GPU test (`TestMnnvlDcpAlltoall`) only exercises the fabric path. - `TestMnnvlDcpDeviceMemoryFallback` only asserted shape, never called alltoall. Reproduced on H200 NVL with 4 GH200s — first `decode_cp_a2a_alltoall` call hangs, no CUDA error. TRT-LLM upstream (`tensorrt_llm/_torch/distributed/ops.py:386`) only supports MNNVL. Changes: - Rename `decode_cp_a2a_allocate_workspace` -> `decode_cp_a2a_allocate_mnnvl_workspace`, matching `trtllm_mnnvl_ar` naming style. `mapping` is now a required positional argument. - Drop the `mapping=None` branch entirely. - Update single-GPU sim test to allocate `torch.zeros` directly via a local `_alloc_sim_workspace` helper — that is what the test actually needs and what its docstring claims. - Drop `TestMnnvlDcpDeviceMemoryFallback` (the path it tested no longer exists). - Refresh module docstring to make the MNNVL requirement explicit.
📝 WalkthroughWalkthroughReplaces the generic DCP A2A workspace allocator with an MNNVL-specific allocator that requires a Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request renames the workspace allocator to decode_cp_a2a_allocate_mnnvl_workspace and makes the mapping argument mandatory, enforcing MNNVL fabric memory for multi-rank operations. Single-GPU tests are updated to use a private simulation helper. Feedback suggests simplifying the API by removing redundant cp_size and cp_rank parameters already available in the mapping object, utilizing the public API within tests to ensure proper validation, and removing redundant test cases.
| from flashinfer.comm import ( | ||
| decode_cp_a2a_alltoall, | ||
| decode_cp_a2a_allocate_workspace, | ||
| decode_cp_a2a_init_workspace, |
There was a problem hiding this comment.
The new public API decode_cp_a2a_allocate_mnnvl_workspace is not imported or used in this test file. Instead, the module-level helper _allocate_mnnvl_workspace_once (lines 121-144) manually reimplements the allocation logic. It is highly recommended to import and use the public API in the tests to ensure it is properly exercised and to reduce code duplication.
| decode_cp_a2a_init_workspace, | |
| decode_cp_a2a_allocate_mnnvl_workspace, |
| # 2. Allocate workspace (MNNVL or plain device memory) | ||
| workspace = decode_cp_a2a_allocate_workspace(cp_size, cp_rank, mapping=mapping) | ||
| # 2. Allocate MNNVL-backed workspace (Mapping is required) | ||
| workspace = decode_cp_a2a_allocate_mnnvl_workspace(cp_size, cp_rank, mapping) |
There was a problem hiding this comment.
The usage protocol should be updated to reflect the simplified API if the redundant cp_size and cp_rank arguments are removed from the allocator.
| workspace = decode_cp_a2a_allocate_mnnvl_workspace(cp_size, cp_rank, mapping) | |
| workspace = decode_cp_a2a_allocate_mnnvl_workspace(mapping) |
| def decode_cp_a2a_allocate_mnnvl_workspace( | ||
| cp_size: int, | ||
| cp_rank: int, | ||
| mapping: Mapping, | ||
| *, | ||
| mapping: Optional[Mapping] = None, | ||
| mnnvl_config: Optional[MnnvlConfig] = None, | ||
| ) -> torch.Tensor: | ||
| """Allocate a workspace tensor of shape ``[cp_size, ws_elems_per_rank]``. | ||
| """Allocate an MNNVL-backed workspace of shape ``[cp_size, ws_elems_per_rank]``. | ||
|
|
||
| The DCP A2A kernel requires a single unified VA spanning all CP ranks | ||
| (see module docstring), so workspace allocation must go through MNNVL | ||
| fabric memory. This function is the only supported allocator. | ||
|
|
||
| After allocation, call :func:`decode_cp_a2a_init_workspace` followed by a | ||
| cross-rank barrier before the first :func:`decode_cp_a2a_alltoall` call. | ||
|
|
||
| Two allocation modes: | ||
|
|
||
| - **MNNVL** (``mapping`` provided): Cross-rank visible GPU memory via | ||
| FlashInfer's ``MnnvlMemory``. Required for multi-node or when ranks | ||
| cannot see each other's device memory directly. | ||
| - **Plain device memory** (``mapping=None``): Standard ``torch.zeros`` | ||
| allocation. Sufficient for single-node with NVLink P2P. | ||
|
|
||
| Args: | ||
| cp_size: Context-parallel group size. | ||
| cp_rank: This rank's position in the CP group. | ||
| mapping: Mapping object for MNNVL allocation. If provided, MNNVL is | ||
| used. The mapping must have ``cp_size`` set correctly. The | ||
| communicator is split using ``mapping.pp_rank``, ``mapping.cp_rank``, | ||
| and ``mapping.tp_rank``. | ||
| cp_rank: This rank's position in the CP group (used for logging only; | ||
| ``mapping`` carries the authoritative rank info). | ||
| mapping: Mapping object for MNNVL allocation. Must have ``cp_size`` | ||
| set correctly. The communicator is split using ``mapping.pp_rank``, | ||
| ``mapping.cp_rank``, and ``mapping.tp_rank``. | ||
| mnnvl_config: Configuration for the MNNVL communication backend. | ||
| Required when using MNNVL with ``torch.distributed`` (pass | ||
| ``MnnvlConfig(comm_backend=TorchDistBackend(group))``). |
There was a problem hiding this comment.
The cp_size and cp_rank arguments are redundant because the mapping object (which is now a required positional argument) already contains this information. As noted in the docstring, mapping carries the authoritative rank info. Removing these redundant parameters simplifies the API and eliminates the risk of passing inconsistent values.
@flashinfer_api
def decode_cp_a2a_allocate_mnnvl_workspace(
mapping: Mapping,
*,
mnnvl_config: Optional[MnnvlConfig] = None,
) -> torch.Tensor:
"""Allocate an MNNVL-backed workspace of shape ``[cp_size, ws_elems_per_rank]``.
The DCP A2A kernel requires a single unified VA spanning all CP ranks
(see module docstring), so workspace allocation must go through MNNVL
fabric memory. This function is the only supported allocator.
After allocation, call :func:`decode_cp_a2a_init_workspace` followed by a
cross-rank barrier before the first :func:`decode_cp_a2a_alltoall` call.
Args:
mapping: Mapping object for MNNVL allocation. Must have ``cp_size``
set correctly. The communicator is split using ``mapping.pp_rank``,
``mapping.cp_rank``, and ``mapping.tp_rank``.
mnnvl_config: Configuration for the MNNVL communication backend.
Required when using MNNVL with ``torch.distributed`` (pass
``MnnvlConfig(comm_backend=TorchDistBackend(group))``).| ws_bytes = decode_cp_a2a_workspace_size(cp_size) | ||
|
|
||
| if mapping is not None: | ||
| MnnvlMemory.initialize() | ||
| if mnnvl_config: | ||
| MnnvlMemory.set_comm_from_config(mapping, mnnvl_config) | ||
|
|
||
| mnnvl_mem = MnnvlMemory(mapping, ws_bytes) | ||
| workspace = mnnvl_mem.as_torch_strided_tensor(torch.int64) | ||
| workspace._mnnvl_mem = mnnvl_mem # prevent GC of MNNVL handle | ||
| logger.info( | ||
| "Rank %d: DCP MNNVL workspace allocated — shape=%s, stride=%s", | ||
| cp_rank, | ||
| list(workspace.shape), | ||
| list(workspace.stride()), | ||
| ) | ||
| return workspace | ||
|
|
||
| ws_elems_per_rank = (ws_bytes + 7) // 8 | ||
| return torch.zeros(cp_size, ws_elems_per_rank, dtype=torch.int64, device="cuda") | ||
| MnnvlMemory.initialize() | ||
| if mnnvl_config: | ||
| MnnvlMemory.set_comm_from_config(mapping, mnnvl_config) | ||
|
|
||
| mnnvl_mem = MnnvlMemory(mapping, ws_bytes) | ||
| workspace = mnnvl_mem.as_torch_strided_tensor(torch.int64) | ||
| workspace._mnnvl_mem = mnnvl_mem # prevent GC of MNNVL handle | ||
| logger.info( | ||
| "Rank %d: DCP MNNVL workspace allocated — shape=%s, stride=%s", | ||
| cp_rank, | ||
| list(workspace.shape), | ||
| list(workspace.stride()), | ||
| ) | ||
| return workspace |
There was a problem hiding this comment.
If the redundant cp_size and cp_rank arguments are removed from the function signature, the implementation should be updated to extract these values directly from the mapping object.
| ws_bytes = decode_cp_a2a_workspace_size(cp_size) | |
| if mapping is not None: | |
| MnnvlMemory.initialize() | |
| if mnnvl_config: | |
| MnnvlMemory.set_comm_from_config(mapping, mnnvl_config) | |
| mnnvl_mem = MnnvlMemory(mapping, ws_bytes) | |
| workspace = mnnvl_mem.as_torch_strided_tensor(torch.int64) | |
| workspace._mnnvl_mem = mnnvl_mem # prevent GC of MNNVL handle | |
| logger.info( | |
| "Rank %d: DCP MNNVL workspace allocated — shape=%s, stride=%s", | |
| cp_rank, | |
| list(workspace.shape), | |
| list(workspace.stride()), | |
| ) | |
| return workspace | |
| ws_elems_per_rank = (ws_bytes + 7) // 8 | |
| return torch.zeros(cp_size, ws_elems_per_rank, dtype=torch.int64, device="cuda") | |
| MnnvlMemory.initialize() | |
| if mnnvl_config: | |
| MnnvlMemory.set_comm_from_config(mapping, mnnvl_config) | |
| mnnvl_mem = MnnvlMemory(mapping, ws_bytes) | |
| workspace = mnnvl_mem.as_torch_strided_tensor(torch.int64) | |
| workspace._mnnvl_mem = mnnvl_mem # prevent GC of MNNVL handle | |
| logger.info( | |
| "Rank %d: DCP MNNVL workspace allocated — shape=%s, stride=%s", | |
| cp_rank, | |
| list(workspace.shape), | |
| list(workspace.stride()), | |
| ) | |
| return workspace | |
| cp_size = mapping.cp_size | |
| cp_rank = mapping.cp_rank | |
| ws_bytes = decode_cp_a2a_workspace_size(cp_size) | |
| MnnvlMemory.initialize() | |
| if mnnvl_config: | |
| MnnvlMemory.set_comm_from_config(mapping, mnnvl_config) | |
| mnnvl_mem = MnnvlMemory(mapping, ws_bytes) | |
| workspace = mnnvl_mem.as_torch_strided_tensor(torch.int64) | |
| workspace._mnnvl_mem = mnnvl_mem # prevent GC of MNNVL handle | |
| logger.info( | |
| "Rank %d: DCP MNNVL workspace allocated — shape=%s, stride=%s", | |
| cp_rank, | |
| list(workspace.shape), | |
| list(workspace.stride()), | |
| ) | |
| return workspace |
| for cp_size in [2, 4]: | ||
| ws_bytes = decode_cp_a2a_workspace_size(cp_size) | ||
| workspace = decode_cp_a2a_allocate_workspace(cp_size, cp_rank=0) | ||
| workspace = _alloc_sim_workspace(cp_size) |
There was a problem hiding this comment.
This test case now uses the local helper _alloc_sim_workspace instead of the public API decode_cp_a2a_allocate_mnnvl_workspace. While this allows the test to pass in a single-GPU environment, it means the public allocator is no longer being tested here. Since the workspace shape and dtype are already verified in tests/comm/test_mnnvl_dcp_alltoall.py using real MNNVL allocations, this test case is now redundant and could be removed to keep the test suite focused.
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/comm/dcp_alltoall.py (1)
127-172:⚠️ Potential issue | 🟠 Major | ⚡ Quick winValidate the topology inputs before allocating.
cp_sizeis still used to size the workspace, but this function never checks that it matchesmapping.cp_sizeor thatcp_rankmatchesmapping.cp_rank. A stale caller can now allocate for one rank layout and initialize the MNNVL handle for another, which can reintroduce the deadlock/corruption this PR is trying to remove.Suggested guard
def decode_cp_a2a_allocate_mnnvl_workspace( cp_size: int, cp_rank: int, mapping: Mapping, *, mnnvl_config: Optional[MnnvlConfig] = None, ) -> torch.Tensor: + if cp_size != mapping.cp_size: + raise ValueError( + f"cp_size={cp_size} does not match mapping.cp_size={mapping.cp_size}" + ) + if cp_rank != mapping.cp_rank: + raise ValueError( + f"cp_rank={cp_rank} does not match mapping.cp_rank={mapping.cp_rank}" + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/comm/dcp_alltoall.py` around lines 127 - 172, In decode_cp_a2a_allocate_mnnvl_workspace validate that the provided cp_size and cp_rank match the topology carried by mapping before initializing/allocating MnnvlMemory: check mapping.cp_size == cp_size and mapping.cp_rank == cp_rank (and raise a ValueError with a clear message including both expected and actual values if they differ); perform these checks at the top of the function (before MnnvlMemory.initialize()/MnnvlMemory(...) and before logging) so a stale caller cannot allocate/initialize using a mismatched layout.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@flashinfer/comm/dcp_alltoall.py`:
- Around line 127-172: In decode_cp_a2a_allocate_mnnvl_workspace validate that
the provided cp_size and cp_rank match the topology carried by mapping before
initializing/allocating MnnvlMemory: check mapping.cp_size == cp_size and
mapping.cp_rank == cp_rank (and raise a ValueError with a clear message
including both expected and actual values if they differ); perform these checks
at the top of the function (before MnnvlMemory.initialize()/MnnvlMemory(...) and
before logging) so a stale caller cannot allocate/initialize using a mismatched
layout.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: e7587b59-f6d0-4604-bac5-4ad8e4030b2d
📒 Files selected for processing (4)
flashinfer/comm/__init__.pyflashinfer/comm/dcp_alltoall.pytests/comm/test_dcp_alltoall.pytests/comm/test_mnnvl_dcp_alltoall.py
💤 Files with no reviewable changes (1)
- tests/comm/test_mnnvl_dcp_alltoall.py
The previous code attached the MnnvlMemory wrapper as a private attribute on the returned tensor (`workspace._mnnvl_mem = mnnvl_mem`) to keep it from being garbage-collected. That works for direct callers, but private attributes on a torch tensor are NOT preserved across ordinary tensor operations: `workspace[r]`, `.view()`, `.contiguous()`, `.clone()`, `.to(...)`, slicing, indexing, etc. all return a fresh tensor without the attribute. Any caller that derives a view / slice and drops the original would silently free the underlying fabric memory while the kernel still holds raw pointers into it. Move the keep-alive to a module-level `Dict[int, MnnvlMemory]` keyed by `workspace.data_ptr()`. The dict pins each MnnvlMemory wrapper for the process lifetime, so the workspace stays mapped regardless of what the caller does with the returned tensor. This matches the spirit of how TRT-LLM holds workspace ownership in its `HelixAllToAllNative._cache` class-level dict. Same single-GPU and 4-rank MNNVL test results as before.
Per Gemini code review on PR flashinfer-ai#3210: `mapping` already carries the authoritative cp_size and cp_rank, so passing them as separate parameters creates a double-source-of-truth footgun. Simplify the public API to take only `mapping` and a keyword-only `mnnvl_config`. Same change ripples to tests: - `tests/comm/test_mnnvl_dcp_alltoall.py::_allocate_mnnvl_workspace_once` now calls the public API instead of manually instantiating MnnvlMemory, so the multi-rank test exercises the public allocator. - `tests/comm/test_dcp_alltoall.py::test_allocate_returns_correct_shape_and_dtype` is removed; after the rename the single-GPU test allocates via a local `torch.zeros` helper, so the assertion is tautological. The real shape/dtype assertion lives in `test_mnnvl_dcp_alltoall.py::test_workspace_shape`. No kernel or behavior changes; pure API simplification.
|
Thanks @gemini-code-assist and @coderabbitai for the review. Addressed all 5 Gemini suggestions in commit b87aa37:
CodeRabbit's suggestion to validate Re-tested on dlcluster GB200-NVL72:
|
|
Thanks for the update, @davidjpyu. It's great to see the MNNVL workspace requirement enforced at the API level and the keep-alive mechanism made more robust. The test suite adjustments also look correct for the new API contract. Since you've verified the changes across single-GPU, single-node, and multi-node environments, this looks solid. |
|
Tip For best results, initiate chat on the files or code changes.
🧠 Learnings usedIf you found this review helpful, would you consider giving us a shout-out on X? Thank you for using CodeRabbit! |
|
/bot run |
…path (#3210) ## Summary Follow-up to #2951. The merged DCP A2A code shipped with two latent foot-guns that this PR cleans up: 1. The `mapping=None` branch in `decode_cp_a2a_allocate_workspace` returns a per-rank `torch.zeros` tensor — this deadlocks at runtime on any real multi-GPU setup. 2. The MNNVL workspace tensor's lifetime is pinned via a private attribute, which is silently lost across ordinary tensor operations and would let the fabric memory be unmapped while the kernel still holds raw pointers into it. Both are addressed below. ## Bug 1: workspace VA mismatch (silent deadlock) `getFifoBasePtr` in `csrc/nv_internal/tensorrt_llm/kernels/helixAllToAll.cu:177` addresses peer FIFOs via: ```cpp auto* mappedMemory = params.workspace + mappedMemoryRank * params.workspaceStrideInU64; ``` This pointer arithmetic only resolves correctly when the workspace is a single unified VA spanning all CP ranks — i.e. MNNVL fabric memory. With per-rank `torch.zeros`, rank 0 writing to "rank 1's FIFO" lands in rank 0's own memory; peer 1 never sees it and the FIFO consumer signaling hangs forever. Reproduced on H200 NVL with 4 GH200s — first `decode_cp_a2a_alltoall` call hangs, no CUDA error, no assertion. TRT-LLM upstream (`tensorrt_llm/_torch/distributed/ops.py:386`) only supports MNNVL — there is no plain-memory branch. The fallback added during the FlashInfer port was wrong from the start. ### Why CI didn't catch Bug 1 1. `tests/comm/test_dcp_alltoall.py` simulates `cp_size` ranks on **one** GPU with **one shared** workspace tensor — pointer arithmetic on the same allocation works, so the bug is invisible. 2. `tests/comm/test_mnnvl_dcp_alltoall.py::TestMnnvlDcpAlltoall` exercises real multi-GPU but only on the MNNVL path. 3. `TestMnnvlDcpDeviceMemoryFallback` asserts shape only — never actually calls `alltoall`, so the deadlock never fires. ## Bug 2: workspace keep-alive via tensor private attribute The previous code did: ```python workspace = mnnvl_mem.as_torch_strided_tensor(torch.int64) workspace._mnnvl_mem = mnnvl_mem # prevent GC of MNNVL handle return workspace ``` `MnnvlMemory.__del__` calls `close_mnnvl_memory` which unmaps the underlying fabric VA. The `workspace._mnnvl_mem` private attribute is the only thing keeping the wrapper alive. **But torch tensor private attributes are NOT preserved across ordinary operations** — `workspace[r]`, `.view()`, `.contiguous()`, `.clone()`, `.to(...)`, slicing, indexing all return a fresh tensor without the attribute. Any caller that derives a view or slice and drops the original would silently free the workspace while the kernel still holds raw pointers into it. Current direct callers happen to be safe, but this is a buried mine. ## Changes ### Bug 1 — drop the broken plain-memory path - **Rename** `decode_cp_a2a_allocate_workspace` → `decode_cp_a2a_allocate_mnnvl_workspace`, matching `trtllm_mnnvl_ar` naming style. The MNNVL requirement is now obvious at the call site. - **Make `mapping` the only required argument.** Drop the `mapping=None` branch entirely. The redundant `cp_size` and `cp_rank` parameters were also removed — `mapping` already carries that info, and a separate path was a double-source-of-truth footgun. - **Refactor the single-GPU sim test** to use a local `_alloc_sim_workspace(cp_size)` helper that does `torch.zeros(...)` directly — that's what the test actually needs and what its docstring claims (it does *not* need the public allocator). - **Drop `TestMnnvlDcpDeviceMemoryFallback`** — the path it covered no longer exists. - **Use the public allocator in the multi-rank test** (`test_mnnvl_dcp_alltoall.py::_allocate_mnnvl_workspace_once`) instead of manually instantiating `MnnvlMemory`, so the public API is exercised end-to-end. - Refresh module docstring to make the MNNVL requirement explicit. ### Bug 2 — robust workspace keep-alive - Replace `workspace._mnnvl_mem = mnnvl_mem` with a module-level `_workspace_keepalive: Dict[int, MnnvlMemory]` keyed by `workspace.data_ptr()`. The dict pins each `MnnvlMemory` for the process lifetime, so the workspace stays mapped regardless of what the caller does with the returned tensor. This matches how TRT-LLM holds workspace ownership in its `HelixAllToAllNative._cache` class-level dict. ### Final allocator signature ```python decode_cp_a2a_allocate_mnnvl_workspace(mapping: Mapping, *, mnnvl_config: Optional[MnnvlConfig] = None) -> torch.Tensor ``` ## Test plan All verified on dlcluster GB200-NVL72 (compute capability 10.0a, CUDA 13): - [x] `tests/comm/test_dcp_alltoall.py` (single-GPU sim, container `flashinfer/flashinfer-ci-cu130`) — **29/29 PASSED in 5.16s** - [x] `mpirun -launcher fork -np 4 pytest tests/comm/test_mnnvl_dcp_alltoall.py` — single-node 4-rank MNNVL — **8/8 PASSED on all 4 ranks in 1.34s** - [x] `srun -N 2 --ntasks-per-node=4 --mpi=pmix pytest tests/comm/test_mnnvl_dcp_alltoall.py` — **multi-node 8-rank MNNVL** (2 nodes × 4 GPUs, cp_size=8, container `nvcr.io/nvidia/pytorch:26.02-py3` with HPC-X) — **8/8 PASSED on all 8 ranks in ~17s** The multi-node run exercises real cross-node fabric memory allocation via `cuMemCreate` with FABRIC handles — same code path Helix production uses on NVL72. 🤖 AI-assisted (Claude Code) --------- Co-authored-by: Alex Yang <aleyang@nvidia.com>
Summary
Follow-up to #2951. The merged DCP A2A code shipped with two latent foot-guns that this PR cleans up:
mapping=Nonebranch indecode_cp_a2a_allocate_workspacereturns a per-ranktorch.zerostensor — this deadlocks at runtime on any real multi-GPU setup.Both are addressed below.
Bug 1: workspace VA mismatch (silent deadlock)
getFifoBasePtrincsrc/nv_internal/tensorrt_llm/kernels/helixAllToAll.cu:177addresses peer FIFOs via:auto* mappedMemory = params.workspace + mappedMemoryRank * params.workspaceStrideInU64;This pointer arithmetic only resolves correctly when the workspace is a single unified VA spanning all CP ranks — i.e. MNNVL fabric memory. With per-rank
torch.zeros, rank 0 writing to "rank 1's FIFO" lands in rank 0's own memory; peer 1 never sees it and the FIFO consumer signaling hangs forever.Reproduced on H200 NVL with 4 GH200s — first
decode_cp_a2a_alltoallcall hangs, no CUDA error, no assertion. TRT-LLM upstream (tensorrt_llm/_torch/distributed/ops.py:386) only supports MNNVL — there is no plain-memory branch. The fallback added during the FlashInfer port was wrong from the start.Why CI didn't catch Bug 1
tests/comm/test_dcp_alltoall.pysimulatescp_sizeranks on one GPU with one shared workspace tensor — pointer arithmetic on the same allocation works, so the bug is invisible.tests/comm/test_mnnvl_dcp_alltoall.py::TestMnnvlDcpAlltoallexercises real multi-GPU but only on the MNNVL path.TestMnnvlDcpDeviceMemoryFallbackasserts shape only — never actually callsalltoall, so the deadlock never fires.Bug 2: workspace keep-alive via tensor private attribute
The previous code did:
MnnvlMemory.__del__callsclose_mnnvl_memorywhich unmaps the underlying fabric VA. Theworkspace._mnnvl_memprivate attribute is the only thing keeping the wrapper alive. But torch tensor private attributes are NOT preserved across ordinary operations —workspace[r],.view(),.contiguous(),.clone(),.to(...), slicing, indexing all return a fresh tensor without the attribute. Any caller that derives a view or slice and drops the original would silently free the workspace while the kernel still holds raw pointers into it. Current direct callers happen to be safe, but this is a buried mine.Changes
Bug 1 — drop the broken plain-memory path
decode_cp_a2a_allocate_workspace→decode_cp_a2a_allocate_mnnvl_workspace, matchingtrtllm_mnnvl_arnaming style. The MNNVL requirement is now obvious at the call site.mappingthe only required argument. Drop themapping=Nonebranch entirely. The redundantcp_sizeandcp_rankparameters were also removed —mappingalready carries that info, and a separate path was a double-source-of-truth footgun._alloc_sim_workspace(cp_size)helper that doestorch.zeros(...)directly — that's what the test actually needs and what its docstring claims (it does not need the public allocator).TestMnnvlDcpDeviceMemoryFallback— the path it covered no longer exists.test_mnnvl_dcp_alltoall.py::_allocate_mnnvl_workspace_once) instead of manually instantiatingMnnvlMemory, so the public API is exercised end-to-end.Bug 2 — robust workspace keep-alive
workspace._mnnvl_mem = mnnvl_memwith a module-level_workspace_keepalive: Dict[int, MnnvlMemory]keyed byworkspace.data_ptr(). The dict pins eachMnnvlMemoryfor the process lifetime, so the workspace stays mapped regardless of what the caller does with the returned tensor. This matches how TRT-LLM holds workspace ownership in itsHelixAllToAllNative._cacheclass-level dict.Final allocator signature
Test plan
All verified on dlcluster GB200-NVL72 (compute capability 10.0a, CUDA 13):
tests/comm/test_dcp_alltoall.py(single-GPU sim, containerflashinfer/flashinfer-ci-cu130) — 29/29 PASSED in 5.16smpirun -launcher fork -np 4 pytest tests/comm/test_mnnvl_dcp_alltoall.py— single-node 4-rank MNNVL — 8/8 PASSED on all 4 ranks in 1.34ssrun -N 2 --ntasks-per-node=4 --mpi=pmix pytest tests/comm/test_mnnvl_dcp_alltoall.py— multi-node 8-rank MNNVL (2 nodes × 4 GPUs, cp_size=8, containernvcr.io/nvidia/pytorch:26.02-py3with HPC-X) — 8/8 PASSED on all 8 ranks in ~17sThe multi-node run exercises real cross-node fabric memory allocation via
cuMemCreatewith FABRIC handles — same code path Helix production uses on NVL72.🤖 AI-assisted (Claude Code)