Skip to content

Fix gemm allreduce two shot#2171

Merged
yzh119 merged 3 commits intoflashinfer-ai:mainfrom
aleozlx:fix/gemm_allreduce_two_shot
Dec 5, 2025
Merged

Fix gemm allreduce two shot#2171
yzh119 merged 3 commits intoflashinfer-ai:mainfrom
aleozlx:fix/gemm_allreduce_two_shot

Conversation

@aleozlx
Copy link
Copy Markdown
Collaborator

@aleozlx aleozlx commented Dec 3, 2025

📌 Description

  • Fix test_cute_dsl_gemm_allreduce_two_shot.py regression from nvidia-cutlass-dsl upgrade to 4.3.1 (removed helper functions)
  • GB300 enabled for this kernel as well

need 8xB200

pytest tests/gemm/test_cute_dsl_gemm_allreduce_two_shot.py -v

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

Release Notes

  • New Features

    • Added new synchronization helper functions for inter-GPU coordination: spin_lock_multimem_arrive, spin_lock_atom_cas_acquire_wait, and sm_wise_inter_gpu_multimem_barrier.
  • Tests

    • Extended test coverage to support additional GPU architectures (SM10.3).

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Dec 3, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

The changes refactor distributed synchronization helpers by consolidating imports from a distributed_helpers module into a distributed module. Three new helper functions for GPU synchronization are introduced: spin-lock multimem arrival, atomic CAS acquisition wait, and inter-GPU multimem barrier coordination. Public API exports are updated accordingly.

Changes

Cohort / File(s) Summary
Distributed synchronization refactoring
flashinfer/cute_dsl/gemm_allreduce_two_shot.py
Replaces distributed_helpers imports with distributed module. Introduces three new public helper functions: spin_lock_multimem_arrive(), spin_lock_atom_cas_acquire_wait(), and sm_wise_inter_gpu_multimem_barrier() for GPU synchronization. Updates kernel call sites to use new functions. Adds Pointer to public typing imports.
Test coverage expansion
tests/gemm/test_cute_dsl_gemm_allreduce_two_shot.py
Broadens GPU compute capability check from strict equality check for (10, 0) to allow both (10, 0) and (10, 3), enabling tests to run on SM100 GPUs.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

  • Primary changes are refactoring of helper function imports and extraction of synchronization primitives
  • New functions follow established patterns for GPU synchronization operations
  • Test change is straightforward condition modification
  • Attention areas:
    • Verify correctness of spin_lock_atom_cas_acquire_wait() scope parameter handling (gpu vs. system)
    • Confirm sm_wise_inter_gpu_multimem_barrier() coordination logic matches prior distributed\_helpers behavior
    • Ensure new public API exports are properly documented

Suggested reviewers

  • yzh119
  • djmmoss
  • kaixih

Poem

🔒 Locks spin true and barriers bloom,
Distributed threads escape the gloom,
From helpers old to modules new,
SM100 joins the GPU crew! 🐰✨

Pre-merge checks and finishing touches

❌ Failed checks (1 inconclusive)
Check name Status Explanation Resolution
Title check ❓ Inconclusive The title references the main component being fixed (gemm allreduce two shot) but lacks specificity about the nature of the fix. Consider a more descriptive title that clarifies the fix, e.g. 'Fix gemm allreduce two shot regression from cutlass-dsl upgrade' or 'Replace distributed helpers in gemm allreduce two shot'.
✅ Passed checks (2 passed)
Check name Status Explanation
Description check ✅ Passed The description covers the main issues and includes checklist completions, but lacks detail on specific changes and the implementation approach.
Docstring Coverage ✅ Passed Docstring coverage is 80.00% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @aleozlx, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a regression in the gemm_allreduce_two_shot test suite, which arose after an upgrade of the nvidia-cutlass-dsl library to version 4.3.1. The core issue involved changes to the distributed helper functions within the cutlass library. The solution involves updating module imports and reimplementing specific distributed synchronization and memory access primitives to ensure continued compatibility and correct operation with the new library version.

Highlights

  • Dependency Update Compatibility: Addressed a regression caused by upgrading nvidia-cutlass-dsl to version 4.3.1, which removed certain helper functions previously used.
  • Distributed Utilities Refactoring: Updated the import path for cutlass distributed utilities from cutlass.utils.distributed_helpers to cutlass.utils.distributed.
  • Reimplementation of Helper Functions: Introduced local implementations for spin_lock_multimem_arrive and sm_wise_inter_gpu_multimem_barrier to restore functionality that was removed or changed in the new cutlass-dsl version.
  • API Migration: Migrated existing calls to distributed memory operations, such as spin locks and multimemory load/reduce/store functions, to align with the updated API of the cutlass.utils.distributed module.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix a regression caused by an upgrade of nvidia-cutlass-dsl, which removed some helper functions. The changes involve replacing the deprecated cutlass.utils.distributed_helpers with cutlass.utils.distributed and re-implementing some of the removed helper functions locally. While the intent is clear, my review has identified two critical issues in the newly added helper functions. One function, sm_wise_inter_gpu_multimem_barrier, has incomplete logic that will lead to race conditions. Additionally, the new functions use a Pointer type hint that is not imported, which will cause a NameError at runtime. These issues must be addressed before merging.

Comment thread flashinfer/cute_dsl/gemm_allreduce_two_shot.py
Comment thread flashinfer/cute_dsl/gemm_allreduce_two_shot.py Outdated
@aleozlx
Copy link
Copy Markdown
Collaborator Author

aleozlx commented Dec 5, 2025

reported NVIDIA/cutlass#2845

@aleozlx
Copy link
Copy Markdown
Collaborator Author

aleozlx commented Dec 5, 2025

(py312) root@bia0097:/home/aleyang/repos/flashinfer# pytest tests/gemm/test_cute_dsl_gemm_allreduce_two_shot.py -v -s
================================================================================================ test session starts =================================================================================================
platform linux -- Python 3.12.11, pytest-9.0.1, pluggy-1.6.0 -- /opt/conda/envs/py312/bin/python3.12
cachedir: .pytest_cache
rootdir: /home/aleyang/repos/flashinfer
configfile: pytest.ini
collected 1 item

tests/gemm/test_cute_dsl_gemm_allreduce_two_shot.py::test_cute_dsl_gemm_allreduce_two_shot[8] Running test for world_size=8
[Gloo] Rank 0 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 4 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 5 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 1 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 2 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 6 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 3 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 7 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
Running Blackwell Persistent Dense GEMM test with:
mnkl: (2048, 2048, 4096, 1)
AB dtype: TFloat32, C dtype: Float32, Acc dtype: Float32
Matrix majors - A: k, B: k, C: n
Mma Tiler (M, N): (128, 128), Cluster Shape (M, N): (1, 1)
2CTA MMA instructions: False
Use TMA Store: False
Tolerance: 0.1
Warmup iterations: 0
Iterations: 1
Skip reference checking: False
Use cold L2: False
Fused AllReduce Op: two_shot
Running Blackwell Persistent Dense GEMM test with:
mnkl: (2048, 2048, 4096, 1)
AB dtype: TFloat32, C dtype: Float32, Acc dtype: Float32
Matrix majors - A: k, B: k, C: n
Mma Tiler (M, N): (128, 128), Cluster Shape (M, N): (1, 1)
2CTA MMA instructions: False
Use TMA Store: False
Tolerance: 0.1
Warmup iterations: 0
Iterations: 1
Skip reference checking: False
Use cold L2: False
Fused AllReduce Op: two_shot
Running Blackwell Persistent Dense GEMM test with:
mnkl: (2048, 2048, 4096, 1)
AB dtype: TFloat32, C dtype: Float32, Acc dtype: Float32
Matrix majors - A: k, B: k, C: n
Mma Tiler (M, N): (128, 128), Cluster Shape (M, N): (1, 1)
2CTA MMA instructions: False
Use TMA Store: False
Tolerance: 0.1
Warmup iterations: 0
Iterations: 1
Skip reference checking: False
Use cold L2: False
Fused AllReduce Op: two_shot
Running Blackwell Persistent Dense GEMM test with:
mnkl: (2048, 2048, 4096, 1)
AB dtype: TFloat32, C dtype: Float32, Acc dtype: Float32
Matrix majors - A: k, B: k, C: n
Mma Tiler (M, N): (128, 128), Cluster Shape (M, N): (1, 1)
2CTA MMA instructions: False
Use TMA Store: False
Tolerance: 0.1
Warmup iterations: 0
Iterations: 1
Skip reference checking: False
Use cold L2: False
Fused AllReduce Op: two_shot
Running Blackwell Persistent Dense GEMM test with:
mnkl: (2048, 2048, 4096, 1)
AB dtype: TFloat32, C dtype: Float32, Acc dtype: Float32
Matrix majors - A: k, B: k, C: n
Mma Tiler (M, N): (128, 128), Cluster Shape (M, N): (1, 1)
2CTA MMA instructions: False
Use TMA Store: False
Tolerance: 0.1
Warmup iterations: 0
Iterations: 1
Skip reference checking: False
Use cold L2: False
Fused AllReduce Op: two_shot
Running Blackwell Persistent Dense GEMM test with:
mnkl: (2048, 2048, 4096, 1)
AB dtype: TFloat32, C dtype: Float32, Acc dtype: Float32
Matrix majors - A: k, B: k, C: n
Mma Tiler (M, N): (128, 128), Cluster Shape (M, N): (1, 1)
2CTA MMA instructions: False
Use TMA Store: False
Tolerance: 0.1
Warmup iterations: 0
Iterations: 1
Skip reference checking: False
Use cold L2: False
Fused AllReduce Op: two_shot
Running Blackwell Persistent Dense GEMM test with:
mnkl: (2048, 2048, 4096, 1)
AB dtype: TFloat32, C dtype: Float32, Acc dtype: Float32
Matrix majors - A: k, B: k, C: n
Mma Tiler (M, N): (128, 128), Cluster Shape (M, N): (1, 1)
2CTA MMA instructions: False
Use TMA Store: False
Tolerance: 0.1
Warmup iterations: 0
Iterations: 1
Skip reference checking: False
Use cold L2: False
Fused AllReduce Op: two_shot
Running Blackwell Persistent Dense GEMM test with:
mnkl: (2048, 2048, 4096, 1)
AB dtype: TFloat32, C dtype: Float32, Acc dtype: Float32
Matrix majors - A: k, B: k, C: n
Mma Tiler (M, N): (128, 128), Cluster Shape (M, N): (1, 1)
2CTA MMA instructions: False
Use TMA Store: False
Tolerance: 0.1
Warmup iterations: 0
Iterations: 1
Skip reference checking: False
Use cold L2: False
Fused AllReduce Op: two_shot
exec_time: 403.6799967288971
cute_dsl_gemm_allreduce_two_shot on 8 GPUs: OK
PASSED

================================================================================================= 1 passed in 48.25s

@aleozlx aleozlx marked this pull request as ready for review December 5, 2025 03:31
@aleozlx aleozlx requested review from kaixih and yzh119 as code owners December 5, 2025 03:31
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
flashinfer/cute_dsl/gemm_allreduce_two_shot.py (1)

85-102: Add type hint for num_ranks parameter.

The wait logic is now properly implemented, addressing the previous concern. However, the num_ranks parameter still lacks a type hint. For consistency with other parameters and better code clarity, please add a type annotation (likely int or Int32).

Apply this diff:

 def sm_wise_inter_gpu_multimem_barrier(
-    barrier: Pointer, barrier_mc: Pointer, num_ranks, loc=None, ip=None
+    barrier: Pointer, barrier_mc: Pointer, num_ranks: int, loc=None, ip=None
 ) -> None:

Based on learnings, this was previously flagged as missing.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 442dec9 and fa4467e.

📒 Files selected for processing (2)
  • flashinfer/cute_dsl/gemm_allreduce_two_shot.py (6 hunks)
  • tests/gemm/test_cute_dsl_gemm_allreduce_two_shot.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/gemm/test_cute_dsl_gemm_allreduce_two_shot.py (1)
flashinfer/utils.py (1)
  • get_compute_capability (253-256)
🔇 Additional comments (9)
tests/gemm/test_cute_dsl_gemm_allreduce_two_shot.py (1)

487-488: LGTM! Broadened compute capability support.

The change correctly enables the test on both SM 10.0 and SM 10.3 (GB300) GPUs as intended by the PR.

flashinfer/cute_dsl/gemm_allreduce_two_shot.py (8)

11-11: LGTM! Import updated to align with cutlass 4.3.1.

The import change from distributed_helpers to distributed correctly addresses the library upgrade.


15-15: LGTM! Pointer type now properly imported.

This addresses the previous review comment about the missing Pointer type import. Based on learnings, this resolves a past concern about undefined types in function signatures.


25-30: LGTM! Clean wrapper for multimem arrive operation.

The function correctly wraps the distributed module's multimem_red_relaxed_gpu_add1 operation for spinlock arrival.


32-83: LGTM! Workaround for cutlass 4.3.1 regression.

This function implements the missing CAS-based acquire wait semantics. The HACK comment and GitHub issue link appropriately document that this is a temporary workaround until the functionality is restored in a future cutlass release. Based on learnings, this addresses the past concern about incomplete wait logic.


1315-1316: LGTM! Correctly uses the new helper function.

The call site properly invokes the locally-defined spin_lock_multimem_arrive function.


1405-1407: LGTM! Correctly migrated to new distributed module API.

The call site properly uses distributed.spin_lock_atom_cas_relaxed_wait with appropriate parameters for the all-reduce synchronization barrier.


1442-1461: LGTM! Correctly uses distributed module's multimem operations.

The type-specific load-reduce and store operations properly implement the all-reduce logic across different data types (Float16, Float32, BFloat16, Float8E4M3FN, Float8E5M2).


1476-1480: LGTM! Correctly uses the new barrier helper.

The call site properly invokes the locally-defined sm_wise_inter_gpu_multimem_barrier function for the final synchronization after all-reduce operations.

@yzh119 yzh119 merged commit b972005 into flashinfer-ai:main Dec 5, 2025
4 checks passed
BingooYang pushed a commit to BingooYang/flashinfer that referenced this pull request Mar 13, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

- Fix test_cute_dsl_gemm_allreduce_two_shot.py regression from
nvidia-cutlass-dsl upgrade to 4.3.1 (removed helper functions)
- GB300 enabled for this kernel as well

need 8xB200

pytest tests/gemm/test_cute_dsl_gemm_allreduce_two_shot.py -v

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

## Release Notes

* **New Features**
* Added new synchronization helper functions for inter-GPU coordination:
`spin_lock_multimem_arrive`, `spin_lock_atom_cas_acquire_wait`, and
`sm_wise_inter_gpu_multimem_barrier`.

* **Tests**
* Extended test coverage to support additional GPU architectures
(SM10.3).

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants