Skip to content

vulkan: fix SSM_CONV PP scaling with large ubatch sizes#20379

Merged
0cc4m merged 2 commits intoggml-org:masterfrom
ProgenyAlpha:vulkan-ssm-conv-optimize
Mar 12, 2026
Merged

vulkan: fix SSM_CONV PP scaling with large ubatch sizes#20379
0cc4m merged 2 commits intoggml-org:masterfrom
ProgenyAlpha:vulkan-ssm-conv-optimize

Conversation

@ProgenyAlpha
Copy link
Contributor

Fixes #18725

The SSM_CONV shader dispatched one token per Y workgroup, each doing only nc (typically 4) multiply-adds. At ubatch=2048 this meant 2048 workgroups in Y with almost no work per launch — workgroup dispatch overhead dominated.

Changes:

  • Tile 16 tokens per workgroup using a 2D local size (32×16)
  • Add vec4 dot product fast path for the common nc=4 (d_conv) case
  • Pipeline wg_denoms updated from {32,1,1} to {32,16,1}

45/45 SSM_CONV backend-ops tests passing.

test-backend-ops perf (ne_a=[515,3328], nc=4):

us/run GB/s
master 526.24 24.29
this PR 203.56 62.79
speedup 2.59x

Model bench (Qwen3-Coder-Next REAM Q4_K_M, pp2048, AMD 890M):

ubatch master (t/s) this PR (t/s) change
256 148.43 150.36 +1.3%
512 171.38 175.82 +2.6%
1024 147.46 181.14 +22.8%
2048 126.34 161.56 +27.9%

Master shows the #18725 pattern — PP drops from 171 at ub512 to 126 at ub2048. With this fix, PP peaks at ub1024 (181) and stays strong at ub2048 (162). The degradation cliff is gone.

Tested on AMD Radeon 890M (RDNA3.5, 8 CUs, Strix Point integrated). Would appreciate testing from @lemmi on the discrete 8060S where the original issue was reported.

Tile tokens into 2D workgroups (32x16) to reduce workgroup launch
overhead at large ubatch sizes. Add vec4 fast path for nc=4 (common
d_conv size). Fixes PP performance degradation with ubatch > 512.

Ref: ggml-org#18725

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@ProgenyAlpha ProgenyAlpha requested a review from 0cc4m as a code owner March 11, 2026 03:51
@github-actions github-actions bot added Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels Mar 11, 2026
Copy link
Collaborator

@jeffbolznv jeffbolznv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lemmi
Copy link

lemmi commented Mar 11, 2026

All numbers are up. Especially the big model sees a huge improvement with larger u-batch sizes. There still is a noticeable drop in performance after a certain u-batch size.

master (e1a3999):

model size params backend ngl n_ubatch fa mmap dio test t/s
qwen3next 80B.A3B Q4_K - Medium 46.20 GiB 79.67 B Vulkan 99 256 1 0 1 pp2048 426.43 ± 0.00
qwen3next 80B.A3B Q4_K - Medium 46.20 GiB 79.67 B Vulkan 99 512 1 0 1 pp2048 498.42 ± 0.00
qwen3next 80B.A3B Q4_K - Medium 46.20 GiB 79.67 B Vulkan 99 1024 1 0 1 pp2048 411.18 ± 0.00
qwen3next 80B.A3B Q4_K - Medium 46.20 GiB 79.67 B Vulkan 99 2048 1 0 1 pp2048 374.57 ± 0.00
qwen35moe 35B.A3B Q8_0 34.36 GiB 34.66 B Vulkan 99 256 1 0 1 pp2048 635.30 ± 0.00
qwen35moe 35B.A3B Q8_0 34.36 GiB 34.66 B Vulkan 99 512 1 0 1 pp2048 720.11 ± 0.00
qwen35moe 35B.A3B Q8_0 34.36 GiB 34.66 B Vulkan 99 1024 1 0 1 pp2048 589.41 ± 0.00
qwen35moe 35B.A3B Q8_0 34.36 GiB 34.66 B Vulkan 99 2048 1 0 1 pp2048 519.26 ± 0.00
qwen35moe 122B.A10B Q5_K - Medium 85.60 GiB 122.11 B Vulkan 99 256 1 0 1 pp2048 168.10 ± 0.00
qwen35moe 122B.A10B Q5_K - Medium 85.60 GiB 122.11 B Vulkan 99 512 1 0 1 pp2048 199.71 ± 0.00
qwen35moe 122B.A10B Q5_K - Medium 85.60 GiB 122.11 B Vulkan 99 1024 1 0 1 pp2048 195.53 ± 0.00
qwen35moe 122B.A10B Q5_K - Medium 85.60 GiB 122.11 B Vulkan 99 2048 1 0 1 pp2048 211.37 ± 0.00

PR (209464006):

model size params backend ngl n_ubatch fa mmap dio test t/s
qwen3next 80B.A3B Q4_K - Medium 46.20 GiB 79.67 B Vulkan 99 256 1 0 1 pp2048 439.39 ± 0.00
qwen3next 80B.A3B Q4_K - Medium 46.20 GiB 79.67 B Vulkan 99 512 1 0 1 pp2048 518.92 ± 0.00
qwen3next 80B.A3B Q4_K - Medium 46.20 GiB 79.67 B Vulkan 99 1024 1 0 1 pp2048 538.21 ± 0.00
qwen3next 80B.A3B Q4_K - Medium 46.20 GiB 79.67 B Vulkan 99 2048 1 0 1 pp2048 462.32 ± 0.00
qwen35moe 35B.A3B Q8_0 34.36 GiB 34.66 B Vulkan 99 256 1 0 1 pp2048 656.18 ± 0.00
qwen35moe 35B.A3B Q8_0 34.36 GiB 34.66 B Vulkan 99 512 1 0 1 pp2048 777.28 ± 0.00
qwen35moe 35B.A3B Q8_0 34.36 GiB 34.66 B Vulkan 99 1024 1 0 1 pp2048 762.16 ± 0.00
qwen35moe 35B.A3B Q8_0 34.36 GiB 34.66 B Vulkan 99 2048 1 0 1 pp2048 638.68 ± 0.00
qwen35moe 122B.A10B Q5_K - Medium 85.60 GiB 122.11 B Vulkan 99 256 1 0 1 pp2048 170.49 ± 0.00
qwen35moe 122B.A10B Q5_K - Medium 85.60 GiB 122.11 B Vulkan 99 512 1 0 1 pp2048 209.06 ± 0.00
qwen35moe 122B.A10B Q5_K - Medium 85.60 GiB 122.11 B Vulkan 99 1024 1 0 1 pp2048 229.97 ± 0.00
qwen35moe 122B.A10B Q5_K - Medium 85.60 GiB 122.11 B Vulkan 99 2048 1 0 1 pp2048 251.91 ± 0.00

Bonus pp8128 run for the 122B model, since I didn't see the drop:

model size params backend ngl n_ubatch fa mmap dio test t/s
qwen35moe 122B.A10B Q5_K - Medium 85.60 GiB 122.11 B Vulkan 99 256 1 0 1 pp8192 171.67 ± 0.00
qwen35moe 122B.A10B Q5_K - Medium 85.60 GiB 122.11 B Vulkan 99 512 1 0 1 pp8192 215.54 ± 0.00
qwen35moe 122B.A10B Q5_K - Medium 85.60 GiB 122.11 B Vulkan 99 1024 1 0 1 pp8192 242.35 ± 0.00
qwen35moe 122B.A10B Q5_K - Medium 85.60 GiB 122.11 B Vulkan 99 2048 1 0 1 pp8192 259.36 ± 0.00
qwen35moe 122B.A10B Q5_K - Medium 85.60 GiB 122.11 B Vulkan 99 4096 1 0 1 pp8192 229.01 ± 0.00
qwen35moe 122B.A10B Q5_K - Medium 85.60 GiB 122.11 B Vulkan 99 8192 1 0 1 pp8192 229.46 ± 0.00

@jeffbolznv
Copy link
Collaborator

A nice boost on 5090:

before

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -m c:\models\Qwen3.5-35B-A3B-Q4_K_M.gguf -p 2048 -n 0 -ub 512,1024,2048 -r 10 --prio 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 5090 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
load_backend: loaded Vulkan backend from Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo\ggml-vulkan.dll
load_backend: loaded CPU backend from Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo\ggml-cpu.dll
| model                          |       size |     params | backend    | ngl | n_ubatch |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | --------------: | -------------------: |
| qwen35moe 35B.A3B Q4_K - Medium |  20.49 GiB |    34.66 B | Vulkan     |  99 |      512 |          pp2048 |      6467.69 ± 53.57 |
| qwen35moe 35B.A3B Q4_K - Medium |  20.49 GiB |    34.66 B | Vulkan     |  99 |     1024 |          pp2048 |      7618.89 ± 27.43 |
| qwen35moe 35B.A3B Q4_K - Medium |  20.49 GiB |    34.66 B | Vulkan     |  99 |     2048 |          pp2048 |      7851.51 ± 25.81 |

after

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -m c:\models\Qwen3.5-35B-A3B-Q4_K_M.gguf -p 2048 -n 0 -ub 512,1024,2048 -r 10 --prio 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 5090 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
load_backend: loaded Vulkan backend from Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo\ggml-vulkan.dll
load_backend: loaded CPU backend from Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo\ggml-cpu.dll
| model                          |       size |     params | backend    | ngl | n_ubatch |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | --------------: | -------------------: |
| qwen35moe 35B.A3B Q4_K - Medium |  20.49 GiB |    34.66 B | Vulkan     |  99 |      512 |          pp2048 |      6755.28 ± 36.26 |
| qwen35moe 35B.A3B Q4_K - Medium |  20.49 GiB |    34.66 B | Vulkan     |  99 |     1024 |          pp2048 |      8031.69 ± 25.96 |
| qwen35moe 35B.A3B Q4_K - Medium |  20.49 GiB |    34.66 B | Vulkan     |  99 |     2048 |          pp2048 |      8281.31 ± 20.18 |

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@ProgenyAlpha
Copy link
Contributor Author

@lemmi Those 122B numbers are solid — ub2048 going from 211 to 252 and holding through pp8192 is exactly what we want. The remaining drop at very large ubatch is likely CONCAT or memory bandwidth, not SSM_CONV anymore. Could confirm with a perf logger run if you're curious but that's a separate issue.

@jeffbolznv Nice to see it helps on the 5090 too. +5% on already-fast hardware from a dispatch change is free money.

@0cc4m
Copy link
Contributor

0cc4m commented Mar 12, 2026

LGTM, good improvement, thank you.

@0cc4m 0cc4m merged commit 40c550d into ggml-org:master Mar 12, 2026
71 of 78 checks passed
tekintian added a commit to tekintian/llama.cpp that referenced this pull request Mar 12, 2026
* 'master' of github.com:ggml-org/llama.cpp: (33 commits)
  convert : better mtp check and fix return [no ci] (ggml-org#20419)
  vulkan: fix SSM_CONV PP scaling with large ubatch sizes (ggml-org#20379)
  New conversations now auto-select the first loaded model (ggml-org#20403)
  ggml-virtgpu: Fix some build commands (ggml-org#20341)
  metal : avoid divisions in bin kernel (ggml-org#20426)
  ci: Setup self-hosted CI for Intel Linux Vulkan backend (ggml-org#20154)
  vulkan: fix l2_norm epsilon handling (ggml-org#20350)
  vulkan: fix OOB check in flash_attn_mask_opt (ggml-org#20296)
  vulkan: Fix ErrorOutOfHostMemory on Intel GPU when loading large models with --no-mmap (ggml-org#20059)
  opencl: use larger workgroup size for get_rows (ggml-org#20316)
  opencl: add cumsum op (ggml-org#18981)
  hip: compile debug builds with -O2 on hip to avoid a compiler bug (ggml-org#20392)
  common/parser: add GigaChatV3/3.1 models support (ggml-org#19931)
  model : add support for Phi4ForCausalLMV (ggml-org#20168)
  graph : add optional scale parameter to build_lora_mm [no ci] (ggml-org#20427)
  common : fix --n-cpu-moe, --cpu-moe for models with fused gate + up (ggml-org#20416)
  ggml-webgpu: Add supports for `GGML_OP_REPEAT` (ggml-org#20230)
  llama : enable chunked fused GDN path (ggml-org#20340)
  llama : whitespace cleanup (ggml-org#20422)
  ggml : add NVFP4 quantization type support (ggml-org#19769)
  ...
am17an pushed a commit to am17an/llama.cpp that referenced this pull request Mar 12, 2026
* vulkan: optimize SSM_CONV workgroup dispatch for large ubatch

Tile tokens into 2D workgroups (32x16) to reduce workgroup launch
overhead at large ubatch sizes. Add vec4 fast path for nc=4 (common
d_conv size). Fixes PP performance degradation with ubatch > 512.

Ref: ggml-org#18725

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* vulkan: remove unused shared memory declaration in SSM_CONV

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Progeny Alpha <ProgenyAlpha@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
ProgenyAlpha added a commit to ProgenyAlpha/llama.cpp that referenced this pull request Mar 13, 2026
The 2D tiling (32x16 workgroups) from ggml-org#20379 causes DeviceLost on
multi-GPU RADV setups. Revert to 1D dispatch but keep the vec4 dot
product fast path for nc=4.

Fixes ggml-org#20462
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Vulkan Issues specific to the Vulkan backend

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Misc. bug: Qwen3-Next PP performance loss with larger ubatch-size (Strix Halo, Vulkan)

4 participants