GPT-OSS MoE expert routing optimization for native pytorch#535
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the performance of Mixture-of-Experts (MoE) models when using the native PyTorch backend. The core issue addressed was the frequent CUDA synchronizations caused by Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a significant optimization for MoE expert routing in the native PyTorch backend by replacing the torch.where call inside the expert loop with a more efficient pre-sorting approach. This change effectively reduces CUDA stream synchronizations from num_experts to a single one per forward pass, which should lead to noticeable performance improvements as described. The implementation looks correct and follows the described logic. I have one suggestion regarding code duplication to improve maintainability.
mmathew23
left a comment
There was a problem hiding this comment.
Just some small nits., but otherwise look good, thanks! Did you happen to check train time before and after the change? Would be good to know what gains to expect.
| if count == 0: | ||
| continue | ||
| # Use pre-computed indices (no torch.where needed) | ||
| idx = sorted_tokens[offset:offset + count] |
There was a problem hiding this comment.
NIT: Can we keep this as token_idx instead of idx?
| continue | ||
|
|
||
| # Use pre-computed indices (no torch.where needed) | ||
| idx = sorted_tokens[offset:offset + count] |
There was a problem hiding this comment.
NIT: Can we keep this as token_idx instead of idx?
| gated_output = gated_output.to(torch.float32) | ||
| device_type = gated_output.device.type if isinstance(gated_output.device.type, str) and gated_output.device.type != "mps" else "cpu" | ||
| with torch.autocast(device_type=device_type, enabled=False): # Force float32 | ||
| with torch.autocast(device_type=device_type, enabled=False): |
There was a problem hiding this comment.
NIT: can we keep the comment?
|
Thank you @mmathew23 for the review! I’ve fixed the naming issue as suggested :)
Yes, I ran this optimization on the GPT-OSS 20B model. The performance improvement is reported in the PR description: "we achieved ~23% speedup for the forward pass and ~13% speedup for the backward pass during GPT-OSS-20B fine-tuning." It would be great if the Unsloth team could validate performance numbers for this PR as well. However, with the current Unsloth version (after the It might be helpful for the Unsloth team to check whether this is reproducible, specifically by comparing VRAM usage between the January version and the current version of Unsloth when running GPT-OSS-20B QLoRA fine-tuning with bs=8, keeping all other hyperparameters identical to the official Unsloth GPT-OSS fine-tuning notebook. PS: I’m currently on vacation, so responses may be delayed. I can make small fixes to the PR from my personal Mac, but I’m not able to run tests or experiments at the moment. Thanks! |
|
Actually you shouldn't need to set |
|
@danielhanchen i checked again on a T4 notebook and B200. We're seeing consistent speed improvements on both around 10-15%. |
Optimized MoE expert routing loop to eliminate unnecessary cudaStreamSynchronize calls caused by
torch.where. Currently benefits GPT-OSS MoE models usingos.environ["UNSLOTH_MOE_BACKEND"] = "native_torch".The original implementation uses
torch.where(router_indices == expert_idx)inside the expert loop to find which tokens are assigned to each expert.torch.whereon a boolean mask requires determining the output size dynamically, which triggers a CUDA synchronization on every iteration — resulting in num_experts synchronization points per forward pass.Solution
Replace
torch.wherewith a pre-sorting approach:bincount+.tolist()(single D2H sync)This reduces CUDA synchronization from num_experts times to 1 time per forward pass.
Nsys profiling traces as follows:
Performance
With our optimized MoE implementation, we achieve a 23% speedup for the forward pass and a 13% speedup for the backward pass in GPT-OSS-20B fine-tuning. Since Unsloth now supports
grouped_mm, this optimization may not be critical as before, but it still benefits those using the native PyTorch backend for MoE models.