Skip to content

Added heuristic for trtllm_allreduce_fusion#1972

Merged
yzh119 merged 3 commits intoflashinfer-ai:mainfrom
nvjullin:all-reduce-fusion-heuristic
Oct 24, 2025
Merged

Added heuristic for trtllm_allreduce_fusion#1972
yzh119 merged 3 commits intoflashinfer-ai:mainfrom
nvjullin:all-reduce-fusion-heuristic

Conversation

@nvjullin
Copy link
Copy Markdown
Contributor

@nvjullin nvjullin commented Oct 23, 2025

📌 Description

The original heuristic does not accurately reflect the performance of oneshot/twoshot. Updated with heuristics based on this benchmark allreduce_test.py. The benchmark uses hidden_dim of Llama3, LLama4 and GPT-OSS and combinations of token_num, fusion patterns and fp32_acc.

The results are at the bottom. TL;DR token_num is a bad predictor of whether to use oneshot or twoshot. Using the communication size of oneshot is a good predictor, but only if we treat each TP separately. Fusion patterns and fp32_acc is irrelevant to the choice.

Full size results

comm_size_TP=2 comm_size_TP=4 comm_size_TP=8

Results zoomed in on small comm_size

comm_size_Enlarge_TP=2 comm_size_Enlarge_TP=4 comm_size_Enlarge_TP=8

Mixing TP=2/4/8 makes the choice noisy

comm_size_TP=248 comm_size_Enlarge_TP=248

token_num is a bad predictor

token_num_TP=248 token_num_Enlarge_TP=248

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Improvements
    • Enhanced distributed training performance with intelligent heuristic-based all-reduce strategy selection that dynamically adapts based on model size and system configuration.
    • The all-reduce strategy parameter now supports automatic mode to select the optimal strategy, while manual override options remain available.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @nvjullin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the trtllm_allreduce_fusion function by replacing its previous static 'oneshot' strategy selection with a dynamic, empirically-derived heuristic. The new logic leverages comprehensive benchmark data to determine the most performant all-reduce strategy based on communication size and world size, aiming to optimize performance across various configurations, particularly on B200 GPUs.

Highlights

  • New Heuristic for All-Reduce Fusion: Implemented a new heuristic to dynamically decide between 'oneshot' and 'twoshot' all-reduce strategies within the trtllm_allreduce_fusion function.
  • Performance-Driven Decision Making: The strategy selection is now based on the communication size in megabytes (comm_size_mb) and the world_size, derived from extensive benchmarks on B200 GPUs, replacing a fixed token_num threshold.
  • Irrelevant Factors Identified: Benchmarking revealed that fusion patterns and fp32_acc do not significantly influence the optimal choice between the oneshot and twoshot strategies.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a heuristic for trtllm_allreduce_fusion based on communication size to determine whether to use oneshot or twoshot strategy. The heuristic replaces the previous token_num based approach, aiming for better performance. The changes include adding a _use_oneshot_heuristics dictionary, a _should_use_oneshot function, and modifying the trtllm_allreduce_fusion function to use the new heuristic. The documentation for the use_oneshot parameter is also updated.

Comment on lines 817 to 818
if not use_oneshot:
assert token_num > world_size, "sequence length should be larger than tp_size"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This assertion message is not very informative. Consider providing more context, such as the actual values of token_num and world_size, to aid in debugging.

    if not use_oneshot:
        assert token_num > world_size, f"sequence length ({token_num}) should be larger than tp_size ({world_size})"

Comment thread flashinfer/comm/trtllm_ar.py Outdated
Comment on lines +761 to +763
def _should_use_oneshot(token_num: int, hidden_dim: int, dtype: torch.dtype, world_size: int) -> bool:
comm_size_mb = token_num * hidden_dim * 2 * world_size * dtype.itemsize / 1024 / 1024
return comm_size_mb <= _use_oneshot_heuristics[world_size]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The calculation of comm_size_mb could potentially benefit from using torch.finfo(dtype).bits instead of dtype.itemsize to explicitly represent the number of bits in the datatype, enhancing readability and maintainability.

Also, consider adding a small constant (e.g., 1e-9) to the denominator to prevent potential division by zero errors, although 1024 * 1024 is unlikely to be zero, it's a good practice to avoid it.

    comm_size_mb = token_num * hidden_dim * 2 * world_size * torch.finfo(dtype).bits / 8 / 1024 / 1024 # Convert bits to MB
    return comm_size_mb <= _use_oneshot_heuristics[world_size]

Comment thread flashinfer/comm/trtllm_ar.py Outdated
Comment on lines +806 to +815
if use_oneshot is None:
use_oneshot = token_num <= 128
use_oneshot = _should_use_oneshot(token_num, hidden_dim, allreduce_in.dtype, world_size)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a log message here to indicate when the heuristic is being used and what the decision is. This can be helpful for debugging and monitoring purposes.

For example, log the calculated comm_size_mb and the chosen use_oneshot value.

    use_oneshot = _should_use_oneshot(token_num, hidden_dim, allreduce_in.dtype, world_size)
    logging.info(f"Heuristic chose use_oneshot={use_oneshot} for token_num={token_num}, hidden_dim={hidden_dim}, world_size={world_size}")

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The figures look awesome, would you mind adding these benchmarking scripts (in another PR, not urgent) to benchmarks?

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Oct 24, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

A heuristic-based mechanism for deciding oneshot all-reduce strategy was introduced, replacing a fixed token threshold. The update adds a configuration dictionary mapping world size to communication-size thresholds and a helper function that computes whether to enable oneshot based on token count, hidden dimension, and data type.

Changes

Cohort / File(s) Summary
All-reduce heuristic refinement
flashinfer/comm/trtllm_ar.py
Added _use_oneshot_heuristics dict and _should_use_oneshot() helper function for adaptive oneshot strategy selection. Updated trtllm_allreduce_fusion() signature to accept use_oneshot: Optional[bool], enabling internal heuristic computation when None. Changed control flow from fixed token_num <= 128 rule to communication-size-based threshold comparison. Updated documentation to reflect new behavior.

Sequence Diagram(s)

sequenceDiagram
    participant Caller
    participant trtllm_allreduce_fusion
    participant _should_use_oneshot
    
    Caller->>trtllm_allreduce_fusion: call with use_oneshot=None
    alt use_oneshot is None
        trtllm_allreduce_fusion->>_should_use_oneshot: compute heuristic<br/>(token_num, hidden_dim, dtype, world_size)
        _should_use_oneshot->>_should_use_oneshot: calculate comm_size_mb<br/>= (token_num × hidden_dim × dtype_bytes)<br/> ÷ (1024²)
        _should_use_oneshot->>_should_use_oneshot: compare vs threshold<br/>from _use_oneshot_heuristics[world_size]
        _should_use_oneshot-->>trtllm_allreduce_fusion: return bool
        rect rgb(200, 220, 250)
            note over trtllm_allreduce_fusion: Execute oneshot or<br/>standard all-reduce
        end
    else use_oneshot is explicitly set
        rect rgb(240, 240, 240)
            note over trtllm_allreduce_fusion: Use provided value<br/>(validate if False)
        end
    end
Loading

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

The change is contained to a single file with focused logic modifications. The new heuristic mechanism is straightforward (communication-size threshold comparison), the signature update is clear, and documentation aligns with implementation. Primary review concerns involve verifying threshold values and control-flow edge cases.

Poem

🐰 A strategy refined, no longer bound,
To rigid tokens of one-twenty-eight,
Communication wisdom now we've found,
Heuristics guide what once was fate,
All-reduce flows both swift and great!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title Check ✅ Passed The title "Added heuristic for trtllm_allreduce_fusion" directly and accurately describes the main change in the pull request. The core modification is replacing the fixed 128-token threshold with a heuristic-based mechanism that uses communication size to decide between oneshot and twoshot all-reduce strategies. The title is concise, specific, and clearly summarizes the primary change without unnecessary details or vague terminology.
Description Check ✅ Passed The pull request description follows the required template structure with all major sections present. The Description section is substantive and well-detailed, explaining the motivation for the change (original heuristic was inaccurate), the benchmarking methodology (covering Llama3, Llama4, and GPT-OSS models with various configurations), and key findings (communication size is a good predictor when treating each TP separately). The Related Issues section is empty but non-critical, and the checklist items are present with standard unchecked boxes typical for newly created PRs. The description adequately communicates the purpose and context of the changes.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

@yzh119 yzh119 enabled auto-merge (squash) October 24, 2025 00:09
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
flashinfer/comm/trtllm_ar.py (1)

763-769: Consider adding logging for debugging and monitoring.

Based on past review feedback, adding a log message when this heuristic is used would be helpful for debugging and understanding the decision-making process.

 def _should_use_oneshot(
     token_num: int, hidden_dim: int, dtype: torch.dtype, world_size: int
 ) -> bool:
     comm_size_mb = (
         token_num * hidden_dim * 2 * world_size * dtype.itemsize / 1024 / 1024
     )
-    return comm_size_mb <= _use_oneshot_heuristics[world_size]
+    threshold = _use_oneshot_heuristics[world_size]
+    use_oneshot = comm_size_mb <= threshold
+    logging.debug(
+        f"Oneshot heuristic: comm_size={comm_size_mb:.2f}MB, threshold={threshold}MB, "
+        f"use_oneshot={use_oneshot} (world_size={world_size}, token_num={token_num})"
+    )
+    return use_oneshot

Based on learnings

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 739df61 and 2af730f.

📒 Files selected for processing (1)
  • flashinfer/comm/trtllm_ar.py (3 hunks)
🔇 Additional comments (2)
flashinfer/comm/trtllm_ar.py (2)

804-804: LGTM!

The updated docstring clearly documents the new behavior when use_oneshot is None.


820-823: Logic correctly implements the new heuristic.

The implementation properly delegates to _should_use_oneshot when use_oneshot is None, replacing the previous hard-coded threshold. This aligns with the PR's objective of using empirically-derived heuristics based on communication size rather than token count.

Note: This depends on the KeyError issue being resolved in _should_use_oneshot (see comment on lines 754-760).

Comment on lines +754 to +760
# Heuristics based on all configs of trtllm_allreduce_fusion on B200.
# Empirically, the fusion pattern and fp32_acc are irrelevant to the decision.
_use_oneshot_heuristics: dict[int, int] = {
2: 512,
4: 64,
8: 42,
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Add handling for unsupported world_size values.

The dictionary only contains entries for world_size 2, 4, and 8. If _should_use_oneshot is called with any other world_size value, it will raise a KeyError at line 769.

Consider one of these approaches:

Option 1: Add a default fallback value

 _use_oneshot_heuristics: dict[int, int] = {
 }
+
+def _get_oneshot_threshold(world_size: int) -> int:
+    """Get oneshot threshold with fallback for unsupported world_size."""
+    if world_size in _use_oneshot_heuristics:
+        return _use_oneshot_heuristics[world_size]
+    # Conservative fallback: use smallest threshold
+    return min(_use_oneshot_heuristics.values())

Then update line 769 to use _get_oneshot_threshold(world_size).

Option 2: Add validation in the main function

 def trtllm_allreduce_fusion(
     ...
 ) -> None:
     """..."""
 
+    if use_oneshot is None and world_size not in _use_oneshot_heuristics:
+        raise ValueError(
+            f"Unsupported world_size {world_size} for oneshot heuristic. "
+            f"Supported values: {list(_use_oneshot_heuristics.keys())}"
+        )
+
     if use_oneshot is None:

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In flashinfer/comm/trtllm_ar.py around lines 754-760, the
_use_oneshot_heuristics dict only covers world_size 2,4,8 so lookup in
_should_use_oneshot will KeyError for other values; fix by adding handling for
unsupported world_size—either implement a small helper
_get_oneshot_threshold(world_size) that returns
_use_oneshot_heuristics.get(world_size, <sane_default>) and use that in
_should_use_oneshot, or add explicit validation at the start of
_should_use_oneshot that raises a clear ValueError for unsupported world_size;
update callers to use the helper/validation accordingly.

@yzh119 yzh119 merged commit 4512a6c into flashinfer-ai:main Oct 24, 2025
4 checks passed
@coderabbitai coderabbitai Bot mentioned this pull request Nov 20, 2025
5 tasks
@coderabbitai coderabbitai Bot mentioned this pull request Mar 6, 2026
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants