Skip to content

GPT-OSS MoE expert routing optimization for native pytorch#535

Merged
mmathew23 merged 2 commits into
unslothai:mainfrom
ruixiang63:gpt-oss-optimization
Mar 18, 2026
Merged

GPT-OSS MoE expert routing optimization for native pytorch#535
mmathew23 merged 2 commits into
unslothai:mainfrom
ruixiang63:gpt-oss-optimization

Conversation

@ruixiang63

Copy link
Copy Markdown
Contributor

Optimized MoE expert routing loop to eliminate unnecessary cudaStreamSynchronize calls caused by torch.where. Currently benefits GPT-OSS MoE models using os.environ["UNSLOTH_MOE_BACKEND"] = "native_torch".

The original implementation uses torch.where(router_indices == expert_idx) inside the expert loop to find which tokens are assigned to each expert. torch.where on a boolean mask requires determining the output size dynamically, which triggers a CUDA synchronization on every iteration — resulting in num_experts synchronization points per forward pass.

Solution

Replace torch.where with a pre-sorting approach:

  • Flatten and sort all expert assignments using argsort (fixed output size, no sync)
  • Compute per-expert token counts using bincount + .tolist() (single D2H sync)
  • Index into the sorted array using cumulative offsets (no sync)

This reduces CUDA synchronization from num_experts times to 1 time per forward pass.

Nsys profiling traces as follows:

  • Without this PR
image
  • With this PR
image

Performance

With our optimized MoE implementation, we achieve a 23% speedup for the forward pass and a 13% speedup for the backward pass in GPT-OSS-20B fine-tuning. Since Unsloth now supports grouped_mm, this optimization may not be critical as before, but it still benefits those using the native PyTorch backend for MoE models.

@gemini-code-assist

Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the performance of Mixture-of-Experts (MoE) models when using the native PyTorch backend. The core issue addressed was the frequent CUDA synchronizations caused by torch.where within the expert routing loop. By implementing a more efficient pre-sorting algorithm, the number of synchronization points has been drastically reduced, leading to substantial speedups in both forward and backward passes during fine-tuning.

Highlights

  • Performance Optimization: Optimized the Mixture-of-Experts (MoE) expert routing loop in native PyTorch to eliminate unnecessary cudaStreamSynchronize calls, which were previously triggered by torch.where.
  • Algorithm Change: Replaced the torch.where based token assignment with a pre-sorting approach utilizing argsort, bincount, and cumulative offsets. This reduces CUDA synchronizations from num_experts times to a single synchronization per forward pass.
  • Impact: Achieved a 23% speedup for the forward pass and a 13% speedup for the backward pass in GPT-OSS-20B fine-tuning, specifically benefiting models using the native PyTorch backend for MoE.
Changelog
  • unsloth_zoo/temporary_patches/gpt_oss.py
    • Refactored the forward method within the MoE layer to implement a pre-sorting approach for expert token assignment.
    • Introduced num_tokens and top_k variables to facilitate the new routing logic.
    • Replaced dynamic torch.where calls with pre-computed sorted indices (sorted_tokens) and counts (counts) for efficient token distribution to experts.
    • Updated the torch_native_forward function with the same optimization strategy to ensure consistent performance improvements.
Activity
  • No human activity has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant optimization for MoE expert routing in the native PyTorch backend by replacing the torch.where call inside the expert loop with a more efficient pre-sorting approach. This change effectively reduces CUDA stream synchronizations from num_experts to a single one per forward pass, which should lead to noticeable performance improvements as described. The implementation looks correct and follows the described logic. I have one suggestion regarding code duplication to improve maintainability.

@mmathew23 mmathew23 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some small nits., but otherwise look good, thanks! Did you happen to check train time before and after the change? Would be good to know what gains to expect.

if count == 0:
continue
# Use pre-computed indices (no torch.where needed)
idx = sorted_tokens[offset:offset + count]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Can we keep this as token_idx instead of idx?

continue

# Use pre-computed indices (no torch.where needed)
idx = sorted_tokens[offset:offset + count]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Can we keep this as token_idx instead of idx?

gated_output = gated_output.to(torch.float32)
device_type = gated_output.device.type if isinstance(gated_output.device.type, str) and gated_output.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with torch.autocast(device_type=device_type, enabled=False):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: can we keep the comment?

@ruixiang63

ruixiang63 commented Mar 11, 2026

Copy link
Copy Markdown
Contributor Author

Thank you @mmathew23 for the review! I’ve fixed the naming issue as suggested :)

Did you happen to check train time before and after the change? Would be good to know what gains to expect.

Yes, I ran this optimization on the GPT-OSS 20B model. The performance improvement is reported in the PR description: "we achieved ~23% speedup for the forward pass and ~13% speedup for the backward pass during GPT-OSS-20B fine-tuning." It would be great if the Unsloth team could validate performance numbers for this PR as well.
These measurements were obtained using the January version of Unsloth (i.e., before the grouped_mm changes).

However, with the current Unsloth version (after the grouped_mm for MoE update), I encountered an OOM issue using the same training script. The script was unchanged except for adding: os.environ["UNSLOTH_MOE_BACKEND"] = "native_torch".
I ran into this issue about three weeks ago, and I’m not sure whether it has been resolved since then. I didn’t have the bandwidth to investigate further.

It might be helpful for the Unsloth team to check whether this is reproducible, specifically by comparing VRAM usage between the January version and the current version of Unsloth when running GPT-OSS-20B QLoRA fine-tuning with bs=8, keeping all other hyperparameters identical to the official Unsloth GPT-OSS fine-tuning notebook.
CC. @danielhanchen

PS: I’m currently on vacation, so responses may be delayed. I can make small fixes to the PR from my personal Mac, but I’m not able to run tests or experiments at the moment. Thanks!

@mmathew23

mmathew23 commented Mar 11, 2026

Copy link
Copy Markdown
Collaborator

Actually you shouldn't need to set UNSLOTH_MOE_BACKEND. I did test the PR on a B200 end to end on our gpt-oss 20 notebook and I'm seeing ~10% speedup which is excellent.

@mmathew23

Copy link
Copy Markdown
Collaborator

@danielhanchen i checked again on a T4 notebook and B200. We're seeing consistent speed improvements on both around 10-15%.

@mmathew23 mmathew23 merged commit 378a905 into unslothai:main Mar 18, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants