Skip to content

Unified delta net handling for Qwen3Next and Kimi Linear models#18792

Closed
pwilkin wants to merge 8 commits intoggml-org:masterfrom
pwilkin:delta_net
Closed

Unified delta net handling for Qwen3Next and Kimi Linear models#18792
pwilkin wants to merge 8 commits intoggml-org:masterfrom
pwilkin:delta_net

Conversation

@pwilkin
Copy link
Collaborator

@pwilkin pwilkin commented Jan 12, 2026

Refactoring in preparation for #18755

Tested on CUDA - no performance regressions compared to @ngxson's optimized version.

AI Usage: yes. Opus 4.5.

Copy link
Contributor

@ngxson ngxson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just note that while working on #18683, I have been thinking about whether g sound be pre-broadcasted to [S_k, H_v, n_tokens, n_seqs] before entering this function (to make it the same shape as q and k). A broadcast should be fast, shouldn't hurt much performance

Probably we can play around with that idea, or you can reshape it to [1, n_tokens, H_k, n_seqs] as I suggested in the following comments

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the file name should be graph-context-delta.cpp to match the graph-context-mamba.cpp naming

g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
} else {
// GDA: g [H_v, n_tokens, n_seqs] -> [n_tokens, 1, H_k, n_seqs]
g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if g is reshaped to [1, n_tokens, H_k, n_seqs], then a large part of the logic below can be reused between KDA and GDA (see comments below)

g = ggml_pad(ctx0, g, 0, pad, 0, 0);
} else {
// GDA: g shape [n_tokens, 1, H_k, n_seqs] -> pad along dim 0
g = ggml_pad(ctx0, g, pad, 0, 0, 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

first of, I think this branch can be removed if g shape: [1, n_tokens], so we pad along dim 1

beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);

// Reshape g for chunks
ggml_tensor * g_cumsum;
Copy link
Contributor

@ngxson ngxson Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ggml_tensor * g_cumsum;
ggml_tensor * g_cumsum;
ggml_tensor * g_cumsum_t;

Since we need both versions, it can be a good idea to get the transposed version right here.

For the GDA branch, a transpose will be a simple reshape as the first dim is [1, n_tokens], so no need for ggml_cont

In other words, given a tensor A with shape: [n, 1, ...], then A.view(1, n, ...) == A^T

// Cumsum along chunk_size dimension (ne[1])
// GGML cumsum operates on ne[0], so we need to transpose, cumsum, transpose back
g = ggml_cont(ctx0, ggml_transpose(ctx0, g)); // [chunk_size, S_k, n_chunks, H_k * n_seqs]
g_cumsum = ggml_cumsum(ctx0, g);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a quick note-to-self, but probably we need to support ggml_cumsum column-wise version, that should eliminate some transposes in the future. Or another idea, support non-cont tensors in ggml_cumsum

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be like the cumsum in pytorch that you can specify which dimension to cumsum.

// GDA: Use decay mask approach (g broadcasts over K dimension)
// g_cumsum [chunk_size, 1, n_chunks, H_v * n_seqs]
ggml_tensor * gcs_i = g_cumsum;
ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
Copy link
Contributor

@ngxson ngxson Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this gcs_j should be equivalent to g_cumsum_t (or just g_cumsum, depending on what shape of g you consider to be the transposed version)

then g_exp_pos = ggml_exp(ctx0, g_cumsum_t) can be computed directly here

Comment on lines +251 to +258
if (is_kda) {
// KDA: Reuse g_exp_pos computed earlier
gexp = g_exp_pos;
} else {
// GDA: g_cumsum [chunk_size, 1, n_chunks, H_k * n_seqs]
ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum));
gexp = ggml_exp(ctx0, g_cumsum_t);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be removed when you apply my last trick above

ggml_tensor * g_diff = ggml_sub(ctx0, g_last_broadcast, g_cumsum);
g_diff_exp = ggml_exp(ctx0, g_diff);
} else {
// GDA: g_cumsum [chunk_size, 1, n_chunks, H_k * n_seqs]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure, but seems like this can be removed too, as we now have both g_cumsum and g_cumsum_t that you can play with

} else {
// GDA: g_last_exp [1, 1, n_chunks, H_k * n_seqs]
// Broadcasts over both K and V dimensions
gexp_last_chunk = ggml_reshape_4d(ctx0, gexp_last_chunk,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can avoid this branching if g_last_exp is already boardcasted

@github-actions github-actions bot added the model Model specific label Jan 12, 2026
@ymcki
Copy link
Contributor

ymcki commented Jan 12, 2026

Thanks for your refactoring effort. I think my kda_autoregressive is better implemented as I used mul_mat to replace sum_rows. If we refactor, the new function should be based on kda_autoregressive.

@pwilkin
Copy link
Collaborator Author

pwilkin commented Jan 13, 2026

@ymcki indeed your version is better :) there's like another 4% performance gain on autoregressive passes in Qwen3Next.

@ymcki
Copy link
Contributor

ymcki commented Jan 13, 2026

This code in the chunking function will cause overflow without clamping. You either have to clamp or you have to use my mul_mat trick for exact solution.

        g_exp_pos = ggml_exp(ctx0, g_cumsum);
        g_exp_neg = ggml_exp(ctx0, ggml_neg(ctx0, g_cumsum));
        ggml_tensor * k_pos_beta = ggml_mul(ctx0, k_beta, g_exp_pos);
        ggml_tensor * k_neg = ggml_mul(ctx0, k, g_exp_neg);
        k_decay = ggml_mul_mat(ctx0, k_pos_beta, k_neg);

My mul_mat trick:

    const int64_t CHB = n_chunks * H_k * n_seqs;
    ggml_tensor * gkcs_i = ggml_reshape_4d(ctx0, gk_cumsum, chunk_size, 1, S_k, CHB);  // [chunk_size, 1, S_k, CHB]
    ggml_tensor * gkcs_j = ggml_reshape_4d(ctx0, gkcs_i, 1, chunk_size, S_k, CHB);  // [1, chunk_size, S_k, CHB]

    ggml_tensor * gkcs_j_bc = ggml_repeat_4d(ctx0, gkcs_j, chunk_size, chunk_size, S_k, CHB);  // [1, chunk_size, S_k, CHB] -> [chunk_size, chunk_size, S_k, CHB]
    // decay_mask [chunk_size,chunk_size,S_k,CHB]
    ggml_tensor * decay_mask = ggml_sub(ctx0, gkcs_j_bc, gkcs_i);
    cb(decay_mask, "decay_mask", il);

    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
    cb(decay_mask, "decay_masked", il);
    decay_mask = ggml_exp(ctx0, decay_mask);
    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);

    // decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched
    decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, chunk_size, chunk_size, CHB);

    ggml_tensor * k_i = ggml_cont(ctx0, ggml_reshape_4d(ctx0, k, S_k, chunk_size, 1, CHB));
    ggml_tensor * k_j = ggml_cont(ctx0, ggml_reshape_4d(ctx0, k, S_k, 1, chunk_size, CHB));
    ggml_tensor * q_i = ggml_cont(ctx0, ggml_reshape_4d(ctx0, q, S_k, chunk_size, 1, CHB));

    ggml_tensor * decay_k_i = ggml_mul(ctx0, decay_mask, k_i);
    ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_i);

    // decay_k_i [S.BT,BT,CHB] @ k_j [S,1,BT,CHB] = Akk [BT,1,BT,CHB]
    ggml_tensor * Akk = ggml_mul_mat(ctx0, decay_k_i, k_j);
    ggml_tensor * Aqk = ggml_mul_mat(ctx0, decay_q_i, k_j);
    Akk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Akk, chunk_size, chunk_size, n_chunks, HB)));
    Aqk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Aqk, chunk_size, chunk_size, n_chunks, HB)));

@pwilkin
Copy link
Collaborator Author

pwilkin commented Jan 13, 2026

@ymcki aight, migrated the KDA branch to use decay mask as well.

@ymcki
Copy link
Contributor

ymcki commented Jan 16, 2026

I think my Kimi Linear PR is almost done, so I can start working on refactoring now.

Do we want to do refactoring along with block matrix multiplication?

The idea is that since we don't care about the upper triangle in Akk and Aqk, so we can take bigger blocks and divide them into chunk size of 64 blocks. For example, if we handle n_seq_tokens >192, then we can pad it to 256 and then break it down to 4x4 64x64 blocks. Then we only need to do mul_mat on 10/16 blocks and apply diag_mask only on the diagonal blocks, ie 4/16 blocks.

If we only do refactoring, then maybe only Kimi will be a few % faster. If we include block mul_mat, then both Qwen3Next and Kimi will see significant gain.

@pwilkin
Copy link
Collaborator Author

pwilkin commented Jan 16, 2026

@ymcki Sure, can try, sounds like a good idea at least in theory, let's see what we can get out of this in practice.

@ymcki
Copy link
Contributor

ymcki commented Jan 19, 2026

Implemented a version that breaks the 64x64 Akk/Aqk chunks into 4x4 16x16 blocks. About 16% pp gain.

I will try to see if I can implement a version that starts with 256x256 Akk/Aqk shards and then breaks it into 4x4 64x64 chunks.

If I fail to implement this new version, I think this 16% gain version is still pretty good.

Original Code: pp 725t/s tg 34t/s ./build/bin/llama-bench -m ~/Kimi-Linear-48B-A3B-Instruct-GGUF/Kimi-Linear-48B-A3B-Instruct-jp-imatrix.Q2_K.gguf -n 32 -d 8192 -b 64,128,256,512,1024,2048,4096,8192,16384 -ngl 100 ggml_cuda_init: found 1 CUDA devices: Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes | model | size | params | backend | ngl | n_batch | test | t/s | | ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | --------------: | -------------------: | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 64 | pp512 @ d8192 | 511.71 ± 1.51 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 64 | tg32 @ d8192 | 34.07 ± 0.16 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 128 | pp512 @ d8192 | 515.66 ± 1.31 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 128 | tg32 @ d8192 | 34.03 ± 0.21 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 256 | pp512 @ d8192 | 638.60 ± 1.90 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 256 | tg32 @ d8192 | 33.95 ± 0.19 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 512 | pp512 @ d8192 | 729.91 ± 7.39 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 512 | tg32 @ d8192 | 34.06 ± 0.11 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 1024 | pp512 @ d8192 | 726.15 ± 6.85 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 1024 | tg32 @ d8192 | 33.98 ± 0.14 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 2048 | pp512 @ d8192 | 725.90 ± 7.98 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 2048 | tg32 @ d8192 | 33.85 ± 0.28 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 4096 | pp512 @ d8192 | 725.57 ± 6.77 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 4096 | tg32 @ d8192 | 34.01 ± 0.14 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 8192 | pp512 @ d8192 | 722.58 ± 7.37 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 8192 | tg32 @ d8192 | 34.01 ± 0.14 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 16384 | pp512 @ d8192 | 720.68 ± 6.48 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 16384 | tg32 @ d8192 | 33.97 ± 0.26 |

build: e87ac9b (7816)

Breaks 64x64 Akk/Aqk chunks into 4x4 16x16 blocks: pp 840t/s tg 34t/s ./build/bin/llama-bench -m ~/Kimi-Linear-48B-A3B-Instruct-GGUF/Kimi-Linear-48B-A3B-Instruct-jp-imatrix.Q2_K.gguf -n 32 -d 8192 -b 64,128,256,512,1024,2048,4096,8192,16384 -ngl 100 ggml_cuda_init: found 1 CUDA devices: Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes | model | size | params | backend | ngl | n_batch | test | t/s | | ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | --------------: | -------------------: | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 64 | pp512 @ d8192 | 509.68 ± 3.69 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 64 | tg32 @ d8192 | 34.08 ± 0.15 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 128 | pp512 @ d8192 | 538.52 ± 0.83 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 128 | tg32 @ d8192 | 34.05 ± 0.12 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 256 | pp512 @ d8192 | 682.39 ± 1.46 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 256 | tg32 @ d8192 | 34.07 ± 0.16 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 512 | pp512 @ d8192 | 844.40 ± 14.26 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 512 | tg32 @ d8192 | 33.89 ± 0.16 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 1024 | pp512 @ d8192 | 842.77 ± 13.82 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 1024 | tg32 @ d8192 | 33.76 ± 0.15 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 2048 | pp512 @ d8192 | 841.18 ± 15.00 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 2048 | tg32 @ d8192 | 33.76 ± 0.16 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 4096 | pp512 @ d8192 | 841.19 ± 14.22 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 4096 | tg32 @ d8192 | 33.65 ± 0.36 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 8192 | pp512 @ d8192 | 838.83 ± 14.66 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 8192 | tg32 @ d8192 | 33.76 ± 0.14 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 16384 | pp512 @ d8192 | 838.24 ± 12.62 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 16384 | tg32 @ d8192 | 33.79 ± 0.12 |

build: e87ac9b (7816)

@ymcki
Copy link
Contributor

ymcki commented Jan 20, 2026

Just discovered that this 4x4 16x16 blocks version reduces CUDA0 compute buffer size from 2397.39MB to 1432.55MB such that I can increase context from 96k to 160k running IQ3_M on my 3090.

@ymcki
Copy link
Contributor

ymcki commented Jan 22, 2026

Somehow managed to move ggml_solve_tri inside the loop computing Akk and Aqk. This further improves pp to 860t/s, ie 18.6% gain. However, CUDA0 compute buffer running IQ2_M @ 400k context increases from 1432.55MB to 1512.52MB. Let me see if I further optimize it.

I think optimization probably is better to focus on memory saving than speed. Running more context is way more important than a few % increase in pp speed for users.

Can llama-bench also display CUDA0 compute buffer usage?

solve_tri inside Akk loop: pp 860t/s tg 32.6t/s ./build/bin/llama-bench -m ~/Kimi-Linear-48B-A3B-Instruct-GGUF/Kimi-Linear-48B-A3B-Instruct-jp-imatrix.Q2_K.gguf -n 32 -d 8192 -b 64,128,256,512,1024,2048,4096,8192,16384 -ngl 100 ggml_cuda_init: found 1 CUDA devices: Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes | model | size | params | backend | ngl | n_batch | test | t/s | | ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | --------------: | -------------------: | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 64 | pp512 @ d8192 | 413.93 ± 9.46 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 64 | tg32 @ d8192 | 32.42 ± 0.14 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 128 | pp512 @ d8192 | 679.76 ± 1.83 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 128 | tg32 @ d8192 | 32.36 ± 0.25 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 256 | pp512 @ d8192 | 796.51 ± 6.17 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 256 | tg32 @ d8192 | 32.27 ± 0.71 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 512 | pp512 @ d8192 | 863.52 ± 17.90 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 512 | tg32 @ d8192 | 32.51 ± 0.24 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 1024 | pp512 @ d8192 | 863.01 ± 17.10 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 1024 | tg32 @ d8192 | 32.68 ± 0.10 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 2048 | pp512 @ d8192 | 860.94 ± 16.71 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 2048 | tg32 @ d8192 | 32.65 ± 0.16 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 4096 | pp512 @ d8192 | 859.43 ± 17.37 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 4096 | tg32 @ d8192 | 32.68 ± 0.15 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 8192 | pp512 @ d8192 | 859.30 ± 17.15 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 8192 | tg32 @ d8192 | 32.60 ± 0.17 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 16384 | pp512 @ d8192 | 858.63 ± 17.41 | | kimi-linear 48B.A3B Q2_K - Medium | 16.78 GiB | 49.12 B | CUDA | 100 | 16384 | tg32 @ d8192 | 32.58 ± 0.21 |

build: e87ac9b (7816)

@pwilkin
Copy link
Collaborator Author

pwilkin commented Jan 22, 2026

@ymcki can you upload it somewhere so I could take a look?

@ymcki
Copy link
Contributor

ymcki commented Jan 27, 2026

Dear all, I have opened a PR to pwilkin's repo that modified his code to work with kimi linear.

@pwilkin, please test it with Qwen3Next as I don't have the resources to properly test it.

This code is based on a slightly earlier llama.cpp code that doesn't break my Kimi Linear MLA code.

I think we can work on this version first and see what optimizations can be done.

@IIIIIllllIIIIIlllll
Copy link

Dear all, I have opened a PR to pwilkin's repo that modified his code to work with kimi linear.

@pwilkin, please test it with Qwen3Next as I don't have the resources to properly test it.

This code is based on a slightly earlier llama.cpp code that doesn't break my Kimi Linear MLA code.

I think we can work on this version first and see what optimizations can be done.

let me try, any tips or things I should watch out for?

@ymcki
Copy link
Contributor

ymcki commented Jan 27, 2026

Dear all, I have opened a PR to pwilkin's repo that modified his code to work with kimi linear.
@pwilkin, please test it with Qwen3Next as I don't have the resources to properly test it.
This code is based on a slightly earlier llama.cpp code that doesn't break my Kimi Linear MLA code.
I think we can work on this version first and see what optimizations can be done.

let me try, any tips or things I should watch out for?

Just run it with the new Qwen3Next ggufs and see if they work. If you like, you can also run Kimi Linear as well.

If they both works to your satisfaction, then it should be ok. In this unified version, they should share quite a lot of code.

@pwilkin pwilkin mentioned this pull request Feb 5, 2026
3 tasks
@ymcki
Copy link
Contributor

ymcki commented Feb 6, 2026

