Fix Qwen3.5 attention wrapper to accept new mlx-vlm kwargs#721
Conversation
mlx-vlm 0.6.1 added position_embeddings and target_verify kwargs to Qwen3_5Attention.__call__. The patched wrapper rejected them with a TypeError on training start. Add **kwargs to the signature and forward to the original call. Using **kwargs instead of named params for durability as mlx-vlm frequently adds new kwargs to this interface. Fixes #6002 (first crash)
There was a problem hiding this comment.
Code Review
This pull request updates the _fix_qwen35_attention_cache function in unsloth_zoo/mlx/loader.py to accept and forward arbitrary keyword arguments (**kwargs) in the patched attention call. This ensures compatibility with any additional arguments that might be passed to the attention mechanism. There are no review comments, and I have no further feedback to provide.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
|
Validated this on mlx 0.31.2 + mlx-vlm 0.6.2 (CPU backend) with a tiny random-weight
Two notes:
|
…on kwargs, patch non-differentiable Metal kernels (#738) * fix(mlx): unblock Qwen3.5/3.6 MLX training crashing on new mlx-vlm kwargs and non-differentiable kernels Three fixes for LoRA training of mlx-vlm qwen3_5-family models, which currently crashes before the first step: 1. loader.py _fix_qwen35_attention_cache: the patched attention __call__ had a fixed signature and rejected the position_embeddings / target_verify kwargs added by newer mlx-vlm releases (TypeError: unexpected keyword argument 'position_embeddings'). Forward unknown kwargs through, and skip synthesizing position_ids when position_embeddings is already provided. 2. gated_delta_vjp.py patch_gated_delta_vlm: mlx_vlm.models.qwen3_5 ships its own gated_delta_update that calls mlx_lm's gated_delta_kernel directly, so patch_gated_delta() (mlx_lm-only) never intercepts the VLM path and training dies with ValueError: [Primitive::vjp] Not implemented for CustomKernel. Apply the same state=None -> gated_delta_ops_efficient routing to the mlx_vlm module, in both namespaces that hold a reference (the defining module and language.py's from-import). 3. loader.py _disable_fused_mrope: MRoPERotaryEmbedding.apply_rotary takes a fused Metal kernel path whenever Metal is available, with no gradient support and no training gate; same VJP crash. Flip fused_apply off on the model so apply_rotary uses its differentiable cos/sin fallback. The remaining custom kernels in mlx_vlm qwen3_5 (ragged decode SDPA, target-verify GEMV/QMV) are unreachable during training: they require a left-padded KV cache or target_verify=True (gdn_sink), neither of which occurs in the trainer's forward. Fixes the second crash in unslothai/unsloth#6002; item 1 overlaps PR #721. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> * Drop redundant attention wrapper hunk and tighten comments for PR #738 * Handle legacy mlx-vlm gated_delta import layout for PR #738 --------- Co-authored-by: Claude Fable 5 <noreply@anthropic.com> Co-authored-by: Daniel Han <danielhanchen@gmail.com>
A newer mlx-lm adding parameters (precedent: the qwen3_5 attention kwargs that needed unslothai#721) would previously TypeError inside the patched wrapper. Silently dropping them could change training semantics, so unknown kwargs now delegate to the original implementation with a one-time warning naming the new arguments — old mlx-lm is unchanged, newer mlx-lm stays correct until the patch is updated.
mlx-vlm 0.6.1 added position_embeddings and target_verify kwargs to Qwen3_5Attention.call. The patched wrapper rejected them with a TypeError on training start.
Add **kwargs to the signature and forward to the original call. Using **kwargs instead of named params for durability as mlx-vlm frequently adds new kwargs to this interface.
Fixes #6002 (first crash)