[Bugfix][comm] Fix FP4 one-shot launch config instability in trtllm_allreduce_fusion#2557
Conversation
Summary of ChangesHello @baonudesifeizhai, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical bug causing numerical instability and accuracy collapse in FP4 all-reduce fusion, particularly on SM100/B300 GPUs. The changes ensure that the one-shot launch configurations for FP4 operations are robust and numerically stable by enforcing specific constraints on Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
📝 WalkthroughWalkthroughThis change refactors FP4-specialized launch configuration in the TRTLLM allreduce kernel launcher: it captures a baseline configuration, validates and tries FP4 block sizes via helpers, adapts thread/block sizing to SM count, and adds fallbacks and a runtime check to ensure threads_per_block * cluster_size equals threads_per_token for oneshot launches. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~22 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
⚔️ Resolve merge conflicts (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request effectively fixes an instability in the FP4 one-shot launch configuration by ensuring the cluster_size is a power of two. The changes are well-structured, refactoring the selection logic into a helper lambda and adding a robust fallback mechanism to a baseline configuration if invariants are not met. I have a couple of suggestions to improve code reuse and maintainability by extracting duplicated logic into helpers.
|
|
||
| // FP4 optimization: apply BEFORE SM count check to avoid being overridden | ||
| // This allows FP4 to use smaller block_size even when cluster_num is large | ||
| auto is_power_of_two = [](int x) { return x > 0 && (x & (x - 1)) == 0; }; |
| if (oneshot && threads_per_block * cluster_size != threads_per_token) { | ||
| // Fallback to baseline launch config when FP4 specialization produces | ||
| // an invalid coverage configuration. | ||
| threads_per_block = baseline_threads_per_block; | ||
| cluster_size = baseline_cluster_size; | ||
| block_size = threads_per_block; | ||
| while (cluster_num * cluster_size > sm_count && cluster_size > 1 && | ||
| threads_per_block <= max_threads_per_block / 2) { | ||
| threads_per_block *= 2; | ||
| cluster_size /= 2; | ||
| block_size = threads_per_block; | ||
| } | ||
| } |
There was a problem hiding this comment.
The while loop here for adjusting launch parameters based on SM count duplicates logic from earlier in the function (lines 1462-1475). This code repetition can make maintenance more difficult.
Consider extracting this logic into a helper function to be called in both places. This would improve code clarity and maintainability.
For example:
void adjust_for_sm_count(int cluster_num, int sm_count, int max_threads_per_block,
int& threads_per_block, int& cluster_size, int& block_size) {
while (cluster_num * cluster_size > sm_count && cluster_size > 1 &&
threads_per_block <= max_threads_per_block / 2) {
threads_per_block *= 2;
cluster_size /= 2;
}
block_size = threads_per_block;
}|
@aleozlx could we prioritize merging this (or a different fix if preferred)? It's causing an accuracy collapse for vLLM DS3-fp4 so we have to downgrade to unfused kernels, greatly affecting performance. |
|
Congratulate my friend 💯 |
📌 Description
Fix unstable FP4 one-shot launch configuration in
trtllm_allreduce_fusion.On SM100/B300 with FP4 fusion, launcher could pick non power-of-two
cluster_size(e.g. hidden_dim=7168 -> block_size=128 -> cluster_size=7).This can lead to systematic numerical mismatch (observed ~41.4% mismatched elements in correctness test), and propagates to vLLM accuracy collapse in NVFP4 fused path.
File:
include/flashinfer/comm/trtllm_allreduce_fusion.cuh160/192/128), only accept candidates with:cluster_size <= 8cluster_sizeis power-of-twothreads_per_block * cluster_size != threads_per_token, fallback to baseline config🔍 Related Issues
vllm-project/vllm#34395
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

result on vllm side:
https://paste.ubuntu.com/p/JTctq8WNZS/
(did not go fallback)
res: https://paste.ubuntu.com/p/cpkWSmTRVn/
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
pytest -s tests/comm/test_trtllm_allreduce_fusion.py::test_trtllm_allreduce_fusion[True-7168-dtype0-2] --maxfail=1 -q
--passed
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Performance Improvements
Bug Fixes