Skip to content

[Bug] Qwen3.5 LoRA SFT in Studio (MLX): patched_attn_call rejects position_embeddings (+ second VJP error after fix) #6002

@felixboelter

Description

@felixboelter
  1. Did you update? N/A. Running Unsloth Studio (the packaged app), not pip. Bundled unsloth_zoo is 2026.6.1 and I launched Studio today after the morning's Mac/MLX commits, so this is on the latest.
  2. Colab or Kaggle or local / cloud -> Local Mac. M2 Pro, macOS 14.6.1, 16 GB unified memory.
  3. Number GPUs used, use nvidia-smi -> N/A. Apple Silicon, no CUDA.
  4. Which notebook? Please link! -> N/A. Studio GUI, SFT tab.
  5. Which Unsloth version, TRL version, transformers version, PyTorch version?
    • unsloth_zoo 2026.6.1, mlx 0.31.2, mlx-lm 0.31.3, mlx-vlm 0.6.1
    • transformers 4.57.6 / trl 0.23.1 (not used on the MLX path)
    • torch not installed
    • Python 3.13.3
  6. Which trainer? -> Studio's MLX SFT trainer (unsloth_zoo/mlx/trainer.py).

Studio config that reproduces:

# Model:    mlx-community/Qwen3.5-0.8B-8bit  (also tried -bf16 and -MLX-4bit, same result)
# Adapter:  LoRA r=16, alpha=16, dropout=0
# Trainer:  SFT
# BS=2, grad_accum=4, max_seq_len=2048, optim=adamw, scheduler=linear
# Features: CCE, gradient checkpointing, mx.compile
# Dataset:  ChatML JSONL, 31 dialogues, ~22 turns each (system+user+assistant)

There are two separate crashes here, the second only visible after working around the first.

First crash: patched_attn_call doesn't accept position_embeddings

Studio crashes before any step completes:

TypeError: _fix_qwen35_attention_cache.<locals>.patched_attn_call() got an
unexpected keyword argument 'position_embeddings'
  at unsloth_zoo/mlx/utils.py:93 (checkpointed_fn)
  at mlx_vlm/models/qwen3_5/language.py:1738

The wrapper in unsloth_zoo/mlx/loader.py:514 was written against an older mlx-vlm:

def patched_attn_call(self, x, mask=None, cache=None, position_ids=None):
    ...

But mlx-vlm 0.6.1's Qwen3_5Attention.__call__ has added two new kwargs, position_embeddings and target_verify, both of which the decoder layer at qwen3_5/language.py:1738 now passes. The wrapper has no way to accept them, hence the TypeError.

Adding **kwargs to the signature and forwarding it makes the error disappear:

def patched_attn_call(self, x, mask=None, cache=None, position_ids=None, **kwargs):
    if cache is None and position_ids is None:
        import mlx.core as mx
        L = x.shape[1]
        position_ids = mx.arange(L)
        position_ids = mx.expand_dims(position_ids, axis=0)
        position_ids = mx.tile(position_ids, (3, 1, 1))
    return original_attn_call(
        self, x, mask=mask, cache=cache, position_ids=position_ids, **kwargs
    )

mlx-vlm is adding kwargs to this signature fairly often (MRoPE work is active), so **kwargs is more durable than spelling out today's named ones, which would just break again next time something new gets added.

Second crash: missing VJP for a CustomKernel, surfaced after the patch above

With the wrapper fixed, training gets past the forward pass and crashes again on the first backward:

ValueError: [Primitive::vjp] Not implemented for CustomKernel.
  at mlx/nn/utils.py:35 (wrapped_value_grad_fn)
  at unsloth_zoo/mlx/trainer.py:913 (step_fn -> loss_and_grad_fn)
  at unsloth_zoo/mlx/trainer.py:1130 (_train_inner)

mx.compile also fails at runtime and Studio falls back to eager. The crash reproduces in eager.

It's not GatedDeltaNet. Studio prints Unsloth: Patched GatedDeltaNet with memory-efficient custom VJP. and unsloth_zoo/gated_delta_vjp.py covers that op via @mx.custom_function. Something else in the Qwen3.5 forward graph is hitting autograd as a raw CustomKernel.

I ran a few narrowing experiments, all crashing identically:

Variant Outcome
Qwen3.5-0.8B-8bit, all features on crash
Qwen3.5-0.8B-bf16, all features on crash, so not 8-bit-specific
Qwen3.5-0.8B-MLX-4bit, all features on crash, so quant level doesn't matter
bf16 + CCE disabled crash, so not the CCE quantized matmul
mx.compile auto-falls-back to eager crash, so not a compile artifact

Between them this rules out _target_verify_qmv_kernel (only activates on quantized linears) and CCE's internal quantized matmul. What's left in mlx_vlm/models/qwen3_5/language.py that uses a bare mx.fast.metal_kernel:

  • _qwen3_5_ragged_sdpa_{one_pass,two_pass_1,two_pass_2}_kernel for full-attention SDPA. Named after a decode helper (_qwen3_5_ragged_decode_attention at line 1196), so maybe inference-only?
  • _TARGET_VERIFY_GEMV at line 75, reachable from _target_verify_linears even on non-quantized weights.

The standard attention path in mlx_lm/models/base.py:108 calls mx.fast.scaled_dot_product_attention when cache=None (the training case). That's a built-in MLX primitive with a backward already registered, so it shouldn't be where the missing VJP is.

Anything built with mx.fast.metal_kernel(...) returns a CustomKernel primitive, and no VJP is registered automatically. You have to wrap it in @mx.custom_function or pass vjp_function= at construction, otherwise the backward pass has no gradient function to call when it tries to differentiate through that kernel. That's the exact failure mode in the second crash, and the remaining suspect kernels in qwen3_5/language.py all use the bare form.

I didn't disable gradient checkpointing, which is the one feature I haven't toggled off, but it'd be surprising if mx.checkpoint were the issue. Happy to run more tests or apply maintainer patches if useful.

For context: this is the third missing-VJP-on-CustomKernel issue I've hit in Qwen3.5-on-Mac training this week, alongside ml-explore/mlx-lm#482 (closed by PR #496, but the fix didn't propagate everywhere) and the still-open ml-explore/mlx-lm#1217. mlx-vlm's Qwen3.5 has several Metal kernels written for inference that don't yet have backward implementations.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions