Fix gemm allreduce two shot#2171
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughThe changes refactor distributed synchronization helpers by consolidating imports from a Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 inconclusive)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @aleozlx, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves a regression in the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request aims to fix a regression caused by an upgrade of nvidia-cutlass-dsl, which removed some helper functions. The changes involve replacing the deprecated cutlass.utils.distributed_helpers with cutlass.utils.distributed and re-implementing some of the removed helper functions locally. While the intent is clear, my review has identified two critical issues in the newly added helper functions. One function, sm_wise_inter_gpu_multimem_barrier, has incomplete logic that will lead to race conditions. Additionally, the new functions use a Pointer type hint that is not imported, which will cause a NameError at runtime. These issues must be addressed before merging.
|
reported NVIDIA/cutlass#2845 |
|
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
flashinfer/cute_dsl/gemm_allreduce_two_shot.py (1)
85-102: Add type hint for num_ranks parameter.The wait logic is now properly implemented, addressing the previous concern. However, the
num_ranksparameter still lacks a type hint. For consistency with other parameters and better code clarity, please add a type annotation (likelyintorInt32).Apply this diff:
def sm_wise_inter_gpu_multimem_barrier( - barrier: Pointer, barrier_mc: Pointer, num_ranks, loc=None, ip=None + barrier: Pointer, barrier_mc: Pointer, num_ranks: int, loc=None, ip=None ) -> None:Based on learnings, this was previously flagged as missing.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/cute_dsl/gemm_allreduce_two_shot.py(6 hunks)tests/gemm/test_cute_dsl_gemm_allreduce_two_shot.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/gemm/test_cute_dsl_gemm_allreduce_two_shot.py (1)
flashinfer/utils.py (1)
get_compute_capability(253-256)
🔇 Additional comments (9)
tests/gemm/test_cute_dsl_gemm_allreduce_two_shot.py (1)
487-488: LGTM! Broadened compute capability support.The change correctly enables the test on both SM 10.0 and SM 10.3 (GB300) GPUs as intended by the PR.
flashinfer/cute_dsl/gemm_allreduce_two_shot.py (8)
11-11: LGTM! Import updated to align with cutlass 4.3.1.The import change from
distributed_helperstodistributedcorrectly addresses the library upgrade.
15-15: LGTM! Pointer type now properly imported.This addresses the previous review comment about the missing Pointer type import. Based on learnings, this resolves a past concern about undefined types in function signatures.
25-30: LGTM! Clean wrapper for multimem arrive operation.The function correctly wraps the distributed module's
multimem_red_relaxed_gpu_add1operation for spinlock arrival.
32-83: LGTM! Workaround for cutlass 4.3.1 regression.This function implements the missing CAS-based acquire wait semantics. The HACK comment and GitHub issue link appropriately document that this is a temporary workaround until the functionality is restored in a future cutlass release. Based on learnings, this addresses the past concern about incomplete wait logic.
1315-1316: LGTM! Correctly uses the new helper function.The call site properly invokes the locally-defined
spin_lock_multimem_arrivefunction.
1405-1407: LGTM! Correctly migrated to new distributed module API.The call site properly uses
distributed.spin_lock_atom_cas_relaxed_waitwith appropriate parameters for the all-reduce synchronization barrier.
1442-1461: LGTM! Correctly uses distributed module's multimem operations.The type-specific load-reduce and store operations properly implement the all-reduce logic across different data types (Float16, Float32, BFloat16, Float8E4M3FN, Float8E5M2).
1476-1480: LGTM! Correctly uses the new barrier helper.The call site properly invokes the locally-defined
sm_wise_inter_gpu_multimem_barrierfunction for the final synchronization after all-reduce operations.
<!-- .github/pull_request_template.md --> ## 📌 Description - Fix test_cute_dsl_gemm_allreduce_two_shot.py regression from nvidia-cutlass-dsl upgrade to 4.3.1 (removed helper functions) - GB300 enabled for this kernel as well need 8xB200 pytest tests/gemm/test_cute_dsl_gemm_allreduce_two_shot.py -v ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added new synchronization helper functions for inter-GPU coordination: `spin_lock_multimem_arrive`, `spin_lock_atom_cas_acquire_wait`, and `sm_wise_inter_gpu_multimem_barrier`. * **Tests** * Extended test coverage to support additional GPU architectures (SM10.3). <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
need 8xB200
pytest tests/gemm/test_cute_dsl_gemm_allreduce_two_shot.py -v
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Release Notes
New Features
spin_lock_multimem_arrive,spin_lock_atom_cas_acquire_wait, andsm_wise_inter_gpu_multimem_barrier.Tests
✏️ Tip: You can customize this high-level summary in your review settings.