Skip to content

fix(mlx): disable fused MRoPE for Qwen3-VL training to allow VJP#750

Merged
danielhanchen merged 2 commits into
unslothai:mainfrom
BardiaKoopah:fix/mlx-qwen3-vl-disable-fused-mrope
Jun 11, 2026
Merged

fix(mlx): disable fused MRoPE for Qwen3-VL training to allow VJP#750
danielhanchen merged 2 commits into
unslothai:mainfrom
BardiaKoopah:fix/mlx-qwen3-vl-disable-fused-mrope

Conversation

@BardiaKoopah

Copy link
Copy Markdown
Contributor

Summary

Training Qwen3-VL on MLX crashes in value_and_grad with:

ValueError: [Primitive::vjp] Not implemented for CustomKernel.

before step 1. The Qwen3-VL language tower's MRoPERotaryEmbedding routes through a fused Metal kernel when fused_apply=True (mlx-vlm 0.6.x default), and that kernel has no gradient implementation.

This is the same family of bug PR #738 fixes for qwen3_5. The function _disable_fused_mrope is generic — it walks the model and flips fused_apply off on any module that has it — but PR #738 only wires it into the qwen3_5 trainer block. This PR adds the parallel wiring for Qwen3-VL.

Relationship to #738

On rebase after #738 lands, the duplicate function definition resolves trivially (drop it from this PR, keep only the trainer-side wiring).

Verification

$ grep -n 'fused_apply' .../mlx_vlm/models/qwen3_vl/language.py
224:            and not self.layers[0].self_attn.rotary_emb.fused_apply

Qwen3-VL's language tower gates position_embeddings precomputation on fused_apply, exactly mirroring qwen3_5. When fused_apply=False, apply_rotary takes the differentiable cos/sin fallback.

Tested on M2 16GB with unsloth/Qwen3-VL-2B-Instruct + unsloth/LaTeX_OCR, vision-frozen LoRA, 5 steps:

Unsloth: Disabled fused MRoPE kernel on 28 modules for training (no VJP).
...
Step 1/5 | Loss: 1.6067 | Grad: 4.3397
Step 2/5 | Loss: 1.9784 | Grad: 4.9310
Step 3/5 | Loss: 1.4756 | Grad: 3.6400
Step 4/5 | Loss: 2.0330 | Grad: 4.8749
Step 5/5 | Loss: 1.5666 | Grad: 3.4725
Unsloth: Training complete!

28 rotary modules had fused_apply=True; all flipped off, training proceeds with finite gradients throughout.

Note on dependency

End-to-end Qwen3-VL training also requires PR #749 (mlx-vlm flat image list), which fixes an independent error in the dataloader. With #749 + this PR, training proceeds end-to-end on M2.

Training Qwen3-VL on MLX crashes in value_and_grad with:
  ValueError: [Primitive::vjp] Not implemented for CustomKernel.

The Qwen3-VL language tower's MRoPERotaryEmbedding routes through a
fused Metal kernel whenever Metal is available (mlx-vlm 0.6.x), and
that kernel has no gradient implementation. The same situation exists
in qwen3_5 and is solved by PR unslothai#738 via _disable_fused_mrope, which
flips fused_apply off on each rotary module so apply_rotary takes its
differentiable cos/sin fallback.

Wiring: add a 'qwen3_vl in model_type' block in trainer.py that calls
_disable_fused_mrope(model). The function is the same one introduced
by PR unslothai#738 (also added here so this PR is self-contained for testing;
on rebase after unslothai#738 lands the function definition will dedupe).

Verified on M2 16GB with unsloth/Qwen3-VL-2B-Instruct +
unsloth/LaTeX_OCR, vision-frozen LoRA, 5 steps:
- Studio logs: 'Disabled fused MRoPE kernel on 28 modules for training'
- Step 1 loss 1.61, grad 4.34, finite throughout
- Avg loss 1.73 over 5 steps, adapter saved

Note: testing also required unslothai#749 (mlx-vlm flat image list); both fixes
are needed end-to-end for Qwen3-VL training, but they fix independent
errors in different layers.
@gemini-code-assist

Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

…ble-fused-mrope

# Conflicts:
#	unsloth_zoo/mlx/loader.py
#	unsloth_zoo/mlx/trainer.py
@danielhanchen

Copy link
Copy Markdown
Member

Merged current main into the branch now that #738 landed, resolving the predicted overlap: the duplicate _disable_fused_mrope definition is dropped (the helper now comes from #738 in loader.py) and this PR keeps only the qwen3_vl wiring block in trainer.py, next to the qwen3_5 block that gained patch_gated_delta_vlm on main. Verified a single helper definition remains, the wiring imports it from loader, and the MLX test files pass. The Metal VJP crash repro itself stays Apple-hardware validation, which your M2 run already covered.

@danielhanchen danielhanchen merged commit d3d833c into unslothai:main Jun 11, 2026
1 of 11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants