Added heuristic for trtllm_allreduce_fusion#1972
Conversation
Summary of ChangesHello @nvjullin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a heuristic for trtllm_allreduce_fusion based on communication size to determine whether to use oneshot or twoshot strategy. The heuristic replaces the previous token_num based approach, aiming for better performance. The changes include adding a _use_oneshot_heuristics dictionary, a _should_use_oneshot function, and modifying the trtllm_allreduce_fusion function to use the new heuristic. The documentation for the use_oneshot parameter is also updated.
| if not use_oneshot: | ||
| assert token_num > world_size, "sequence length should be larger than tp_size" |
There was a problem hiding this comment.
| def _should_use_oneshot(token_num: int, hidden_dim: int, dtype: torch.dtype, world_size: int) -> bool: | ||
| comm_size_mb = token_num * hidden_dim * 2 * world_size * dtype.itemsize / 1024 / 1024 | ||
| return comm_size_mb <= _use_oneshot_heuristics[world_size] |
There was a problem hiding this comment.
The calculation of comm_size_mb could potentially benefit from using torch.finfo(dtype).bits instead of dtype.itemsize to explicitly represent the number of bits in the datatype, enhancing readability and maintainability.
Also, consider adding a small constant (e.g., 1e-9) to the denominator to prevent potential division by zero errors, although 1024 * 1024 is unlikely to be zero, it's a good practice to avoid it.
comm_size_mb = token_num * hidden_dim * 2 * world_size * torch.finfo(dtype).bits / 8 / 1024 / 1024 # Convert bits to MB
return comm_size_mb <= _use_oneshot_heuristics[world_size]| if use_oneshot is None: | ||
| use_oneshot = token_num <= 128 | ||
| use_oneshot = _should_use_oneshot(token_num, hidden_dim, allreduce_in.dtype, world_size) |
There was a problem hiding this comment.
Consider adding a log message here to indicate when the heuristic is being used and what the decision is. This can be helpful for debugging and monitoring purposes.
For example, log the calculated comm_size_mb and the chosen use_oneshot value.
use_oneshot = _should_use_oneshot(token_num, hidden_dim, allreduce_in.dtype, world_size)
logging.info(f"Heuristic chose use_oneshot={use_oneshot} for token_num={token_num}, hidden_dim={hidden_dim}, world_size={world_size}")
yzh119
left a comment
There was a problem hiding this comment.
The figures look awesome, would you mind adding these benchmarking scripts (in another PR, not urgent) to benchmarks?
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughA heuristic-based mechanism for deciding oneshot all-reduce strategy was introduced, replacing a fixed token threshold. The update adds a configuration dictionary mapping world size to communication-size thresholds and a helper function that computes whether to enable oneshot based on token count, hidden dimension, and data type. Changes
Sequence Diagram(s)sequenceDiagram
participant Caller
participant trtllm_allreduce_fusion
participant _should_use_oneshot
Caller->>trtllm_allreduce_fusion: call with use_oneshot=None
alt use_oneshot is None
trtllm_allreduce_fusion->>_should_use_oneshot: compute heuristic<br/>(token_num, hidden_dim, dtype, world_size)
_should_use_oneshot->>_should_use_oneshot: calculate comm_size_mb<br/>= (token_num × hidden_dim × dtype_bytes)<br/> ÷ (1024²)
_should_use_oneshot->>_should_use_oneshot: compare vs threshold<br/>from _use_oneshot_heuristics[world_size]
_should_use_oneshot-->>trtllm_allreduce_fusion: return bool
rect rgb(200, 220, 250)
note over trtllm_allreduce_fusion: Execute oneshot or<br/>standard all-reduce
end
else use_oneshot is explicitly set
rect rgb(240, 240, 240)
note over trtllm_allreduce_fusion: Use provided value<br/>(validate if False)
end
end
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes The change is contained to a single file with focused logic modifications. The new heuristic mechanism is straightforward (communication-size threshold comparison), the signature update is clear, and documentation aligns with implementation. Primary review concerns involve verifying threshold values and control-flow edge cases. Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
flashinfer/comm/trtllm_ar.py (1)
763-769: Consider adding logging for debugging and monitoring.Based on past review feedback, adding a log message when this heuristic is used would be helpful for debugging and understanding the decision-making process.
def _should_use_oneshot( token_num: int, hidden_dim: int, dtype: torch.dtype, world_size: int ) -> bool: comm_size_mb = ( token_num * hidden_dim * 2 * world_size * dtype.itemsize / 1024 / 1024 ) - return comm_size_mb <= _use_oneshot_heuristics[world_size] + threshold = _use_oneshot_heuristics[world_size] + use_oneshot = comm_size_mb <= threshold + logging.debug( + f"Oneshot heuristic: comm_size={comm_size_mb:.2f}MB, threshold={threshold}MB, " + f"use_oneshot={use_oneshot} (world_size={world_size}, token_num={token_num})" + ) + return use_oneshotBased on learnings
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/comm/trtllm_ar.py(3 hunks)
🔇 Additional comments (2)
flashinfer/comm/trtllm_ar.py (2)
804-804: LGTM!The updated docstring clearly documents the new behavior when
use_oneshotis None.
820-823: Logic correctly implements the new heuristic.The implementation properly delegates to
_should_use_oneshotwhenuse_oneshotis None, replacing the previous hard-coded threshold. This aligns with the PR's objective of using empirically-derived heuristics based on communication size rather than token count.Note: This depends on the KeyError issue being resolved in
_should_use_oneshot(see comment on lines 754-760).
| # Heuristics based on all configs of trtllm_allreduce_fusion on B200. | ||
| # Empirically, the fusion pattern and fp32_acc are irrelevant to the decision. | ||
| _use_oneshot_heuristics: dict[int, int] = { | ||
| 2: 512, | ||
| 4: 64, | ||
| 8: 42, | ||
| } |
There was a problem hiding this comment.
Add handling for unsupported world_size values.
The dictionary only contains entries for world_size 2, 4, and 8. If _should_use_oneshot is called with any other world_size value, it will raise a KeyError at line 769.
Consider one of these approaches:
Option 1: Add a default fallback value
_use_oneshot_heuristics: dict[int, int] = {
}
+
+def _get_oneshot_threshold(world_size: int) -> int:
+ """Get oneshot threshold with fallback for unsupported world_size."""
+ if world_size in _use_oneshot_heuristics:
+ return _use_oneshot_heuristics[world_size]
+ # Conservative fallback: use smallest threshold
+ return min(_use_oneshot_heuristics.values())Then update line 769 to use _get_oneshot_threshold(world_size).
Option 2: Add validation in the main function
def trtllm_allreduce_fusion(
...
) -> None:
"""..."""
+ if use_oneshot is None and world_size not in _use_oneshot_heuristics:
+ raise ValueError(
+ f"Unsupported world_size {world_size} for oneshot heuristic. "
+ f"Supported values: {list(_use_oneshot_heuristics.keys())}"
+ )
+
if use_oneshot is None:Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In flashinfer/comm/trtllm_ar.py around lines 754-760, the
_use_oneshot_heuristics dict only covers world_size 2,4,8 so lookup in
_should_use_oneshot will KeyError for other values; fix by adding handling for
unsupported world_size—either implement a small helper
_get_oneshot_threshold(world_size) that returns
_use_oneshot_heuristics.get(world_size, <sane_default>) and use that in
_should_use_oneshot, or add explicit validation at the start of
_should_use_oneshot that raises a clear ValueError for unsupported world_size;
update callers to use the helper/validation accordingly.
📌 Description
The original heuristic does not accurately reflect the performance of oneshot/twoshot. Updated with heuristics based on this benchmark allreduce_test.py. The benchmark uses hidden_dim of Llama3, LLama4 and GPT-OSS and combinations of token_num, fusion patterns and fp32_acc.
The results are at the bottom. TL;DR token_num is a bad predictor of whether to use oneshot or twoshot. Using the communication size of oneshot is a good predictor, but only if we treat each TP separately. Fusion patterns and fp32_acc is irrelevant to the choice.
Full size results
Results zoomed in on small comm_size
Mixing TP=2/4/8 makes the choice noisy
token_num is a bad predictor
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit