fix(mlx): disable fused MRoPE for Qwen3-VL training to allow VJP#750
Merged
danielhanchen merged 2 commits intoJun 11, 2026
Merged
Conversation
Training Qwen3-VL on MLX crashes in value_and_grad with: ValueError: [Primitive::vjp] Not implemented for CustomKernel. The Qwen3-VL language tower's MRoPERotaryEmbedding routes through a fused Metal kernel whenever Metal is available (mlx-vlm 0.6.x), and that kernel has no gradient implementation. The same situation exists in qwen3_5 and is solved by PR unslothai#738 via _disable_fused_mrope, which flips fused_apply off on each rotary module so apply_rotary takes its differentiable cos/sin fallback. Wiring: add a 'qwen3_vl in model_type' block in trainer.py that calls _disable_fused_mrope(model). The function is the same one introduced by PR unslothai#738 (also added here so this PR is self-contained for testing; on rebase after unslothai#738 lands the function definition will dedupe). Verified on M2 16GB with unsloth/Qwen3-VL-2B-Instruct + unsloth/LaTeX_OCR, vision-frozen LoRA, 5 steps: - Studio logs: 'Disabled fused MRoPE kernel on 28 modules for training' - Step 1 loss 1.61, grad 4.34, finite throughout - Avg loss 1.73 over 5 steps, adapter saved Note: testing also required unslothai#749 (mlx-vlm flat image list); both fixes are needed end-to-end for Qwen3-VL training, but they fix independent errors in different layers.
Contributor
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
…ble-fused-mrope # Conflicts: # unsloth_zoo/mlx/loader.py # unsloth_zoo/mlx/trainer.py
Member
|
Merged current main into the branch now that #738 landed, resolving the predicted overlap: the duplicate _disable_fused_mrope definition is dropped (the helper now comes from #738 in loader.py) and this PR keeps only the qwen3_vl wiring block in trainer.py, next to the qwen3_5 block that gained patch_gated_delta_vlm on main. Verified a single helper definition remains, the wiring imports it from loader, and the MLX test files pass. The Metal VJP crash repro itself stays Apple-hardware validation, which your M2 run already covered. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Training Qwen3-VL on MLX crashes in
value_and_gradwith:before step 1. The Qwen3-VL language tower's
MRoPERotaryEmbeddingroutes through a fused Metal kernel whenfused_apply=True(mlx-vlm 0.6.x default), and that kernel has no gradient implementation.This is the same family of bug PR #738 fixes for
qwen3_5. The function_disable_fused_mropeis generic — it walks the model and flipsfused_applyoff on any module that has it — but PR #738 only wires it into theqwen3_5trainer block. This PR adds the parallel wiring for Qwen3-VL.Relationship to #738
_disable_fused_mropeinloader.pyand wires it forqwen3_5_disable_fused_mrope(so it can be tested standalone) and wires it forqwen3_vlOn rebase after #738 lands, the duplicate function definition resolves trivially (drop it from this PR, keep only the trainer-side wiring).
Verification
Qwen3-VL's language tower gates
position_embeddingsprecomputation onfused_apply, exactly mirroring qwen3_5. Whenfused_apply=False,apply_rotarytakes the differentiable cos/sin fallback.Tested on M2 16GB with
unsloth/Qwen3-VL-2B-Instruct+unsloth/LaTeX_OCR, vision-frozen LoRA, 5 steps:28 rotary modules had
fused_apply=True; all flipped off, training proceeds with finite gradients throughout.Note on dependency
End-to-end Qwen3-VL training also requires PR #749 (mlx-vlm flat image list), which fixes an independent error in the dataloader. With #749 + this PR, training proceeds end-to-end on M2.