Skip to content

fix(mlx): Qwen3.5/3.6 VLM training — pass through new mlx-vlm attention kwargs, patch non-differentiable Metal kernels#738

Merged
danielhanchen merged 4 commits into
unslothai:mainfrom
benrey:fix/qwen35-mlx-vlm-training
Jun 11, 2026
Merged

fix(mlx): Qwen3.5/3.6 VLM training — pass through new mlx-vlm attention kwargs, patch non-differentiable Metal kernels#738
danielhanchen merged 4 commits into
unslothai:mainfrom
benrey:fix/qwen35-mlx-vlm-training

Conversation

@benrey

@benrey benrey commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Problem

LoRA training of mlx-vlm qwen3_5-family models (e.g. unsloth/Qwen3.6-35B-A3B-MLX-8bit in Unsloth Studio) crashes before the first step, twice in a row:

  1. TypeError: _fix_qwen35_attention_cache.<locals>.patched_attn_call() got an unexpected keyword argument 'position_embeddings'
  2. After fixing that: ValueError: [Primitive::vjp] Not implemented for CustomKernel raised from value_and_grad in mlx/trainer.py step_fn.

Both crashes are reported in unslothai/unsloth#6002. Reproduced on macOS with mlx 0.31.2 / mlx-lm 0.31.3 / mlx-vlm 0.6.2 / unsloth_zoo 2026.6.1.

Root causes and fixes

1. Attention wrapper rejects new mlx-vlm kwargs (mlx/loader.py)
mlx-vlm 0.6.x added position_embeddings and target_verify parameters to Qwen3_5Attention.__call__; the monkey-patch in _fix_qwen35_attention_cache has a fixed signature and crashes on them. Fix: accept and forward **kwargs, and skip synthesizing position_ids when position_embeddings is already supplied. (Overlaps #721 — happy to rebase on top of it if that lands first; this PR additionally guards the position_embeddings-provided case.)

2. patch_gated_delta() never covers the VLM module (gated_delta_vjp.py)
mlx_vlm/models/qwen3_5/gated_delta.py defines its own gated_delta_update that calls mlx_lm's gated_delta_kernel directly (imported by name), so patching mlx_lm.models.gated_delta.gated_delta_update does not intercept the VLM path — the non-differentiable kernel runs and kills the backward pass. This is easy to misdiagnose because the "Patched GatedDeltaNet" message still prints (the mlx_lm patch applies fine; it's just not the copy that executes). New patch_gated_delta_vlm() applies the same state is Nonegated_delta_ops_efficient routing to the mlx_vlm module, patching both namespaces that hold a reference (gated_delta.py and language.py's from-import). Inference calls (state provided) keep the fast kernel.

3. Fused MRoPE kernel has no VJP and no training gate (mlx/loader.py)
MRoPERotaryEmbedding.apply_rotary routes through a fused Metal kernel whenever Metal is available — including under value_and_grad — even though a differentiable cos/sin fallback exists right below it. Minimal repro:

rope = Qwen3_5RotaryEmbedding(32, max_position_embeddings=4096, base=10000.0, mrope_section=[8,4,4])
q = k = mx.random.normal((1,2,6,32))
pos = mx.tile(mx.expand_dims(mx.arange(6), 0), (3,1,1))
mx.grad(lambda q: rope.apply_rotary(q, k, pos, unsqueeze_dim=1)[0].sum())(q)
# ValueError: [Primitive::vjp] Not implemented for CustomKernel.

New _disable_fused_mrope(model) flips fused_apply off on the model's rotary embedding modules at training setup, so apply_rotary takes the differentiable fallback.

Both new fixes are wired into the existing "qwen3_5" in model_type block in mlx/trainer.py.

Why this is the complete set for training

The remaining bare mx.fast.metal_kernel users in mlx_vlm's qwen3_5 code (_qwen3_5_ragged_sdpa_*, _TARGET_VERIFY_GEMV, _target_verify_qmv_kernel) are unreachable in the trainer's forward: they require a left-padded decode KV cache or target_verify=True (which needs a gdn_sink), none of which occur during training.

Verification

  • Unit-level: mx.grad flows through the patched gated_delta_update (synthetic batch, use_kernel=True) and through apply_rotary after _disable_fused_mrope; both previously raised the VJP error.
  • End-to-end: with all three fixes applied to the 2026.6.1 install used by Unsloth Studio, a QLoRA run of unsloth/Qwen3.6-35B-A3B-MLX-8bit (M-series, 128 GB) that previously crashed at both points now trains — loss and grad-norm reported from step 1 onward, ~97% GPU utilization.

Fixes the second crash in unslothai/unsloth#6002.

🤖 Generated with Claude Code

…args and non-differentiable kernels

Three fixes for LoRA training of mlx-vlm qwen3_5-family models, which
currently crashes before the first step:

1. loader.py _fix_qwen35_attention_cache: the patched attention __call__
   had a fixed signature and rejected the position_embeddings /
   target_verify kwargs added by newer mlx-vlm releases
   (TypeError: unexpected keyword argument 'position_embeddings').
   Forward unknown kwargs through, and skip synthesizing position_ids
   when position_embeddings is already provided.

2. gated_delta_vjp.py patch_gated_delta_vlm: mlx_vlm.models.qwen3_5 ships
   its own gated_delta_update that calls mlx_lm's gated_delta_kernel
   directly, so patch_gated_delta() (mlx_lm-only) never intercepts the
   VLM path and training dies with
   ValueError: [Primitive::vjp] Not implemented for CustomKernel.
   Apply the same state=None -> gated_delta_ops_efficient routing to the
   mlx_vlm module, in both namespaces that hold a reference (the defining
   module and language.py's from-import).

3. loader.py _disable_fused_mrope: MRoPERotaryEmbedding.apply_rotary
   takes a fused Metal kernel path whenever Metal is available, with no
   gradient support and no training gate; same VJP crash. Flip
   fused_apply off on the model so apply_rotary uses its differentiable
   cos/sin fallback.

The remaining custom kernels in mlx_vlm qwen3_5 (ragged decode SDPA,
target-verify GEMV/QMV) are unreachable during training: they require a
left-padded KV cache or target_verify=True (gdn_sink), neither of which
occurs in the trainer's forward.

Fixes the second crash in unslothai/unsloth#6002; item 1 overlaps PR unslothai#721.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for training Qwen 3.5 models with mlx-vlm by monkey-patching non-differentiable custom kernels. It introduces patch_gated_delta_vlm to route gated delta updates through a memory-efficient VJP, updates attention cache patching to preserve extra keyword arguments, and adds _disable_fused_mrope to fall back to a differentiable MRoPE implementation. The feedback suggests adding a safety check to verify the existence of gated_delta_update before patching to prevent potential AttributeError crashes on different versions of mlx_vlm.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

if getattr(vlm_gated_delta, "_unsloth_gated_delta_patched", False):
return

original_update = vlm_gated_delta.gated_delta_update

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To prevent potential 'AttributeError' crashes on different or future versions of 'mlx_vlm', it is safer to verify that 'gated_delta_update' exists on 'vlm_gated_delta' before attempting to access and patch it.

    if not hasattr(vlm_gated_delta, "gated_delta_update"):
        return

    original_update = vlm_gated_delta.gated_delta_update

@danielhanchen

Copy link
Copy Markdown
Member

Reviewed this in depth, including running the patched code on both Linux CPU mlx and Apple Silicon CI. The two core fixes are right and we want them in; the loader.py wrapper hunk should be dropped before merge.

Validated (mlx 0.31.2, mlx-vlm 0.6.2):

  • patch_gated_delta_vlm is numerically exact: forward output, returned state, and dq/dk/dv gradients from the patched state=None path are bitwise identical (max abs diff 0) to mlx-vlm's own reference path, in masked, unmasked, and state-provided inference variants. compute_g parity, beta = sigmoid(b), the (B, Hv, Dv, Dk) zero-state shape, and the return structure all match mlx-vlm exactly. The two-namespace patch is needed and sufficient: language.py from-imports gated_delta_update, so patching only mlx_lm.models.gated_delta (what main does today) never intercepts the VLM path.
  • _disable_fused_mrope genuinely reaches the rotary modules: model.modules() traverses the plain-list layers and finds the MRoPERotaryEmbedding instances (verified by forcing fused_apply=True and observing 2/2 flips, with forward and grad still working after).
  • The remaining bare Metal kernels (_qwen3_5_ragged_sdpa_*, _TARGET_VERIFY_GEMV, _target_verify_qmv_kernel) are confirmed unreachable on the training path (cache=None, target_verify=False, no gdn_sink, no left-padded decode), so this is the complete set for training, as the PR claims.
  • Apple Silicon proof, macos-14 run on this PR merged onto current main plus Metal-only regression tests: https://github.com/unslothai/unsloth-zoo/actions/runs/27282784483. The patched gated-delta gradient runs on Metal and matches the use_kernel=False reference, the MRoPE gradient works after _disable_fused_mrope with fused-vs-fallback outputs matching, and a tiny qwen3_5 end-to-end value_and_grad step completes with both patches applied. (The one red test in that run was a bug in the reviewer-written negative control, not in this PR: the raw Metal kernel path returns a list, the test expected a tuple. Fixed; a fully green confirmation run is queued behind the macOS runner backlog as https://github.com/unslothai/unsloth-zoo/actions/runs/27288537041.) The plain merge of this PR onto main is also green on Mac CI: https://github.com/unslothai/unsloth-zoo/actions/runs/27282786321

Requested changes:

  1. Drop the unsloth_zoo/mlx/loader.py patched_attn_call hunk entirely. Main already has the **kwargs signature and forwarding from Fix Qwen3.5 attention wrapper to accept new mlx-vlm kwargs #721, so after a rebase the only thing the hunk adds is the kwargs.get("position_embeddings") is None clause, which should not land: mlx-vlm only supplies position_embeddings together with non-None position_ids (the model-level forward computes them jointly), so the clause is unreachable today. And if it ever did fire, skipping synthesis would forward position_ids=None into the original attention, whose position_ids is None branch dereferences cache.offset and crashes on cache=None, the exact failure this patch function exists to prevent. Keep _disable_fused_mrope in loader.py; that part is net-new and correct.
  2. Optional: set the idempotency marker somewhere both namespaces see it (it currently lives only on vlm_gated_delta while vlm_language.gated_delta_update is also replaced). Harmless in the normal flow, just a theoretical re-import gap.

Sequencing note: #684 (explore/mlx) keeps the "qwen3_5" in model_type trainer block without these new calls; flagged there that a post-#738 rebase must not mechanically resolve that block back to the old form.

After the rebase that drops the wrapper hunk, this is approve-ready from my side.

@danielhanchen

Copy link
Copy Markdown
Member

Confirmation run is fully green on macos-14: https://github.com/unslothai/unsloth-zoo/actions/runs/27288537041

All four Metal regression tests pass, including the negative control: differentiating the unpatched mlx-vlm gated-delta kernel raises the exact [Primitive::vjp] error from the issue, the patched path runs and matches the reference, the MRoPE gradient works after _disable_fused_mrope, and the end-to-end qwen3_5 training step completes with both patches applied. Validation on this PR is done; only the loader.py hunk removal from the earlier comment remains before merge.

@danielhanchen

Copy link
Copy Markdown
Member

Pushed two commits to this branch via maintainer edit to get it merge-ready: a sync merge of current main, and a cleanup commit (d47ff61) that applies the review above. The cleanup drops the patched_attn_call hunk entirely (main already has the **kwargs forwarding from #721, and the position_embeddings guard was unreachable and would reintroduce the cache.offset crash if it ever fired) and tightens the new docstrings and comments.

The PR diff is now exactly the two validated fixes plus their wiring:

  • unsloth_zoo/gated_delta_vjp.py: patch_gated_delta_vlm()
  • unsloth_zoo/mlx/loader.py: _disable_fused_mrope()
  • unsloth_zoo/mlx/trainer.py: both wired into the qwen3_5 block

The CPU equivalence tests still pass bitwise on the updated branch (forward, state, and dq/dk/dv gradients identical to the mlx-vlm reference; inference path untouched). Combined with the green Apple Silicon runs above, this is ready to merge from my side.

@danielhanchen

Copy link
Copy Markdown
Member

@codex review

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: d47ff613c2

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth_zoo/gated_delta_vjp.py Outdated
Comment on lines +288 to +290
from mlx_vlm.models.qwen3_5 import gated_delta as vlm_gated_delta
from mlx_vlm.models.qwen3_5 import language as vlm_language
except ImportError:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve patching for supported mlx-vlm releases

In environments with the declared supported dependency range (mlx-vlm>=0.4.4 in pyproject.toml), Qwen3.5 language.py imports gated_delta_update directly from mlx_lm.models.gated_delta and there is no mlx_vlm.models.qwen3_5.gated_delta module, so this new import raises and the function returns before rebinding vlm_language.gated_delta_update. Because patch_gated_delta() only mutates the mlx-lm module object after the from-import has already copied the old function, Qwen3.5 VLM training on those supported versions still calls the non-differentiable/kernel path this change is meant to avoid. Please handle the older import layout (or raise the minimum mlx-vlm version) rather than silently no-oping here.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Valid catch, confirmed against real installs. On mlx-vlm 0.5.0 the qwen3_5 package has no gated_delta module and language.py holds a by-value copy of mlx_lm's gated_delta_update taken at import time, so the original version of this function ImportErrored and silently no-oped while patch_gated_delta() left the stale copy unpatched.

Fixed in 7f60f49: the gated_delta import is now tried separately, and on the legacy layout the function rebinds language.gated_delta_update to a forwarder that late-binds through the mlx_lm module attribute, so it picks up patch_gated_delta()'s patched version regardless of call order. Verified on mlx-vlm 0.5.0 (CPU): the stale-copy bug reproduces before the fix, the rebind takes effect after, calls route through the patched function, gradients flow, and a second call is a no-op. The 0.6.x path is unchanged and its bitwise equivalence tests still pass.

@danielhanchen danielhanchen merged commit e0f7e60 into unslothai:main Jun 11, 2026
1 check failed
@danielhanchen

Copy link
Copy Markdown
Member

Thanks for the PR!

danielhanchen added a commit that referenced this pull request Jun 11, 2026
Training Qwen3-VL on MLX crashes in value_and_grad with:
  ValueError: [Primitive::vjp] Not implemented for CustomKernel.

The Qwen3-VL language tower's MRoPERotaryEmbedding routes through a
fused Metal kernel whenever Metal is available (mlx-vlm 0.6.x), and
that kernel has no gradient implementation. The same situation exists
in qwen3_5 and is solved by PR #738 via _disable_fused_mrope, which
flips fused_apply off on each rotary module so apply_rotary takes its
differentiable cos/sin fallback.

Wiring: add a 'qwen3_vl in model_type' block in trainer.py that calls
_disable_fused_mrope(model). The function is the same one introduced
by PR #738 (also added here so this PR is self-contained for testing;
on rebase after #738 lands the function definition will dedupe).

Verified on M2 16GB with unsloth/Qwen3-VL-2B-Instruct +
unsloth/LaTeX_OCR, vision-frozen LoRA, 5 steps:
- Studio logs: 'Disabled fused MRoPE kernel on 28 modules for training'
- Step 1 loss 1.61, grad 4.34, finite throughout
- Avg loss 1.73 over 5 steps, adapter saved

Note: testing also required #749 (mlx-vlm flat image list); both fixes
are needed end-to-end for Qwen3-VL training, but they fix independent
errors in different layers.

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Lyxot added a commit to Lyxot/unsloth-zoo that referenced this pull request Jun 12, 2026
After rebasing onto unslothai#738: drop its legacy-layout branch (the binding
sweep already rebinds mlx-vlm 0.4-0.5 from-imports), route its
mlx-vlm >= 0.6 training branch through the fused-kernel dispatch (ops
VJP under whole-step mx.compile would reintroduce the compile_fuse
wedge the kernels eliminated), and teach the sweep to recognize the
sibling patch instead of warning about a foreign implementation.
Lyxot added a commit to Lyxot/unsloth-zoo that referenced this pull request Jun 12, 2026
After rebasing onto unslothai#738: drop its legacy-layout branch (the binding
sweep already rebinds mlx-vlm 0.4-0.5 from-imports), route its
mlx-vlm >= 0.6 training branch through the fused-kernel dispatch (ops
VJP under whole-step mx.compile would reintroduce the compile_fuse
wedge the kernels eliminated), and teach the sweep to recognize the
sibling patch instead of warning about a foreign implementation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants