Skip to content

vulkan: add v_dot2_f32_f16 support in matrix-matrix multiplication and Flash Attention#24123

Merged
0cc4m merged 4 commits into
masterfrom
0cc4m/vulkan-valve-dot2
Jun 9, 2026
Merged

vulkan: add v_dot2_f32_f16 support in matrix-matrix multiplication and Flash Attention#24123
0cc4m merged 4 commits into
masterfrom
0cc4m/vulkan-valve-dot2

Conversation

@0cc4m

@0cc4m 0cc4m commented Jun 4, 2026

Copy link
Copy Markdown
Contributor

Overview

This PR adds basic support for the Vulkan extension VK_VALVE_shader_mixed_float_dot_product. The background to this is that AMD Vega20, Navi14 and RDNA2+ GPUs have fp16 dot2 instructions for machine learning acceleration that are not emitted by the shader compiler due to numerical inconsistencies. The extension allows shaders to manually emit them.

This PR adds support for the v_dot2_f32_f16 fp16 packed dot product with fp32 accumulator in matrix-matrix multiplications and Flash Attention. This is a good improvement for AMD GPUs with this instruction, but without coopmat support.

AMD Radeon Pro VII (Vega20) Benchmarks
Test Before After Δ%
MUL_MAT(type_a=f32,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 6.220 8.170 +31.35%
MUL_MAT(type_a=f16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 4.490 9.840 +119.15%
MUL_MAT(type_a=bf16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 5.380 5.320 -1.12%
MUL_MAT(type_a=q4_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 15.370 15.230 -0.91%
MUL_MAT(type_a=q4_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 14.950 14.800 -1.00%
MUL_MAT(type_a=q5_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 9.980 9.900 -0.80%
MUL_MAT(type_a=q5_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 9.890 9.820 -0.71%
MUL_MAT(type_a=q8_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 13.190 13.150 -0.30%
MUL_MAT(type_a=q1_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 4.880 10.730 +119.88%
MUL_MAT(type_a=mxfp4,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 14.690 14.610 -0.54%
MUL_MAT(type_a=nvfp4,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 4.740 10.230 +115.82%
MUL_MAT(type_a=q2_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 7.240 7.220 -0.28%
MUL_MAT(type_a=q3_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 10.360 10.310 -0.48%
MUL_MAT(type_a=q4_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 11.590 11.550 -0.35%
MUL_MAT(type_a=q5_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 8.190 8.180 -0.12%
MUL_MAT(type_a=q6_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 7.950 7.940 -0.13%
MUL_MAT(type_a=iq2_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 4.690 9.900 +111.09%
MUL_MAT(type_a=iq2_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 4.670 9.810 +110.06%
MUL_MAT(type_a=iq2_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 4.610 9.640 +109.11%
MUL_MAT(type_a=iq3_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 4.500 9.050 +101.11%
MUL_MAT(type_a=iq1_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 4.700 9.990 +112.55%
MUL_MAT(type_a=iq1_m,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 4.640 9.760 +110.34%
MUL_MAT(type_a=iq4_nl,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 4.750 10.220 +115.16%
MUL_MAT(type_a=iq3_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 4.500 8.990 +99.78%
MUL_MAT(type_a=iq4_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) 4.560 9.510 +108.55%
MUL_MAT_ID(type_a=f32,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=512,k=2048) 2.640 3.000 +13.64%
MUL_MAT_ID(type_a=f16,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=512,k=2048) 2.370 4.700 +98.31%
MUL_MAT_ID(type_a=q4_0,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=512,k=2048) 6.320 6.280 -0.63%
MUL_MAT_ID(type_a=q8_0,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=512,k=2048) 5.660 5.700 +0.71%
MUL_MAT_ID(type_a=q4_K,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=512,k=2048) 6.030 6.040 +0.17%
MUL_MAT_ID(type_a=q6_K,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=512,k=2048) 4.400 4.420 +0.45%
MUL_MAT_ID(type_a=iq2_xs,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=512,k=2048) 2.400 5.060 +110.83%
MUL_MAT_ID(type_a=f32,type_b=f32,n_mats=32,n_used=4,b=0,m=1792,n=512,k=2048) 3.400 3.470 +2.06%
MUL_MAT_ID(type_a=f16,type_b=f32,n_mats=32,n_used=4,b=0,m=1792,n=512,k=2048) 3.270 5.760 +76.15%
MUL_MAT_ID(type_a=q4_0,type_b=f32,n_mats=32,n_used=4,b=0,m=1792,n=512,k=2048) 8.900 8.430 -5.28%
MUL_MAT_ID(type_a=q8_0,type_b=f32,n_mats=32,n_used=4,b=0,m=1792,n=512,k=2048) 7.950 7.640 -3.90%
MUL_MAT_ID(type_a=q4_K,type_b=f32,n_mats=32,n_used=4,b=0,m=1792,n=512,k=2048) 7.940 8.030 +1.13%
MUL_MAT_ID(type_a=q6_K,type_b=f32,n_mats=32,n_used=4,b=0,m=1792,n=512,k=2048) 6.110 6.190 +1.31%
MUL_MAT_ID(type_a=iq2_xs,type_b=f32,n_mats=32,n_used=4,b=0,m=1792,n=512,k=2048) 3.080 6.640 +115.58%
MUL_MAT_ID(type_a=mxfp4,type_b=f32,n_mats=32,n_used=4,b=0,m=2880,n=512,k=2880) 8.210 8.420 +2.56%
Test Before After Δ%
FLASH_ATTN_EXT(hsk=72,hsv=72,nh=16,nr23=[1,1],kv=5776,nb=5776,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 2.740 2.550 -6.93%
FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[8,1],kv=7680,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 1.410 1.490 +5.67%
FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[8,1],kv=7680,nb=4,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 1.890 2.080 +10.05%
FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[8,1],kv=7680,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=q4_0,type_V=q4_0,permute=[0,1,2,3]) 1.350 1.350 +0.00%
FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[8,1],kv=7680,nb=512,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=q4_0,type_V=q4_0,permute=[0,1,2,3]) 2.930 2.880 -1.71%
FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[8,1],kv=7680,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=q8_0,type_V=q8_0,permute=[0,1,2,3]) 1.380 1.390 +0.72%
FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[8,1],kv=7680,nb=512,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=q8_0,type_V=q8_0,permute=[0,1,2,3]) 3.150 3.160 +0.32%
FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[1,1],kv=4096,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 0.307 0.311 +1.19%
FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[4,1],kv=4096,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 0.822 0.890 +8.29%
FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[1,1],kv=4096,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 0.446 0.456 +2.06%
FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[4,1],kv=4096,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 1.220 1.280 +4.92%
FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[1,1],kv=8192,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 0.420 0.415 -1.24%
FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[4,1],kv=8192,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 0.997 1.110 +11.33%
FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[1,1],kv=8192,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 0.542 0.542 -0.14%
FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[4,1],kv=8192,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 1.380 1.450 +5.07%
FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[1,1],kv=16384,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 0.499 0.481 -3.75%
FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr23=[4,1],kv=16384,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 1.280 1.450 +13.28%
FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[1,1],kv=16384,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 0.591 0.590 -0.02%
FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[4,1],kv=16384,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_K=f16,type_V=f16,permute=[0,1,2,3]) 1.550 1.640 +5.81%
model size params ngl fa mmap test t/s (before) t/s (after) diff
llama 1B F16 2.05 GiB 1.10 B -1 1 0 pp512 2188.04 ± 7.71 4112.38 ± 13.61 +87.9%
llama 1B F16 2.05 GiB 1.10 B -1 1 0 tg128 227.67 ± 1.46 224.17 ± 0.67 -1.5%
llama 1B F16 2.05 GiB 1.10 B -1 1 0 pp512 @ d4096 1596.55 ± 10.42 2596.67 ± 5.88 +62.6%
llama 1B F16 2.05 GiB 1.10 B -1 1 0 tg128 @ d4096 191.41 ± 0.83 193.39 ± 0.86 +1.0%
llama 1B F16 2.05 GiB 1.10 B -1 1 0 pp512 @ d8192 1228.05 ± 8.12 1865.52 ± 16.46 +51.9%
llama 1B F16 2.05 GiB 1.10 B -1 1 0 tg128 @ d8192 175.61 ± 0.61 179.68 ± 0.34 +2.3%
llama 8B Q4_0 4.33 GiB 8.03 B -1 1 0 pp512 846.17 ± 2.87 828.88 ± 1.90 -2.0%
llama 8B Q4_0 4.33 GiB 8.03 B -1 1 0 tg128 101.36 ± 0.16 102.01 ± 0.29 +0.6%
llama 8B Q4_0 4.33 GiB 8.03 B -1 1 0 pp512 @ d4096 346.27 ± 0.94 449.96 ± 3.69 +29.9%
llama 8B Q4_0 4.33 GiB 8.03 B -1 1 0 tg128 @ d4096 87.77 ± 0.11 88.57 ± 0.11 +0.9%
llama 8B Q4_0 4.33 GiB 8.03 B -1 1 0 pp512 @ d8192 190.72 ± 1.93 275.47 ± 1.97 +44.4%
llama 8B Q4_0 4.33 GiB 8.03 B -1 1 0 tg128 @ d8192 79.34 ± 0.02 79.22 ± 0.94 -0.2%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B -1 1 0 pp512 755.22 ± 5.48 790.23 ± 9.42 +4.6%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B -1 1 0 tg128 71.77 ± 0.05 72.30 ± 0.01 +0.7%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B -1 1 0 pp512 @ d4096 379.66 ± 0.80 438.63 ± 1.63 +15.5%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B -1 1 0 tg128 @ d4096 52.90 ± 0.01 53.14 ± 0.01 +0.5%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B -1 1 0 pp512 @ d8192 251.56 ± 0.33 286.02 ± 0.75 +13.7%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B -1 1 0 tg128 @ d8192 42.19 ± 0.10 42.30 ± 0.02 +0.3%
llama 8B IQ4_XS - 4.25 bpw 4.13 GiB 8.03 B -1 1 0 pp512 325.59 ± 0.34 633.87 ± 1.11 +94.7%
llama 8B IQ4_XS - 4.25 bpw 4.13 GiB 8.03 B -1 1 0 tg128 61.86 ± 0.06 61.70 ± 0.05 -0.3%
llama 8B IQ4_XS - 4.25 bpw 4.13 GiB 8.03 B -1 1 0 pp512 @ d4096 207.74 ± 0.77 388.60 ± 2.32 +87.1%
llama 8B IQ4_XS - 4.25 bpw 4.13 GiB 8.03 B -1 1 0 tg128 @ d4096 56.53 ± 0.05 56.69 ± 0.02 +0.3%
llama 8B IQ4_XS - 4.25 bpw 4.13 GiB 8.03 B -1 1 0 pp512 @ d8192 138.80 ± 1.06 251.68 ± 4.89 +81.3%
llama 8B IQ4_XS - 4.25 bpw 4.13 GiB 8.03 B -1 1 0 tg128 @ d8192 52.49 ± 0.02 52.89 ± 0.06 +0.8%

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure: YES, Claude wrote the code, I reviewed and tested it.

@0cc4m 0cc4m requested a review from a team as a code owner June 4, 2026 12:42
@github-actions github-actions Bot added Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels Jun 4, 2026
Comment thread ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp Outdated
@netrunnereve

Copy link
Copy Markdown
Collaborator

Wow they actually did it! Strange that it's a Valve/radv exclusive extension though, but I guess it really only makes sense for AMD.

@0cc4m

0cc4m commented Jun 8, 2026

Copy link
Copy Markdown
Contributor Author

@ggml-org/maintainers Another approval needed.

@0cc4m 0cc4m force-pushed the 0cc4m/vulkan-valve-dot2 branch from 16cfaba to 62a989f Compare June 8, 2026 09:06
@0cc4m

0cc4m commented Jun 9, 2026

Copy link
Copy Markdown
Contributor Author

@ggml-org/maintainers Sorry, need another approval because of the merge conflict.

@0cc4m 0cc4m merged commit b4e3dc6 into master Jun 9, 2026
29 checks passed
@0cc4m 0cc4m deleted the 0cc4m/vulkan-valve-dot2 branch June 9, 2026 11:27
Jcfunk added a commit to Jcfunk/llama.cpp that referenced this pull request Jun 11, 2026
* upstream/HEAD: (329 commits)
  vendor : update LibreSSL to 4.3.2 (ggml-org#24397)
  Remove padding and multiple D2D copies for MTP (ggml-org#24086)
  chat: fix LFM2/LFM2.5 ignoring json_schema (ggml-org#24377)
  CUDA: Fix ssm_scan_f32 data-races (ggml-org#24360)
  ci : bump komac version (ggml-org#24396)
  speculative : fix "ngram-map-k4v" name in logging (ggml-org#24253)
  webui: implement pinned conversations support (ggml-org#21387)
  graph: Fix granite speech model inference by applying embedding scale when deepstack is not used (ggml-org#24357)
  ci : fix windows release (ggml-org#24369)
  ui: add opt-in run_javascript frontend tool (ggml-org#24244)
  mtmd: build_vit batching (ggml-org#24352)
  vulkan: reduce iq1 shared memory usage for mul_mm (ggml-org#24287)
  vulkan: add `v_dot2_f32_f16` support in matrix-matrix multiplication and Flash Attention (ggml-org#24123)
  ui: Fix excessive style recalculation on hover (ggml-org#24243)
  mtmd: refactor video subproc handling (ggml-org#24316)
  server: log prompts to directory (ggml-org#22031)
  ui: fix mobile chat form overflow and bust stale bundle cache (ggml-org#24158)
  ggml : add GGML_OP_COL2IM_1D (ggml-org#24206)
  server : do not clear slots without unified KV cache (ggml-org#24190)
  models : fix plamo2 attention_key/value_length regression (ggml-org#24317)
  ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Vulkan Issues specific to the Vulkan backend

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants