A unified API for the MNNVL and single-node/multi-GPU AllReduce kernels.#2130
Conversation
WalkthroughAdds a unified AllReduce fusion API: introduces an abstract AllReduceFusionWorkspace base class, concrete TRTLLM/MNNVL workspace wrappers, a factory create_allreduce_fusion_workspace(), dispatcher allreduce_fusion(), backend checks/heuristics, updated backend modules and deprecation messages, and comprehensive MPI/CUDA tests and helpers. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant User
participant API as AllReduce API\n(flashinfer/comm/allreduce.py)
participant Selector as Backend\nSelector/Heuristic
participant TRT as TRTLLM\nBackend (legacy IPC / trtllm_ar)
participant MNN as MNNVL\nBackend (trtllm_mnnvl_ar)
User->>API: create_allreduce_fusion_workspace(backend, topology, ...)
API->>Selector: evaluate topology & heuristics
alt choose TRTLLM
Selector-->>API: select trtllm
API->>TRT: create/wrap IPC workspace
TRT-->>API: TRTLLMAllReduceFusionWorkspace
else choose MNNVL
Selector-->>API: select mnnvl
API->>MNN: create/wrap MNNVL workspace
MNN-->>API: MNNVLAllReduceFusionWorkspace
end
API-->>User: workspace
User->>API: allreduce_fusion(input, workspace, pattern, ...)
API->>API: inspect workspace.backend
alt TRTLLM workspace
API->>API: flatten inputs, build metadata
API->>TRT: trtllm_allreduce_fusion(flattened, metadata)
TRT-->>API: flattened output
API->>API: reshape output
API-->>User: output
else MNNVL workspace
API->>API: validate pattern & required tensors
alt kARResidualRMSNorm
API->>MNN: trtllm_mnnvl_fused_allreduce_add_rmsnorm(...)
else kAllReduce
API->>MNN: trtllm_mnnvl_allreduce(...)
end
MNN-->>API: output tensor(s)
API-->>User: output
end
User->>API: workspace.destroy()
API->>TRT/MNN: cleanup resources (idempotent)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: defaults Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🧰 Additional context used🧬 Code graph analysis (1)flashinfer/comm/__init__.py (3)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
🔇 Additional comments (1)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
| max_token_num: int, | ||
| hidden_dim: int, | ||
| dtype: torch.dtype, | ||
| topology: str, |
There was a problem hiding this comment.
| topology: str, | |
| topology: Literal["single_node", "multi_node"], |
| max_token_num: int = None, | ||
| hidden_dim: int = None, | ||
| dtype: torch.dtype = None, | ||
| topology: str = "single_node", |
There was a problem hiding this comment.
I don't think it is needed longer term since we will use the same pytorch symmetric API to allocate symmetric memory for single and multi-node (under the cover pytorch/NCCL/NVSHMEM will detect platform and decides the right mem allocation handle)
| input: torch.Tensor, | ||
| workspace: AllReduceFusionWorkspace, | ||
| pattern: int, | ||
| launch_with_pdl: bool = False, |
There was a problem hiding this comment.
Why the advantage to give pdl control to the user?
There was a problem hiding this comment.
We have been doing this for all our APIs, but I am not sure why. Maybe because not all archs support it?
There was a problem hiding this comment.
the main purpose of having this control is for debugging as we observed some issues of pdl on sm121.
For deployment it should always be turned on IMO, so I suppose the default value should be None and we will call
flashinfer/flashinfer/utils.py
Line 615 in 3a301a1
| Args: | ||
| input: Input tensor [token_num, hidden_dim] | ||
| workspace: Workspace object (type determines backend) | ||
| pattern: Fusion pattern (AllReduceFusionPattern constant, 0-5) |
There was a problem hiding this comment.
All they all 2-kernel overlap or some are real fusion kernels?
There was a problem hiding this comment.
with one-shot mnnvl it's real fusion. And I think similar for the trtllm_ar kernels. It's just two-shot mnnvl that is the 2-kernel overlap.
| }, | ||
| heuristic_func=_workspace_creation_heuristic, | ||
| ) | ||
| def create_allreduce_fusion_workspace( |
There was a problem hiding this comment.
Could create_allreduce_fusion_workspace take an optional workspace argument? If workspace is big enough or too big this is a noop (maybe just updating backend selection). If it is too small, destroy current workspace and allocate a bigger one.
When we switch to mem pool, we should be able to call create_allreduce_fusion_workspace at each forward pass and memory will just get reused from the mempool (instead of new allocations).
CC @Amir-19
There was a problem hiding this comment.
What's the advantage of having
workspace= create_allreduce_fusion_workspace(old_workspace)
vs
workspace = old_workspace if condition else create_allreduce_fusion_workspace()
?
There was a problem hiding this comment.
It was for the case when you would support growing the workspace:
If needed_size > old_workspace.size:
old_workspace.destroy()
return new_workspace
But we discussed that we may not be able to destroy current workspace due to previous CUDA graph captures.
So you can ignore my comment for now.
| - Workspace(max_token_num=2048, hidden_dim=4096) can handle: | ||
| - (token_num=2048, hidden_dim=4096) ✓ | ||
| - (token_num=1024, hidden_dim=4096) ✓ | ||
| - (token_num=4096, hidden_dim=2048) ✓ (same total size) |
There was a problem hiding this comment.
I only see FW adjusting the num of tokens but hidden_dim should be fixed per model.
There was a problem hiding this comment.
Discussed this live, and resulted in creating the one-shot vs two-shot decision to be delayed until runtime (aka not during workspace creation time).
| ... max_token_num=2048, | ||
| ... hidden_dim=4096, | ||
| ... dtype=torch.bfloat16, | ||
| ... topology="single_node" |
There was a problem hiding this comment.
Could we had a check now to detect topology? before we switch to the mempool allocation?
There was a problem hiding this comment.
Hi @nvcastet , could you elaborate what you mean with a topology check?
There was a problem hiding this comment.
Something like that check: https://github.com/pytorch/pytorch/blob/eb66e1d2b590865be962b8677acc31728d3ad953/aten/src/ATen/cuda/PeerToPeerAccess.cpp#L146-L180
If fabric is enabled you would have a fabric workspace (for MNNVL) otherwise to default to single-node workspace.
437c7df to
10554e5
Compare
There was a problem hiding this comment.
Actionable comments posted: 10
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/comm/mnnvl.py (1)
132-150: Return type bug: annotated-> intbut returnsNone
alloc_and_copy_to_cuda()is typed to returnintbut returnsNonewhenhost_ptr_arrayis empty. This will break callers expecting an integer device pointer.-def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int: +def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int: """ A helper function that allocates memory on cuda and copies the data from the host to the device. """ if not host_ptr_array: - return None + return 0
🧹 Nitpick comments (18)
include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh (2)
218-242: Lamport sync helpers are fragile w.r.t. early-exit / partial participation
LamportFlags::ctaArrive()uses__syncthreads()/cluster.sync(), which will deadlock if any threads in the block exit early before reaching it. Today you mostly avoid OOB by shapingblockSize, but this is a sharp edge for future tuning (e.g., cluster sizing,loadsPerThread, partial tiles). Consider reworkingctaArrive()to avoid requiring full-block participation (e.g., a single designated thread does the atomic without a block-wide barrier), or enforce the invariant with stronger checks/docs at the call sites.Also applies to: 385-418, 509-651, 760-888
924-950: Minor:offsetsarray is oversized (stack/register pressure)
InrmsNormLamport,offsetsis declared asLoadsPerThread * kELTS_PER_LOADbut indexed asoffsets[i]fori < LoadsPerThread. Tighten this toLoadsPerThread.- uint32_t offsets[LoadsPerThread * kELTS_PER_LOAD]; + uint32_t offsets[LoadsPerThread];csrc/trtllm_mnnvl_allreduce.cu (1)
29-35: Add basic range/null guards for pointer-like int64 args
A null/invalidmulticast_buffer_ptr,buffer_ptrs_dev,buffer_ptr_local, or emptybuffer_flags_mnnvlwill turn into UB quickly. Even lightweight> 0/ non-empty checks would make failures actionable.Also applies to: 79-90
flashinfer/comm/mnnvl.py (1)
640-656: Minor: unusedrecvmsgtuple parts (ruff RUF059)
Prefix unusedmsg/flags/addrwith_to avoid lint noise.tests/comm/test_allreduce_negative.py (1)
36-70: Consider extracting duplicated fixture setup to reduce code repetition.Both
TestMNNVLUnsupportedPatternsandTestMNNVLMissingRequiredParametershave nearly identicalsetupfixtures. Consider extracting the common workspace creation and teardown logic into a shared fixture or base class to improve maintainability.@pytest.fixture def mnnvl_workspace(): """Shared fixture for MNNVL workspace setup and teardown.""" rank, world_size, gpus_per_node = setup_mpi_and_cuda() workspace = create_allreduce_fusion_workspace( backend="mnnvl", world_size=world_size, rank=rank, max_token_num=128, hidden_dim=2880, dtype=torch.float16, topology="single_node", gpus_per_node=gpus_per_node, ) yield workspace, rank, world_size, gpus_per_node if workspace is not None: workspace.destroy() trtllm_mnnvl_ar.mpi_barrier()flashinfer/comm/workspace_base.py (1)
52-61: Consider using a more specific type foruse_oneshotparameter.The
use_oneshot: Optional[Any]parameter is very permissive. Based on the relevant code snippets, the TRTLLM backend usesboolwhile MNNVL usesMNNVLAllreduceFusionStrategy. Consider using aUniontype or documenting the expected types per backend.+from typing import Optional, Any, Union + @abstractmethod def is_buffer_size_sufficient( self, tp_size: int, num_tokens: int, hidden_dim: int, dtype: torch.dtype, - use_oneshot: Optional[Any] = None, + strategy: Optional[Any] = None, # Backend-specific: bool for TRTLLM, MNNVLAllreduceFusionStrategy for MNNVL ) -> bool: passtests/comm/test_trtllm_mnnvl_allreduce.py (2)
17-103: Consider inlining the nestedfuncfunction.The nested
funcfunction doesn't capture any variables from the outer scope that aren't already passed as parameters, and it's only called once. Consider inlining it for clarity.
233-271: Fix type annotation forreference_output.The variable is initialized to
Nonebut typed asTuple[torch.Tensor, ...]. UseOptional[Tuple[torch.Tensor, ...]]or remove the annotation since Python will infer it correctly from the subsequent assignments.x_local = x_full[rank, :, :] - reference_output: Tuple[torch.Tensor, ...] = None + reference_output: Optional[Tuple[torch.Tensor, ...]] = None if fusion:Note:
Optionalis already imported fromtyping.tests/comm/test_allreduce_unified_api.py (4)
26-46: Hardcoded port may cause conflicts in parallel test runs.The hardcoded
MASTER_PORT = "29500"could cause port conflicts if multiple test instances run concurrently on the same machine.Consider using a dynamic port or environment variable fallback:
- os.environ["MASTER_PORT"] = "29500" + os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
133-146: Consider documenting why tolerances are set so high.The tolerance values (
rtol=0.05,atol=0.15) are quite loose for numerical validation. While this may be necessary for distributed bf16/fp16 operations, a brief comment explaining this would help future maintainers understand this isn't masking bugs.
190-197: Remove unusedmonkeypatchparameter.The
monkeypatchparameter is declared but never used in this function, as flagged by static analysis.def run_allreduce_test( - monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int, backend: str, ):Also update the call site in
test_allreduce_unifiedaccordingly.
314-334: Consider adding pytest markers for CI optimization.The parametrized test creates 144 combinations (6 seq_lens × 2 fusion × 2 dtype × 2 hidden_size × 3 backend). Consider adding
@pytest.mark.slowor similar markers to allow selective test execution in CI.flashinfer/comm/trtllm_mnnvl_ar.py (2)
79-81: Use logging instead of print for production code.Line 81 uses
print()while other parts of this class uselogging.debug(). For consistency and proper log level control, use logging:- print("Allocating MNNVL Allreduce Fusion Workspace...") + logging.info("Allocating MNNVL Allreduce Fusion Workspace...")
179-197:functools.cacheon instance method can cause memory leaks.Using
@functools.cacheon an instance method bindsselfto the cache key, preventing the instance from being garbage collected even after all other references are dropped. Since workspaces may be created and destroyed frequently, this could accumulate memory.Consider using
lru_cachewith a bounded size, or move the caching to the static methodget_required_buffer_size_byteswhich doesn't have this issue:- @functools.cache def is_buffer_size_sufficient( self, tp_size: int, num_tokens: int, hidden_dim: int, dtype: torch.dtype, strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, ) -> bool: - """ - Calculate the required buffer size for a given problem size. - """ + """Check if the buffer is sufficient for a given problem size.""" required_buffer_size = self.get_required_buffer_size_bytes( tp_size, num_tokens, hidden_dim, dtype, strategy ) - if required_buffer_size > self.buffer_size_bytes: - return False - else: - return True + return required_buffer_size <= self.buffer_size_bytesflashinfer/comm/allreduce.py (4)
146-161: Use logging instead of print, and address unused parameter.
- Line 160 uses
print()instead ofloggingfor consistency with other modules.- The
use_oneshotparameter is unused (static analysis ARG002) but may be kept for API consistency withMNNVLAllReduceFusionWorkspace.is_buffer_size_sufficient.+import logging + ... def is_buffer_size_sufficient( self, tp_size: int, num_tokens: int, hidden_dim: int, dtype: torch.dtype, - use_oneshot: Optional[Any] = None, + use_oneshot: Optional[Any] = None, # Unused, kept for API consistency with MNNVL ) -> bool: try: check_trtllm_allreduce_fusion_workspace_metadata( num_tokens, hidden_dim, tp_size, dtype, self.metadata ) return True except ValueError as e: - print(f"Workspace is insufficient for problem size. {e}") + logging.debug(f"Workspace is insufficient for problem size: {e}") return False
200-217: Simplify_mnnvl_workspace_check- currently always returns True.The function has a redundant conditional structure since both branches return
True. If this is intentional scaffolding for future checks, consider adding a comment; otherwise simplify:def _mnnvl_workspace_check( backend: str, world_size: int, rank: int, max_token_num: int, hidden_dim: int, dtype: torch.dtype, topology: Literal["single_node", "multi_node"], ) -> bool: """ Check if mnnvl backend CAN be used for workspace creation. - """ - - if topology == "multi_node": - return True - - return True + # MNNVL supports both single-node and multi-node topologies + return True
286-297: Use explicitOptionaltype hints for parameters defaulting to None.Parameters with
Nonedefaults should explicitly declareOptionaltypes for PEP 484 compliance and better IDE support.def create_allreduce_fusion_workspace( backend: Literal["trtllm", "mnnvl", "auto"] = "auto", - world_size: int = None, - rank: int = None, - max_token_num: int = None, - hidden_dim: int = None, - dtype: torch.dtype = None, + world_size: Optional[int] = None, + rank: Optional[int] = None, + max_token_num: Optional[int] = None, + hidden_dim: Optional[int] = None, + dtype: Optional[torch.dtype] = None, topology: Literal["single_node", "multi_node"] = "single_node", process_group: Optional["torch.distributed.ProcessGroup"] = None, - gpus_per_node: int = None, + gpus_per_node: Optional[int] = None, comm_backend: Optional[CommBackend] = None, ) -> AllReduceFusionWorkspace:
682-693: Prefix unused variable with underscore.The
residual_resultvariable is unpacked but never used, as flagged by static analysis.# Call the MNNVL fusion function - norm_result, residual_result = trtllm_mnnvl_fused_allreduce_add_rmsnorm( + norm_result, _residual_result = trtllm_mnnvl_fused_allreduce_add_rmsnorm( input=input, residual_in=residual_in, gamma=rms_gamma, workspace=workspace, epsilon=rms_eps, output=norm_out, residual_out=residual_out, launch_with_pdl=launch_with_pdl, ) return norm_result
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between 6bb01d1 and 9e672bf7434e07037423c5f9dc324042492e1bb8.
📒 Files selected for processing (13)
csrc/trtllm_mnnvl_allreduce.cu(1 hunks)flashinfer/comm/__init__.py(1 hunks)flashinfer/comm/allreduce.py(1 hunks)flashinfer/comm/mnnvl.py(19 hunks)flashinfer/comm/trtllm_ar.py(5 hunks)flashinfer/comm/trtllm_mnnvl_ar.py(5 hunks)flashinfer/comm/workspace_base.py(1 hunks)include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh(2 hunks)include/flashinfer/utils.cuh(2 hunks)tests/comm/test_allreduce_negative.py(1 hunks)tests/comm/test_allreduce_unified_api.py(1 hunks)tests/comm/test_trtllm_allreduce_fusion.py(6 hunks)tests/comm/test_trtllm_mnnvl_allreduce.py(9 hunks)
🧰 Additional context used
🧬 Code graph analysis (9)
tests/comm/test_allreduce_negative.py (4)
flashinfer/comm/allreduce.py (4)
create_allreduce_fusion_workspace(286-444)allreduce_fusion(452-702)backend(135-136)destroy(163-169)flashinfer/comm/trtllm_ar.py (2)
AllReduceFusionPattern(64-77)QuantizationSFLayout(80-95)flashinfer/comm/mapping.py (1)
local_rank(391-392)flashinfer/comm/trtllm_mnnvl_ar.py (3)
backend(229-230)destroy(232-242)mpi_barrier(24-28)
flashinfer/comm/workspace_base.py (2)
flashinfer/comm/allreduce.py (3)
backend(135-136)destroy(163-169)is_buffer_size_sufficient(146-161)flashinfer/comm/trtllm_mnnvl_ar.py (3)
backend(229-230)destroy(232-242)is_buffer_size_sufficient(180-197)
tests/comm/test_trtllm_allreduce_fusion.py (2)
flashinfer/comm/trtllm_ar.py (4)
trtllm_create_ipc_workspace_for_all_reduce_fusion(506-645)trtllm_allreduce_fusion(230-275)trtllm_allreduce_fusion(857-963)trtllm_destroy_ipc_workspace_for_all_reduce_fusion(648-664)flashinfer/comm/allreduce.py (4)
create_allreduce_fusion_workspace(286-444)backend(135-136)allreduce_fusion(452-702)destroy(163-169)
tests/comm/test_trtllm_mnnvl_allreduce.py (5)
flashinfer/comm/mapping.py (2)
Mapping(21-475)tp_rank(325-326)flashinfer/comm/trtllm_mnnvl_ar.py (3)
MNNVLAllReduceFusionWorkspace(51-242)trtllm_mnnvl_allreduce(327-402)get_allreduce_mnnvl_workspace(503-559)flashinfer/comm/mnnvl.py (8)
barrier(168-168)barrier(227-228)get_multicast_ptr(968-972)get_multicast_ptr(1270-1272)get_buffer_ptrs_dev(954-956)get_buffer_ptrs_dev(1278-1280)get_unicast_ptr(958-966)get_unicast_ptr(1274-1276)tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py (1)
barrier(60-64)include/flashinfer/trtllm/fused_moe/runner.h (1)
hidden_size(265-265)
tests/comm/test_allreduce_unified_api.py (6)
flashinfer/comm/allreduce.py (4)
create_allreduce_fusion_workspace(286-444)allreduce_fusion(452-702)backend(135-136)destroy(163-169)flashinfer/comm/trtllm_ar.py (1)
AllReduceFusionPattern(64-77)flashinfer/comm/workspace_base.py (1)
AllReduceFusionWorkspace(23-89)flashinfer/comm/trtllm_mnnvl_ar.py (2)
backend(229-230)destroy(232-242)csrc/tvm_ffi_utils.h (1)
Tensor(287-289)flashinfer/comm/mapping.py (1)
local_rank(391-392)
flashinfer/comm/trtllm_ar.py (1)
flashinfer/logits_processor/types.py (1)
dtype(126-130)
flashinfer/comm/mnnvl.py (2)
flashinfer/cuda_utils.py (1)
checkCudaErrors(51-61)flashinfer/utils.py (1)
round_up(631-633)
flashinfer/comm/trtllm_mnnvl_ar.py (3)
flashinfer/comm/mapping.py (5)
rank(311-312)rank(315-322)tp_rank(325-326)local_rank(391-392)is_multi_node(403-404)flashinfer/comm/mnnvl.py (13)
McastGPUBuffer(1199-1280)CommBackend(152-171)MPIBackend(211-232)lamport_initialize(1179-1196)lamport_initialize(1239-1240)barrier(168-168)barrier(227-228)get_buffer_ptrs_dev(954-956)get_buffer_ptrs_dev(1278-1280)get_unicast_ptr(958-966)get_unicast_ptr(1274-1276)get_multicast_ptr(968-972)get_multicast_ptr(1270-1272)flashinfer/comm/workspace_base.py (4)
AllReduceFusionWorkspace(23-89)is_buffer_size_sufficient(53-61)backend(38-40)destroy(43-50)
csrc/trtllm_mnnvl_allreduce.cu (2)
csrc/tvm_ffi_utils.h (1)
get_stream(277-279)flashinfer/comm/trtllm_mnnvl_ar.py (1)
trtllm_mnnvl_allreduce_fusion(269-320)
🪛 Ruff (0.14.8)
flashinfer/comm/workspace_base.py
83-83: Do not catch blind exception: Exception
(BLE001)
flashinfer/comm/allreduce.py
141-143: Avoid specifying long messages outside the exception class
(TRY003)
152-152: Unused method argument: use_oneshot
(ARG002)
158-158: Consider moving this statement to an else block
(TRY300)
178-178: Unused function argument: backend
(ARG001)
179-179: Unused function argument: world_size
(ARG001)
180-180: Unused function argument: rank
(ARG001)
181-181: Unused function argument: max_token_num
(ARG001)
182-182: Unused function argument: hidden_dim
(ARG001)
183-183: Unused function argument: dtype
(ARG001)
201-201: Unused function argument: backend
(ARG001)
202-202: Unused function argument: world_size
(ARG001)
203-203: Unused function argument: rank
(ARG001)
204-204: Unused function argument: max_token_num
(ARG001)
205-205: Unused function argument: hidden_dim
(ARG001)
206-206: Unused function argument: dtype
(ARG001)
227-227: Unused function argument: backend
(ARG001)
228-228: Unused function argument: world_size
(ARG001)
229-229: Unused function argument: rank
(ARG001)
230-230: Unused function argument: max_token_num
(ARG001)
231-231: Unused function argument: hidden_dim
(ARG001)
232-232: Unused function argument: dtype
(ARG001)
288-288: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
289-289: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
290-290: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
291-291: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
295-295: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
398-401: Avoid specifying long messages outside the exception class
(TRY003)
444-444: Avoid specifying long messages outside the exception class
(TRY003)
584-584: Avoid specifying long messages outside the exception class
(TRY003)
646-648: Avoid specifying long messages outside the exception class
(TRY003)
651-653: Avoid specifying long messages outside the exception class
(TRY003)
672-672: Avoid specifying long messages outside the exception class
(TRY003)
674-674: Avoid specifying long messages outside the exception class
(TRY003)
683-683: Unpacked variable residual_result is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
696-696: Avoid specifying long messages outside the exception class
(TRY003)
699-702: Avoid specifying long messages outside the exception class
(TRY003)
tests/comm/test_allreduce_unified_api.py
191-191: Unused function argument: monkeypatch
(ARG001)
flashinfer/comm/mnnvl.py
587-587: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
612-612: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
640-640: Unpacked variable msg is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
640-640: Unpacked variable flags is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
640-640: Unpacked variable addr is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
656-656: Avoid specifying long messages outside the exception class
(TRY003)
726-726: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
1011-1011: Do not catch blind exception: Exception
(BLE001)
flashinfer/comm/trtllm_mnnvl_ar.py
122-124: Avoid specifying long messages outside the exception class
(TRY003)
179-179: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks
(B019)
359-361: Avoid specifying long messages outside the exception class
(TRY003)
366-368: Avoid specifying long messages outside the exception class
(TRY003)
380-382: Avoid specifying long messages outside the exception class
(TRY003)
442-444: Avoid specifying long messages outside the exception class
(TRY003)
446-448: Avoid specifying long messages outside the exception class
(TRY003)
450-452: Avoid specifying long messages outside the exception class
(TRY003)
456-458: Avoid specifying long messages outside the exception class
(TRY003)
462-464: Avoid specifying long messages outside the exception class
(TRY003)
475-477: Avoid specifying long messages outside the exception class
(TRY003)
608-610: Avoid specifying long messages outside the exception class
(TRY003)
679-681: Avoid specifying long messages outside the exception class
(TRY003)
685-687: Avoid specifying long messages outside the exception class
(TRY003)
690-692: Avoid specifying long messages outside the exception class
(TRY003)
694-696: Avoid specifying long messages outside the exception class
(TRY003)
699-701: Avoid specifying long messages outside the exception class
(TRY003)
704-706: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (30)
flashinfer/comm/__init__.py (1)
42-51: Re-exports look good; just watch for import cycles
Public surface wiring is straightforward; please just verify importingflashinfer.commdoesn’t create a circular import via.allreduce/.trtllm_mnnvl_ar.tests/comm/test_allreduce_negative.py (3)
18-33: LGTM - Clean MPI/CUDA setup helper.The helper properly handles skip conditions for missing CUDA devices and insufficient MPI ranks, and correctly computes local_rank for multi-GPU nodes.
72-127: LGTM - Comprehensive negative tests for unsupported patterns.The parameterized tests properly verify that the MNNVL backend rejects quantization fusion patterns and layout codes with appropriate error messages.
130-185: LGTM - Required parameter validation tests are correct.The tests properly verify that
kARResidualRMSNormpattern requires bothresidual_inandrms_gammainputs, and that meaningful error messages are raised when they're missing.flashinfer/comm/workspace_base.py (2)
23-34: LGTM - Clean abstract base class definition.The base class properly establishes the contract for AllReduce fusion workspaces with appropriate type annotations and initialization logic.
63-89: LGTM - Appropriate destructor safety net.The
__del__implementation correctly handles the constraints of Python destructors. The broadExceptioncatch (flagged by Ruff BLE001) is intentional here since exceptions cannot propagate from__del__, and the warning provides useful diagnostic information. The docstring clearly communicates that users should not rely on this behavior.flashinfer/comm/trtllm_ar.py (6)
125-127: LGTM - Clear deprecation notice.The deprecation message clearly directs users to the new unified API.
400-402: LGTM - Consistent deprecation messaging.
503-506: LGTM - Informative deprecation with migration guidance.
809-852: LGTM - Well-structured workspace metadata validation.The validation function properly checks required keys first (failing fast if missing), then validates the values. The error accumulation pattern provides comprehensive error messages when multiple validation issues exist.
854-856: LGTM - Deprecation notice for legacy fusion function.
909-913: LGTM - Clean integration of metadata validation.The optional metadata validation is properly integrated into the existing function flow.
tests/comm/test_trtllm_allreduce_fusion.py (6)
25-27: LGTM - Function signature updated to support both APIs.The
legacy_api=Truedefault maintains backward compatibility while allowing explicit selection of the unified API.
62-88: LGTM - Clean workspace creation branching.The code properly separates legacy and unified workspace creation paths, with the legacy path returning metadata for validation and the unified path using the new workspace abstraction.
392-403: LGTM - Proper cleanup logic for both API paths.The cleanup correctly branches based on the API path and handles the case where the unified workspace may be None (if creation failed).
448-471: LGTM - Well-structured test parameterization.Adding the
legacy_apiparameter doubles the test coverage to ensure both API paths are validated across all configuration combinations.
474-479: LGTM - Manual test entry point for both APIs.
209-240: The unified API correctly handlestrigger_completion_at_endinternally. At line 616 ofallreduce.py, the parameter is automatically set astrigger_completion_at_end=launch_with_pdl, with an inline comment indicating they have the same meaning. The unified API intentionally doesn't expose this as an explicit parameter; instead, it's controlled via thelaunch_with_pdlargument passed toallreduce_fusion. This is a cleaner API design with equivalent behavior to the legacy path.tests/comm/test_trtllm_mnnvl_allreduce.py (3)
2-2: LGTM - Added traceback import for better error diagnostics.
105-227: LGTM - Legacy API test function properly preserved.The legacy function maintains the original API surface with explicit buffer pointers, enabling backward compatibility testing.
435-460: LGTM - Clear test separation between refactored and legacy APIs.The parameterization differences between the two tests appropriately reflect the different capabilities of each API path.
tests/comm/test_allreduce_unified_api.py (1)
1-24: LGTM! Well-structured imports and module setup.The imports are appropriate for MPI-based distributed testing with the unified API.
flashinfer/comm/trtllm_mnnvl_ar.py (5)
31-48: LGTM! Clean strategy enum with threshold-based selection.The
MNNVLAllreduceFusionStrategyenum andselect_strategymethod provide a clear interface for choosing between oneshot and twoshot strategies based on problem size.
129-130: Clarify rank attribute semantics.Line 129 overwrites
self.rank(set by base class__init__tomapping.rank) withmapping.tp_rank. Ifrankandtp_rankdiffer, this creates ambiguity about which rank value the workspace represents. The base classAllReduceFusionWorkspaceexpectsrankto be the process rank, but this stores the TP rank.Please verify this is intentional. If both values are needed, consider using distinct attribute names:
- self.rank = mapping.tp_rank + self.tp_rank = mapping.tp_rank
327-402: LGTM! Clean implementation with good input validation.The function provides clear input validation and error messages. The strategy selection and buffer size checking are well-implemented.
405-496: LGTM! Comprehensive input validation for fused operation.The function properly validates all tensor dimensions and provides informative error messages. The default epsilon handling is appropriate.
499-633: Good deprecation strategy for backward compatibility.The
@deprecateddecorators with clear migration messages provide a smooth upgrade path for users of the legacy API.flashinfer/comm/allreduce.py (3)
1-71: LGTM! Well-documented module with clear example usage.The module docstring provides a good overview and usage example. Imports are appropriately organized.
568-639: LGTM! Clean TRTLLM backend dispatch with proper tensor handling.The contiguity check and flattening logic is well-documented. The comment explaining why contiguous tensors are required for view semantics is helpful.
641-702: LGTM! Proper pattern validation for MNNVL backend.Good validation of supported patterns with clear error messages directing users to TRTLLM for unsupported fusion patterns.
| DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(input.dtype(), c_type, [&] { | ||
| // Extract parameters from tensors | ||
| int64_t num_tokens = in.size(0); | ||
| int64_t token_dim = in.size(1); | ||
| int64_t num_tokens = input.size(0); | ||
| int64_t token_dim = input.size(1); | ||
|
|
||
| // Validate input parameters | ||
| TVM_FFI_ICHECK_EQ(token_dim % (sizeof(float2) / sizeof(c_type)), 0) | ||
| << "token_dim must be divisible by " << sizeof(float2) / sizeof(c_type); | ||
| TVM_FFI_ICHECK_EQ(token_dim % (sizeof(float4) / sizeof(c_type)), 0) | ||
| << "token_dim must be divisible by " << sizeof(float4) / sizeof(c_type); | ||
| TVM_FFI_ICHECK(output.size(0) == input.size(0) && output.size(1) == input.size(1)) | ||
| << "output shape mismatch: expected (" << input.size(0) << ", " << input.size(1) | ||
| << ") but got (" << output.size(0) << ", " << output.size(1) << ")"; | ||
| TVM_FFI_ICHECK(nranks >= 2 && nranks <= 64) | ||
| << "nranks must be between 2 and 64, got " << nranks; | ||
| TVM_FFI_ICHECK(rank >= 0 && rank < nranks) | ||
| << "rank must be between 0 and nranks-1, got " << rank; | ||
| TVM_FFI_ICHECK(out.has_value() || !wait_for_results) | ||
| << "out tensor must be provided if wait_for_results is true"; | ||
| TVM_FFI_ICHECK((residual_in.has_value() && residual_out.has_value() && gamma.has_value() && | ||
| epsilon.has_value()) || | ||
| !rmsnorm_fusion) | ||
| << "residual_in, residual_out, gamma, and epsilon must be provided if rmsnorm_fusion is " | ||
| "true"; | ||
|
|
||
| if (rmsnorm_fusion) { | ||
| TVM_FFI_ICHECK(residual_in.value().size(0) == num_tokens && | ||
| residual_in.value().size(1) == token_dim) | ||
| << "residual_in shape mismatch: expected (" << input.size(0) << ", " << input.size(1) | ||
| << ") but got (" << residual_in.value().size(0) << ", " << residual_in.value().size(1) | ||
| << ")"; | ||
| TVM_FFI_ICHECK(residual_out.value().size(0) == num_tokens && | ||
| residual_out.value().size(1) == token_dim) | ||
| << "residual_out shape mismatch: expected (" << input.size(0) << ", " << input.size(1) | ||
| << ") but got (" << residual_out.value().size(0) << ", " << residual_out.value().size(1) | ||
| << ")"; | ||
| TVM_FFI_ICHECK(gamma.value().size(0) == token_dim) | ||
| << "gamma must have the same shape as token dimension (" << token_dim << ") but got (" | ||
| << gamma.value().size(0) << ")"; | ||
| } |
There was a problem hiding this comment.
Missing dtype validation across tensors (risk: type confusion / memory corruption)
c_type is derived from input.dtype(), but output (and residual_in/out, gamma) are not checked to have the same dtype. If a caller passes mismatched dtypes, the kernel will reinterpret memory with the wrong element size. Add dtype checks for all participating tensors (especially in rmsnorm_fusion).
DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(input.dtype(), c_type, [&] {
+ TVM_FFI_ICHECK_EQ(encode_dlpack_dtype(output.dtype()), encode_dlpack_dtype(input.dtype()))
+ << "output dtype must match input dtype";
...
if (rmsnorm_fusion) {
+ TVM_FFI_ICHECK_EQ(encode_dlpack_dtype(residual_in.value().dtype()), encode_dlpack_dtype(input.dtype()))
+ << "residual_in dtype must match input dtype";
+ TVM_FFI_ICHECK_EQ(encode_dlpack_dtype(residual_out.value().dtype()), encode_dlpack_dtype(input.dtype()))
+ << "residual_out dtype must match input dtype";
+ TVM_FFI_ICHECK_EQ(encode_dlpack_dtype(gamma.value().dtype()), encode_dlpack_dtype(input.dtype()))
+ << "gamma dtype must match input dtype";
...
}Also applies to: 79-103
🤖 Prompt for AI Agents
In csrc/trtllm_mnnvl_allreduce.cu around lines 39 to 74 (and similarly apply
checks at lines ~79-103), the code derives c_type from input.dtype() but never
verifies that output and optional tensors share the same dtype; this can cause
type confusion and memory corruption. Add TVM_FFI_ICHECK (or equivalent) asserts
that output.dtype() == input.dtype(), and when rmsnorm_fusion is true also
assert residual_in.dtype(), residual_out.dtype(), and gamma.dtype() match
input.dtype() (and epsilon if it is a tensor), updating the error messages to
name the offending tensor and expected dtype. Ensure these checks are placed
before any pointer casts or size/divisibility checks so dtype mismatches are
caught early.
| Use `workspace.is_sufficient_for(token_num, hidden_dim, dtype)` to check before use. | ||
|
|
||
| Args: | ||
| backend: Backend to use ("trtllm", "mnnvl", or "auto") | ||
| "auto" uses heuristic to select best backend based on topology | ||
| and problem size | ||
| world_size: Number of ranks in the process group | ||
| rank: Current rank ID | ||
| max_token_num: Maximum number of tokens to support | ||
| hidden_dim: Hidden dimension size | ||
| dtype: Data type for communication tensors | ||
| topology: Network topology hint for backend selection | ||
| "single_node" - All ranks on one node (default) | ||
| "multi_node" - Ranks span multiple nodes | ||
| process_group: PyTorch distributed process group (for trtllm backend). | ||
| gpus_per_node: Number of GPUs per node (for multi-node topology). | ||
| comm_backend: Communication backend to use (for multi-node topology). | ||
|
|
||
| Returns: | ||
| Workspace object (TRTLLMAllReduceFusionWorkspace or MNNVLAllReduceFusionWorkspace) | ||
| The workspace type determines which backend will be used in allreduce_fusion() | ||
|
|
||
| Raises: | ||
| BackendSupportedError: If no suitable backend available for the configuration | ||
| ValueError: If problem size not supported for the specified backend |
There was a problem hiding this comment.
Documentation inconsistencies with implementation.
- Line 314 references
workspace.is_sufficient_for()but the actual method isis_buffer_size_sufficient(). - Line 337 mentions
BackendSupportedErrorbut the code raisesValueError(Line 398).
- Use `workspace.is_sufficient_for(token_num, hidden_dim, dtype)` to check before use.
+ Use `workspace.is_buffer_size_sufficient(tp_size, token_num, hidden_dim, dtype)` to check before use.
...
Raises:
- BackendSupportedError: If no suitable backend available for the configuration
+ ValueError: If no suitable backend available for the configuration
ValueError: If problem size not supported for the specified backend📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| Use `workspace.is_sufficient_for(token_num, hidden_dim, dtype)` to check before use. | |
| Args: | |
| backend: Backend to use ("trtllm", "mnnvl", or "auto") | |
| "auto" uses heuristic to select best backend based on topology | |
| and problem size | |
| world_size: Number of ranks in the process group | |
| rank: Current rank ID | |
| max_token_num: Maximum number of tokens to support | |
| hidden_dim: Hidden dimension size | |
| dtype: Data type for communication tensors | |
| topology: Network topology hint for backend selection | |
| "single_node" - All ranks on one node (default) | |
| "multi_node" - Ranks span multiple nodes | |
| process_group: PyTorch distributed process group (for trtllm backend). | |
| gpus_per_node: Number of GPUs per node (for multi-node topology). | |
| comm_backend: Communication backend to use (for multi-node topology). | |
| Returns: | |
| Workspace object (TRTLLMAllReduceFusionWorkspace or MNNVLAllReduceFusionWorkspace) | |
| The workspace type determines which backend will be used in allreduce_fusion() | |
| Raises: | |
| BackendSupportedError: If no suitable backend available for the configuration | |
| ValueError: If problem size not supported for the specified backend | |
| Use `workspace.is_buffer_size_sufficient(tp_size, token_num, hidden_dim, dtype)` to check before use. | |
| Args: | |
| backend: Backend to use ("trtllm", "mnnvl", or "auto") | |
| "auto" uses heuristic to select best backend based on topology | |
| and problem size | |
| world_size: Number of ranks in the process group | |
| rank: Current rank ID | |
| max_token_num: Maximum number of tokens to support | |
| hidden_dim: Hidden dimension size | |
| dtype: Data type for communication tensors | |
| topology: Network topology hint for backend selection | |
| "single_node" - All ranks on one node (default) | |
| "multi_node" - Ranks span multiple nodes | |
| process_group: PyTorch distributed process group (for trtllm backend). | |
| gpus_per_node: Number of GPUs per node (for multi-node topology). | |
| comm_backend: Communication backend to use (for multi-node topology). | |
| Returns: | |
| Workspace object (TRTLLMAllReduceFusionWorkspace or MNNVLAllReduceFusionWorkspace) | |
| The workspace type determines which backend will be used in allreduce_fusion() | |
| Raises: | |
| ValueError: If no suitable backend available for the configuration | |
| ValueError: If problem size not supported for the specified backend |
🤖 Prompt for AI Agents
In flashinfer/comm/allreduce.py around lines 314 to 338, the docstring is
inconsistent with the implementation: replace the reference to
workspace.is_sufficient_for(...) with the actual method name
workspace.is_buffer_size_sufficient(...), and change the listed raised exception
BackendSupportedError to ValueError to match the current code (or alternatively
implement a BackendSupportedError and raise it at line ~398 if you prefer
changing code instead of docs); update the docstring accordingly so method names
and exception types match the implementation.
| # The helper class for passing the FD handle over the socket. | ||
| class IpcSocket: | ||
| """Unix Domain Socket for IPC file descriptor passing""" | ||
|
|
||
| def __init__(self, rank: int, op_id: int, use_abstract=True): | ||
| """ | ||
| Initialize IPC socket | ||
|
|
||
| Args: | ||
| rank: Process rank | ||
| op_id: Unique operation ID (hash) | ||
| use_abstract: Use Linux abstract socket namespace | ||
| """ | ||
| self.rank = rank | ||
| self.op_id = op_id | ||
| self.use_abstract = use_abstract | ||
|
|
||
| # Create Unix domain socket (DGRAM for compatibility with C code) | ||
| self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) | ||
|
|
||
| # Create unique socket name | ||
| socket_name = f"/tmp/mcastmem-socket-{rank}-{op_id:x}" | ||
|
|
||
| if use_abstract: | ||
| # Linux abstract socket: prepend null byte | ||
| self.socket_path = "\0" + socket_name | ||
| else: | ||
| self.socket_path = socket_name | ||
| # Remove existing socket file if it exists | ||
| with contextlib.suppress(FileNotFoundError): | ||
| os.unlink(socket_name) | ||
|
|
||
| # Bind socket | ||
| self.sock.bind(self.socket_path) | ||
|
|
||
| def send_fd(self, fd: int, dest_rank: int, dest_op_id: Optional[int] = None): | ||
| """ | ||
| Send a file descriptor to another process | ||
|
|
||
| Args: | ||
| fd: File descriptor to send | ||
| dest_rank: Destination process rank | ||
| dest_op_id: Destination operation ID | ||
| """ | ||
| # Construct destination socket path | ||
| dest_op_id = dest_op_id or self.op_id | ||
| dest_socket_name = f"/tmp/mcastmem-socket-{dest_rank}-{dest_op_id:x}" | ||
|
|
||
| if self.use_abstract: | ||
| dest_path = "\0" + dest_socket_name | ||
| else: | ||
| dest_path = dest_socket_name | ||
|
|
||
| # Prepare message with file descriptor | ||
| # Send dummy byte as data (required) | ||
| dummy_data = b"\x00" | ||
|
|
||
| # Pack file descriptor in ancillary data (SCM_RIGHTS) | ||
| fds = array.array("i", [fd]) | ||
| ancillary = [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds.tobytes())] | ||
|
|
||
| # Send message with file descriptor | ||
| self.sock.sendmsg([dummy_data], ancillary, 0, dest_path) | ||
|
|
||
| def recv_fd(self): | ||
| """ | ||
| Receive a file descriptor from another process | ||
|
|
||
| Returns: | ||
| int: Received file descriptor | ||
| """ | ||
| # Receive message with ancillary data | ||
| # Maximum size for ancillary data containing one fd | ||
| fds = array.array("i") | ||
| msg, ancdata, flags, addr = self.sock.recvmsg( | ||
| 1, | ||
| socket.CMSG_SPACE( | ||
| fds.itemsize | ||
| ), # Buffer size for dummy data # Ancillary data size | ||
| ) | ||
|
|
||
| # Extract file descriptor from ancillary data | ||
| for cmsg_level, cmsg_type, cmsg_data in ancdata: | ||
| if cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS: | ||
| fds = array.array("i") | ||
| fds.frombytes( | ||
| cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)] | ||
| ) | ||
| return fds[0] | ||
|
|
||
| raise RuntimeError("No file descriptor received") | ||
|
|
||
| def close(self): | ||
| """Close the socket""" | ||
| self.sock.close() | ||
| if not self.use_abstract and self.socket_path: | ||
| with contextlib.suppress(FileNotFoundError): | ||
| os.unlink(self.socket_path) | ||
|
|
||
|
|
There was a problem hiding this comment.
IPC socket naming + op-id generation: avoid /tmp semantics and random
Even with abstract sockets, embedding "/tmp/..." in the abstract name is misleading (and ruff flags it); also random.randint() is avoidable. Prefer an abstract name without a filesystem-looking prefix and generate opId via secrets.randbits(64) (collision-resistant, no extra baggage).
-import random
+import secrets
...
class PosixFDHandleExchanger(HandleExchanger):
def _init_ipc_socket(self) -> IpcSocket:
if self.rank == 0:
- opId = random.randint(0, 2**64 - 1)
+ opId = secrets.randbits(64)- socket_name = f"/tmp/mcastmem-socket-{rank}-{op_id:x}"
+ socket_name = f"mcastmem-socket-{rank}-{op_id:x}"Also applies to: 717-731
🧰 Tools
🪛 Ruff (0.14.8)
587-587: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
612-612: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
640-640: Unpacked variable msg is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
640-640: Unpacked variable flags is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
640-640: Unpacked variable addr is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
656-656: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
flashinfer/comm/mnnvl.py lines 566-665 (and similar block at 717-731): the code
embeds a filesystem-looking "/tmp/..." string into abstract socket names and
relies on random.randint for op_id generation; change the socket name to a
concise abstract-only identifier (e.g., f"mcastmem-{rank}-{op_id:x}" without any
leading "/tmp") and keep the null-byte prefix only when use_abstract is True;
for non-abstract sockets keep the full filesystem path but do not use that
pattern for abstract names; replace any random.randint op_id generation sites
with secrets.randbits(64) to produce collision-resistant IDs; ensure unlinking
of socket files only happens for non-abstract filesystem paths and update
docstrings/comments accordingly.
| #include <cuda_runtime.h> | ||
|
|
||
| #include <iostream> | ||
| #include <type_traits> | ||
|
|
There was a problem hiding this comment.
Host-side __CUDA_ARCH__ gating won’t do what you expect (cluster features likely disabled)
adjustGridConfig() and the host dispatch code use #if defined(__CUDA_ARCH__) to select cluster sizing and to set cudaLaunchAttributeClusterDimension, but __CUDA_ARCH__ is not defined for host compilation, so these branches won’t trigger at runtime (even on SM90+). Prefer a runtime compute capability / attribute check (e.g., cudaDeviceGetAttribute or flashinfer::GetCudaComputeCapability()) and gate cluster attributes based on that.
Also applies to: 54-171, 448-498, 655-704, 1124-1150
🤖 Prompt for AI Agents
In include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh around lines 21-25 (and
also apply similar changes to ranges 54-171, 448-498, 655-704, 1124-1150): the
code incorrectly uses compile-time __CUDA_ARCH__ to gate host-side logic (so
cluster/SM90 features are never enabled at runtime). Replace those #if
defined(__CUDA_ARCH__) host branches with a runtime check—use
cudaDeviceGetAttribute (or flashinfer::GetCudaComputeCapability()) to query the
device compute capability or the relevant CUDA attribute and only set/use
cudaLaunchAttributeClusterDimension when the runtime capability indicates
cluster support (e.g., major>=9 or specific attribute). Update
adjustGridConfig() and the host dispatch code to call the runtime check, branch
based on its result, and ensure any cudaLaunchKernel attribute setting is
guarded by that runtime boolean.
| #include <iostream> | ||
| #include <type_traits> | ||
|
|
There was a problem hiding this comment.
Missing direct standard headers for new types (std::array, std::tuple)
This file uses std::array and std::tuple but doesn’t include <array> / <tuple> locally. Relying on transitive includes is brittle and can break depending on include order.
#include <iostream>
#include <type_traits>
+ #include <array>
+ #include <tuple>Also applies to: 123-142, 143-163, 243-289, 450-498
🤖 Prompt for AI Agents
In include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh around lines 23-25 (and
also relevant to usages in ranges 123-142, 143-163, 243-289, 450-498), the file
uses std::array and std::tuple but does not include the direct standard headers;
add #include <array> and #include <tuple> at the top of this file (near the
existing <type_traits> include) so the types are guaranteed to be available
regardless of transitive includes, then rebuild to ensure no other
missing-direct-header warnings remain.
| PackedVec<PackedType, float> valuesLamport[WorldSize]; | ||
| while (1) { | ||
| bool valid = true; | ||
| #pragma unroll | ||
| for (int r = 0; r < WorldSize; r++) { | ||
| valuesLamport[r].packed = loadPackedVolatile<PackedType>( | ||
| &stagePtrLocal[token * tokenDim * WorldSize + r * tokenDim + | ||
| packedIdx * kELTS_PER_THREAD]); | ||
|
|
||
| #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) | ||
| #pragma unroll | ||
| for (int i = 0; i < kLAMPORT_ELTS_PER_PACKED; i++) { | ||
| valid &= !isNegZero(valuesLamport[r].elements[i]); | ||
| } | ||
| } | ||
| if (valid) { | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| auto values = reinterpret_cast<PackedVec<PackedType, T>*>(valuesLamport); | ||
| // ======================= Reduction ============================= |
There was a problem hiding this comment.
UB risk: type-punning via reinterpret_cast between different PackedVec instantiations
auto values = reinterpret_cast<PackedVec<PackedType, T>*>(valuesLamport); relies on aliasing between PackedVec<PackedType, float> and PackedVec<PackedType, T>. That’s undefined behavior in C++ and can miscompile under optimization. Consider loading into a PackedVec<PackedType, T> per-rank (or use memcpy/bit_cast-style copying) rather than pointer re-interpretation.
Also applies to: 825-847
🤖 Prompt for AI Agents
In include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh around lines 560-580 (and
similarly at 825-847), the code performs type-punning via reinterpret_cast from
PackedVec<PackedType,float> to PackedVec<PackedType,T>, which is undefined
behavior; replace the reinterpret_cast by loading/storing into the correct
target type: either declare the per-rank buffer as PackedVec<PackedType,T> and
call the appropriate load routine into it, or explicitly copy the bytes from the
float-instantiated buffer into a PackedVec<PackedType,T> instance using a safe
byte-wise copy (memcpy or std::bit_cast-like copy) for each element before use;
ensure volatile semantics are preserved for the load step and remove the
reinterpret_cast to eliminate aliasing UB.
| @@ -21,6 +21,7 @@ | |||
| #include <cuda_fp8.h> | |||
| #include <cuda_runtime.h> | |||
|
|
|||
| #include <atomic> | |||
There was a problem hiding this comment.
Cache is not device-aware + caches failures silently
GetCudaMultiProcessorCount() caches a single sm_count for the whole process and never revalidates on device changes; also it doesn’t check CUDA return codes, so a transient failure can get cached as “0”. Consider caching (device_id, sm_count) and only storing on success.
inline int GetCudaMultiProcessorCount() {
- static std::atomic<int> sm_count{0};
- int cached = sm_count.load(std::memory_order_relaxed);
- if (cached == 0) {
- int device_id;
- cudaGetDevice(&device_id);
- cudaDeviceProp device_prop;
- cudaGetDeviceProperties(&device_prop, device_id);
- cached = device_prop.multiProcessorCount;
- sm_count.store(cached, std::memory_order_relaxed);
- }
- return cached;
+ static std::atomic<int> cached_device{-1};
+ static std::atomic<int> cached_sm{0};
+
+ int device_id = 0;
+ if (cudaGetDevice(&device_id) != cudaSuccess) return 0;
+
+ int sm = cached_sm.load(std::memory_order_relaxed);
+ int dev = cached_device.load(std::memory_order_relaxed);
+ if (sm != 0 && dev == device_id) return sm;
+
+ cudaDeviceProp prop{};
+ if (cudaGetDeviceProperties(&prop, device_id) != cudaSuccess) return 0;
+
+ cached_device.store(device_id, std::memory_order_relaxed);
+ cached_sm.store(prop.multiProcessorCount, std::memory_order_relaxed);
+ return prop.multiProcessorCount;
}Also applies to: 339-353
| def prepare_test_data(seq_len: int, hidden_size: int, dtype: torch.dtype, fusion: bool): | ||
| """Prepare test data distributed across MPI ranks.""" | ||
| comm = MPI.COMM_WORLD | ||
| rank = comm.Get_rank() | ||
| world_size = comm.Get_size() | ||
| if rank == 0: | ||
| x_full = torch.randn((world_size, seq_len, hidden_size), dtype=dtype) | ||
| residual = torch.randn((seq_len, hidden_size), dtype=dtype) | ||
| norm_weight = torch.randn((hidden_size,), dtype=dtype) | ||
| else: | ||
| x_full = None | ||
| residual = None | ||
| norm_weight = None | ||
|
|
||
| # Use lowercase bcast() for Python object broadcasting | ||
| x_full = comm.bcast(x_full, root=0) | ||
| residual = comm.bcast(residual, root=0) | ||
| norm_weight = comm.bcast(norm_weight, root=0) | ||
|
|
||
| x_full = x_full.cuda() | ||
| residual = residual.cuda() | ||
| norm_weight = norm_weight.cuda() | ||
|
|
||
| x_local = x_full[rank, :, :] | ||
| reference_output: Tuple[torch.Tensor, ...] = None | ||
| if fusion: | ||
| # Fused case: AllReduce + Residual Add + RMS Norm | ||
| allreduce_result = torch.sum(x_full, dim=0) # AllReduce result | ||
| residual_out = allreduce_result + residual # Add residual | ||
| norm_out = rmsnorm( | ||
| residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False | ||
| ) | ||
|
|
||
| reference_output = (norm_out, residual_out) | ||
| else: | ||
| # Non-fused case: Only AllReduce | ||
| allreduce_result = torch.sum(x_full, dim=0) # AllReduce result | ||
| reference_output = (allreduce_result,) | ||
| return (x_local, residual, norm_weight), reference_output |
There was a problem hiding this comment.
Epsilon inconsistency between reference and test.
The reference RMSNorm at Line 179 uses torch.finfo(dtype).eps, but run_allreduce_test passes eps = 1e-5 (Line 234) to the actual fusion operation. This inconsistency could cause subtle differences between reference and test outputs.
Consider passing the same epsilon value to both:
-def prepare_test_data(seq_len: int, hidden_size: int, dtype: torch.dtype, fusion: bool):
+def prepare_test_data(seq_len: int, hidden_size: int, dtype: torch.dtype, fusion: bool, eps: float):
...
norm_out = rmsnorm(
- residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False
+ residual_out, norm_weight, eps, enable_pdl=False
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def prepare_test_data(seq_len: int, hidden_size: int, dtype: torch.dtype, fusion: bool): | |
| """Prepare test data distributed across MPI ranks.""" | |
| comm = MPI.COMM_WORLD | |
| rank = comm.Get_rank() | |
| world_size = comm.Get_size() | |
| if rank == 0: | |
| x_full = torch.randn((world_size, seq_len, hidden_size), dtype=dtype) | |
| residual = torch.randn((seq_len, hidden_size), dtype=dtype) | |
| norm_weight = torch.randn((hidden_size,), dtype=dtype) | |
| else: | |
| x_full = None | |
| residual = None | |
| norm_weight = None | |
| # Use lowercase bcast() for Python object broadcasting | |
| x_full = comm.bcast(x_full, root=0) | |
| residual = comm.bcast(residual, root=0) | |
| norm_weight = comm.bcast(norm_weight, root=0) | |
| x_full = x_full.cuda() | |
| residual = residual.cuda() | |
| norm_weight = norm_weight.cuda() | |
| x_local = x_full[rank, :, :] | |
| reference_output: Tuple[torch.Tensor, ...] = None | |
| if fusion: | |
| # Fused case: AllReduce + Residual Add + RMS Norm | |
| allreduce_result = torch.sum(x_full, dim=0) # AllReduce result | |
| residual_out = allreduce_result + residual # Add residual | |
| norm_out = rmsnorm( | |
| residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False | |
| ) | |
| reference_output = (norm_out, residual_out) | |
| else: | |
| # Non-fused case: Only AllReduce | |
| allreduce_result = torch.sum(x_full, dim=0) # AllReduce result | |
| reference_output = (allreduce_result,) | |
| return (x_local, residual, norm_weight), reference_output | |
| def prepare_test_data(seq_len: int, hidden_size: int, dtype: torch.dtype, fusion: bool, eps: float): | |
| """Prepare test data distributed across MPI ranks.""" | |
| comm = MPI.COMM_WORLD | |
| rank = comm.Get_rank() | |
| world_size = comm.Get_size() | |
| if rank == 0: | |
| x_full = torch.randn((world_size, seq_len, hidden_size), dtype=dtype) | |
| residual = torch.randn((seq_len, hidden_size), dtype=dtype) | |
| norm_weight = torch.randn((hidden_size,), dtype=dtype) | |
| else: | |
| x_full = None | |
| residual = None | |
| norm_weight = None | |
| # Use lowercase bcast() for Python object broadcasting | |
| x_full = comm.bcast(x_full, root=0) | |
| residual = comm.bcast(residual, root=0) | |
| norm_weight = comm.bcast(norm_weight, root=0) | |
| x_full = x_full.cuda() | |
| residual = residual.cuda() | |
| norm_weight = norm_weight.cuda() | |
| x_local = x_full[rank, :, :] | |
| reference_output: Tuple[torch.Tensor, ...] = None | |
| if fusion: | |
| # Fused case: AllReduce + Residual Add + RMS Norm | |
| allreduce_result = torch.sum(x_full, dim=0) # AllReduce result | |
| residual_out = allreduce_result + residual # Add residual | |
| norm_out = rmsnorm( | |
| residual_out, norm_weight, eps, enable_pdl=False | |
| ) | |
| reference_output = (norm_out, residual_out) | |
| else: | |
| # Non-fused case: Only AllReduce | |
| allreduce_result = torch.sum(x_full, dim=0) # AllReduce result | |
| reference_output = (allreduce_result,) | |
| return (x_local, residual, norm_weight), reference_output |
🤖 Prompt for AI Agents
In tests/comm/test_allreduce_unified_api.py around lines 149 to 187, the
reference RMSNorm uses torch.finfo(dtype).eps while the actual fusion
run_allreduce_test uses eps = 1e-5, causing a mismatch; fix by computing a
single eps value (e.g., eps = 1e-5 or derive it once based on dtype) before
building reference_output and pass that same eps into the rmsnorm call so both
reference and tested code use identical epsilon.
9e672bf to
f707678
Compare
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (2)
flashinfer/comm/allreduce.py (1)
314-338: Documentation inconsistencies remain from previous review.As noted in a previous review comment:
- Line 314 references
workspace.is_sufficient_for()but the actual method isis_buffer_size_sufficient()- Line 337 mentions
BackendSupportedErrorbut the code raisesValueErrorat line 398tests/comm/test_allreduce_unified_api.py (1)
149-187: Epsilon inconsistency between reference and test remains.As noted in a previous review: the reference RMSNorm at line 179 uses
torch.finfo(dtype).eps, butrun_allreduce_testuseseps = 1e-5(line 234). This can cause subtle numerical differences between reference and test outputs.For
float16,torch.finfo(dtype).eps ≈ 9.77e-4, which differs significantly from1e-5.
🧹 Nitpick comments (7)
flashinfer/comm/allreduce.py (5)
138-144:__getattr__delegation may cause infinite recursion.The
__getattr__method attempts to delegate attribute access to_internal_workspace, but if_internal_workspaceitself doesn't exist (e.g., during initialization failure), this will cause infinite recursion. The check forname.startswith("_")helps but doesn't fully protect against this edge case.Consider adding a safeguard:
def __getattr__(self, name): """Delegate attribute access to internal workspace if not found.""" if name.startswith("_"): raise AttributeError( f"'{type(self).__name__}' object has no attribute '{name}'" ) + if "_internal_workspace" not in self.__dict__: + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) return getattr(self._internal_workspace, name)
146-161: Consider using logging instead of print for validation failures.The
print(f"Workspace is insufficient...")on line 160 outputs directly to stdout. For a library function, consider usinglogging.warning()or returning a more structured error message. This aligns with the logging usage elsewhere in the codebase (e.g.,trtllm_ar.pyuseslogging.warning).The unused
use_oneshotparameter (Ruff ARG002) is intentional for ABC interface compliance withMNNVLAllReduceFusionWorkspace.is_buffer_size_sufficientwhich does use this parameter.
200-218: Simplify_mnnvl_workspace_check- currently always returnsTrue.The function has a redundant conditional structure:
if topology == "multi_node": return True return TrueBoth branches return
True. Consider simplifying to justreturn Truewith a comment explaining that MNNVL works with both topologies, or remove the function entirely and inline the check. The unused parameters are acceptable for interface consistency with_trtllm_workspace_check.def _mnnvl_workspace_check( backend: str, world_size: int, rank: int, max_token_num: int, hidden_dim: int, dtype: torch.dtype, topology: Literal["single_node", "multi_node"], ) -> bool: """ Check if mnnvl backend CAN be used for workspace creation. - + MNNVL supports both single-node and multi-node topologies. """ - - if topology == "multi_node": - return True - return True
682-693: Unusedresidual_resultvariable.The unpacked
residual_resultis never used (Ruff RUF059). The residual output is written to theresidual_outtensor passed as an argument. Consider using_to indicate the value is intentionally ignored:- norm_result, residual_result = trtllm_mnnvl_fused_allreduce_add_rmsnorm( + norm_result, _ = trtllm_mnnvl_fused_allreduce_add_rmsnorm(
286-297: Consider explicitOptionaltypes for parameters withNonedefaults.Several parameters have
Noneas default but aren't explicitly typed asOptional:world_size: int = None, # Should be Optional[int] = None rank: int = None, max_token_num: int = None, hidden_dim: int = None, gpus_per_node: int = None,Per PEP 484, implicit
Optionalis discouraged (Ruff RUF013). However, these appear to be required parameters in practice -Nonevalues would cause runtime errors. Consider either:
- Adding validation and raising early if
None- Making them required (no default)
- Explicitly typing as
Optional[int]tests/comm/test_allreduce_unified_api.py (2)
26-45: Hardcoded MASTER_PORT may cause port conflicts.The
MASTER_PORT = "29500"is hardcoded. If another process is using this port or if multiple test runs execute concurrently, initialization will fail. Consider using a dynamic port similar toget_open_port()intest_trtllm_allreduce_fusion.py:import socket def get_open_port(): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("127.0.0.1", 0)) return str(s.getsockname()[1])However, since MPI tests typically run in isolation and this port is only used for torch.distributed NCCL rendezvous, the risk is lower in practice.
190-197: Unusedmonkeypatchparameter.The
monkeypatchfixture is passed torun_allreduce_testbut never used (Ruff ARG001). Consider removing it if not needed, or prefix with underscore if planned for future use:def run_allreduce_test( - monkeypatch, + _monkeypatch, # Reserved for future use seq_lens: list[int], ...Or simply remove from both
run_allreduce_testandtest_allreduce_unified:-def test_allreduce_unified( - monkeypatch, +def test_allreduce_unified( seq_lens: list[int],
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between 9e672bf7434e07037423c5f9dc324042492e1bb8 and f707678.
📒 Files selected for processing (9)
flashinfer/comm/__init__.py(1 hunks)flashinfer/comm/allreduce.py(1 hunks)flashinfer/comm/trtllm_ar.py(5 hunks)flashinfer/comm/trtllm_mnnvl_ar.py(10 hunks)flashinfer/comm/workspace_base.py(1 hunks)tests/comm/test_allreduce_negative.py(1 hunks)tests/comm/test_allreduce_unified_api.py(1 hunks)tests/comm/test_trtllm_allreduce_fusion.py(6 hunks)tests/comm/test_trtllm_mnnvl_allreduce.py(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- flashinfer/comm/init.py
- tests/comm/test_trtllm_mnnvl_allreduce.py
- flashinfer/comm/trtllm_mnnvl_ar.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
flashinfer/comm/allreduce.py
🧬 Code graph analysis (4)
flashinfer/comm/trtllm_ar.py (1)
flashinfer/logits_processor/types.py (1)
dtype(126-130)
tests/comm/test_allreduce_unified_api.py (7)
flashinfer/comm/allreduce.py (4)
create_allreduce_fusion_workspace(286-444)allreduce_fusion(452-702)backend(135-136)destroy(163-169)flashinfer/comm/trtllm_ar.py (1)
AllReduceFusionPattern(64-77)flashinfer/comm/workspace_base.py (1)
AllReduceFusionWorkspace(23-89)flashinfer/comm/trtllm_mnnvl_ar.py (3)
backend(229-230)mpi_barrier(24-28)destroy(232-242)include/flashinfer/trtllm/fused_moe/runner.h (1)
hidden_size(265-265)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
enable_pdl(220-220)flashinfer/comm/mapping.py (1)
local_rank(391-392)
flashinfer/comm/allreduce.py (4)
flashinfer/comm/workspace_base.py (4)
AllReduceFusionWorkspace(23-89)backend(38-40)is_buffer_size_sufficient(53-61)destroy(43-50)flashinfer/comm/trtllm_ar.py (3)
trtllm_allreduce_fusion(230-275)trtllm_allreduce_fusion(857-963)AllReduceFusionPattern(64-77)flashinfer/comm/mapping.py (2)
Mapping(21-475)tp_rank(325-326)flashinfer/comm/trtllm_mnnvl_ar.py (4)
MNNVLAllReduceFusionWorkspace(51-242)backend(229-230)is_buffer_size_sufficient(180-197)destroy(232-242)
flashinfer/comm/workspace_base.py (2)
flashinfer/comm/allreduce.py (3)
backend(135-136)destroy(163-169)is_buffer_size_sufficient(146-161)flashinfer/comm/trtllm_mnnvl_ar.py (3)
backend(229-230)destroy(232-242)is_buffer_size_sufficient(180-197)
🪛 Ruff (0.14.8)
tests/comm/test_allreduce_unified_api.py
191-191: Unused function argument: monkeypatch
(ARG001)
flashinfer/comm/allreduce.py
141-143: Avoid specifying long messages outside the exception class
(TRY003)
152-152: Unused method argument: use_oneshot
(ARG002)
158-158: Consider moving this statement to an else block
(TRY300)
178-178: Unused function argument: backend
(ARG001)
179-179: Unused function argument: world_size
(ARG001)
180-180: Unused function argument: rank
(ARG001)
181-181: Unused function argument: max_token_num
(ARG001)
182-182: Unused function argument: hidden_dim
(ARG001)
183-183: Unused function argument: dtype
(ARG001)
201-201: Unused function argument: backend
(ARG001)
202-202: Unused function argument: world_size
(ARG001)
203-203: Unused function argument: rank
(ARG001)
204-204: Unused function argument: max_token_num
(ARG001)
205-205: Unused function argument: hidden_dim
(ARG001)
206-206: Unused function argument: dtype
(ARG001)
227-227: Unused function argument: backend
(ARG001)
228-228: Unused function argument: world_size
(ARG001)
229-229: Unused function argument: rank
(ARG001)
230-230: Unused function argument: max_token_num
(ARG001)
231-231: Unused function argument: hidden_dim
(ARG001)
232-232: Unused function argument: dtype
(ARG001)
288-288: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
289-289: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
290-290: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
291-291: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
295-295: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
398-401: Avoid specifying long messages outside the exception class
(TRY003)
444-444: Avoid specifying long messages outside the exception class
(TRY003)
584-584: Avoid specifying long messages outside the exception class
(TRY003)
646-648: Avoid specifying long messages outside the exception class
(TRY003)
651-653: Avoid specifying long messages outside the exception class
(TRY003)
672-672: Avoid specifying long messages outside the exception class
(TRY003)
674-674: Avoid specifying long messages outside the exception class
(TRY003)
683-683: Unpacked variable residual_result is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
696-696: Avoid specifying long messages outside the exception class
(TRY003)
699-702: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer/comm/workspace_base.py
83-83: Do not catch blind exception: Exception
(BLE001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (11)
flashinfer/comm/workspace_base.py (2)
63-89: LGTM - safety net destructor with appropriate error handling.The
__del__implementation correctly handles the edge case wheredestroy()wasn't called explicitly. Catching bareException(flagged by Ruff BLE001) is intentional here since__del__cannot propagate exceptions and we need to warn about any cleanup failure.One minor note:
stacklevel=2in warnings from__del__may not point to the expected call site since__del__is called by the garbage collector, not directly by user code. Consider usingstacklevel=1or omitting it, but this is cosmetic.
23-61: Well-designed abstract base class.The
AllReduceFusionWorkspaceABC establishes a clean contract for workspace lifecycle management. The constructor initializes common attributes, and abstract methods ensure consistent interface acrossTRTLLMAllReduceFusionWorkspaceandMNNVLAllReduceFusionWorkspacesubclasses.flashinfer/comm/trtllm_ar.py (2)
809-852: Clear metadata validation with good error aggregation.The
check_trtllm_allreduce_fusion_workspace_metadatafunction properly validates workspace compatibility:
- Checks required keys first before accessing them
- Aggregates multiple errors for better developer experience
- Validates
world_size, buffer capacity, and dtype alignmentOne observation: the function raises on missing keys, then continues to check values. If keys are missing but no error was raised yet, the subsequent checks would fail with KeyError. This is correctly handled by the early
if errors: raiseafter key checks.
503-506: Consistent deprecation messaging guiding users to unified API.The deprecation decorators now consistently direct users to the unified
allreduce.pyAPI, which aligns with the PR's goal of consolidating the AllReduce interface.tests/comm/test_allreduce_negative.py (3)
36-61: Well-structured test fixture with proper cleanup.The
autousefixture pattern ensures workspace cleanup happens regardless of test outcome. Usingyieldbefore cleanup is correct for pytest fixtures.
72-128: Good parameterized coverage of unsupported patterns and layout codes.The tests comprehensively verify that MNNVL backend rejects:
- Quantization fusion patterns (FP8/FP4 variants)
- Any
layout_codespecificationError message patterns in
pytest.raises(match=...)align with the implementation inallreduce.py.
130-185: Required parameter validation tests are thorough.Tests verify that
kARResidualRMSNormpattern correctly raisesValueErrorwhen:
residual_inis missingrms_gammais missingThis aligns with the validation logic in
allreduce_fusion()for the MNNVL backend.tests/comm/test_trtllm_allreduce_fusion.py (3)
62-88: Clean separation of legacy and unified workspace creation paths.The conditional workspace creation properly handles both APIs:
- Legacy: Uses
trtllm_create_ipc_workspace_for_all_reduce_fusionwith metadata- Unified: Uses
create_allreduce_fusion_workspacereturning a workspace objectNote: Line 77 initializes
workspace = Nonebefore the unified API path, which is good for the cleanup logic at line 400.
209-240: Unified API integration correctly adapts tensor shapes.The unified API calls correctly reshape 1D tensors to 2D
[token_num, hidden_dim]format using.view(). Since the original tensors are contiguous, these views maintain contiguity as required byallreduce_fusion().
448-471: Excellent dual-API test coverage.Parameterizing
legacy_apias[True, False]ensures both the legacy and unified APIs are validated against the same correctness criteria. This is valuable for ensuring the unified API maintains backward compatibility with existing behavior.tests/comm/test_allreduce_unified_api.py (1)
239-311: Robust error handling and cleanup pattern.The test properly:
- Wraps execution in try/except/finally
- Gathers failure status from all ranks for diagnostics
- Ensures workspace cleanup in
finallyblock- Cleans up torch.distributed if initialized
This pattern ensures resources are freed even on test failures in distributed scenarios.
|
test_trtllm_mnnvl_allreduce_custom_comm.py was not updated in #2118. Could you update it in this PR? |
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (2)
tests/comm/test_allreduce_unified_api.py (1)
155-156: Epsilon inconsistency between reference and actual test remains.The reference RMSNorm at Line 156 uses
torch.finfo(dtype).eps, butrun_allreduce_testuses a hardcodedeps = 1e-5(Line 211) which is passed to the actual fusion operation. This mismatch could cause subtle differences between reference and test outputs.Consider using the same epsilon for both:
- eps = 1e-5 + eps = torch.finfo(dtype).epsOr pass
epstoprepare_test_dataand use it consistently.Also applies to: 211-211
flashinfer/comm/allreduce.py (1)
314-314: Documentation inconsistencies persist.
Line 314 references
workspace.is_sufficient_for(...)but the actual method isis_buffer_size_sufficient(tp_size, num_tokens, hidden_dim, dtype).Line 337 mentions
BackendSupportedErrorbut the code raisesValueErrorat Line 398.- Use `workspace.is_sufficient_for(token_num, hidden_dim, dtype)` to check before use. + Use `workspace.is_buffer_size_sufficient(tp_size, token_num, hidden_dim, dtype)` to check before use. ... Raises: - BackendSupportedError: If no suitable backend available for the configuration + ValueError: If no suitable backend available for the configuration ValueError: If problem size not supported for the specified backendAlso applies to: 337-338
🧹 Nitpick comments (4)
tests/test_helpers/comm.py (1)
48-52: Consider environment variable fallbacks for port configuration.The hardcoded
MASTER_PORT = "29500"could cause conflicts if the port is already in use. Consider checking for existing environment variables first:# Set environment variables for torch.distributed - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29500" + os.environ.setdefault("MASTER_ADDR", "localhost") + os.environ.setdefault("MASTER_PORT", "29500") os.environ["RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size)This allows tests to override the port via environment variables when needed.
tests/comm/test_allreduce_unified_api.py (1)
167-168: Remove unusedmonkeypatchparameter.The
monkeypatchparameter is declared but never used in the function, as flagged by static analysis.def run_allreduce_test( - monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int, backend: str, ):Also update the call site in
test_allreduce_unified:def test_allreduce_unified( - monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int, backend: str, ): - run_allreduce_test(monkeypatch, seq_lens, fusion, dtype, hidden_size, backend) + run_allreduce_test(seq_lens, fusion, dtype, hidden_size, backend)flashinfer/comm/allreduce.py (2)
159-160: Consider using logging instead of print for buffer insufficiency messages.Line 160 uses
print()which could be noisy in production environments. Consider using theloggingmodule for consistency with other debug output in the codebase:+import logging + def is_buffer_size_sufficient( ... ) -> bool: try: check_trtllm_allreduce_fusion_workspace_metadata( num_tokens, hidden_dim, tp_size, dtype, self.metadata ) return True except ValueError as e: - print(f"Workspace is insufficient for problem size. {e}") + logging.debug(f"Workspace is insufficient for problem size. {e}") return False
683-683: Prefix unused variable with underscore.The
residual_resultvariable is unpacked but never used (flagged by static analysis RUF059). Use underscore to indicate it's intentionally ignored:- norm_result, residual_result = trtllm_mnnvl_fused_allreduce_add_rmsnorm( + norm_result, _residual_result = trtllm_mnnvl_fused_allreduce_add_rmsnorm(
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
flashinfer/comm/allreduce.py(1 hunks)flashinfer/comm/trtllm_mnnvl_ar.py(10 hunks)tests/comm/test_allreduce_negative.py(1 hunks)tests/comm/test_allreduce_unified_api.py(1 hunks)tests/test_helpers/comm.py(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
flashinfer/comm/allreduce.py
🧬 Code graph analysis (2)
tests/test_helpers/comm.py (1)
flashinfer/comm/mapping.py (1)
local_rank(391-392)
flashinfer/comm/trtllm_mnnvl_ar.py (3)
flashinfer/comm/workspace_base.py (3)
AllReduceFusionWorkspace(23-89)backend(38-40)destroy(43-50)flashinfer/comm/mapping.py (2)
rank(311-312)rank(315-322)flashinfer/comm/allreduce.py (2)
backend(135-136)destroy(163-169)
🪛 Ruff (0.14.8)
tests/comm/test_allreduce_unified_api.py
168-168: Unused function argument: monkeypatch
(ARG001)
flashinfer/comm/allreduce.py
141-143: Avoid specifying long messages outside the exception class
(TRY003)
152-152: Unused method argument: use_oneshot
(ARG002)
158-158: Consider moving this statement to an else block
(TRY300)
178-178: Unused function argument: backend
(ARG001)
179-179: Unused function argument: world_size
(ARG001)
180-180: Unused function argument: rank
(ARG001)
181-181: Unused function argument: max_token_num
(ARG001)
182-182: Unused function argument: hidden_dim
(ARG001)
183-183: Unused function argument: dtype
(ARG001)
201-201: Unused function argument: backend
(ARG001)
202-202: Unused function argument: world_size
(ARG001)
203-203: Unused function argument: rank
(ARG001)
204-204: Unused function argument: max_token_num
(ARG001)
205-205: Unused function argument: hidden_dim
(ARG001)
206-206: Unused function argument: dtype
(ARG001)
227-227: Unused function argument: backend
(ARG001)
228-228: Unused function argument: world_size
(ARG001)
229-229: Unused function argument: rank
(ARG001)
230-230: Unused function argument: max_token_num
(ARG001)
231-231: Unused function argument: hidden_dim
(ARG001)
232-232: Unused function argument: dtype
(ARG001)
288-288: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
289-289: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
290-290: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
291-291: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
295-295: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
398-401: Avoid specifying long messages outside the exception class
(TRY003)
444-444: Avoid specifying long messages outside the exception class
(TRY003)
584-584: Avoid specifying long messages outside the exception class
(TRY003)
646-648: Avoid specifying long messages outside the exception class
(TRY003)
651-653: Avoid specifying long messages outside the exception class
(TRY003)
672-672: Avoid specifying long messages outside the exception class
(TRY003)
674-674: Avoid specifying long messages outside the exception class
(TRY003)
683-683: Unpacked variable residual_result is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
696-696: Avoid specifying long messages outside the exception class
(TRY003)
699-702: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (15)
tests/test_helpers/comm.py (2)
10-32: LGTM!The MPI/CUDA setup function correctly validates prerequisites and assigns CUDA devices based on local rank. The
local_rank = rank % gpus_per_nodecalculation is consistent with the pattern used inflashinfer/comm/mapping.py.
61-67: LGTM!The cleanup function correctly guards against calling
destroy_process_groupwhen not initialized.flashinfer/comm/trtllm_mnnvl_ar.py (5)
79-79: Potential inconsistency:self.rankis set twice with potentially different values.Line 79 passes
mapping.rankto the base class__init__, which setsself.rank = rank. However, Line 128 then overwritesself.rank = mapping.tp_rank. Ifmapping.rankdiffers frommapping.tp_rank, this creates an inconsistency.Please verify this is intentional. If
tp_rankis the correct value for workspace operations, consider passing it to the base class:- super().__init__(mapping.world_size, mapping.rank) + super().__init__(mapping.world_size, mapping.tp_rank)Or if both values are needed, use a different attribute name for the TP-specific rank.
Also applies to: 128-129
227-229: LGTM!The
backendproperty correctly implements the abstract method from the base class.
231-241: LGTM!The
destroy()method correctly implements idempotent cleanup with the_destroyedguard, consistent with the pattern used inTRTLLMAllReduceFusionWorkspace.
326-332: LGTM!The function signatures and docstrings are correctly updated to use the new
MNNVLAllReduceFusionWorkspacetype.Also applies to: 404-414
498-501: LGTM!The deprecation message and legacy function correctly redirect users to the new
MNNVLAllReduceFusionWorkspaceclass.Also applies to: 542-547
tests/comm/test_allreduce_negative.py (3)
25-116: LGTM!The test class properly validates that the MNNVL backend rejects unsupported quantization patterns and layout codes. The fixture correctly handles workspace cleanup and MPI synchronization.
119-174: LGTM!The test class correctly validates that required parameters (
residual_in,rms_gamma) are enforced for thekARResidualRMSNormpattern.
177-272: LGTM!The test class correctly validates buffer size sufficiency checks across both backends. The parametrized approach ensures consistent behavior between MNNVL and TRTLLM implementations.
tests/comm/test_allreduce_unified_api.py (2)
31-123: LGTM!The test function correctly exercises both fused and unfused allreduce patterns and validates outputs against reference implementations.
291-311: LGTM!The parametrized test provides comprehensive coverage across sequence lengths, fusion modes, data types, hidden dimensions, and backends.
flashinfer/comm/allreduce.py (3)
177-217: LGTM!The backend check functions correctly implement topology-based requirements. The unused parameters (flagged by static analysis) appear intentional for a consistent interface that allows future extensibility without signature changes.
225-278: LGTM!The heuristic function correctly implements backend selection based on topology and benchmark data, with appropriate fallbacks.
452-472: LGTM!The
allreduce_fusiondispatcher correctly routes to backend-specific implementations based on workspace type. The contiguity check before flattening ensures writes to flattened tensors are reflected in the original 2D tensors.Also applies to: 568-639
yzh119
left a comment
There was a problem hiding this comment.
@nvmbreughe this is huge and thank you for working on the refactor, LGTM overall.
| >>> prenorm = torch.empty_like(hidden_states) | ||
| >>> normed = torch.empty_like(hidden_states) | ||
| >>> | ||
| >>> # Call fusion - backend inferred from workspace type |
| print(f"Rank {rank} using GPU {torch.cuda.current_device()}") | ||
|
|
||
| eps = 1e-5 | ||
| torch.manual_seed(42 + rank) |
There was a problem hiding this comment.
Is it designed to make sure each rank was initialized in a different way? (I'm curious if we should have a standard for data initialization for communication APIs).
There was a problem hiding this comment.
Good catch! I think it must have been a copy paste error actually: the data is only initialized for rank 0, so this would not have an effect.
| torch.testing.assert_close( | ||
| output[0], | ||
| reference_output[0], | ||
| rtol=0.05, |
There was a problem hiding this comment.
We might consider different tolerance for different output data types in the future (if we add more precision output data types).
There was a problem hiding this comment.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/comm/__init__.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/comm/__init__.py (3)
flashinfer/comm/workspace_base.py (1)
AllReduceFusionWorkspace(23-89)flashinfer/comm/trtllm_mnnvl_ar.py (1)
MNNVLAllReduceFusionWorkspace(51-241)flashinfer/comm/allreduce.py (3)
TRTLLMAllReduceFusionWorkspace(88-169)allreduce_fusion(452-702)create_allreduce_fusion_workspace(286-444)
🪛 GitHub Actions: pre-commit
flashinfer/comm/__init__.py
[error] 49-49: Mypy: Syntax error in flashinfer/comm/init.py: '(' was never closed.
[error] 50-50: Ruff-check: invalid-syntax in import block (Expected ')' before newline) in flashinfer/comm/init.py.
[error] 50-50: Ruff-format: Failed to parse flashinfer/comm/init.py due to syntax error.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
…ls. (flashinfer-ai#2130) <!-- .github/pull_request_template.md --> ## 📌 Description A unified API for the MNNVL and single-node AllReduce kernels. * This introduces the API's `create_allreduce_fusion_workspace`, and `allreduce_fusion` * The backend ("trtllm" or "mnnvl") is chosen during workspace creation. We can either pick it explicitly, or use the "auto" backend to have a heuristic pick the best backend. * The API can be used for both single-node allreduce, as well as for multi-node allreduce. Test with ``` mpirun -np 4 pytest tests/comm/test_allreduce_unified_api.py mpirun -np 4 pytest tests/comm/test_allreduce_negative.py ``` note: mpirun is needed for the mnnvl backend, as illustrated in the test commands above ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Unified AllReduce Fusion API with multi-backend support (auto/TRTLLM/MNNVL) and public workspace types * Common workspace base with lifecycle management and automatic cleanup warnings * **Bug Fixes / Validation** * Stronger input/workspace validation with aggregated error messages * Idempotent destroy semantics for safer resource cleanup * **Deprecations** * Legacy AllReduce APIs deprecated in favor of the unified API * **Tests** * Expanded negative, integration, and cross-backend correctness tests plus MPI/CUDA test helpers <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Zihao Ye <expye@outlook.com> Co-authored-by: yzh119 <zihaoy@nvidia.com>
📌 Description
A unified API for the MNNVL and single-node AllReduce kernels.
create_allreduce_fusion_workspace, andallreduce_fusionTest with
note: mpirun is needed for the mnnvl backend, as illustrated in the test commands above
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Bug Fixes / Validation
Deprecations
Tests
✏️ Tip: You can customize this high-level summary in your review settings.