Skip to content

model: try to improve Qwen3 Next#18683

Merged
ngxson merged 16 commits intoggml-org:masterfrom
ngxson:xsn/qwen3next_improve
Jan 11, 2026
Merged

model: try to improve Qwen3 Next#18683
ngxson merged 16 commits intoggml-org:masterfrom
ngxson:xsn/qwen3next_improve

Conversation

@ngxson
Copy link
Contributor

@ngxson ngxson commented Jan 8, 2026

Important

If you're using old GGUF and it's no longer loaded, be sure to update to this fix: #18762

I was quite curious why there was a function called fix_query_key_value_ordering in transformers code (which was mirrored into llama.cpp implementation). Just wondering what are they trying to fix.

Turns out, the projected QKVZ was in a wrong order. Not sure why they don't fix the original weight instead of permuting the results. But that's not very not important.

I took my trusty pen & paper to see what can be done:

image

Since the projected matrix is big, I supposed it will give a significant boost in perf. But I was quite disappointed to see only 1% improvement in pp512. So I don't even know if it's worth the fix. GGUF will need to reconverted to take advantage of this.

(Update: there are more improvements that allow from 5% all the way to 20% boost depending on backend, see my comment below)

master:

model size params backend threads test t/s
qwen3next 80B.A3B F16 148.50 GiB 79.67 B Metal,BLAS 24 pp512 790.87 ± 17.93
qwen3next 80B.A3B F16 148.50 GiB 79.67 B Metal,BLAS 24 tg128 22.33 ± 0.08

PR:

model size params backend threads test t/s
qwen3next 80B.A3B F16 148.50 GiB 79.67 B Metal,BLAS 24 pp512 965.15 ± 8.00
qwen3next 80B.A3B F16 148.50 GiB 79.67 B Metal,BLAS 24 tg128 28.87 ± 0.06

Hardware: mac studio m3 256gb ram


I uploaded a q8_0 here: https://huggingface.co/ngxson/qwen3_next_fixed/tree/main (converted from Instruct version)

It can be used to test against q8_0 already exist on the internet: https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct-GGUF/blob/main/Qwen3-Next-80B-A3B-Instruct-Q8_0.gguf

@ngxson ngxson requested a review from CISC as a code owner January 8, 2026 00:00
@ngxson ngxson requested review from pwilkin and removed request for CISC January 8, 2026 00:00
@ngxson ngxson marked this pull request as draft January 8, 2026 00:00
@pwilkin
Copy link
Collaborator

pwilkin commented Jan 8, 2026

I like the idea (and it surely does make the computation clearer), but I'm not sure if people would appreciate having to recreate and redownload GGUFs at this point for a 1% performance increase. I don't mind, but would like to see what people who use the model more think first.

EDIT: never mind, I only now looked at the code in detail and you added a compatibility path. Yeah, I think it's worth to do it :) maybe it'll help more on other backends?

@ngxson
Copy link
Contributor Author

ngxson commented Jan 8, 2026

Yeah I would appreciate helps on testing other backend too. Curious to see if it's better on CUDA. The apple unified memory thingy can make things a bit complicated to measure.

(My optimization should only take effects on system with constrained memory bandwidth)

@jeffbolznv
Copy link
Collaborator

Do I need to do anything special to test it (e.g. regenerate ggufs)?

@ngxson
Copy link
Contributor Author

ngxson commented Jan 8, 2026

@jeffbolznv yes, you will need to generate the GGUF with the PR:

Download the model here: https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct

Then convert it as normal:

python ../../llama.cpp/convert_hf_to_gguf.py . --outtype f16 --outfile model.gguf

@pwilkin
Copy link
Collaborator

pwilkin commented Jan 8, 2026

It's actually BF16 :)

"torch_dtype": "bfloat16"

@jacekpoplawski
Copy link
Contributor

I understand that many 1% improvements add up to something significant.
If someone publishes a quantized GGUF, it will be much easier for people to benchmark it on different configurations.

@am17an
Copy link
Collaborator

am17an commented Jan 8, 2026

a merged QKV in normal models gives about ~3-4% improvement in PP last time I checked. #16813. However, many backends already support computing these things in parallel so the METAL backend perhaps won't benefit as much. You can try the CUDA backend, I suspect the gains will be slightly more pronounced

@jeffbolznv
Copy link
Collaborator

I have limited cycles for the next week or so, if somebody can make a quantized model easily accessible it'll be easier for me to try it.

@IIIIIllllIIIIIlllll
Copy link

I don't have a CUDA device at the moment, so I simply did the test on AI MAX+ 395.

Master

command: /home/mark/llama.cpp/llama.cpp-master/build/bin/llama-bench -m /home/mark/Models/Q8/Qwen3-Next-80B-A3B-Instruct-Q8_0/next-master.gguf -r 5 -p 2048 -n 32 -b 2048 -ub 2048 -fa 1 -mmp 0

ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32
| model                          |       size |     params | backend    | ngl | n_ubatch | fa | mmap |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |          pp2048 |        579.45 ± 1.62 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |            tg32 |         28.98 ± 0.04 |

build: unknown (0)

This

command: /home/mark/llama.cpp/llama.cpp-xsn-qwen3next_improve/build/bin/llama-bench -m /home/mark/Models/Q8/Test-Qwen3-Next/model.gguf -r 5 -p 2048 -n 32 -b 2048 -ub 2048 -fa 1 -mmp 0

ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32
| model                          |       size |     params | backend    | ngl | n_ubatch | fa | mmap |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |          pp2048 |        593.37 ± 1.45 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |            tg32 |         29.18 ± 0.03 |

build: unknown (0)

@ggerganov
Copy link
Member

Btw, I'm also planning to apply the #18550 functionality to Qwen3 Next when it is ready in order to make the ggml graphs static in memory and avoid the graph reallocations that currently occur because of the chunking. It still won't allow CUDA graphs to be used, but at least we should be able to avoid a lot of overhead that we currently have from the changing number of nodes based on the batch size.

@ngxson
Copy link
Contributor Author

ngxson commented Jan 8, 2026

@am17an the problem with qwen3 next is that qkv is not actually used for attention. Instead, this qkvz tensor is used for ssm conv (naming is a bit confusing indeed). The old logic does: qkvz projection --> permute --> concat which is redundant (compared to attention: projection --> rope --> permute --> attention)

As I expected, seems like this really have an improvement for system with less memory bandwidth as @IIIIIllllIIIIIlllll confirmed. The projected qkvz tensors is large, so it should be faster if we don't need to do any permutation on it.

@ngxson
Copy link
Contributor Author

ngxson commented Jan 8, 2026

@jeffbolznv I'll try to upload a q8_0 later today

@ngxson
Copy link
Contributor Author

ngxson commented Jan 8, 2026

I uploaded a q8_0 here: https://huggingface.co/ngxson/qwen3_next_fixed/tree/main (converted from Instruct version)

It can be used to test against q8_0 already exist on the internet: https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct-GGUF/blob/main/Qwen3-Next-80B-A3B-Instruct-Q8_0.gguf

Note: @pwilkin we usually test on f16 because it's the most compatible type cross all backend. Some backend like CPU internally convert bf16 to f16 if the hardware doesn't support it; for hardware that supports bf16, there should be no differences between the 2

But that's not important for this PR: we are comparing perf before/after so the importance is to make sure we're using the same type

@pwilkin
Copy link
Collaborator

pwilkin commented Jan 8, 2026

Ah, okay, got it.

@ngxson
Copy link
Contributor Author

ngxson commented Jan 8, 2026

@ggerganov I'm thinking of this idea: can we enforce the number of chunks? For example, we can enforce cgraph to always allocate n_chunks = ubatch_size / CHUNK_SIZE, and unused chunks will have 0 elements, which allow backend to skip them.

This should make the cgraph topology to be more static, although, it is unclear for me if CUDA graph expects the tensor shapes to be unchanged (the case of 0 elements for unused chunk that I talked above)

One extra question: I remember we had a mechanism in ggml to detect no-op (like GGML_OP_VIEW). I'm wondering if it can be beneficial to extend it to consider singleton as no-op (for example: transpose a tensor with n_elements = 1)

@jeffbolznv
Copy link
Collaborator

My test system is an RTX 5090 with 32GB, so I'm not sure I can get useful data from this 80GB model. I ran with -ngl 19 which is the most layers I can fit, and see like a 10x slowdown with the new model/changes but it's probably either due to paging or due to worse CPU inferencing performance. I've done my previous qwen3next testing on a Q2_K_L model which is around 29GB.

@ggerganov
Copy link
Member

@ggerganov I'm thinking of this idea: can we enforce the number of chunks? For example, we can enforce cgraph to always allocate n_chunks = ubatch_size / CHUNK_SIZE, and unused chunks will have 0 elements, which allow backend to skip them.

Not sure how robust ggml is for zero-sized tensors. Think it will need some significant changes to be compatible. But we can definitely give this idea a try.

One extra question: I remember we had a mechanism in ggml to detect no-op (like GGML_OP_VIEW). I'm wondering if it can be beneficial to extend it to consider singleton as no-op (for example: transpose a tensor with n_elements = 1)

Hm, not sure I understand. Transpose is already a noop:

static bool ggml_op_is_empty(enum ggml_op op) {
switch (op) {
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_TRANSPOSE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
return true;
default:
return false;
}
}

@ngxson
Copy link
Contributor Author

ngxson commented Jan 8, 2026

@ggerganov what I mean is that there can be also cases where ggml_cont on a contiguous tensor, or ggml_sum_rows on a tensor with t->ne[0] == 1, etc, which can all be consider as no-op. That's just an idea for maybe merge 2 branches: chunk and autoregressive together in the future (although, it can be much more difficult to actually do it)

@danbev
Copy link
Member

danbev commented Jan 9, 2026

System Config: DGX Spark (Grace CPU + Blackwell GPU)
# OS & Kernel
Description:	Ubuntu 24.04.3 LTS
6.11.0-1016-nvidia

# CPU (Grace)
Architecture:                         aarch64
Byte Order:                           Little Endian
Vendor ID:                            ARM
Model name:                           Cortex-X925
Model:                                1
Thread(s) per core:                   1
Model name:                           Cortex-A725
Model:                                1
Thread(s) per core:                   1

# GPU (Blackwell)
name, driver_version, memory.total [MiB], compute_cap
NVIDIA GB10, 580.95.05, [N/A], 12.1

# CUDA Version
Cuda compilation tools, release 13.0, V13.0.88

I ran the following command for the benchmarks (same as used in comment):

./build/bin/llama-bench -m models/Qwen3-Next-80B-A3B-Instruct-bf16-Q8_0.gguf -r 5 -p 2048 -n 32 -b 2048 -ub 2048 -fa 1 -mmp 0

Let me know if there are other configurations that you'd like me to run.

master:

| model                          |       size |     params | backend    | ngl | n_ubatch | fa | mmap |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | CUDA       |  99 |     2048 |  1 |    0 |          pp2048 |        899.72 ± 2.06 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | CUDA       |  99 |     2048 |  1 |    0 |            tg32 |         29.77 ± 0.04 |

build: 8ece3836b (7681)

PR:

| model                          |       size |     params | backend    | ngl | n_ubatch | fa | mmap |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | CUDA       |  99 |     2048 |  1 |    0 |          pp2048 |        924.32 ± 2.21 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | CUDA       |  99 |     2048 |  1 |    0 |            tg32 |         31.20 ± 0.52 |

build: 65602e899 (7686)

@ngxson ngxson requested review from CISC and ggerganov January 10, 2026 17:38
@IIIIIllllIIIIIlllll
Copy link

There is no further performance improvement on my AI MAX+ 395 either.

ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32
| model                          |       size |     params | backend    | ngl | fa | mmap |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | ROCm       |  99 |  1 |    0 |           pp512 |        605.15 ± 5.10 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | ROCm       |  99 |  1 |    0 |           tg128 |         30.00 ± 0.01 |

build: unknown (0)

@jeffbolznv
Copy link
Collaborator

I'm not at my computer for a while, but I dont think there should be any splits, except for the initial layer(s).

To enable the logging set GGML_SCHED_DEBUG=2 and use -v.

@jeffbolznv
Copy link
Collaborator

Confirming that I only see the initial split 0 with GET_ROWS and then the remaining nodes in split 1.

@danbev
Copy link
Member

danbev commented Jan 11, 2026

| model                          |       size |     params | backend    | ngl | n_ubatch | fa | mmap |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | CUDA       |  99 |     2048 |  1 |    0 |          pp2048 |        978.86 ± 2.65 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | CUDA       |  99 |     2048 |  1 |    0 |            tg32 |         31.16 ± 0.69 |

build: d5a085696 (7704)

@jacekpoplawski
Copy link
Contributor

three GPUs

Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
Device 1: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
Device 2: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes

original:

model size params backend ngl test t/s
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 pp512 721.64 ± 2.92
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 tg128 67.14 ± 2.78

build: 707cbaf (7700)

improved:

model size params backend ngl test t/s
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 pp512 743.86 ± 3.90
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 tg128 71.88 ± 3.15

build: d5a0856 (7704)

two GPUs

Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
Device 1: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes

original:

model size params backend ngl test t/s
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 pp512 988.48 ± 6.17
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 tg128 74.99 ± 0.06

build: 707cbaf (7700)

improved:

model size params backend ngl test t/s
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 pp512 1029.37 ± 8.31
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 tg128 79.99 ± 0.07

build: d5a0856 (7704)

@ngxson ngxson merged commit 506bb6e into ggml-org:master Jan 11, 2026
80 checks passed
@ngxson
Copy link
Contributor Author

ngxson commented Jan 11, 2026

@bartowski1182 @danielhanchen if possible, could you re-generate the GGUF to take advantage of this change? Thanks!!

@SirSchnobi
Copy link

SirSchnobi commented Jan 11, 2026

I saw this PR on Reddit. I was wondering why my GGUF does not load anymore since this update.

I get the below error when using an old non-fixed GGUF.
Is this expected or are old models just not supported anymore?

I run vulkan radv with the llama.cpp b7703 in a docker container on Strix Halo here.

Attaching to llama-vulkan-radv
llama-vulkan-radv  | [39389] ggml_vulkan: Found 1 Vulkan devices:
llama-vulkan-radv  | [39389] ggml_vulkan: 0 = Radeon 8060S Graphics (RADV GFX1151) (radv) | uma: 1 | fp16: 1 | bf16: 0 | warp size: 64 | shared memory: 65536 | int dot: 1 | matrix cores: KHR_coopmat
llama-vulkan-radv  | [39389] main: n_parallel is set to auto, using n_parallel = 4 and kv_unified = true
llama-vulkan-radv  | [39389] build: 7703 (506bb6e01) with GNU 15.2.1 for Linux x86_64
llama-vulkan-radv  | [39389] system info: n_threads = 32, n_threads_batch = 32, total_threads = 32
llama-vulkan-radv  | [39389] 
llama-vulkan-radv  | [39389] system_info: n_threads = 32 (n_threads_batch = 32) / 32 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX_VNNI = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | BMI2 = 1 | AVX512 = 1 | AVX512_VBMI = 1 | AVX512_VNNI = 1 | AVX512_BF16 = 1 | LLAMAFILE = 1 | OPENMP = 1 | REPACK = 1 | 
llama-vulkan-radv  | [39389] 
llama-vulkan-radv  | [39389] init: using 31 threads for HTTP server
llama-vulkan-radv  | [39389] start: binding port with default address family
llama-vulkan-radv  | [39389] main: loading model
llama-vulkan-radv  | [39389] srv    load_model: loading model '/models/Qwen3-Next-80B-A3B-Instruct-UD-Q8_K_XL.gguf'
llama-vulkan-radv  | [39389] common_init_result: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on
llama-vulkan-radv  | [39389] /opt/llama.cpp/ggml/src/ggml.c:3553: GGML_ASSERT(ggml_is_contiguous(a)) failed
llama-vulkan-radv  | [39389] /lib64/libggml-base.so.0(+0x35a5) [0x7f593929f5a5]
llama-vulkan-radv  | [39389] /lib64/libggml-base.so.0(ggml_print_backtrace+0x1eb) [0x7f593929f96b]
llama-vulkan-radv  | [39389] /lib64/libggml-base.so.0(ggml_abort+0x11f) [0x7f593929faef]
llama-vulkan-radv  | [39389] /lib64/libggml-base.so.0(ggml_reshape_2d+0xab) [0x7f59392a5d6b]
llama-vulkan-radv  | [39389] /lib64/libllama.so.0(_ZN19llm_build_qwen3next23build_layer_attn_linearEP18llm_graph_input_rsP11ggml_tensorS3_S3_S3_i+0xa66) [0x7f593cc21246]
llama-vulkan-radv  | [39389] /lib64/libllama.so.0(_ZN19llm_build_qwen3nextC1ERK11llama_modelRK16llm_graph_params+0x195) [0x7f593cc21785]
llama-vulkan-radv  | [39389] /lib64/libllama.so.0(_ZNK11llama_model11build_graphERK16llm_graph_params+0x755) [0x7f593cb39105]
llama-vulkan-radv  | [39389] /lib64/libllama.so.0(_ZN13llama_context13graph_reserveEjjjPK22llama_memory_context_ibPm+0x193) [0x7f593cac3b33]
llama-vulkan-radv  | [39389] /lib64/libllama.so.0(_ZN13llama_contextC2ERK11llama_model20llama_context_params+0x17f6) [0x7f593cac86c6]
llama-vulkan-radv  | [39389] /lib64/libllama.so.0(llama_init_from_model+0x113) [0x7f593cac92d3]
llama-vulkan-radv  | [39389] /lib64/libllama.so.0(+0x2664a) [0x7f593ca9f64a]
llama-vulkan-radv  | [39389] /lib64/libllama.so.0(+0x27767) [0x7f593caa0767]
llama-vulkan-radv  | [39389] /lib64/libllama.so.0(llama_params_fit+0x4e) [0x7f593caa442e]
llama-vulkan-radv  | [39389] /usr/bin/llama-server() [0x64e951]
llama-vulkan-radv  | [39389] /usr/bin/llama-server() [0x6512c9]
llama-vulkan-radv  | [39389] /usr/bin/llama-server() [0x50736a]
llama-vulkan-radv  | [39389] /usr/bin/llama-server() [0x46a86d]
llama-vulkan-radv  | [39389] /lib64/libc.so.6(+0x35b5) [0x7f5938d145b5]
llama-vulkan-radv  | [39389] /lib64/libc.so.6(__libc_start_main+0x88) [0x7f5938d14668]
llama-vulkan-radv  | [39389] /usr/bin/llama-server() [0x46ce85]

@Som-anon
Copy link

@SirSchnobi try this #18761 and comment there. Commenting on merged pull requests is not the best way to report bugs...

ngxson added a commit to ngxson/llama.cpp that referenced this pull request Jan 11, 2026
@ngxson
Copy link
Contributor Author

ngxson commented Jan 11, 2026

will be fixed on #18762

ngxson added a commit that referenced this pull request Jan 11, 2026
@danielhanchen
Copy link
Contributor

@ngxson Nice work will update today!

@danielhanchen
Copy link
Contributor

Updated Instruct at https://huggingface.co/unsloth/Qwen3-Next-80B-A3B-Instruct-GGUF and Thinking is ongoing! Had to redo imatirx cal :)

@bartowski1182
Copy link
Contributor

Both of mine have been updated, good work on this

https://huggingface.co/bartowski/Qwen_Qwen3-Next-80B-A3B-Thinking-GGUF

https://huggingface.co/bartowski/Qwen_Qwen3-Next-80B-A3B-Instruct-GGUF

gary149 pushed a commit to gary149/llama-agent that referenced this pull request Jan 13, 2026
* qwen3next: simplify qkvz projection

* use ggml_swiglu_split

* revert swiglu_split, but remove redundant repeat()

* fix missing reshape

* rm 2 redundant transposes

* move mul_mat(k,q) to outside of chunking

* rm redundant cont

* improve g_cs_chunk

* add comments about no cont

* use std::pair instead of ggml_concat

* vectorize key_gdiff calculation

* rm unused tensor

* avoid ggml_concat inside loop

* bring back ggml_concat as it may not work on other backend

* nits
gary149 pushed a commit to gary149/llama-agent that referenced this pull request Jan 13, 2026
dillon-blake pushed a commit to Boxed-Logic/llama.cpp that referenced this pull request Jan 15, 2026
* qwen3next: simplify qkvz projection

* use ggml_swiglu_split

* revert swiglu_split, but remove redundant repeat()

* fix missing reshape

* rm 2 redundant transposes

* move mul_mat(k,q) to outside of chunking

* rm redundant cont

* improve g_cs_chunk

* add comments about no cont

* use std::pair instead of ggml_concat

* vectorize key_gdiff calculation

* rm unused tensor

* avoid ggml_concat inside loop

* bring back ggml_concat as it may not work on other backend

* nits
dillon-blake pushed a commit to Boxed-Logic/llama.cpp that referenced this pull request Jan 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Model specific python python script changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.