Skip to content

Added workspace check and reflected this in test#1991

Merged
nvmbreughe merged 5 commits intoflashinfer-ai:mainfrom
nvmbreughe:mbreughe/1986
Oct 28, 2025
Merged

Added workspace check and reflected this in test#1991
nvmbreughe merged 5 commits intoflashinfer-ai:mainfrom
nvmbreughe:mbreughe/1986

Conversation

@nvmbreughe
Copy link
Copy Markdown
Contributor

@nvmbreughe nvmbreughe commented Oct 27, 2025

📌 Description

This PR attempts to fix #1986 (to be confirmed by requester)

The issue is that num_tokens was larger than MAX_TOKEN_NUM, which results in an IMA, or even in a hang. To address this, I added a validation check. This required a non-breaking API change:

  • create_ipc_workspace_for_all_reduce_fusion now has an optional "create_metadata" bool, which results in an additional return value
    • it is made optional as additional return value could break the API
  • trtllm_allreduce_fusion now takes an optional metadata dictionary
    • When provided, this will run the validation check
    • again, this is also optional, to avoid breaking the api

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Workspace creation can optionally return a metadata object describing configuration (for runtime validation).
    • Allreduce fusion operations accept optional metadata to validate parameters against the workspace at runtime.
  • Bug Fixes

    • Added metadata-based validation that surfaces clear errors on configuration mismatches.
  • Chores

    • Legacy workspace/allreduce wrappers marked deprecated and a workspace destruction endpoint renamed for consistency.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Oct 27, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Added optional workspace metadata creation in the IPC workspace API and metadata-based validation in the all-reduce fusion call. Signatures changed to support create_metadata and metadata; deprecated non-fusion wrappers were introduced and lamport buffer sizing and metadata schema were adjusted.

Changes

Cohort / File(s) Summary
Core API & Deprecations
flashinfer/comm/trtllm_ar.py
Added create_metadata: bool = False to trtllm_create_ipc_workspace_for_all_reduce_fusion and widened return to optionally include a metadata dict. Added metadata: Optional[dict] = None to trtllm_allreduce_fusion with runtime validation (checks token_num ≤ max_token_num, world_size == tp_size, hidden_dim match) raising ValueError on mismatch. Constructed metadata when create_metadata=True (keys: tp_rank, tp_size, max_token_num, hidden_dim, use_fp32_lamport, buffer_size, flag_size, lamport_comm_size, lamport_buffer_size). Adjusted lamport buffer size calculation/comments. Imported and applied typing_extensions.deprecated wrappers to deprecate trtllm_create_ipc_workspace_for_all_reduce and trtllm_custom_all_reduce.
Tests & API migration
tests/comm/test_trtllm_allreduce_fusion.py
Updated tests to request create_metadata=True, capture returned workspace_metadata, pass it into trtllm_allreduce_fusion, and use the _fusion destroy API (trtllm_destroy_ipc_workspace_for_all_reduce_fusion). Adjusted call sites to match new signatures and return shapes.

Sequence Diagram(s)

sequenceDiagram
    participant Test
    participant CreateWS as trtllm_create_ipc_workspace_for_all_reduce_fusion
    participant AllReduce as trtllm_allreduce_fusion

    Test->>CreateWS: create_metadata=True, (tp_rank,tp_size,max_token_num,hidden_dim,...)
    CreateWS-->>Test: (ipc_handles, workspace_tensor, metadata)

    note right of Test: metadata → {tp_rank,tp_size,max_token_num,hidden_dim,use_fp32_lamport,buffer_size,flag_size,lamport_comm_size,lamport_buffer_size}

    Test->>AllReduce: allreduce_in, world_size, token_num, hidden_dim, workspace_ptrs, ..., metadata

    rect rgb(245,250,255)
      note over AllReduce: Validation when metadata provided
      AllReduce->>AllReduce: assert token_num ≤ metadata.max_token_num
      AllReduce->>AllReduce: assert world_size == metadata.tp_size
      AllReduce->>AllReduce: assert hidden_dim == metadata.hidden_dim
    end

    alt Validation passes
      AllReduce->>AllReduce: launch fused kernels / run logic
      AllReduce-->>Test: completion
    else Validation fails
      AllReduce-->>Test: raise ValueError (detailed message)
    end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

  • Pay attention to:
    • Exact metadata schema and produced values in trtllm_create_ipc_workspace_for_all_reduce_fusion
    • Validation branches and clarity of raised ValueError messages in trtllm_allreduce_fusion
    • Lamport buffer size formula and unit/edge-case correctness (potential memory/hang risk)
    • Correct behavior and visibility of deprecation wrappers

Poem

🐰 I stitched a tiny metadata map,
A carrot of counts and buffer cap,
I hop through ranks and check each token,
Guarding buffers so kernels stay woken,
A tidy hop — the fusion's on the lap.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (4 passed)
Check name Status Explanation
Title Check ✅ Passed The title "Added workspace check and reflected this in test" directly references the core functional improvement of this PR—the validation check that prevents illegal memory access with large token counts. While the title does not capture the complete scope of changes (particularly the significant API extension with optional metadata parameters), it clearly identifies a real and important aspect of the changeset. A reviewer scanning the history would understand that this PR adds validation logic and corresponding test updates.
Linked Issues Check ✅ Passed Issue #1986 requires detecting and preventing illegal memory access or hangs with larger token counts in allreduce fusion. The raw summary shows that this PR implements metadata-based validation hooks in both trtllm_create_ipc_workspace_for_all_reduce_fusion and trtllm_allreduce_fusion that verify consistency between workspace metadata and runtime parameters, raising ValueError with detailed messages on mismatch. The metadata includes max_token_num which enables enforcement of token count constraints. Tests have been updated to pass and utilize the new metadata parameters [#1986].
Out of Scope Changes Check ✅ Passed The primary changes—adding metadata-based validation, optional parameters, and test updates—directly address issue #1986. However, the raw summary also indicates that deprecation wrappers were added to trtllm_create_ipc_workspace_for_all_reduce and trtllm_custom_all_reduce. While these deprecation wrappers are tangential to the core validation fix, they are reasonable as part of guiding users toward the new metadata-aware API. The deprecations are intentionally documented in the raw summary as public API changes accompanying the broader improvement.
Description Check ✅ Passed The PR description includes the required sections from the template: a Description section explaining what the PR does and why it's needed (fixing #1986 with validation checks and non-breaking API changes), a Pull Request Checklist with pre-commit and test sections, and a Related Issues section. However, the Related Issues section is empty rather than containing an explicit link to #1986, even though the issue is mentioned in the Description. The checklist also shows incomplete pre-commit work and unconfirmed full test passage, though these are procedural notes rather than content gaps.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @nvmbreughe, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a crucial safety mechanism to the all-reduce fusion operations by implementing workspace validation checks. It addresses a potential issue where incorrect token counts or dimension mismatches could lead to memory access errors or system hangs. By optionally providing and consuming workspace metadata, the system can now proactively detect and prevent these inconsistencies, enhancing the robustness and stability of the communication primitives.

Highlights

  • Workspace Metadata Generation: The trtllm_create_ipc_workspace_for_all_reduce_fusion function now supports an optional create_metadata flag, allowing it to return a dictionary containing workspace configuration details such as tp_rank, tp_size, max_token_num, and hidden_dim.
  • All-Reduce Fusion Validation: The trtllm_allreduce_fusion function now accepts an optional metadata dictionary. When provided, it performs critical validation checks to prevent potential issues like Illegal Memory Access (IMA) or hangs.
  • Validation Checks Implemented: The new validation logic ensures that token_num does not exceed max_token_num, world_size matches tp_size, and hidden_dim aligns with the workspace's hidden_dim. A ValueError is raised if any of these conditions are violated.
  • Test Suite Updates: Existing correctness tests for trtllm_allreduce_fusion have been updated to utilize and verify the new metadata generation and validation features, ensuring the robustness of the changes.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a valuable validation check to prevent potential illegal memory access errors by ensuring num_tokens does not exceed the workspace's max_token_num. The changes are implemented in a non-breaking way by using optional parameters, which is a good practice for maintaining API compatibility. The tests have been updated accordingly, and a bug in the test cleanup logic was also fixed.

I have a couple of suggestions to improve the maintainability and type safety of the new API, related to using typing.overload for clearer function signatures and a TypedDict or dataclass for the metadata dictionary. Overall, this is a solid contribution that improves the robustness of the library.

Comment on lines +504 to +506
) -> Union[
Tuple[List[List[int]], torch.Tensor], Tuple[List[List[int]], torch.Tensor, dict]
]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While using Union of tuples with different lengths maintains API compatibility, it can be challenging for static analysis tools and IDEs to handle correctly. For better type safety and developer experience, you could use typing.overload to define the function signatures explicitly. This would make it clear to users of the function what to expect for different inputs, improving usability and reducing potential for errors.

Comment on lines +622 to +632
metadata = {
"tp_rank": tp_rank,
"tp_size": tp_size,
"max_token_num": max_token_num,
"hidden_dim": hidden_dim,
"use_fp32_lamport": use_fp32_lamport,
"buffer_size": buffer_size,
"flag_size": flag_size,
"lamport_comm_size": lamport_comm_size,
"lamport_buffer_size": lamport_buffer_size,
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a raw dict for metadata is prone to errors from typos in keys, both when creating it here and when consuming it in trtllm_allreduce_fusion. To make this more robust and self-documenting, consider defining a TypedDict or a dataclass for the metadata. This provides static type checking and autocompletion support in IDEs, reducing the chance of runtime errors.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (3)
flashinfer/comm/trtllm_ar.py (2)

503-506: Return typing is fine; consider a stronger metadata type for clarity.

Union return is OK. To improve API clarity, define a TypedDict for metadata and use it in the signature and docstrings. This avoids untyped dict usage downstream.

-from typing import List, Optional, Tuple, Union
+from typing import List, Optional, Tuple, Union, TypedDict
+
+class AllReduceFusionWorkspaceMeta(TypedDict):
+    tp_rank: int
+    tp_size: int
+    max_token_num: int
+    hidden_dim: int
+    use_fp32_lamport: bool
+    buffer_size: int
+    flag_size: int
+    lamport_comm_size: int
+    lamport_buffer_size: int
@@
-) -> Union[
-    Tuple[List[List[int]], torch.Tensor], Tuple[List[List[int]], torch.Tensor, dict]
-]:
+) -> Union[
+    Tuple[List[List[int]], torch.Tensor],
+    Tuple[List[List[int]], torch.Tensor, AllReduceFusionWorkspaceMeta],
+]:

621-636: Good: metadata emission. Consider adding device and validating key presence.

Including creation params is useful. Add a 'device' (CUDA index) for extra safety and validate required keys on consumption to avoid KeyError if a caller passes a partial dict.

         metadata = {
             "tp_rank": tp_rank,
             "tp_size": tp_size,
             "max_token_num": max_token_num,
             "hidden_dim": hidden_dim,
             "use_fp32_lamport": use_fp32_lamport,
+            "device": torch.cuda.current_device(),
             "buffer_size": buffer_size,
             "flag_size": flag_size,
             "lamport_comm_size": lamport_comm_size,
             "lamport_buffer_size": lamport_buffer_size,
         }
tests/comm/test_trtllm_allreduce_fusion.py (1)

60-71: Good: create workspace with metadata; add quick negative-path checks here.

Leverage the new validation with a small trio of negative assertions right after creating metadata to guard against regressions (no heavy kernels executed).

         ipc_handles, workspace_tensor, workspace_metadata = (
             comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
                 rank,
                 world_size,
                 MAX_TOKEN_NUM,
                 hidden_dim,
                 group=group,
                 use_fp32_lamport=lamport_use_fp32,
                 create_metadata=True,  # Get metadata for validation
             )
         )
+
+        # Negative-path validation: token_num overflow
+        bad_token_num = MAX_TOKEN_NUM + 1
+        with pytest.raises(ValueError, match="token_num .* exceeds"):
+            comm.trtllm_allreduce_fusion(
+                allreduce_in=torch.empty(bad_token_num * hidden_dim, dtype=dtype, device=device),
+                world_size=world_size,
+                world_rank=rank,
+                token_num=bad_token_num,
+                hidden_dim=hidden_dim,
+                workspace_ptrs=workspace_tensor,
+                launch_with_pdl=True,
+                trigger_completion_at_end=True,
+                fp32_acc=False,
+                pattern_code=comm.AllReduceFusionPattern.kAllReduce,
+                use_oneshot=None,
+                allreduce_out=None,
+                residual_in=None,
+                residual_out=None,
+                norm_out=None,
+                quant_out=None,
+                scale_out=None,
+                rms_gamma=None,
+                rms_eps=None,
+                scale_factor=None,
+                layout_code=None,
+                metadata=workspace_metadata,
+            )
+
+        # Negative-path validation: world_size mismatch
+        meta_ws_mismatch = dict(workspace_metadata)
+        meta_ws_mismatch["tp_size"] = world_size + 1
+        with pytest.raises(ValueError, match="world_size .* does not match"):
+            comm.trtllm_allreduce_fusion(
+                allreduce_in=torch.empty(1 * hidden_dim, dtype=dtype, device=device),
+                world_size=world_size,
+                world_rank=rank,
+                token_num=1,
+                hidden_dim=hidden_dim,
+                workspace_ptrs=workspace_tensor,
+                launch_with_pdl=True,
+                trigger_completion_at_end=True,
+                fp32_acc=False,
+                pattern_code=comm.AllReduceFusionPattern.kAllReduce,
+                use_oneshot=None,
+                allreduce_out=None,
+                residual_in=None,
+                residual_out=None,
+                norm_out=None,
+                quant_out=None,
+                scale_out=None,
+                rms_gamma=None,
+                rms_eps=None,
+                scale_factor=None,
+                layout_code=None,
+                metadata=meta_ws_mismatch,
+            )
+
+        # Negative-path validation: hidden_dim mismatch
+        meta_hd_mismatch = dict(workspace_metadata)
+        meta_hd_mismatch["hidden_dim"] = hidden_dim + 64
+        with pytest.raises(ValueError, match="hidden_dim .* does not match"):
+            comm.trtllm_allreduce_fusion(
+                allreduce_in=torch.empty(1 * hidden_dim, dtype=dtype, device=device),
+                world_size=world_size,
+                world_rank=rank,
+                token_num=1,
+                hidden_dim=hidden_dim,
+                workspace_ptrs=workspace_tensor,
+                launch_with_pdl=True,
+                trigger_completion_at_end=True,
+                fp32_acc=False,
+                pattern_code=comm.AllReduceFusionPattern.kAllReduce,
+                use_oneshot=None,
+                allreduce_out=None,
+                residual_in=None,
+                residual_out=None,
+                norm_out=None,
+                quant_out=None,
+                scale_out=None,
+                rms_gamma=None,
+                rms_eps=None,
+                scale_factor=None,
+                layout_code=None,
+                metadata=meta_hd_mismatch,
+            )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d4a3ff4 and 0420306.

📒 Files selected for processing (2)
  • flashinfer/comm/trtllm_ar.py (6 hunks)
  • tests/comm/test_trtllm_allreduce_fusion.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/comm/test_trtllm_allreduce_fusion.py (1)
flashinfer/comm/trtllm_ar.py (2)
  • trtllm_create_ipc_workspace_for_all_reduce_fusion (496-635)
  • trtllm_destroy_ipc_workspace_for_all_reduce_fusion (638-654)
flashinfer/comm/trtllm_ar.py (1)
flashinfer/comm/mapping.py (1)
  • tp_rank (325-326)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (5)
flashinfer/comm/trtllm_ar.py (2)

893-898: LGTM: oneshot fallback guard.

Warning and disabling oneshot when exceeding MAX_COMM_SIZE is appropriate.


818-819: No breaking changes detected; new param added safely.

The new metadata parameter is added as a trailing optional parameter with a default value (None), which is the safe way to extend a function signature. The call at line 249 passes 21 positional arguments that correctly match the first 21 function parameters in order. The new trailing parameter does not affect this or any other call site—all existing calls continue to work without modification. Test calls (lines 166, 196) and internal calls (line 906) use keyword arguments, which are unaffected.

tests/comm/test_trtllm_allreduce_fusion.py (3)

188-189: LGTM: pass metadata during warmup.

This exercises the validation path without affecting CUDA graph capture.


218-219: LGTM: pass metadata during capture.

Safe since validation occurs on Python side before kernel launch; no graph-safety issues.


310-312: LGTM: updated destroy to fusion variant.

Matches the new creation path and keeps resource lifecycle consistent.

Comment thread flashinfer/comm/trtllm_ar.py
Comment thread flashinfer/comm/trtllm_ar.py
dist.barrier(group=group)

comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group=group)
comm.trtllm_destroy_ipc_workspace_for_all_reduce_fusion(
Copy link
Copy Markdown
Contributor Author

@nvmbreughe nvmbreughe Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yzh119 note this subtle issue of calling the wrong destroy function. It is not a problem here, as they have the same implementation.

As an alternative solution for this, as well as for the root cause of #1986 (workspace size insufficient) we could implement this as an AR-class. e.g., we would have trtllm_create_ipc_workspace_for_all_reduce_fusion return an AR object. The destructor would clean up (using the correct destructor :)), and we could just add members like "max_token_size", etc, for workspace validation.

This PR addresses this by having trtllm_create_ipc_workspace_for_all_reduce_fusion return a metadata dictionary that allows to do checks. It's not ideal, but it is non-API breaking. It's not high priority, but we could consider the proposed alternative solution, and deprecate both custom_all_reduce and all_reduce_fusion variants.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed making it a class should be a more reasonable solution

Comment thread flashinfer/comm/trtllm_ar.py Outdated
"""

# Validate against workspace metadata if provided
if metadata is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid similar cases where num_tokens > MAX_TOKEN_NUM, the check is necessary anyway — so why not just return meta regardless?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because it would break the API. We can only do that for major bumps if we want to respect semantic versioning.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't address this, but we should keep it in mind for an upcoming major bump. I added a TODO comment so we don't forget.

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@nvmbreughe nvmbreughe enabled auto-merge (squash) October 28, 2025 19:33
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (3)
flashinfer/comm/trtllm_ar.py (3)

510-513: Consider @overload for clearer type signatures.

While the Union return type maintains compatibility, static analysis tools and IDEs struggle with variable-length tuple unions. Using typing.overload would provide explicit signatures for each case, improving type safety and developer experience.

Example:

@overload
def trtllm_create_ipc_workspace_for_all_reduce_fusion(
    tp_rank: int,
    tp_size: int,
    max_token_num: int,
    hidden_dim: int,
    use_fp32_lamport: bool = False,
    group: Optional[ProcessGroup] = None,
    create_metadata: Literal[False] = False,
) -> Tuple[List[List[int]], torch.Tensor]: ...

@overload
def trtllm_create_ipc_workspace_for_all_reduce_fusion(
    tp_rank: int,
    tp_size: int,
    max_token_num: int,
    hidden_dim: int,
    use_fp32_lamport: bool = False,
    group: Optional[ProcessGroup] = None,
    create_metadata: Literal[True] = ...,
) -> Tuple[List[List[int]], torch.Tensor, dict]: ...

628-642: Consider TypedDict or dataclass for type-safe metadata.

Raw dicts are vulnerable to typos in both creation and consumption sites. Defining a structured type would enable static type checking, IDE autocompletion, and early error detection.

Example with TypedDict:

from typing import TypedDict

class WorkspaceMetadata(TypedDict):
    tp_rank: int
    tp_size: int
    max_token_num: int
    hidden_dim: int
    use_fp32_lamport: bool
    buffer_size: int
    flag_size: int
    lamport_comm_size: int
    lamport_buffer_size: int

913-917: Major: Size calculation should use lamport dtype from metadata.

When metadata is provided, the required lamport size should be calculated using metadata["use_fp32_lamport"] rather than allreduce_in.dtype. The workspace's lamport buffer dtype determines safety, not the input tensor dtype. Using the input dtype can underestimate required bytes and incorrectly enable oneshot mode, leading to buffer overruns.

Apply this diff:

+    # Determine element size: use metadata's lamport dtype when available
+    if metadata is not None:
+        elem_bytes = 4 if metadata["use_fp32_lamport"] else 2
+    else:
+        elem_bytes = 4 if allreduce_in.dtype == torch.float32 else 2
+    
     required_lamport_comm_size = (
-        token_num * hidden_dim * 2 * world_size
-        if allreduce_in.dtype != torch.float32
-        else token_num * hidden_dim * 4 * world_size
+        token_num * hidden_dim * elem_bytes * world_size
     )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f997ceb and 89e8cf4.

📒 Files selected for processing (1)
  • flashinfer/comm/trtllm_ar.py (10 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (5)
flashinfer/comm/trtllm_ar.py (5)

22-22: LGTM: Import added for deprecation decorators.

The typing_extensions.deprecated import supports the deprecation decorators applied to legacy functions, guiding users toward the fusion API.


125-127: LGTM: Deprecation decorators guide users to fusion API.

The deprecation markers on legacy functions provide a clear migration path to trtllm_create_ipc_workspace_for_all_reduce_fusion and trtllm_allreduce_fusion.

Also applies to: 400-402, 709-711


522-528: LGTM: Docstrings accurately document the new behavior.

The documentation correctly describes the conditional return values and accurately reflects the lamport buffer size computation based on max_token_num without referencing the deprecated OneShotMaxToken.

Also applies to: 537-538


828-828: LGTM: Optional metadata parameter maintains backward compatibility.

The default None value ensures existing callers continue to work while enabling validation when metadata is explicitly provided.


853-855: LGTM: Documentation clearly explains validation behavior.

The docstring effectively communicates the metadata parameter's purpose and the consequences of validation failures.

Comment thread flashinfer/comm/trtllm_ar.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
flashinfer/comm/trtllm_ar.py (1)

893-898: Critical: Inverted logic in use_fp32_lamport validation.

Line 894 uses == which raises an error when metadata and dtype correctly match, allowing mismatches to pass. This is inverted and defeats the validation purpose.

Apply this diff to fix the inverted condition:

-        if metadata["use_fp32_lamport"] == (allreduce_in.dtype == torch.float32):
+        if metadata["use_fp32_lamport"] != (allreduce_in.dtype == torch.float32):
🧹 Nitpick comments (2)
flashinfer/comm/trtllm_ar.py (2)

512-514: Consider typing.overload for better type safety (optional).

As noted in a previous review, using @overload decorators would provide clearer type signatures for different create_metadata values and improve IDE autocompletion. While the current Union approach works, overloads make it explicit what callers receive based on the boolean flag.

Based on learnings


630-640: Consider TypedDict or dataclass for metadata (recommended).

As noted in a previous review, defining a structured type for the metadata dict would prevent typos in keys and enable static type checking. This would make the API more robust when metadata is created here and consumed in trtllm_allreduce_fusion.

Based on learnings

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 89e8cf4 and 4a90d53.

📒 Files selected for processing (1)
  • flashinfer/comm/trtllm_ar.py (10 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (3)
flashinfer/comm/trtllm_ar.py (3)

22-22: LGTM: Deprecation markers correctly applied.

The deprecation decorators properly guide users toward the fusion APIs and maintain backward compatibility.

Also applies to: 125-127, 400-402, 710-712


862-891: LGTM: Validation logic for keys, token_num, world_size, and hidden_dim is correct.

The checks appropriately detect missing keys, token overflow, and dimension mismatches, providing clear error messages.


829-829: LGTM: Metadata parameter properly added with clear documentation.

The optional parameter maintains API compatibility while enabling validation when workspace metadata is provided.

Also applies to: 854-856

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
flashinfer/comm/trtllm_ar.py (1)

859-905: Comprehensive workspace metadata validation addresses issue #1986.

The validation logic correctly prevents illegal memory access by checking token_num against max_token_num and ensuring all workspace parameters match the runtime configuration. The error accumulation pattern provides excellent user feedback by reporting all mismatches in a single error message.

The critical issue from past reviews regarding inverted logic at line 894 has been fixed—the condition now correctly uses != to detect mismatches between use_fp32_lamport and the input dtype.

Optional: Minor Pythonic improvement for key checking.

Lines 863-865 could be slightly more concise using set operations:

-        for key in required_keys:
-            if key not in metadata:
-                errors.append(f"Workspace metadata is missing required key: {key}")
+        missing_keys = set(required_keys) - set(metadata.keys())
+        if missing_keys:
+            errors.append(f"Workspace metadata is missing required keys: {sorted(missing_keys)}")

This is purely stylistic and doesn't affect correctness.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4a90d53 and c83ac93.

📒 Files selected for processing (1)
  • flashinfer/comm/trtllm_ar.py (10 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (6)
flashinfer/comm/trtllm_ar.py (6)

22-22: LGTM: Appropriate import for API deprecation.

The deprecated decorator import is correctly placed and supports the API migration strategy.


125-127: LGTM: Consistent deprecation strategy.

The deprecation decorators are applied consistently to both the custom op registrations and wrapper functions, with clear migration guidance pointing users to the fusion-based APIs.

Also applies to: 400-402, 710-712


503-514: Optional metadata return maintains backward compatibility.

The create_metadata parameter allows callers to opt into receiving workspace metadata for validation purposes without breaking existing code. The TODO comment appropriately flags that this should become the default behavior in a future major version.

Note: A past review suggested using @overload for better type safety instead of Union of tuples, but the current approach is functional and maintains simplicity.

Also applies to: 523-530


538-539: LGTM: Documentation now matches implementation.

The lamport buffer size formula correctly documents the dtype-dependent sizing (2 bytes for fp16/bf16, 4 bytes for fp32), matching the conditional logic in lines 551-555.


629-643: Conditional metadata return logic is correct.

The implementation correctly returns metadata when requested, including all parameters necessary for downstream validation in trtllm_allreduce_fusion.

Note: A past review suggested using TypedDict or dataclass for the metadata structure to improve type safety and reduce typo-related errors, but the current dict-based approach is functional.


829-829: LGTM: Optional metadata parameter enables workspace validation.

The optional metadata parameter allows callers to opt into runtime validation while maintaining backward compatibility with existing code that doesn't provide metadata.

Also applies to: 854-856

@nvmbreughe nvmbreughe merged commit 7d9d7af into flashinfer-ai:main Oct 28, 2025
4 checks passed
@coderabbitai coderabbitai Bot mentioned this pull request Nov 27, 2025
5 tasks
BingooYang pushed a commit to BingooYang/flashinfer that referenced this pull request Mar 13, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

This PR attempts to fix flashinfer-ai#1986 (to be confirmed by requester)

The issue is that num_tokens was larger than MAX_TOKEN_NUM, which
results in an IMA, or even in a hang. To address this, I added a
validation check. This required a non-breaking API change:
* create_ipc_workspace_for_all_reduce_fusion now has an optional
"create_metadata" bool, which results in an additional return value
  * it is made optional as additional return value could break the API
* trtllm_allreduce_fusion now takes an optional metadata dictionary
  * When provided, this will run the validation check
  * again, this is also optional, to avoid breaking the api   


In addition this PR deprecates the older AllReduce functionality so it can be removed in a major version bump.

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **API Changes**
* Workspace creation can optionally return metadata describing the
workspace configuration (create_metadata flag).
* Allreduce fusion operations accept optional metadata to validate
runtime parameters against the workspace and raise clear errors on
mismatch.
  * A workspace destruction endpoint was renamed for naming consistency.
* Legacy wrappers were marked deprecated and now point users toward the
newer fusion variants.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Allreduce fusion illegal memory access or hang on larger sizes

3 participants