SYCL: add BF16 to DMMV kernel path (~4x tg speedup on Intel Arc)#21580
Conversation
BF16 models had no dedicated token generation kernel — they fell through to the generic full-GEMM path, resulting in ~14% memory bandwidth utilization on Intel Arc GPUs. This adds BF16 support to the DMMV (dequantize mul-mat-vec) path, matching the existing F16 implementation. Fixes ggml-org#20478
|
@PMZFX Here is no performance increase: base: ./build/bin/llama-bench -m ../models/qwen2.5-1.5b-instruct-fp16.gguf
This PR:
|
There was a problem hiding this comment.
@PMZFX
The performance is increased.
Some UT cases are fault:
./build/bin/test-backend-ops -b SYCL0
MUL_MAT(type_a=bf16,type_b=f32,m=128,n=1,k=1056,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=2112,o=1)
MUL_MAT(type_a=bf16,type_b=f32,m=129,n=1,k=1056,bs=[4,1],nr=[4,1],per=[0,1,2,3],k_v=2112,o=1)
MUL_MAT(type_a=bf16,type_b=f32,m=128,n=1,k=1056,bs=[2,3],nr=[1,1],per=[0,1,2,3],k_v=2112,o=1)
MUL_MAT(type_a=bf16,type_b=f32,m=128,n=1,k=1056,bs=[8,3],nr=[1,1],per=[0,1,2,3],k_v=2112,o=1)
MUL_MAT(type_a=bf16,type_b=f32,m=128,n=1,k=1056,bs=[8,3],nr=[4,1],per=[0,1,2,3],k_v=2112,o=1)
MUL_MAT(type_a=bf16,type_b=f32,m=129,n=1,k=1056,bs=[8,3],nr=[4,1],per=[0,1,2,3],k_v=2112,o=1)
....
|
@PMZFX Thank you! |
The qk=1 kernel (used for F16 and BF16) iterates with stride 2*GGML_SYCL_DMMV_X (= 64 on Intel targets where WARP_SIZE=16). When ncols is a multiple of DMMV_X (32) but not of 2*DMMV_X (64), the last warp iteration accesses elements at col >= ncols, producing NaN for the final row and wrong values for interior rows. Fix: tighten can_use_dequantize_mul_mat_vec to require ne[0] % (2*DMMV_X) == 0 for F16/BF16 types, and update the ASSERT in the BF16 launcher to match. Quantized types use block-structured kernels with different access patterns and keep the existing DMMV_X check. Verified: test-backend-ops MUL_MAT passes 913/913 on Intel Arc Pro B70. Previously failing: m=128/129 n=1 k=1056 cases (NaN and ERR > 0.0005). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
|
Found and fixed. The Fix: tightened F16 tests passed with k=1056 because those test tensors are non-contiguous, routing F16 through Full |
arthw
left a comment
There was a problem hiding this comment.
It shows good performance increasing:
gemma-4-E4B-it-BF16.gguf on B60
| Test | Base t/s | Primary t/s | Increase Rate (Primary vs Base) |
|---|---|---|---|
| pp512 | 1010.34 | 1051.54 | 4.08% |
| tg128 | 6.37 | 35.94 | 464.21% |
Thank you!
…l-org#21580) * SYCL: add BF16 to DMMV kernel path for ~4x token generation speedup BF16 models had no dedicated token generation kernel — they fell through to the generic full-GEMM path, resulting in ~14% memory bandwidth utilization on Intel Arc GPUs. This adds BF16 support to the DMMV (dequantize mul-mat-vec) path, matching the existing F16 implementation. Fixes ggml-org#20478 * SYCL: fix BF16 DMMV out-of-bounds when ncols % 64 != 0 The qk=1 kernel (used for F16 and BF16) iterates with stride 2*GGML_SYCL_DMMV_X (= 64 on Intel targets where WARP_SIZE=16). When ncols is a multiple of DMMV_X (32) but not of 2*DMMV_X (64), the last warp iteration accesses elements at col >= ncols, producing NaN for the final row and wrong values for interior rows. Fix: tighten can_use_dequantize_mul_mat_vec to require ne[0] % (2*DMMV_X) == 0 for F16/BF16 types, and update the ASSERT in the BF16 launcher to match. Quantized types use block-structured kernels with different access patterns and keep the existing DMMV_X check. Verified: test-backend-ops MUL_MAT passes 913/913 on Intel Arc Pro B70. Previously failing: m=128/129 n=1 k=1056 cases (NaN and ERR > 0.0005). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
…l-org#21580) * SYCL: add BF16 to DMMV kernel path for ~4x token generation speedup BF16 models had no dedicated token generation kernel — they fell through to the generic full-GEMM path, resulting in ~14% memory bandwidth utilization on Intel Arc GPUs. This adds BF16 support to the DMMV (dequantize mul-mat-vec) path, matching the existing F16 implementation. Fixes ggml-org#20478 * SYCL: fix BF16 DMMV out-of-bounds when ncols % 64 != 0 The qk=1 kernel (used for F16 and BF16) iterates with stride 2*GGML_SYCL_DMMV_X (= 64 on Intel targets where WARP_SIZE=16). When ncols is a multiple of DMMV_X (32) but not of 2*DMMV_X (64), the last warp iteration accesses elements at col >= ncols, producing NaN for the final row and wrong values for interior rows. Fix: tighten can_use_dequantize_mul_mat_vec to require ne[0] % (2*DMMV_X) == 0 for F16/BF16 types, and update the ASSERT in the BF16 launcher to match. Quantized types use block-structured kernels with different access patterns and keep the existing DMMV_X check. Verified: test-backend-ops MUL_MAT passes 913/913 on Intel Arc Pro B70. Previously failing: m=128/129 n=1 k=1056 cases (NaN and ERR > 0.0005). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
* origin/master: server: only parse empty msg if continuing an assistant msg (ggml-org#23506) perplexity : fix integer overflow (ggml-org#23496) SYCL: improve MoE prefill throughput (ggml-org#23142) sycl : Level Zero detection in ggml_sycl_init (ggml-org#23097) SYCL : gated_delta_net K>1 (ggml-org#23174) SYCL: add BF16 to DMMV kernel path (~4x tg speedup on Intel Arc) (ggml-org#21580) docs: Update documentation with Granite 4.0/4.1 (ggml-org#23404) ggml-zendnn : add Q8_0 quantization support (ggml-org#23414) cmake : build router app only during standalone builds (ggml-org#23521) vocab : fix HybridDNA tokenizer (ggml-org#23466) cmake : add install() for impl libraries + fix apple builds (ggml-org#23511) CUDA: fix PDL CC check for JIT compilation (ggml-org#23471) cmake : remove STATIC from impl libraries, enable LLAMA_BUILD_APP by default (ggml-org#23462) Update WebGPU support and add link to blog/demo (ggml-org#23483) vulkan: fuse snake activation (mul, sin, sqr, mul, add) (ggml-org#22855)
…l-org#21580) * SYCL: add BF16 to DMMV kernel path for ~4x token generation speedup BF16 models had no dedicated token generation kernel — they fell through to the generic full-GEMM path, resulting in ~14% memory bandwidth utilization on Intel Arc GPUs. This adds BF16 support to the DMMV (dequantize mul-mat-vec) path, matching the existing F16 implementation. Fixes ggml-org#20478 * SYCL: fix BF16 DMMV out-of-bounds when ncols % 64 != 0 The qk=1 kernel (used for F16 and BF16) iterates with stride 2*GGML_SYCL_DMMV_X (= 64 on Intel targets where WARP_SIZE=16). When ncols is a multiple of DMMV_X (32) but not of 2*DMMV_X (64), the last warp iteration accesses elements at col >= ncols, producing NaN for the final row and wrong values for interior rows. Fix: tighten can_use_dequantize_mul_mat_vec to require ne[0] % (2*DMMV_X) == 0 for F16/BF16 types, and update the ASSERT in the BF16 launcher to match. Quantized types use block-structured kernels with different access patterns and keep the existing DMMV_X check. Verified: test-backend-ops MUL_MAT passes 913/913 on Intel Arc Pro B70. Previously failing: m=128/129 n=1 k=1056 cases (NaN and ERR > 0.0005). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
…l-org#21580) * SYCL: add BF16 to DMMV kernel path for ~4x token generation speedup BF16 models had no dedicated token generation kernel — they fell through to the generic full-GEMM path, resulting in ~14% memory bandwidth utilization on Intel Arc GPUs. This adds BF16 support to the DMMV (dequantize mul-mat-vec) path, matching the existing F16 implementation. Fixes ggml-org#20478 * SYCL: fix BF16 DMMV out-of-bounds when ncols % 64 != 0 The qk=1 kernel (used for F16 and BF16) iterates with stride 2*GGML_SYCL_DMMV_X (= 64 on Intel targets where WARP_SIZE=16). When ncols is a multiple of DMMV_X (32) but not of 2*DMMV_X (64), the last warp iteration accesses elements at col >= ncols, producing NaN for the final row and wrong values for interior rows. Fix: tighten can_use_dequantize_mul_mat_vec to require ne[0] % (2*DMMV_X) == 0 for F16/BF16 types, and update the ASSERT in the BF16 launcher to match. Quantized types use block-structured kernels with different access patterns and keep the existing DMMV_X check. Verified: test-backend-ops MUL_MAT passes 913/913 on Intel Arc Pro B70. Previously failing: m=128/129 n=1 k=1056 cases (NaN and ERR > 0.0005). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
…l-org#21580) * SYCL: add BF16 to DMMV kernel path for ~4x token generation speedup BF16 models had no dedicated token generation kernel — they fell through to the generic full-GEMM path, resulting in ~14% memory bandwidth utilization on Intel Arc GPUs. This adds BF16 support to the DMMV (dequantize mul-mat-vec) path, matching the existing F16 implementation. Fixes ggml-org#20478 * SYCL: fix BF16 DMMV out-of-bounds when ncols % 64 != 0 The qk=1 kernel (used for F16 and BF16) iterates with stride 2*GGML_SYCL_DMMV_X (= 64 on Intel targets where WARP_SIZE=16). When ncols is a multiple of DMMV_X (32) but not of 2*DMMV_X (64), the last warp iteration accesses elements at col >= ncols, producing NaN for the final row and wrong values for interior rows. Fix: tighten can_use_dequantize_mul_mat_vec to require ne[0] % (2*DMMV_X) == 0 for F16/BF16 types, and update the ASSERT in the BF16 launcher to match. Quantized types use block-structured kernels with different access patterns and keep the existing DMMV_X check. Verified: test-backend-ops MUL_MAT passes 913/913 on Intel Arc Pro B70. Previously failing: m=128/129 n=1 k=1056 cases (NaN and ERR > 0.0005). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
…l-org#21580) * SYCL: add BF16 to DMMV kernel path for ~4x token generation speedup BF16 models had no dedicated token generation kernel — they fell through to the generic full-GEMM path, resulting in ~14% memory bandwidth utilization on Intel Arc GPUs. This adds BF16 support to the DMMV (dequantize mul-mat-vec) path, matching the existing F16 implementation. Fixes ggml-org#20478 * SYCL: fix BF16 DMMV out-of-bounds when ncols % 64 != 0 The qk=1 kernel (used for F16 and BF16) iterates with stride 2*GGML_SYCL_DMMV_X (= 64 on Intel targets where WARP_SIZE=16). When ncols is a multiple of DMMV_X (32) but not of 2*DMMV_X (64), the last warp iteration accesses elements at col >= ncols, producing NaN for the final row and wrong values for interior rows. Fix: tighten can_use_dequantize_mul_mat_vec to require ne[0] % (2*DMMV_X) == 0 for F16/BF16 types, and update the ASSERT in the BF16 launcher to match. Quantized types use block-structured kernels with different access patterns and keep the existing DMMV_X check. Verified: test-backend-ops MUL_MAT passes 913/913 on Intel Arc Pro B70. Previously failing: m=128/129 n=1 k=1056 cases (NaN and ERR > 0.0005). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> (cherry picked from commit 8cc67ef)
…l-org#21580) * SYCL: add BF16 to DMMV kernel path for ~4x token generation speedup BF16 models had no dedicated token generation kernel — they fell through to the generic full-GEMM path, resulting in ~14% memory bandwidth utilization on Intel Arc GPUs. This adds BF16 support to the DMMV (dequantize mul-mat-vec) path, matching the existing F16 implementation. Fixes ggml-org#20478 * SYCL: fix BF16 DMMV out-of-bounds when ncols % 64 != 0 The qk=1 kernel (used for F16 and BF16) iterates with stride 2*GGML_SYCL_DMMV_X (= 64 on Intel targets where WARP_SIZE=16). When ncols is a multiple of DMMV_X (32) but not of 2*DMMV_X (64), the last warp iteration accesses elements at col >= ncols, producing NaN for the final row and wrong values for interior rows. Fix: tighten can_use_dequantize_mul_mat_vec to require ne[0] % (2*DMMV_X) == 0 for F16/BF16 types, and update the ASSERT in the BF16 launcher to match. Quantized types use block-structured kernels with different access patterns and keep the existing DMMV_X check. Verified: test-backend-ops MUL_MAT passes 913/913 on Intel Arc Pro B70. Previously failing: m=128/129 n=1 k=1056 cases (NaN and ERR > 0.0005). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
…l-org#21580) * SYCL: add BF16 to DMMV kernel path for ~4x token generation speedup BF16 models had no dedicated token generation kernel — they fell through to the generic full-GEMM path, resulting in ~14% memory bandwidth utilization on Intel Arc GPUs. This adds BF16 support to the DMMV (dequantize mul-mat-vec) path, matching the existing F16 implementation. Fixes ggml-org#20478 * SYCL: fix BF16 DMMV out-of-bounds when ncols % 64 != 0 The qk=1 kernel (used for F16 and BF16) iterates with stride 2*GGML_SYCL_DMMV_X (= 64 on Intel targets where WARP_SIZE=16). When ncols is a multiple of DMMV_X (32) but not of 2*DMMV_X (64), the last warp iteration accesses elements at col >= ncols, producing NaN for the final row and wrong values for interior rows. Fix: tighten can_use_dequantize_mul_mat_vec to require ne[0] % (2*DMMV_X) == 0 for F16/BF16 types, and update the ASSERT in the BF16 launcher to match. Quantized types use block-structured kernels with different access patterns and keep the existing DMMV_X check. Verified: test-backend-ops MUL_MAT passes 913/913 on Intel Arc Pro B70. Previously failing: m=128/129 n=1 k=1056 cases (NaN and ERR > 0.0005). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Summary
BF16 models currently have no dedicated token generation (tg) kernel in the SYCL backend. During single-token generation, BF16 falls through to the generic
ggml_sycl_op_mul_mat_syclGEMM path, which dequantizes to FP32 and runs a full matrix multiply — far too heavy for a memory-bound batch=1 operation.This adds BF16 to the DMMV (dequantize mul-mat-vec) path, following the existing F16 pattern.
Changes
ggml/src/ggml-sycl/dmmv.cpp:convert_bf16()— readssycl::ext::oneapi::bfloat16, casts to float (mirrorsconvert_f16)convert_mul_mat_vec_bf16_sycl()— kernel launcher (mirrors F16 version)src1_convert_f16list for half-precision intrinsics whenGGML_SYCL_F16is enabledGGML_SYCL_DMMV_HAS_BF16(compile-time bfloat16 header detection)ggml/src/ggml-sycl/ggml-sycl.cpp:GGML_TYPE_BF16toggml_sycl_supports_dmmv()Benchmark — Qwen2.5-1.5B, Intel Arc Pro B70 (Xe2), single GPU
BF16 bandwidth utilization goes from ~14% to ~58% of theoretical (608 GB/s).
Testing
-DGGML_SYCL=ON -DGGML_SYCL_F16=ONHardware
Note
This addresses the tg (token generation) path only. BF16 is still not included in the F16-specific special paths for permuted/batched operations (KQ, KQV). Those are separate and would be a broader change.
Fixes #20478
AI Disclosure
AI (Claude) assisted with investigating the dispatch path and drafting the kernel code. All code was human-reviewed, tested, and benchmarked on real hardware.