Skip to content

SYCL: add BF16 to DMMV kernel path (~4x tg speedup on Intel Arc)#21580

Merged
ggerganov merged 2 commits into
ggml-org:masterfrom
PMZFX:sycl-bf16-dmmv
May 22, 2026
Merged

SYCL: add BF16 to DMMV kernel path (~4x tg speedup on Intel Arc)#21580
ggerganov merged 2 commits into
ggml-org:masterfrom
PMZFX:sycl-bf16-dmmv

Conversation

@PMZFX

@PMZFX PMZFX commented Apr 7, 2026

Copy link
Copy Markdown
Contributor

Summary

BF16 models currently have no dedicated token generation (tg) kernel in the SYCL backend. During single-token generation, BF16 falls through to the generic ggml_sycl_op_mul_mat_sycl GEMM path, which dequantizes to FP32 and runs a full matrix multiply — far too heavy for a memory-bound batch=1 operation.

This adds BF16 to the DMMV (dequantize mul-mat-vec) path, following the existing F16 pattern.

Changes

ggml/src/ggml-sycl/dmmv.cpp:

  • convert_bf16() — reads sycl::ext::oneapi::bfloat16, casts to float (mirrors convert_f16)
  • convert_mul_mat_vec_bf16_sycl() — kernel launcher (mirrors F16 version)
  • Added BF16 to the DMMV dispatch switch
  • Added BF16 to the src1_convert_f16 list for half-precision intrinsics when GGML_SYCL_F16 is enabled
  • All BF16 code guarded behind GGML_SYCL_DMMV_HAS_BF16 (compile-time bfloat16 header detection)

ggml/src/ggml-sycl/ggml-sycl.cpp:

  • Added GGML_TYPE_BF16 to ggml_sycl_supports_dmmv()

Benchmark — Qwen2.5-1.5B, Intel Arc Pro B70 (Xe2), single GPU

Format Size pp512 (before) pp512 (after) tg128 (before) tg128 (after) tg speedup
Q4_K_M 1.04 GiB 8777 8778 202.6 202.6
Q8_0 1.76 GiB 9304 9304 180.6 180.6
BF16 2.88 GiB 2580 4887 29.7 123.9 4.2x

BF16 bandwidth utilization goes from ~14% to ~58% of theoretical (608 GB/s).

Testing

  • Builds cleanly with -DGGML_SYCL=ON -DGGML_SYCL_F16=ON
  • Token generation produces correct output (verified text coherence)
  • No regressions on Q4_K_M, Q8_0, or larger 9B models
  • Tested on Qwen2.5-1.5B and Qwen3.5-9B
  • Not yet tested on Intel Arc A-series (Alchemist) — would appreciate community testing

Hardware

  • Intel Arc Pro B70 (BMG-G31, 32 GB GDDR6 ECC, 608 GB/s)
  • Driver: libze-intel-gpu1 26.09.37435.1, IGC 2.30.1
  • oneAPI DPC++ 2025.3.3

Note

This addresses the tg (token generation) path only. BF16 is still not included in the F16-specific special paths for permuted/batched operations (KQ, KQV). Those are separate and would be a broader change.

Fixes #20478

AI Disclosure

AI (Claude) assisted with investigating the dispatch path and drafting the kernel code. All code was human-reviewed, tested, and benchmarked on real hardware.

BF16 models had no dedicated token generation kernel — they fell through
to the generic full-GEMM path, resulting in ~14% memory bandwidth
utilization on Intel Arc GPUs. This adds BF16 support to the DMMV
(dequantize mul-mat-vec) path, matching the existing F16 implementation.

Fixes ggml-org#20478
@PMZFX PMZFX requested a review from a team as a code owner April 7, 2026 19:53
@github-actions github-actions Bot added ggml changes relating to the ggml tensor library for machine learning SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language labels Apr 7, 2026
@NeoZhangJianyu

Copy link
Copy Markdown
Contributor

@PMZFX
Sorry, I missed this PR.
I test the PR on B60 with Qwen/qwen2.5-1.5b-instruct-fp16.gguf.
What's the gguf you used?

Here is no performance increase:

base:

./build/bin/llama-bench -m ../models/qwen2.5-1.5b-instruct-fp16.gguf

model size params backend ngl test t/s
qwen2 1.5B F16 3.31 GiB 1.78 B SYCL 99 pp512 1326.90 ± 0.79
qwen2 1.5B F16 3.31 GiB 1.78 B SYCL 99 tg128 24.49 ± 0.22

This PR:
./build/bin/llama-bench -m ../models/qwen2.5-1.5b-instruct-fp16.gguf

model size params backend ngl test t/s
qwen2 1.5B F16 3.31 GiB 1.78 B SYCL 99 pp512 1383.78 ± 1.30
qwen2 1.5B F16 3.31 GiB 1.78 B SYCL 99 tg128 23.96 ± 0.13

Comment thread ggml/src/ggml-sycl/dmmv.cpp

@arthw arthw left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@PMZFX
The performance is increased.

Some UT cases are fault:

./build/bin/test-backend-ops -b SYCL0


MUL_MAT(type_a=bf16,type_b=f32,m=128,n=1,k=1056,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=2112,o=1)
  MUL_MAT(type_a=bf16,type_b=f32,m=129,n=1,k=1056,bs=[4,1],nr=[4,1],per=[0,1,2,3],k_v=2112,o=1)
  MUL_MAT(type_a=bf16,type_b=f32,m=128,n=1,k=1056,bs=[2,3],nr=[1,1],per=[0,1,2,3],k_v=2112,o=1)
  MUL_MAT(type_a=bf16,type_b=f32,m=128,n=1,k=1056,bs=[8,3],nr=[1,1],per=[0,1,2,3],k_v=2112,o=1)
  MUL_MAT(type_a=bf16,type_b=f32,m=128,n=1,k=1056,bs=[8,3],nr=[4,1],per=[0,1,2,3],k_v=2112,o=1)
  MUL_MAT(type_a=bf16,type_b=f32,m=129,n=1,k=1056,bs=[8,3],nr=[4,1],per=[0,1,2,3],k_v=2112,o=1)
....

@arthw

arthw commented May 14, 2026

Copy link
Copy Markdown
Contributor

@PMZFX
Could you check the failed UT cases?

Thank you!

The qk=1 kernel (used for F16 and BF16) iterates with stride
2*GGML_SYCL_DMMV_X (= 64 on Intel targets where WARP_SIZE=16). When
ncols is a multiple of DMMV_X (32) but not of 2*DMMV_X (64), the last
warp iteration accesses elements at col >= ncols, producing NaN for the
final row and wrong values for interior rows.

Fix: tighten can_use_dequantize_mul_mat_vec to require ne[0] %
(2*DMMV_X) == 0 for F16/BF16 types, and update the ASSERT in the BF16
launcher to match. Quantized types use block-structured kernels with
different access patterns and keep the existing DMMV_X check.

Verified: test-backend-ops MUL_MAT passes 913/913 on Intel Arc Pro B70.
Previously failing: m=128/129 n=1 k=1056 cases (NaN and ERR > 0.0005).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@PMZFX

PMZFX commented May 14, 2026

Copy link
Copy Markdown
Contributor Author

Found and fixed. The dequantize_mul_mat_vec<1, 1> kernel (qk=1, used by F16 and BF16) iterates in steps of 2*GGML_SYCL_DMMV_X = 64. When ncols=1056 (divisible by 32 but not 64), the last warp iteration reads past the end of the row, producing NaN. The dispatch check only required ne0 % DMMV_X == 0 (% 32) — too permissive.

Fix: tightened can_use_dequantize_mul_mat_vec to require ne0 % (2*DMMV_X) == 0 for F16/BF16 (quantized types have different access patterns and keep the existing check). ASSERT in the BF16 launcher updated to match.

F16 tests passed with k=1056 because those test tensors are non-contiguous, routing F16 through mul_mat_vec_nc before DMMV is considered. BF16 has no such path.

Full test-backend-ops -b SYCL0 passes on Intel Arc Pro B70.

@arthw arthw left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shows good performance increasing:

gemma-4-E4B-it-BF16.gguf on B60

Test Base t/s Primary t/s Increase Rate (Primary vs Base)
pp512 1010.34 1051.54 4.08%
tg128 6.37 35.94 464.21%

Thank you!

Comment thread ggml/src/ggml-sycl/dmmv.cpp
Comment thread ggml/src/ggml-sycl/dmmv.cpp
@arthw arthw added the merge ready A maintainer can use this label to indicate that they consider the changes final and ready to merge. label May 20, 2026
@ggerganov ggerganov merged commit 8cc67ef into ggml-org:master May 22, 2026
49 of 50 checks passed
Alex7MV pushed a commit to Alex7MV/claude_llama.cpp that referenced this pull request May 22, 2026
…l-org#21580)

* SYCL: add BF16 to DMMV kernel path for ~4x token generation speedup

BF16 models had no dedicated token generation kernel — they fell through
to the generic full-GEMM path, resulting in ~14% memory bandwidth
utilization on Intel Arc GPUs. This adds BF16 support to the DMMV
(dequantize mul-mat-vec) path, matching the existing F16 implementation.

Fixes ggml-org#20478

* SYCL: fix BF16 DMMV out-of-bounds when ncols % 64 != 0

The qk=1 kernel (used for F16 and BF16) iterates with stride
2*GGML_SYCL_DMMV_X (= 64 on Intel targets where WARP_SIZE=16). When
ncols is a multiple of DMMV_X (32) but not of 2*DMMV_X (64), the last
warp iteration accesses elements at col >= ncols, producing NaN for the
final row and wrong values for interior rows.

Fix: tighten can_use_dequantize_mul_mat_vec to require ne[0] %
(2*DMMV_X) == 0 for F16/BF16 types, and update the ASSERT in the BF16
launcher to match. Quantized types use block-structured kernels with
different access patterns and keep the existing DMMV_X check.

Verified: test-backend-ops MUL_MAT passes 913/913 on Intel Arc Pro B70.
Previously failing: m=128/129 n=1 k=1056 cases (NaN and ERR > 0.0005).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
ProTekk pushed a commit to ProTekk/buun-llama-cpp that referenced this pull request May 22, 2026
…l-org#21580)

* SYCL: add BF16 to DMMV kernel path for ~4x token generation speedup

BF16 models had no dedicated token generation kernel — they fell through
to the generic full-GEMM path, resulting in ~14% memory bandwidth
utilization on Intel Arc GPUs. This adds BF16 support to the DMMV
(dequantize mul-mat-vec) path, matching the existing F16 implementation.

Fixes ggml-org#20478

* SYCL: fix BF16 DMMV out-of-bounds when ncols % 64 != 0

The qk=1 kernel (used for F16 and BF16) iterates with stride
2*GGML_SYCL_DMMV_X (= 64 on Intel targets where WARP_SIZE=16). When
ncols is a multiple of DMMV_X (32) but not of 2*DMMV_X (64), the last
warp iteration accesses elements at col >= ncols, producing NaN for the
final row and wrong values for interior rows.

Fix: tighten can_use_dequantize_mul_mat_vec to require ne[0] %
(2*DMMV_X) == 0 for F16/BF16 types, and update the ASSERT in the BF16
launcher to match. Quantized types use block-structured kernels with
different access patterns and keep the existing DMMV_X check.

Verified: test-backend-ops MUL_MAT passes 913/913 on Intel Arc Pro B70.
Previously failing: m=128/129 n=1 k=1056 cases (NaN and ERR > 0.0005).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request May 22, 2026
* origin/master:
server: only parse empty msg if continuing an assistant msg (ggml-org#23506)
perplexity : fix integer overflow (ggml-org#23496)
SYCL: improve MoE prefill throughput (ggml-org#23142)
sycl : Level Zero detection in ggml_sycl_init (ggml-org#23097)
SYCL : gated_delta_net K>1 (ggml-org#23174)
SYCL: add BF16 to DMMV kernel path (~4x tg speedup on Intel Arc) (ggml-org#21580)
docs: Update documentation with Granite 4.0/4.1 (ggml-org#23404)
ggml-zendnn : add Q8_0 quantization support (ggml-org#23414)
cmake : build router app only during standalone builds (ggml-org#23521)
vocab : fix HybridDNA tokenizer (ggml-org#23466)
cmake : add install() for impl libraries + fix apple builds (ggml-org#23511)
CUDA: fix PDL CC check for JIT compilation (ggml-org#23471)
cmake : remove STATIC from impl libraries, enable LLAMA_BUILD_APP by default (ggml-org#23462)
Update WebGPU support and add link to blog/demo (ggml-org#23483)
vulkan: fuse snake activation (mul, sin, sqr, mul, add) (ggml-org#22855)
baramofme pushed a commit to baramofme/llama-cpp-turboquant that referenced this pull request May 23, 2026
…l-org#21580)

* SYCL: add BF16 to DMMV kernel path for ~4x token generation speedup

BF16 models had no dedicated token generation kernel — they fell through
to the generic full-GEMM path, resulting in ~14% memory bandwidth
utilization on Intel Arc GPUs. This adds BF16 support to the DMMV
(dequantize mul-mat-vec) path, matching the existing F16 implementation.

Fixes ggml-org#20478

* SYCL: fix BF16 DMMV out-of-bounds when ncols % 64 != 0

The qk=1 kernel (used for F16 and BF16) iterates with stride
2*GGML_SYCL_DMMV_X (= 64 on Intel targets where WARP_SIZE=16). When
ncols is a multiple of DMMV_X (32) but not of 2*DMMV_X (64), the last
warp iteration accesses elements at col >= ncols, producing NaN for the
final row and wrong values for interior rows.

Fix: tighten can_use_dequantize_mul_mat_vec to require ne[0] %
(2*DMMV_X) == 0 for F16/BF16 types, and update the ASSERT in the BF16
launcher to match. Quantized types use block-structured kernels with
different access patterns and keep the existing DMMV_X check.

Verified: test-backend-ops MUL_MAT passes 913/913 on Intel Arc Pro B70.
Previously failing: m=128/129 n=1 k=1056 cases (NaN and ERR > 0.0005).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
srossitto79 pushed a commit to srossitto79/llama.cpp that referenced this pull request May 23, 2026
…l-org#21580)

* SYCL: add BF16 to DMMV kernel path for ~4x token generation speedup

BF16 models had no dedicated token generation kernel — they fell through
to the generic full-GEMM path, resulting in ~14% memory bandwidth
utilization on Intel Arc GPUs. This adds BF16 support to the DMMV
(dequantize mul-mat-vec) path, matching the existing F16 implementation.

Fixes ggml-org#20478

* SYCL: fix BF16 DMMV out-of-bounds when ncols % 64 != 0

The qk=1 kernel (used for F16 and BF16) iterates with stride
2*GGML_SYCL_DMMV_X (= 64 on Intel targets where WARP_SIZE=16). When
ncols is a multiple of DMMV_X (32) but not of 2*DMMV_X (64), the last
warp iteration accesses elements at col >= ncols, producing NaN for the
final row and wrong values for interior rows.

Fix: tighten can_use_dequantize_mul_mat_vec to require ne[0] %
(2*DMMV_X) == 0 for F16/BF16 types, and update the ASSERT in the BF16
launcher to match. Quantized types use block-structured kernels with
different access patterns and keep the existing DMMV_X check.

Verified: test-backend-ops MUL_MAT passes 913/913 on Intel Arc Pro B70.
Previously failing: m=128/129 n=1 k=1056 cases (NaN and ERR > 0.0005).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
kashif pushed a commit to kashif/llama.cpp that referenced this pull request May 23, 2026
…l-org#21580)

* SYCL: add BF16 to DMMV kernel path for ~4x token generation speedup

BF16 models had no dedicated token generation kernel — they fell through
to the generic full-GEMM path, resulting in ~14% memory bandwidth
utilization on Intel Arc GPUs. This adds BF16 support to the DMMV
(dequantize mul-mat-vec) path, matching the existing F16 implementation.

Fixes ggml-org#20478

* SYCL: fix BF16 DMMV out-of-bounds when ncols % 64 != 0

The qk=1 kernel (used for F16 and BF16) iterates with stride
2*GGML_SYCL_DMMV_X (= 64 on Intel targets where WARP_SIZE=16). When
ncols is a multiple of DMMV_X (32) but not of 2*DMMV_X (64), the last
warp iteration accesses elements at col >= ncols, producing NaN for the
final row and wrong values for interior rows.

Fix: tighten can_use_dequantize_mul_mat_vec to require ne[0] %
(2*DMMV_X) == 0 for F16/BF16 types, and update the ASSERT in the BF16
launcher to match. Quantized types use block-structured kernels with
different access patterns and keep the existing DMMV_X check.

Verified: test-backend-ops MUL_MAT passes 913/913 on Intel Arc Pro B70.
Previously failing: m=128/129 n=1 k=1056 cases (NaN and ERR > 0.0005).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
carlosfundora pushed a commit to carlosfundora/llama.cpp-1-bit-turbo that referenced this pull request May 24, 2026
…l-org#21580)

* SYCL: add BF16 to DMMV kernel path for ~4x token generation speedup

BF16 models had no dedicated token generation kernel — they fell through
to the generic full-GEMM path, resulting in ~14% memory bandwidth
utilization on Intel Arc GPUs. This adds BF16 support to the DMMV
(dequantize mul-mat-vec) path, matching the existing F16 implementation.

Fixes ggml-org#20478

* SYCL: fix BF16 DMMV out-of-bounds when ncols % 64 != 0

The qk=1 kernel (used for F16 and BF16) iterates with stride
2*GGML_SYCL_DMMV_X (= 64 on Intel targets where WARP_SIZE=16). When
ncols is a multiple of DMMV_X (32) but not of 2*DMMV_X (64), the last
warp iteration accesses elements at col >= ncols, producing NaN for the
final row and wrong values for interior rows.

Fix: tighten can_use_dequantize_mul_mat_vec to require ne[0] %
(2*DMMV_X) == 0 for F16/BF16 types, and update the ASSERT in the BF16
launcher to match. Quantized types use block-structured kernels with
different access patterns and keep the existing DMMV_X check.

Verified: test-backend-ops MUL_MAT passes 913/913 on Intel Arc Pro B70.
Previously failing: m=128/129 n=1 k=1056 cases (NaN and ERR > 0.0005).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
(cherry picked from commit 8cc67ef)
fewtarius pushed a commit to fewtarius/llama.cpp that referenced this pull request May 30, 2026
…l-org#21580)

* SYCL: add BF16 to DMMV kernel path for ~4x token generation speedup

BF16 models had no dedicated token generation kernel — they fell through
to the generic full-GEMM path, resulting in ~14% memory bandwidth
utilization on Intel Arc GPUs. This adds BF16 support to the DMMV
(dequantize mul-mat-vec) path, matching the existing F16 implementation.

Fixes ggml-org#20478

* SYCL: fix BF16 DMMV out-of-bounds when ncols % 64 != 0

The qk=1 kernel (used for F16 and BF16) iterates with stride
2*GGML_SYCL_DMMV_X (= 64 on Intel targets where WARP_SIZE=16). When
ncols is a multiple of DMMV_X (32) but not of 2*DMMV_X (64), the last
warp iteration accesses elements at col >= ncols, producing NaN for the
final row and wrong values for interior rows.

Fix: tighten can_use_dequantize_mul_mat_vec to require ne[0] %
(2*DMMV_X) == 0 for F16/BF16 types, and update the ASSERT in the BF16
launcher to match. Quantized types use block-structured kernels with
different access patterns and keep the existing DMMV_X check.

Verified: test-backend-ops MUL_MAT passes 913/913 on Intel Arc Pro B70.
Previously failing: m=128/129 n=1 k=1056 cases (NaN and ERR > 0.0005).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
turbo-tan pushed a commit to turbo-tan/llama.cpp-tq3 that referenced this pull request Jun 2, 2026
…l-org#21580)

* SYCL: add BF16 to DMMV kernel path for ~4x token generation speedup

BF16 models had no dedicated token generation kernel — they fell through
to the generic full-GEMM path, resulting in ~14% memory bandwidth
utilization on Intel Arc GPUs. This adds BF16 support to the DMMV
(dequantize mul-mat-vec) path, matching the existing F16 implementation.

Fixes ggml-org#20478

* SYCL: fix BF16 DMMV out-of-bounds when ncols % 64 != 0

The qk=1 kernel (used for F16 and BF16) iterates with stride
2*GGML_SYCL_DMMV_X (= 64 on Intel targets where WARP_SIZE=16). When
ncols is a multiple of DMMV_X (32) but not of 2*DMMV_X (64), the last
warp iteration accesses elements at col >= ncols, producing NaN for the
final row and wrong values for interior rows.

Fix: tighten can_use_dequantize_mul_mat_vec to require ne[0] %
(2*DMMV_X) == 0 for F16/BF16 types, and update the ASSERT in the BF16
launcher to match. Quantized types use block-structured kernels with
different access patterns and keep the existing DMMV_X check.

Verified: test-backend-ops MUL_MAT passes 913/913 on Intel Arc Pro B70.
Previously failing: m=128/129 n=1 k=1056 cases (NaN and ERR > 0.0005).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning merge ready A maintainer can use this label to indicate that they consider the changes final and ready to merge. SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Misc. bug: SYCL: BF16 falling to CPU

6 participants