Skip to content

feat: Support numLocalTokens=0 for moe All-to-all#2247

Merged
yzh119 merged 1 commit intoflashinfer-ai:mainfrom
trevor-m:zeropatch
Dec 23, 2025
Merged

feat: Support numLocalTokens=0 for moe All-to-all#2247
yzh119 merged 1 commit intoflashinfer-ai:mainfrom
trevor-m:zeropatch

Conversation

@trevor-m
Copy link
Copy Markdown
Contributor

@trevor-m trevor-m commented Dec 19, 2025

📌 Description

This is a port of NVIDIA/TensorRT-LLM#9822 which was done by @bobboli

This feature is necessary for SGlang integration because some DP workers may have 0 tokens. The workaround to use a dummy token is quite messy and brittle.

🔍 Related Issues

Follow up to #2102

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes
    • Improved robustness of mixture-of-experts all-to-all communication to gracefully handle scenarios with zero local tokens, preventing synchronization failures and ensuring stable operation in edge cases.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Dec 19, 2025

Walkthrough

The changes extend the MoE All-to-All communication kernels to support zero-token scenarios by implementing thread synchronization mechanisms, per-token shared-memory setup, duplicate target rank deduplication via bitmask, and relaxed parameter validation from > 0 to >= 0 across dispatch and combine operations.

Changes

Cohort / File(s) Summary
Kernel edge-case handling
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu
Added special-case logic for local_num_tokens == 0 to keep thread 0 alive for synchronization while other threads exit early; introduced per-token shared-memory tile setup with WarpPolicy/BlockPolicy conditioning; implemented duplicate target rank deduplication using bitmask (already_copied) with mirror -1 markers; added synchronization barriers at key points; reworked dispatch loop for per-payload vectorized operations reading from staged shared memory; updated last-token detection for zero-token scenarios; ensured minimal grid launch (grid_size >= 1) for synchronization participation.
Validation parameter relaxation
csrc/trtllm_moe_alltoall.cu
Relaxed local_num_tokens validation checks in moeA2ADispatchOp and moeA2ACombineOp from > 0 to >= 0 to accept zero-token scenarios.

Sequence Diagram

sequenceDiagram
    participant Grid as GPU Grid
    participant T0 as Thread 0
    participant TN as Other Threads
    participant SMem as Shared Memory
    participant GMemD as Global Memory (Dispatch)
    
    Grid->>T0: Launch kernel (local_num_tokens == 0)
    Grid->>TN: Launch kernel
    
    rect rgba(200, 220, 255, 0.3)
    Note over T0, TN: Edge-case check
    T0->>T0: Check local_num_tokens == 0
    TN->>TN: Check local_num_tokens == 0
    TN->>TN: Exit early
    end
    
    rect rgba(220, 200, 255, 0.3)
    Note over T0, SMem: Per-token tile setup
    T0->>SMem: Setup shard pointers (based on threading policy)
    T0->>SMem: Initialize shared-memory tiles
    T0->>T0: Sync barrier (ThreadingPolicy::sync)
    end
    
    rect rgba(255, 220, 200, 0.3)
    Note over T0, SMem: Routing deduplication
    T0->>SMem: Read topk_target_ranks, topk_send_indices
    T0->>SMem: Check bitmask (already_copied) for duplicates
    alt Duplicate detected
        T0->>SMem: Mirror -1 marker
    else First occurrence
        T0->>GMemD: atomicAdd to destination index
        T0->>SMem: Record target rank & dest index
        T0->>GMemD: Record in global memory
    end
    T0->>T0: Sync barrier
    end
    
    rect rgba(200, 255, 220, 0.3)
    Note over T0, GMemD: Dispatch payloads
    T0->>SMem: Read staged routing from registers
    T0->>GMemD: Vectorized dispatch per payload
    end
    
    rect rgba(255, 255, 200, 0.3)
    Note over T0, GMemD: Last-token finalization
    T0->>T0: Update is_last_token (no atomic increment for zero-token)
    T0->>GMemD: Synchronize for peer readiness
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

  • Synchronization logic: Validate thread 0 keepalive semantics and early-exit correctness for zero-token cases; verify all synchronization barriers are correctly placed to prevent deadlock or race conditions.
  • Shared memory access patterns: Confirm per-token tile setup works correctly with both WarpPolicy and BlockPolicy; check bitmask-based deduplication logic for correctness and warp-level atomicity.
  • Dispatch/combine symmetry: Ensure dispatch and combine kernels handle zero-token scenarios symmetrically and that grid sizing (>= 1) is applied consistently.
  • Edge case robustness: Verify that removing atomic increments for last-token detection in zero-token scenarios doesn't introduce correctness issues with distributed synchronization.

Poem

A kernel that counts zero tokens with care, 🐰
Thread zero stands guard at the synchronization square,
While duplicates vanish via bitmask's delight,
Payloads dance vectorized through shared memory's might—
MoE whispers grateful: "You've fixed my plight!"

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.33% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely summarizes the main feature: adding support for zero local tokens in MOE all-to-all operations, which is the primary objective of this changeset.
Description check ✅ Passed The description provides context (ports NVIDIA PR 9822), explains the motivation (SGlang integration with zero-token workers), and references related issues, though the checklist items remain unchecked.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 261f5ab and 05a65ce.

📒 Files selected for processing (2)
  • csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (7 hunks)
  • csrc/trtllm_moe_alltoall.cu (2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (8)
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (6)

297-374: LGTM! Zero-token handling is correct.

The special-case logic correctly keeps only threads with local_token_idx == 0 alive for synchronization while others return early. For WarpPolicy with 1 block launched, this means warp 0 stays alive (warps 1-7 return), which is sufficient since only the first warp participates in synchronization at line 376+.


382-387: LGTM! Last-token detection correctly handles the zero case.

Directly setting is_last_token = true when local_num_tokens == 0 is correct and avoids unnecessary atomic operations.


459-459: LGTM! Validation correctly relaxed to support zero tokens.


501-503: LGTM! Grid size adjustments ensure synchronization participation.

The logic correctly ensures at least one block is launched when local_num_tokens == 0, which is necessary for inter-rank synchronization even when a rank has no tokens to process.

Also applies to: 515-517


703-711: LGTM! Combine kernel correctly mirrors dispatch logic.

The zero-token handling in the combine kernel follows the same pattern as the dispatch kernel: keep threads with local_token_idx == 0 alive for synchronization, then exit early (line 766) before any token processing.

Also applies to: 766-766


824-824: LGTM! Combine launch correctly adjusted for zero tokens.

The validation relaxation and grid size adjustments mirror the dispatch launch changes, ensuring consistency across the pipeline.

Also applies to: 834-839

csrc/trtllm_moe_alltoall.cu (2)

136-136: LGTM! Dispatch validation correctly relaxed.

The change from > 0 to >= 0 aligns with the kernel-level support for zero tokens and is consistent with the PR objective.


284-284: LGTM! Combine validation correctly relaxed.

The change enables zero-token combine operations, allowing alloc_tensor to create a valid [0, elementsPerToken] output tensor when needed.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @trevor-m, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant enhancement to the Mixture-of-Experts (MoE) All-to-All communication mechanism by enabling robust support for scenarios where Data Parallel (DP) workers might have zero local tokens. This capability is essential for integrating with systems like SGlang, which previously relied on cumbersome workarounds. The changes ensure that the underlying CUDA kernels and their launch configurations gracefully handle the absence of tokens, maintaining synchronization integrity and improving the overall flexibility and stability of the MoE system.

Highlights

  • Zero Token Support: The Mixture-of-Experts (MoE) All-to-All communication kernels (moeA2ADispatchKernel and moeA2ACombineKernel) now correctly handle cases where numLocalTokens is zero, ensuring necessary synchronization occurs even when no tokens are present.
  • Validation Relaxed: Validation checks for local_num_tokens in launch functions (moe_a2a_dispatch_launch, moe_a2a_combine_launch) and FFI operations (moeA2ADispatchOp, moeA2ACombineOp) have been updated to accept values greater than or equal to zero (>= 0), instead of strictly greater than zero (> 0).
  • Kernel Launch Adjustment: Grid sizes for kernel launches are now set to a minimum of 1 when local_num_tokens is zero. This ensures that kernels are still launched and can participate in synchronization, even if there are no tokens to process.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@trevor-m
Copy link
Copy Markdown
Contributor Author

@djns99 @yzh119 It would be great if this could make it into the next release, it makes the Sglang integraton of the new a2a much easier.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request ports a feature from TensorRT-LLM to support numLocalTokens=0 for MoE All-to-all communication. This is an important edge case for certain use cases like SGlang integration. The changes correctly handle this scenario by launching a minimal kernel for synchronization even when there are no local tokens, and by adjusting boundary checks and logic within the kernels. The implementation appears correct. I've added a few suggestions to reduce code duplication and improve maintainability.

Comment on lines 498 to +503
int grid_size = params.local_num_tokens;
// If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the
// synchronization.
if (grid_size == 0) {
grid_size = 1;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic to ensure grid_size is at least 1 can be written more concisely. This pattern is also repeated in the else block below (lines 512-517) and in moe_a2a_combine_launch. Using a more compact form would improve maintainability.

    int grid_size = params.local_num_tokens > 0 ? params.local_num_tokens : 1;
    // If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the
    // synchronization.

Comment on lines 512 to +517
int grid_size = ceilDiv(params.local_num_tokens, kWarpsPerBlock);
// If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the
// synchronization.
if (grid_size == 0) {
grid_size = 1;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the if branch above, this logic to ensure grid_size is at least 1 can be made more concise. This improves readability and reduces duplication.

    int grid_size = params.local_num_tokens == 0 ? 1 : ceilDiv(params.local_num_tokens, kWarpsPerBlock);
    // If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the
    // synchronization.

Comment on lines +832 to +839
// If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the
// synchronization.
if (grid_size_warp == 0) {
grid_size_warp = 1;
}
if (grid_size_block == 0) {
grid_size_block = 1;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for ensuring grid sizes are at least 1 can be simplified. Since both grid_size_warp and grid_size_block will be 0 if and only if params.local_num_tokens is 0, you can use a single condition to set both, which is clearer and avoids repetition.

  // If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the
  // synchronization.
  if (params.local_num_tokens == 0) {
    grid_size_warp = 1;
    grid_size_block = 1;
  }

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @trevor-m thanks for the PR, I have no problem with this, it will be better if we have unittest coverage for the case of numLocalTokens=0.

@yzh119 yzh119 merged commit 25de38e into flashinfer-ai:main Dec 23, 2025
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants