Skip to content

metal : add GDN kernel#20361

Merged
ggerganov merged 5 commits intogg/llama-allow-gdn-chfrom
gg/metal-add-gdn
Mar 11, 2026
Merged

metal : add GDN kernel#20361
ggerganov merged 5 commits intogg/llama-allow-gdn-chfrom
gg/metal-add-gdn

Conversation

@ggerganov
Copy link
Member

target #20340
cont #20244

Add fused GDN recurrent kernel. Use both for BS == 1 and BS > 1.

Model Test t/s master t/s gg/metal-add-gdn Speedup
kimi-linear 48B.A3B Q4_K_M pp512 613.65 1529.10 2.49
kimi-linear 48B.A3B Q4_K_M pp2048 654.58 1955.46 2.99
kimi-linear 48B.A3B Q4_K_M tg32 68.19 85.61 1.26
qwen35 27B Q8_0 pp512 349.12 390.39 1.12
qwen35 27B Q8_0 pp2048 363.75 406.81 1.12
qwen35 27B Q8_0 tg32 17.03 20.36 1.20
qwen35moe 35B.A3B Q4_0 pp512 1612.12 2058.31 1.28
qwen35moe 35B.A3B Q4_0 pp2048 1879.76 2462.35 1.31
qwen35moe 35B.A3B Q4_0 tg32 57.08 77.65 1.36
qwen3next 80B.A3B Q4_K_M pp512 1002.26 1172.08 1.17
qwen3next 80B.A3B Q4_K_M pp2048 1222.72 1477.06 1.21
qwen3next 80B.A3B Q4_K_M tg32 46.58 60.67 1.30

arkavo-com and others added 4 commits March 10, 2026 15:23
Add a fused Metal kernel for the gated delta net recurrence op
(#19504), enabling GPU-accelerated inference for DeltaNet-based
models (Qwen3.5, etc.) on Apple Silicon.

Supports both GDA (scalar gate) and KDA (per-row gate) modes
with head_size 64 and 128. Unsupported configurations (head_size
32, non-contiguous tensors) gracefully fall back to CPU.

Performance: Qwen3.5-0.8B Q4_K_M on M4 Max
  tg128: 170 -> 213 t/s (+25%)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@github-actions github-actions bot added ggml changes relating to the ggml tensor library for machine learning Apple Metal https://en.wikipedia.org/wiki/Metal_(API) labels Mar 10, 2026
@arkavo-com
Copy link
Contributor

arkavo-com commented Mar 11, 2026

Bench results — Apple M4 Max & M1 Max

All 13/13 test-backend-ops -o GATED_DELTA_NET pass on Metal (both devices).

5 reps, -ngl 99 -t 1, base = merge-base 0cd4f47.

Qwen3.5-0.8B Q4_K_M

Test Base PR (M4 Max) Speedup PR (M1 Max) Speedup
pp512 5,981 / 3,088 t/s 7,957 t/s +33% 4,370 t/s +42%
tg128 160 / 82 t/s 247 t/s +54% 112 t/s +37%

Qwen3.5-9B Q4_K_M

Test Base PR (M4 Max) Speedup PR (M1 Max) Speedup
pp512 772 / 414 t/s 838 t/s +9% 460 t/s +11%
tg128 52 / 28 t/s 61 t/s +17% 32 t/s +17%

Copy link
Contributor

@arkavo-com arkavo-com left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested on M4 Max and M1 Max — all 13/13 backend tests pass, benchmarks look great.

One potential issue: supports_op checks ne20 % 32 == 0 but there's no upper bound. If a model had head_size > 128 (i.e., nsg > 4), supports_op would return true but no matching template exists (NSG only has 1/2/4 specializations), which would fail at pipeline compilation time.

Suggested fix in ggml-metal-device.m:

return op->src[2]->ne[0] % 32 == 0 && op->src[2]->ne[0] <= 128;

No current models use head_size > 128 for GDN, so this is theoretical — but would be a nice safety guard.

@tarruda
Copy link

tarruda commented Mar 11, 2026

Pretty good improvements on M1 ultra

AesSedai Qwen 3.5 35B Q4_K_M:

Before:

| model                          |       size |     params | backend    | threads | n_ubatch | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | -------: | -: | --------------: | -------------------: |
| qwen35moe ?B Q8_0              |  20.61 GiB |    34.66 B | MTL,BLAS   |       1 |     2048 |  1 |           pp512 |        976.75 ± 7.72 |
| qwen35moe ?B Q8_0              |  20.61 GiB |    34.66 B | MTL,BLAS   |       1 |     2048 |  1 |           tg128 |         39.25 ± 0.05 |
| qwen35moe ?B Q8_0              |  20.61 GiB |    34.66 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d10000 |        793.13 ± 2.91 |
| qwen35moe ?B Q8_0              |  20.61 GiB |    34.66 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d10000 |         38.24 ± 0.09 |
| qwen35moe ?B Q8_0              |  20.61 GiB |    34.66 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d20000 |        667.17 ± 3.33 |
| qwen35moe ?B Q8_0              |  20.61 GiB |    34.66 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d20000 |         36.84 ± 0.13 |
| qwen35moe ?B Q8_0              |  20.61 GiB |    34.66 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d30000 |        579.77 ± 3.18 |
| qwen35moe ?B Q8_0              |  20.61 GiB |    34.66 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d30000 |         35.62 ± 0.04 |
| qwen35moe ?B Q8_0              |  20.61 GiB |    34.66 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d40000 |        507.58 ± 1.72 |
| qwen35moe ?B Q8_0              |  20.61 GiB |    34.66 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d40000 |         34.14 ± 0.05 |

After:

| model                          |       size |     params | backend    | threads | n_ubatch | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | -------: | -: | --------------: | -------------------: |
| qwen35moe 35B.A3B Q8_0         |  20.61 GiB |    34.66 B | MTL,BLAS   |       1 |     2048 |  1 |           pp512 |       1203.68 ± 6.87 |
| qwen35moe 35B.A3B Q8_0         |  20.61 GiB |    34.66 B | MTL,BLAS   |       1 |     2048 |  1 |           tg128 |         52.91 ± 0.08 |
| qwen35moe 35B.A3B Q8_0         |  20.61 GiB |    34.66 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d10000 |        939.31 ± 2.34 |
| qwen35moe 35B.A3B Q8_0         |  20.61 GiB |    34.66 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d10000 |         50.08 ± 0.16 |
| qwen35moe 35B.A3B Q8_0         |  20.61 GiB |    34.66 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d20000 |        769.63 ± 2.76 |
| qwen35moe 35B.A3B Q8_0         |  20.61 GiB |    34.66 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d20000 |         48.51 ± 0.10 |
| qwen35moe 35B.A3B Q8_0         |  20.61 GiB |    34.66 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d30000 |        649.57 ± 2.92 |
| qwen35moe 35B.A3B Q8_0         |  20.61 GiB |    34.66 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d30000 |         46.25 ± 0.08 |
| qwen35moe 35B.A3B Q8_0         |  20.61 GiB |    34.66 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d40000 |        563.00 ± 4.10 |
| qwen35moe 35B.A3B Q8_0         |  20.61 GiB |    34.66 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d40000 |         44.06 ± 0.08 |

AesSedai Qwen 3.5 122B Q4_K_M:

Before:

| model                          |       size |     params | backend    | threads | n_ubatch | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | -------: | -: | --------------: | -------------------: |
| qwen35moe 122B.A10B Q8_0       |  71.44 GiB |   122.11 B | MTL,BLAS   |       1 |     2048 |  1 |           pp512 |        333.09 ± 1.62 |
| qwen35moe 122B.A10B Q8_0       |  71.44 GiB |   122.11 B | MTL,BLAS   |       1 |     2048 |  1 |           tg128 |         21.28 ± 0.01 |
| qwen35moe 122B.A10B Q8_0       |  71.44 GiB |   122.11 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d10000 |        287.46 ± 1.59 |
| qwen35moe 122B.A10B Q8_0       |  71.44 GiB |   122.11 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d10000 |         20.37 ± 0.01 |
| qwen35moe 122B.A10B Q8_0       |  71.44 GiB |   122.11 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d20000 |        252.55 ± 0.65 |
| qwen35moe 122B.A10B Q8_0       |  71.44 GiB |   122.11 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d20000 |         19.48 ± 0.01 |
| qwen35moe 122B.A10B Q8_0       |  71.44 GiB |   122.11 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d30000 |        225.05 ± 0.65 |
| qwen35moe 122B.A10B Q8_0       |  71.44 GiB |   122.11 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d30000 |         18.71 ± 0.00 |
| qwen35moe 122B.A10B Q8_0       |  71.44 GiB |   122.11 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d40000 |        201.86 ± 0.68 |
| qwen35moe 122B.A10B Q8_0       |  71.44 GiB |   122.11 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d40000 |         18.00 ± 0.01 |

After:

| model                          |       size |     params | backend    | threads | n_ubatch | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | -------: | -: | --------------: | -------------------: |
| qwen35moe 122B.A10B Q8_0       |  71.44 GiB |   122.11 B | MTL,BLAS   |       1 |     2048 |  1 |           pp512 |        391.70 ± 2.28 |
| qwen35moe 122B.A10B Q8_0       |  71.44 GiB |   122.11 B | MTL,BLAS   |       1 |     2048 |  1 |           tg128 |         29.32 ± 0.03 |
| qwen35moe 122B.A10B Q8_0       |  71.44 GiB |   122.11 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d10000 |        329.11 ± 0.96 |
| qwen35moe 122B.A10B Q8_0       |  71.44 GiB |   122.11 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d10000 |         27.48 ± 0.02 |
| qwen35moe 122B.A10B Q8_0       |  71.44 GiB |   122.11 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d20000 |        284.26 ± 0.55 |
| qwen35moe 122B.A10B Q8_0       |  71.44 GiB |   122.11 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d20000 |         25.81 ± 0.02 |
| qwen35moe 122B.A10B Q8_0       |  71.44 GiB |   122.11 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d30000 |        249.75 ± 0.62 |
| qwen35moe 122B.A10B Q8_0       |  71.44 GiB |   122.11 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d30000 |         24.52 ± 0.02 |
| qwen35moe 122B.A10B Q8_0       |  71.44 GiB |   122.11 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d40000 |        222.30 ± 0.24 |
| qwen35moe 122B.A10B Q8_0       |  71.44 GiB |   122.11 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d40000 |         23.26 ± 0.02 |

Ubergarm Qwen 3.5 397B smol-IQ2_XS

Before:

| model                          |       size |     params | backend    | threads | n_ubatch | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | -------: | -: | --------------: | -------------------: |
| qwen35moe 397B.A17B Q8_0       | 113.41 GiB |   396.35 B | MTL,BLAS   |       1 |     2048 |  1 |           pp512 |        171.43 ± 1.59 |
| qwen35moe 397B.A17B Q8_0       | 113.41 GiB |   396.35 B | MTL,BLAS   |       1 |     2048 |  1 |           tg128 |         15.19 ± 0.01 |
| qwen35moe 397B.A17B Q8_0       | 113.41 GiB |   396.35 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d10000 |        154.50 ± 0.60 |
| qwen35moe 397B.A17B Q8_0       | 113.41 GiB |   396.35 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d10000 |         14.57 ± 0.00 |
| qwen35moe 397B.A17B Q8_0       | 113.41 GiB |   396.35 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d20000 |        140.73 ± 0.28 |
| qwen35moe 397B.A17B Q8_0       | 113.41 GiB |   396.35 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d20000 |         14.00 ± 0.01 |
| qwen35moe 397B.A17B Q8_0       | 113.41 GiB |   396.35 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d30000 |        129.52 ± 0.27 |
| qwen35moe 397B.A17B Q8_0       | 113.41 GiB |   396.35 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d30000 |         13.51 ± 0.01 |
| qwen35moe 397B.A17B Q8_0       | 113.41 GiB |   396.35 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d40000 |        119.92 ± 0.25 |
| qwen35moe 397B.A17B Q8_0       | 113.41 GiB |   396.35 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d40000 |         13.02 ± 0.01 |

After:

| model                          |       size |     params | backend    | threads | n_ubatch | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | -------: | -: | --------------: | -------------------: |
| qwen35moe 397B.A17B Q8_0       | 113.41 GiB |   396.35 B | MTL,BLAS   |       1 |     2048 |  1 |           pp512 |        189.70 ± 2.01 |
| qwen35moe 397B.A17B Q8_0       | 113.41 GiB |   396.35 B | MTL,BLAS   |       1 |     2048 |  1 |           tg128 |         20.01 ± 0.02 |
| qwen35moe 397B.A17B Q8_0       | 113.41 GiB |   396.35 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d10000 |        168.99 ± 0.56 |
| qwen35moe 397B.A17B Q8_0       | 113.41 GiB |   396.35 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d10000 |         18.94 ± 0.02 |
| qwen35moe 397B.A17B Q8_0       | 113.41 GiB |   396.35 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d20000 |        152.31 ± 0.19 |
| qwen35moe 397B.A17B Q8_0       | 113.41 GiB |   396.35 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d20000 |         17.88 ± 0.00 |
| qwen35moe 397B.A17B Q8_0       | 113.41 GiB |   396.35 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d30000 |        139.28 ± 0.30 |
| qwen35moe 397B.A17B Q8_0       | 113.41 GiB |   396.35 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d30000 |         17.12 ± 0.02 |
| qwen35moe 397B.A17B Q8_0       | 113.41 GiB |   396.35 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d40000 |        128.37 ± 0.45 |
| qwen35moe 397B.A17B Q8_0       | 113.41 GiB |   396.35 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d40000 |         16.39 ± 0.00 |

@ggerganov
Copy link
Member Author

ggerganov commented Mar 11, 2026

Thanks. Now that I finally know how to profile the kernels, more improvements will come.

@ggerganov ggerganov merged commit f6a0c16 into gg/llama-allow-gdn-ch Mar 11, 2026
72 of 74 checks passed
@ggerganov ggerganov deleted the gg/metal-add-gdn branch March 11, 2026 18:39
ggerganov added a commit that referenced this pull request Mar 11, 2026
* llama : enable chunked fused GDN path

* models : avoid Q and K repeats when using fused GDA

* cont : fix comment

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

* cont : fix the fix

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

* cont : fix

* metal : add GDN kernel (#20361)

* metal : add Metal backend for GGML_OP_GATED_DELTA_NET

Add a fused Metal kernel for the gated delta net recurrence op
(#19504), enabling GPU-accelerated inference for DeltaNet-based
models (Qwen3.5, etc.) on Apple Silicon.

Supports both GDA (scalar gate) and KDA (per-row gate) modes
with head_size 64 and 128. Unsupported configurations (head_size
32, non-contiguous tensors) gracefully fall back to CPU.

Performance: Qwen3.5-0.8B Q4_K_M on M4 Max
  tg128: 170 -> 213 t/s (+25%)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* metal : validate contiguity of all input tensors in supports_op

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* metal : add algorithm equivalence comment for GDA decay path

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* cont : unslop + optimize

* cont : clean-up

---------

Co-authored-by: Paul Flynn <paul@arkavo.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

* CUDA: AR gated delta net improvements (#20391)

* Add FastDiv to gated_delta_net_cuda

* Shard columns across warps

This reduces register pressure (avoids spill for S_v = 128) and gives
the warp-scheduler more CTAs to schedule (thus hiding data-access
latencies).

* Remove unneded include in gated_delta_net.cu

* Improve comments

* Apply code-formating

* Make sharding HIP-compatible

1. Use ggml_cuda_get_physical_warp_size() to determine warp size flexibly
2. Add test with partial warp to test sum reduction on CUDA

* Remove fastdiv_s64, as we can treat neqk1 and rq3 as uint32_t

* Rename variables

* Enable GDN also for prefill, move TODO for chunked_GDN

* Actually remove the TODO from 2068908

* Get warp size at runtime

warp_size is not known at compile time in hip host code.

* Don't expose ggml_cuda_get_physical_warp_size on host

---------

Co-authored-by: uvos <devnull@uvos.xyz>

* llama : refactor llm_build_delta_net_base API

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
Co-authored-by: Paul Flynn <paul@arkavo.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Oliver Simons <osimons@nvidia.com>
Co-authored-by: uvos <devnull@uvos.xyz>
ProgenyAlpha pushed a commit to ProgenyAlpha/llama.cpp that referenced this pull request Mar 12, 2026
* llama : enable chunked fused GDN path

* models : avoid Q and K repeats when using fused GDA

* cont : fix comment

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

* cont : fix the fix

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

* cont : fix

* metal : add GDN kernel (ggml-org#20361)

* metal : add Metal backend for GGML_OP_GATED_DELTA_NET

Add a fused Metal kernel for the gated delta net recurrence op
(ggml-org#19504), enabling GPU-accelerated inference for DeltaNet-based
models (Qwen3.5, etc.) on Apple Silicon.

Supports both GDA (scalar gate) and KDA (per-row gate) modes
with head_size 64 and 128. Unsupported configurations (head_size
32, non-contiguous tensors) gracefully fall back to CPU.

Performance: Qwen3.5-0.8B Q4_K_M on M4 Max
  tg128: 170 -> 213 t/s (+25%)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* metal : validate contiguity of all input tensors in supports_op

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* metal : add algorithm equivalence comment for GDA decay path

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* cont : unslop + optimize

* cont : clean-up

---------

Co-authored-by: Paul Flynn <paul@arkavo.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

* CUDA: AR gated delta net improvements (ggml-org#20391)

* Add FastDiv to gated_delta_net_cuda

* Shard columns across warps

This reduces register pressure (avoids spill for S_v = 128) and gives
the warp-scheduler more CTAs to schedule (thus hiding data-access
latencies).

* Remove unneded include in gated_delta_net.cu

* Improve comments

* Apply code-formating

* Make sharding HIP-compatible

1. Use ggml_cuda_get_physical_warp_size() to determine warp size flexibly
2. Add test with partial warp to test sum reduction on CUDA

* Remove fastdiv_s64, as we can treat neqk1 and rq3 as uint32_t

* Rename variables

* Enable GDN also for prefill, move TODO for chunked_GDN

* Actually remove the TODO from 2068908

* Get warp size at runtime

warp_size is not known at compile time in hip host code.

* Don't expose ggml_cuda_get_physical_warp_size on host

---------

Co-authored-by: uvos <devnull@uvos.xyz>

* llama : refactor llm_build_delta_net_base API

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
Co-authored-by: Paul Flynn <paul@arkavo.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Oliver Simons <osimons@nvidia.com>
Co-authored-by: uvos <devnull@uvos.xyz>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Apple Metal https://en.wikipedia.org/wiki/Metal_(API) ggml changes relating to the ggml tensor library for machine learning

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants