Add Metal VJP kernel for gated_delta_update (trainable Qwen3.5 / Qwen3-Next LoRA on Apple Silicon)#1217
Add Metal VJP kernel for gated_delta_update (trainable Qwen3.5 / Qwen3-Next LoRA on Apple Silicon)#1217SudarkinV wants to merge 3 commits into
Conversation
Provides a training-time backward path for gated_delta_update when the current use_kernel=False fallback (gated_delta_ops) exceeds the 36 GB unified-memory budget on Apple Silicon at T >= 2048. New files: - gated_delta_vjp.py: pure-Python chunked reference VJP with mx.checkpoint; O(T/chunk) autodiff graph. - gated_delta_vjp_metal.py: Metal backward kernel registered as mx.custom_function; reverse sweep over saved state history with threadgroup reduction. 8-11x faster than the Python reference and bit-identical gradients up to fp32 SIMD-reduction noise. Integration: - gated_delta_update() gets a new training=False argument. When set (by Qwen3.5 / Qwen3-Next self_attn at .train()), routing picks the Metal VJP first and the Python VJP as import fallback. - qwen3_5.py / qwen3_next.py linear_attn call sites set training=self.training. Tests (appended to tests/test_models.py TestModels): - test_gated_delta_vjp_forward_equivalence: Python VJP forward matches gated_delta_update(use_kernel=False). - test_gated_delta_vjp_fd_gradient: central-difference check of mx.grad vs. analytic backward on a toy shape. - test_gated_delta_vjp_metal_matches_python: Metal backward matches the Python reference up to fp32 SIMD-reduction noise. Inference path and kv-cache behaviour are unchanged.
- gated_delta_vjp / gated_delta_vjp_metal: initialize default state as fp32 to match the existing gated_delta_update path. Previously the VJP modules used q.dtype, silently downgrading the recurrent state to bf16 during Qwen3.5 / Qwen3-Next training when state=None. - gated_delta_update: training=True now selects the Metal VJP only when GPU/Metal is available, mask is None, Dk%32==0 and Dv%4==0. Otherwise it falls back to the Python VJP, which has no shape constraints and also runs on CPU. Previously training=True unconditionally invoked the Metal kernel and crashed for non-GPU runs and for shapes the kernel does not handle. - tests/test_models.py: add test_gated_delta_vjp_metal_gradients_match_python to exercise mx.grad through both the Metal and Python VJP and compare gradients for q, k, v, a, b, A_log, dt_bias, state. The previous Metal-vs-Python test only compared forward outputs, so a broken Metal backward could have passed.
The previous fix initialised the default recurrent state as fp32 to
match the existing gated_delta_update path, but the Metal forward and
backward kernels typed all state buffers as InT (the input dtype).
For bf16 inputs with state=None this produced a Metal compile error:
error: incompatible pointer types assigning to
'const device bfloat *' from 'const device float *'
caused by `S_prev_row = s_initial + ...` where s_initial points to
the fp32 state_initial while S_prev_row was declared `const device InT*`.
Fix: introduce a separate StT template parameter for state-typed
buffers and route it explicitly through both the scalar and the
vectorised forward/backward kernels. State writes (state_history,
state_out, dS_initial) are cast to StT instead of InT, and the Python
wrappers _fwd_save / _bwd publish the state dtype on the kernel
template and on the corresponding output_dtypes slots.
Also adds tests/test_gated_delta_vjp_bf16_default_state to cover the
public training route with bf16 inputs and state=None — the configuration
the previous combination would have crashed on.
|
@angeloskath gentle ping when you have bandwidth. In #997 you noted that the fp32 gated-delta state affects finetuning fairly heavily It is scoped to the training/backward path only; inference routing is unchanged. |
|
Hi @SudarkinV, FYI: I've opened #1389, which addresses the same training-path problem on the ops side, replacing the sequential While benchmarking I also ran your branch on an M3 Ultra (your table's shape, bf16, fwd+bwd): your Metal kernel is faster where its constraints hold (0.097 s vs 0.117 s at T=2048), while the chunked path uses less peak memory (3.0 GB vs 8.5 GB). The two work well together: if your PR lands, the chunked path could replace the sequential loop in the fallback your routing already has. Happy to rebase mine to fit. |
|
Hi @tsato081 — thanks for the heads-up and for #1389. I spent a day Correctness. I cross-checked One stress-test datapoint that supports your defaults: at the End-to-end LoRA. 30-iteration Qwen3.5-9B LoRA smoke (real SFT Performance on Max-class chips. On M4 Max your path is faster A speedup for your path. While verifying I found a GQA Given all this: +1 to #1389 as the default training fallback. Happy |
|
@SudarkinV And yes, please send the GQA patch, I'll run it through the equivalence tests here and fold it into #1389 with credit. Thanks for the +1 as well, If Ultra numbers for the kernel would help at any point, happy to run them. |
Add Metal VJP kernel for gated_delta_update (trainable Qwen3.5 / Qwen3-Next LoRA on Apple Silicon)
Extends #496 with a Metal backward kernel for
gated_delta_update. Theexisting
use_kernel=Falsefallback intogated_delta_opsunblocks thetracing error but unrolls an
O(T)-node auto-diff graph, which OOMs atT ≥ 2048on a 36 GB M-series Mac — the common full-parameter / LoRAfine-tuning setup for Qwen3.5-9B and Qwen3-Next-80B.
In merged PR #997, @angeloskath noted "This does affect finetuning
fairly heavily but I think we need a kernel for that to be an enjoyable
experience anyway." This PR is that kernel.
Related context: #1206 reports Qwen3.5-9B LoRA crashing on the first
backward pass. I do not claim this PR fully fixes that hardware-specific
report, but it targets the same training-time bottleneck: avoiding an
O(T)-node autodiff graph through the gated-delta recurrence by providing
a custom VJP path.
What this PR adds
Two new modules and a small integration change:
mlx_lm/models/gated_delta_vjp.py— pure-Python reference VJPusing
mx.checkpointon fixed-size chunks.O(T/CHUNK_SIZE)autodiffgraph, verifiable against
gated_delta_ops. Used as the importfallback.
mlx_lm/models/gated_delta_vjp_metal.py— Metal backward kernelregistered via
mx.custom_function. Forward-with-save + reversesweep in the same chunked layout; threadgroup-local reduction (no
atomics, deterministic). 8–11× faster than the Python VJP at
Qwen3.5-9B shapes.
mlx_lm/models/gated_delta.py—gated_delta_update()gains atraining: bool = Falseargument that routes to the VJP path. TheMetal backward is selected only when GPU/Metal is available,
mask is None,Dk % 32 == 0andDv % 4 == 0; otherwise the call falls backto the Python VJP (which has no shape constraints and runs on CPU).
The existing
use_kerneland mask/inference behaviour is unchanged.mlx_lm/models/qwen3_5.py/mlx_lm/models/qwen3_next.py—linear_attncall site passestraining=self.training.Inference and KV-cache paths are untouched.
Correctness
New tests appended to
tests/test_models.py(TestModelsclass,unittest.TestCasestyle, matching the existingtest_gated_delta*block):
test_gated_delta_vjp_forward_equivalence— Python VJP forward isbit-exact against
gated_delta_update(use_kernel=False).test_gated_delta_vjp_fd_gradient— central-difference check ofmx.gradoutput vs. analytic backward on a toy shape (B=1, T=4, Hk=2, Hv=4, Dk=8, Dv=8, fp32), tolerance1e-3.test_gated_delta_vjp_metal_matches_python— Metal VJP forward andstate agree with the Python reference within fp32 SIMD-reduction
noise (
atol=1e-3).test_gated_delta_vjp_metal_gradients_match_python—mx.gradoutput of the Metal VJP matches the Python VJP across all eight
trainable inputs (
q, k, v, a, b, A_log, dt_bias, state) withinatol=1e-3, rtol=1e-3.All eight
test_gated_delta*tests pass locally (python -m unittest tests.test_models.TestModels -k gated_delta).Performance (Qwen3.5-9B linear_attn shape: `B=1, Hk=16, Hv=64, Dk=192,
Dv=128`, bf16)
use_kernel=False(ops) fwd+bwdEnd-to-end training (Qwen3.5-9B LoRA on 36 GB M4 Max,
max_seq=4096)500-iteration full LoRA run on the unfiltered training set,
batch=1,grad_checkpoint=true, 4 LoRA keys (q_proj, v_proj, in_proj_qkv, out_proj):Converges in this configuration; peak memory stable across the run.
Total time ≈ 84 minutes (10 s/iter). This is one observation on one
shape and is not a generic stability guarantee for all training
configurations.
Relationship to #496
PR #496 added the
use_kernel: bool = Truerouting so that.trainingfalls through togated_delta_ops. This PR reuses thatsignal — the new
training=Truepath selects the VJP module; anythingelse goes through the existing
use_kernelbranch. No behaviour changefor inference or for the
use_kernel=Falseeval path.Scope
This PR covers only the training-time VJP/backward path for
gated_delta_update.Out of scope and not included:
Files
mlx_lm/models/gated_delta_vjp.py(new, ~180 LoC)mlx_lm/models/gated_delta_vjp_metal.py(new, ~770 LoC)mlx_lm/models/gated_delta.py(+~30 LoC)mlx_lm/models/qwen3_5.py(+1 LoC)mlx_lm/models/qwen3_next.py(+1 LoC)tests/test_models.py(+~160 LoC)