fix(mlx): Qwen3.5/3.6 VLM training — pass through new mlx-vlm attention kwargs, patch non-differentiable Metal kernels#738
Conversation
…args and non-differentiable kernels Three fixes for LoRA training of mlx-vlm qwen3_5-family models, which currently crashes before the first step: 1. loader.py _fix_qwen35_attention_cache: the patched attention __call__ had a fixed signature and rejected the position_embeddings / target_verify kwargs added by newer mlx-vlm releases (TypeError: unexpected keyword argument 'position_embeddings'). Forward unknown kwargs through, and skip synthesizing position_ids when position_embeddings is already provided. 2. gated_delta_vjp.py patch_gated_delta_vlm: mlx_vlm.models.qwen3_5 ships its own gated_delta_update that calls mlx_lm's gated_delta_kernel directly, so patch_gated_delta() (mlx_lm-only) never intercepts the VLM path and training dies with ValueError: [Primitive::vjp] Not implemented for CustomKernel. Apply the same state=None -> gated_delta_ops_efficient routing to the mlx_vlm module, in both namespaces that hold a reference (the defining module and language.py's from-import). 3. loader.py _disable_fused_mrope: MRoPERotaryEmbedding.apply_rotary takes a fused Metal kernel path whenever Metal is available, with no gradient support and no training gate; same VJP crash. Flip fused_apply off on the model so apply_rotary uses its differentiable cos/sin fallback. The remaining custom kernels in mlx_vlm qwen3_5 (ragged decode SDPA, target-verify GEMV/QMV) are unreachable during training: they require a left-padded KV cache or target_verify=True (gdn_sink), neither of which occurs in the trainer's forward. Fixes the second crash in unslothai/unsloth#6002; item 1 overlaps PR unslothai#721. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request adds support for training Qwen 3.5 models with mlx-vlm by monkey-patching non-differentiable custom kernels. It introduces patch_gated_delta_vlm to route gated delta updates through a memory-efficient VJP, updates attention cache patching to preserve extra keyword arguments, and adds _disable_fused_mrope to fall back to a differentiable MRoPE implementation. The feedback suggests adding a safety check to verify the existence of gated_delta_update before patching to prevent potential AttributeError crashes on different versions of mlx_vlm.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if getattr(vlm_gated_delta, "_unsloth_gated_delta_patched", False): | ||
| return | ||
|
|
||
| original_update = vlm_gated_delta.gated_delta_update |
There was a problem hiding this comment.
To prevent potential 'AttributeError' crashes on different or future versions of 'mlx_vlm', it is safer to verify that 'gated_delta_update' exists on 'vlm_gated_delta' before attempting to access and patch it.
if not hasattr(vlm_gated_delta, "gated_delta_update"):
return
original_update = vlm_gated_delta.gated_delta_update|
Reviewed this in depth, including running the patched code on both Linux CPU mlx and Apple Silicon CI. The two core fixes are right and we want them in; the loader.py wrapper hunk should be dropped before merge. Validated (mlx 0.31.2, mlx-vlm 0.6.2):
Requested changes:
Sequencing note: #684 (explore/mlx) keeps the After the rebase that drops the wrapper hunk, this is approve-ready from my side. |
|
Confirmation run is fully green on macos-14: https://github.com/unslothai/unsloth-zoo/actions/runs/27288537041 All four Metal regression tests pass, including the negative control: differentiating the unpatched mlx-vlm gated-delta kernel raises the exact |
|
Pushed two commits to this branch via maintainer edit to get it merge-ready: a sync merge of current main, and a cleanup commit (d47ff61) that applies the review above. The cleanup drops the The PR diff is now exactly the two validated fixes plus their wiring:
The CPU equivalence tests still pass bitwise on the updated branch (forward, state, and dq/dk/dv gradients identical to the mlx-vlm reference; inference path untouched). Combined with the green Apple Silicon runs above, this is ready to merge from my side. |
|
@codex review |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: d47ff613c2
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| from mlx_vlm.models.qwen3_5 import gated_delta as vlm_gated_delta | ||
| from mlx_vlm.models.qwen3_5 import language as vlm_language | ||
| except ImportError: |
There was a problem hiding this comment.
Preserve patching for supported mlx-vlm releases
In environments with the declared supported dependency range (mlx-vlm>=0.4.4 in pyproject.toml), Qwen3.5 language.py imports gated_delta_update directly from mlx_lm.models.gated_delta and there is no mlx_vlm.models.qwen3_5.gated_delta module, so this new import raises and the function returns before rebinding vlm_language.gated_delta_update. Because patch_gated_delta() only mutates the mlx-lm module object after the from-import has already copied the old function, Qwen3.5 VLM training on those supported versions still calls the non-differentiable/kernel path this change is meant to avoid. Please handle the older import layout (or raise the minimum mlx-vlm version) rather than silently no-oping here.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Valid catch, confirmed against real installs. On mlx-vlm 0.5.0 the qwen3_5 package has no gated_delta module and language.py holds a by-value copy of mlx_lm's gated_delta_update taken at import time, so the original version of this function ImportErrored and silently no-oped while patch_gated_delta() left the stale copy unpatched.
Fixed in 7f60f49: the gated_delta import is now tried separately, and on the legacy layout the function rebinds language.gated_delta_update to a forwarder that late-binds through the mlx_lm module attribute, so it picks up patch_gated_delta()'s patched version regardless of call order. Verified on mlx-vlm 0.5.0 (CPU): the stale-copy bug reproduces before the fix, the rebind takes effect after, calls route through the patched function, gradients flow, and a second call is a no-op. The 0.6.x path is unchanged and its bitwise equivalence tests still pass.
|
Thanks for the PR! |
Training Qwen3-VL on MLX crashes in value_and_grad with: ValueError: [Primitive::vjp] Not implemented for CustomKernel. The Qwen3-VL language tower's MRoPERotaryEmbedding routes through a fused Metal kernel whenever Metal is available (mlx-vlm 0.6.x), and that kernel has no gradient implementation. The same situation exists in qwen3_5 and is solved by PR #738 via _disable_fused_mrope, which flips fused_apply off on each rotary module so apply_rotary takes its differentiable cos/sin fallback. Wiring: add a 'qwen3_vl in model_type' block in trainer.py that calls _disable_fused_mrope(model). The function is the same one introduced by PR #738 (also added here so this PR is self-contained for testing; on rebase after #738 lands the function definition will dedupe). Verified on M2 16GB with unsloth/Qwen3-VL-2B-Instruct + unsloth/LaTeX_OCR, vision-frozen LoRA, 5 steps: - Studio logs: 'Disabled fused MRoPE kernel on 28 modules for training' - Step 1 loss 1.61, grad 4.34, finite throughout - Avg loss 1.73 over 5 steps, adapter saved Note: testing also required #749 (mlx-vlm flat image list); both fixes are needed end-to-end for Qwen3-VL training, but they fix independent errors in different layers. Co-authored-by: Daniel Han <danielhanchen@gmail.com>
After rebasing onto unslothai#738: drop its legacy-layout branch (the binding sweep already rebinds mlx-vlm 0.4-0.5 from-imports), route its mlx-vlm >= 0.6 training branch through the fused-kernel dispatch (ops VJP under whole-step mx.compile would reintroduce the compile_fuse wedge the kernels eliminated), and teach the sweep to recognize the sibling patch instead of warning about a foreign implementation.
After rebasing onto unslothai#738: drop its legacy-layout branch (the binding sweep already rebinds mlx-vlm 0.4-0.5 from-imports), route its mlx-vlm >= 0.6 training branch through the fused-kernel dispatch (ops VJP under whole-step mx.compile would reintroduce the compile_fuse wedge the kernels eliminated), and teach the sweep to recognize the sibling patch instead of warning about a foreign implementation.
Problem
LoRA training of mlx-vlm
qwen3_5-family models (e.g.unsloth/Qwen3.6-35B-A3B-MLX-8bitin Unsloth Studio) crashes before the first step, twice in a row:TypeError: _fix_qwen35_attention_cache.<locals>.patched_attn_call() got an unexpected keyword argument 'position_embeddings'ValueError: [Primitive::vjp] Not implemented for CustomKernelraised fromvalue_and_gradinmlx/trainer.pystep_fn.Both crashes are reported in unslothai/unsloth#6002. Reproduced on macOS with mlx 0.31.2 / mlx-lm 0.31.3 / mlx-vlm 0.6.2 / unsloth_zoo 2026.6.1.
Root causes and fixes
1. Attention wrapper rejects new mlx-vlm kwargs (
mlx/loader.py)mlx-vlm 0.6.x added
position_embeddingsandtarget_verifyparameters toQwen3_5Attention.__call__; the monkey-patch in_fix_qwen35_attention_cachehas a fixed signature and crashes on them. Fix: accept and forward**kwargs, and skip synthesizingposition_idswhenposition_embeddingsis already supplied. (Overlaps #721 — happy to rebase on top of it if that lands first; this PR additionally guards theposition_embeddings-provided case.)2.
patch_gated_delta()never covers the VLM module (gated_delta_vjp.py)mlx_vlm/models/qwen3_5/gated_delta.pydefines its owngated_delta_updatethat calls mlx_lm'sgated_delta_kerneldirectly (imported by name), so patchingmlx_lm.models.gated_delta.gated_delta_updatedoes not intercept the VLM path — the non-differentiable kernel runs and kills the backward pass. This is easy to misdiagnose because the "Patched GatedDeltaNet" message still prints (the mlx_lm patch applies fine; it's just not the copy that executes). Newpatch_gated_delta_vlm()applies the samestate is None→gated_delta_ops_efficientrouting to the mlx_vlm module, patching both namespaces that hold a reference (gated_delta.pyandlanguage.py's from-import). Inference calls (state provided) keep the fast kernel.3. Fused MRoPE kernel has no VJP and no training gate (
mlx/loader.py)MRoPERotaryEmbedding.apply_rotaryroutes through a fused Metal kernel whenever Metal is available — including undervalue_and_grad— even though a differentiable cos/sin fallback exists right below it. Minimal repro:New
_disable_fused_mrope(model)flipsfused_applyoff on the model's rotary embedding modules at training setup, soapply_rotarytakes the differentiable fallback.Both new fixes are wired into the existing
"qwen3_5" in model_typeblock inmlx/trainer.py.Why this is the complete set for training
The remaining bare
mx.fast.metal_kernelusers in mlx_vlm's qwen3_5 code (_qwen3_5_ragged_sdpa_*,_TARGET_VERIFY_GEMV,_target_verify_qmv_kernel) are unreachable in the trainer's forward: they require a left-padded decode KV cache ortarget_verify=True(which needs agdn_sink), none of which occur during training.Verification
mx.gradflows through the patchedgated_delta_update(synthetic batch,use_kernel=True) and throughapply_rotaryafter_disable_fused_mrope; both previously raised the VJP error.unsloth/Qwen3.6-35B-A3B-MLX-8bit(M-series, 128 GB) that previously crashed at both points now trains — loss and grad-norm reported from step 1 onward, ~97% GPU utilization.Fixes the second crash in unslothai/unsloth#6002.
🤖 Generated with Claude Code