- Did you update? N/A. Running Unsloth Studio (the packaged app), not pip. Bundled
unsloth_zoo is 2026.6.1 and I launched Studio today after the morning's Mac/MLX commits, so this is on the latest.
Colab or Kaggle or local / cloud -> Local Mac. M2 Pro, macOS 14.6.1, 16 GB unified memory.
- Number GPUs used, use
nvidia-smi -> N/A. Apple Silicon, no CUDA.
- Which notebook? Please link! -> N/A. Studio GUI, SFT tab.
- Which Unsloth version, TRL version, transformers version, PyTorch version?
unsloth_zoo 2026.6.1, mlx 0.31.2, mlx-lm 0.31.3, mlx-vlm 0.6.1
transformers 4.57.6 / trl 0.23.1 (not used on the MLX path)
torch not installed
- Python 3.13.3
- Which trainer? -> Studio's MLX SFT trainer (
unsloth_zoo/mlx/trainer.py).
Studio config that reproduces:
# Model: mlx-community/Qwen3.5-0.8B-8bit (also tried -bf16 and -MLX-4bit, same result)
# Adapter: LoRA r=16, alpha=16, dropout=0
# Trainer: SFT
# BS=2, grad_accum=4, max_seq_len=2048, optim=adamw, scheduler=linear
# Features: CCE, gradient checkpointing, mx.compile
# Dataset: ChatML JSONL, 31 dialogues, ~22 turns each (system+user+assistant)
There are two separate crashes here, the second only visible after working around the first.
First crash: patched_attn_call doesn't accept position_embeddings
Studio crashes before any step completes:
TypeError: _fix_qwen35_attention_cache.<locals>.patched_attn_call() got an
unexpected keyword argument 'position_embeddings'
at unsloth_zoo/mlx/utils.py:93 (checkpointed_fn)
at mlx_vlm/models/qwen3_5/language.py:1738
The wrapper in unsloth_zoo/mlx/loader.py:514 was written against an older mlx-vlm:
def patched_attn_call(self, x, mask=None, cache=None, position_ids=None):
...
But mlx-vlm 0.6.1's Qwen3_5Attention.__call__ has added two new kwargs, position_embeddings and target_verify, both of which the decoder layer at qwen3_5/language.py:1738 now passes. The wrapper has no way to accept them, hence the TypeError.
Adding **kwargs to the signature and forwarding it makes the error disappear:
def patched_attn_call(self, x, mask=None, cache=None, position_ids=None, **kwargs):
if cache is None and position_ids is None:
import mlx.core as mx
L = x.shape[1]
position_ids = mx.arange(L)
position_ids = mx.expand_dims(position_ids, axis=0)
position_ids = mx.tile(position_ids, (3, 1, 1))
return original_attn_call(
self, x, mask=mask, cache=cache, position_ids=position_ids, **kwargs
)
mlx-vlm is adding kwargs to this signature fairly often (MRoPE work is active), so **kwargs is more durable than spelling out today's named ones, which would just break again next time something new gets added.
Second crash: missing VJP for a CustomKernel, surfaced after the patch above
With the wrapper fixed, training gets past the forward pass and crashes again on the first backward:
ValueError: [Primitive::vjp] Not implemented for CustomKernel.
at mlx/nn/utils.py:35 (wrapped_value_grad_fn)
at unsloth_zoo/mlx/trainer.py:913 (step_fn -> loss_and_grad_fn)
at unsloth_zoo/mlx/trainer.py:1130 (_train_inner)
mx.compile also fails at runtime and Studio falls back to eager. The crash reproduces in eager.
It's not GatedDeltaNet. Studio prints Unsloth: Patched GatedDeltaNet with memory-efficient custom VJP. and unsloth_zoo/gated_delta_vjp.py covers that op via @mx.custom_function. Something else in the Qwen3.5 forward graph is hitting autograd as a raw CustomKernel.
I ran a few narrowing experiments, all crashing identically:
| Variant |
Outcome |
Qwen3.5-0.8B-8bit, all features on |
crash |
Qwen3.5-0.8B-bf16, all features on |
crash, so not 8-bit-specific |
Qwen3.5-0.8B-MLX-4bit, all features on |
crash, so quant level doesn't matter |
bf16 + CCE disabled |
crash, so not the CCE quantized matmul |
mx.compile auto-falls-back to eager |
crash, so not a compile artifact |
Between them this rules out _target_verify_qmv_kernel (only activates on quantized linears) and CCE's internal quantized matmul. What's left in mlx_vlm/models/qwen3_5/language.py that uses a bare mx.fast.metal_kernel:
_qwen3_5_ragged_sdpa_{one_pass,two_pass_1,two_pass_2}_kernel for full-attention SDPA. Named after a decode helper (_qwen3_5_ragged_decode_attention at line 1196), so maybe inference-only?
_TARGET_VERIFY_GEMV at line 75, reachable from _target_verify_linears even on non-quantized weights.
The standard attention path in mlx_lm/models/base.py:108 calls mx.fast.scaled_dot_product_attention when cache=None (the training case). That's a built-in MLX primitive with a backward already registered, so it shouldn't be where the missing VJP is.
Anything built with mx.fast.metal_kernel(...) returns a CustomKernel primitive, and no VJP is registered automatically. You have to wrap it in @mx.custom_function or pass vjp_function= at construction, otherwise the backward pass has no gradient function to call when it tries to differentiate through that kernel. That's the exact failure mode in the second crash, and the remaining suspect kernels in qwen3_5/language.py all use the bare form.
I didn't disable gradient checkpointing, which is the one feature I haven't toggled off, but it'd be surprising if mx.checkpoint were the issue. Happy to run more tests or apply maintainer patches if useful.
For context: this is the third missing-VJP-on-CustomKernel issue I've hit in Qwen3.5-on-Mac training this week, alongside ml-explore/mlx-lm#482 (closed by PR #496, but the fix didn't propagate everywhere) and the still-open ml-explore/mlx-lm#1217. mlx-vlm's Qwen3.5 has several Metal kernels written for inference that don't yet have backward implementations.
unsloth_zoois 2026.6.1 and I launched Studio today after the morning's Mac/MLX commits, so this is on the latest.ColaborKaggleor local / cloud -> Local Mac. M2 Pro, macOS 14.6.1, 16 GB unified memory.nvidia-smi-> N/A. Apple Silicon, no CUDA.unsloth_zoo2026.6.1,mlx0.31.2,mlx-lm0.31.3,mlx-vlm0.6.1transformers4.57.6 /trl0.23.1 (not used on the MLX path)torchnot installedunsloth_zoo/mlx/trainer.py).Studio config that reproduces:
There are two separate crashes here, the second only visible after working around the first.
First crash:
patched_attn_calldoesn't acceptposition_embeddingsStudio crashes before any step completes:
The wrapper in
unsloth_zoo/mlx/loader.py:514was written against an older mlx-vlm:But
mlx-vlm 0.6.1'sQwen3_5Attention.__call__has added two new kwargs,position_embeddingsandtarget_verify, both of which the decoder layer atqwen3_5/language.py:1738now passes. The wrapper has no way to accept them, hence theTypeError.Adding
**kwargsto the signature and forwarding it makes the error disappear:mlx-vlm is adding kwargs to this signature fairly often (MRoPE work is active), so
**kwargsis more durable than spelling out today's named ones, which would just break again next time something new gets added.Second crash: missing VJP for a CustomKernel, surfaced after the patch above
With the wrapper fixed, training gets past the forward pass and crashes again on the first backward:
mx.compilealso fails at runtime and Studio falls back to eager. The crash reproduces in eager.It's not GatedDeltaNet. Studio prints
Unsloth: Patched GatedDeltaNet with memory-efficient custom VJP.andunsloth_zoo/gated_delta_vjp.pycovers that op via@mx.custom_function. Something else in the Qwen3.5 forward graph is hitting autograd as a rawCustomKernel.I ran a few narrowing experiments, all crashing identically:
Qwen3.5-0.8B-8bit, all features onQwen3.5-0.8B-bf16, all features onQwen3.5-0.8B-MLX-4bit, all features onbf16+ CCE disabledmx.compileauto-falls-back to eagerBetween them this rules out
_target_verify_qmv_kernel(only activates on quantized linears) and CCE's internal quantized matmul. What's left inmlx_vlm/models/qwen3_5/language.pythat uses a baremx.fast.metal_kernel:_qwen3_5_ragged_sdpa_{one_pass,two_pass_1,two_pass_2}_kernelfor full-attention SDPA. Named after a decode helper (_qwen3_5_ragged_decode_attentionat line 1196), so maybe inference-only?_TARGET_VERIFY_GEMVat line 75, reachable from_target_verify_linearseven on non-quantized weights.The standard attention path in
mlx_lm/models/base.py:108callsmx.fast.scaled_dot_product_attentionwhencache=None(the training case). That's a built-in MLX primitive with a backward already registered, so it shouldn't be where the missing VJP is.Anything built with
mx.fast.metal_kernel(...)returns aCustomKernelprimitive, and no VJP is registered automatically. You have to wrap it in@mx.custom_functionor passvjp_function=at construction, otherwise the backward pass has no gradient function to call when it tries to differentiate through that kernel. That's the exact failure mode in the second crash, and the remaining suspect kernels inqwen3_5/language.pyall use the bare form.I didn't disable gradient checkpointing, which is the one feature I haven't toggled off, but it'd be surprising if
mx.checkpointwere the issue. Happy to run more tests or apply maintainer patches if useful.For context: this is the third missing-VJP-on-CustomKernel issue I've hit in Qwen3.5-on-Mac training this week, alongside ml-explore/mlx-lm#482 (closed by PR #496, but the fix didn't propagate everywhere) and the still-open ml-explore/mlx-lm#1217. mlx-vlm's Qwen3.5 has several Metal kernels written for inference that don't yet have backward implementations.