Skip to content

A unified API for the MNNVL and single-node/multi-GPU AllReduce kernels.#2130

Merged
yzh119 merged 22 commits intoflashinfer-ai:mainfrom
nvmbreughe:mbreughe/allreduce_unified
Dec 17, 2025
Merged

A unified API for the MNNVL and single-node/multi-GPU AllReduce kernels.#2130
yzh119 merged 22 commits intoflashinfer-ai:mainfrom
nvmbreughe:mbreughe/allreduce_unified

Conversation

@nvmbreughe
Copy link
Copy Markdown
Contributor

@nvmbreughe nvmbreughe commented Nov 21, 2025

📌 Description

A unified API for the MNNVL and single-node AllReduce kernels.

  • This introduces the API's create_allreduce_fusion_workspace, and allreduce_fusion
  • The backend ("trtllm" or "mnnvl") is chosen during workspace creation. We can either pick it explicitly, or use the "auto" backend to have a heuristic pick the best backend.
  • The API can be used for both single-node allreduce, as well as for multi-node allreduce.

Test with

mpirun -np 4 pytest tests/comm/test_allreduce_unified_api.py
mpirun -np 4 pytest tests/comm/test_allreduce_negative.py

note: mpirun is needed for the mnnvl backend, as illustrated in the test commands above

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Unified AllReduce Fusion API with multi-backend support (auto/TRTLLM/MNNVL) and public workspace types
    • Common workspace base with lifecycle management and automatic cleanup warnings
  • Bug Fixes / Validation

    • Stronger input/workspace validation with aggregated error messages
    • Idempotent destroy semantics for safer resource cleanup
  • Deprecations

    • Legacy AllReduce APIs deprecated in favor of the unified API
  • Tests

    • Expanded negative, integration, and cross-backend correctness tests plus MPI/CUDA test helpers

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Nov 21, 2025

Walkthrough

Adds a unified AllReduce fusion API: introduces an abstract AllReduceFusionWorkspace base class, concrete TRTLLM/MNNVL workspace wrappers, a factory create_allreduce_fusion_workspace(), dispatcher allreduce_fusion(), backend checks/heuristics, updated backend modules and deprecation messages, and comprehensive MPI/CUDA tests and helpers.

Changes

Cohort / File(s) Change Summary
Base API
flashinfer/comm/workspace_base.py
New abstract AllReduceFusionWorkspace with world_size/rank, abstract backend, destroy, is_buffer_size_sufficient, and safety __del__ that warns and attempts cleanup.
Unified API & Dispatcher
flashinfer/comm/allreduce.py
Adds create_allreduce_fusion_workspace() factory, allreduce_fusion() dispatcher, TRTLLMAllReduceFusionWorkspace wrapper, backend checks/heuristics, and wiring to MNNVL workspace APIs.
Module Exports
flashinfer/comm/__init__.py
Exports added: AllReduceFusionWorkspace, TRTLLMAllReduceFusionWorkspace, MNNVLAllReduceFusionWorkspace, allreduce_fusion, create_allreduce_fusion_workspace.
TRTLLM Backend
flashinfer/comm/trtllm_ar.py
Updated deprecation messages; new check_trtllm_allreduce_fusion_workspace_metadata() helper; trtllm_allreduce_fusion() accepts optional metadata and delegates validation to helper.
MNNVL Backend
flashinfer/comm/trtllm_mnnvl_ar.py
Renames/aligns to MNNVLAllReduceFusionWorkspace inheriting from AllReduceFusionWorkspace, adds backend property and idempotent destroy(), updates public function signatures to use new type.
Tests — unified & helpers
tests/comm/test_allreduce_unified_api.py, tests/test_helpers/comm.py
New unified API test harness and MPI/CUDA helpers: workspace creation, correctness checks across backends, dtype/seq_len parameterization, and dist init/cleanup utilities.
Tests — negative
tests/comm/test_allreduce_negative.py
New negative tests covering MNNVL unsupported patterns/layouts, missing parameters, and buffer-size sufficiency checks across backends.
Tests — adaptations
tests/comm/test_trtllm_allreduce_fusion.py, tests/comm/test_trtllm_mnnvl_allreduce.py
Add legacy_api toggle to exercise legacy vs unified paths, fix MNNVL workspace type/name inconsistencies, and branch cleanup accordingly.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant User
    participant API as AllReduce API\n(flashinfer/comm/allreduce.py)
    participant Selector as Backend\nSelector/Heuristic
    participant TRT as TRTLLM\nBackend (legacy IPC / trtllm_ar)
    participant MNN as MNNVL\nBackend (trtllm_mnnvl_ar)

    User->>API: create_allreduce_fusion_workspace(backend, topology, ...)
    API->>Selector: evaluate topology & heuristics
    alt choose TRTLLM
        Selector-->>API: select trtllm
        API->>TRT: create/wrap IPC workspace
        TRT-->>API: TRTLLMAllReduceFusionWorkspace
    else choose MNNVL
        Selector-->>API: select mnnvl
        API->>MNN: create/wrap MNNVL workspace
        MNN-->>API: MNNVLAllReduceFusionWorkspace
    end
    API-->>User: workspace

    User->>API: allreduce_fusion(input, workspace, pattern, ...)
    API->>API: inspect workspace.backend
    alt TRTLLM workspace
        API->>API: flatten inputs, build metadata
        API->>TRT: trtllm_allreduce_fusion(flattened, metadata)
        TRT-->>API: flattened output
        API->>API: reshape output
        API-->>User: output
    else MNNVL workspace
        API->>API: validate pattern & required tensors
        alt kARResidualRMSNorm
            API->>MNN: trtllm_mnnvl_fused_allreduce_add_rmsnorm(...)
        else kAllReduce
            API->>MNN: trtllm_mnnvl_allreduce(...)
        end
        MNN-->>API: output tensor(s)
        API-->>User: output
    end

    User->>API: workspace.destroy()
    API->>TRT/MNN: cleanup resources (idempotent)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

  • Focus areas:
    • Backend selection heuristic and topology checks in create_allreduce_fusion_workspace()
    • TRTLLM flatten/reshape and metadata validation/propagation in trtllm_allreduce_fusion()
    • MNNVL pattern validation, layout/quantization error paths in MNNVL APIs
    • Workspace lifecycle: destroy() semantics and __del__ safety behavior
    • MPI/CUDA test harness, distributed synchronization, and test parameterization

Possibly related PRs

Suggested reviewers

  • cyx-6
  • djmmoss
  • wenscarl
  • yongwww
  • jiahanc
  • kahyunnam
  • IwakuraRein

Poem

🐰 I found a workspace, neat and spry,

world_size, rank — we hop and try.
TRT flattens, MNNVL hums along,
heuristics pick which backend’s song.
Destroy, cleanup — then nibble a carrot high. 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 69.64% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: introducing a unified API for AllReduce kernels across MNNVL and single-node/multi-GPU backends.
Description check ✅ Passed The description includes key sections from the template and provides comprehensive details about the unified API, backend selection, usage modes, test commands, and pre-commit/test completion status.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 78c6977 and 8fa9ccd.

📒 Files selected for processing (1)
  • flashinfer/comm/__init__.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/comm/__init__.py (3)
flashinfer/comm/workspace_base.py (1)
  • AllReduceFusionWorkspace (23-89)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
  • MNNVLAllReduceFusionWorkspace (51-241)
flashinfer/comm/allreduce.py (3)
  • TRTLLMAllReduceFusionWorkspace (88-169)
  • allreduce_fusion (452-702)
  • create_allreduce_fusion_workspace (286-444)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (1)
flashinfer/comm/__init__.py (1)

42-51: Syntax error fixed; import structure is correct.

The unclosed parenthesis from the previous review has been correctly fixed. The import of AllReduceFusionWorkspace from .allreduce is intentional—allreduce.py re-exports it from workspace_base.py at line 52, providing a unified API namespace for workspace classes. This pattern is consistent with the documented AllReduce Fusion API structure.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Comment thread flashinfer/comm/allreduce.py Outdated
max_token_num: int,
hidden_dim: int,
dtype: torch.dtype,
topology: str,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
topology: str,
topology: Literal["single_node", "multi_node"],

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Thanks!

Comment thread flashinfer/comm/allreduce.py Outdated
max_token_num: int = None,
hidden_dim: int = None,
dtype: torch.dtype = None,
topology: str = "single_node",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is needed longer term since we will use the same pytorch symmetric API to allocate symmetric memory for single and multi-node (under the cover pytorch/NCCL/NVSHMEM will detect platform and decides the right mem allocation handle)

input: torch.Tensor,
workspace: AllReduceFusionWorkspace,
pattern: int,
launch_with_pdl: bool = False,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the advantage to give pdl control to the user?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have been doing this for all our APIs, but I am not sure why. Maybe because not all archs support it?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the main purpose of having this control is for debugging as we observed some issues of pdl on sm121.

For deployment it should always be turned on IMO, so I suppose the default value should be None and we will call

def device_support_pdl(device: torch.device) -> bool:
to automatically determine it.

Args:
input: Input tensor [token_num, hidden_dim]
workspace: Workspace object (type determines backend)
pattern: Fusion pattern (AllReduceFusionPattern constant, 0-5)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All they all 2-kernel overlap or some are real fusion kernels?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with one-shot mnnvl it's real fusion. And I think similar for the trtllm_ar kernels. It's just two-shot mnnvl that is the 2-kernel overlap.

},
heuristic_func=_workspace_creation_heuristic,
)
def create_allreduce_fusion_workspace(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could create_allreduce_fusion_workspace take an optional workspace argument? If workspace is big enough or too big this is a noop (maybe just updating backend selection). If it is too small, destroy current workspace and allocate a bigger one.

When we switch to mem pool, we should be able to call create_allreduce_fusion_workspace at each forward pass and memory will just get reused from the mempool (instead of new allocations).
CC @Amir-19

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the advantage of having
workspace= create_allreduce_fusion_workspace(old_workspace)
vs

workspace = old_workspace if condition else create_allreduce_fusion_workspace()

?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was for the case when you would support growing the workspace:
If needed_size > old_workspace.size:
old_workspace.destroy()
return new_workspace

But we discussed that we may not be able to destroy current workspace due to previous CUDA graph captures.
So you can ignore my comment for now.

- Workspace(max_token_num=2048, hidden_dim=4096) can handle:
- (token_num=2048, hidden_dim=4096) ✓
- (token_num=1024, hidden_dim=4096) ✓
- (token_num=4096, hidden_dim=2048) ✓ (same total size)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only see FW adjusting the num of tokens but hidden_dim should be fixed per model.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed this live, and resulted in creating the one-shot vs two-shot decision to be delayed until runtime (aka not during workspace creation time).

... max_token_num=2048,
... hidden_dim=4096,
... dtype=torch.bfloat16,
... topology="single_node"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we had a check now to detect topology? before we switch to the mempool allocation?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @nvcastet , could you elaborate what you mean with a topology check?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like that check: https://github.com/pytorch/pytorch/blob/eb66e1d2b590865be962b8677acc31728d3ad953/aten/src/ATen/cuda/PeerToPeerAccess.cpp#L146-L180

If fabric is enabled you would have a fabric workspace (for MNNVL) otherwise to default to single-node workspace.

@nvmbreughe nvmbreughe force-pushed the mbreughe/allreduce_unified branch from 437c7df to 10554e5 Compare December 9, 2025 23:09
@nvmbreughe nvmbreughe marked this pull request as ready for review December 12, 2025 00:23
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/comm/mnnvl.py (1)

132-150: Return type bug: annotated -> int but returns None
alloc_and_copy_to_cuda() is typed to return int but returns None when host_ptr_array is empty. This will break callers expecting an integer device pointer.

-def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int:
+def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int:
   """
   A helper function that allocates memory on cuda and copies the data from the host to the device.
   """
   if not host_ptr_array:
-      return None
+      return 0
🧹 Nitpick comments (18)
include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh (2)

218-242: Lamport sync helpers are fragile w.r.t. early-exit / partial participation
LamportFlags::ctaArrive() uses __syncthreads() / cluster.sync(), which will deadlock if any threads in the block exit early before reaching it. Today you mostly avoid OOB by shaping blockSize, but this is a sharp edge for future tuning (e.g., cluster sizing, loadsPerThread, partial tiles). Consider reworking ctaArrive() to avoid requiring full-block participation (e.g., a single designated thread does the atomic without a block-wide barrier), or enforce the invariant with stronger checks/docs at the call sites.

Also applies to: 385-418, 509-651, 760-888


924-950: Minor: offsets array is oversized (stack/register pressure)
In rmsNormLamport, offsets is declared as LoadsPerThread * kELTS_PER_LOAD but indexed as offsets[i] for i < LoadsPerThread. Tighten this to LoadsPerThread.

-  uint32_t offsets[LoadsPerThread * kELTS_PER_LOAD];
+  uint32_t offsets[LoadsPerThread];
csrc/trtllm_mnnvl_allreduce.cu (1)

29-35: Add basic range/null guards for pointer-like int64 args
A null/invalid multicast_buffer_ptr, buffer_ptrs_dev, buffer_ptr_local, or empty buffer_flags_mnnvl will turn into UB quickly. Even lightweight > 0 / non-empty checks would make failures actionable.

Also applies to: 79-90

flashinfer/comm/mnnvl.py (1)

640-656: Minor: unused recvmsg tuple parts (ruff RUF059)
Prefix unused msg/flags/addr with _ to avoid lint noise.

tests/comm/test_allreduce_negative.py (1)

36-70: Consider extracting duplicated fixture setup to reduce code repetition.

Both TestMNNVLUnsupportedPatterns and TestMNNVLMissingRequiredParameters have nearly identical setup fixtures. Consider extracting the common workspace creation and teardown logic into a shared fixture or base class to improve maintainability.

@pytest.fixture
def mnnvl_workspace():
    """Shared fixture for MNNVL workspace setup and teardown."""
    rank, world_size, gpus_per_node = setup_mpi_and_cuda()
    
    workspace = create_allreduce_fusion_workspace(
        backend="mnnvl",
        world_size=world_size,
        rank=rank,
        max_token_num=128,
        hidden_dim=2880,
        dtype=torch.float16,
        topology="single_node",
        gpus_per_node=gpus_per_node,
    )
    
    yield workspace, rank, world_size, gpus_per_node
    
    if workspace is not None:
        workspace.destroy()
    trtllm_mnnvl_ar.mpi_barrier()
flashinfer/comm/workspace_base.py (1)

52-61: Consider using a more specific type for use_oneshot parameter.

The use_oneshot: Optional[Any] parameter is very permissive. Based on the relevant code snippets, the TRTLLM backend uses bool while MNNVL uses MNNVLAllreduceFusionStrategy. Consider using a Union type or documenting the expected types per backend.

+from typing import Optional, Any, Union
+
 @abstractmethod
 def is_buffer_size_sufficient(
     self,
     tp_size: int,
     num_tokens: int,
     hidden_dim: int,
     dtype: torch.dtype,
-    use_oneshot: Optional[Any] = None,
+    strategy: Optional[Any] = None,  # Backend-specific: bool for TRTLLM, MNNVLAllreduceFusionStrategy for MNNVL
 ) -> bool:
     pass
tests/comm/test_trtllm_mnnvl_allreduce.py (2)

17-103: Consider inlining the nested func function.

The nested func function doesn't capture any variables from the outer scope that aren't already passed as parameters, and it's only called once. Consider inlining it for clarity.


233-271: Fix type annotation for reference_output.

The variable is initialized to None but typed as Tuple[torch.Tensor, ...]. Use Optional[Tuple[torch.Tensor, ...]] or remove the annotation since Python will infer it correctly from the subsequent assignments.

     x_local = x_full[rank, :, :]
-    reference_output: Tuple[torch.Tensor, ...] = None
+    reference_output: Optional[Tuple[torch.Tensor, ...]] = None
     if fusion:

Note: Optional is already imported from typing.

tests/comm/test_allreduce_unified_api.py (4)

26-46: Hardcoded port may cause conflicts in parallel test runs.

The hardcoded MASTER_PORT = "29500" could cause port conflicts if multiple test instances run concurrently on the same machine.

Consider using a dynamic port or environment variable fallback:

-    os.environ["MASTER_PORT"] = "29500"
+    os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")

133-146: Consider documenting why tolerances are set so high.

The tolerance values (rtol=0.05, atol=0.15) are quite loose for numerical validation. While this may be necessary for distributed bf16/fp16 operations, a brief comment explaining this would help future maintainers understand this isn't masking bugs.


190-197: Remove unused monkeypatch parameter.

The monkeypatch parameter is declared but never used in this function, as flagged by static analysis.

 def run_allreduce_test(
-    monkeypatch,
     seq_lens: list[int],
     fusion: bool,
     dtype: torch.dtype,
     hidden_size: int,
     backend: str,
 ):

Also update the call site in test_allreduce_unified accordingly.


314-334: Consider adding pytest markers for CI optimization.

The parametrized test creates 144 combinations (6 seq_lens × 2 fusion × 2 dtype × 2 hidden_size × 3 backend). Consider adding @pytest.mark.slow or similar markers to allow selective test execution in CI.

flashinfer/comm/trtllm_mnnvl_ar.py (2)

79-81: Use logging instead of print for production code.

Line 81 uses print() while other parts of this class use logging.debug(). For consistency and proper log level control, use logging:

-        print("Allocating MNNVL Allreduce Fusion Workspace...")
+        logging.info("Allocating MNNVL Allreduce Fusion Workspace...")

179-197: functools.cache on instance method can cause memory leaks.

Using @functools.cache on an instance method binds self to the cache key, preventing the instance from being garbage collected even after all other references are dropped. Since workspaces may be created and destroyed frequently, this could accumulate memory.

Consider using lru_cache with a bounded size, or move the caching to the static method get_required_buffer_size_bytes which doesn't have this issue:

-    @functools.cache
     def is_buffer_size_sufficient(
         self,
         tp_size: int,
         num_tokens: int,
         hidden_dim: int,
         dtype: torch.dtype,
         strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO,
     ) -> bool:
-        """
-        Calculate the required buffer size for a given problem size.
-        """
+        """Check if the buffer is sufficient for a given problem size."""
         required_buffer_size = self.get_required_buffer_size_bytes(
             tp_size, num_tokens, hidden_dim, dtype, strategy
         )
-        if required_buffer_size > self.buffer_size_bytes:
-            return False
-        else:
-            return True
+        return required_buffer_size <= self.buffer_size_bytes
flashinfer/comm/allreduce.py (4)

146-161: Use logging instead of print, and address unused parameter.

  1. Line 160 uses print() instead of logging for consistency with other modules.
  2. The use_oneshot parameter is unused (static analysis ARG002) but may be kept for API consistency with MNNVLAllReduceFusionWorkspace.is_buffer_size_sufficient.
+import logging
+
 ...
     def is_buffer_size_sufficient(
         self,
         tp_size: int,
         num_tokens: int,
         hidden_dim: int,
         dtype: torch.dtype,
-        use_oneshot: Optional[Any] = None,
+        use_oneshot: Optional[Any] = None,  # Unused, kept for API consistency with MNNVL
     ) -> bool:
         try:
             check_trtllm_allreduce_fusion_workspace_metadata(
                 num_tokens, hidden_dim, tp_size, dtype, self.metadata
             )
             return True
         except ValueError as e:
-            print(f"Workspace is insufficient for problem size. {e}")
+            logging.debug(f"Workspace is insufficient for problem size: {e}")
             return False

200-217: Simplify _mnnvl_workspace_check - currently always returns True.

The function has a redundant conditional structure since both branches return True. If this is intentional scaffolding for future checks, consider adding a comment; otherwise simplify:

 def _mnnvl_workspace_check(
     backend: str,
     world_size: int,
     rank: int,
     max_token_num: int,
     hidden_dim: int,
     dtype: torch.dtype,
     topology: Literal["single_node", "multi_node"],
 ) -> bool:
     """
     Check if mnnvl backend CAN be used for workspace creation.
-
     """
-
-    if topology == "multi_node":
-        return True
-
-    return True
+    # MNNVL supports both single-node and multi-node topologies
+    return True

286-297: Use explicit Optional type hints for parameters defaulting to None.

Parameters with None defaults should explicitly declare Optional types for PEP 484 compliance and better IDE support.

 def create_allreduce_fusion_workspace(
     backend: Literal["trtllm", "mnnvl", "auto"] = "auto",
-    world_size: int = None,
-    rank: int = None,
-    max_token_num: int = None,
-    hidden_dim: int = None,
-    dtype: torch.dtype = None,
+    world_size: Optional[int] = None,
+    rank: Optional[int] = None,
+    max_token_num: Optional[int] = None,
+    hidden_dim: Optional[int] = None,
+    dtype: Optional[torch.dtype] = None,
     topology: Literal["single_node", "multi_node"] = "single_node",
     process_group: Optional["torch.distributed.ProcessGroup"] = None,
-    gpus_per_node: int = None,
+    gpus_per_node: Optional[int] = None,
     comm_backend: Optional[CommBackend] = None,
 ) -> AllReduceFusionWorkspace:

682-693: Prefix unused variable with underscore.

The residual_result variable is unpacked but never used, as flagged by static analysis.

             # Call the MNNVL fusion function
-            norm_result, residual_result = trtllm_mnnvl_fused_allreduce_add_rmsnorm(
+            norm_result, _residual_result = trtllm_mnnvl_fused_allreduce_add_rmsnorm(
                 input=input,
                 residual_in=residual_in,
                 gamma=rms_gamma,
                 workspace=workspace,
                 epsilon=rms_eps,
                 output=norm_out,
                 residual_out=residual_out,
                 launch_with_pdl=launch_with_pdl,
             )
             return norm_result
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6bb01d1 and 9e672bf7434e07037423c5f9dc324042492e1bb8.

📒 Files selected for processing (13)
  • csrc/trtllm_mnnvl_allreduce.cu (1 hunks)
  • flashinfer/comm/__init__.py (1 hunks)
  • flashinfer/comm/allreduce.py (1 hunks)
  • flashinfer/comm/mnnvl.py (19 hunks)
  • flashinfer/comm/trtllm_ar.py (5 hunks)
  • flashinfer/comm/trtllm_mnnvl_ar.py (5 hunks)
  • flashinfer/comm/workspace_base.py (1 hunks)
  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh (2 hunks)
  • include/flashinfer/utils.cuh (2 hunks)
  • tests/comm/test_allreduce_negative.py (1 hunks)
  • tests/comm/test_allreduce_unified_api.py (1 hunks)
  • tests/comm/test_trtllm_allreduce_fusion.py (6 hunks)
  • tests/comm/test_trtllm_mnnvl_allreduce.py (9 hunks)
🧰 Additional context used
🧬 Code graph analysis (9)
tests/comm/test_allreduce_negative.py (4)
flashinfer/comm/allreduce.py (4)
  • create_allreduce_fusion_workspace (286-444)
  • allreduce_fusion (452-702)
  • backend (135-136)
  • destroy (163-169)
flashinfer/comm/trtllm_ar.py (2)
  • AllReduceFusionPattern (64-77)
  • QuantizationSFLayout (80-95)
flashinfer/comm/mapping.py (1)
  • local_rank (391-392)
flashinfer/comm/trtllm_mnnvl_ar.py (3)
  • backend (229-230)
  • destroy (232-242)
  • mpi_barrier (24-28)
flashinfer/comm/workspace_base.py (2)
flashinfer/comm/allreduce.py (3)
  • backend (135-136)
  • destroy (163-169)
  • is_buffer_size_sufficient (146-161)
flashinfer/comm/trtllm_mnnvl_ar.py (3)
  • backend (229-230)
  • destroy (232-242)
  • is_buffer_size_sufficient (180-197)
tests/comm/test_trtllm_allreduce_fusion.py (2)
flashinfer/comm/trtllm_ar.py (4)
  • trtllm_create_ipc_workspace_for_all_reduce_fusion (506-645)
  • trtllm_allreduce_fusion (230-275)
  • trtllm_allreduce_fusion (857-963)
  • trtllm_destroy_ipc_workspace_for_all_reduce_fusion (648-664)
flashinfer/comm/allreduce.py (4)
  • create_allreduce_fusion_workspace (286-444)
  • backend (135-136)
  • allreduce_fusion (452-702)
  • destroy (163-169)
tests/comm/test_trtllm_mnnvl_allreduce.py (5)
flashinfer/comm/mapping.py (2)
  • Mapping (21-475)
  • tp_rank (325-326)
flashinfer/comm/trtllm_mnnvl_ar.py (3)
  • MNNVLAllReduceFusionWorkspace (51-242)
  • trtllm_mnnvl_allreduce (327-402)
  • get_allreduce_mnnvl_workspace (503-559)
flashinfer/comm/mnnvl.py (8)
  • barrier (168-168)
  • barrier (227-228)
  • get_multicast_ptr (968-972)
  • get_multicast_ptr (1270-1272)
  • get_buffer_ptrs_dev (954-956)
  • get_buffer_ptrs_dev (1278-1280)
  • get_unicast_ptr (958-966)
  • get_unicast_ptr (1274-1276)
tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py (1)
  • barrier (60-64)
include/flashinfer/trtllm/fused_moe/runner.h (1)
  • hidden_size (265-265)
tests/comm/test_allreduce_unified_api.py (6)
flashinfer/comm/allreduce.py (4)
  • create_allreduce_fusion_workspace (286-444)
  • allreduce_fusion (452-702)
  • backend (135-136)
  • destroy (163-169)
flashinfer/comm/trtllm_ar.py (1)
  • AllReduceFusionPattern (64-77)
flashinfer/comm/workspace_base.py (1)
  • AllReduceFusionWorkspace (23-89)
flashinfer/comm/trtllm_mnnvl_ar.py (2)
  • backend (229-230)
  • destroy (232-242)
csrc/tvm_ffi_utils.h (1)
  • Tensor (287-289)
flashinfer/comm/mapping.py (1)
  • local_rank (391-392)
flashinfer/comm/trtllm_ar.py (1)
flashinfer/logits_processor/types.py (1)
  • dtype (126-130)
flashinfer/comm/mnnvl.py (2)
flashinfer/cuda_utils.py (1)
  • checkCudaErrors (51-61)
flashinfer/utils.py (1)
  • round_up (631-633)
flashinfer/comm/trtllm_mnnvl_ar.py (3)
flashinfer/comm/mapping.py (5)
  • rank (311-312)
  • rank (315-322)
  • tp_rank (325-326)
  • local_rank (391-392)
  • is_multi_node (403-404)
flashinfer/comm/mnnvl.py (13)
  • McastGPUBuffer (1199-1280)
  • CommBackend (152-171)
  • MPIBackend (211-232)
  • lamport_initialize (1179-1196)
  • lamport_initialize (1239-1240)
  • barrier (168-168)
  • barrier (227-228)
  • get_buffer_ptrs_dev (954-956)
  • get_buffer_ptrs_dev (1278-1280)
  • get_unicast_ptr (958-966)
  • get_unicast_ptr (1274-1276)
  • get_multicast_ptr (968-972)
  • get_multicast_ptr (1270-1272)
flashinfer/comm/workspace_base.py (4)
  • AllReduceFusionWorkspace (23-89)
  • is_buffer_size_sufficient (53-61)
  • backend (38-40)
  • destroy (43-50)
csrc/trtllm_mnnvl_allreduce.cu (2)
csrc/tvm_ffi_utils.h (1)
  • get_stream (277-279)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
  • trtllm_mnnvl_allreduce_fusion (269-320)
🪛 Ruff (0.14.8)
flashinfer/comm/workspace_base.py

83-83: Do not catch blind exception: Exception

(BLE001)

flashinfer/comm/allreduce.py

141-143: Avoid specifying long messages outside the exception class

(TRY003)


152-152: Unused method argument: use_oneshot

(ARG002)


158-158: Consider moving this statement to an else block

(TRY300)


178-178: Unused function argument: backend

(ARG001)


179-179: Unused function argument: world_size

(ARG001)


180-180: Unused function argument: rank

(ARG001)


181-181: Unused function argument: max_token_num

(ARG001)


182-182: Unused function argument: hidden_dim

(ARG001)


183-183: Unused function argument: dtype

(ARG001)


201-201: Unused function argument: backend

(ARG001)


202-202: Unused function argument: world_size

(ARG001)


203-203: Unused function argument: rank

(ARG001)


204-204: Unused function argument: max_token_num

(ARG001)


205-205: Unused function argument: hidden_dim

(ARG001)


206-206: Unused function argument: dtype

(ARG001)


227-227: Unused function argument: backend

(ARG001)


228-228: Unused function argument: world_size

(ARG001)


229-229: Unused function argument: rank

(ARG001)


230-230: Unused function argument: max_token_num

(ARG001)


231-231: Unused function argument: hidden_dim

(ARG001)


232-232: Unused function argument: dtype

(ARG001)


288-288: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


289-289: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


290-290: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


291-291: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


295-295: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


398-401: Avoid specifying long messages outside the exception class

(TRY003)


444-444: Avoid specifying long messages outside the exception class

(TRY003)


584-584: Avoid specifying long messages outside the exception class

(TRY003)


646-648: Avoid specifying long messages outside the exception class

(TRY003)


651-653: Avoid specifying long messages outside the exception class

(TRY003)


672-672: Avoid specifying long messages outside the exception class

(TRY003)


674-674: Avoid specifying long messages outside the exception class

(TRY003)


683-683: Unpacked variable residual_result is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


696-696: Avoid specifying long messages outside the exception class

(TRY003)


699-702: Avoid specifying long messages outside the exception class

(TRY003)

tests/comm/test_allreduce_unified_api.py

191-191: Unused function argument: monkeypatch

(ARG001)

flashinfer/comm/mnnvl.py

587-587: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"

(S108)


612-612: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"

(S108)


640-640: Unpacked variable msg is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


640-640: Unpacked variable flags is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


640-640: Unpacked variable addr is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


656-656: Avoid specifying long messages outside the exception class

(TRY003)


726-726: Standard pseudo-random generators are not suitable for cryptographic purposes

(S311)


1011-1011: Do not catch blind exception: Exception

(BLE001)

flashinfer/comm/trtllm_mnnvl_ar.py

122-124: Avoid specifying long messages outside the exception class

(TRY003)


179-179: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks

(B019)


359-361: Avoid specifying long messages outside the exception class

(TRY003)


366-368: Avoid specifying long messages outside the exception class

(TRY003)


380-382: Avoid specifying long messages outside the exception class

(TRY003)


442-444: Avoid specifying long messages outside the exception class

(TRY003)


446-448: Avoid specifying long messages outside the exception class

(TRY003)


450-452: Avoid specifying long messages outside the exception class

(TRY003)


456-458: Avoid specifying long messages outside the exception class

(TRY003)


462-464: Avoid specifying long messages outside the exception class

(TRY003)


475-477: Avoid specifying long messages outside the exception class

(TRY003)


608-610: Avoid specifying long messages outside the exception class

(TRY003)


679-681: Avoid specifying long messages outside the exception class

(TRY003)


685-687: Avoid specifying long messages outside the exception class

(TRY003)


690-692: Avoid specifying long messages outside the exception class

(TRY003)


694-696: Avoid specifying long messages outside the exception class

(TRY003)


699-701: Avoid specifying long messages outside the exception class

(TRY003)


704-706: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (30)
flashinfer/comm/__init__.py (1)

42-51: Re-exports look good; just watch for import cycles
Public surface wiring is straightforward; please just verify importing flashinfer.comm doesn’t create a circular import via .allreduce / .trtllm_mnnvl_ar.

tests/comm/test_allreduce_negative.py (3)

18-33: LGTM - Clean MPI/CUDA setup helper.

The helper properly handles skip conditions for missing CUDA devices and insufficient MPI ranks, and correctly computes local_rank for multi-GPU nodes.


72-127: LGTM - Comprehensive negative tests for unsupported patterns.

The parameterized tests properly verify that the MNNVL backend rejects quantization fusion patterns and layout codes with appropriate error messages.


130-185: LGTM - Required parameter validation tests are correct.

The tests properly verify that kARResidualRMSNorm pattern requires both residual_in and rms_gamma inputs, and that meaningful error messages are raised when they're missing.

flashinfer/comm/workspace_base.py (2)

23-34: LGTM - Clean abstract base class definition.

The base class properly establishes the contract for AllReduce fusion workspaces with appropriate type annotations and initialization logic.


63-89: LGTM - Appropriate destructor safety net.

The __del__ implementation correctly handles the constraints of Python destructors. The broad Exception catch (flagged by Ruff BLE001) is intentional here since exceptions cannot propagate from __del__, and the warning provides useful diagnostic information. The docstring clearly communicates that users should not rely on this behavior.

flashinfer/comm/trtllm_ar.py (6)

125-127: LGTM - Clear deprecation notice.

The deprecation message clearly directs users to the new unified API.


400-402: LGTM - Consistent deprecation messaging.


503-506: LGTM - Informative deprecation with migration guidance.


809-852: LGTM - Well-structured workspace metadata validation.

The validation function properly checks required keys first (failing fast if missing), then validates the values. The error accumulation pattern provides comprehensive error messages when multiple validation issues exist.


854-856: LGTM - Deprecation notice for legacy fusion function.


909-913: LGTM - Clean integration of metadata validation.

The optional metadata validation is properly integrated into the existing function flow.

tests/comm/test_trtllm_allreduce_fusion.py (6)

25-27: LGTM - Function signature updated to support both APIs.

The legacy_api=True default maintains backward compatibility while allowing explicit selection of the unified API.


62-88: LGTM - Clean workspace creation branching.

The code properly separates legacy and unified workspace creation paths, with the legacy path returning metadata for validation and the unified path using the new workspace abstraction.


392-403: LGTM - Proper cleanup logic for both API paths.

The cleanup correctly branches based on the API path and handles the case where the unified workspace may be None (if creation failed).


448-471: LGTM - Well-structured test parameterization.

Adding the legacy_api parameter doubles the test coverage to ensure both API paths are validated across all configuration combinations.


474-479: LGTM - Manual test entry point for both APIs.


209-240: The unified API correctly handles trigger_completion_at_end internally. At line 616 of allreduce.py, the parameter is automatically set as trigger_completion_at_end=launch_with_pdl, with an inline comment indicating they have the same meaning. The unified API intentionally doesn't expose this as an explicit parameter; instead, it's controlled via the launch_with_pdl argument passed to allreduce_fusion. This is a cleaner API design with equivalent behavior to the legacy path.

tests/comm/test_trtllm_mnnvl_allreduce.py (3)

2-2: LGTM - Added traceback import for better error diagnostics.


105-227: LGTM - Legacy API test function properly preserved.

The legacy function maintains the original API surface with explicit buffer pointers, enabling backward compatibility testing.


435-460: LGTM - Clear test separation between refactored and legacy APIs.

The parameterization differences between the two tests appropriately reflect the different capabilities of each API path.

tests/comm/test_allreduce_unified_api.py (1)

1-24: LGTM! Well-structured imports and module setup.

The imports are appropriate for MPI-based distributed testing with the unified API.

flashinfer/comm/trtllm_mnnvl_ar.py (5)

31-48: LGTM! Clean strategy enum with threshold-based selection.

The MNNVLAllreduceFusionStrategy enum and select_strategy method provide a clear interface for choosing between oneshot and twoshot strategies based on problem size.


129-130: Clarify rank attribute semantics.

Line 129 overwrites self.rank (set by base class __init__ to mapping.rank) with mapping.tp_rank. If rank and tp_rank differ, this creates ambiguity about which rank value the workspace represents. The base class AllReduceFusionWorkspace expects rank to be the process rank, but this stores the TP rank.

Please verify this is intentional. If both values are needed, consider using distinct attribute names:

-        self.rank = mapping.tp_rank
+        self.tp_rank = mapping.tp_rank

327-402: LGTM! Clean implementation with good input validation.

The function provides clear input validation and error messages. The strategy selection and buffer size checking are well-implemented.


405-496: LGTM! Comprehensive input validation for fused operation.

The function properly validates all tensor dimensions and provides informative error messages. The default epsilon handling is appropriate.


499-633: Good deprecation strategy for backward compatibility.

The @deprecated decorators with clear migration messages provide a smooth upgrade path for users of the legacy API.

flashinfer/comm/allreduce.py (3)

1-71: LGTM! Well-documented module with clear example usage.

The module docstring provides a good overview and usage example. Imports are appropriately organized.


568-639: LGTM! Clean TRTLLM backend dispatch with proper tensor handling.

The contiguity check and flattening logic is well-documented. The comment explaining why contiguous tensors are required for view semantics is helpful.


641-702: LGTM! Proper pattern validation for MNNVL backend.

Good validation of supported patterns with clear error messages directing users to TRTLLM for unsupported fusion patterns.

Comment on lines +39 to +74
DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(input.dtype(), c_type, [&] {
// Extract parameters from tensors
int64_t num_tokens = in.size(0);
int64_t token_dim = in.size(1);
int64_t num_tokens = input.size(0);
int64_t token_dim = input.size(1);

// Validate input parameters
TVM_FFI_ICHECK_EQ(token_dim % (sizeof(float2) / sizeof(c_type)), 0)
<< "token_dim must be divisible by " << sizeof(float2) / sizeof(c_type);
TVM_FFI_ICHECK_EQ(token_dim % (sizeof(float4) / sizeof(c_type)), 0)
<< "token_dim must be divisible by " << sizeof(float4) / sizeof(c_type);
TVM_FFI_ICHECK(output.size(0) == input.size(0) && output.size(1) == input.size(1))
<< "output shape mismatch: expected (" << input.size(0) << ", " << input.size(1)
<< ") but got (" << output.size(0) << ", " << output.size(1) << ")";
TVM_FFI_ICHECK(nranks >= 2 && nranks <= 64)
<< "nranks must be between 2 and 64, got " << nranks;
TVM_FFI_ICHECK(rank >= 0 && rank < nranks)
<< "rank must be between 0 and nranks-1, got " << rank;
TVM_FFI_ICHECK(out.has_value() || !wait_for_results)
<< "out tensor must be provided if wait_for_results is true";
TVM_FFI_ICHECK((residual_in.has_value() && residual_out.has_value() && gamma.has_value() &&
epsilon.has_value()) ||
!rmsnorm_fusion)
<< "residual_in, residual_out, gamma, and epsilon must be provided if rmsnorm_fusion is "
"true";

if (rmsnorm_fusion) {
TVM_FFI_ICHECK(residual_in.value().size(0) == num_tokens &&
residual_in.value().size(1) == token_dim)
<< "residual_in shape mismatch: expected (" << input.size(0) << ", " << input.size(1)
<< ") but got (" << residual_in.value().size(0) << ", " << residual_in.value().size(1)
<< ")";
TVM_FFI_ICHECK(residual_out.value().size(0) == num_tokens &&
residual_out.value().size(1) == token_dim)
<< "residual_out shape mismatch: expected (" << input.size(0) << ", " << input.size(1)
<< ") but got (" << residual_out.value().size(0) << ", " << residual_out.value().size(1)
<< ")";
TVM_FFI_ICHECK(gamma.value().size(0) == token_dim)
<< "gamma must have the same shape as token dimension (" << token_dim << ") but got ("
<< gamma.value().size(0) << ")";
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Missing dtype validation across tensors (risk: type confusion / memory corruption)
c_type is derived from input.dtype(), but output (and residual_in/out, gamma) are not checked to have the same dtype. If a caller passes mismatched dtypes, the kernel will reinterpret memory with the wrong element size. Add dtype checks for all participating tensors (especially in rmsnorm_fusion).

   DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(input.dtype(), c_type, [&] {
+    TVM_FFI_ICHECK_EQ(encode_dlpack_dtype(output.dtype()), encode_dlpack_dtype(input.dtype()))
+        << "output dtype must match input dtype";
     ...
     if (rmsnorm_fusion) {
+      TVM_FFI_ICHECK_EQ(encode_dlpack_dtype(residual_in.value().dtype()), encode_dlpack_dtype(input.dtype()))
+          << "residual_in dtype must match input dtype";
+      TVM_FFI_ICHECK_EQ(encode_dlpack_dtype(residual_out.value().dtype()), encode_dlpack_dtype(input.dtype()))
+          << "residual_out dtype must match input dtype";
+      TVM_FFI_ICHECK_EQ(encode_dlpack_dtype(gamma.value().dtype()), encode_dlpack_dtype(input.dtype()))
+          << "gamma dtype must match input dtype";
       ...
     }

Also applies to: 79-103

🤖 Prompt for AI Agents
In csrc/trtllm_mnnvl_allreduce.cu around lines 39 to 74 (and similarly apply
checks at lines ~79-103), the code derives c_type from input.dtype() but never
verifies that output and optional tensors share the same dtype; this can cause
type confusion and memory corruption. Add TVM_FFI_ICHECK (or equivalent) asserts
that output.dtype() == input.dtype(), and when rmsnorm_fusion is true also
assert residual_in.dtype(), residual_out.dtype(), and gamma.dtype() match
input.dtype() (and epsilon if it is a tensor), updating the error messages to
name the offending tensor and expected dtype. Ensure these checks are placed
before any pointer casts or size/divisibility checks so dtype mismatches are
caught early.

Comment thread flashinfer/comm/allreduce.py Outdated
Comment on lines +314 to +338
Use `workspace.is_sufficient_for(token_num, hidden_dim, dtype)` to check before use.

Args:
backend: Backend to use ("trtllm", "mnnvl", or "auto")
"auto" uses heuristic to select best backend based on topology
and problem size
world_size: Number of ranks in the process group
rank: Current rank ID
max_token_num: Maximum number of tokens to support
hidden_dim: Hidden dimension size
dtype: Data type for communication tensors
topology: Network topology hint for backend selection
"single_node" - All ranks on one node (default)
"multi_node" - Ranks span multiple nodes
process_group: PyTorch distributed process group (for trtllm backend).
gpus_per_node: Number of GPUs per node (for multi-node topology).
comm_backend: Communication backend to use (for multi-node topology).

Returns:
Workspace object (TRTLLMAllReduceFusionWorkspace or MNNVLAllReduceFusionWorkspace)
The workspace type determines which backend will be used in allreduce_fusion()

Raises:
BackendSupportedError: If no suitable backend available for the configuration
ValueError: If problem size not supported for the specified backend
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Documentation inconsistencies with implementation.

  1. Line 314 references workspace.is_sufficient_for() but the actual method is is_buffer_size_sufficient().
  2. Line 337 mentions BackendSupportedError but the code raises ValueError (Line 398).
-    Use `workspace.is_sufficient_for(token_num, hidden_dim, dtype)` to check before use.
+    Use `workspace.is_buffer_size_sufficient(tp_size, token_num, hidden_dim, dtype)` to check before use.
 ...
     Raises:
-        BackendSupportedError: If no suitable backend available for the configuration
+        ValueError: If no suitable backend available for the configuration
         ValueError: If problem size not supported for the specified backend
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
Use `workspace.is_sufficient_for(token_num, hidden_dim, dtype)` to check before use.
Args:
backend: Backend to use ("trtllm", "mnnvl", or "auto")
"auto" uses heuristic to select best backend based on topology
and problem size
world_size: Number of ranks in the process group
rank: Current rank ID
max_token_num: Maximum number of tokens to support
hidden_dim: Hidden dimension size
dtype: Data type for communication tensors
topology: Network topology hint for backend selection
"single_node" - All ranks on one node (default)
"multi_node" - Ranks span multiple nodes
process_group: PyTorch distributed process group (for trtllm backend).
gpus_per_node: Number of GPUs per node (for multi-node topology).
comm_backend: Communication backend to use (for multi-node topology).
Returns:
Workspace object (TRTLLMAllReduceFusionWorkspace or MNNVLAllReduceFusionWorkspace)
The workspace type determines which backend will be used in allreduce_fusion()
Raises:
BackendSupportedError: If no suitable backend available for the configuration
ValueError: If problem size not supported for the specified backend
Use `workspace.is_buffer_size_sufficient(tp_size, token_num, hidden_dim, dtype)` to check before use.
Args:
backend: Backend to use ("trtllm", "mnnvl", or "auto")
"auto" uses heuristic to select best backend based on topology
and problem size
world_size: Number of ranks in the process group
rank: Current rank ID
max_token_num: Maximum number of tokens to support
hidden_dim: Hidden dimension size
dtype: Data type for communication tensors
topology: Network topology hint for backend selection
"single_node" - All ranks on one node (default)
"multi_node" - Ranks span multiple nodes
process_group: PyTorch distributed process group (for trtllm backend).
gpus_per_node: Number of GPUs per node (for multi-node topology).
comm_backend: Communication backend to use (for multi-node topology).
Returns:
Workspace object (TRTLLMAllReduceFusionWorkspace or MNNVLAllReduceFusionWorkspace)
The workspace type determines which backend will be used in allreduce_fusion()
Raises:
ValueError: If no suitable backend available for the configuration
ValueError: If problem size not supported for the specified backend
🤖 Prompt for AI Agents
In flashinfer/comm/allreduce.py around lines 314 to 338, the docstring is
inconsistent with the implementation: replace the reference to
workspace.is_sufficient_for(...) with the actual method name
workspace.is_buffer_size_sufficient(...), and change the listed raised exception
BackendSupportedError to ValueError to match the current code (or alternatively
implement a BackendSupportedError and raise it at line ~398 if you prefer
changing code instead of docs); update the docstring accordingly so method names
and exception types match the implementation.

Comment thread flashinfer/comm/mnnvl.py
Comment on lines +566 to +665
# The helper class for passing the FD handle over the socket.
class IpcSocket:
"""Unix Domain Socket for IPC file descriptor passing"""

def __init__(self, rank: int, op_id: int, use_abstract=True):
"""
Initialize IPC socket

Args:
rank: Process rank
op_id: Unique operation ID (hash)
use_abstract: Use Linux abstract socket namespace
"""
self.rank = rank
self.op_id = op_id
self.use_abstract = use_abstract

# Create Unix domain socket (DGRAM for compatibility with C code)
self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)

# Create unique socket name
socket_name = f"/tmp/mcastmem-socket-{rank}-{op_id:x}"

if use_abstract:
# Linux abstract socket: prepend null byte
self.socket_path = "\0" + socket_name
else:
self.socket_path = socket_name
# Remove existing socket file if it exists
with contextlib.suppress(FileNotFoundError):
os.unlink(socket_name)

# Bind socket
self.sock.bind(self.socket_path)

def send_fd(self, fd: int, dest_rank: int, dest_op_id: Optional[int] = None):
"""
Send a file descriptor to another process

Args:
fd: File descriptor to send
dest_rank: Destination process rank
dest_op_id: Destination operation ID
"""
# Construct destination socket path
dest_op_id = dest_op_id or self.op_id
dest_socket_name = f"/tmp/mcastmem-socket-{dest_rank}-{dest_op_id:x}"

if self.use_abstract:
dest_path = "\0" + dest_socket_name
else:
dest_path = dest_socket_name

# Prepare message with file descriptor
# Send dummy byte as data (required)
dummy_data = b"\x00"

# Pack file descriptor in ancillary data (SCM_RIGHTS)
fds = array.array("i", [fd])
ancillary = [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds.tobytes())]

# Send message with file descriptor
self.sock.sendmsg([dummy_data], ancillary, 0, dest_path)

def recv_fd(self):
"""
Receive a file descriptor from another process

Returns:
int: Received file descriptor
"""
# Receive message with ancillary data
# Maximum size for ancillary data containing one fd
fds = array.array("i")
msg, ancdata, flags, addr = self.sock.recvmsg(
1,
socket.CMSG_SPACE(
fds.itemsize
), # Buffer size for dummy data # Ancillary data size
)

# Extract file descriptor from ancillary data
for cmsg_level, cmsg_type, cmsg_data in ancdata:
if cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS:
fds = array.array("i")
fds.frombytes(
cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]
)
return fds[0]

raise RuntimeError("No file descriptor received")

def close(self):
"""Close the socket"""
self.sock.close()
if not self.use_abstract and self.socket_path:
with contextlib.suppress(FileNotFoundError):
os.unlink(self.socket_path)


Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

IPC socket naming + op-id generation: avoid /tmp semantics and random
Even with abstract sockets, embedding "/tmp/..." in the abstract name is misleading (and ruff flags it); also random.randint() is avoidable. Prefer an abstract name without a filesystem-looking prefix and generate opId via secrets.randbits(64) (collision-resistant, no extra baggage).

-import random
+import secrets
 ...
 class PosixFDHandleExchanger(HandleExchanger):
   def _init_ipc_socket(self) -> IpcSocket:
     if self.rank == 0:
-        opId = random.randint(0, 2**64 - 1)
+        opId = secrets.randbits(64)
-        socket_name = f"/tmp/mcastmem-socket-{rank}-{op_id:x}"
+        socket_name = f"mcastmem-socket-{rank}-{op_id:x}"

Also applies to: 717-731

🧰 Tools
🪛 Ruff (0.14.8)

587-587: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"

(S108)


612-612: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"

(S108)


640-640: Unpacked variable msg is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


640-640: Unpacked variable flags is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


640-640: Unpacked variable addr is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


656-656: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
flashinfer/comm/mnnvl.py lines 566-665 (and similar block at 717-731): the code
embeds a filesystem-looking "/tmp/..." string into abstract socket names and
relies on random.randint for op_id generation; change the socket name to a
concise abstract-only identifier (e.g., f"mcastmem-{rank}-{op_id:x}" without any
leading "/tmp") and keep the null-byte prefix only when use_abstract is True;
for non-abstract sockets keep the full filesystem path but do not use that
pattern for abstract names; replace any random.randint op_id generation sites
with secrets.randbits(64) to produce collision-resistant IDs; ensure unlinking
of socket files only happens for non-abstract filesystem paths and update
docstrings/comments accordingly.

Comment thread flashinfer/comm/mnnvl.py
Comment on lines +21 to 25
#include <cuda_runtime.h>

#include <iostream>
#include <type_traits>

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Host-side __CUDA_ARCH__ gating won’t do what you expect (cluster features likely disabled)
adjustGridConfig() and the host dispatch code use #if defined(__CUDA_ARCH__) to select cluster sizing and to set cudaLaunchAttributeClusterDimension, but __CUDA_ARCH__ is not defined for host compilation, so these branches won’t trigger at runtime (even on SM90+). Prefer a runtime compute capability / attribute check (e.g., cudaDeviceGetAttribute or flashinfer::GetCudaComputeCapability()) and gate cluster attributes based on that.

Also applies to: 54-171, 448-498, 655-704, 1124-1150

🤖 Prompt for AI Agents
In include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh around lines 21-25 (and
also apply similar changes to ranges 54-171, 448-498, 655-704, 1124-1150): the
code incorrectly uses compile-time __CUDA_ARCH__ to gate host-side logic (so
cluster/SM90 features are never enabled at runtime). Replace those #if
defined(__CUDA_ARCH__) host branches with a runtime check—use
cudaDeviceGetAttribute (or flashinfer::GetCudaComputeCapability()) to query the
device compute capability or the relevant CUDA attribute and only set/use
cudaLaunchAttributeClusterDimension when the runtime capability indicates
cluster support (e.g., major>=9 or specific attribute). Update
adjustGridConfig() and the host dispatch code to call the runtime check, branch
based on its result, and ensure any cudaLaunchKernel attribute setting is
guarded by that runtime boolean.

Comment on lines 23 to 25
#include <iostream>
#include <type_traits>

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Missing direct standard headers for new types (std::array, std::tuple)
This file uses std::array and std::tuple but doesn’t include <array> / <tuple> locally. Relying on transitive includes is brittle and can break depending on include order.

 #include <iostream>
 #include <type_traits>
+ #include <array>
+ #include <tuple>

Also applies to: 123-142, 143-163, 243-289, 450-498

🤖 Prompt for AI Agents
In include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh around lines 23-25 (and
also relevant to usages in ranges 123-142, 143-163, 243-289, 450-498), the file
uses std::array and std::tuple but does not include the direct standard headers;
add #include <array> and #include <tuple> at the top of this file (near the
existing <type_traits> include) so the types are guaranteed to be available
regardless of transitive includes, then rebuild to ensure no other
missing-direct-header warnings remain.

Comment on lines +560 to +580
PackedVec<PackedType, float> valuesLamport[WorldSize];
while (1) {
bool valid = true;
#pragma unroll
for (int r = 0; r < WorldSize; r++) {
valuesLamport[r].packed = loadPackedVolatile<PackedType>(
&stagePtrLocal[token * tokenDim * WorldSize + r * tokenDim +
packedIdx * kELTS_PER_THREAD]);

#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
#pragma unroll
for (int i = 0; i < kLAMPORT_ELTS_PER_PACKED; i++) {
valid &= !isNegZero(valuesLamport[r].elements[i]);
}
}
if (valid) {
break;
}
}

auto values = reinterpret_cast<PackedVec<PackedType, T>*>(valuesLamport);
// ======================= Reduction =============================
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

UB risk: type-punning via reinterpret_cast between different PackedVec instantiations
auto values = reinterpret_cast<PackedVec<PackedType, T>*>(valuesLamport); relies on aliasing between PackedVec<PackedType, float> and PackedVec<PackedType, T>. That’s undefined behavior in C++ and can miscompile under optimization. Consider loading into a PackedVec<PackedType, T> per-rank (or use memcpy/bit_cast-style copying) rather than pointer re-interpretation.

Also applies to: 825-847

🤖 Prompt for AI Agents
In include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh around lines 560-580 (and
similarly at 825-847), the code performs type-punning via reinterpret_cast from
PackedVec<PackedType,float> to PackedVec<PackedType,T>, which is undefined
behavior; replace the reinterpret_cast by loading/storing into the correct
target type: either declare the per-rank buffer as PackedVec<PackedType,T> and
call the appropriate load routine into it, or explicitly copy the bytes from the
float-instantiated buffer into a PackedVec<PackedType,T> instance using a safe
byte-wise copy (memcpy or std::bit_cast-like copy) for each element before use;
ensure volatile semantics are preserved for the load step and remove the
reinterpret_cast to eliminate aliasing UB.

@@ -21,6 +21,7 @@
#include <cuda_fp8.h>
#include <cuda_runtime.h>

#include <atomic>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Cache is not device-aware + caches failures silently
GetCudaMultiProcessorCount() caches a single sm_count for the whole process and never revalidates on device changes; also it doesn’t check CUDA return codes, so a transient failure can get cached as “0”. Consider caching (device_id, sm_count) and only storing on success.

 inline int GetCudaMultiProcessorCount() {
-  static std::atomic<int> sm_count{0};
-  int cached = sm_count.load(std::memory_order_relaxed);
-  if (cached == 0) {
-    int device_id;
-    cudaGetDevice(&device_id);
-    cudaDeviceProp device_prop;
-    cudaGetDeviceProperties(&device_prop, device_id);
-    cached = device_prop.multiProcessorCount;
-    sm_count.store(cached, std::memory_order_relaxed);
-  }
-  return cached;
+  static std::atomic<int> cached_device{-1};
+  static std::atomic<int> cached_sm{0};
+
+  int device_id = 0;
+  if (cudaGetDevice(&device_id) != cudaSuccess) return 0;
+
+  int sm = cached_sm.load(std::memory_order_relaxed);
+  int dev = cached_device.load(std::memory_order_relaxed);
+  if (sm != 0 && dev == device_id) return sm;
+
+  cudaDeviceProp prop{};
+  if (cudaGetDeviceProperties(&prop, device_id) != cudaSuccess) return 0;
+
+  cached_device.store(device_id, std::memory_order_relaxed);
+  cached_sm.store(prop.multiProcessorCount, std::memory_order_relaxed);
+  return prop.multiProcessorCount;
 }

Also applies to: 339-353

Comment on lines +149 to +187
def prepare_test_data(seq_len: int, hidden_size: int, dtype: torch.dtype, fusion: bool):
"""Prepare test data distributed across MPI ranks."""
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
world_size = comm.Get_size()
if rank == 0:
x_full = torch.randn((world_size, seq_len, hidden_size), dtype=dtype)
residual = torch.randn((seq_len, hidden_size), dtype=dtype)
norm_weight = torch.randn((hidden_size,), dtype=dtype)
else:
x_full = None
residual = None
norm_weight = None

# Use lowercase bcast() for Python object broadcasting
x_full = comm.bcast(x_full, root=0)
residual = comm.bcast(residual, root=0)
norm_weight = comm.bcast(norm_weight, root=0)

x_full = x_full.cuda()
residual = residual.cuda()
norm_weight = norm_weight.cuda()

x_local = x_full[rank, :, :]
reference_output: Tuple[torch.Tensor, ...] = None
if fusion:
# Fused case: AllReduce + Residual Add + RMS Norm
allreduce_result = torch.sum(x_full, dim=0) # AllReduce result
residual_out = allreduce_result + residual # Add residual
norm_out = rmsnorm(
residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False
)

reference_output = (norm_out, residual_out)
else:
# Non-fused case: Only AllReduce
allreduce_result = torch.sum(x_full, dim=0) # AllReduce result
reference_output = (allreduce_result,)
return (x_local, residual, norm_weight), reference_output
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Epsilon inconsistency between reference and test.

The reference RMSNorm at Line 179 uses torch.finfo(dtype).eps, but run_allreduce_test passes eps = 1e-5 (Line 234) to the actual fusion operation. This inconsistency could cause subtle differences between reference and test outputs.

Consider passing the same epsilon value to both:

-def prepare_test_data(seq_len: int, hidden_size: int, dtype: torch.dtype, fusion: bool):
+def prepare_test_data(seq_len: int, hidden_size: int, dtype: torch.dtype, fusion: bool, eps: float):
     ...
         norm_out = rmsnorm(
-            residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False
+            residual_out, norm_weight, eps, enable_pdl=False
         )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def prepare_test_data(seq_len: int, hidden_size: int, dtype: torch.dtype, fusion: bool):
"""Prepare test data distributed across MPI ranks."""
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
world_size = comm.Get_size()
if rank == 0:
x_full = torch.randn((world_size, seq_len, hidden_size), dtype=dtype)
residual = torch.randn((seq_len, hidden_size), dtype=dtype)
norm_weight = torch.randn((hidden_size,), dtype=dtype)
else:
x_full = None
residual = None
norm_weight = None
# Use lowercase bcast() for Python object broadcasting
x_full = comm.bcast(x_full, root=0)
residual = comm.bcast(residual, root=0)
norm_weight = comm.bcast(norm_weight, root=0)
x_full = x_full.cuda()
residual = residual.cuda()
norm_weight = norm_weight.cuda()
x_local = x_full[rank, :, :]
reference_output: Tuple[torch.Tensor, ...] = None
if fusion:
# Fused case: AllReduce + Residual Add + RMS Norm
allreduce_result = torch.sum(x_full, dim=0) # AllReduce result
residual_out = allreduce_result + residual # Add residual
norm_out = rmsnorm(
residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False
)
reference_output = (norm_out, residual_out)
else:
# Non-fused case: Only AllReduce
allreduce_result = torch.sum(x_full, dim=0) # AllReduce result
reference_output = (allreduce_result,)
return (x_local, residual, norm_weight), reference_output
def prepare_test_data(seq_len: int, hidden_size: int, dtype: torch.dtype, fusion: bool, eps: float):
"""Prepare test data distributed across MPI ranks."""
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
world_size = comm.Get_size()
if rank == 0:
x_full = torch.randn((world_size, seq_len, hidden_size), dtype=dtype)
residual = torch.randn((seq_len, hidden_size), dtype=dtype)
norm_weight = torch.randn((hidden_size,), dtype=dtype)
else:
x_full = None
residual = None
norm_weight = None
# Use lowercase bcast() for Python object broadcasting
x_full = comm.bcast(x_full, root=0)
residual = comm.bcast(residual, root=0)
norm_weight = comm.bcast(norm_weight, root=0)
x_full = x_full.cuda()
residual = residual.cuda()
norm_weight = norm_weight.cuda()
x_local = x_full[rank, :, :]
reference_output: Tuple[torch.Tensor, ...] = None
if fusion:
# Fused case: AllReduce + Residual Add + RMS Norm
allreduce_result = torch.sum(x_full, dim=0) # AllReduce result
residual_out = allreduce_result + residual # Add residual
norm_out = rmsnorm(
residual_out, norm_weight, eps, enable_pdl=False
)
reference_output = (norm_out, residual_out)
else:
# Non-fused case: Only AllReduce
allreduce_result = torch.sum(x_full, dim=0) # AllReduce result
reference_output = (allreduce_result,)
return (x_local, residual, norm_weight), reference_output
🤖 Prompt for AI Agents
In tests/comm/test_allreduce_unified_api.py around lines 149 to 187, the
reference RMSNorm uses torch.finfo(dtype).eps while the actual fusion
run_allreduce_test uses eps = 1e-5, causing a mismatch; fix by computing a
single eps value (e.g., eps = 1e-5 or derive it once based on dtype) before
building reference_output and pass that same eps into the rmsnorm call so both
reference and tested code use identical epsilon.

Comment thread tests/comm/test_trtllm_mnnvl_allreduce.py
@nvmbreughe nvmbreughe force-pushed the mbreughe/allreduce_unified branch from 9e672bf to f707678 Compare December 12, 2025 20:59
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
flashinfer/comm/allreduce.py (1)

314-338: Documentation inconsistencies remain from previous review.

As noted in a previous review comment:

  1. Line 314 references workspace.is_sufficient_for() but the actual method is is_buffer_size_sufficient()
  2. Line 337 mentions BackendSupportedError but the code raises ValueError at line 398
tests/comm/test_allreduce_unified_api.py (1)

149-187: Epsilon inconsistency between reference and test remains.

As noted in a previous review: the reference RMSNorm at line 179 uses torch.finfo(dtype).eps, but run_allreduce_test uses eps = 1e-5 (line 234). This can cause subtle numerical differences between reference and test outputs.

For float16, torch.finfo(dtype).eps ≈ 9.77e-4, which differs significantly from 1e-5.

🧹 Nitpick comments (7)
flashinfer/comm/allreduce.py (5)

138-144: __getattr__ delegation may cause infinite recursion.

The __getattr__ method attempts to delegate attribute access to _internal_workspace, but if _internal_workspace itself doesn't exist (e.g., during initialization failure), this will cause infinite recursion. The check for name.startswith("_") helps but doesn't fully protect against this edge case.

Consider adding a safeguard:

 def __getattr__(self, name):
     """Delegate attribute access to internal workspace if not found."""
     if name.startswith("_"):
         raise AttributeError(
             f"'{type(self).__name__}' object has no attribute '{name}'"
         )
+    if "_internal_workspace" not in self.__dict__:
+        raise AttributeError(
+            f"'{type(self).__name__}' object has no attribute '{name}'"
+        )
     return getattr(self._internal_workspace, name)

146-161: Consider using logging instead of print for validation failures.

The print(f"Workspace is insufficient...") on line 160 outputs directly to stdout. For a library function, consider using logging.warning() or returning a more structured error message. This aligns with the logging usage elsewhere in the codebase (e.g., trtllm_ar.py uses logging.warning).

The unused use_oneshot parameter (Ruff ARG002) is intentional for ABC interface compliance with MNNVLAllReduceFusionWorkspace.is_buffer_size_sufficient which does use this parameter.


200-218: Simplify _mnnvl_workspace_check - currently always returns True.

The function has a redundant conditional structure:

if topology == "multi_node":
    return True
return True

Both branches return True. Consider simplifying to just return True with a comment explaining that MNNVL works with both topologies, or remove the function entirely and inline the check. The unused parameters are acceptable for interface consistency with _trtllm_workspace_check.

 def _mnnvl_workspace_check(
     backend: str,
     world_size: int,
     rank: int,
     max_token_num: int,
     hidden_dim: int,
     dtype: torch.dtype,
     topology: Literal["single_node", "multi_node"],
 ) -> bool:
     """
     Check if mnnvl backend CAN be used for workspace creation.
-
+    MNNVL supports both single-node and multi-node topologies.
     """
-
-    if topology == "multi_node":
-        return True
-
     return True

682-693: Unused residual_result variable.

The unpacked residual_result is never used (Ruff RUF059). The residual output is written to the residual_out tensor passed as an argument. Consider using _ to indicate the value is intentionally ignored:

-            norm_result, residual_result = trtllm_mnnvl_fused_allreduce_add_rmsnorm(
+            norm_result, _ = trtllm_mnnvl_fused_allreduce_add_rmsnorm(

286-297: Consider explicit Optional types for parameters with None defaults.

Several parameters have None as default but aren't explicitly typed as Optional:

world_size: int = None,  # Should be Optional[int] = None
rank: int = None,
max_token_num: int = None,
hidden_dim: int = None,
gpus_per_node: int = None,

Per PEP 484, implicit Optional is discouraged (Ruff RUF013). However, these appear to be required parameters in practice - None values would cause runtime errors. Consider either:

  1. Adding validation and raising early if None
  2. Making them required (no default)
  3. Explicitly typing as Optional[int]
tests/comm/test_allreduce_unified_api.py (2)

26-45: Hardcoded MASTER_PORT may cause port conflicts.

The MASTER_PORT = "29500" is hardcoded. If another process is using this port or if multiple test runs execute concurrently, initialization will fail. Consider using a dynamic port similar to get_open_port() in test_trtllm_allreduce_fusion.py:

import socket

def get_open_port():
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(("127.0.0.1", 0))
        return str(s.getsockname()[1])

However, since MPI tests typically run in isolation and this port is only used for torch.distributed NCCL rendezvous, the risk is lower in practice.


190-197: Unused monkeypatch parameter.

The monkeypatch fixture is passed to run_allreduce_test but never used (Ruff ARG001). Consider removing it if not needed, or prefix with underscore if planned for future use:

 def run_allreduce_test(
-    monkeypatch,
+    _monkeypatch,  # Reserved for future use
     seq_lens: list[int],
     ...

Or simply remove from both run_allreduce_test and test_allreduce_unified:

-def test_allreduce_unified(
-    monkeypatch,
+def test_allreduce_unified(
     seq_lens: list[int],
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9e672bf7434e07037423c5f9dc324042492e1bb8 and f707678.

📒 Files selected for processing (9)
  • flashinfer/comm/__init__.py (1 hunks)
  • flashinfer/comm/allreduce.py (1 hunks)
  • flashinfer/comm/trtllm_ar.py (5 hunks)
  • flashinfer/comm/trtllm_mnnvl_ar.py (10 hunks)
  • flashinfer/comm/workspace_base.py (1 hunks)
  • tests/comm/test_allreduce_negative.py (1 hunks)
  • tests/comm/test_allreduce_unified_api.py (1 hunks)
  • tests/comm/test_trtllm_allreduce_fusion.py (6 hunks)
  • tests/comm/test_trtllm_mnnvl_allreduce.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
  • flashinfer/comm/init.py
  • tests/comm/test_trtllm_mnnvl_allreduce.py
  • flashinfer/comm/trtllm_mnnvl_ar.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • flashinfer/comm/allreduce.py
🧬 Code graph analysis (4)
flashinfer/comm/trtllm_ar.py (1)
flashinfer/logits_processor/types.py (1)
  • dtype (126-130)
tests/comm/test_allreduce_unified_api.py (7)
flashinfer/comm/allreduce.py (4)
  • create_allreduce_fusion_workspace (286-444)
  • allreduce_fusion (452-702)
  • backend (135-136)
  • destroy (163-169)
flashinfer/comm/trtllm_ar.py (1)
  • AllReduceFusionPattern (64-77)
flashinfer/comm/workspace_base.py (1)
  • AllReduceFusionWorkspace (23-89)
flashinfer/comm/trtllm_mnnvl_ar.py (3)
  • backend (229-230)
  • mpi_barrier (24-28)
  • destroy (232-242)
include/flashinfer/trtllm/fused_moe/runner.h (1)
  • hidden_size (265-265)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
  • enable_pdl (220-220)
flashinfer/comm/mapping.py (1)
  • local_rank (391-392)
flashinfer/comm/allreduce.py (4)
flashinfer/comm/workspace_base.py (4)
  • AllReduceFusionWorkspace (23-89)
  • backend (38-40)
  • is_buffer_size_sufficient (53-61)
  • destroy (43-50)
flashinfer/comm/trtllm_ar.py (3)
  • trtllm_allreduce_fusion (230-275)
  • trtllm_allreduce_fusion (857-963)
  • AllReduceFusionPattern (64-77)
flashinfer/comm/mapping.py (2)
  • Mapping (21-475)
  • tp_rank (325-326)
flashinfer/comm/trtllm_mnnvl_ar.py (4)
  • MNNVLAllReduceFusionWorkspace (51-242)
  • backend (229-230)
  • is_buffer_size_sufficient (180-197)
  • destroy (232-242)
flashinfer/comm/workspace_base.py (2)
flashinfer/comm/allreduce.py (3)
  • backend (135-136)
  • destroy (163-169)
  • is_buffer_size_sufficient (146-161)
flashinfer/comm/trtllm_mnnvl_ar.py (3)
  • backend (229-230)
  • destroy (232-242)
  • is_buffer_size_sufficient (180-197)
🪛 Ruff (0.14.8)
tests/comm/test_allreduce_unified_api.py

191-191: Unused function argument: monkeypatch

(ARG001)

flashinfer/comm/allreduce.py

141-143: Avoid specifying long messages outside the exception class

(TRY003)


152-152: Unused method argument: use_oneshot

(ARG002)


158-158: Consider moving this statement to an else block

(TRY300)


178-178: Unused function argument: backend

(ARG001)


179-179: Unused function argument: world_size

(ARG001)


180-180: Unused function argument: rank

(ARG001)


181-181: Unused function argument: max_token_num

(ARG001)


182-182: Unused function argument: hidden_dim

(ARG001)


183-183: Unused function argument: dtype

(ARG001)


201-201: Unused function argument: backend

(ARG001)


202-202: Unused function argument: world_size

(ARG001)


203-203: Unused function argument: rank

(ARG001)


204-204: Unused function argument: max_token_num

(ARG001)


205-205: Unused function argument: hidden_dim

(ARG001)


206-206: Unused function argument: dtype

(ARG001)


227-227: Unused function argument: backend

(ARG001)


228-228: Unused function argument: world_size

(ARG001)


229-229: Unused function argument: rank

(ARG001)


230-230: Unused function argument: max_token_num

(ARG001)


231-231: Unused function argument: hidden_dim

(ARG001)


232-232: Unused function argument: dtype

(ARG001)


288-288: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


289-289: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


290-290: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


291-291: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


295-295: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


398-401: Avoid specifying long messages outside the exception class

(TRY003)


444-444: Avoid specifying long messages outside the exception class

(TRY003)


584-584: Avoid specifying long messages outside the exception class

(TRY003)


646-648: Avoid specifying long messages outside the exception class

(TRY003)


651-653: Avoid specifying long messages outside the exception class

(TRY003)


672-672: Avoid specifying long messages outside the exception class

(TRY003)


674-674: Avoid specifying long messages outside the exception class

(TRY003)


683-683: Unpacked variable residual_result is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


696-696: Avoid specifying long messages outside the exception class

(TRY003)


699-702: Avoid specifying long messages outside the exception class

(TRY003)

flashinfer/comm/workspace_base.py

83-83: Do not catch blind exception: Exception

(BLE001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (11)
flashinfer/comm/workspace_base.py (2)

63-89: LGTM - safety net destructor with appropriate error handling.

The __del__ implementation correctly handles the edge case where destroy() wasn't called explicitly. Catching bare Exception (flagged by Ruff BLE001) is intentional here since __del__ cannot propagate exceptions and we need to warn about any cleanup failure.

One minor note: stacklevel=2 in warnings from __del__ may not point to the expected call site since __del__ is called by the garbage collector, not directly by user code. Consider using stacklevel=1 or omitting it, but this is cosmetic.


23-61: Well-designed abstract base class.

The AllReduceFusionWorkspace ABC establishes a clean contract for workspace lifecycle management. The constructor initializes common attributes, and abstract methods ensure consistent interface across TRTLLMAllReduceFusionWorkspace and MNNVLAllReduceFusionWorkspace subclasses.

flashinfer/comm/trtllm_ar.py (2)

809-852: Clear metadata validation with good error aggregation.

The check_trtllm_allreduce_fusion_workspace_metadata function properly validates workspace compatibility:

  • Checks required keys first before accessing them
  • Aggregates multiple errors for better developer experience
  • Validates world_size, buffer capacity, and dtype alignment

One observation: the function raises on missing keys, then continues to check values. If keys are missing but no error was raised yet, the subsequent checks would fail with KeyError. This is correctly handled by the early if errors: raise after key checks.


503-506: Consistent deprecation messaging guiding users to unified API.

The deprecation decorators now consistently direct users to the unified allreduce.py API, which aligns with the PR's goal of consolidating the AllReduce interface.

tests/comm/test_allreduce_negative.py (3)

36-61: Well-structured test fixture with proper cleanup.

The autouse fixture pattern ensures workspace cleanup happens regardless of test outcome. Using yield before cleanup is correct for pytest fixtures.


72-128: Good parameterized coverage of unsupported patterns and layout codes.

The tests comprehensively verify that MNNVL backend rejects:

  1. Quantization fusion patterns (FP8/FP4 variants)
  2. Any layout_code specification

Error message patterns in pytest.raises(match=...) align with the implementation in allreduce.py.


130-185: Required parameter validation tests are thorough.

Tests verify that kARResidualRMSNorm pattern correctly raises ValueError when:

  • residual_in is missing
  • rms_gamma is missing

This aligns with the validation logic in allreduce_fusion() for the MNNVL backend.

tests/comm/test_trtllm_allreduce_fusion.py (3)

62-88: Clean separation of legacy and unified workspace creation paths.

The conditional workspace creation properly handles both APIs:

  • Legacy: Uses trtllm_create_ipc_workspace_for_all_reduce_fusion with metadata
  • Unified: Uses create_allreduce_fusion_workspace returning a workspace object

Note: Line 77 initializes workspace = None before the unified API path, which is good for the cleanup logic at line 400.


209-240: Unified API integration correctly adapts tensor shapes.

The unified API calls correctly reshape 1D tensors to 2D [token_num, hidden_dim] format using .view(). Since the original tensors are contiguous, these views maintain contiguity as required by allreduce_fusion().


448-471: Excellent dual-API test coverage.

Parameterizing legacy_api as [True, False] ensures both the legacy and unified APIs are validated against the same correctness criteria. This is valuable for ensuring the unified API maintains backward compatibility with existing behavior.

tests/comm/test_allreduce_unified_api.py (1)

239-311: Robust error handling and cleanup pattern.

The test properly:

  1. Wraps execution in try/except/finally
  2. Gathers failure status from all ranks for diagnostics
  3. Ensures workspace cleanup in finally block
  4. Cleans up torch.distributed if initialized

This pattern ensures resources are freed even on test failures in distributed scenarios.

@nvmbreughe nvmbreughe changed the title A unified API for the MNNVL and single-node AllReduce kernels. A unified API for the MNNVL and single-node/multi-GPU AllReduce kernels. Dec 12, 2025
@wenscarl
Copy link
Copy Markdown
Collaborator

test_trtllm_mnnvl_allreduce_custom_comm.py was not updated in #2118. Could you update it in this PR?

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
tests/comm/test_allreduce_unified_api.py (1)

155-156: Epsilon inconsistency between reference and actual test remains.

The reference RMSNorm at Line 156 uses torch.finfo(dtype).eps, but run_allreduce_test uses a hardcoded eps = 1e-5 (Line 211) which is passed to the actual fusion operation. This mismatch could cause subtle differences between reference and test outputs.

Consider using the same epsilon for both:

-    eps = 1e-5
+    eps = torch.finfo(dtype).eps

Or pass eps to prepare_test_data and use it consistently.

Also applies to: 211-211

flashinfer/comm/allreduce.py (1)

314-314: Documentation inconsistencies persist.

  1. Line 314 references workspace.is_sufficient_for(...) but the actual method is is_buffer_size_sufficient(tp_size, num_tokens, hidden_dim, dtype).

  2. Line 337 mentions BackendSupportedError but the code raises ValueError at Line 398.

-    Use `workspace.is_sufficient_for(token_num, hidden_dim, dtype)` to check before use.
+    Use `workspace.is_buffer_size_sufficient(tp_size, token_num, hidden_dim, dtype)` to check before use.
 ...
     Raises:
-        BackendSupportedError: If no suitable backend available for the configuration
+        ValueError: If no suitable backend available for the configuration
         ValueError: If problem size not supported for the specified backend

Also applies to: 337-338

🧹 Nitpick comments (4)
tests/test_helpers/comm.py (1)

48-52: Consider environment variable fallbacks for port configuration.

The hardcoded MASTER_PORT = "29500" could cause conflicts if the port is already in use. Consider checking for existing environment variables first:

     # Set environment variables for torch.distributed
-    os.environ["MASTER_ADDR"] = "localhost"
-    os.environ["MASTER_PORT"] = "29500"
+    os.environ.setdefault("MASTER_ADDR", "localhost")
+    os.environ.setdefault("MASTER_PORT", "29500")
     os.environ["RANK"] = str(rank)
     os.environ["WORLD_SIZE"] = str(world_size)

This allows tests to override the port via environment variables when needed.

tests/comm/test_allreduce_unified_api.py (1)

167-168: Remove unused monkeypatch parameter.

The monkeypatch parameter is declared but never used in the function, as flagged by static analysis.

 def run_allreduce_test(
-    monkeypatch,
     seq_lens: list[int],
     fusion: bool,
     dtype: torch.dtype,
     hidden_size: int,
     backend: str,
 ):

Also update the call site in test_allreduce_unified:

 def test_allreduce_unified(
-    monkeypatch,
     seq_lens: list[int],
     fusion: bool,
     dtype: torch.dtype,
     hidden_size: int,
     backend: str,
 ):
-    run_allreduce_test(monkeypatch, seq_lens, fusion, dtype, hidden_size, backend)
+    run_allreduce_test(seq_lens, fusion, dtype, hidden_size, backend)
flashinfer/comm/allreduce.py (2)

159-160: Consider using logging instead of print for buffer insufficiency messages.

Line 160 uses print() which could be noisy in production environments. Consider using the logging module for consistency with other debug output in the codebase:

+import logging
+
 def is_buffer_size_sufficient(
     ...
 ) -> bool:
     try:
         check_trtllm_allreduce_fusion_workspace_metadata(
             num_tokens, hidden_dim, tp_size, dtype, self.metadata
         )
         return True
     except ValueError as e:
-        print(f"Workspace is insufficient for problem size. {e}")
+        logging.debug(f"Workspace is insufficient for problem size. {e}")
         return False

683-683: Prefix unused variable with underscore.

The residual_result variable is unpacked but never used (flagged by static analysis RUF059). Use underscore to indicate it's intentionally ignored:

-            norm_result, residual_result = trtllm_mnnvl_fused_allreduce_add_rmsnorm(
+            norm_result, _residual_result = trtllm_mnnvl_fused_allreduce_add_rmsnorm(
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f707678 and c6477c3.

📒 Files selected for processing (5)
  • flashinfer/comm/allreduce.py (1 hunks)
  • flashinfer/comm/trtllm_mnnvl_ar.py (10 hunks)
  • tests/comm/test_allreduce_negative.py (1 hunks)
  • tests/comm/test_allreduce_unified_api.py (1 hunks)
  • tests/test_helpers/comm.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • flashinfer/comm/allreduce.py
🧬 Code graph analysis (2)
tests/test_helpers/comm.py (1)
flashinfer/comm/mapping.py (1)
  • local_rank (391-392)
flashinfer/comm/trtllm_mnnvl_ar.py (3)
flashinfer/comm/workspace_base.py (3)
  • AllReduceFusionWorkspace (23-89)
  • backend (38-40)
  • destroy (43-50)
flashinfer/comm/mapping.py (2)
  • rank (311-312)
  • rank (315-322)
flashinfer/comm/allreduce.py (2)
  • backend (135-136)
  • destroy (163-169)
🪛 Ruff (0.14.8)
tests/comm/test_allreduce_unified_api.py

168-168: Unused function argument: monkeypatch

(ARG001)

flashinfer/comm/allreduce.py

141-143: Avoid specifying long messages outside the exception class

(TRY003)


152-152: Unused method argument: use_oneshot

(ARG002)


158-158: Consider moving this statement to an else block

(TRY300)


178-178: Unused function argument: backend

(ARG001)


179-179: Unused function argument: world_size

(ARG001)


180-180: Unused function argument: rank

(ARG001)


181-181: Unused function argument: max_token_num

(ARG001)


182-182: Unused function argument: hidden_dim

(ARG001)


183-183: Unused function argument: dtype

(ARG001)


201-201: Unused function argument: backend

(ARG001)


202-202: Unused function argument: world_size

(ARG001)


203-203: Unused function argument: rank

(ARG001)


204-204: Unused function argument: max_token_num

(ARG001)


205-205: Unused function argument: hidden_dim

(ARG001)


206-206: Unused function argument: dtype

(ARG001)


227-227: Unused function argument: backend

(ARG001)


228-228: Unused function argument: world_size

(ARG001)


229-229: Unused function argument: rank

(ARG001)


230-230: Unused function argument: max_token_num

(ARG001)


231-231: Unused function argument: hidden_dim

(ARG001)


232-232: Unused function argument: dtype

(ARG001)


288-288: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


289-289: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


290-290: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


291-291: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


295-295: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


398-401: Avoid specifying long messages outside the exception class

(TRY003)


444-444: Avoid specifying long messages outside the exception class

(TRY003)


584-584: Avoid specifying long messages outside the exception class

(TRY003)


646-648: Avoid specifying long messages outside the exception class

(TRY003)


651-653: Avoid specifying long messages outside the exception class

(TRY003)


672-672: Avoid specifying long messages outside the exception class

(TRY003)


674-674: Avoid specifying long messages outside the exception class

(TRY003)


683-683: Unpacked variable residual_result is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


696-696: Avoid specifying long messages outside the exception class

(TRY003)


699-702: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (15)
tests/test_helpers/comm.py (2)

10-32: LGTM!

The MPI/CUDA setup function correctly validates prerequisites and assigns CUDA devices based on local rank. The local_rank = rank % gpus_per_node calculation is consistent with the pattern used in flashinfer/comm/mapping.py.


61-67: LGTM!

The cleanup function correctly guards against calling destroy_process_group when not initialized.

flashinfer/comm/trtllm_mnnvl_ar.py (5)

79-79: Potential inconsistency: self.rank is set twice with potentially different values.

Line 79 passes mapping.rank to the base class __init__, which sets self.rank = rank. However, Line 128 then overwrites self.rank = mapping.tp_rank. If mapping.rank differs from mapping.tp_rank, this creates an inconsistency.

Please verify this is intentional. If tp_rank is the correct value for workspace operations, consider passing it to the base class:

-        super().__init__(mapping.world_size, mapping.rank)
+        super().__init__(mapping.world_size, mapping.tp_rank)

Or if both values are needed, use a different attribute name for the TP-specific rank.

Also applies to: 128-129


227-229: LGTM!

The backend property correctly implements the abstract method from the base class.


231-241: LGTM!

The destroy() method correctly implements idempotent cleanup with the _destroyed guard, consistent with the pattern used in TRTLLMAllReduceFusionWorkspace.


326-332: LGTM!

The function signatures and docstrings are correctly updated to use the new MNNVLAllReduceFusionWorkspace type.

Also applies to: 404-414


498-501: LGTM!

The deprecation message and legacy function correctly redirect users to the new MNNVLAllReduceFusionWorkspace class.

Also applies to: 542-547

tests/comm/test_allreduce_negative.py (3)

25-116: LGTM!

The test class properly validates that the MNNVL backend rejects unsupported quantization patterns and layout codes. The fixture correctly handles workspace cleanup and MPI synchronization.


119-174: LGTM!

The test class correctly validates that required parameters (residual_in, rms_gamma) are enforced for the kARResidualRMSNorm pattern.


177-272: LGTM!

The test class correctly validates buffer size sufficiency checks across both backends. The parametrized approach ensures consistent behavior between MNNVL and TRTLLM implementations.

tests/comm/test_allreduce_unified_api.py (2)

31-123: LGTM!

The test function correctly exercises both fused and unfused allreduce patterns and validates outputs against reference implementations.


291-311: LGTM!

The parametrized test provides comprehensive coverage across sequence lengths, fusion modes, data types, hidden dimensions, and backends.

flashinfer/comm/allreduce.py (3)

177-217: LGTM!

The backend check functions correctly implement topology-based requirements. The unused parameters (flagged by static analysis) appear intentional for a consistent interface that allows future extensibility without signature changes.


225-278: LGTM!

The heuristic function correctly implements backend selection based on topology and benchmark data, with appropriate fallbacks.


452-472: LGTM!

The allreduce_fusion dispatcher correctly routes to backend-specific implementations based on workspace type. The contiguity check before flattening ensures writes to flattened tensors are reflected in the original 2D tensors.

Also applies to: 568-639

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nvmbreughe this is huge and thank you for working on the refactor, LGTM overall.

>>> prenorm = torch.empty_like(hidden_states)
>>> normed = torch.empty_like(hidden_states)
>>>
>>> # Call fusion - backend inferred from workspace type
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good design :)

print(f"Rank {rank} using GPU {torch.cuda.current_device()}")

eps = 1e-5
torch.manual_seed(42 + rank)
Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 Dec 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it designed to make sure each rank was initialized in a different way? (I'm curious if we should have a standard for data initialization for communication APIs).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I think it must have been a copy paste error actually: the data is only initialized for rank 0, so this would not have an effect.

torch.testing.assert_close(
output[0],
reference_output[0],
rtol=0.05,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might consider different tolerance for different output data types in the future (if we add more precision output data types).

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c6477c3 and 78c6977.

📒 Files selected for processing (1)
  • flashinfer/comm/__init__.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/comm/__init__.py (3)
flashinfer/comm/workspace_base.py (1)
  • AllReduceFusionWorkspace (23-89)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
  • MNNVLAllReduceFusionWorkspace (51-241)
flashinfer/comm/allreduce.py (3)
  • TRTLLMAllReduceFusionWorkspace (88-169)
  • allreduce_fusion (452-702)
  • create_allreduce_fusion_workspace (286-444)
🪛 GitHub Actions: pre-commit
flashinfer/comm/__init__.py

[error] 49-49: Mypy: Syntax error in flashinfer/comm/init.py: '(' was never closed.


[error] 50-50: Ruff-check: invalid-syntax in import block (Expected ')' before newline) in flashinfer/comm/init.py.


[error] 50-50: Ruff-format: Failed to parse flashinfer/comm/init.py due to syntax error.

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

Comment thread flashinfer/comm/__init__.py
@yzh119 yzh119 enabled auto-merge (squash) December 17, 2025 16:51
@yzh119 yzh119 merged commit fd0c2f1 into flashinfer-ai:main Dec 17, 2025
4 checks passed
@coderabbitai coderabbitai Bot mentioned this pull request Dec 18, 2025
5 tasks
@coderabbitai coderabbitai Bot mentioned this pull request Mar 5, 2026
5 tasks
BingooYang pushed a commit to BingooYang/flashinfer that referenced this pull request Mar 13, 2026
…ls. (flashinfer-ai#2130)

<!-- .github/pull_request_template.md -->

## 📌 Description

A unified API for the MNNVL and single-node AllReduce kernels.
* This introduces the API's `create_allreduce_fusion_workspace`, and
`allreduce_fusion`
* The backend ("trtllm" or "mnnvl") is chosen during workspace creation.
We can either pick it explicitly, or use the "auto" backend to have a
heuristic pick the best backend.
* The API can be used for both single-node allreduce, as well as for
multi-node allreduce.

Test with
```
mpirun -np 4 pytest tests/comm/test_allreduce_unified_api.py
mpirun -np 4 pytest tests/comm/test_allreduce_negative.py
```

note: mpirun is needed for the mnnvl backend, as illustrated in the test
commands above

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Unified AllReduce Fusion API with multi-backend support
(auto/TRTLLM/MNNVL) and public workspace types
* Common workspace base with lifecycle management and automatic cleanup
warnings

* **Bug Fixes / Validation**
  * Stronger input/workspace validation with aggregated error messages
  * Idempotent destroy semantics for safer resource cleanup

* **Deprecations**
  * Legacy AllReduce APIs deprecated in favor of the unified API

* **Tests**
* Expanded negative, integration, and cross-backend correctness tests
plus MPI/CUDA test helpers

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Zihao Ye <expye@outlook.com>
Co-authored-by: yzh119 <zihaoy@nvidia.com>
@coderabbitai coderabbitai Bot mentioned this pull request Apr 16, 2026
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants