Skip to content

model : refactor QKV into common build_qkv and create_tensor_qkv helpers#21245

Merged
CISC merged 2 commits into
ggml-org:masterfrom
JoursBleu:refactor/build-qkv-helper
Apr 16, 2026
Merged

model : refactor QKV into common build_qkv and create_tensor_qkv helpers#21245
CISC merged 2 commits into
ggml-org:masterfrom
JoursBleu:refactor/build-qkv-helper

Conversation

@JoursBleu

@JoursBleu JoursBleu commented Apr 1, 2026

Copy link
Copy Markdown
Contributor

Overview

Currently llama.cpp supports 112 model files in src/models/.

We modified the 85 applicable model files. Our changes abstract the duplicated
Q/K/V tensors' loading and graph-building code into two reusable helpers,
following the create_tensor_gate_up_exps pattern (#19139).

create_tensor_qkv (llama-model.cpp): tries fused wqkv/bqkv first (TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL), falls back to separate wq/wk/wv. Supports adding biases.

build_qkv (llama-graph.h/cpp): returns {Qcur, Kcur, Vcur} as 3D tensors. Fused case: single fused qkv matmul + ggml_view_3d split. Separate case: 3 separate matmuls + ggml_reshape_3d.

Test: test-llama-archs — all OK, 0 FAIL. Zero diff on llama-arch.cpp.

The remaining 27 models are not modified for the following reasons:

Reason Count Models
Non-attention (SSM/linear/RNN) 10 mamba, mamba-base, rwkv6, rwkv6-base, rwkv6qwen2, rwkv7, rwkv7-base, arwkv7, delta-net-base, wavtokenizer-dec
MLA attention 4 deepseek2, minicpm3, minimax-m2, plm
Graph directly uses layer.wqkv (non-standard layout) 3 cogvlm, openelm, plamo2
Q+gate joint projection 4 qwen35, qwen35moe, qwen3next, plamo3
n_embd_head_k != n_embd_head_v 2 step35-iswa, mimo2-iswa
No fused wqkv_enc 1 t5-enc
Other special architectures 3 olmo2, olmoe, kimi-linear

Additional information

Basing on the discussion in #20628 (@am17an, @ngxson). The plan is:

  1. This PR: This PR does not modify any logic, it simply extracts the redundant code into
    the two functions above, and adds handling for the fused qkv case.
  2. Future PR: add --fuse-qkv to convert_hf_to_gguf.py.

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure: YES - used as a translation tool for translating the PR description

@github-actions github-actions Bot added the model Model specific label Apr 1, 2026
@JoursBleu JoursBleu force-pushed the refactor/build-qkv-helper branch 3 times, most recently from da129d5 to 26e72e0 Compare April 1, 2026 04:18
@JoursBleu JoursBleu marked this pull request as ready for review April 1, 2026 06:01
@JoursBleu JoursBleu requested a review from CISC as a code owner April 1, 2026 06:01
Comment thread src/llama-model.cpp Outdated
Comment thread src/llama-model.cpp Outdated
Comment thread src/llama-model.cpp Outdated
Comment thread src/llama-model.cpp Outdated
Comment thread src/llama-model.cpp Outdated
Comment thread src/llama-model.cpp Outdated
Comment thread src/llama-model.cpp
Comment thread src/llama-model.cpp Outdated
Comment thread src/llama-model.cpp
Comment thread src/llama-model.cpp
@JoursBleu JoursBleu force-pushed the refactor/build-qkv-helper branch from 26e72e0 to bcc69fd Compare April 1, 2026 09:05
@JoursBleu

Copy link
Copy Markdown
Contributor Author

hi @CISC,

  • Removed the has_bias flag.
  • Bias tensors are now always created with TENSOR_NOT_REQUIRED.
  • Fixed the incomplete conversions and typos mentioned above.

Comment thread src/llama-model.cpp
Comment thread src/llama-model.cpp Outdated
Comment thread src/llama-model.cpp
@JoursBleu

Copy link
Copy Markdown
Contributor Author

@CISC Done:

  • Remove unnecessary comments/restore comments that should be retained.
  • JAIS2 restores manually created bias tensors.

Comment thread src/models/afmoe.cpp
@JoursBleu

Copy link
Copy Markdown
Contributor Author

@CISC Done:

  • Remove the Vcur reshape in afmoe.cpp

@CISC CISC left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OP is inaccurate, there's nothing special about these:

  • nemotron-h: just add build_qkv in llm_build_nemotron_h::build_attention_layer
  • granite-hybrid: just add build_qkv in lm_build_granite_hybrid::build_attention_layer
  • olmo/mpt/dbrx: use build_qkv, add clamping
  • gemma3n-iswa: just do build_qkv
  • t5-dec/t5-enc: do build_qkv on normal self-attention
  • bert: use build_qkv
  • lfm2: do build_qkv in build_attn_block

Comment thread src/llama-graph.cpp Outdated
@JoursBleu JoursBleu marked this pull request as draft April 2, 2026 12:56
@CISC

CISC commented Apr 4, 2026

Copy link
Copy Markdown
Member

I meant move the clamping to build_qkv.

@JoursBleu JoursBleu force-pushed the refactor/build-qkv-helper branch from 09d8066 to 04506d4 Compare April 6, 2026 01:27
Comment thread src/llama-graph.cpp Outdated
@JoursBleu JoursBleu force-pushed the refactor/build-qkv-helper branch 2 times, most recently from 050b5a9 to 623ed29 Compare April 9, 2026 01:17
@JoursBleu JoursBleu marked this pull request as ready for review April 9, 2026 01:17
@JoursBleu

Copy link
Copy Markdown
Contributor Author

@CISC Done:

  • Extended build_qkv to bert, mpt, dbrx, olmo, lfm2, nemotron-h, granite-hybrid, gemma3n-iswa, t5-dec, t5-enc;
  • Clamping handled internally in build_qkv using hparams.f_clamp_kqv.

Comment thread src/models/gemma3n-iswa.cpp Outdated
@JoursBleu JoursBleu force-pushed the refactor/build-qkv-helper branch from 623ed29 to ccd1f60 Compare April 10, 2026 13:39
Comment thread src/llama-graph.h Outdated
@JoursBleu JoursBleu force-pushed the refactor/build-qkv-helper branch from ccd1f60 to 67a8492 Compare April 11, 2026 05:29
@CISC CISC requested a review from ngxson April 11, 2026 09:36
Comment thread src/models/mimo2-iswa.cpp Outdated
Comment thread src/models/openai-moe-iswa.cpp Outdated
@JoursBleu JoursBleu force-pushed the refactor/build-qkv-helper branch from 67a8492 to d8bf733 Compare April 12, 2026 07:18
@JoursBleu

JoursBleu commented Apr 13, 2026

Copy link
Copy Markdown
Contributor Author

@ngxson @am17an @ggerganov This PR is ready. Could you take a look when you have time?

@am17an am17an left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job!

Comment thread src/llama-graph.cpp Outdated
@JoursBleu JoursBleu force-pushed the refactor/build-qkv-helper branch from d8bf733 to 51dbd8c Compare April 16, 2026 09:05
@ggerganov ggerganov added the merge ready A maintainer can use this label to indicate that they consider the changes final and ready to merge. label Apr 16, 2026
@CISC CISC merged commit 9db77a0 into ggml-org:master Apr 16, 2026
48 of 50 checks passed
cnsiva pushed a commit to saas-home/llama.cpp that referenced this pull request Apr 17, 2026
…ers (ggml-org#21245)

* model : refactor QKV into common build_qkv and create_tensor_qkv helpers

* model : extend build_qkv to bert/mpt/dbrx/olmo/lfm2/nemotron-h/granite-hybrid/gemma3n-iswa/t5-dec and fix wqkv_s
samuraieng pushed a commit to samuraieng/llama.cpp that referenced this pull request Apr 19, 2026
…ers (ggml-org#21245)

* model : refactor QKV into common build_qkv and create_tensor_qkv helpers

* model : extend build_qkv to bert/mpt/dbrx/olmo/lfm2/nemotron-h/granite-hybrid/gemma3n-iswa/t5-dec and fix wqkv_s
jspadgett pushed a commit to jspadgett/llama.cpp that referenced this pull request Apr 20, 2026
…ers (ggml-org#21245)

* model : refactor QKV into common build_qkv and create_tensor_qkv helpers

* model : extend build_qkv to bert/mpt/dbrx/olmo/lfm2/nemotron-h/granite-hybrid/gemma3n-iswa/t5-dec and fix wqkv_s
mengqin pushed a commit to mengqin/llama.cpp that referenced this pull request Apr 20, 2026
…ers (ggml-org#21245)

* model : refactor QKV into common build_qkv and create_tensor_qkv helpers

* model : extend build_qkv to bert/mpt/dbrx/olmo/lfm2/nemotron-h/granite-hybrid/gemma3n-iswa/t5-dec and fix wqkv_s
ArberSephirotheca pushed a commit to ArberSephirotheca/llama.cpp that referenced this pull request Apr 21, 2026
…ers (ggml-org#21245)

* model : refactor QKV into common build_qkv and create_tensor_qkv helpers

* model : extend build_qkv to bert/mpt/dbrx/olmo/lfm2/nemotron-h/granite-hybrid/gemma3n-iswa/t5-dec and fix wqkv_s
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Apr 23, 2026
…ers (ggml-org#21245)

* model : refactor QKV into common build_qkv and create_tensor_qkv helpers

* model : extend build_qkv to bert/mpt/dbrx/olmo/lfm2/nemotron-h/granite-hybrid/gemma3n-iswa/t5-dec and fix wqkv_s
rsenthilkumar6 pushed a commit to rsenthilkumar6/llama.cpp that referenced this pull request May 1, 2026
…ers (ggml-org#21245)

* model : refactor QKV into common build_qkv and create_tensor_qkv helpers

* model : extend build_qkv to bert/mpt/dbrx/olmo/lfm2/nemotron-h/granite-hybrid/gemma3n-iswa/t5-dec and fix wqkv_s
ljubomirj pushed a commit to ljubomirj/llama.cpp that referenced this pull request May 6, 2026
…ers (ggml-org#21245)

* model : refactor QKV into common build_qkv and create_tensor_qkv helpers

* model : extend build_qkv to bert/mpt/dbrx/olmo/lfm2/nemotron-h/granite-hybrid/gemma3n-iswa/t5-dec and fix wqkv_s
JoursBleu added a commit to JoursBleu/llama.cpp that referenced this pull request May 7, 2026
…F conversion

Adds an opt-in '--fuse-qkv' flag to convert_hf_to_gguf.py that concatenates
separate Q / K / V weight tensors into a single fused attn_qkv tensor during
HF -> GGUF conversion. Fusion happens in the shared base ModelBase.modify_tensors()
sink that subclass overrides forward into, so per-layer Q/K/V tensors are buffered
until all three are seen, then emitted as one fused tensor via torch.cat([q,k,v], dim=0).

At runtime the existing build_qkv helper introduced in ggml-org#21245 already handles the
fused path with one matmul + ggml_view_3d split, so a GGUF produced with --fuse-qkv
mostly reuses that path.

This branch also adds two small C++ correctness fixes for fused-weight + separate-bias
models (qwen2 / phi2 / starcoder2 / stablelm and similar):

  - llama-model.cpp: when wqkv is found but wqkv_b is absent, fall back to loading
    separate wq_b / wk_b / wv_b. --fuse-qkv only fuses weights, not biases.
  - llama-graph.cpp: in build_qkv fused path, when wqkv_b is absent but per-head
    biases exist, concat wq_b + wk_b + wv_b with ggml_concat and add after the
    fused matmul. Without this fix biases were silently dropped, producing garbage
    output.

gguf-py/gguf/constants.py registers MODEL_TENSOR.ATTN_QKV on every arch that
already declares ATTN_Q + ATTN_K + ATTN_V; without this the writer rejects the
fused tensor at format_tensor_name().

Tested on 4x AMD R9700 (gfx1201, ROCm 7.2): 34 architectures match nofuse output
bit-for-bit at Q8_0 and Q4_0; representative pp512 speedups: refact +22.7%,
qwen2 +13.1%, granite +12.5%, seed-oss +12.3%, phi2 +6.3%, mistral3 +5.3%.
test-llama-archs passes (0 FAIL).

This is PR 2 of the split discussed in ggml-org#20628; PR 1 (ggml-org#21245) is already merged.
my-other-github-account pushed a commit to my-other-github-account/llama.cpp that referenced this pull request May 15, 2026
…ers (ggml-org#21245)

* model : refactor QKV into common build_qkv and create_tensor_qkv helpers

* model : extend build_qkv to bert/mpt/dbrx/olmo/lfm2/nemotron-h/granite-hybrid/gemma3n-iswa/t5-dec and fix wqkv_s
my-other-github-account pushed a commit to my-other-github-account/llama.cpp that referenced this pull request May 15, 2026
…ers (ggml-org#21245)

* model : refactor QKV into common build_qkv and create_tensor_qkv helpers

* model : extend build_qkv to bert/mpt/dbrx/olmo/lfm2/nemotron-h/granite-hybrid/gemma3n-iswa/t5-dec and fix wqkv_s
zapabob pushed a commit to zapabob/llama.cpp that referenced this pull request May 17, 2026
…ache

Wholesale sync of 384 upstream commits since merge-base 7fc1c4e (2026-04-22).
Headline upstream feature: MTP / Multi-Token Prediction (ggml-org#22673) + spec-decoding stack
(ggml-org#22838 parallel drafting, ggml-org#22227 spec-simple checkpoints, ggml-org#19493 server spec checkpointing,
plus 5 spec bug-fixes).

11 conflicts resolved across CUDA fattn / Metal / Vulkan / common:

ggml/src/ggml-cuda/fattn-mma-f16.cuh
  RDNA config matrix: union TQ's (640, 512) entries with upstream's expanded
  (112..576) RDNA matrix. Took upstream's new sentinel fallback (no ampere
  fallback for RDNA).

ggml/src/ggml-cuda/fattn.cu
  - Extended hoisted ncols2_max to include 640 head-dim.
  - Volta: dropped TQ's local ncols2_max redefinition in favor of upstream's
    hoisted version (with 640 added).
  - WMMA gate: union exclusions (40, 72, 192, 512, 576, 640).
  - Preserved TQ's RDNA4 vector-kernel branch for TurboQuant cache types
    (renamed inner gqa_ratio_eff_rdna4 to avoid shadowing); took upstream's
    restructured MFMA/CDNA path verbatim.

ggml/src/ggml-cuda/ggml-cuda.cu
  Supported-op switch: union TQ's GGML_OP_TURBO_WHT case with upstream's
  GGML_OP_ADD/SUB/MUL/DIV FP16 cases.

ggml/src/ggml-metal/ggml-metal-device.h
  Kept TQ's get_pipeline_turbo_wht declaration; took upstream's new
  get_pipeline_mul_mv_ext(lib, const ggml_tensor * op, ...) signature
  (replaces split tsrc0/tsrc1 args).

ggml/src/ggml-metal/ggml-metal-device.cpp
  Kept TQ's get_pipeline_turbo_wht implementation; took upstream's new
  get_pipeline_mul_mv_ext signature — body already uses op-> for tsrc0/tsrc1/ne12/r2/r3.

ggml/src/ggml-metal/ggml-metal-ops.cpp
  Preserved TQ's is_tq_weight rotate→matmul→unrotate path with original
  hardcoded dispatch shape. Updated non-TQ fallback to upstream's pipeline-
  param dispatch (pipeline.nr0 / nr1 / nsg + (ne11+nr1-1)/nr1 shape).

ggml/src/ggml-vulkan/* (3 files)
  Upstream-wholesale via `git checkout --theirs`. Upstream architecturally
  refactored FA from compile-time DATA_A_* variants to runtime
  FaTypeK/FaTypeV spec-constant switches. TQ's TURBO3_0 GLSL path is
  DEFERRED — Vulkan TURBO3_0 support needs re-implementation against the
  new architecture in a follow-up PR. Mac mini + M5 Max have no Vulkan;
  no in-house validation path for an immediate re-adaptation.

common/arg.cpp
  --spec-default: took upstream's new struct shape
  (params.speculative.types vector + params.speculative.ngram_mod.{n_match,n_min,n_max}).

common/speculative.cpp
  Low-acceptance reset: took upstream's sinfo.n_low / sinfo.i_last
  (variables moved into sinfo struct).

NOT-CONFLICTED upstream additions that touch TQ-adjacent surface
(auto-merged clean, but worth eyes during review):
  - src/llama-memory-recurrent.{cpp,h}  (MTP rollback API)
  - src/llama-memory-hybrid.{cpp,h}     (recall feedback_llama_memory_types
                                         + feedback_layer0_hybrid_trap)
  - src/llama-graph.cpp, src/llama-kv-cache.cpp, src/llama-context.cpp
  - tools/server/server-context.cpp     (+~1100 lines: MTP + parallel
                                         drafting + spec checkpointing)
  - src/models/qwen35*.cpp, qwen3next.cpp, delta-net-base.cpp
    (entirely new in upstream — MTP draft-head integration)

Known-deferred follow-ups:
  1. Vulkan TURBO3_0 re-implementation against runtime spec-constant FA arch
  2. PR ggml-org#21245 QKV refactor helpers — landed; TQ models not migrated to use
     them. Migrate in a focused follow-up; do not bundle here.

Validation gate (pending):
  M2 mini PPL/decode comparison @ Qwen2.5-7B-Q8_0 K=q8_0/V=turbo4 asymmetric
  ctx 2048 + 16384 — see PR body.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: tturney@psyguard.ai
fewtarius pushed a commit to fewtarius/llama.cpp that referenced this pull request May 30, 2026
…ers (ggml-org#21245)

* model : refactor QKV into common build_qkv and create_tensor_qkv helpers

* model : extend build_qkv to bert/mpt/dbrx/olmo/lfm2/nemotron-h/granite-hybrid/gemma3n-iswa/t5-dec and fix wqkv_s
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

merge ready A maintainer can use this label to indicate that they consider the changes final and ready to merge. model Model specific

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants