[WIP][MoE] Gpt oss moe kernels#447
Conversation
This reverts commit 169b1ea.
Summary of ChangesHello @Datta0, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a suite of significant enhancements aimed at optimizing Mixture of Experts (MoE) models within the Unsloth framework. It integrates advanced 4-bit quantization techniques, leverages high-performance grouped GEMM kernels for faster operations, and refines the handling of LoRA adapters for MoE architectures. The changes also include model-specific optimizations for Qwen3-VL-MoE, provide greater control over MXFP4 quantization, and add support for Generative Reinforcement Learning from Pairwise Optimization (GRPO) training. These improvements collectively boost efficiency, reduce memory consumption, and expand the utility of MoE models. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces significant enhancements for Mixture of Experts (MoE) models, particularly for Qwen3 and Qwen3-VL architectures. Key changes include the implementation of 4-bit quantization for MoE layers using bitsandbytes, integration of grouped GEMM kernels (both native PyTorch and Triton) for optimized forward passes, and a robust patching mechanism for PEFT's ParamWrapper to support separated LoRA for MoE. Additionally, the changes improve vLLM integration by adding vllm_config propagation and refining vllm_version detection. The introduction of MXFP4 configuration for training, including conditional dequantization based on triton_kernels availability, further enhances the quantization capabilities. Overall, these changes aim to improve performance, memory efficiency, and compatibility for MoE models within the Unsloth ecosystem.
| gate_up_weight = gate_up_proj[expert_idx].data.clone() # [2*I, H] | ||
| down_weight = down_proj[expert_idx].data.clone() # [H, I] |
There was a problem hiding this comment.
The use of .data is deprecated in PyTorch. It's recommended to use .detach().clone() instead to avoid potential issues with autograd and ensure proper tensor handling.
| gate_up_weight = gate_up_proj[expert_idx].data.clone() # [2*I, H] | |
| down_weight = down_proj[expert_idx].data.clone() # [H, I] | |
| gate_up_weight = gate_up_proj[expert_idx].detach().clone() # [2*I, H] | |
| down_weight = down_proj[expert_idx].detach().clone() # [H, I] |
| os.makedirs(UNSLOTH_COMPILE_LOCATION) | ||
| except: | ||
| pass | ||
|
|
There was a problem hiding this comment.
The try...except block is too broad. It's generally better to catch specific exceptions like OSError or IOError for file operations, rather than a generic Exception. This helps in debugging and understanding the root cause of failures.
| os.makedirs(UNSLOTH_COMPILE_LOCATION) | |
| except: | |
| pass | |
| try: | |
| os.makedirs(UNSLOTH_COMPILE_LOCATION) | |
| except OSError: | |
| pass |
| pass | ||
|
|
||
|
|
||
| install_to_cache(__file__, "moe_utils.py") |
There was a problem hiding this comment.
Executing install_to_cache(__file__, "moe_utils.py") directly at the module level means this code runs every time moe_utils.py is imported. While this might be intended for setup, it can lead to unexpected side effects or performance overhead in certain scenarios (e.g., repeated imports, testing environments). Consider wrapping this call in a function or a conditional block (e.g., if __name__ == "__main__":) if it's meant for a specific initialization step, or add a comment explaining why it's necessary to run at import time.
| except Exception: | ||
| _TORCH_GROUPED_MM_SUPPORTED = False |
There was a problem hiding this comment.
The try...except Exception block is too broad. For a runtime check like this, it would be more precise to catch specific exceptions that torch._grouped_mm might raise, such as RuntimeError or AttributeError, to avoid masking other potential issues.
| except Exception: | |
| _TORCH_GROUPED_MM_SUPPORTED = False | |
| except (RuntimeError, AttributeError): | |
| _TORCH_GROUPED_MM_SUPPORTED = False |
| except Exception: | ||
| pass |
There was a problem hiding this comment.
The try...except Exception block is too broad. When dealing with external libraries like Triton, it's usually more appropriate to catch ImportError or ModuleNotFoundError if the issue is related to the library's availability, or specific runtime errors if the issue is with its usage.
| except Exception: | |
| pass | |
| except ImportError: | |
| pass |
| except Exception: | ||
| return None |
There was a problem hiding this comment.
| forward_native_grouped_mm = moe_utils.forward_native_grouped_mm | ||
| forward_triton_grouped_gemm = moe_utils.forward_triton_grouped_gemm | ||
| forward_native_moe_loop = moe_utils.forward_native_moe_loop |
There was a problem hiding this comment.
Importing forward_native_grouped_mm, forward_triton_grouped_gemm, and forward_native_moe_loop locally within the old_forward function is unusual. While it works, it can make the code harder to read and potentially lead to subtle scoping issues or unexpected behavior if the module structure changes. It's generally better practice to place imports at the top of the module or function scope where they are first needed, rather than inside a dynamically patched function. Consider moving these imports to the module level if they are consistently used across different patched functions.
| forward_native_grouped_mm = moe_utils.forward_native_grouped_mm | ||
| forward_triton_grouped_gemm = moe_utils.forward_triton_grouped_gemm | ||
| forward_native_moe_loop = moe_utils.forward_native_moe_loop |
There was a problem hiding this comment.
Similar to the old_forward function, importing forward_native_grouped_mm, forward_triton_grouped_gemm, and forward_native_moe_loop locally within this forward function is not ideal. It's generally better practice to place imports at the top of the module or function scope where they are first needed. For dynamically patched functions, this pattern might be a workaround, but it could impact readability and maintainability. If these functions are meant to be globally available to the patched methods, module-level imports would be clearer.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 60350d5958
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| # Get dimensions from the original module | ||
| gate_up_proj = experts.gate_up_proj | ||
| num_experts = gate_up_proj.shape[0] | ||
| intermediate_dim = gate_up_proj.shape[1] // 2 | ||
| hidden_dim = gate_up_proj.shape[2] |
There was a problem hiding this comment.
Fix MoE 4-bit dim inference for Qwen3-VL layout
The quantized module infers intermediate_dim and hidden_dim from gate_up_proj.shape[1]/[2], which assumes the standard (E, 2I, H) layout. This same commit patches Qwen3‑VL MoE to store weights in grouped_mm format (E, H, 2I) in qwen3_vl_moe.py (lines 189–231), so for Qwen3‑VL this calculation swaps dimensions (intermediate_dim becomes H/2 and hidden_dim becomes 2*I). The result is a quantized module with wrong shapes, which will fail to load weights or produce invalid outputs. Consider detecting the transposed layout or using config.hidden_size/moe_intermediate_size instead of raw tensor dims.
Useful? React with 👍 / 👎.
| # Save weight data | ||
| gate_up_prefix = f"{prefix}gate_up_projs.{expert_idx}." | ||
| down_prefix = f"{prefix}down_projs.{expert_idx}." | ||
|
|
||
| destination[f"{gate_up_prefix}weight"] = ( |
There was a problem hiding this comment.
Make 4-bit MoE save/load key scheme consistent
The custom _save_to_state_dict writes weights under gate_up_projs.*/down_projs.*, but _load_from_state_dict only recognizes the stacked gate_up_proj/down_proj keys and otherwise defers to the default loader (which expects _bnb_gate_up_weights.*). A checkpoint saved by this module will therefore reload with missing quantized weights (keys become unexpected and the ParameterLists stay empty), breaking inference on reload. The save/load key names need to align so round‑trip works.
Useful? React with 👍 / 👎.
|
Closing in favor of #450 |
Please take a look at #396 first and then this