feat: Support numLocalTokens=0 for moe All-to-all#2247
feat: Support numLocalTokens=0 for moe All-to-all#2247yzh119 merged 1 commit intoflashinfer-ai:mainfrom
Conversation
WalkthroughThe changes extend the MoE All-to-All communication kernels to support zero-token scenarios by implementing thread synchronization mechanisms, per-token shared-memory setup, duplicate target rank deduplication via bitmask, and relaxed parameter validation from Changes
Sequence DiagramsequenceDiagram
participant Grid as GPU Grid
participant T0 as Thread 0
participant TN as Other Threads
participant SMem as Shared Memory
participant GMemD as Global Memory (Dispatch)
Grid->>T0: Launch kernel (local_num_tokens == 0)
Grid->>TN: Launch kernel
rect rgba(200, 220, 255, 0.3)
Note over T0, TN: Edge-case check
T0->>T0: Check local_num_tokens == 0
TN->>TN: Check local_num_tokens == 0
TN->>TN: Exit early
end
rect rgba(220, 200, 255, 0.3)
Note over T0, SMem: Per-token tile setup
T0->>SMem: Setup shard pointers (based on threading policy)
T0->>SMem: Initialize shared-memory tiles
T0->>T0: Sync barrier (ThreadingPolicy::sync)
end
rect rgba(255, 220, 200, 0.3)
Note over T0, SMem: Routing deduplication
T0->>SMem: Read topk_target_ranks, topk_send_indices
T0->>SMem: Check bitmask (already_copied) for duplicates
alt Duplicate detected
T0->>SMem: Mirror -1 marker
else First occurrence
T0->>GMemD: atomicAdd to destination index
T0->>SMem: Record target rank & dest index
T0->>GMemD: Record in global memory
end
T0->>T0: Sync barrier
end
rect rgba(200, 255, 220, 0.3)
Note over T0, GMemD: Dispatch payloads
T0->>SMem: Read staged routing from registers
T0->>GMemD: Vectorized dispatch per payload
end
rect rgba(255, 255, 200, 0.3)
Note over T0, GMemD: Last-token finalization
T0->>T0: Update is_last_token (no atomic increment for zero-token)
T0->>GMemD: Synchronize for peer readiness
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: defaults Review profile: CHILL Plan: Pro 📒 Files selected for processing (2)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
🔇 Additional comments (8)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @trevor-m, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a significant enhancement to the Mixture-of-Experts (MoE) All-to-All communication mechanism by enabling robust support for scenarios where Data Parallel (DP) workers might have zero local tokens. This capability is essential for integrating with systems like SGlang, which previously relied on cumbersome workarounds. The changes ensure that the underlying CUDA kernels and their launch configurations gracefully handle the absence of tokens, maintaining synchronization integrity and improving the overall flexibility and stability of the MoE system. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request ports a feature from TensorRT-LLM to support numLocalTokens=0 for MoE All-to-all communication. This is an important edge case for certain use cases like SGlang integration. The changes correctly handle this scenario by launching a minimal kernel for synchronization even when there are no local tokens, and by adjusting boundary checks and logic within the kernels. The implementation appears correct. I've added a few suggestions to reduce code duplication and improve maintainability.
| int grid_size = params.local_num_tokens; | ||
| // If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the | ||
| // synchronization. | ||
| if (grid_size == 0) { | ||
| grid_size = 1; | ||
| } |
There was a problem hiding this comment.
This logic to ensure grid_size is at least 1 can be written more concisely. This pattern is also repeated in the else block below (lines 512-517) and in moe_a2a_combine_launch. Using a more compact form would improve maintainability.
int grid_size = params.local_num_tokens > 0 ? params.local_num_tokens : 1;
// If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the
// synchronization.
| int grid_size = ceilDiv(params.local_num_tokens, kWarpsPerBlock); | ||
| // If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the | ||
| // synchronization. | ||
| if (grid_size == 0) { | ||
| grid_size = 1; | ||
| } |
There was a problem hiding this comment.
Similar to the if branch above, this logic to ensure grid_size is at least 1 can be made more concise. This improves readability and reduces duplication.
int grid_size = params.local_num_tokens == 0 ? 1 : ceilDiv(params.local_num_tokens, kWarpsPerBlock);
// If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the
// synchronization.
| // If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the | ||
| // synchronization. | ||
| if (grid_size_warp == 0) { | ||
| grid_size_warp = 1; | ||
| } | ||
| if (grid_size_block == 0) { | ||
| grid_size_block = 1; | ||
| } |
There was a problem hiding this comment.
This logic for ensuring grid sizes are at least 1 can be simplified. Since both grid_size_warp and grid_size_block will be 0 if and only if params.local_num_tokens is 0, you can use a single condition to set both, which is clearer and avoids repetition.
// If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the
// synchronization.
if (params.local_num_tokens == 0) {
grid_size_warp = 1;
grid_size_block = 1;
}
📌 Description
This is a port of NVIDIA/TensorRT-LLM#9822 which was done by @bobboli
This feature is necessary for SGlang integration because some DP workers may have 0 tokens. The workaround to use a dummy token is quite messy and brittle.
🔍 Related Issues
Follow up to #2102
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.