Remove padding and multiple D2D copies for MTP#24086
Conversation
|
Marking this as a draft as it has changes only for the CUDA and Vulkan backends. Changes in the op implementation are fairly simple. I will work on changes for other backends once I get initial feedback. |
ggerganov
left a comment
There was a problem hiding this comment.
Yes, let's do the change. I think we the "determine K from the tensor shape" was never really needed.
| // TODO: remove pad + simplify | ||
| ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, D, 1, n_seqs); | ||
| ggml_tensor * s_3d_pad = ggml_pad (ctx0, s_3d, 0, K - 1, 0, 0); | ||
| ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, D, 1, n_seqs); |
There was a problem hiding this comment.
I don't think we need to collapse the state anymore. Mainly for consistency reasons I think the state should remain a 4D tensor with shape: [S_v, S_v, H_v, n_seqs].
| // state holds the initial state s0 only, shape (S_v*S_v*H, 1, n_seqs). K is the number of | ||
| // output snapshot slots: | ||
| // K == 1: output carries the final state only. | ||
| // K > 1: output carries K snapshot slots; the kernel writes the last min(n_tokens, K) | ||
| // per-token snapshots into the trailing slots | ||
| // K > 1: output carries K snapshots, most-recent first (slot 0 = final state, slot s = | ||
| // state s tokens back); when n_tokens < K only slots 0..n_tokens-1 are written. | ||
| GGML_API struct ggml_tensor * ggml_gated_delta_net( | ||
| struct ggml_context * ctx, | ||
| struct ggml_tensor * q, | ||
| struct ggml_tensor * k, | ||
| struct ggml_tensor * v, | ||
| struct ggml_tensor * g, | ||
| struct ggml_tensor * beta, | ||
| struct ggml_tensor * state); | ||
| struct ggml_tensor * state, | ||
| int64_t K); |
There was a problem hiding this comment.
Would be useful to document in a comment the shapes of the tensors of this op.
Changes are now made in all backends. I have thoroughly reviewed the change, but I have only been able to test CUDA, Vulkan, and CPU. I am hoping CI will be able to help capture any issues on other backends. It will be great if maintainers can help test these changes on other backends as well. |
| const int64_t n_written = std::min<int64_t>(n_seq_tokens, K); | ||
|
|
||
| ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst)); | ||
| } | ||
| // write the produced snapshots into the recurrent cache (snapshot slot i -> rollback group i) | ||
| ggml_tensor * src = ggml_view_3d(ctx0, gdn_out, | ||
| D, n_seqs, n_written, | ||
| ggml_row_size(gdn_out->type, D), | ||
| ggml_row_size(gdn_out->type, state_size_per_snap), | ||
| ggml_row_size(gdn_out->type, attn_score_elems)); | ||
|
|
||
| ggml_tensor * dst = ggml_view_3d(ctx0, ssm_states_all, | ||
| D, n_seqs, n_written, | ||
| ssm_states_all->nb[1], | ||
| (size_t) mem_size * row_size, | ||
| (size_t) kv_head * row_size); | ||
|
|
||
| ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst)); |
There was a problem hiding this comment.
@ggerganov could you please help review this part of the code? The earlier code was copying uninitialized snapshots if n_seq_tokens < K.
I have now limited the number of snapshot copies to min(n_seq_tokens, K).
There was a problem hiding this comment.
Yup, looks fine. It's unfortunate we can't deduce the n_written from the output shape (technically we can, but too complicated).
…, n_seqs) and passes the snapshot count K as an op parameter instead of inferring it from state->ne[1]. Remove the padding hack and copy all emitted snapshots into the recurrent cache with a single strided ggml_cpy
|
@gaugarg-nv It's no bad. Thank you! |
ggerganov
left a comment
There was a problem hiding this comment.
Testing on Mac seems to work OK after the changes.
| const int64_t n_written = std::min<int64_t>(n_seq_tokens, K); | ||
|
|
||
| ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst)); | ||
| } | ||
| // write the produced snapshots into the recurrent cache (snapshot slot i -> rollback group i) | ||
| ggml_tensor * src = ggml_view_3d(ctx0, gdn_out, | ||
| D, n_seqs, n_written, | ||
| ggml_row_size(gdn_out->type, D), | ||
| ggml_row_size(gdn_out->type, state_size_per_snap), | ||
| ggml_row_size(gdn_out->type, attn_score_elems)); | ||
|
|
||
| ggml_tensor * dst = ggml_view_3d(ctx0, ssm_states_all, | ||
| D, n_seqs, n_written, | ||
| ssm_states_all->nb[1], | ||
| (size_t) mem_size * row_size, | ||
| (size_t) kv_head * row_size); | ||
|
|
||
| ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst)); |
There was a problem hiding this comment.
Yup, looks fine. It's unfortunate we can't deduce the n_written from the output shape (technically we can, but too complicated).
|
@ggml-org/maintainers, could I please get another approval? @ggml-org/ggml-hexagon @ggml-org/ggml-webgpu @ggml-org/ggml-opencl Could you please take a look at this as well? |
* upstream/HEAD: (329 commits) vendor : update LibreSSL to 4.3.2 (ggml-org#24397) Remove padding and multiple D2D copies for MTP (ggml-org#24086) chat: fix LFM2/LFM2.5 ignoring json_schema (ggml-org#24377) CUDA: Fix ssm_scan_f32 data-races (ggml-org#24360) ci : bump komac version (ggml-org#24396) speculative : fix "ngram-map-k4v" name in logging (ggml-org#24253) webui: implement pinned conversations support (ggml-org#21387) graph: Fix granite speech model inference by applying embedding scale when deepstack is not used (ggml-org#24357) ci : fix windows release (ggml-org#24369) ui: add opt-in run_javascript frontend tool (ggml-org#24244) mtmd: build_vit batching (ggml-org#24352) vulkan: reduce iq1 shared memory usage for mul_mm (ggml-org#24287) vulkan: add `v_dot2_f32_f16` support in matrix-matrix multiplication and Flash Attention (ggml-org#24123) ui: Fix excessive style recalculation on hover (ggml-org#24243) mtmd: refactor video subproc handling (ggml-org#24316) server: log prompts to directory (ggml-org#22031) ui: fix mobile chat form overflow and bust stale bundle cache (ggml-org#24158) ggml : add GGML_OP_COL2IM_1D (ggml-org#24206) server : do not clear slots without unified KV cache (ggml-org#24190) models : fix plamo2 attention_key/value_length regression (ggml-org#24317) ...
Sync 205 upstream commits. The VE backend is out-of-tree (ggml-ve/), so the merge is conflict-free. The one upstream change that touches the backend is ggml-org#24086 (e95dae1, the GATED_DELTA_NET op-contract change), re-ported in the following commit.
Upstream ggml-org#24086 (e95dae1) changed GATED_DELTA_NET's recurrent state from the old 3D [S_v*S_v*H, K, n_seqs] layout to a 4D s0 [S_v, S_v, H, n_seqs], moved the snapshot slot count K to op_params[0], and reversed the snapshot ordering (slot 0 = newest). Without this the VE GDN supports() rejects the new 4D state, the op falls to CPU, and Qwen3.5/3.6 garble. Make supports() accept BOTH the pre- and post-ggml-org#24086 contracts, read K from op_params[0] (new) or state->ne[1] (old), pass the correct per-seq state stride (nb[3] new / nb[2] old), and thread a new_contract flag into the kernel so the K>1 snapshot slot uses (n_tokens-1-t) under the new order. Validated on ve-backend-port after the 205-commit master sync: Llama-3.2-3B stays coherent; Qwen3.5-4B-rys (GDN) produces coherent output ("...Paris.") instead of garble. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Scanned ggml-org/llama.cpp 160 commits ahead of base 354ebac for any MTP perf fixes that might obviate the planned patch 9 (MTP net-negative on small models / Gemma E2B). Findings: - e95dae1 (Remove padding and multiple D2D copies for MTP, ggml-org#24086) is the only MTP-tagged perf commit, but it targets ggml_gated_delta_net, the Qwen3.5 DeltaNet recurrent path. Gemma4 is a plain transformer so this commit does not address the E2B net-negative case. - a66d505 / 88a3927 / 260862b / 7acb4e8 noted as not applicable or pure refactor. Conclusion: patch 9 would need original investigation work (IME2 coverage of the MTP block, n_max tuning, backend sampling re-enablement). With E4B/12B already net-positive on the same binary the patch may be deprioritisable.
* Make ggml_gated_delta_net take only the initial recurrent state (D, 1, n_seqs) and passes the snapshot count K as an op parameter instead of inferring it from state->ne[1]. Remove the padding hack and copy all emitted snapshots into the recurrent cache with a single strided ggml_cpy * Make GDN changes in all backends. Address review comments. * Fix CI build errors
Based on @ggerganov's suggestion at #23940 (comment)
Make
ggml_gated_delta_nettake only the initial recurrent state (D, 1, n_seqs) and pass the snapshot count K as an op parameter instead of inferring it from state->ne[1].Remove the padding hack and copy all emitted snapshots into the recurrent cache with a single strided ggml_cpy
This improves MTP performance by about 4% on DGX Spark.
Command:
llama-server -m Qwen3.6-35B-A3B-UD-Q4_K_M.gguf --spec-type draft-mtpMaster:
PR:
Requirements