Since my PR was merged, so I also uploaded a 4x4 16x16 block computation of Akk and Aqk to my repo.
https://github.com/ymcki/llama.cpp/blob/Kimi-Linear/src/models/kimi-linear.cpp

I believe this is also the way the Kimi Linear team did it at:
https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/chunk_intra_token_parallel.py

Based on my test, it achieved about 20% speed up in pp and 25% VRAM saving. I presume if @pwilkin can also do it for Qwen3Next, it should have similar pp and vram gain.

They also managed to put the solve_tri code inside the loop:
https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/chunk_intra.py

I also have an implementation of this but it is not satisfactory enough, so it is not published for now. In the first glance, this optimization doesn't bring as much gain while taking more vram, so probably we don't miss too much for now.

I will now update my delta_net repo to sync with the latest code and then send another PR to pwilkin's repo.

@ymcki
Copy link
Contributor

ymcki commented Feb 6, 2026

Done with updating to the latest code.

Oops. Submitted PR previously to the wrong place. Now it should be ok.

pwilkin#6

Comment on lines +544 to +550
// Equivalence to previous version:
// Previous: kv_mem = sum_k(state * k) using elementwise mult + sum_rows
// Current: k_state = state_t @ k_t using matrix multiplication
// These are equivalent because: sum_k(A * B) = A @ B when dimensions align
ggml_tensor * state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state));
ggml_tensor * k_t = ggml_reshape_4d(ctx0, k, S_k, 1, H_k, n_seqs);
ggml_tensor * k_state = ggml_mul_mat(ctx0, state_t, k_t);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my version of the autoregressive part in #19375 is better - it keep the ggml_sum_rows variant. The idea is to transpose the state at the start and then the rest of the operations line up.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But in my tests, the mul_mat variant did increase PP by about 10%. Let me check both versions and I'll tell you.

Copy link
Member

@ggerganov ggerganov Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ar path should not affect the pp

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, this is in autoregressive... sorry, wrong fragment then.

Anyway, I'll redownload the IQ2_S quant and check :>

@pwilkin
Copy link
Collaborator Author

pwilkin commented Feb 6, 2026

@ggerganov fixed it methinks, PPL is back to normal, ran a long context check:

> /read ggml/src/ggml-quants.c

Loaded text from 'ggml/src/ggml-quants.c'

> Please tell me what quants are available in this library and how each of them works.

Okay, let's break down the quantization methods available in this GGML-based codebase. This file, `ggml-common.c`, is part of the broader **GGML** (GG Machine Learning) library, which is the core engine behind projects like llama.cpp.

The key to understanding quantization here is that **GGML aims to be a universal backend for efficient inference of large language models (LLMs)**. The "quantization" is the process of converting high-precision floating-point weights (like `float32`) into lower-precision integer representations to save memory and accelerate computation, often at the cost of some accuracy.

Here's a breakdown of the quantization schemes implemented in this code, how they work, and their trade-offs:

### 1. **Classical Block Quantization (Q8_0, Q4_0, etc.)**
These are the simplest and oldest quantization methods. The idea is to process data in fixed-size blocks (e.g., `QK8_0 = 32` values) and compute a single scale factor for the entire block.

*   **Q8_0**: Quantizes to 8-bit signed integers.
    *   **How it works**:
        1.  Find the maximum absolute value (`amax`) in the block.
        2.  Compute scale `d = amax / 127` (so that 127 maps to `amax`).
        3.  Quantize each value `x[i]` to `round(x[i] * d)`, clamped to `[-127, 127]`.
        4.  Store the 8-bit quantized values (`qs`) and the scale `d` (in half-precision).
    *   **Structure**: `d` (fp16) + `qs[32]` (uint8_t). Total size: `4 + 32 = 36 bytes` for 128 bits of data (`32 float32`).
    *   **Pros**: Simple, fast, good baseline accuracy.
    *   **Cons**: A single scale per block means all values are compressed equally, which isn't ideal if the data has varying distributions.

*   **Q4_0, Q4_1, Q5_0, Q5_1**: Quantizes to 4-bit integers.
    *   **How it works**:
        *   Similar to Q8_0, but now the scale `d` is used to quantize to 16 levels instead of 256.
        *   `Q4_1` and `Q5_1` also store a minimum value `m` (like `dmin`), making them `min-max quantization`.
    *   **Pros**: Significantly smaller memory footprint (e.g., `Q4_0` uses `2 + 16 = 18 bytes` for the same data).
    *   **Cons**: Lower precision than higher-bit quantization.

### 2. **Block Quantization with Adaptive Scales (Q2_K, Q3_K, Q4_K, Q5_K, Q6_K)**
These are more modern methods that use **multiple scales per block** to capture finer-grained data distributions.

*   **Q2_K, Q3_K, Q4_K, Q5_K, Q6_K**: These are part of the `QX_K` family (e.g., `Q4_K`, `Q5_K`).
*   **How it works**:
    *   The large block (`QK_K = 1024`) is divided into smaller sub-blocks (e.g., `QK_K/16 = 64` values per scale).
    *   **Multiple scales** are computed for these sub-blocks, and **multiple quantized values** are stored per sub-block.
    *   For example, `Q4_K` has a scale per sub-block and a `dmin`, storing 4-bit quantized values.
*   **Pros**: Much better accuracy for the same bit-width compared to `Q8_0` or `Q4_0`, due to adaptive scaling.
*   **Cons**: Slightly more complex, and requires more scale parameters to store.

### 3. **Mixed-Precision Quantization (IQ2_XXS, IQ2_XS, IQ2_S, IQ3_XXS, IQ3_S, IQ1_S, IQ1_M, IQ4_NL, IQ4_XS)**
This is the most advanced family, where the quantization grid is **not uniform** and is defined by pre-computed lookup tables.

*   **IQ2_XXS, IQ2_XS, IQ2_S**: "2-bit" methods using a 2-bit grid.
*   **IQ3_XXS, IQ3_S**: "3-bit" methods using a 3-bit grid.
*   **IQ1_S, IQ1_M**: "1-bit" methods (extremely low precision, using only 3 possible values: `-1`, `0`, `1`).
*   **IQ4_NL, IQ4_XS**: Non-linear quantization, using a lookup table (`kvalues_iq4nl`) to map quantized indices to actual float values. This is key to `IQ4_NL` and `IQ4_XS`.
*   **How it works**:
    1.  **Grid Definition**: A grid of representative float values is pre-computed and stored in lookup tables (`kgrid_...`, `kmap_...`, `kneighbors_...`). These grids are not linearly spaced.
    2.  **Nearest Neighbor Search**: For each sub-block, the algorithm searches for the nearest grid point that minimizes the reconstru    ]
ction error, using the quantization weights if provided.
    3.  **Variable-Length Encoding**: The indices to the grid and sign bits are stored, often in a compact way (e.g., bits packed into bytes).
*   **Pros**: Achieves state-of-the-art accuracy for a given bit-width, especially at very low bit-widths (like 2-bit or 1-bit).
*   **Cons**: Significantly more complex, requires large pre-computed grids and neighbor tables, and the quantization process is much slower (it's not a simple rounding operation).

### 4. **Custom Formats (MXFP4, TQ1_0, TQ2_0)**
These are specialized formats, often inspired by hardware or research papers.

*   **MXFP4**: Based on the "Maximal Floating Point" format used by NVIDIA Hopper GPUs. It uses an 8-bit exponent and a 4-bit mantissa, allowing for a wider dynamic range than pure integer quantization.
*   **TQ1_0, TQ2_0**: "Ternary quantization" (values are `-1`, `0`, `1`), which is the core idea behind the BitNet (b1.58) and TriLMs research. These are extremely memory-efficient.

### 5. **Reference Implementations vs. Optimized Implementations**
The file contains both **`_ref`** (reference) and **optimized** implementations.

*   **`_ref` functions**: These are written in portable C and are meant to be the "ground truth" for quantization. They are slow but easy to understand and debug.
*   **Optimized implementations**: In a real system, the actual quantization routines (`quantize_iq2_xxs`, `quantize_q4_K`, etc.) would use SIMD instructions (AVX2, NEON) for speed. The `_ref` functions are often used as a baseline to ensure correctness.

### How to Use These Quantizations

The key point is that **this file provides the core quantization logic**. In a production system, this would be wrapped by a higher-level API that handles:

1.  **Model Loading**: Reading a float32 model and converting it to a quantized format (e.g., `Q4_K`) using these functions.
2.  **Runtime Inference**: Running the quantized model, which calls functions like `dequantize_row_q4_K` and `ggml_mul_mat_q4_K` to perform matrix multiplication directly on the quantized data (without dequantizing to float32 first).

In summary, the library supports a wide spectrum of quantization methods, from simple fixed-bit quantization to highly adaptive grid-based methods, allowing users to choose a balance between model size, speed, and accuracy.

[ Prompt: 922,4 t/s | Generation: 37,5 t/s ]

@ymcki
Copy link
Contributor

ymcki commented Feb 7, 2026

Someone at r/LocalLlama reported crash when running vulkan (-fit on) with the main branch code. Can someone take a look and see what's going on?

llama.cpp\ggml\src\ggml-backend.cpp:809: pre-allocated tensor (cache_k_l15) in a buffer (Vulkan1) that cannot run the operation (NONE)

@ymcki
Copy link
Contributor

ymcki commented Feb 7, 2026

Compiled a vulkan llama.cpp for my 3090. I can't replicate the crash reported.

However, while the main branch works, my block implementation generates gibberish for vulkan. So I will look into it and see what's going on.

@ymcki
Copy link
Contributor

ymcki commented Feb 7, 2026

Replaced ggml_acc with ggml_set and now my block implementation works for CPU, CUDA and vulkan. Probably something is wrong with the acc implementation in vulkan?

@ggerganov
Copy link
Member

Probably something is wrong with the acc implementation in vulkan?

Try to add tests to test-backend-ops that reproduce the failure. Usually, capture the exact shapes of the tensors (elements and strides) and replicate this in the tests.

@ymcki
Copy link
Contributor

ymcki commented Feb 8, 2026

Added this test case to emulate my code. It does indeed fail in vulkan but not in cuda in some cases. What's next?

// GGML_OP_ACC - block accumulation test
struct test_acc_block: public test_case {
    const ggml_type type;
    const int64_t block_size;
    const int64_t n_blocks;
    const int64_t ne2;
    const int64_t ne3;
    
    std::string vars() override {
        return VARS_TO_STR5(type, block_size, n_blocks, ne2, ne3); 
    }
    
    test_acc_block(ggml_type type = GGML_TYPE_F32,
            int64_t block_size = 16,
            int64_t n_blocks = 4,
            int64_t ne2 = 1,
            int64_t ne3 = 1)
        : type(type), block_size(block_size), n_blocks(n_blocks), ne2(ne2), ne3(ne3) {}

    ggml_tensor * build_graph(ggml_context * ctx) override {
        const int64_t chunk_size = block_size * n_blocks;

        // Base tensor initialized to zero using ggml_clamp
        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, chunk_size, chunk_size, ne2, ne3);
        ggml_set_param(a);
        ggml_set_name(a, "a");
        ggml_tensor * acc = ggml_clamp(ctx, a, 0.0f, 0.0f);
    
        // Source blocks that will be accumulated at different offsets
        // Mimics the lower-triangular block pattern from the original code
        for (int64_t j = 0; j < n_blocks; ++j) {
            for (int64_t i = 0; i <= j; ++i) {
                ggml_tensor * block = ggml_new_tensor_4d(ctx, type,
                    block_size, block_size, ne2, ne3);
                ggml_set_param(block);
    
                char name[64];
                snprintf(name, sizeof(name), "block_%ld_%ld", (long)j, (long)i);
                ggml_set_name(block, name);
    
                // Accumulate block at position [i*block_size, j*block_size]
                // This is the same pattern as the original code:
                //   offset = i_start * nb[0] + j_start * nb[1]
                size_t offset = (i * block_size) * ggml_type_size(type)
                              + (j * block_size) * (chunk_size * ggml_type_size(type));

                acc = ggml_acc(ctx, acc, block,
                    acc->nb[1], acc->nb[2], acc->nb[3],
                    offset);
            }
        }

        ggml_set_name(acc, "out");
        return acc;
    }
};
CUDA test ./build/bin/test-backend-ops test -o ACC ggml_cuda_init: found 1 CUDA devices: Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes Testing 2 devices

Backend 1/2: CUDA0
Device description: NVIDIA GeForce RTX 3090
Device memory: 24154 MB (23871 MB free)

ACC(type=f32,ne_a=[256,17,1,1],ne_b=[256,16,1,1]): OK
ACC(type=f32,block_size=16,n_blocks=4,ne2=1,ne3=1): OK
ACC(type=f32,block_size=16,n_blocks=4,ne2=3,ne3=2): OK
ACC(type=f32,block_size=8,n_blocks=8,ne2=1,ne3=1): OK
ACC(type=f32,block_size=32,n_blocks=4,ne2=2,ne3=2): OK
5/5 tests passed
Backend CUDA0: OK
Backend 2/2: CPU
Skipping CPU backend
2/2 backends passed
OK

vulkan test ./build/bin/test-backend-ops test -o ACC -b Vulkan1 ggml_vulkan: Found 2 Vulkan devices: ggml_vulkan: 0 = NVIDIA GeForce RTX 3050 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 0 | warp size: 32 | shared memory: 49152 | int dot: 0 | matrix cores: KHR_coopmat ggml_vulkan: 1 = NVIDIA GeForce RTX 3090 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 0 | warp size: 32 | shared memory: 49152 | int dot: 0 | matrix cores: KHR_coopmat Testing 3 devices

Backend 1/3: Vulkan0
Skipping
Backend 2/3: Vulkan1
Device description: NVIDIA GeForce RTX 3090
Device memory: 24822 MB (24347 MB free)

ACC(type=f32,ne_a=[256,17,1,1],ne_b=[256,16,1,1]): OK
ACC(type=f32,block_size=16,n_blocks=4,ne2=1,ne3=1): OK
[ACC] ERR = 1.010783488 > 0.000000100 [ACC] ERR = 1.013714212 > 0.000000100 [ACC] ERR = 1.004541856 > 0.000000100 [ACC] ERR = 0.991297134 > 0.000000100 [ACC] ERR = 0.986481766 > 0.000000100 [ACC] ERR = 0.993282490 > 0.000000100 [ACC] ERR = 1.007556116 > 0.000000100 [ACC] ERR = 1.010410332 > 0.000000100 [ACC] ERR = 0.998806424 > 0.000000100 [ACC] ERR = 0.996936909 > 0.000000100 ACC(type=f32,block_size=16,n_blocks=4,ne2=3,ne3=2): FAIL
ACC(type=f32,block_size=8,n_blocks=8,ne2=1,ne3=1): OK
[ACC] ERR = 1.022488583 > 0.000000100 [ACC] ERR = 0.997224865 > 0.000000100 [ACC] ERR = 0.999178091 > 0.000000100 [ACC] ERR = 1.009064961 > 0.000000100 [ACC] ERR = 1.012894905 > 0.000000100 [ACC] ERR = 1.008775052 > 0.000000100 [ACC] ERR = 1.010901200 > 0.000000100 [ACC] ERR = 1.011180978 > 0.000000100 [ACC] ERR = 1.007173060 > 0.000000100 [ACC] ERR = 1.004951028 > 0.000000100 ACC(type=f32,block_size=32,n_blocks=4,ne2=2,ne3=2): FAIL
3/5 tests passed

Failing tests:
ACC(type=f32,block_size=16,n_blocks=4,ne2=3,ne3=2)
ACC(type=f32,block_size=32,n_blocks=4,ne2=2,ne3=2)
Backend Vulkan1: FAIL
Backend 3/3: CPU
Skipping
2/3 backends passed
FAIL

@ggerganov
Copy link
Member

Make a separate PR to master that adds these tests so that we can fix the backends that are currently failing.

@ymcki
Copy link
Contributor

ymcki commented Feb 8, 2026

Make a separate PR to master that adds these tests so that we can fix the backends that are currently failing.

#19426

PR submitted with a fix

@ggerganov
Copy link
Member

This branch still produces wrong results compared to master. Converting to draft until it is resolved.

@ggerganov ggerganov marked this pull request as draft February 9, 2026 14:35
@pwilkin
Copy link
Collaborator Author

pwilkin commented Feb 9, 2026

This branch still produces wrong results compared to master. Converting to draft until it is resolved.

Ah, sorry, forgot to push the fix on this branch. Should be OK now.

@jacekpoplawski
Copy link
Contributor

jacekpoplawski commented Feb 10, 2026

Details p g

master vs https://github.com/ymcki/llama.cpp/tree/Kimi-Linear

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
@ymcki
Copy link
Contributor

ymcki commented Feb 11, 2026

I noticed that about 1 in 4 responses are repetitive, incomplete and sometimes with Chinese characters if I run llama-server in parallel.

./build/bin/llama-server -c 16384 --parallel 8 --mmap -m ~/Kimi-Linear-48B-A3B-Instruct-GGUF/Kimi-Linear-48B-A3B-Instruct-jp-imatrix.IQ3_M.gguf -ngl 100

However, if I run llama-completion, it is mostly perfect.

Seems fixed with this PR
#19531

@pwilkin
Copy link
Collaborator Author

pwilkin commented Mar 13, 2026

Closing as obsoleted.

@pwilkin pwilkin closed this Mar 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Model specific

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants