Refactor trtllm_mnnvl_allreduce#2118
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughReplaces legacy MNNVL all-reduce with a fused Lamport-buffer allreduce exposed as Changes
Sequence Diagram(s)sequenceDiagram
participant App as Application
participant PyAPI as Python API / Workspace
participant Strategy as Strategy Selector
participant Comm as Comm Backend (MPI / IpcSocket)
participant Kernel as CUDA Kernel (trtllm_mnnvl_allreduce_fusion)
participant Buff as Buffers / Output
App->>PyAPI: call trtllm_mnnvl_allreduce(...) or fused API
PyAPI->>PyAPI: validate inputs, prepare workspace & outputs
PyAPI->>Strategy: select ONESHOT / TWOSHOT (AUTO inspects workspace/problem)
PyAPI->>Comm: exchange/share handles (MPI bcast/barrier or IpcSocket FD exchange)
PyAPI->>Kernel: invoke trtllm_mnnvl_allreduce_fusion(params)
rect rgb(245,250,255)
Kernel->>Kernel: lamport-stage broadcast & per-token reduction
alt RMSNorm fusion enabled
Kernel->>Kernel: compute RMS, apply gamma, add residuals
end
Kernel->>Buff: write output (and residual_out if present)
end
Buff-->>PyAPI: return tensor(s)
PyAPI-->>App: deliver result(s)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @timlee0212, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a comprehensive refactoring of the multi-node NVLink (MNNVL) all-reduce system within FlashInfer. It unifies the all-reduce and RMSNorm operations into a single, highly configurable C++ kernel, exposed through intuitive new Python APIs. A key improvement is the new workspace management class, which automates and optimizes buffer allocation. Furthermore, the PR adds crucial support for IPC Socket-based handle transfer, broadening compatibility to hardware environments like DGX machines. These changes collectively enhance the flexibility, performance, and overall robustness of distributed computations. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request is a significant refactoring of the MNNVL all-reduce implementation, introducing a new, cleaner API with a dedicated workspace manager class, and adding support for IPC sockets for single-node communication. The changes are extensive and substantially improve the code's structure and capabilities. My review focuses on ensuring backward compatibility is fully maintained as intended, removing leftover debug code, improving memory usage efficiency, adding a critical safety check for buffer sizes in the new API, and suggesting a minor precision improvement in a CUDA kernel.
| def trtllm_mnnvl_allreduce( | ||
| input: torch.Tensor, | ||
| workspace: MNNVLAllreduceFusionWorkspace, | ||
| launch_with_pdl: bool, | ||
| output: Optional[torch.Tensor] = None, | ||
| strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, | ||
| ) -> torch.Tensor: |
There was a problem hiding this comment.
The new trtllm_mnnvl_allreduce function is missing a check to ensure that the input tensor fits within the allocated workspace. The old API had a check like if inp.shape[0] > buffer_M: raise ValueError(...). A similar check should be added here to prevent potential out-of-bounds memory access, which could lead to crashes or incorrect results. The required buffer size depends on the strategy (one-shot vs. two-shot) and can be calculated using MNNVLAllreduceFusionWorkspace.get_required_buffer_size_bytes.
There was a problem hiding this comment.
Do we want this check to be on the execution path? Or should we assuming it is the user's liability to ensure it does not overflow.
There was a problem hiding this comment.
We do want this check. I recently added it because it did bite others.
| ) | ||
| self.buf_size = buf_size | ||
| self.local_device = device | ||
|
|
||
| def lamport_initialize(self, rank: int, dtype: torch.dtype): | ||
| self.mcast_device_memory.lamport_initialize(rank, dtype) | ||
|
|
||
| def get_mc_buffer( | ||
| def get_multicast_buffer( |
There was a problem hiding this comment.
There was a problem hiding this comment.
This class is used internally, and left as a placeholder but not implemented. Thus, a breaking changes is fine. Tag @nvmbreughe for confirmation.
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/comm/mnnvl.py (1)
132-149:alloc_and_copy_to_cudareturn type and empty-input behavior are inconsistentThe function is annotated as returning
intbut returnsNonewhenhost_ptr_arrayis empty. Callers currently pass non‑empty lists, but this mismatch can trip type checkers and hide bugs if an empty list is ever passed.Either tighten behavior or relax the signature, for example:
- If empty input is invalid, raise:
def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int: if not host_ptr_array: raise ValueError("host_ptr_array must be non-empty")
- Or, if you want the sentinel, change the annotation to
int | Noneand document theNonecase.tests/comm/test_trtllm_mnnvl_allreduce.py (1)
328-427: Moveallgather()and finalmpi_barrier()tofinallyblock to ensure all ranks participate in collectivesLines 414 and 434 create a deadlock risk in error scenarios. The
allgather()at line 414 is inside theexceptblock, so only ranks that hit an exception call it. Meanwhile, thempi_barrier()at line 434 is unconditionally called aftertry/except/finally. If an error occurs on some but not all ranks, failing ranks block inallgather()waiting for non-failing ranks that never enter theexceptblock, while non-failing ranks block in the final barrier—both unable to proceed.Move the
allgather()call and finalmpi_barrier()to thefinallyblock to ensure all ranks participate in these collective operations:rank_failed = False try: ... except Exception as e: rank_failed = True failure_message = ... print(failure_message) import traceback print(traceback.format_exc()) raise finally: all_failures = MPI.COMM_WORLD.allgather(rank_failed) if any(all_failures): failed_ranks = [i for i, failed in enumerate(all_failures) if failed] if rank == 0: print(f"Test failed on ranks: {failed_ranks}") if "workspace" in locals(): del workspace trtllm_mnnvl_ar.mpi_barrier()This applies to line 328–426 (main
try/except) and line 434 (final barrier).
🧹 Nitpick comments (8)
flashinfer/comm/mnnvl.py (1)
640-655: Minor polish: unused recvmsg outputs and predictableopIdTwo small, non‑blocking cleanups:
- In
IpcSocket.recv_fd(), the unpackedmsg,flags, andaddrfromrecvmsgare unused. Renaming them to_msg,_flags,_addrwill make that explicit and silence linters:_msg, ancdata, _flags, _addr = self.sock.recvmsg(...)
opIdfor the socket name is generated withrandom.randint. Since it’s only used as a uniqueness hint (not security‑sensitive), this is fine; if you want to appease S311 you could switch tosecrets.randbits(64)or document that it’s non‑cryptographic.Both are optional, but would make static analysis quieter.
Also applies to: 885-889
include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh (3)
23-25: Explicitly include<array>and<tuple>, and guardadjustGridConfigagainstsmCount == 0Within this header:
LamportBufferLayout,LamportFlags,PackedVec, and several kernels usestd::array.adjustGridConfigreturnsstd::tuple<int, int, int>and callers usestd::get.But only
<type_traits>is included;<array>and<tuple>are currently pulled in (if at all) via transitive includes, which is fragile.Also,
adjustGridConfigrelies onGetCudaMultiProcessorCount():int smCount = GetCudaMultiProcessorCount(); while (numTokens * clusterSize > smCount && clusterSize > 1 && blockSize <= 512) { ... }If
GetCudaMultiProcessorCount()ever returns 0 (e.g., CUDA error or misconfiguration), this loop will keep shrinkingclusterSizeand inflatingblockSizein a somewhat opaque way.Suggestions:
- Add explicit includes at the top of the header:
#include <array> #include <tuple>
- Make
adjustGridConfigrobust to a 0 or negative SM count by early‑clamping:int smCount = GetCudaMultiProcessorCount(); if (smCount <= 0) { // Fall back to single-SM configuration clusterSize = 1; blockSize = std::min(threadsNeeded, 1024); return {blockSize, clusterSize, loadsPerThread}; }This keeps the fused path predictable even if the helper cannot obtain a valid SM count.
Also applies to: 54-177, 143-163, 291-313, 348-359, 385-419, 449-497
509-651: Confirm lamport clear / wait protocol assumptions for oneshot kernelThe oneshot fused kernel uses
LamportFlagsas follows:
- Out‑of‑bounds threads call
ctaArrive()thenclearDirtyLamportBuf()and return.- In‑bounds threads:
- write their shard into the multicast lamport buffer,
- call
ctaArrive()again,- then call
clearDirtyLamportBuf()and spin on the Lamport buffers until all entries are non‑negZero.This protocol assumes:
- Every thread in the grid calls
clearDirtyLamportBuf()exactly once per iteration.- Buffer flags and
bytesToClearare correctly initialized to match the configurednumTokens * tokenDim * WorldSize.Given that this is a direct Lamport port, the logic looks consistent, but the protocol is subtle. I’d recommend:
- Double‑checking the initialization of
buffer_flagsinMNNVLAllreduceFusionWorkspacematches the expectations here (current index, dirty index, bytes per buffer, and stage counts).- Adding a brief comment near the kernel launch documenting that
buffer_flagsmust follow the[cur, dirty, bytes_per_buffer, dirty_num_stages, bytes_to_clear[4], access_ptr]layout used byLamportFlags.No code change strictly required, but the invariants are nontrivial and worth locking down in comments/tests.
754-885: Two‑shot path & RMSNorm fusion: validate world sizes and loads‑per‑thread boundsThe two‑shot kernels and dispatchers introduce several constraints:
twoshotAllreduceFusionDispatch<T>only supportsnRanksin{2, 4, 8, 16, 32, 64}and enforcestokenDim % (sizeof(float4) / sizeof(T)) == 0.rmsNormLamportis instantiated withLoadsPerThreadin[1, 8]and usesfloat4loads into shared memory; dynamic shared memory is sized as3 * rnBlockSize * iters * sizeof(T)and indexed accordingly.The implementation looks coherent, but a few invariants are implicit:
MNNVLTwoShotStage::NUM_STAGESmust stay in sync with theLamportFlags<float4>usage and the twobytes_to_clearentries inwaitAndUpdate.rnLoadsPerThreadretrieved fromadjustGridConfigmust remain in[1, 8]; thedefault:branch already errors if it’s out of range, which is good.rnClusterSizefromadjustGridConfigis assumed to be<= 8given__shared__ float sharedVal[8];in the RMSNorm kernel.Given these contracts, I’d suggest:
- Adding asserts (or comments) that
rnClusterSize <= 8when CGA is used, to guard future changes toadjustGridConfig.- Extending tests to cover the corner cases where
tokenDimis just at or above the supported boundary (e.g., maximum hidden size and multiple world sizes) so we don’t regress theFLASHINFER_CHECKconditions.Functionally the code looks sound; this is mainly about making the implicit constraints explicit.
Also applies to: 898-959, 1062-1219
csrc/trtllm_mnnvl_allreduce.cu (1)
99-107: Error message still mentions “twoshot” even for oneshot pathRegardless of
use_oneshot, the failure message says:TVM_FFI_ICHECK(status == cudaSuccess) << "twoshot_allreduce_dispatch_world_size failed with error code " << cudaGetErrorString(status);This is slightly misleading when the oneshot dispatch is used. Consider making the message neutral (e.g., “allreduce_fusion_dispatch failed…”) or switching on
use_oneshotto provide a more accurate label. Behavior is otherwise fine.tests/comm/test_trtllm_mnnvl_allreduce.py (2)
232-270: Use the sameepsfor reference RMSNorm as the fused kernelIn
prepare_test_data, the fused reference path uses:norm_out = rmsnorm( residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False )But the actual fused kernel is driven by the
epsargument passed intorow_linear_residual_norm_fusion_forward(eps = 1e-5inrun_mnnvl_ar_full).To keep the reference as close as possible to the fused implementation (and not rely on loose tolerances), consider:
def prepare_test_data(..., fusion: bool, eps: float): ... if fusion: ... norm_out = rmsnorm(residual_out, norm_weight, eps, enable_pdl=False)and threading
epsthrough the call sites.
273-281: Annotatelegacy_explicit_workspace_bytesas optionalRuff’s RUF013 warning here is valid:
def run_mnnvl_ar_full(..., legacy_explicit_workspace_bytes: int = None, legacy_api: bool = False, ):Changing the signature to make the optionality explicit improves readability and typing:
from typing import Optional def run_mnnvl_ar_full( ..., legacy_explicit_workspace_bytes: Optional[int] = None, legacy_api: bool = False, ) -> None: ...or, in Python 3.10+:
legacy_explicit_workspace_bytes: int | None = Noneflashinfer/comm/trtllm_mnnvl_ar.py (1)
203-205: Drop debug print from hot path.
This unconditional- print( - f"[Rank {rank}] Inside Kernel: multicast_buffer_ptr: {multicast_buffer_ptr:x}, buffer_ptrs_dev: {buffer_ptrs_dev:x}, buffer_ptr_local: {buffer_ptr_local:x}, buffer_flags_mnnvl: {buffer_flags_mnnvl}" - )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between 0753095 and a2670e8c69a66ae142c31582d58621173fe2408a.
📒 Files selected for processing (6)
csrc/trtllm_mnnvl_allreduce.cu(1 hunks)flashinfer/comm/mnnvl.py(18 hunks)flashinfer/comm/trtllm_mnnvl_ar.py(5 hunks)include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh(2 hunks)include/flashinfer/utils.cuh(1 hunks)tests/comm/test_trtllm_mnnvl_allreduce.py(7 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
tests/comm/test_trtllm_mnnvl_allreduce.py (3)
flashinfer/comm/mapping.py (2)
Mapping(21-475)tp_rank(325-326)flashinfer/comm/trtllm_mnnvl_ar.py (7)
MNNVLAllreduceFusionWorkspace(47-141)mpi_barrier(23-27)trtllm_mnnvl_fused_allreduce_add_rmsnorm(301-391)MNNVLAllreduceFusionStrategy(30-40)trtllm_mnnvl_allreduce(229-298)get_allreduce_mnnvl_workspace(398-451)get_required_buffer_size_bytes(116-141)flashinfer/comm/mnnvl.py (10)
barrier(168-168)barrier(227-228)bcast(165-165)bcast(224-225)get_multicast_ptr(868-872)get_multicast_ptr(1191-1193)get_buffer_ptrs_dev(854-856)get_buffer_ptrs_dev(1199-1201)get_unicast_ptr(858-866)get_unicast_ptr(1195-1197)
csrc/trtllm_mnnvl_allreduce.cu (3)
flashinfer/comm/cuda_ipc.py (2)
cudaSetDevice(149-150)cudaGetErrorString(146-147)csrc/tvm_ffi_utils.h (1)
get_stream(272-274)flashinfer/comm/trtllm_mnnvl_ar.py (1)
trtllm_mnnvl_allreduce_fusion(168-222)
flashinfer/comm/trtllm_mnnvl_ar.py (5)
flashinfer/comm/mapping.py (5)
rank(311-312)rank(315-322)tp_rank(325-326)local_rank(391-392)is_multi_node(403-404)flashinfer/jit/comm.py (1)
gen_trtllm_mnnvl_comm_module(33-39)flashinfer/utils.py (2)
register_custom_op(273-282)register_custom_op(292-311)flashinfer/comm/mnnvl.py (13)
McastGPUBuffer(1121-1201)CommBackend(152-171)MPIBackend(211-232)lamport_initialize(1101-1118)lamport_initialize(1160-1161)barrier(168-168)barrier(227-228)get_buffer_ptrs_dev(854-856)get_buffer_ptrs_dev(1199-1201)get_unicast_ptr(858-866)get_unicast_ptr(1195-1197)get_multicast_ptr(868-872)get_multicast_ptr(1191-1193)csrc/trtllm_mnnvl_allreduce.cu (2)
trtllm_mnnvl_allreduce_fusion(31-109)trtllm_mnnvl_allreduce_fusion(31-37)
flashinfer/comm/mnnvl.py (1)
flashinfer/cuda_utils.py (1)
checkCudaErrors(51-61)
🪛 Ruff (0.14.5)
tests/comm/test_trtllm_mnnvl_allreduce.py
279-279: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
flashinfer/comm/trtllm_mnnvl_ar.py
74-76: Avoid specifying long messages outside the exception class
(TRY003)
261-263: Avoid specifying long messages outside the exception class
(TRY003)
268-270: Avoid specifying long messages outside the exception class
(TRY003)
338-340: Avoid specifying long messages outside the exception class
(TRY003)
342-344: Avoid specifying long messages outside the exception class
(TRY003)
346-348: Avoid specifying long messages outside the exception class
(TRY003)
352-354: Avoid specifying long messages outside the exception class
(TRY003)
358-360: Avoid specifying long messages outside the exception class
(TRY003)
500-502: Avoid specifying long messages outside the exception class
(TRY003)
571-573: Avoid specifying long messages outside the exception class
(TRY003)
577-579: Avoid specifying long messages outside the exception class
(TRY003)
582-584: Avoid specifying long messages outside the exception class
(TRY003)
586-588: Avoid specifying long messages outside the exception class
(TRY003)
591-593: Avoid specifying long messages outside the exception class
(TRY003)
596-598: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer/comm/mnnvl.py
587-587: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
612-612: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
640-640: Unpacked variable msg is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
640-640: Unpacked variable flags is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
640-640: Unpacked variable addr is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
656-656: Avoid specifying long messages outside the exception class
(TRY003)
885-885: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
| inline int GetCudaMultiProcessorCount() { | ||
| static int sm_count = 0; | ||
| if (sm_count == 0) { | ||
| int device_id; | ||
| cudaGetDevice(&device_id); | ||
| cudaDeviceProp device_prop; | ||
| cudaGetDeviceProperties(&device_prop, device_id); | ||
| sm_count = device_prop.multiProcessorCount; | ||
| } | ||
| return sm_count; |
There was a problem hiding this comment.
Make GetCudaMultiProcessorCount thread‑safe and clarify multi‑device semantics
static int sm_countis written without synchronization; concurrent calls from multiple host threads can cause a data race and undefined behavior.- The function also permanently caches the SM count of whichever device is current on the first call; if the process later switches devices, the cached value will be wrong.
Consider making the cache atomic (or using std::call_once) and, if needed, keying by device ID. For example:
- inline int GetCudaMultiProcessorCount() {
- static int sm_count = 0;
- if (sm_count == 0) {
- int device_id;
- cudaGetDevice(&device_id);
- cudaDeviceProp device_prop;
- cudaGetDeviceProperties(&device_prop, device_id);
- sm_count = device_prop.multiProcessorCount;
- }
- return sm_count;
-}
+ inline int GetCudaMultiProcessorCount() {
+ static std::atomic<int> sm_count{0};
+ int cached = sm_count.load(std::memory_order_relaxed);
+ if (cached == 0) {
+ int device_id = 0;
+ cudaGetDevice(&device_id);
+ cudaDeviceProp device_prop{};
+ cudaGetDeviceProperties(&device_prop, device_id);
+ cached = device_prop.multiProcessorCount;
+ sm_count.store(cached, std::memory_order_relaxed);
+ }
+ return cached;
+}(You’d also need to #include <atomic> in this header.)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| inline int GetCudaMultiProcessorCount() { | |
| static int sm_count = 0; | |
| if (sm_count == 0) { | |
| int device_id; | |
| cudaGetDevice(&device_id); | |
| cudaDeviceProp device_prop; | |
| cudaGetDeviceProperties(&device_prop, device_id); | |
| sm_count = device_prop.multiProcessorCount; | |
| } | |
| return sm_count; | |
| inline int GetCudaMultiProcessorCount() { | |
| static std::atomic<int> sm_count{0}; | |
| int cached = sm_count.load(std::memory_order_relaxed); | |
| if (cached == 0) { | |
| int device_id = 0; | |
| cudaGetDevice(&device_id); | |
| cudaDeviceProp device_prop{}; | |
| cudaGetDeviceProperties(&device_prop, device_id); | |
| cached = device_prop.multiProcessorCount; | |
| sm_count.store(cached, std::memory_order_relaxed); | |
| } | |
| return cached; | |
| } |
🤖 Prompt for AI Agents
In include/flashinfer/utils.cuh around lines 292-301, the function uses a
non-synchronized static sm_count that creates a data race and also permanently
caches the SM count for whichever device was current on first call; fix by
making the lookup thread-safe and per-device: first call cudaGetDevice() to get
the current device id, then fetch properties if no cached value exists for that
device; protect the cache with thread-safety (e.g., use std::call_once
per-device, or a mutex + an unordered_map<int,int> keyed by device id, or an
atomic<int> array sized by device count) and include the required headers
(<mutex>, <unordered_map> or <atomic>); return the cached per-device SM count
after initialization.
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (2)
csrc/trtllm_mnnvl_allreduce.cu (1)
56-69: Guard RMSNorm fusion against missingresidual_inand validate its shape.The precondition at line 56 enforces presence of
residual_out,gamma, andepsilonwhenrmsnorm_fusion=true, but omitsresidual_in. The kernel will dereferenceresidualInPtrunconditionally whenRMSNormFusionis true, causing undefined behavior ifresidual_inis absent.Additionally, shape validation (lines 61-68) only covers
residual_outandgamma;residual_inis not validated.Extend the precondition to include
residual_in:- TVM_FFI_ICHECK((residual_out.has_value() && gamma.has_value() && epsilon.has_value()) || + TVM_FFI_ICHECK((residual_out.has_value() && residual_in.has_value() && + gamma.has_value() && epsilon.has_value()) || !rmsnorm_fusion) - << "residual_out, gamma, and epsilon must be provided if rmsnorm_fusion is true"; + << "residual_out, residual_in, gamma, and epsilon must be provided if rmsnorm_fusion is true";Add shape validation for
residual_inwithin theif (rmsnorm_fusion)block:if (rmsnorm_fusion) { TVM_FFI_ICHECK(residual_out.value().size(0) == num_tokens && residual_out.value().size(1) == token_dim) << "residual_out shape mismatch: expected (" << input.size(0) << ", " << input.size(1) << ") but got (" << residual_out.value().size(0) << ", " << residual_out.value().size(1) << ")"; + TVM_FFI_ICHECK(residual_in.value().size(0) == num_tokens && + residual_in.value().size(1) == token_dim) + << "residual_in shape mismatch: expected (" << num_tokens << ", " << token_dim + << ") but got (" << residual_in.value().size(0) << ", " + << residual_in.value().size(1) << ")"; TVM_FFI_ICHECK(gamma.value().size(0) == token_dim) << "gamma must have the same shape as token dimension (" << token_dim << ") but got (" << gamma.value().size(0) << ")"; }flashinfer/comm/trtllm_mnnvl_ar.py (1)
331-332: Restore RMSNorm epsilon default to 1e-5.Overriding
epsilonwithtorch.finfo(input.dtype).epsreplaces the kernel's built-in 1e-5 default (see line 91 incsrc/trtllm_mnnvl_allreduce.cu). For fp16 this becomes ~1e-3, materially changing RMSNorm results and breaking parity with TensorRT-LLM.Apply this diff:
if epsilon is None: - epsilon = torch.finfo(input.dtype).eps + epsilon = 1e-5
🧹 Nitpick comments (1)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
502-504: Clarify assertion for legacy API compatibility.The assertion at lines 502-504 will fail with a cryptic message if
wait_for_results=Falseis passed. Since this is deprecated legacy code, the assertion is reasonable, but consider improving the error message for clarity:- assert wait_for_results and (out is not None), ( - "Calling the legacy trtllm_mnnvl_all_reduce with wait_for_results=False is not supported. Please use trtllm_mnnvl_allreduce instead." - ) + if not wait_for_results or out is None: + raise ValueError( + "Legacy trtllm_mnnvl_all_reduce requires wait_for_results=True and a valid output tensor. " + "Please use the new trtllm_mnnvl_allreduce API instead." + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between a2670e8c69a66ae142c31582d58621173fe2408a and 92cbd483f63f102ce7b34ae5521df8023db1c96d.
📒 Files selected for processing (3)
csrc/trtllm_mnnvl_allreduce.cu(1 hunks)flashinfer/comm/trtllm_mnnvl_ar.py(5 hunks)tests/comm/test_trtllm_mnnvl_allreduce.py(8 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
flashinfer/comm/trtllm_mnnvl_ar.py
🧬 Code graph analysis (3)
csrc/trtllm_mnnvl_allreduce.cu (3)
flashinfer/comm/cuda_ipc.py (2)
cudaSetDevice(149-150)cudaGetErrorString(146-147)csrc/tvm_ffi_utils.h (1)
get_stream(272-274)flashinfer/comm/trtllm_mnnvl_ar.py (1)
trtllm_mnnvl_allreduce_fusion(168-219)
flashinfer/comm/trtllm_mnnvl_ar.py (5)
flashinfer/comm/mapping.py (5)
rank(311-312)rank(315-322)tp_rank(325-326)local_rank(391-392)is_multi_node(403-404)flashinfer/jit/comm.py (1)
gen_trtllm_mnnvl_comm_module(33-39)flashinfer/utils.py (2)
register_custom_op(273-282)register_custom_op(292-311)flashinfer/comm/mnnvl.py (13)
McastGPUBuffer(1121-1201)CommBackend(152-171)MPIBackend(211-232)lamport_initialize(1101-1118)lamport_initialize(1160-1161)barrier(168-168)barrier(227-228)get_buffer_ptrs_dev(854-856)get_buffer_ptrs_dev(1199-1201)get_unicast_ptr(858-866)get_unicast_ptr(1195-1197)get_multicast_ptr(868-872)get_multicast_ptr(1191-1193)csrc/trtllm_mnnvl_allreduce.cu (2)
trtllm_mnnvl_allreduce_fusion(31-108)trtllm_mnnvl_allreduce_fusion(31-37)
tests/comm/test_trtllm_mnnvl_allreduce.py (2)
flashinfer/comm/mapping.py (2)
Mapping(21-475)tp_rank(325-326)flashinfer/comm/trtllm_mnnvl_ar.py (7)
MNNVLAllreduceFusionWorkspace(47-141)mpi_barrier(23-27)trtllm_mnnvl_fused_allreduce_add_rmsnorm(298-388)MNNVLAllreduceFusionStrategy(30-40)trtllm_mnnvl_allreduce(226-295)get_allreduce_mnnvl_workspace(395-448)get_required_buffer_size_bytes(116-141)
🪛 Ruff (0.14.5)
flashinfer/comm/trtllm_mnnvl_ar.py
74-76: Avoid specifying long messages outside the exception class
(TRY003)
258-260: Avoid specifying long messages outside the exception class
(TRY003)
265-267: Avoid specifying long messages outside the exception class
(TRY003)
335-337: Avoid specifying long messages outside the exception class
(TRY003)
339-341: Avoid specifying long messages outside the exception class
(TRY003)
343-345: Avoid specifying long messages outside the exception class
(TRY003)
349-351: Avoid specifying long messages outside the exception class
(TRY003)
355-357: Avoid specifying long messages outside the exception class
(TRY003)
497-499: Avoid specifying long messages outside the exception class
(TRY003)
568-570: Avoid specifying long messages outside the exception class
(TRY003)
574-576: Avoid specifying long messages outside the exception class
(TRY003)
579-581: Avoid specifying long messages outside the exception class
(TRY003)
583-585: Avoid specifying long messages outside the exception class
(TRY003)
588-590: Avoid specifying long messages outside the exception class
(TRY003)
593-595: Avoid specifying long messages outside the exception class
(TRY003)
tests/comm/test_trtllm_mnnvl_allreduce.py
280-280: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (1)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
226-232: Add workspace capacity check to prevent buffer overflow.The new
trtllm_mnnvl_allreducefunction doesn't verify that the input tensor fits within the allocated workspace buffer. A previous review comment suggested adding a check similar to the legacy API'sif inp.shape[0] > buffer_Mvalidation.While the author questioned whether this should be on the execution path, buffer overflow can cause crashes or silent memory corruption. Consider adding a validation check:
required_size = MNNVLAllreduceFusionWorkspace.get_required_buffer_size_bytes( workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy ) if required_size > workspace.buffer_size_bytes: raise ValueError( f"Input tensor requires {required_size} bytes but workspace only has " f"{workspace.buffer_size_bytes} bytes. Please increase workspace size." )Based on past review comments, the maintainer questioned if this check should be on the execution path. If this is intentionally omitted for performance, please document this as a user responsibility in the docstring.
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (2)
flashinfer/comm/mnnvl.py (1)
566-664: Close remaining POSIX FDs in IPC path to avoid leaksIn the POSIX handle path of
_alloc_mn_mcast_mem, a few FDs are still never closed:
local_shareable_uc_handlereturned bycuMemExportToShareableHandle(line 958) is used in the IPC ring allgather but never closed.- During the ring, each rank sends its
local_shareable_uc_handleto all peers, including itself. The self‑recv forp == group_rankpopulatesall_shareable_uc_handles[self.group_rank], but that FD is never imported (due toif p != self.group_rank) and also never closed.You already close imported POSIX FDs after
cuMemImportFromShareableHandleand close the multicast FD after import; closing the remaining two FDs will complete the cleanup and prevent per‑allocation FD leaks in long‑running jobs.One way to fix this:
if ( self._shareable_handle_type == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC ): # All-gather fabric handles all_shareable_uc_handles = self.comm_backend.allgather( local_shareable_uc_handle.data ) else: # Implement the allgather logic with ipc socket all_shareable_uc_handles = [None] * self.group_size for i in range(self.group_size): self.comm_backend.barrier() # Send to peer at offset i dest_rank = (self.group_rank + i) % self.group_size self._ipc_socket.send_fd(local_shareable_uc_handle, dest_rank) # Receive from peer at offset -i src_rank = (self.group_rank + self.group_size - i) % self.group_size all_shareable_uc_handles[src_rank] = self._ipc_socket.recv_fd() + if ( + self._shareable_handle_type + == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR + ): + # Close our exported UC handle FD and the self-received FD + os.close(local_shareable_uc_handle) + if all_shareable_uc_handles[self.group_rank] is not None: + os.close(all_shareable_uc_handles[self.group_rank])The existing per‑peer close after import:
if self._shareable_handle_type == CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR: os.close(all_shareable_uc_handles[p])and the multicast close:
if self._shareable_handle_type == CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR: os.close(shareable_mc_handle)can stay as‑is.
Together with the new
__del__logic callingself._ipc_socket.close(), this fully addresses the descriptor‑leak concern in the IPC path.Also applies to: 957-1005, 1008-1055
tests/comm/test_trtllm_mnnvl_allreduce.py (1)
233-271: Align reference RMSNorm epsilon with kernel default (still usingtorch.finfo(dtype).eps)
prepare_test_datastill usestorch.finfo(dtype).epsas the epsilon for the reference RMSNorm:norm_out = rmsnorm( residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False )while the kernel and test harness default to
eps = 1e-5(seerun_mnnvl_ar_fulland the C++ FFI wrapper’sparams.epsilondefault). This inconsistency can mask subtle discrepancies behind loose tolerances or cause avoidable test drift.To keep the reference path exactly aligned with the implementation, switch this to the same constant:
- norm_out = rmsnorm( - residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False - ) + norm_out = rmsnorm( + residual_out, + norm_weight, + 1e-5, + enable_pdl=False, + )(or better, reuse the same
epsvalue passed intorun_mnnvl_ar_fullto avoid hard‑coding the constant twice).
🧹 Nitpick comments (4)
csrc/trtllm_mnnvl_allreduce.cu (1)
100-114: Ensure epsilon defaults stay consistent with Python API and testsHere
params.epsilonfalls back to1e-5when the Optionalepsilonis not provided:params.epsilon = epsilon.has_value() ? epsilon.value() : 1e-5;The Python wrapper in
flashinfer/comm/trtllm_mnnvl_ar.pyand the tests intests/comm/test_trtllm_mnnvl_allreduce.pyshould use the same default to avoid silent discrepancies between the kernel and reference paths. The core test harness already setseps = 1e-5; the remaining mismatch is in the reference RMSNorm computation (seeprepare_test_data), which still usestorch.finfo(dtype).eps.flashinfer/comm/mnnvl.py (3)
132-150: Fixalloc_and_copy_to_cudareturn type vsNonebehavior
alloc_and_copy_to_cudais annotated as returningintbut still returnsNonefor an emptyhost_ptr_array:def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int: if not host_ptr_array: return NoneCurrent call sites (
signal_padsanduc_ptrs) always pass non‑empty lists, so behavior is correct, but the annotation is misleading and could hide bugs if the function gets reused.Either make the return type explicit about
Noneor enforce non‑emptiness by raising:-def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int: - if not host_ptr_array: - return None +def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int: + if not host_ptr_array: + raise ValueError("host_ptr_array must be non-empty")(or change the annotation to
Optional[int]if you prefer the sentinel behavior).
885-893: IPC opId bootstrap looks fine; consider documenting ordering guarantees
_init_ipc_socketuses an MPI‑likebcastto distribute a randomly chosenopIdfrom rank 0, then uses it to constructIpcSocketendpoints on all ranks. This nicely avoids hard‑coding operation IDs and lines up with the C++ IPC model.Given the reliance on collective barriers around
send_fd/recv_fd, it would help future maintainers to mention in a comment here that all ranks are expected to participate in the same sequence of IPC operations for a givenopId, and that mismatched usage will deadlock. The code is correct as written; this is just a documentation/clarity suggestion.
1143-1170: McastGPUBuffer workspace integration and pointer getters look consistentThe new
comm_backend_for_handle_transferparameter is threaded through toMcastDeviceMemory, and the addedget_unicast_ptrwrapper simply delegates tomcast_device_memory.get_unicast_ptr(rank). This lines up with how tests andget_allreduce_mnnvl_workspaceuse these pointers and keeps pointer access encapsulated.The placeholder buffer‑view methods (
get_multicast_buffer,get_unicast_buffer) are clearly markedNotImplementedError, so they won’t be hit accidentally. If you plan to expose tensor views later, you can implement them viacreate_tensor_from_cuda_memory.Also applies to: 1209-1212
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between 92cbd483f63f102ce7b34ae5521df8023db1c96d and 5be26976a40db2935501d132ec1db9ba9fbbd1bb.
📒 Files selected for processing (4)
csrc/trtllm_mnnvl_allreduce.cu(1 hunks)flashinfer/comm/mnnvl.py(18 hunks)include/flashinfer/utils.cuh(2 hunks)tests/comm/test_trtllm_mnnvl_allreduce.py(8 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/utils.cuh
🧬 Code graph analysis (3)
tests/comm/test_trtllm_mnnvl_allreduce.py (3)
flashinfer/comm/mapping.py (2)
Mapping(21-475)tp_rank(325-326)flashinfer/comm/trtllm_mnnvl_ar.py (7)
MNNVLAllreduceFusionWorkspace(47-141)mpi_barrier(23-27)trtllm_mnnvl_fused_allreduce_add_rmsnorm(298-388)MNNVLAllreduceFusionStrategy(30-40)trtllm_mnnvl_allreduce(226-295)get_allreduce_mnnvl_workspace(395-448)get_required_buffer_size_bytes(116-141)flashinfer/comm/mnnvl.py (14)
barrier(168-168)barrier(227-228)Get_rank(156-156)Get_rank(215-216)Get_size(159-159)Get_size(218-219)bcast(165-165)bcast(224-225)get_multicast_ptr(871-875)get_multicast_ptr(1205-1207)get_buffer_ptrs_dev(857-859)get_buffer_ptrs_dev(1213-1215)get_unicast_ptr(861-869)get_unicast_ptr(1209-1211)
csrc/trtllm_mnnvl_allreduce.cu (2)
csrc/tvm_ffi_utils.h (1)
get_stream(272-274)flashinfer/comm/trtllm_mnnvl_ar.py (1)
trtllm_mnnvl_allreduce_fusion(168-219)
flashinfer/comm/mnnvl.py (1)
flashinfer/cuda_utils.py (1)
checkCudaErrors(51-61)
🪛 Ruff (0.14.5)
flashinfer/comm/mnnvl.py
587-587: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
612-612: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
640-640: Unpacked variable msg is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
640-640: Unpacked variable flags is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
640-640: Unpacked variable addr is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
656-656: Avoid specifying long messages outside the exception class
(TRY003)
888-888: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (4)
include/flashinfer/utils.cuh (1)
293-307: Thread‑safe SM count cache looks good; confirm single‑GPU‑per‑process assumptionUsing
static std::atomic<int>with relaxed loads/stores makes this helper thread‑safe and avoids the previousstatic intdata race. The comment explicitly assumes one CUDA device per process, since the cachedsm_countis never recomputed if the current device changes.If there are any call sites that may run in a multi‑GPU‑per‑process setup, consider extending this to a per‑device cache (e.g., keyed by device id) rather than a single global integer; otherwise, this implementation is fine as long as the single‑device assumption holds.
csrc/trtllm_mnnvl_allreduce.cu (1)
41-76: RMSNorm fusion validation and shape checks look correctThe updated precondition now correctly requires
residual_in,residual_out,gamma, andepsilonwhenrmsnorm_fusionis true, and the subsequent shape checks onresidual_in,residual_out, andgammaguard the fused path against mismatched tensors. This should prevent the fused kernels from ever seeing invalid residual/norm inputs via the FFI boundary.The overall parameter wiring into
AllReduceFusionParams(including buffer pointers and flags) also looks consistent with the Python side.flashinfer/comm/mnnvl.py (1)
781-790: Good: IPC socket is now closed in destructorThe addition of:
if hasattr(self, "_ipc_socket"): self._ipc_socket.close()inside
__del__ensures the Unix domain socket is closed and, for non‑abstract sockets, the filesystem entry is unlinked. This addresses the earlier socket‑leak concern while remaining safe when construction fails before_ipc_socketis set.tests/comm/test_trtllm_mnnvl_allreduce.py (1)
16-103: Test harness refactor cleanly exercises both refactored and legacy APIsThe new helpers (
row_linear_residual_norm_fusion_forward,_legacy,run_mnnvl_ar_full) and parametrized tests (test_mnnvl_allreduce_refactored,test_mnnvl_allreduce_legacy) do a good job of:
- Sharing core logic between fused and non‑fused paths.
- Covering both the new workspace‑based API and the legacy pointer‑based API.
- Exercising a variety of sequence lengths, dtypes, and hidden sizes.
- Integrating MPI barriers and rank‑aware logging to make multi‑rank failures diagnosable.
Once the epsilon alignment in
prepare_test_datais fixed, this test suite should give solid coverage for the new fused implementation and its backward‑compatibility guarantees.Also applies to: 274-397, 439-465
| comm_backend: Optional[CommBackend] = None, | ||
| ): | ||
| """ | ||
| Initialize the MNNVL Allreduce Fusion Workspace. COMM_WORLD will be used for creating the workspace and synchronization. The process might hang if the intended communication group in mapping is not COMM_WORLD. |
There was a problem hiding this comment.
Is there a way we can check this?
There was a problem hiding this comment.
Forgot to update the doc. Fixed.
| def __init__( | ||
| self, | ||
| mapping: Mapping, | ||
| buffer_size_in_bytes: Optional[int] = None, |
There was a problem hiding this comment.
Could you provide guidance for buffer_size_in_bytes? E.g., in function of number of tokens and hidden size? Or just refer to get_required_buffer_size_bytes
There was a problem hiding this comment.
just refer to get_required_buffer_size_bytes
| def trtllm_mnnvl_allreduce( | ||
| input: torch.Tensor, | ||
| workspace: MNNVLAllreduceFusionWorkspace, | ||
| launch_with_pdl: bool, | ||
| output: Optional[torch.Tensor] = None, | ||
| strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, | ||
| ) -> torch.Tensor: |
There was a problem hiding this comment.
We do want this check. I recently added it because it did bite others.
| def __init__( | ||
| self, | ||
| mapping: Mapping, | ||
| buffer_size_in_bytes: Optional[int] = None, |
There was a problem hiding this comment.
Another option would be to replace "buffer_size_in_bytes" by the parameters that get_required_buffer_size_bytes takes, and just call this from the init function. Seems more user friendly.
If you do want to just allocate a blob of memory, we could still have buffer_size_in_bytes as an addtional parameter that would override whatever is calculated.
There was a problem hiding this comment.
I don't think that is a good design as it might give the user an impression that the allocated workspace will ONLY support the set of parameters (max_num_tokens, hidden_dim, dtype, strategy)
But actually, the workspace usage is quite flexible and as long as the required workspace size is smaller than the allocation, it will work. Thus, the intended usage is the user checks the required workspace size (or we can check it when calling the allreduce function, but at some cost) before using it.
| Optional<TensorView> out) { | ||
| cudaSetDevice(in.device().device_id); | ||
| auto stream = get_stream(in.device()); | ||
| // FIXME: is bool flag for oneshot a good idea? Trying to avoid defining a new type/enum at this |
There was a problem hiding this comment.
I don't think it is a problem
There was a problem hiding this comment.
Comment removed.
| AUTO = 99 | ||
|
|
||
| @staticmethod | ||
| def is_one_shot( |
There was a problem hiding this comment.
Maybe rename this to "heuristic_for_one_shot" or something like that.
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (1)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
361-362: Critical: Restore epsilon default to 1e-5 to match kernel.This epsilon fallback was flagged as critical in a previous review but remains unresolved. Using
torch.finfo(input.dtype).epssets epsilon to approximately 1e-3 for fp16, diverging from the kernel's built-in 1e-5 default (seecsrc/trtllm_mnnvl_allreduce.culine 96). This materially alters RMSNorm results and breaks compatibility with TensorRT-LLM.Apply this fix:
- if epsilon is None: - epsilon = torch.finfo(input.dtype).eps + if epsilon is None: + epsilon = 1e-5
🧹 Nitpick comments (1)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
118-136: Consider replacing@functools.cacheon instance method.Using
@functools.cacheon an instance method can prevent the instance from being garbage collected, leading to memory leaks. Since this method takesselfas the first parameter, the cache will hold references to the instance.Consider either:
- Making this a standalone function that takes workspace parameters explicitly
- Using
@functools.lru_cache(maxsize=...)with a reasonable limit- Implementing manual caching in the instance if needed
Based on learnings
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between 5be26976a40db2935501d132ec1db9ba9fbbd1bb and c6ed14724f52798bb8edc8fda26ebec1bbbec7e2.
📒 Files selected for processing (2)
csrc/trtllm_mnnvl_allreduce.cu(1 hunks)flashinfer/comm/trtllm_mnnvl_ar.py(5 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
flashinfer/comm/trtllm_mnnvl_ar.py
🧬 Code graph analysis (2)
flashinfer/comm/trtllm_mnnvl_ar.py (5)
flashinfer/comm/mapping.py (6)
Mapping(21-475)rank(311-312)rank(315-322)tp_rank(325-326)local_rank(391-392)is_multi_node(403-404)flashinfer/jit/comm.py (1)
gen_trtllm_mnnvl_comm_module(33-39)flashinfer/utils.py (2)
register_custom_op(273-282)register_custom_op(292-311)flashinfer/comm/mnnvl.py (13)
McastGPUBuffer(1135-1215)CommBackend(152-171)MPIBackend(211-232)lamport_initialize(1115-1132)lamport_initialize(1174-1175)barrier(168-168)barrier(227-228)get_buffer_ptrs_dev(857-859)get_buffer_ptrs_dev(1213-1215)get_unicast_ptr(861-869)get_unicast_ptr(1209-1211)get_multicast_ptr(871-875)get_multicast_ptr(1205-1207)csrc/trtllm_mnnvl_allreduce.cu (2)
trtllm_mnnvl_allreduce_fusion(29-113)trtllm_mnnvl_allreduce_fusion(29-35)
csrc/trtllm_mnnvl_allreduce.cu (3)
flashinfer/comm/cuda_ipc.py (2)
cudaSetDevice(149-150)cudaGetErrorString(146-147)csrc/tvm_ffi_utils.h (1)
get_stream(272-274)flashinfer/comm/trtllm_mnnvl_ar.py (1)
trtllm_mnnvl_allreduce_fusion(192-243)
🪛 Ruff (0.14.5)
flashinfer/comm/trtllm_mnnvl_ar.py
77-79: Avoid specifying long messages outside the exception class
(TRY003)
118-118: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks
(B019)
282-284: Avoid specifying long messages outside the exception class
(TRY003)
289-291: Avoid specifying long messages outside the exception class
(TRY003)
303-305: Avoid specifying long messages outside the exception class
(TRY003)
365-367: Avoid specifying long messages outside the exception class
(TRY003)
369-371: Avoid specifying long messages outside the exception class
(TRY003)
373-375: Avoid specifying long messages outside the exception class
(TRY003)
379-381: Avoid specifying long messages outside the exception class
(TRY003)
385-387: Avoid specifying long messages outside the exception class
(TRY003)
398-400: Avoid specifying long messages outside the exception class
(TRY003)
528-530: Avoid specifying long messages outside the exception class
(TRY003)
599-601: Avoid specifying long messages outside the exception class
(TRY003)
605-607: Avoid specifying long messages outside the exception class
(TRY003)
610-612: Avoid specifying long messages outside the exception class
(TRY003)
614-616: Avoid specifying long messages outside the exception class
(TRY003)
619-621: Avoid specifying long messages outside the exception class
(TRY003)
624-626: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (4)
csrc/trtllm_mnnvl_allreduce.cu (1)
29-113: LGTM! Fusion entry point is well-structured.The refactored entry point properly validates all inputs, including the RMSNorm fusion parameters that were flagged in previous reviews. The dispatch logic cleanly selects between oneshot and twoshot strategies, and error messages are clear and actionable.
flashinfer/comm/trtllm_mnnvl_ar.py (3)
30-48: Strategy enum and heuristic look good.The
MNNVLAllreduceFusionStrategyenum provides a clear interface for selecting between oneshot and twoshot approaches, with a sensible AUTO mode that uses an empirically-derived threshold.
250-326: Buffer size validation properly implemented.The function now includes the buffer size check that was requested in previous reviews (lines 300-305), preventing potential out-of-bounds access. Input validation is comprehensive and error messages are clear.
422-646: Deprecation strategy is well-executed.The legacy APIs are properly marked with
@deprecateddecorators and include clear migration guidance. The wrappers correctly redirect to the new fusion-based implementations, maintaining backward compatibility while encouraging adoption of the improved APIs.
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/comm/mnnvl.py (1)
132-149: Fix return type inconsistency.The function returns
Noneat line 137 whenhost_ptr_arrayis empty, but the return type annotation at line 132 indicatesint. This creates a type mismatch.Consider one of these fixes:
Option 1: Return
Optional[int]and update callers to handle None:-def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int: +def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> Optional[int]:Option 2: Raise an error instead of returning None:
if not host_ptr_array: - return None + raise ValueError("host_ptr_array cannot be empty")
♻️ Duplicate comments (1)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
377-378: Restore RMSNorm epsilon default to 1e-5.Overriding
epsilonwithtorch.finfo(input.dtype).epsreplaces the kernel's built-in 1e-5 default (seetrtllm_mnnvl_allreduce_fusionincsrc/trtllm_mnnvl_allreduce.culine ~35:params.epsilon = epsilon.has_value() ? epsilon.value() : 1e-5). For fp16 this becomes ~1e-3, materially changing RMSNorm results and breaking numerical parity.Apply this diff to fix:
if epsilon is None: - epsilon = torch.finfo(input.dtype).eps + epsilon = 1e-5
🧹 Nitpick comments (3)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
134-152: Consider alternatives to@functools.cacheon instance methods.Using
@functools.cache(or@lru_cache) on instance methods can prevent garbage collection of instances because the cache holds references to bound methods, which in turn hold references toself. SinceMNNVLAllreduceFusionWorkspaceinstances are likely long-lived in typical usage, this may be acceptable, but consider these alternatives:
- Use
@functools.lru_cache(maxsize=128)to limit cache growth- Move caching logic to a module-level cache keyed on relevant parameters
- Document the caching behavior and its memory implications
Based on learnings
Apply this diff if you want to limit cache size:
- @functools.cache + @functools.lru_cache(maxsize=128) def is_buffer_size_sufficient(flashinfer/comm/mnnvl.py (2)
640-654: Prefix unused unpacked variables with underscore.The variables
msg,flags, andaddrfromrecvmsgare unpacked but never used. Prefix them with_to indicate they're intentionally ignored.Apply this diff:
- msg, ancdata, flags, addr = self.sock.recvmsg( + _msg, ancdata, _flags, _addr = self.sock.recvmsg(
893-900: Consider usingsecretsmodule for opId generation.While cryptographic randomness is not strictly required for socket naming, using
secrets.randbelow(2**64)instead ofrandom.randintprovides better collision resistance if multiple jobs run concurrently on the same node.Apply this diff:
+import secrets + def _init_ipc_socket(self): if self.group_rank == 0: - # Gnerate the opId - opId = random.randint(0, 2**64 - 1) + # Generate the opId + opId = secrets.randbelow(2**64)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between c6ed14724f52798bb8edc8fda26ebec1bbbec7e2 and a390685bc98732bf7cf97b14833775c2b4962daf.
📒 Files selected for processing (2)
flashinfer/comm/mnnvl.py(19 hunks)flashinfer/comm/trtllm_mnnvl_ar.py(5 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
flashinfer/comm/trtllm_mnnvl_ar.py
🧬 Code graph analysis (2)
flashinfer/comm/mnnvl.py (1)
flashinfer/cuda_utils.py (1)
checkCudaErrors(51-61)
flashinfer/comm/trtllm_mnnvl_ar.py (4)
flashinfer/comm/mapping.py (5)
rank(311-312)rank(315-322)tp_rank(325-326)local_rank(391-392)is_multi_node(403-404)flashinfer/utils.py (2)
register_custom_op(273-282)register_custom_op(292-311)flashinfer/comm/mnnvl.py (13)
McastGPUBuffer(1143-1224)CommBackend(152-171)MPIBackend(211-232)lamport_initialize(1123-1140)lamport_initialize(1183-1184)barrier(168-168)barrier(227-228)get_buffer_ptrs_dev(857-859)get_buffer_ptrs_dev(1222-1224)get_unicast_ptr(861-869)get_unicast_ptr(1218-1220)get_multicast_ptr(871-875)get_multicast_ptr(1214-1216)csrc/trtllm_mnnvl_allreduce.cu (2)
trtllm_mnnvl_allreduce_fusion(29-113)trtllm_mnnvl_allreduce_fusion(29-35)
🪛 Ruff (0.14.5)
flashinfer/comm/mnnvl.py
587-587: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
612-612: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
640-640: Unpacked variable msg is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
640-640: Unpacked variable flags is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
640-640: Unpacked variable addr is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
656-656: Avoid specifying long messages outside the exception class
(TRY003)
896-896: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
flashinfer/comm/trtllm_mnnvl_ar.py
77-79: Avoid specifying long messages outside the exception class
(TRY003)
134-134: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks
(B019)
298-300: Avoid specifying long messages outside the exception class
(TRY003)
305-307: Avoid specifying long messages outside the exception class
(TRY003)
319-321: Avoid specifying long messages outside the exception class
(TRY003)
381-383: Avoid specifying long messages outside the exception class
(TRY003)
385-387: Avoid specifying long messages outside the exception class
(TRY003)
389-391: Avoid specifying long messages outside the exception class
(TRY003)
395-397: Avoid specifying long messages outside the exception class
(TRY003)
401-403: Avoid specifying long messages outside the exception class
(TRY003)
414-416: Avoid specifying long messages outside the exception class
(TRY003)
544-546: Avoid specifying long messages outside the exception class
(TRY003)
615-617: Avoid specifying long messages outside the exception class
(TRY003)
621-623: Avoid specifying long messages outside the exception class
(TRY003)
626-628: Avoid specifying long messages outside the exception class
(TRY003)
630-632: Avoid specifying long messages outside the exception class
(TRY003)
635-637: Avoid specifying long messages outside the exception class
(TRY003)
640-642: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
266-341: LGTM!The function correctly validates inputs, selects strategy, checks buffer size sufficiency (addressing past review feedback), and invokes the fusion kernel with appropriate parameters.
flashinfer/comm/mnnvl.py (1)
788-789: LGTM!The IPC socket cleanup correctly uses
hasattrto check for existence before closing, addressing the file descriptor leak concern from past reviews.
| #elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) | ||
| asm volatile("red.release.global.gpu.add.u32 [%0], %1;" ::"l"(mFlagAccessPtr), "r"(1) | ||
| : "memory"); | ||
| #else |
There was a problem hiding this comment.
Does this mean the kernel is supported for archs < 700? What's the minimal requirement?
For the API we use the @backend_requirement decorator, which lists supported SMs. So as a minimum I think we can list: 70,80,90,100,103,110,120
Would you agree? Further back is probably not as relevant.
There was a problem hiding this comment.
This macro is fairly common in barrier and semaphore utility functions, largely due to the memory consistency qualifiers introduced with the Volta architecture. For example, CUTLASS uses a similar pattern:
That said, I believe our usage here simply follows established convention. Given that our minimum supported architecture is sm_75, we shouldn't actually need these qualifiers.
There was a problem hiding this comment.
The kernel needs multicast to work, which at least requires SM90 and needs NVSwitch
| if (NUM_INPUTS > 0) { | ||
| T_IN accum[ELTS_PER_THREAD]; | ||
| float4* accum4 = (float4*)&accum; | ||
| flag.clearDirtyLamportBuf(inputPtrs[rank], MNNVLTwoShotStage::SCATTER); |
There was a problem hiding this comment.
Instead of clearing the buffer here, could we assume we have properly initialized lamport buffers at the start?
And then, at the end (e.g., right after we call PDL), we can clear the buffers, so that a next kernel using the same workspace can also assume properly initialized lamport buffers.
There was a problem hiding this comment.
We use triple buffer so you can move the buffer clear anywhere. I found the current arrangement get the best performance.
If your assumption is using single buffer, clear the buffer at the end then assume the buffer is initialized for the next kernel, this won't work as it requires membar.sys which is very expensive. Moreover, we need to order the buffer write of the next kernel, other ranks after the buffer clear of this kernel, which has to use a flag and invalidate the usage of lamport protocol.
| // Check validity across all elements | ||
| #pragma unroll | ||
| for (int i = 0; i < kLAMPORT_ELTS_PER_PACKED; i++) { | ||
| valid &= !isNegZero(valuesLamport[r].elements[i]); | ||
| } |
There was a problem hiding this comment.
As long as one of the values in WorldSize isNegZero, we will keep reading the others over and over again. Perhaps these are all cache hits, in which case: ignore my comment. Otherwise, would it make sense to break out of the inner loop (kLAMPORT_ELTS_PER_PACKED) early?
There was a problem hiding this comment.
It will be cache hit if it is still invalid.
For the inner loop, I had a version that only check one element for validity. But there is no architecture guarantee that a 128B read/write is atomic, which mean beak the inner loop early could, in rare case, cause wrong result.
| all_shareable_uc_handles = self.comm_backend.allgather( | ||
| local_shareable_uc_handle.data | ||
| ) | ||
| else: |
There was a problem hiding this comment.
[nit] Add a comment regarding handle type in this case: CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
There was a problem hiding this comment.
Branching logic no longer exists in the latest commit.
| mc_fabric_handle.data if mc_fabric_handle else None, root=0 | ||
| ) | ||
| shareable_mc_handle = None | ||
| if ( |
There was a problem hiding this comment.
With all the branching on this handle type, would it make sense to build a class that encapsulate this logic?
Roughly thinking:
class ABCCommBufferAllocator:
def allocate_UnicastBuffer()
def allocate_MulitcastBuffer()
class CUFabricCommBufferAllocator(ABCCommBufferAllocator)
class CUPosixFDCommBufferAllocator(ABCCommBufferAllocator)
There was a problem hiding this comment.
Refactored in the latest commit.
| ) | ||
| self.buf_size = buf_size | ||
| self.local_device = device | ||
|
|
||
| def lamport_initialize(self, rank: int, dtype: torch.dtype): | ||
| self.mcast_device_memory.lamport_initialize(rank, dtype) | ||
|
|
||
| def get_mc_buffer( | ||
| def get_multicast_buffer( |
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/comm/test_trtllm_mnnvl_allreduce.py (1)
406-425: Potential deadlock:allgathercalled only by failing ranks.The
MPI.COMM_WORLD.allgather(rank_failed)call is inside theexceptblock, so only ranks that throw an exception will participate. If rank 0 fails while rank 1 succeeds, rank 1 continues to the barrier at line 433 while rank 0 blocks onallgather—causing a distributed deadlock.Consider moving failure synchronization outside the exception handler:
except Exception as e: rank_failed = True failure_message = f"FAILED[rank={rank}]: seq_lens={seq_lens}, fusion={fusion}, dtype={dtype} failed: {e}" print(failure_message) print(traceback.format_exc()) - # Gather failure status from all ranks for logging - all_failures = MPI.COMM_WORLD.allgather(rank_failed) - - if any(all_failures): - failed_ranks = [i for i, failed in enumerate(all_failures) if failed] - if rank == 0: - print(f"Test failed on ranks: {failed_ranks}") - # Cleanup before re-raising if "workspace" in locals(): del workspace # Re-raise the original exception so it can be caught by pytest.raises in negative tests raise finally: # Ensure cleanup happens for this list's workspace if "workspace" in locals(): del workspace + + # Gather failure status from all ranks for logging (must be outside except to avoid deadlock) + all_failures = MPI.COMM_WORLD.allgather(rank_failed) + if any(all_failures): + failed_ranks = [i for i, failed in enumerate(all_failures) if failed] + if rank == 0: + print(f"Test failed on ranks: {failed_ranks}")Note: With this change, the
raiseat line 425 would need to be moved after thefinallyblock completes, or handled differently to ensure all ranks synchronize before any rank exits.
♻️ Duplicate comments (1)
tests/comm/test_trtllm_mnnvl_allreduce.py (1)
257-264: Epsilon mismatch between reference computation and kernel execution.The reference RMSNorm uses
torch.finfo(dtype).eps(line 263), which is ~6e-8 for float16 and ~1e-7 for bfloat16, while the actual kernel execution useseps = 1e-5(line 322). This 100-1000x difference in epsilon values will produce different numerical results between reference and actual outputs.Additionally, line 257 has a type annotation issue:
Tuple[torch.Tensor, ...]assigned toNone.- reference_output: Tuple[torch.Tensor, ...] = None + reference_output: Optional[Tuple[torch.Tensor, ...]] = None if fusion: # Fused case: AllReduce + Residual Add + RMS Norm allreduce_result = torch.sum(x_full, dim=0) # AllReduce result residual_out = allreduce_result + residual # Add residual norm_out = rmsnorm( - residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False + residual_out, norm_weight, 1e-5, enable_pdl=False )
🧹 Nitpick comments (2)
tests/comm/test_trtllm_mnnvl_allreduce.py (2)
60-69: Remove unnecessary tensor allocation.Line 61 allocates
output = torch.empty_like(input)but it's immediately overwritten by the return value oftrtllm_mnnvl_allreduceon line 63. This wastes GPU memory unnecessarily.else: - output = torch.empty_like(input) - output = trtllm_mnnvl_ar.trtllm_mnnvl_allreduce( input, workspace, launch_with_pdl=use_pdl, strategy=trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.AUTO, ) return (output.view(shape),)
274-276: Unusedmonkeypatchparameter.The
monkeypatchfixture is passed but never used in the function body. If it's not needed, consider removing it from the signature and test function parameters.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/comm/test_trtllm_mnnvl_allreduce.py(9 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/comm/test_trtllm_mnnvl_allreduce.py (3)
flashinfer/comm/mapping.py (2)
Mapping(21-475)tp_rank(325-326)flashinfer/comm/trtllm_mnnvl_ar.py (5)
MNNVLAllreduceFusionWorkspace(50-181)MNNVLAllreduceFusionStrategy(30-43)trtllm_mnnvl_allreduce(266-341)get_allreduce_mnnvl_workspace(442-495)get_required_buffer_size_bytes(156-181)flashinfer/comm/mnnvl.py (2)
barrier(168-168)barrier(227-228)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
tests/comm/test_trtllm_mnnvl_allreduce.py (2)
105-124: LGTM!The legacy API test helper correctly passes buffer pointers and uses MPI barrier for synchronization.
439-465: LGTM!Good test coverage with comprehensive parameterization across sequence lengths, fusion modes, data types, and hidden sizes. The separation between refactored and legacy API tests is well-structured.
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/comm/mnnvl.py (1)
132-149: Clarify empty-input behavior inalloc_and_copy_to_cuda
alloc_and_copy_to_cudais annotated to returnintbut returnsNonewhenhost_ptr_arrayis empty (Line 136). That mismatch can surface as a runtime bug if an empty list ever slips through.Consider failing fast instead of returning
None, e.g.:def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int: - """ - A helper function that allocates memory on cuda and copies the data from the host to the device. - """ - if not host_ptr_array: - return None + """ + A helper function that allocates memory on cuda and copies the data from the host to the device. + """ + if not host_ptr_array: + raise ValueError("host_ptr_array must be non-empty")This keeps the return type consistent and surfaces misuse early.
♻️ Duplicate comments (2)
flashinfer/comm/mnnvl.py (2)
300-305: Fixset_comm_from_configwhenconfigisNoneWhen
configisNone, you assign the fallback toMnnvlMemory.configbut still callconfig.comm_backend.Split(...)(Line 302), which will raiseAttributeError.Use the effective config for both the assignment and the
Splitcall:@staticmethod def set_comm_from_config(mapping: Mapping, config: MnnvlConfig = None): - MnnvlMemory.config = config or MnnvlConfig(comm_backend=MPIBackend()) # type: ignore[attr-defined] - comm = config.comm_backend.Split( - mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank - ) + config = config or MnnvlConfig(comm_backend=MPIBackend()) # type: ignore[attr-defined] + MnnvlMemory.config = config + comm = config.comm_backend.Split( + mapping.pp_rank * mapping.cp_size + mapping.cp_rank, + mapping.tp_rank, + ) MnnvlMemory.comm = comm # type: ignore[assignment]This matches the earlier suggested fix and ensures the fallback path actually works.
#!/bin/bash # Simple sanity check: search for other direct uses of bare `config` in this method. rg -n "set_comm_from_config" flashinfer/comm/mnnvl.py -n -C5
720-759: Close local POSIX shareable handles to avoid slow FD leaksThe new exchanger abstraction correctly closes imported POSIX FDs via
PosixFDHandleExchanger.cleanup()(Lines 760–762), but two categories of FDs still remain unclosed in the POSIX path:
local_shareable_uc_handleproduced bycuMemExportToShareableHandle(Lines 1065–1072) is never passed throughcleanup.- For the multicast handle, rank 0’s
shareable_mc_handle(Lines 1117–1125) is broadcast but only non‑root ranks callcleanupafter import (Lines 1135–1141).Over long‑running runs that create/destroy
McastDeviceMemoryrepeatedly, these will accumulate OS FDs.You can reuse the exchanger cleanup hook to fix both without special‑casing for handle type:
# All-gather shareable handles all_shareable_uc_handles = self._exchanger.allgather(local_shareable_uc_handle) cuda.cuCtxSynchronize() # Import remote handles for p in range(self.group_size): if p != self.group_rank: self.uc_handles[p] = checkCudaErrors( cuda.cuMemImportFromShareableHandle( all_shareable_uc_handles[p], self._exchanger.handle_type, ) ) self._exchanger.cleanup(all_shareable_uc_handles[p]) + + # We no longer need our own exported shareable handle. + self._exchanger.cleanup(local_shareable_uc_handle)And in
_setup_multicast:# Broadcast multicast handle from rank 0 shareable_mc_handle = self._exchanger.broadcast(shareable_mc_handle, root=0) cuda.cuCtxSynchronize() - # Import multicast handle for non-root ranks - if self.group_rank != 0: + # Import multicast handle for non-root ranks + if self.group_rank != 0: self.mc_handle = checkCudaErrors( cuda.cuMemImportFromShareableHandle( shareable_mc_handle, self._exchanger.handle_type, ) ) self._exchanger.cleanup(shareable_mc_handle) + else: + # Root rank can now drop its exported handle (POSIX FD path). + self._exchanger.cleanup(shareable_mc_handle)
FabricHandleExchanger.cleanup()is a no‑op, so these calls are harmless in the fabric path and fix the leak in the POSIX‑FD path.#!/bin/bash # Check remaining uses of cuMemExportToShareableHandle and ensure every handle is cleaned up. rg -n "cuMemExportToShareableHandle" flashinfer/comm/mnnvl.py -n -C5Also applies to: 760-765, 1065-1076, 1079-1088, 1129-1142
🧹 Nitpick comments (3)
flashinfer/comm/mnnvl.py (3)
566-664: TidyIpcSocket.recv_fdunused variables and document/tmppath usage
recv_fdcurrently bindsmsg,flags, andaddrbut never uses them (Line 640), which Ruff flags (RUF059). You can keep the signature while silencing the warning:- fds = array.array("i") - msg, ancdata, flags, addr = self.sock.recvmsg( + fds = array.array("i") + _msg, ancdata, _flags, _addr = self.sock.recvmsg( 1, socket.CMSG_SPACE( fds.itemsize ), # Buffer size for dummy data # Ancillary data size )On the S108
/tmp/mcastmem-socket-warning: since you default touse_abstract=True, the filesystem path is never actually used in normal operation. If you expectuse_abstract=Falsein multi‑tenant environments, consider adding a brief comment noting that this path is intended for trusted deployments only.
1003-1013: Consider narrowing the broadExceptioncatch in_verify_cuda_context
_verify_cuda_contextcurrently catches a bareException(Lines 1005–1012). Given this is only used for diagnostics, that’s acceptable, but narrowing to CUDA‑related exceptions (or at least logging the exception type) would align better with BLE001 and avoid swallowing unrelated programming errors.Not strictly required, but worth considering:
- except Exception as e: - print(f"Error checking CUDA context: {e}") + except Exception as e: + # Broad catch is intentional: any CUDA context error should only warn here. + print(f"Error checking CUDA context: {type(e).__name__}: {e}")
1242-1268: Explicitly markget_multicast_buffer/get_unicast_bufferas internal or implement via the new helperBoth
get_multicast_bufferandget_unicast_buffercurrently raiseNotImplementedError(Lines 1257–1258 and 1267–1268). Given prior discussion that this class is internal, that’s acceptable, but the public‑sounding names can confuse users if they discover them.Two options:
- If they are truly internal placeholders, consider prefixing with
_or clarifying in the docstring that they are not expected to be called yet.- If you want them usable, wire them up through
create_tensor_from_cuda_memoryat the top of the file, e.g., usingself.get_multicast_ptr()/self.get_unicast_ptr(rank)plus the desired shape/dtype.Happy to sketch an implementation using
create_tensor_from_cuda_memoryif you plan to expose these in this PR.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/comm/mnnvl.py(19 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/comm/mnnvl.py (2)
flashinfer/cuda_utils.py (1)
checkCudaErrors(51-61)flashinfer/utils.py (1)
round_up(631-633)
🪛 Ruff (0.14.6)
flashinfer/comm/mnnvl.py
587-587: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
612-612: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
640-640: Unpacked variable msg is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
640-640: Unpacked variable flags is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
640-640: Unpacked variable addr is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
656-656: Avoid specifying long messages outside the exception class
(TRY003)
726-726: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
1011-1011: Do not catch blind exception: Exception
(BLE001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
flashinfer/comm/mnnvl.py (2)
878-887: Allocation sizing and usable buffer exposure look consistent
get_allocation_sizeandget_usable_buffer_size(Lines 982–989) align withlamport_initialize, which usesallocation_size - SIGNAL_PAD_SIZE(Lines 1191–1193), andMcastGPUBuffernow surfaces this usable size viaself.buf_size(Lines 1235–1237).This keeps the Python view of capacity in sync with what Lamport initialization actually touches, which is a nice improvement over passing the raw requested size around.
Also applies to: 982-989, 1191-1193, 1235-1237
164-166: I'm unable to access the repository due to persistent cloning issues. However, based on the information provided in the review comment itself, I can provide analysis:The review comment acknowledges uncertainty with the phrase "if any exist in this repo", which suggests:
- The reviewer was aware that there may not be other
CommBackendimplementations- The reviewer already identified
MPIBackendas implementingbcast(lines 224-226)- The concern is conditional on whether other subclasses exist
Without direct repository access to verify all
CommBackendsubclasses and their implementations, I cannot definitively confirm or refute the original review comment's concern.Recommended next steps:
- Manually search the repository for classes inheriting from
CommBackend- Verify each implementation includes the
bcastmethod- Or provide confirmation if only
MPIBackendexists as aCommBackendsubclassSince I cannot complete the verification due to technical limitations:
New
bcastabstract requires allCommBackendimplementations to be updatedAdding
CommBackend.bcast(...)(Lines 164–166) and implementing it inMPIBackend(Lines 224–226) is correct, but any otherCommBackendsubclasses must also implementbcastto avoidTypeErrorat instantiation. Verify that all subclasses (if any exist) have been updated accordingly.
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (1)
tests/comm/test_trtllm_mnnvl_allreduce.py (1)
263-263: Ensure epsilon consistency with the API under test.Line 263 uses
torch.finfo(dtype).epsfor the reference RMSNorm calculation. This should match whatever epsilon default the actual API uses. I've flagged the API's epsilon default (line 407 intrtllm_mnnvl_ar.py) for verification—if that's changed to 1e-5, update this test accordingly:norm_out = rmsnorm( - residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False + residual_out, norm_weight, 1e-5, enable_pdl=False )This is related to the epsilon verification in the main API file.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/comm/trtllm_mnnvl_ar.py(5 hunks)tests/comm/test_trtllm_mnnvl_allreduce.py(9 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
flashinfer/comm/trtllm_mnnvl_ar.py
🧬 Code graph analysis (1)
flashinfer/comm/trtllm_mnnvl_ar.py (6)
flashinfer/comm/mapping.py (4)
Mapping(21-475)rank(311-312)rank(315-322)tp_rank(325-326)flashinfer/jit/comm.py (1)
gen_trtllm_mnnvl_comm_module(33-39)flashinfer/comm/mnnvl.py (13)
McastGPUBuffer(1199-1280)CommBackend(152-171)MPIBackend(211-232)lamport_initialize(1179-1196)lamport_initialize(1239-1240)barrier(168-168)barrier(227-228)get_buffer_ptrs_dev(954-956)get_buffer_ptrs_dev(1278-1280)get_unicast_ptr(958-966)get_unicast_ptr(1274-1276)get_multicast_ptr(968-972)get_multicast_ptr(1270-1272)tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py (1)
barrier(60-64)csrc/trtllm_mnnvl_allreduce.cu (2)
trtllm_mnnvl_allreduce_fusion(29-113)trtllm_mnnvl_allreduce_fusion(29-35)flashinfer/fused_moe/utils.py (1)
shape(76-77)
🪛 Ruff (0.14.7)
flashinfer/comm/trtllm_mnnvl_ar.py
106-108: Avoid specifying long messages outside the exception class
(TRY003)
163-163: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks
(B019)
327-329: Avoid specifying long messages outside the exception class
(TRY003)
334-336: Avoid specifying long messages outside the exception class
(TRY003)
348-350: Avoid specifying long messages outside the exception class
(TRY003)
410-412: Avoid specifying long messages outside the exception class
(TRY003)
414-416: Avoid specifying long messages outside the exception class
(TRY003)
418-420: Avoid specifying long messages outside the exception class
(TRY003)
424-426: Avoid specifying long messages outside the exception class
(TRY003)
430-432: Avoid specifying long messages outside the exception class
(TRY003)
443-445: Avoid specifying long messages outside the exception class
(TRY003)
576-578: Avoid specifying long messages outside the exception class
(TRY003)
647-649: Avoid specifying long messages outside the exception class
(TRY003)
653-655: Avoid specifying long messages outside the exception class
(TRY003)
658-660: Avoid specifying long messages outside the exception class
(TRY003)
662-664: Avoid specifying long messages outside the exception class
(TRY003)
667-669: Avoid specifying long messages outside the exception class
(TRY003)
672-674: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (5)
flashinfer/comm/trtllm_mnnvl_ar.py (2)
295-371: LGTM! Well-structured API with proper validation.The function includes comprehensive input validation, automatic strategy selection, and buffer size checking before kernel invocation. The interface is clean and user-friendly.
406-407: Verify epsilon default—past review flagged torch.finfo(..).eps as problematic.A previous review marked this as critical: using
torch.finfo(input.dtype).epsgives ~1e-3 for fp16, potentially breaking parity with TensorRT-LLM's 1e-5 default. That issue was reportedly addressed in earlier commits, yet this code still usestorch.finfo.To resolve this, verify:
- Check
csrc/trtllm_mnnvl_allreduce.cuto confirm the kernel's epsilon default- Review git history to determine if the fix was reverted
- Run existing tests to check for numerical discrepancies
- Review documentation to understand the intent of using
torch.finfovs. a fixed constantIf the kernel indeed defaults to 1e-5 and tests reveal numerical differences, the fix should be applied as suggested in the original review.
tests/comm/test_trtllm_mnnvl_allreduce.py (3)
17-103: LGTM! Test helper properly exercises both fused and unfused paths.The refactored helper correctly uses the new workspace-based API, handles shape transformations, and validates outputs against reference results with appropriate tolerances.
274-430: LGTM! Robust test orchestration with proper error handling.The test function handles both legacy and refactored API paths, includes comprehensive error handling with per-rank failure tracking, and ensures proper workspace cleanup. The MPI synchronization points are appropriately placed.
435-461: LGTM! Comprehensive test parameterization.The test functions provide good coverage across multiple sequence lengths, fusion modes, data types, and hidden dimensions. The distinction between refactored and legacy API tests is appropriate.
| if max_num_tokens > MNNVL_ONE_SHOT_THRESHOLD: | ||
| two_shot_size_bytes = self.get_required_buffer_size_bytes( | ||
| mapping.tp_size, | ||
| max_num_tokens, | ||
| hidden_dim, | ||
| dtype, | ||
| MNNVLAllreduceFusionStrategy.TWOSHOT, | ||
| ) | ||
| else: | ||
| two_shot_size_bytes = 0 |
There was a problem hiding this comment.
Fix incorrect threshold comparison—comparing tokens to bytes.
Line 85 compares max_num_tokens > MNNVL_ONE_SHOT_THRESHOLD, but MNNVL_ONE_SHOT_THRESHOLD is a byte threshold (line 47), not a token count. This causes incorrect strategy selection when deciding whether to compute two_shot_size_bytes.
Apply this diff to compare the computed one-shot size in bytes instead:
- if max_num_tokens > MNNVL_ONE_SHOT_THRESHOLD:
+ if one_shot_size_bytes > MNNVL_ONE_SHOT_THRESHOLD:
two_shot_size_bytes = self.get_required_buffer_size_bytes(📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if max_num_tokens > MNNVL_ONE_SHOT_THRESHOLD: | |
| two_shot_size_bytes = self.get_required_buffer_size_bytes( | |
| mapping.tp_size, | |
| max_num_tokens, | |
| hidden_dim, | |
| dtype, | |
| MNNVLAllreduceFusionStrategy.TWOSHOT, | |
| ) | |
| else: | |
| two_shot_size_bytes = 0 | |
| if one_shot_size_bytes > MNNVL_ONE_SHOT_THRESHOLD: | |
| two_shot_size_bytes = self.get_required_buffer_size_bytes( | |
| mapping.tp_size, | |
| max_num_tokens, | |
| hidden_dim, | |
| dtype, | |
| MNNVLAllreduceFusionStrategy.TWOSHOT, | |
| ) | |
| else: | |
| two_shot_size_bytes = 0 |
🤖 Prompt for AI Agents
In flashinfer/comm/trtllm_mnnvl_ar.py around lines 85 to 94, the code
incorrectly compares max_num_tokens to MNNVL_ONE_SHOT_THRESHOLD (bytes), so
compute the one-shot buffer size in bytes by calling
get_required_buffer_size_bytes(...) with MNNVLAllreduceFusionStrategy.ONESHOT
and compare that byte size to MNNVL_ONE_SHOT_THRESHOLD; if the computed
one_shot_size_bytes exceeds the threshold then compute two_shot_size_bytes as
shown, otherwise set two_shot_size_bytes to 0. Ensure you pass the same
mapping.tp_size, max_num_tokens, hidden_dim, dtype and use the ONESHOT strategy
for the size calculation before the threshold comparison.
| self.uc_ptr_local = self.mcast_buffer_handle.get_unicast_ptr(self.rank) | ||
| self.mc_ptr = self.mcast_buffer_handle.get_multicast_ptr() | ||
|
|
||
| @functools.cache |
There was a problem hiding this comment.
Cache on instance method can prevent garbage collection.
Using @functools.cache on an instance method can cause memory leaks because the cache holds references to self, preventing the instance from being garbage collected. Since workspace objects may be created and destroyed across different configurations, consider using @functools.lru_cache(maxsize=...) with a bounded cache or moving this to a module-level function that doesn't capture self.
Based on learnings, flashinfer may accept this pattern if workspace instances are expected to be long-lived. If that's the case, add a comment explaining the rationale.
🧰 Tools
🪛 Ruff (0.14.7)
163-163: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks
(B019)
🤖 Prompt for AI Agents
In flashinfer/comm/trtllm_mnnvl_ar.py around line 163, the use of
@functools.cache on an instance method can leak memory because the cache retains
references to self; replace it with @functools.lru_cache(maxsize=...) to bound
the cache size or refactor the cached logic into a module-level function that
accepts only hashable inputs (not self), and if workspace instances are
intentionally long-lived, add a concise comment explaining that choice and why
unbounded caching is acceptable.
| if strategy == MNNVLAllreduceFusionStrategy.ONESHOT: | ||
| # For one-shot, each rank needs to store num_tokens * tp_size tokens | ||
| buffer_size = num_tokens * hidden_dim * tp_size * elem_size | ||
| else: |
There was a problem hiding this comment.
As discussed we would like to make one-shot vs two-shot at runtime, so here I would add MAX_ONE_SHOT_SIZE= MNNVL_ONE_SHOT_THRESHOLD
and buffer_size = max(...., MAX_ONE_SHOT_SIZE). And add a comment explaining that this workspace can be used for both one-shot and two-shot.
There was a problem hiding this comment.
Forgot to press submit. Ignore this.
There was a problem hiding this comment.
I assume this is considered resolved?
This function is used purely as an utility for calculating required workspace given a problem size and implementaation. It should not implement the logic we discuss, which is implemented in the constructor.
|
|
||
| Args: | ||
| mapping: Mapping configuration containing rank info | ||
| buffer_size_in_bytes: The requested size in bytes for each lamport buffer. The actual allocation size may be larger due to alignment requirements. The actual usable size will be NUM_LAMPORT_BUFFERS * actual_buffer_size_per_lamport_buffer. |
There was a problem hiding this comment.
[nit] document the new parameters, and also bring up the point that buffer_size_in_bytes is needed when we force one-shot.
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
419-420: Critical: RMSNorm epsilon default breaks parity with TensorRT-LLM.The code sets
epsilon = torch.finfo(input.dtype).epswhen epsilon is None, which for fp16 becomes ~1e-3 instead of the kernel's built-in default of 1e-5. This materially changes RMSNorm behavior and breaks numerical parity with TensorRT-LLM.A past review comment flagged this issue as addressed, but the problematic code remains. Please restore the correct default.
Apply this diff to fix the epsilon default:
if epsilon is None: - epsilon = torch.finfo(input.dtype).eps + epsilon = 1e-5
🧹 Nitpick comments (2)
csrc/trtllm_mnnvl_allreduce.cu (2)
63-64: Optional: Use extracted variables in error messages for consistency.The error message references
input.size(0)andinput.size(1), butnum_tokensandtoken_dimhave already been extracted (lines 41-42). Using the variables would improve consistency.Apply this diff for consistency:
- << "residual_in shape mismatch: expected (" << input.size(0) << ", " << input.size(1) + << "residual_in shape mismatch: expected (" << num_tokens << ", " << token_dimThe same applies to the
residual_outerror message at lines 68-69:- << "residual_out shape mismatch: expected (" << input.size(0) << ", " << input.size(1) + << "residual_out shape mismatch: expected (" << num_tokens << ", " << token_dim
96-96: Optional: Consider extracting the default epsilon value to a named constant.The default epsilon value
1e-5is hardcoded. For better maintainability and documentation, consider defining it as a named constant at file or namespace scope.Example:
constexpr double kDefaultRMSNormEpsilon = 1e-5;Then use:
- params.epsilon = epsilon.has_value() ? epsilon.value() : 1e-5; + params.epsilon = epsilon.has_value() ? epsilon.value() : kDefaultRMSNormEpsilon;
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
csrc/trtllm_mnnvl_allreduce.cu(1 hunks)flashinfer/comm/trtllm_mnnvl_ar.py(5 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
flashinfer/comm/trtllm_mnnvl_ar.py
🧬 Code graph analysis (1)
csrc/trtllm_mnnvl_allreduce.cu (3)
csrc/tvm_ffi_utils.h (1)
get_stream(277-279)flashinfer/comm/cuda_ipc.py (1)
cudaGetErrorString(146-147)flashinfer/comm/trtllm_mnnvl_ar.py (1)
trtllm_mnnvl_allreduce_fusion(250-301)
🪛 Ruff (0.14.8)
flashinfer/comm/trtllm_mnnvl_ar.py
119-121: Avoid specifying long messages outside the exception class
(TRY003)
176-176: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks
(B019)
340-342: Avoid specifying long messages outside the exception class
(TRY003)
347-349: Avoid specifying long messages outside the exception class
(TRY003)
361-363: Avoid specifying long messages outside the exception class
(TRY003)
423-425: Avoid specifying long messages outside the exception class
(TRY003)
427-429: Avoid specifying long messages outside the exception class
(TRY003)
431-433: Avoid specifying long messages outside the exception class
(TRY003)
437-439: Avoid specifying long messages outside the exception class
(TRY003)
443-445: Avoid specifying long messages outside the exception class
(TRY003)
456-458: Avoid specifying long messages outside the exception class
(TRY003)
589-591: Avoid specifying long messages outside the exception class
(TRY003)
660-662: Avoid specifying long messages outside the exception class
(TRY003)
666-668: Avoid specifying long messages outside the exception class
(TRY003)
671-673: Avoid specifying long messages outside the exception class
(TRY003)
675-677: Avoid specifying long messages outside the exception class
(TRY003)
680-682: Avoid specifying long messages outside the exception class
(TRY003)
685-687: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (15)
csrc/trtllm_mnnvl_allreduce.cu (6)
29-35: Function signature correctly implements the fusion API.The signature properly exposes all necessary parameters for the fused allreduce operation, including the new RMSNorm fusion support. The parameter organization is clear and matches the Python wrapper.
36-42: Device setup and parameter extraction look correct.The device guard and stream extraction follow standard patterns. Parameter extraction from the input tensor is straightforward.
54-74: Excellent: RMSNorm fusion validation properly implemented.The validation logic correctly addresses the previous review comment by:
- Requiring
residual_in,residual_out,gamma, andepsilonwhenrmsnorm_fusionis true (lines 54-58)- Validating shapes for all three tensors (lines 60-74)
This prevents undefined behavior when the fused kernel dereferences these pointers.
77-102: Parameter struct setup is correct.The parameter mapping properly handles all inputs, outputs, and buffers. The conditional pointer extraction for optional parameters is handled correctly with null fallbacks.
104-112: Dispatch logic and error handling are correct.The conditional dispatch based on
use_oneshotproperly routes to the appropriate kernel implementation. The error message correctly references the new function name.
115-115: FFI export is correct.The function is properly exported with the correct symbol name matching the function signature.
flashinfer/comm/trtllm_mnnvl_ar.py (9)
30-48: LGTM—strategy enum and selection logic are correct.The strategy selection compares total data size against the threshold to choose between one-shot and two-shot approaches, which aligns with the documented behavior.
196-224: LGTM—buffer size calculation is correct.The static method correctly computes buffer sizes for both strategies. The two-shot calculation properly rounds up to
tp_sizemultiples and accounts for the 2-stage pipeline.
226-305: LGTM—kernel registration and exposure are correct.The custom op registration properly lists all mutating arguments, and the wrapper function provides clear documentation for all parameters.
308-384: LGTM—validation and kernel invocation are correct.The function properly validates inputs, auto-selects strategy, checks buffer sufficiency, and invokes the fusion kernel with RMSNorm disabled.
422-477: LGTM—validation and kernel invocation logic are correct.The function properly validates all input/output tensors, checks buffer sufficiency, and correctly invokes the fusion kernel with RMSNorm enabled. (Note: epsilon default issue is flagged separately.)
480-540: LGTM—deprecated API properly redirects to new workspace.The legacy function is correctly marked as deprecated and redirects to the new
MNNVLAllreduceFusionWorkspacewhile maintaining backward compatibility.
543-614: LGTM—deprecated API maintains backward compatibility.The legacy function properly redirects to the new fusion kernel while maintaining expected behavior. The assertion at lines 594-596 correctly prevents unsupported usage patterns.
617-707: LGTM—deprecated RMSNorm fusion API correctly redirects.The legacy fused RMSNorm function is properly deprecated and redirects to the new kernel with comprehensive validation and correct parameter mapping.
88-110: Verify the oneshot buffer size calculation logic.The code clamps
oneshot_max_num_tokensto ensure the one-shot size doesn't exceedMNNVL_ONE_SHOT_THRESHOLDby dividing the threshold by(tp_size * elem_size * hidden_dim). Confirm that this division formula correctly inverts the buffer size calculation used inget_required_buffer_size_bytes(), and verify whetherTWOSHOTstrategy also respects a similar threshold or if it's unbounded relative toONESHOT.
| @functools.cache | ||
| def is_buffer_size_sufficient( | ||
| self, | ||
| tp_size: int, | ||
| num_tokens: int, | ||
| hidden_dim: int, | ||
| dtype: torch.dtype, | ||
| strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, | ||
| ) -> bool: | ||
| """ | ||
| Calculate the required buffer size for a given problem size. | ||
| """ | ||
| required_buffer_size = self.get_required_buffer_size_bytes( | ||
| tp_size, num_tokens, hidden_dim, dtype, strategy | ||
| ) | ||
| if required_buffer_size > self.buffer_size_bytes: | ||
| return False | ||
| else: | ||
| return True | ||
|
|
There was a problem hiding this comment.
Reconsider @functools.cache on instance method—potential memory leak.
Using @functools.cache on an instance method prevents the workspace instance from being garbage collected because the cache holds a reference to self. Since workspace objects allocate GPU buffers and may be created/destroyed across different configurations, this can lead to GPU memory leaks.
Consider using @functools.lru_cache(maxsize=128) to bound the cache size, or refactor the check to avoid caching on the instance method altogether (e.g., compute on-demand since it's a simple comparison).
Based on learnings, if workspace instances are intentionally long-lived for the entire application lifetime, add a comment explaining why unbounded caching is acceptable here.
🧰 Tools
🪛 Ruff (0.14.8)
176-176: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks
(B019)
| workspace: MNNVLAllreduceFusionWorkspace, | ||
| launch_with_pdl: bool, | ||
| output: Optional[torch.Tensor] = None, | ||
| strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, |
There was a problem hiding this comment.
Before calling this API, users are expected to create
workspace = MNNVLAllreduceFusionWorkspace(...),
which internally decides whether to use the one-shot or two-shot strategy.
But what happens if the framework (user) chooses a different strategy than the one determined by the workspace?
There was a problem hiding this comment.
It's possible, and by design actually.
There is a check "if not workspace.is_buffer_size_sufficient(...)" that happens before the allreduce executes.
There was a problem hiding this comment.
The longer answer is that: (1) the workspace creation will use a heuristic to find the optimal situation: create one-shot space (the largest one for the same problem space) up to a certain threshold, and two-shot after that. (*)
Then at runtime (2) when allreduce gets called, we check the heuristic again, and use one-shot/two-shot based on that.
However, this means we could fall below the one-shot threshold at runtime. Therefore, internally we also made sure that enough memory was allocated for the largest one-shot problem sizes (refer to trtllm_mnnvl_ar.py:L88).
Now, if a user intentionally wants to use one-shot all the time, they are expected to calculate the memory themselves (see the comments).
| required_buffer_size = self.get_required_buffer_size_bytes( | ||
| tp_size, num_tokens, hidden_dim, dtype, strategy | ||
| ) | ||
| if required_buffer_size > self.buffer_size_bytes: |
There was a problem hiding this comment.
simplify this to: return required_buffer_size <= self.buffer_size_bytes
| def trtllm_mnnvl_allreduce( | ||
| input: torch.Tensor, | ||
| workspace: MNNVLAllreduceFusionWorkspace, | ||
| launch_with_pdl: bool, |
There was a problem hiding this comment.
should we just make this optional and default to True?
| input: Local Input Shard [num_tokens, hidden_dim] | ||
| workspace: MNNVLAllreduceFusionWorkspace | ||
| launch_with_pdl: Whether to launch with PDL | ||
| output: Output tensor to store the result, empty tensor will be created if not provided. |
There was a problem hiding this comment.
Not a blocker for this PR, but just a question: is there a plan to fuse FP8-Quant and NVFP4-Quant into this kernel later?
Currently, the frameworks are using the AR+Norm+Q fused kernels in https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/comm/trtllm_ar.py . It becomes quite difficult for us to choose among these different options.
cc @nvmbreughe
<!-- .github/pull_request_template.md --> ## 📌 Description This PR porting all changes in [TensorRT-LLM#8018](NVIDIA/TensorRT-LLM#8018) into Flashinfer. Apart from the changes mentioned in the original PR, this PR also introduce new API interface as `trtllm_mnnvl_allreduce` and `trtllm_mnnvl_fused_allreduce_add_rmsnorm` to replace the original ones. The workspace allocation is wrapped as an entire class with a given buffer size and the user does not need to worry about the details inside. This PR adds support for IPC Socket based mcast device memory bootstrap so that it can run on DGX machine that does not support fabric handle. @wenscarl This PR also incorporate the changes made in flashinfer-ai#2056 and should be able to replace that PR. A bcast interface is added to the comm backend as this is needed during the handle transfer. The old API is tagged as deprecated and redirected to the new APIs. The user of the old API should not need to make any changes. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Fused all‑reduce with optional RMSNorm fusion and selectable one‑shot/two‑shot strategies; new Python APIs and workspace utilities; IPC-based handle exchange and bcast support. * **Improvements** * Pluggable handle‑exchange backends (Fabric/POSIX), stricter I/O and shape validation, renamed/standardized fusion entry points and parameter surfaces, cached CUDA SM count for tuning, and safer lifecycle/cleanup. * **Tests** * MPI‑aware tests for fused and legacy flows, workspace-based runs, synchronization, and expanded sequence/hidden‑size coverage. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
This PR porting all changes in TensorRT-LLM#8018 into Flashinfer.
Apart from the changes mentioned in the original PR, this PR also introduce new API interface as
trtllm_mnnvl_allreduceandtrtllm_mnnvl_fused_allreduce_add_rmsnormto replace the original ones. The workspace allocation is wrapped as an entire class with a given buffer size and the user does not need to worry about the details inside.This PR adds support for IPC Socket based mcast device memory bootstrap so that it can run on DGX machine that does not support fabric handle.
@wenscarl This PR also incorporate the changes made in #2056 and should be able to replace that PR. A bcast interface is added to the comm backend as this is needed during the handle transfer.
The old API is tagged as deprecated and redirected to the new APIs. The user of the old API should not need to make any changes.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Improvements
Tests
✏️ Tip: You can customize this high-level summary in your review settings.