Skip to content

[WIP][MoE] Gpt oss moe kernels#447

Closed
Datta0 wants to merge 68 commits into
unslothai:nightlyfrom
Datta0:gpt_oss_moe_kernels
Closed

[WIP][MoE] Gpt oss moe kernels#447
Datta0 wants to merge 68 commits into
unslothai:nightlyfrom
Datta0:gpt_oss_moe_kernels

Conversation

@Datta0

@Datta0 Datta0 commented Jan 26, 2026

Copy link
Copy Markdown
Collaborator

Please take a look at #396 first and then this

@Datta0 Datta0 changed the title [MoE] Gpt oss moe kernels [WIP][MoE] Gpt oss moe kernels Jan 26, 2026
@gemini-code-assist

Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @Datta0, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a suite of significant enhancements aimed at optimizing Mixture of Experts (MoE) models within the Unsloth framework. It integrates advanced 4-bit quantization techniques, leverages high-performance grouped GEMM kernels for faster operations, and refines the handling of LoRA adapters for MoE architectures. The changes also include model-specific optimizations for Qwen3-VL-MoE, provide greater control over MXFP4 quantization, and add support for Generative Reinforcement Learning from Pairwise Optimization (GRPO) training. These improvements collectively boost efficiency, reduce memory consumption, and expand the utility of MoE models.

Highlights

  • MoE 4-bit Quantization: Introduced bitsandbytes-style 4-bit quantization for Mixture of Experts (MoE) layers, enabling on-the-fly quantization and seamless integration with model loading processes. This addresses the challenge of quantizing MoE layers that use nn.Parameter tensors instead of standard nn.Linear modules.
  • Optimized MoE Backends: Implemented high-performance MoE forward passes utilizing torch._grouped_mm (native PyTorch) and Unsloth's Triton grouped GEMM kernels. The system now dynamically selects the most efficient backend based on availability and environment variables, significantly accelerating inference and training.
  • Enhanced LoRA for MoE: Improved LoRA support for MoE layers by patching PEFT's ParamWrapper. This allows for handling separated LoRA weights, with specific logic to correctly process both standard and transposed weight formats found in different MoE architectures (e.g., Qwen3-VL-MoE).
  • Qwen3-VL-MoE Specific Optimizations: Added specialized handling for Qwen3-VL-MoE models, including patching the __init__ method of Qwen3VLMoeTextExperts to initialize expert weights in a grouped_mm compatible transposed format. This ensures efficient loading and computation for these specific visual language models.
  • MXFP4 Quantization Flexibility: Introduced new configuration options for MXFP4 quantization. Users can now choose to keep MXFP4 weights quantized (without dequantization to bf16) if triton_kernels is available, offering improved memory efficiency and performance for compatible hardware.
  • GRPO Training Support: Patched the forward methods of Qwen3MoeForCausalLM and Qwen3VLMoeForConditionalGeneration to optionally return hidden states instead of logits. This feature is crucial for facilitating Generative Reinforcement Learning from Pairwise Optimization (GRPO) training workflows.
  • vLLM Integration Improvements: Made minor enhancements to vLLM integration, including a more robust import mechanism for vllm_version and the addition of an explicit lora_request_id parameter to the load_lora function for better management of LoRA requests.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant enhancements for Mixture of Experts (MoE) models, particularly for Qwen3 and Qwen3-VL architectures. Key changes include the implementation of 4-bit quantization for MoE layers using bitsandbytes, integration of grouped GEMM kernels (both native PyTorch and Triton) for optimized forward passes, and a robust patching mechanism for PEFT's ParamWrapper to support separated LoRA for MoE. Additionally, the changes improve vLLM integration by adding vllm_config propagation and refining vllm_version detection. The introduction of MXFP4 configuration for training, including conditional dequantization based on triton_kernels availability, further enhances the quantization capabilities. Overall, these changes aim to improve performance, memory efficiency, and compatibility for MoE models within the Unsloth ecosystem.

Comment on lines +184 to +185
gate_up_weight = gate_up_proj[expert_idx].data.clone() # [2*I, H]
down_weight = down_proj[expert_idx].data.clone() # [H, I]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of .data is deprecated in PyTorch. It's recommended to use .detach().clone() instead to avoid potential issues with autograd and ensure proper tensor handling.

Suggested change
gate_up_weight = gate_up_proj[expert_idx].data.clone() # [2*I, H]
down_weight = down_proj[expert_idx].data.clone() # [H, I]
gate_up_weight = gate_up_proj[expert_idx].detach().clone() # [2*I, H]
down_weight = down_proj[expert_idx].detach().clone() # [H, I]

Comment on lines +34 to +37
os.makedirs(UNSLOTH_COMPILE_LOCATION)
except:
pass

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The try...except block is too broad. It's generally better to catch specific exceptions like OSError or IOError for file operations, rather than a generic Exception. This helps in debugging and understanding the root cause of failures.

