Skip to content

vulkan: add GATED_DELTA_NET op support#20334

Merged
0cc4m merged 6 commits intoggml-org:masterfrom
ProgenyAlpha:vulkan-gated-delta-net
Mar 12, 2026
Merged

vulkan: add GATED_DELTA_NET op support#20334
0cc4m merged 6 commits intoggml-org:masterfrom
ProgenyAlpha:vulkan-gated-delta-net

Conversation

@ProgenyAlpha
Copy link
Contributor

@ProgenyAlpha ProgenyAlpha commented Mar 10, 2026

Summary

First pass at a Vulkan compute shader for GGML_OP_GATED_DELTA_NET, covering both the standard (scalar gate) and KDA (per-row vector gate) variants. This is the core recurrence op used by Qwen3.5 and Qwen3-Next models.

What's here:

  • Autoregressive kernel: one workgroup per (head, sequence), one thread per state column
  • Supports S_V = 32, 64, 128 via specialization constants
  • GQA broadcast via v_repeat ratios
  • Phase 1 optimizations: vec4 dot products, cached exp(g) in shared memory for KDA

Benchmarks — AMD Radeon 8060S / Strix Halo (by @lemmi, vs master):

Model master TG t/s PR TG t/s Improvement
Qwen3-Coder-Next UD-Q4_K_XL 38.04 46.64 +22.6%
Qwen3.5-35B-A3B Q8_0 43.90 53.33 +21.5%
Qwen3.5-122B-A10B UD-Q5_K_XL 18.36 21.59 +17.6%

Benchmarks — AMD Radeon 890M / Strix Point (integrated):

Model PP-512 t/s PP-2048 t/s TG-128 t/s
Qwen3.5-0.8B Q4_K_M 1337 1361 85
Qwen3.5-4B Q4_K_M 346 381 21
Qwen3.5-9B Q4_K_M 249 272 12

Vulkan vs CPU on 9B: 1.7x PP, 7.5x TG.

Op-level perf (Phase 1 vs baseline scalar shader):

  • KDA TG 32h×d128: +5.4% (cached exp eliminates ~16K redundant calls/token)
  • Non-KDA: no regressions, slight gains from vec4

What's next:

13/13 test-backend-ops passing. All head sizes, GQA, multi-seq, permuted layouts, both KDA and non-KDA.

Open to collaboration — if anyone wants to work on the chunked kernel or test on different hardware, happy to coordinate.

cc @jhen0409 (re: #14909)

@github-actions github-actions bot added Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels Mar 10, 2026
@lemmi
Copy link

lemmi commented Mar 10, 2026

Quickly did a couple of benchmarks on strix halo (8060s). PP basically unchanged (maybe a little slower), TG quite a bit faster:

Model master (23fbfcb) tg128 (t/s) pr tg128 (t/s)
Qwen3-Coder-Next:UD-Q4_K_XL 38.04 46.64
Qwen3.5-35B-A3B:Q8_0 43.90 53.33
Qwen3.5-122B-A10B:UD-Q5_K_XL 18.36 21.59

PP performance still exhibits the same issue as described in #18725.

@github-actions github-actions bot added the testing Everything test related label Mar 10, 2026
@ProgenyAlpha
Copy link
Contributor Author

ProgenyAlpha commented Mar 10, 2026

Thanks for testing! Those TG numbers are great — 22% on Coder-Next, 21% on 35B-A3B, 18% on 122B-A10B.

The PP issue from #18725 makes sense — this shader only affects the deltanet recurrence layers, and PP throughput is still bottlenecked by the autoregressive token loop in the current kernel. A chunked parallel kernel (Phase 2) would fix that, but it's a much bigger piece of work. Hope to work on it tomorrow.

Updated the PR description with your benchmark numbers. Really helpful to have data from a discrete 8060S alongside the integrated 890M results.

@IIIIIllllIIIIIlllll
Copy link

Same 8060S setup. In my testing, although the current PR shows performance improvement compared to Master PR, Qwen3.5-35B0A3B cannot reach that speed (50 token/s), quite strange :(

command:

/home/mark/llama.cpp/llama.cpp-vulkan-gated-delta-net/build/bin/llama-bench 
-m /home/mark/Models/Q8/Qwen3.5-35B-A3B-Q8_0/Qwen3.5-35B-A3B-UD-Q8_K_XL.gguf 
--repetitions 5 --output md --mmap 0 --delay 0 
--n-prompt 512 
--n-gen 128 
--n-depth 0 
--batch-size 512 
--ubatch-size 512 
--cache-type-k f16 --cache-type-v f16 --threads 16 
--n-gpu-layers 99 
--n-cpu-moe 0 --flash-attn 1 --direct-io 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = AMD Radeon Graphics (RADV GFX1151) (radv) | uma: 1 | fp16: 1 | bf16: 0 | warp size: 64 | shared memory: 65536 | int dot: 1 | matrix cores: KHR_coopmat
| model                          |       size |     params | backend    | ngl | n_batch | fa | mmap | dio |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | -: | ---: | --: | --------------: | -------------------: |
| qwen35moe 35B.A3B Q8_0         |  45.33 GiB |    34.66 B | Vulkan     |  99 |     512 |  1 |    0 |   1 |           pp512 |        567.96 ± 2.42 |
| qwen35moe 35B.A3B Q8_0         |  45.33 GiB |    34.66 B | Vulkan     |  99 |     512 |  1 |    0 |   1 |           tg128 |         25.87 ± 0.02 |

build: unknown (0)

Copy link
Contributor

@0cc4m 0cc4m left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for getting this out so quick! I have a few comments, but overall it looks good. I can confirm it runs correctly.

Let's focus on getting it running first, optimization can be a follow-up PR.

@lemmi
Copy link

lemmi commented Mar 10, 2026

@IIIIIllllIIIIIlllll The difference in performance is mostly likely caused by the quantization: UD-Q8_K_XL vs Q8_0

Additionally my Minisforum MS-S1 has a very generous power budget (I think the highest of all available boards with 160W short term and 130W long term).

@digitalscream
Copy link

digitalscream commented Mar 10, 2026

For what it's worth, I'm seeing some crazy numbers here in terms of performance on my R9700s:

Model # GPUs tg (6c770d1) tg PR Diff
Qwen3.5-35B-A3B-UD-Q5_K_XL 1 98.23 122.07 +24.2%
Qwen3-Coder-Next-UD-Q4_K_XL 2 54.81 77.99 +42.3%

These are anecdotal numbers from ad-hoc testing, but they're pretty consistent across prompts - just trying to give an indication of why my jaw dropped.

The single-GPU boost is great, but I had to check the dual-GPU result with Qwen3-Coder-Next a bunch of times; it's genuinely usable in an agentic context now. With Cline at the helm, it never drops below 50t/s even as the context fills up past 180k, and both GPUs have gone from ~63W constant to 70W with spikes up to 180W (still nowhere near the 300W max, but that's on the tensor parallel PR when it's fixed for Vulkan...?).

@ProgenyAlpha - I feel like some of us may owe you a beer or five. I might actually be getting some value out of these cards now!

@ProgenyAlpha
Copy link
Contributor Author

Thanks for the review — all 7 items addressed. Pipeline declarations moved to a [3][2] array, shader now uses A_TYPE/D_TYPE/FLOAT_TYPE conventions, scale moved to push constants, dead code removed. 16/16 tests passing. Had a couple of these on my list already but your feedback saved me time and caught things I'd have missed.

@ProgenyAlpha
Copy link
Contributor Author

Pushed the review fixes — all 7 items from @0cc4m addressed. Also ran a full model bench on 890M (integrated):

Qwen3-Coder-Next REAM Q4_K_M (60B)

  • PP-512: 157.67 t/s
  • TG-128: 20.90 t/s

@digitalscream those dual-GPU numbers are wild — 42% improvement and usable agentic speeds at 180k context is exactly the kind of thing that makes this worth doing. I have a tip jar but I'm doing it for the community so not necessary. I love the feedback from different platforms so feel free to keep testing changes.

@lemmi thanks again for the Strix Halo data. PP issue is upstream of this shader (SSM_CONV workgroup scaling, #18725) so it won't change here.

@IIIIIllllIIIIIlllll lemmi nailed it — UD-Q8_K_XL uses importance-weighted mixed precision which generally outperforms uniform Q8_0 across backends.

@ITankForCAD
Copy link

Great work @ProgenyAlpha

@zedbytes
Copy link

@ProgenyAlpha thanks for this, 24% to 42% on R9700 is godsent !

if i may , you can have a look at the following for you next adventure
#19890 (reply in thread)
nocompute flag giving 16% improvement, which may indicate scope for improvement on graphics-queue/compute-queue logic for amd vulkan

@ggerganov
Copy link
Member

Note that for the chunked version you'll need to have the changes from #20340

@ProgenyAlpha
Copy link
Contributor Author

ProgenyAlpha commented Mar 10, 2026

Pushed f16 mixed-precision state — stores the 128-element state array in float16_t, keeps all arithmetic in float32. No precision loss (16/16 tests), lower register pressure.

890M benchmarks (Qwen3-Coder-Next REAM Q4_K_M):

Metric Before After Change
PP-512 157.67 t/s 170.71 t/s +8.3%
TG-128 20.90 t/s 21.20 t/s +1.4%

f16 pipeline auto-selects when the device supports shaderFloat16, falls back to f32 otherwise.

@ggerganov saw your note about #20340 — will rebase onto that once it lands. The rq1neqk1 broadcast change needs the CPU side too. I'll work on that and have first pass chunked shaders ready and waiting to be tested.

@zedbytes interesting find on the nocompute flag. That's a driver-level queue scheduling optimization — separate from shader work but worth investigating for AMD Vulkan in general.

@ProgenyAlpha
Copy link
Contributor Author

@digitalscream @lemmi if you get a chance, would be great to see updated numbers with the latest push. The f16 state change gave +8.3% PP on my 890M — curious if the improvement scales differently on discrete cards with more CUs.

@digitalscream
Copy link

@ProgenyAlpha - OK, with the latest push:

Model # GPUs Before PP Before TG After PP After TG
Qwen3.5-35B-A3B-UD-Q5_K_XL 1 273.7 124.4 292.95 123.38
Qwen3-Coder-Next-UD-Q4_K_XL 2 157.11 78.22 157.2 78.26

So...+7% PP and -0.8% TG for single GPU, negligible changes for dual GPU.

@0cc4m
Copy link
Contributor

0cc4m commented Mar 10, 2026

Please split out the chunked support into a follow-up instead of expanding the scope of this PR. Even float16 might be further than step 1 should be.

@lemmi
Copy link

lemmi commented Mar 11, 2026

Wasn't able to see any meaningful changes on 23fbfcb for PP on my end.

@IIIIIllllIIIIIlllll
Copy link

Same 8060S setup. In my testing, although the current PR shows performance improvement compared to Master PR, Qwen3.5-35B0A3B cannot reach that speed (50 token/s), quite strange :(

command:

/home/mark/llama.cpp/llama.cpp-vulkan-gated-delta-net/build/bin/llama-bench 
-m /home/mark/Models/Q8/Qwen3.5-35B-A3B-Q8_0/Qwen3.5-35B-A3B-UD-Q8_K_XL.gguf 
--repetitions 5 --output md --mmap 0 --delay 0 
--n-prompt 512 
--n-gen 128 
--n-depth 0 
--batch-size 512 
--ubatch-size 512 
--cache-type-k f16 --cache-type-v f16 --threads 16 
--n-gpu-layers 99 
--n-cpu-moe 0 --flash-attn 1 --direct-io 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = AMD Radeon Graphics (RADV GFX1151) (radv) | uma: 1 | fp16: 1 | bf16: 0 | warp size: 64 | shared memory: 65536 | int dot: 1 | matrix cores: KHR_coopmat
| model                          |       size |     params | backend    | ngl | n_batch | fa | mmap | dio |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | -: | ---: | --: | --------------: | -------------------: |
| qwen35moe 35B.A3B Q8_0         |  45.33 GiB |    34.66 B | Vulkan     |  99 |     512 |  1 |    0 |   1 |           pp512 |        567.96 ± 2.42 |
| qwen35moe 35B.A3B Q8_0         |  45.33 GiB |    34.66 B | Vulkan     |  99 |     512 |  1 |    0 |   1 |           tg128 |         25.87 ± 0.02 |

build: unknown (0)

Sorry to bother everyone.
On my strix halo, this commit(a57f2ac) did not result in any performance improvement. @ProgenyAlpha

@ProgenyAlpha
Copy link
Contributor Author

@0cc4m Done — stripped this PR back to autoregressive-only. Removed chunked shaders, f16 state, and all the infrastructure that came with them. This is now just the base GATED_DELTA_NET op support + review fixes.

Split into two follow-up PRs:

13/13 backend-ops tests passing. Benchmarks unchanged from before since the autoregressive path is the same.

@ProgenyAlpha
Copy link
Contributor Author

@lemmi @IIIIIllllIIIIIlllll Thanks for retesting. The f16 state improvement looks hardware-dependent — +8.3% PP on my 890M, +7% on @digitalscream's R9700S, but flat on Strix Halo. Might be related to how the register file handles f16 on different RDNA3.5 configs. Either way, f16 is now split out into its own PR (#20376) so it won't hold up the base support.

This PR is stripped back to autoregressive-only per @0cc4m's review — should be ready for another look.

@ProgenyAlpha
Copy link
Contributor Author

@lemmi heads up — took a look at the PP/ubatch scaling issue from #18725 and put up a fix in #20379. SSM_CONV workgroup dispatch was the bottleneck. +28% at ub2048 on my 890M, degradation cliff gone. Would love your numbers on the 8060S.

@0cc4m
Copy link
Contributor

0cc4m commented Mar 11, 2026

Mark the PR as ready when you are done, please.

@ProgenyAlpha ProgenyAlpha marked this pull request as ready for review March 11, 2026 07:20
@ProgenyAlpha ProgenyAlpha requested a review from ggerganov as a code owner March 11, 2026 07:20
@github-actions github-actions bot added server SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language Apple Metal https://en.wikipedia.org/wiki/Metal_(API) OpenCL Issues specific to the OpenCL backend labels Mar 12, 2026
@ProgenyAlpha
Copy link
Contributor Author

Rebased on master and fixed the Q/K broadcast to use the interleaved layout from #20340. 13/13 tests passing.

@0cc4m
Copy link
Contributor

0cc4m commented Mar 12, 2026

No, that was not a correct rebase.

ProgenyAlpha and others added 6 commits March 12, 2026 01:33
Implements the fused gated delta net recurrence as a Vulkan compute
shader with full support for scalar gate, KDA vector gate, GQA
broadcast, multi-token sequences, and permuted (non-contiguous) q/k
inputs. Specialization constants select head size (32/64/128) and
KDA mode at pipeline creation time.

Passes all 13 test-backend-ops cases on AMD Radeon 890M (RADV GFX1150).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- vec4 dot products on all inner loops (dp4 hardware intrinsic)
- Cache exp(g) in shared memory for KDA path, eliminating ~32K
  redundant global reads and ~16K redundant exp() calls per token
- vec4 fused decay + rank-1 update (3 vec4 ops vs 12 scalar ops)
- Add perf benchmark cases for GATED_DELTA_NET to test-backend-ops

KDA TG: +5.4% throughput. Non-KDA: no regressions.
13/13 test-backend-ops passing on AMD Radeon 890M (RADV GFX1150).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Pipeline array refactor [3][2], A_TYPE/D_TYPE/FLOAT_TYPE shader macros,
scale in push constants, supports_op fix, dispatch restructuring.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Wrap data_q, data_k, and data_g buffer reads with FLOAT_TYPE() casts
to ensure correct behavior across all Vulkan configurations.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Adapt to the interleaved broadcast convention from ggml-org#20340:
head_id / rq1 → head_id % neq1

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@ProgenyAlpha ProgenyAlpha force-pushed the vulkan-gated-delta-net branch from d72268f to d5300db Compare March 12, 2026 05:39
@ProgenyAlpha
Copy link
Contributor Author

Saw it immediately after, rebase picked up duplicate commits from master, didn't catch it before pushing. Cleaned up with cherry-pick, should be 6 commits on master now.

@a4lg
Copy link

a4lg commented Mar 12, 2026

Thanks for correct rebase!

EDIT: Checked.
Although I haven't done exhaustive tests, at least there's no obvious failures (Qwen3.5 models seem working correctly).

@0cc4m
Copy link
Contributor

0cc4m commented Mar 12, 2026

I'm gonna skip waiting for the CI to unblock other PRs. I checked locally that it works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Apple Metal https://en.wikipedia.org/wiki/Metal_(API) devops improvements to build systems and github actions documentation Improvements or additions to documentation examples ggml changes relating to the ggml tensor library for machine learning model Model specific Nvidia GPU Issues specific to Nvidia GPUs OpenCL Issues specific to the OpenCL backend python python script changes script Script related server SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language testing Everything test related Vulkan Issues specific to the Vulkan backend

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants