Skip to content

ggml: add GATED_DELTA_NET op#19504

Merged
am17an merged 5 commits intoggml-org:masterfrom
am17an:gated_delta_net
Mar 7, 2026
Merged

ggml: add GATED_DELTA_NET op#19504
am17an merged 5 commits intoggml-org:masterfrom
am17an:gated_delta_net

Conversation

@am17an
Copy link
Contributor

@am17an am17an commented Feb 11, 2026

Add CPU/CUDA impl for GATED_DELTA_NET used in qwen3next and a lot of upcoming recent attention models. This is a basic vector impl and not the chunking impl, although this should work for n_tokens > 1 as a reference implementation. I tested this vs build_delta_net_autoregressive and the results were good. I plan to add the chunked implementation for CPU and CUDA.

master:

model size params backend threads fa test t/s
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CPU 16 1 tg32 4.77 ± 0.03
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CPU 16 1 tg32 @ d1024 4.55 ± 0.13

sched_reserve: graph nodes = 14990 (with bs=512), 6242 (with bs=1)

ggml_op_gated_delta_net added to the qwen3next graph (not added in the PR)

model size params backend threads fa test t/s
qwen35moe ?B Q4_K - Small 18.55 GiB 34.66 B CPU 16 1 tg32 11.08 ± 0.20
qwen35moe ?B Q4_K - Small 18.55 GiB 34.66 B CPU 16 1 tg32 @ d1024 11.21 ± 0.07

sched_reserve: graph nodes = 14990 (with bs=512), 5342 (with bs=1)

@am17an am17an requested a review from ggerganov as a code owner February 11, 2026 07:09
@am17an am17an requested a review from pwilkin February 11, 2026 07:09
@github-actions github-actions bot added testing Everything test related ggml changes relating to the ggml tensor library for machine learning labels Feb 11, 2026
@ggerganov
Copy link
Member

ggerganov commented Feb 11, 2026

I think it is too early to implement the dedicated delta net ops. There are still many things to optimize in the existing implementation (you can keep track of my progress in #19375). After that we have to consolidate the KDA version of the delta net (#18792). Btw the l2 norm should not be part of this OP - fixed in my branch. Also not sure how to handle the 2 variants of this operator (autoregressive and chunked).

So I think we can experiment with a dedicated op in a branch, but merging this in master will likely take time.

@am17an
Copy link
Contributor Author

am17an commented Feb 11, 2026

@ggerganov I defer to your judgement, my thinking was that qwen3.5 is already a major model series, so even if the op is just for that model it makes sense.

for KDA, AFAIK it the gate is a matrix, so it will just be another dot product instead of a scale. For chunk vs autoregressive, we have the vec FA path for CPU which now serves a reference kernel. I was thinking it would be the same here, the autoregressive kernel remains the simple kernel while chunking is the optimisation, both solve the same recurrence.

@ggerganov
Copy link
Member

Ok, let's prototype a branch that also has this op together with the CUDA implementation rebased on #19375. I will then add the Metal version of the kernel and from there we can consider a quicker merge if things are looking good. Also, want to see if having this op will allow the CUDA graphs to be more easily enabled.

@pwilkin
Copy link
Contributor

pwilkin commented Feb 11, 2026

So this is basically what the Transformers implementations have as the "recurrent" implementation, right? No chunking, just iterating over tokens.

@am17an
Copy link
Contributor Author

am17an commented Feb 11, 2026

@pwilkin yes, just calculating the recurrence token by token

@ggerganov
Copy link
Member

ggerganov commented Feb 11, 2026

Btw, should also consider small batch sizes larger than 1 to be handled by this operator too. I'm not sure where the break-even point would be, but I imagine that processing a few tokens auto-regressively (i.e. more than 1 and less than ~16) would be more efficient compared to the chunking path. Also don't forget that dim 3 will handle separate sequences - though from a quick look, this implementation already accounts for that.

@am17an
Copy link
Contributor Author

am17an commented Feb 11, 2026

Btw, should also consider small batch sizes larger than 1 to be handled by this operator too. I'm not sure where the break-even point would be, but I imagine that processing a few tokens auto-regressively (i.e. more than 1 and less than ~16) would be more efficient compared to the chunking path.

Yes for small amount of tokens we can just run a loop even in CUDA. I have not looked into the chunked impl yet, but I will invest some time in finding the breakeven point

Also don't forget that dim 3 will handle separate sequences - though from a quick look, this implementation already accounts for that.

I think this should be fine, the work is split among dim1 * dim3 (heads * sequences)

@ymcki
Copy link
Contributor

ymcki commented Feb 11, 2026

Great performance gain for inference. Looking forward to seeing your implementation done for the major backends.

If you have plan to do the chunking version as well, it will be great if it is based on the block implementation done at fla.

https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/chunk_intra_token_parallel.py
https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/chunk_intra.py

Copy link
Contributor

@pwilkin pwilkin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks clean to me. Are you planning on doing the chunking version here as well, or separate op / PR?

@ggerganov
Copy link
Member

Converted to draft since I am not sure if my comment was clear: #19504 (comment). First we will be prototyping a new branch and after that we will consider adding the new op.

@pwilkin
Copy link
Contributor

pwilkin commented Feb 11, 2026

Should we use this PR or will you create a dedicated branch?

@am17an
Copy link
Contributor Author

am17an commented Feb 11, 2026

@ggerganov I removed the norm, and also added the autoregressive cuda op in 01eda69, it passes test-backend-ops. I have not done the rebase on #19375

@github-actions github-actions bot added model Model specific Nvidia GPU Issues specific to Nvidia GPUs labels Feb 11, 2026
@ggerganov
Copy link
Member

ggerganov commented Feb 11, 2026

Just a heads up, I will be rebasing the #19375 branch from time to time. Hope it's not a big issue. Just always put your commits on top. I'm hoping to merge in a day or two.

@am17an
Copy link
Contributor Author

am17an commented Feb 11, 2026

I did a quick perf test this PR + #19375 + replacing the autoregressive for qwen3next with gated_delta_net. On a 5090

Details

master

model size params backend ngl fa test t/s
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 83.92 ± 0.39
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d1024 84.45 ± 0.36
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d2048 84.20 ± 0.61
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d4096 83.82 ± 0.56
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d8192 83.43 ± 1.73
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d16384 83.56 ± 0.47

PR:

model size params backend ngl fa test t/s
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 105.95 ± 0.36
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d1024 105.05 ± 0.91
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d2048 105.33 ± 0.42
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d4096 105.10 ± 0.50
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d8192 98.13 ± 1.79
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d16384 97.22 ± 0.49

@ggerganov
Copy link
Member

For reference, what do you get with CUDA graphs forced enabled:

diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index f3d8317e1..605cb3ed4 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -2894,7 +2894,7 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
 #endif
         }
 
-        if (node->op == GGML_OP_ADD &&
+        if (false && node->op == GGML_OP_ADD &&
             node->src[1] && node->src[1]->ne[1] > 1 &&
             (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
             (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&

@am17an
Copy link
Contributor Author

am17an commented Feb 11, 2026

With force enabled CUDA graphs

Details
model size params backend ngl fa test t/s
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 111.89 ± 2.48
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d1024 135.26 ± 6.93
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d2048 135.89 ± 4.95
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d4096 134.77 ± 4.67
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d8192 123.30 ± 6.07
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d16384 121.22 ± 5.28

@ubergarm
Copy link
Contributor

ubergarm commented Mar 6, 2026

Very nice increase in TG speeds for CPU-only here! I didn't measure any increase for PP however (which may be expected). Finally this gaming rig CPU is Zen5 and has avx512_vnni which I don't believe has extra benefits here on mainline.

sweep-bench-Qwen3 5-35B-A3B-ik-vs-mainline-CPU
👈 Details

ik_llama.cpp main@277fc1d2

model=/mnt/astrodata/llm/models/ubergarm/Qwen3.5-35B-A3B-GGUF/Qwen3.5-35B-A3B-Q4_0.gguf
./build/bin/llama-sweep-bench \
  --model "$model" \
  -ctk q8_0 -ctv q8_0 \
  -c 69632 \
  -ub 1024 -b 2048 \
  --merge-qkv \
  --threads 16 \
  --warmup-batch \
  -n 128
PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
1024 128 0 1.296 790.34 5.271 24.28
1024 128 1024 1.358 754.32 5.141 24.90
1024 128 2048 1.401 730.66 5.149 24.86
1024 128 3072 1.450 706.21 5.194 24.65
1024 128 4096 1.493 685.70 5.221 24.52
1024 128 5120 1.530 669.12 5.240 24.43
1024 128 6144 1.575 650.02 5.253 24.37
1024 128 7168 1.599 640.30 5.287 24.21
1024 128 8192 1.643 623.10 5.281 24.24
1024 128 9216 1.681 609.02 5.302 24.14
1024 128 10240 1.792 571.54 5.332 24.01
1024 128 11264 1.763 580.93 5.335 23.99
1024 128 12288 1.806 566.90 5.369 23.84
1024 128 13312 1.847 554.44 5.397 23.72
1024 128 14336 1.885 543.22 5.402 23.69
1024 128 15360 1.929 530.94 5.431 23.57
1024 128 16384 1.980 517.04 5.440 23.53
1024 128 17408 2.062 496.52 5.496 23.29
1024 128 18432 2.060 497.02 5.511 23.22
1024 128 19456 2.087 490.75 5.568 22.99
1024 128 20480 2.141 478.34 5.645 22.67
1024 128 21504 2.160 474.08 5.627 22.75
1024 128 22528 2.224 460.40 5.634 22.72
1024 128 23552 2.258 453.41 5.689 22.50
1024 128 24576 2.422 422.76 5.692 22.49
1024 128 25600 2.327 440.10 5.683 22.52
1024 128 26624 2.367 432.61 5.762 22.22
1024 128 27648 2.410 424.86 5.788 22.12
1024 128 28672 2.444 419.04 5.817 22.00
1024 128 29696 2.501 409.44 5.843 21.91
1024 128 30720 2.581 396.74 5.853 21.87
1024 128 31744 2.610 392.41 5.853 21.87
1024 128 32768 2.669 383.70 5.842 21.91

mainline llama.cpp master@e68f2fb8 + ug/port-sweep-bench

model=/mnt/astrodata/llm/models/ubergarm/Qwen3.5-35B-A3B-GGUF/Qwen3.5-35B-A3B-Q4_0.gguf
./build/bin/llama-sweep-bench \
  --model "$model" \
  -ctk q8_0 -ctv q8_0 \
  -c 69632 \
  -ub 1024 -b 2048 \
  --threads 16 \
  -n 128
PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
1024 128 0 6.153 166.41 8.284 15.45
1024 128 1024 7.336 139.59 8.140 15.72
1024 128 2048 7.958 128.68 8.312 15.40
1024 128 3072 9.048 113.17 8.603 14.88
1024 128 4096 10.045 101.94 8.353 15.32
1024 128 5120 10.969 93.35 8.858 14.45
1024 128 6144 11.710 87.45 8.739 14.65
1024 128 7168 12.552 81.58 8.697 14.72
1024 128 8192 13.479 75.97 9.226 13.87
1024 128 9216 14.620 70.04 9.239 13.85
1024 128 10240 15.000 68.27 9.135 14.01
1024 128 11264 16.088 63.65 9.467 13.52
1024 128 12288 16.675 61.41 9.158 13.98
1024 128 13312 17.236 59.41 9.696 13.20
1024 128 14336 18.576 55.12 9.524 13.44
1024 128 15360 19.520 52.46 9.855 12.99
1024 128 16384 19.817 51.67 9.231 13.87
1024 128 17408 19.869 51.54 9.579 13.36
1024 128 18432 21.962 46.63 10.553 12.13
1024 128 19456 22.715 45.08 10.453 12.25
1024 128 20480 23.965 42.73 10.579 12.10
1024 128 21504 24.021 42.63 9.781 13.09
1024 128 22528 24.344 42.06 10.275 12.46
1024 128 23552 24.429 41.92 10.475 12.22
1024 128 24576 27.235 37.60 10.987 11.65
1024 128 25600 26.745 38.29 10.248 12.49
1024 128 26624 27.459 37.29 10.219 12.53
1024 128 27648 29.335 34.91 10.275 12.46
1024 128 28672 31.157 32.87 10.270 12.46
1024 128 29696 32.972 31.06 12.251 10.45
1024 128 30720 33.706 30.38 10.442 12.26
1024 128 31744 35.189 29.10 12.012 10.66
1024 128 32768 34.935 29.31 12.053 10.62

mainline llama.cpp PR19504 gated_delta_net@e0fbfc01 + ug/port-sweep-bench

model=/mnt/astrodata/llm/models/ubergarm/Qwen3.5-35B-A3B-GGUF/Qwen3.5-35B-A3B-Q4_0.gguf
./build/bin/llama-sweep-bench \
  --model "$model" \
  -ctk q8_0 -ctv q8_0 \
  -c 69632 \
  -ub 1024 -b 2048 \
  --threads 16 \
  -n 128
PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
1024 128 0 6.312 162.23 5.492 23.31
1024 128 1024 7.259 141.07 5.616 22.79
1024 128 2048 8.317 123.11 5.688 22.51
1024 128 3072 9.352 109.49 5.657 22.63
1024 128 4096 10.407 98.39 5.775 22.16
1024 128 5120 11.199 91.44 5.824 21.98
1024 128 6144 12.125 84.45 5.958 21.48
1024 128 7168 12.605 81.24 6.031 21.22
1024 128 8192 13.620 75.19 6.119 20.92
1024 128 9216 14.666 69.82 6.199 20.65
1024 128 10240 15.176 67.48 6.217 20.59
1024 128 11264 16.081 63.68 6.229 20.55
1024 128 12288 16.722 61.24 6.299 20.32
1024 128 13312 17.366 58.96 6.498 19.70
1024 128 14336 18.694 54.78 6.495 19.71
1024 128 15360 19.459 52.62 6.758 18.94
1024 128 16384 20.005 51.19 6.756 18.95
1024 128 17408 20.609 49.69 6.910 18.52
1024 128 18432 21.698 47.19 6.848 18.69
1024 128 19456 22.549 45.41 7.245 17.67
1024 128 20480 23.390 43.78 7.107 18.01
1024 128 21504 24.302 42.14 7.188 17.81
1024 128 22528 25.209 40.62 7.306 17.52
1024 128 23552 25.322 40.44 7.315 17.50
1024 128 24576 26.036 39.33 7.493 17.08
1024 128 25600 27.006 37.92 7.627 16.78
1024 128 26624 27.483 37.26 7.758 16.50
1024 128 27648 29.856 34.30 8.113 15.78
1024 128 28672 30.532 33.54 8.057 15.89
1024 128 29696 33.484 30.58 7.793 16.42
1024 128 30720 33.784 30.31 7.978 16.04
1024 128 31744 34.164 29.97 8.644 14.81
1024 128 32768 33.413 30.65 8.234 15.54

I didn't compare hybrid CPU+GPU performance but expect it will see better throughput as well for TG. Some more details on how I compiled and similar benchmarks without this PR here.

Thanks and great work!

@am17an
Copy link
Contributor Author

am17an commented Mar 7, 2026

I see actually a huge difference in PP on CPU when just using the autoregressive kernel instead of the current one i.e. just use the fused op regardless of n_tokens. But I think I will optimize this later

@am17an am17an merged commit c5a7788 into ggml-org:master Mar 7, 2026
76 of 78 checks passed
@am17an am17an deleted the gated_delta_net branch March 7, 2026 07:41
@jacekpoplawski
Copy link
Contributor

jacekpoplawski commented Mar 7, 2026

Great speedup on tg (Qwen Next and Qwen 3.5)!
  Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
  Device 1: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
  Device 2: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
g p g p

@jeffbolznv
Copy link
Collaborator

Hi @ProgenyAlpha, just wanted to check whether you still plan to submit a PR for the vulkan backend support.

@CISC
Copy link
Member

CISC commented Mar 8, 2026

Huh, not sure exactly what's happening, but MUSA build is now throwing an ICE:
https://github.com/ggml-org/llama.cpp/actions/runs/22774203231/job/66063122971#step:6:416

Edit: This killed our Docker release as well:
https://github.com/ggml-org/llama.cpp/actions/runs/22814140371/job/66176342700#step:9:900

@am17an
Copy link
Contributor Author

am17an commented Mar 8, 2026

@CISC not sure who maintains the MUSA backend, but it seems like a compiler bug

arkavo-com added a commit to arkavo-ai/llama.cpp that referenced this pull request Mar 8, 2026
Add a fused Metal kernel for the gated delta net recurrence op
(ggml-org#19504), enabling GPU-accelerated inference for DeltaNet-based
models (Qwen3.5, etc.) on Apple Silicon.

Supports both GDA (scalar gate) and KDA (per-row gate) modes
with head_size 64 and 128. Unsupported configurations (head_size
32, non-contiguous tensors) gracefully fall back to CPU.

Performance: Qwen3.5-0.8B Q4_K_M on M4 Max
  tg128: 170 -> 213 t/s (+25%)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@ggerganov
Copy link
Member

@yeahdongcn PTAL at the MUSA issue above.

@am17an In the meantime we can change supports_op to return false for MUSA

@yeahdongcn
Copy link
Collaborator

@yeahdongcn PTAL at the MUSA issue above.

@am17an In the meantime we can change supports_op to return false for MUSA

No problem. I'll try a local build first and see if I should open an internal ticket. Thanks!

@ProgenyAlpha
Copy link
Contributor

Hi @ProgenyAlpha, just wanted to check whether you still plan to submit a PR for the vulkan backend support.

I wasn't sure where the thread was going so I wanted to let you guys cook and see how things unfolded before I jump back in. I'll rebase and work on that this week if I have time. Thanks for pinging me!

ggerganov pushed a commit that referenced this pull request Mar 10, 2026
Add a fused Metal kernel for the gated delta net recurrence op
(#19504), enabling GPU-accelerated inference for DeltaNet-based
models (Qwen3.5, etc.) on Apple Silicon.

Supports both GDA (scalar gate) and KDA (per-row gate) modes
with head_size 64 and 128. Unsupported configurations (head_size
32, non-contiguous tensors) gracefully fall back to CPU.

Performance: Qwen3.5-0.8B Q4_K_M on M4 Max
  tg128: 170 -> 213 t/s (+25%)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
bartowski1182 pushed a commit to bartowski1182/llama.cpp that referenced this pull request Mar 10, 2026
* ggml: add GATED_DELTA_NET op

* remove the transpose

* add KDA

* add qwen35 dense

* llama : check for fused gated delta net backend support

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
ggerganov added a commit that referenced this pull request Mar 11, 2026
* metal : add Metal backend for GGML_OP_GATED_DELTA_NET

Add a fused Metal kernel for the gated delta net recurrence op
(#19504), enabling GPU-accelerated inference for DeltaNet-based
models (Qwen3.5, etc.) on Apple Silicon.

Supports both GDA (scalar gate) and KDA (per-row gate) modes
with head_size 64 and 128. Unsupported configurations (head_size
32, non-contiguous tensors) gracefully fall back to CPU.

Performance: Qwen3.5-0.8B Q4_K_M on M4 Max
  tg128: 170 -> 213 t/s (+25%)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* metal : validate contiguity of all input tensors in supports_op

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* metal : add algorithm equivalence comment for GDA decay path

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* cont : unslop + optimize

* cont : clean-up

---------

Co-authored-by: Paul Flynn <paul@arkavo.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
ggerganov added a commit that referenced this pull request Mar 11, 2026
* llama : enable chunked fused GDN path

* models : avoid Q and K repeats when using fused GDA

* cont : fix comment

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

* cont : fix the fix

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

* cont : fix

* metal : add GDN kernel (#20361)

* metal : add Metal backend for GGML_OP_GATED_DELTA_NET

Add a fused Metal kernel for the gated delta net recurrence op
(#19504), enabling GPU-accelerated inference for DeltaNet-based
models (Qwen3.5, etc.) on Apple Silicon.

Supports both GDA (scalar gate) and KDA (per-row gate) modes
with head_size 64 and 128. Unsupported configurations (head_size
32, non-contiguous tensors) gracefully fall back to CPU.

Performance: Qwen3.5-0.8B Q4_K_M on M4 Max
  tg128: 170 -> 213 t/s (+25%)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* metal : validate contiguity of all input tensors in supports_op

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* metal : add algorithm equivalence comment for GDA decay path

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* cont : unslop + optimize

* cont : clean-up

---------

Co-authored-by: Paul Flynn <paul@arkavo.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

* CUDA: AR gated delta net improvements (#20391)

* Add FastDiv to gated_delta_net_cuda

* Shard columns across warps

This reduces register pressure (avoids spill for S_v = 128) and gives
the warp-scheduler more CTAs to schedule (thus hiding data-access
latencies).

* Remove unneded include in gated_delta_net.cu

* Improve comments

* Apply code-formating

* Make sharding HIP-compatible

1. Use ggml_cuda_get_physical_warp_size() to determine warp size flexibly
2. Add test with partial warp to test sum reduction on CUDA

* Remove fastdiv_s64, as we can treat neqk1 and rq3 as uint32_t

* Rename variables

* Enable GDN also for prefill, move TODO for chunked_GDN

* Actually remove the TODO from 2068908

* Get warp size at runtime

warp_size is not known at compile time in hip host code.

* Don't expose ggml_cuda_get_physical_warp_size on host

---------

Co-authored-by: uvos <devnull@uvos.xyz>

* llama : refactor llm_build_delta_net_base API

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
Co-authored-by: Paul Flynn <paul@arkavo.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Oliver Simons <osimons@nvidia.com>
Co-authored-by: uvos <devnull@uvos.xyz>
ProgenyAlpha pushed a commit to ProgenyAlpha/llama.cpp that referenced this pull request Mar 12, 2026
* llama : enable chunked fused GDN path

* models : avoid Q and K repeats when using fused GDA

* cont : fix comment

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

* cont : fix the fix

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

* cont : fix

* metal : add GDN kernel (ggml-org#20361)

* metal : add Metal backend for GGML_OP_GATED_DELTA_NET

Add a fused Metal kernel for the gated delta net recurrence op
(ggml-org#19504), enabling GPU-accelerated inference for DeltaNet-based
models (Qwen3.5, etc.) on Apple Silicon.

Supports both GDA (scalar gate) and KDA (per-row gate) modes
with head_size 64 and 128. Unsupported configurations (head_size
32, non-contiguous tensors) gracefully fall back to CPU.

Performance: Qwen3.5-0.8B Q4_K_M on M4 Max
  tg128: 170 -> 213 t/s (+25%)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* metal : validate contiguity of all input tensors in supports_op

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* metal : add algorithm equivalence comment for GDA decay path

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* cont : unslop + optimize

* cont : clean-up

---------

Co-authored-by: Paul Flynn <paul@arkavo.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

* CUDA: AR gated delta net improvements (ggml-org#20391)

* Add FastDiv to gated_delta_net_cuda

* Shard columns across warps

This reduces register pressure (avoids spill for S_v = 128) and gives
the warp-scheduler more CTAs to schedule (thus hiding data-access
latencies).

* Remove unneded include in gated_delta_net.cu

* Improve comments

* Apply code-formating

* Make sharding HIP-compatible

1. Use ggml_cuda_get_physical_warp_size() to determine warp size flexibly
2. Add test with partial warp to test sum reduction on CUDA

* Remove fastdiv_s64, as we can treat neqk1 and rq3 as uint32_t

* Rename variables

* Enable GDN also for prefill, move TODO for chunked_GDN

* Actually remove the TODO from 2068908

* Get warp size at runtime

warp_size is not known at compile time in hip host code.

* Don't expose ggml_cuda_get_physical_warp_size on host

---------

Co-authored-by: uvos <devnull@uvos.xyz>

* llama : refactor llm_build_delta_net_base API

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
Co-authored-by: Paul Flynn <paul@arkavo.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Oliver Simons <osimons@nvidia.com>
Co-authored-by: uvos <devnull@uvos.xyz>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning model Model specific Nvidia GPU Issues specific to Nvidia GPUs testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.