Suggested change
os.makedirs(UNSLOTH_COMPILE_LOCATION)
except:
pass
try:
os.makedirs(UNSLOTH_COMPILE_LOCATION)
except OSError:
pass

pass


install_to_cache(__file__, "moe_utils.py")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Executing install_to_cache(__file__, "moe_utils.py") directly at the module level means this code runs every time moe_utils.py is imported. While this might be intended for setup, it can lead to unexpected side effects or performance overhead in certain scenarios (e.g., repeated imports, testing environments). Consider wrapping this call in a function or a conditional block (e.g., if __name__ == "__main__":) if it's meant for a specific initialization step, or add a comment explaining why it's necessary to run at import time.

Comment on lines +98 to +99
except Exception:
_TORCH_GROUPED_MM_SUPPORTED = False

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The try...except Exception block is too broad. For a runtime check like this, it would be more precise to catch specific exceptions that torch._grouped_mm might raise, such as RuntimeError or AttributeError, to avoid masking other potential issues.

Suggested change
except Exception:
_TORCH_GROUPED_MM_SUPPORTED = False
except (RuntimeError, AttributeError):
_TORCH_GROUPED_MM_SUPPORTED = False

Comment on lines +145 to +146
except Exception:
pass

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The try...except Exception block is too broad. When dealing with external libraries like Triton, it's usually more appropriate to catch ImportError or ModuleNotFoundError if the issue is related to the library's availability, or specific runtime errors if the issue is with its usage.

Suggested change
except Exception:
pass
except ImportError:
pass

Comment on lines +421 to +422
except Exception:
return None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The try...except Exception block is too broad. When extracting LoRA weights, it's better to catch specific exceptions that might occur during attribute access or tensor manipulation, such as AttributeError, KeyError, or IndexError, to provide more targeted error handling and debugging.

Comment on lines +78 to +80
forward_native_grouped_mm = moe_utils.forward_native_grouped_mm
forward_triton_grouped_gemm = moe_utils.forward_triton_grouped_gemm
forward_native_moe_loop = moe_utils.forward_native_moe_loop

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Importing forward_native_grouped_mm, forward_triton_grouped_gemm, and forward_native_moe_loop locally within the old_forward function is unusual. While it works, it can make the code harder to read and potentially lead to subtle scoping issues or unexpected behavior if the module structure changes. It's generally better practice to place imports at the top of the module or function scope where they are first needed, rather than inside a dynamically patched function. Consider moving these imports to the module level if they are consistently used across different patched functions.

Comment on lines +134 to +136
forward_native_grouped_mm = moe_utils.forward_native_grouped_mm
forward_triton_grouped_gemm = moe_utils.forward_triton_grouped_gemm
forward_native_moe_loop = moe_utils.forward_native_moe_loop

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the old_forward function, importing forward_native_grouped_mm, forward_triton_grouped_gemm, and forward_native_moe_loop locally within this forward function is not ideal. It's generally better practice to place imports at the top of the module or function scope where they are first needed. For dynamically patched functions, this pattern might be a workaround, but it could impact readability and maintainability. If these functions are meant to be globally available to the patched methods, module-level imports would be clearer.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 60350d5958

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +401 to +405
# Get dimensions from the original module
gate_up_proj = experts.gate_up_proj
num_experts = gate_up_proj.shape[0]
intermediate_dim = gate_up_proj.shape[1] // 2
hidden_dim = gate_up_proj.shape[2]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Fix MoE 4-bit dim inference for Qwen3-VL layout

The quantized module infers intermediate_dim and hidden_dim from gate_up_proj.shape[1]/[2], which assumes the standard (E, 2I, H) layout. This same commit patches Qwen3‑VL MoE to store weights in grouped_mm format (E, H, 2I) in qwen3_vl_moe.py (lines 189–231), so for Qwen3‑VL this calculation swaps dimensions (intermediate_dim becomes H/2 and hidden_dim becomes 2*I). The result is a quantized module with wrong shapes, which will fail to load weights or produce invalid outputs. Consider detecting the transposed layout or using config.hidden_size/moe_intermediate_size instead of raw tensor dims.

Useful? React with 👍 / 👎.

Comment on lines +328 to +332
# Save weight data
gate_up_prefix = f"{prefix}gate_up_projs.{expert_idx}."
down_prefix = f"{prefix}down_projs.{expert_idx}."

destination[f"{gate_up_prefix}weight"] = (

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Make 4-bit MoE save/load key scheme consistent

The custom _save_to_state_dict writes weights under gate_up_projs.*/down_projs.*, but _load_from_state_dict only recognizes the stacked gate_up_proj/down_proj keys and otherwise defers to the default loader (which expects _bnb_gate_up_weights.*). A checkpoint saved by this module will therefore reload with missing quantized weights (keys become unexpected and the ParameterLists stay empty), breaking inference on reload. The save/load key names need to align so round‑trip works.

Useful? React with 👍 / 👎.

@Datta0

Datta0 commented Feb 3, 2026

Copy link
Copy Markdown
Collaborator Author

Closing in favor of #450

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant