Skip to content

[Bugfix][comm] Fix FP4 one-shot launch config instability in trtllm_allreduce_fusion#2557

Merged
yzh119 merged 12 commits intoflashinfer-ai:mainfrom
baonudesifeizhai:testforallreduce
Feb 17, 2026
Merged

[Bugfix][comm] Fix FP4 one-shot launch config instability in trtllm_allreduce_fusion#2557
yzh119 merged 12 commits intoflashinfer-ai:mainfrom
baonudesifeizhai:testforallreduce

Conversation

@baonudesifeizhai
Copy link
Copy Markdown
Contributor

@baonudesifeizhai baonudesifeizhai commented Feb 13, 2026

📌 Description

Fix unstable FP4 one-shot launch configuration in trtllm_allreduce_fusion.
On SM100/B300 with FP4 fusion, launcher could pick non power-of-two cluster_size (e.g. hidden_dim=7168 -> block_size=128 -> cluster_size=7).
This can lead to systematic numerical mismatch (observed ~41.4% mismatched elements in correctness test), and propagates to vLLM accuracy collapse in NVFP4 fused path.

File: include/flashinfer/comm/trtllm_allreduce_fusion.cuh

  • In FP4 one-shot candidate selection (160/192/128), only accept candidates with:
    • valid divisibility
    • cluster_size <= 8
    • cluster_size is power-of-two
  • Keep existing coverage invariant and fallback:
    • if threads_per_block * cluster_size != threads_per_token, fallback to baseline config
  • No change to one-shot/two-shot policy.
  • No extra debug probes in final patch.

🔍 Related Issues

vllm-project/vllm#34395

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
result on vllm side:
https://paste.ubuntu.com/p/JTctq8WNZS/
(did not go fallback)
a1bbd7b17b24a4e5e6a0663b710d99c5

res: https://paste.ubuntu.com/p/cpkWSmTRVn/

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

pytest -s tests/comm/test_trtllm_allreduce_fusion.py::test_trtllm_allreduce_fusion[True-7168-dtype0-2] --maxfail=1 -q
--passed

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Performance Improvements

    • Optimized FP4 launch configurations with adaptive block-sizing and SM-aware adjustments to improve resource use and stability.
    • Preserves baseline configuration to maintain full token coverage when specialized tuning isn't applicable.
  • Bug Fixes

    • Added validation and fallback logic for oneshot and tuned launches to ensure coverage and enforce runtime invariants.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @baonudesifeizhai, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical bug causing numerical instability and accuracy collapse in FP4 all-reduce fusion, particularly on SM100/B300 GPUs. The changes ensure that the one-shot launch configurations for FP4 operations are robust and numerically stable by enforcing specific constraints on cluster_size and providing a reliable fallback mechanism when these constraints are not met.

Highlights

  • FP4 One-Shot Launch Configuration Stability: Implemented stricter criteria for selecting FP4 one-shot launch configurations, specifically requiring cluster_size to be a power-of-two and less than or equal to 8, to prevent numerical instability.
  • Robust Fallback Mechanism: Introduced a fallback to a known-good baseline launch configuration if the FP4 specialization results in an invalid token coverage, ensuring correctness.
  • Refactored FP4 Optimization Logic: The logic for trying different block_size candidates (160, 192, 128) for FP4 optimization was refactored into a helper lambda for improved readability and maintainability.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • include/flashinfer/comm/trtllm_allreduce_fusion.cuh
    • Added an is_power_of_two lambda function for checking power-of-two values.
    • Introduced baseline_threads_per_block and baseline_cluster_size to store a known-good configuration before FP4 specialization.
    • Refactored the FP4 block size selection logic into a try_fp4_block_size lambda, incorporating new checks for power-of-two cluster_size and cluster_size <= 8.
    • Added a conditional fallback to the baseline configuration if the FP4 specialized configuration leads to invalid token coverage.
    • Included a FLASHINFER_CHECK to assert correct one-shot launch configuration after all adjustments.
Activity
  • The author reported successful pre-commit checks.
  • The author reported that tests passed for test_trtllm_allreduce_fusion.
  • The author provided vLLM side results indicating the fix did not trigger a fallback.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Feb 13, 2026

📝 Walkthrough

Walkthrough

This change refactors FP4-specialized launch configuration in the TRTLLM allreduce kernel launcher: it captures a baseline configuration, validates and tries FP4 block sizes via helpers, adapts thread/block sizing to SM count, and adds fallbacks and a runtime check to ensure threads_per_block * cluster_size equals threads_per_token for oneshot launches.

Changes

Cohort / File(s) Summary
FP4 Allreduce Launch Configuration
include/flashinfer/comm/trtllm_allreduce_fusion.cuh
Add helpers (is_power_of_two, adjust_for_sm_count, try_fp4_block_size); capture baseline threads/cluster sizes; replace hard-coded FP4 block-size attempts with validated sequence (160,192,128); apply SM-aware adjustments; add oneshot fallback to baseline when FP4 yields invalid coverage; add FLASHINFER_CHECK invariant for oneshot coverage; retain remaining final validations.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~22 minutes

Possibly related PRs

Suggested labels

v0.6.2

Suggested reviewers

  • djmmoss
  • yongwww
  • nvmbreughe
  • cyx-6
  • yzh119

Poem

🐰 I tuned the blocks with FP4 delight,
Power-of-two clusters hopping just right,
Baseline tucked safe when coverage is thin,
SM-aware hops keep launches in.
✨ Kernels leap — a rabbit's grin!

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Merge Conflict Detection ⚠️ Warning ❌ Merge conflicts detected (8 files):

⚔️ csrc/flashinfer_sampling_binding.cu (content)
⚔️ csrc/sampling.cu (content)
⚔️ flashinfer/aot.py (content)
⚔️ flashinfer/sampling.py (content)
⚔️ include/flashinfer/comm/trtllm_allreduce_fusion.cuh (content)
⚔️ include/flashinfer/sampling.cuh (content)
⚔️ scripts/task_run_unit_tests.sh (content)
⚔️ scripts/test_utils.sh (content)

These conflicts must be resolved before merging into main.
Resolve conflicts locally and push changes to this branch.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main change: fixing FP4 one-shot launch config instability in trtllm_allreduce_fusion, which matches the primary objective of the changeset.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Description check ✅ Passed The PR description covers the core issue and solution with technical details about FP4 launch config constraints, related issues, pre-commit checks, and test results.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
⚔️ Resolve merge conflicts (beta)
  • Auto-commit resolved conflicts to branch testforallreduce
  • Post resolved changes as copyable diffs in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively fixes an instability in the FP4 one-shot launch configuration by ensuring the cluster_size is a power of two. The changes are well-structured, refactoring the selection logic into a helper lambda and adding a robust fallback mechanism to a baseline configuration if invariants are not met. I have a couple of suggestions to improve code reuse and maintainability by extracting duplicated logic into helpers.


// FP4 optimization: apply BEFORE SM count check to avoid being overridden
// This allows FP4 to use smaller block_size even when cluster_num is large
auto is_power_of_two = [](int x) { return x > 0 && (x & (x - 1)) == 0; };
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This is_power_of_two helper is useful. Consider moving it to include/flashinfer/utils.cuh as a __device__ __host__ function to promote code reuse, similar to UpPowerOfTwo. This would make it available for other parts of the codebase as well.

Comment on lines +1477 to +1489
if (oneshot && threads_per_block * cluster_size != threads_per_token) {
// Fallback to baseline launch config when FP4 specialization produces
// an invalid coverage configuration.
threads_per_block = baseline_threads_per_block;
cluster_size = baseline_cluster_size;
block_size = threads_per_block;
while (cluster_num * cluster_size > sm_count && cluster_size > 1 &&
threads_per_block <= max_threads_per_block / 2) {
threads_per_block *= 2;
cluster_size /= 2;
block_size = threads_per_block;
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The while loop here for adjusting launch parameters based on SM count duplicates logic from earlier in the function (lines 1462-1475). This code repetition can make maintenance more difficult.

Consider extracting this logic into a helper function to be called in both places. This would improve code clarity and maintainability.

For example:

void adjust_for_sm_count(int cluster_num, int sm_count, int max_threads_per_block,
                         int& threads_per_block, int& cluster_size, int& block_size) {
  while (cluster_num * cluster_size > sm_count && cluster_size > 1 &&
         threads_per_block <= max_threads_per_block / 2) {
    threads_per_block *= 2;
    cluster_size /= 2;
  }
  block_size = threads_per_block;
}

@aleozlx aleozlx self-assigned this Feb 14, 2026
@ProExpertProg
Copy link
Copy Markdown

@aleozlx could we prioritize merging this (or a different fix if preferred)? It's causing an accuracy collapse for vLLM DS3-fp4 so we have to downgrade to unfused kernels, greatly affecting performance.

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall.

@yzh119 yzh119 merged commit 432f343 into flashinfer-ai:main Feb 17, 2026
17 of 30 checks passed
@Bias92
Copy link
Copy Markdown
Contributor

Bias92 commented Feb 18, 2026

Congratulate my friend 💯

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants