Skip to content

Add Metal VJP kernel for gated_delta_update (trainable Qwen3.5 / Qwen3-Next LoRA on Apple Silicon)#1217

Open
SudarkinV wants to merge 3 commits into
ml-explore:mainfrom
SudarkinV:feat/gated-delta-vjp-narrow
Open

Add Metal VJP kernel for gated_delta_update (trainable Qwen3.5 / Qwen3-Next LoRA on Apple Silicon)#1217
SudarkinV wants to merge 3 commits into
ml-explore:mainfrom
SudarkinV:feat/gated-delta-vjp-narrow

Conversation

@SudarkinV

Copy link
Copy Markdown

Add Metal VJP kernel for gated_delta_update (trainable Qwen3.5 / Qwen3-Next LoRA on Apple Silicon)

Extends #496 with a Metal backward kernel for gated_delta_update. The
existing use_kernel=False fallback into gated_delta_ops unblocks the
tracing error but unrolls an O(T)-node auto-diff graph, which OOMs at
T ≥ 2048 on a 36 GB M-series Mac — the common full-parameter / LoRA
fine-tuning setup for Qwen3.5-9B and Qwen3-Next-80B.

In merged PR #997, @angeloskath noted "This does affect finetuning
fairly heavily but I think we need a kernel for that to be an enjoyable
experience anyway."
This PR is that kernel.

Related context: #1206 reports Qwen3.5-9B LoRA crashing on the first
backward pass. I do not claim this PR fully fixes that hardware-specific
report, but it targets the same training-time bottleneck: avoiding an
O(T)-node autodiff graph through the gated-delta recurrence by providing
a custom VJP path.

What this PR adds

Two new modules and a small integration change:

  • mlx_lm/models/gated_delta_vjp.py — pure-Python reference VJP
    using mx.checkpoint on fixed-size chunks. O(T/CHUNK_SIZE) autodiff
    graph, verifiable against gated_delta_ops. Used as the import
    fallback.
  • mlx_lm/models/gated_delta_vjp_metal.py — Metal backward kernel
    registered via mx.custom_function. Forward-with-save + reverse
    sweep in the same chunked layout; threadgroup-local reduction (no
    atomics, deterministic). 8–11× faster than the Python VJP at
    Qwen3.5-9B shapes.
  • mlx_lm/models/gated_delta.pygated_delta_update() gains a
    training: bool = False argument that routes to the VJP path. The
    Metal backward is selected only when GPU/Metal is available, mask is None, Dk % 32 == 0 and Dv % 4 == 0; otherwise the call falls back
    to the Python VJP (which has no shape constraints and runs on CPU).
    The existing use_kernel and mask/inference behaviour is unchanged.
  • mlx_lm/models/qwen3_5.py / mlx_lm/models/qwen3_next.py
    linear_attn call site passes training=self.training.

Inference and KV-cache paths are untouched.

Correctness

New tests appended to tests/test_models.py (TestModels class,
unittest.TestCase style, matching the existing test_gated_delta*
block):

  • test_gated_delta_vjp_forward_equivalence — Python VJP forward is
    bit-exact against gated_delta_update(use_kernel=False).
  • test_gated_delta_vjp_fd_gradient — central-difference check of
    mx.grad output vs. analytic backward on a toy shape (B=1, T=4, Hk=2, Hv=4, Dk=8, Dv=8, fp32), tolerance 1e-3.
  • test_gated_delta_vjp_metal_matches_python — Metal VJP forward and
    state agree with the Python reference within fp32 SIMD-reduction
    noise (atol=1e-3).
  • test_gated_delta_vjp_metal_gradients_match_pythonmx.grad
    output of the Metal VJP matches the Python VJP across all eight
    trainable inputs (q, k, v, a, b, A_log, dt_bias, state) within
    atol=1e-3, rtol=1e-3.

All eight test_gated_delta* tests pass locally (python -m unittest tests.test_models.TestModels -k gated_delta).

Performance (Qwen3.5-9B linear_attn shape: `B=1, Hk=16, Hv=64, Dk=192,

Dv=128`, bf16)

T use_kernel=False (ops) fwd+bwd Python VJP fwd+bwd Metal VJP fwd+bwd Peak mem (Metal)
256 152 ms (graph-bound) 145.3 ms 13.4 ms 1.8 GB
512 304 ms 296.3 ms 28.2 ms 3.0 GB
1024 617 ms 599.6 ms 62.2 ms 4.7 GB
2048 OOM on 36 GB 1233.5 ms 149.8 ms 8.1 GB

End-to-end training (Qwen3.5-9B LoRA on 36 GB M4 Max, max_seq=4096)

500-iteration full LoRA run on the unfiltered training set, batch=1,
grad_checkpoint=true, 4 LoRA keys (q_proj, v_proj, in_proj_qkv, out_proj):

Iter Val loss Peak mem
1 0.524 9.1 GB
50 0.248 9.2 GB
100 0.121 9.2 GB
200 0.143 9.2 GB
300 0.246 13.8 GB
500 0.155 13.8 GB

Converges in this configuration; peak memory stable across the run.
Total time ≈ 84 minutes (10 s/iter). This is one observation on one
shape and is not a generic stability guarantee for all training
configurations.

Relationship to #496

PR #496 added the use_kernel: bool = True routing so that
.training falls through to gated_delta_ops. This PR reuses that
signal — the new training=True path selects the VJP module; anything
else goes through the existing use_kernel branch. No behaviour change
for inference or for the use_kernel=False eval path.

Scope

This PR covers only the training-time VJP/backward path for
gated_delta_update.

Out of scope and not included:

  • inference kernel changes
  • speculative decoding
  • prefix-scan prototypes
  • broader training or quantization changes

Files

  • mlx_lm/models/gated_delta_vjp.py (new, ~180 LoC)
  • mlx_lm/models/gated_delta_vjp_metal.py (new, ~770 LoC)
  • mlx_lm/models/gated_delta.py (+~30 LoC)
  • mlx_lm/models/qwen3_5.py (+1 LoC)
  • mlx_lm/models/qwen3_next.py (+1 LoC)
  • tests/test_models.py (+~160 LoC)

Viktor Sudarkin added 3 commits April 27, 2026 22:59
Provides a training-time backward path for gated_delta_update when
the current use_kernel=False fallback (gated_delta_ops) exceeds the
36 GB unified-memory budget on Apple Silicon at T >= 2048.

New files:
- gated_delta_vjp.py: pure-Python chunked reference VJP with
  mx.checkpoint; O(T/chunk) autodiff graph.
- gated_delta_vjp_metal.py: Metal backward kernel registered as
  mx.custom_function; reverse sweep over saved state history with
  threadgroup reduction. 8-11x faster than the Python reference
  and bit-identical gradients up to fp32 SIMD-reduction noise.

Integration:
- gated_delta_update() gets a new training=False argument. When
  set (by Qwen3.5 / Qwen3-Next self_attn at .train()), routing
  picks the Metal VJP first and the Python VJP as import fallback.
- qwen3_5.py / qwen3_next.py linear_attn call sites set
  training=self.training.

Tests (appended to tests/test_models.py TestModels):
- test_gated_delta_vjp_forward_equivalence: Python VJP forward
  matches gated_delta_update(use_kernel=False).
- test_gated_delta_vjp_fd_gradient: central-difference check of
  mx.grad vs. analytic backward on a toy shape.
- test_gated_delta_vjp_metal_matches_python: Metal backward matches
  the Python reference up to fp32 SIMD-reduction noise.

Inference path and kv-cache behaviour are unchanged.
- gated_delta_vjp / gated_delta_vjp_metal: initialize default state as
  fp32 to match the existing gated_delta_update path. Previously the VJP
  modules used q.dtype, silently downgrading the recurrent state to
  bf16 during Qwen3.5 / Qwen3-Next training when state=None.

- gated_delta_update: training=True now selects the Metal VJP only when
  GPU/Metal is available, mask is None, Dk%32==0 and Dv%4==0. Otherwise
  it falls back to the Python VJP, which has no shape constraints and
  also runs on CPU. Previously training=True unconditionally invoked
  the Metal kernel and crashed for non-GPU runs and for shapes the
  kernel does not handle.

- tests/test_models.py: add test_gated_delta_vjp_metal_gradients_match_python
  to exercise mx.grad through both the Metal and Python VJP and compare
  gradients for q, k, v, a, b, A_log, dt_bias, state. The previous
  Metal-vs-Python test only compared forward outputs, so a broken Metal
  backward could have passed.
The previous fix initialised the default recurrent state as fp32 to
match the existing gated_delta_update path, but the Metal forward and
backward kernels typed all state buffers as InT (the input dtype).
For bf16 inputs with state=None this produced a Metal compile error:

    error: incompatible pointer types assigning to
    'const device bfloat *' from 'const device float *'

caused by `S_prev_row = s_initial + ...` where s_initial points to
the fp32 state_initial while S_prev_row was declared `const device InT*`.

Fix: introduce a separate StT template parameter for state-typed
buffers and route it explicitly through both the scalar and the
vectorised forward/backward kernels. State writes (state_history,
state_out, dS_initial) are cast to StT instead of InT, and the Python
wrappers _fwd_save / _bwd publish the state dtype on the kernel
template and on the corresponding output_dtypes slots.

Also adds tests/test_gated_delta_vjp_bf16_default_state to cover the
public training route with bf16 inputs and state=None — the configuration
the previous combination would have crashed on.
@SudarkinV

Copy link
Copy Markdown
Author

@angeloskath gentle ping when you have bandwidth.

In #997 you noted that the fp32 gated-delta state affects finetuning fairly heavily
and that a kernel would likely be needed for that path. This PR is intended as that
training-time VJP/backward kernel for gated_delta_update.

It is scoped to the training/backward path only; inference routing is unchanged.

@tsato081

Copy link
Copy Markdown

Hi @SudarkinV, FYI: I've opened #1389, which addresses the same training-path problem on the ops side, replacing the sequential gated_delta_ops fallback with the chunked (UT/WY) formulation wrapped in mx.checkpoint.

While benchmarking I also ran your branch on an M3 Ultra (your table's shape, bf16, fwd+bwd): your Metal kernel is faster where its constraints hold (0.097 s vs 0.117 s at T=2048), while the chunked path uses less peak memory (3.0 GB vs 8.5 GB). The two work well together: if your PR lands, the chunked path could replace the sequential loop in the fallback your routing already has. Happy to rebase mine to fit.

@SudarkinV

Copy link
Copy Markdown
Author

Hi @tsato081 — thanks for the heads-up and for #1389. I spent a day
verifying your branch independently on an M4 Max (40-core GPU, 36 GB),
since the 36 GB machines in #1206 are Max-class hardware. Findings:

Correctness. I cross-checked gated_delta_ops_chunked against
three implementations (the sequential ops reference, plus the two VJP
implementations from this PR) on GQA shapes with carried-in state:
forward and final state agree to fp32 noise, and gradients for all
eight inputs agree to rel ≤ 1.3e-5. I also ran cases your tests don't
cover: gradients under padding masks (left/right), repeated unit keys
with β=0.999 through the blocked solve, and a CPU-device run — all
clean. Gradients at decay extremes are clean too, with one documented
caveat: below the 1e-12 clamp the gradient w.r.t. g is zeroed
where the sequential reference returns a nonzero value — harmless in
practice (the model's parameterization multiplies by g≈0 anyway), and
this PR's Python VJP has the equivalent dead zone via its own clamp.
At T=4096 bf16 your path and this PR's Metal kernel show matching
drift vs the fp32 sequential reference (y rel 5.64e-3 both, final
state 5.57e-3 vs 5.59e-3). Numerically interchangeable.

One stress-test datapoint that supports your defaults: at the
no-decay/collinear corner (repeated unit keys, β=0.999, g→1) the
blocked solve degrades gracefully at SUB_BLOCK=16 (rel ~4e-4) but
blows up at SUB_BLOCK=32 (rel ~2e11). Your sb=16 choice is the right
one; I'd keep it.

End-to-end LoRA. 30-iteration Qwen3.5-9B LoRA smoke (real SFT
dataset, identical seed, merged tree with both branches behind an env
switch): loss curves agree within reordering noise at every report
point (≤0.004; final val 0.125 vs 0.128), peak memory 11.48 vs
11.55 GB. After re-running the kernel path on an equally warmed-up
machine, throughput matches too (single runs each, so ±5–10% noise
applies). Caveats: batch=1, 16 LoRA layers, samples average ~350
tokens — which also means the large memory gaps from long-context
micro-benchmarks don't show up in this particular e2e; they would
with ≥2k-token samples. Your OOM fix vs the main fallback stands
regardless — this comparison is specifically chunked vs this PR's
kernel.

Performance on Max-class chips. On M4 Max your path is faster
than this PR's Metal kernel at T=256–2048 and roughly equal at
T=4096 (within-process pairs, e.g. 184.9 ms / 3.68 GB vs 216.8 ms /
9.37 GB at T=2048, fwd+bwd bf16; absolute times drift ±20% between
sessions on this laptop, ratios within a process are stable). That
matches your M3 Ultra numbers pointing the other way — plausibly the
kernel's latency-bound sequential sweep vs your matmul-throughput
path trading off with GPU core count, though I haven't profiled
this. Since #1206-class machines are Max-class, I agree #1389 is the
right default fallback and I'd support it landing regardless of what
happens with this PR.

A speedup for your path. While verifying I found a GQA
optimization worth ~13–14% at your exact defaults (C=64, sb=16),
same-or-lower peak memory, same numerics: in Qwen3.5, q/k have
Hk=16 heads vs Hv=64, so k @ k^T and q @ k^T — the two heaviest
[C,C] matmuls — can be computed on Hk heads before the GQA repeat,
with the per-Hv gating (β, decay mask) applied afterwards via
broadcast; the solve stays per-Hv. Within one process on M4 Max at
T=4096: 487.9 ms / 5.92 GB vs 569.9 ms / 9.05 GB for your branch
(T=2048: 160.3 ms / 3.10 GB vs 183.7 ms / 3.68 GB); an interleaved
A/B protocol (alternating pairs, two fresh processes) puts the
median ratio higher — 1.21–1.22x at T=2048 and 1.34–1.37x at
T=4096 — so ~13–14% is the conservative floor. Part of the win is
memory traffic: the current code materializes the GQA repeat of
q/k to Hv heads for the full sequence before the chunk loop, while
the patch keeps them at Hk and broadcasts inside the checkpointed
chunk (peak drops accordingly at long T). Verified the
same correctness battery I ran on your branch (cross-check vs the
sequential reference incl. gradients and masks, the collinear stress
above, plus a 30-iter e2e LoRA run — val loss 0.128, same as your
path). For Qwen3-Next the effect should be smaller (rf=2). It's
~30 lines against your branch — happy to send it as a patch to
#1389 if you want it.

Given all this: +1 to #1389 as the default training fallback. Happy
to rebase this PR on top of yours or slim it down — maintainers'
call.

@tsato081

Copy link
Copy Markdown

@SudarkinV
Thank you for taking a full day to verify this so thoroughly. The sb=32 result and the Max-class numbers are both very useful; I'd only tested on the Ultra, so the flip was good to learn about. I'll add your collinear stress case to the test suite as a regression guard for sb, and note the clamp dead-zone in the PR description. I agree it's harmless, but it should be on the record.

And yes, please send the GQA patch, I'll run it through the equivalence tests here and fold it into #1389 with credit.

Thanks for the +1 as well, If Ultra numbers for the kernel would help at any point, happy to run them.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants