Skip to content

models : optimizing qwen3next graph#19375

Merged
ggerganov merged 19 commits intomasterfrom
gg/qwen3-next-opt
Feb 14, 2026
Merged

models : optimizing qwen3next graph#19375
ggerganov merged 19 commits intomasterfrom
gg/qwen3-next-opt

Conversation

@ggerganov
Copy link
Member

@ggerganov ggerganov commented Feb 5, 2026

Rewording the ggml compute graph to avoid too many unnecessary copies.

M2 Ultra:

Model Test t/s b7946 t/s gg/qwen3-next-opt Speedup
qwen3next 80B.A3B Q4_0 pp1 37.92 51.99 1.37
qwen3next 80B.A3B Q4_0 pp2 43.20 57.55 1.33
qwen3next 80B.A3B Q4_0 pp3 60.03 78.88 1.31
qwen3next 80B.A3B Q4_0 pp4 77.27 100.71 1.30
qwen3next 80B.A3B Q4_0 pp5 89.65 119.23 1.33
qwen3next 80B.A3B Q4_0 pp6 109.17 140.88 1.29
qwen3next 80B.A3B Q4_0 pp7 122.24 156.66 1.28
qwen3next 80B.A3B Q4_0 pp8 137.75 176.36 1.28
qwen3next 80B.A3B Q4_0 pp512 930.70 1125.73 1.21
qwen3next 80B.A3B Q4_0 pp2048 1049.91 1352.31 1.29
qwen3next 80B.A3B Q4_0 tg32 38.02 50.39 1.33
qwen3next 80B.A3B Q4_K_M pp1 34.00 46.47 1.37
qwen3next 80B.A3B Q4_K_M pp2 40.94 52.84 1.29
qwen3next 80B.A3B Q4_K_M pp3 56.43 72.38 1.28
qwen3next 80B.A3B Q4_K_M pp4 68.36 87.54 1.28
qwen3next 80B.A3B Q4_K_M pp5 81.44 104.15 1.28
qwen3next 80B.A3B Q4_K_M pp6 96.56 121.09 1.25
qwen3next 80B.A3B Q4_K_M pp7 105.85 131.26 1.24
qwen3next 80B.A3B Q4_K_M pp8 117.27 145.61 1.24
qwen3next 80B.A3B Q4_K_M pp512 842.57 1001.46 1.19
qwen3next 80B.A3B Q4_K_M pp2048 977.30 1232.47 1.26
qwen3next 80B.A3B Q4_K_M tg32 34.63 46.43 1.34
qwen3next 80B.A3B Q8_0 pp1 34.38 43.98 1.28
qwen3next 80B.A3B Q8_0 pp2 39.88 50.97 1.28
qwen3next 80B.A3B Q8_0 pp3 52.48 67.62 1.29
qwen3next 80B.A3B Q8_0 pp4 66.37 83.92 1.26
qwen3next 80B.A3B Q8_0 pp5 75.80 95.88 1.26
qwen3next 80B.A3B Q8_0 pp6 89.07 109.19 1.23
qwen3next 80B.A3B Q8_0 pp7 98.05 118.99 1.21
qwen3next 80B.A3B Q8_0 pp8 107.72 130.04 1.21
qwen3next 80B.A3B Q8_0 pp512 928.72 1113.64 1.20
qwen3next 80B.A3B Q8_0 pp2048 1047.39 1338.82 1.28
qwen3next 80B.A3B Q8_0 tg32 33.75 43.78 1.30

DGX Spark:

Model Test t/s b7946 t/s gg/qwen3-next-opt Speedup
qwen3next 80B.A3B Q4_0 pp512 1055.58 1161.67 1.10
qwen3next 80B.A3B Q4_0 pp2048 1059.00 1324.66 1.25
qwen3next 80B.A3B Q4_0 tg32 43.11 59.58 1.38
qwen3next 80B.A3B Q8_0 pp512 886.51 965.89 1.09
qwen3next 80B.A3B Q8_0 pp2048 1009.43 1246.61 1.23
qwen3next 80B.A3B Q8_0 tg32 31.13 39.68 1.27

Related backend optimizations and refactorings:

Notes:

Next PRs

@github-actions github-actions bot added the model Model specific label Feb 5, 2026
@jeffbolznv
Copy link
Collaborator

On my system:

before:

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -fa 1 -m c:\models\Qwen3-Next-80B-A3B-Instruct-Q2_K_L.gguf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 5090 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
load_backend: loaded Vulkan backend from Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo\ggml-vulkan.dll
load_backend: loaded CPU backend from Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo\ggml-cpu.dll
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| qwen3next 80B.A3B Q2_K - Medium |  27.23 GiB |    79.67 B | Vulkan     |  99 |  1 |           pp512 |      4608.85 ± 16.61 |
| qwen3next 80B.A3B Q2_K - Medium |  27.23 GiB |    79.67 B | Vulkan     |  99 |  1 |           tg128 |        165.05 ± 0.22 |

after:

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -fa 1 -m c:\models\Qwen3-Next-80B-A3B-Instruct-Q2_K_L.gguf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 5090 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
load_backend: loaded Vulkan backend from Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo\ggml-vulkan.dll
load_backend: loaded CPU backend from Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo\ggml-cpu.dll
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| qwen3next 80B.A3B Q2_K - Medium |  27.23 GiB |    79.67 B | Vulkan     |  99 |  1 |           pp512 |      4589.41 ± 15.89 |
| qwen3next 80B.A3B Q2_K - Medium |  27.23 GiB |    79.67 B | Vulkan     |  99 |  1 |           tg128 |        168.55 ± 0.39 |

So a couple percent for TG, no benefit for PP.

@pwilkin
Copy link
Collaborator

pwilkin commented Feb 5, 2026

@ggerganov did you look at the optimizations in #18792 ?

I think there's a ~15% improvement in PP there.

(for reference, I'm waiting for #18755 to get finished and merged so I can actually merge the delta-net codes)

@ggerganov
Copy link
Member Author

I missed this PR, thanks for pointing out. I'll rebase on top of it then.

@jacekpoplawski
Copy link
Contributor

jacekpoplawski commented Feb 5, 2026

no problems on master (tested only Release)

assert on RELEASE
(gdb) bt
#0  __pthread_kill_implementation (threadid=<optimized out>, signo=6, no_tid=0) at ./nptl/pthread_kill.c:44
#1  __pthread_kill_internal (threadid=<optimized out>, signo=6) at ./nptl/pthread_kill.c:89
#2  __GI___pthread_kill (threadid=<optimized out>, signo=signo@entry=6) at ./nptl/pthread_kill.c:100
#3  0x00007ffff6c4579e in __GI_raise (sig=sig@entry=6) at ../sysdeps/posix/raise.c:26
#4  0x00007ffff6c288cd in __GI_abort () at ./stdlib/abort.c:73
#5  0x00007ffff7dfd156 in ggml_abort () from /home/jacek/git/llama.cpp/build_2026.02.05_next/bin/libggml-base.so.0
#6  0x00007ffff6b21a26 in ggml_compute_forward_pad () from /home/jacek/git/llama.cpp/build_2026.02.05_next/bin/libggml-cpu.so.0
#7  0x00007ffff6aa67f7 in ggml_graph_compute_thread.isra () from /home/jacek/git/llama.cpp/build_2026.02.05_next/bin/libggml-cpu.so.0
#8  0x00007ffff72c9329 in GOMP_parallel () from /lib/x86_64-linux-gnu/libgomp.so.1
#9  0x00007ffff6aa9071 in ggml_graph_compute () from /home/jacek/git/llama.cpp/build_2026.02.05_next/bin/libggml-cpu.so.0
#10 0x00007ffff6aa9512 in ggml_backend_cpu_graph_compute(ggml_backend*, ggml_cgraph*) () from /home/jacek/git/llama.cpp/build_2026.02.05_next/bin/libggml-cpu.so.0
#11 0x00007ffff7e19dc7 in ggml_backend_sched_graph_compute_async () from /home/jacek/git/llama.cpp/build_2026.02.05_next/bin/libggml-base.so.0
#12 0x00007ffff74ba0b1 in llama_context::graph_compute(ggml_cgraph*, bool) () from /home/jacek/git/llama.cpp/build_2026.02.05_next/bin/libllama.so.0
#13 0x00007ffff74bc022 in llama_context::process_ubatch(llama_ubatch const&, llm_graph_type, llama_memory_context_i*, ggml_status&) ()
   from /home/jacek/git/llama.cpp/build_2026.02.05_next/bin/libllama.so.0
#14 0x00007ffff74c252f in llama_context::decode(llama_batch const&) () from /home/jacek/git/llama.cpp/build_2026.02.05_next/bin/libllama.so.0
#15 0x00007ffff74c408f in llama_decode () from /home/jacek/git/llama.cpp/build_2026.02.05_next/bin/libllama.so.0
#16 0x00005555557fb862 in common_init_from_params(common_params&) ()
#17 0x00005555556d6b2d in server_context_impl::load_model(common_params const&) ()
#18 0x0000555555623e5c in main ()
(gdb)
assert on Debug
llama-server: /home/jacek/git/llama.cpp/src/llama.cpp:404: uint32_t llama_params_fit_impl(const char*, llama_model_params*, llama_context_params*, float*, llama_model_tensor_buft_override*, size_t*, uint32_t, ggml_log_level)::ngl_t::n_full() const: Assertion `n_layer >= n_part' failed.
(gdb) bt
#0  __pthread_kill_implementation (threadid=<optimized out>, signo=6, no_tid=0) at ./nptl/pthread_kill.c:44
#1  __pthread_kill_internal (threadid=<optimized out>, signo=6) at ./nptl/pthread_kill.c:89
#2  __GI___pthread_kill (threadid=<optimized out>, signo=signo@entry=6) at ./nptl/pthread_kill.c:100
#3  0x00007fffedc4579e in __GI_raise (sig=sig@entry=6) at ../sysdeps/posix/raise.c:26
#4  0x00007fffedc288cd in __GI_abort () at ./stdlib/abort.c:73
#5  0x00007fffedc28830 in __assert_fail_base (fmt=<optimized out>, assertion=<optimized out>, file=<optimized out>, line=<optimized out>, function=<optimized out>) at ./assert/assert.c:118
#6  0x00007ffff6e50c99 in ngl_t::n_full (this=0x55555b846604) at /home/jacek/git/llama.cpp/src/llama.cpp:404
#7  0x00007ffff6e50e95 in operator() (__closure=0x7fffffff57d0, ngl_per_device=std::vector of length 4, capacity 4 = {...}, overflow_bufts=std::vector of length 4, capacity 4 = {...}, mparams=...)
    at /home/jacek/git/llama.cpp/src/llama.cpp:430
#8  0x00007ffff6e512ef in operator() (__closure=0x7fffffff5850, func_name=0x7ffff71474bd "llama_params_fit_impl", ngl_per_device=std::vector of length 4, capacity 4 = {...},
    overflow_bufts=std::vector of length 4, capacity 4 = {...}) at /home/jacek/git/llama.cpp/src/llama.cpp:458
#9  0x00007ffff6e541ca in llama_params_fit_impl (path_model=0x5555562fd2b0 "/mnt/models2/Qwen/Qwen_Qwen3-Next-80B-A3B-Instruct-Q5_K_M-00001-of-00002.gguf", mparams=0x7fffffff5a40,
    cparams=0x7fffffff5a90, tensor_split=0x555556378bf0, tensor_buft_overrides=0x555556391610, margins_s=0x5555562fd660, n_ctx_min=4096, log_level=GGML_LOG_LEVEL_ERROR)
    at /home/jacek/git/llama.cpp/src/llama.cpp:689
#10 0x00007ffff6e54faa in llama_params_fit (path_model=0x5555562fd2b0 "/mnt/models2/Qwen/Qwen_Qwen3-Next-80B-A3B-Instruct-Q5_K_M-00001-of-00002.gguf", mparams=0x7fffffff5a40,
    cparams=0x7fffffff5a90, tensor_split=0x555556378bf0, tensor_buft_overrides=0x555556391610, margins=0x5555562fd660, n_ctx_min=4096, log_level=GGML_LOG_LEVEL_ERROR)
    at /home/jacek/git/llama.cpp/src/llama.cpp:748
#11 0x00005555558f368c in common_init_result::common_init_result (this=0x5555562567a0, params=...) at /home/jacek/git/llama.cpp/common/common.cpp:1099
#12 0x00005555558f4124 in common_init_from_params (params=...) at /home/jacek/git/llama.cpp/common/common.cpp:1215
#13 0x000055555570f999 in server_context_impl::load_model (this=0x5555563788f0, params=...) at /home/jacek/git/llama.cpp/tools/server/server-context.cpp:625
#14 0x00005555556e8ac8 in server_context::load_model (this=0x7fffffff79e0, params=...) at /home/jacek/git/llama.cpp/tools/server/server-context.cpp:2856
#15 0x00005555556167bd in main (argc=3, argv=0x7fffffffe0f8) at /home/jacek/git/llama.cpp/tools/server/server.cpp:248
(gdb)

@ngxson
Copy link
Contributor

ngxson commented Feb 6, 2026

Just note that I'm also experimenting some smaller optimizations on my side, but not sure if I have time to finish it.

The main idea is that:

// last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta
ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta);

This operation broadcast [1, n] to [n, n], then multiply it with [n, 1], so it's essentially an outer product, can be converted to a mul_mat.

And one more place:

    // again, since it's over dim = -2, transpose, sum, transpose back
    ggml_tensor * core_attn_out =
        ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q))));

The code above seems to be replaceable with a mul_mat too:

image

@ggerganov
Copy link
Member Author

I noticed that this can become a matmul, but my tests show that the sum-rows version is actually faster - likely matmul with ne00 = 1 is too degenerate case. Note that in my version, I have avoided the transpose + sum_rows + transpose.

@jacekpoplawski
Copy link
Contributor

jacekpoplawski commented Feb 10, 2026

still crashing on my system

Details
This GDB supports auto-downloading debuginfo from the following URLs:
  <https://debuginfod.ubuntu.com>
Enable debuginfod for this session? (y or [n]) [answered N; input not from terminal]
Debuginfod has been disabled.
To make this setting permanent, add 'set debuginfod enabled off' to .gdbinit.
Function(s) ^std::(move|forward|as_const|(__)?addressof) will be skipped when stepping.
Function(s) ^std::(shared|unique)_ptr<.*>::(get|operator) will be skipped when stepping.
Function(s) ^std::(basic_string|vector|array|deque|(forward_)?list|(unordered_|flat_)?(multi)?(map|set)|span)<.*>::(c?r?(begin|end)|front|back|data|size|empty) will be skipped when stepping.
Function(s) ^std::(basic_string|vector|array|deque|span)<.*>::operator.] will be skipped when stepping.
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".
__syscall_cancel_arch () at ../sysdeps/unix/sysv/linux/x86_64/syscall_cancel.S:56
warning: 56     ../sysdeps/unix/sysv/linux/x86_64/syscall_cancel.S: No such file or directory
#0  __syscall_cancel_arch () at ../sysdeps/unix/sysv/linux/x86_64/syscall_cancel.S:56
56      in ../sysdeps/unix/sysv/linux/x86_64/syscall_cancel.S
#1  0x0000711c1409eb63 in __internal_syscall_cancel (a1=<optimized out>, a2=<optimized out>, a3=<optimized out>, a4=<optimized out>, a5=0, a6=0, nr=61) at ./nptl/cancellation.c:49
warning: 49     ./nptl/cancellation.c: No such file or directory
#2  __syscall_cancel (a1=<optimized out>, a2=<optimized out>, a3=<optimized out>, a4=<optimized out>, a5=a5@entry=0, a6=a6@entry=0, nr=61) at ./nptl/cancellation.c:75
75      in ./nptl/cancellation.c
#3  0x0000711c1411ae9f in __GI___wait4 (pid=<optimized out>, stat_loc=<optimized out>, options=<optimized out>, usage=<optimized out>) at ../sysdeps/unix/sysv/linux/wait4.c:30
warning: 30     ../sysdeps/unix/sysv/linux/wait4.c: No such file or directory
#4  0x0000711c15242fd3 in ggml_print_backtrace () from /home/jacek/git/llama.cpp/build_2026.02.10_qwen/bin/libggml-base.so.0
#5  0x0000711c1524317b in ggml_abort () from /home/jacek/git/llama.cpp/build_2026.02.10_qwen/bin/libggml-base.so.0
#6  0x0000711c131ef9c1 in void launch_bin_bcast_pack<&(op_mul(float, float)), float, float, float, 0ul>(ggml_tensor const*, ggml_tensor const*, ggml_tensor*, float const*, float const*, float*, CUstream_st*, std::integer_sequence<unsigned long, 0ul>) [clone .constprop.0] () from /home/jacek/git/llama.cpp/build_2026.02.10_qwen/bin/libggml-cuda.so.0
#7  0x0000711c132290f8 in ggml_cuda_op_mul(ggml_backend_cuda_context&, ggml_tensor*) () from /home/jacek/git/llama.cpp/build_2026.02.10_qwen/bin/libggml-cuda.so.0
#8  0x0000711c1329e25f in ggml_cuda_compute_forward(ggml_backend_cuda_context&, ggml_tensor*) () from /home/jacek/git/llama.cpp/build_2026.02.10_qwen/bin/libggml-cuda.so.0
#9  0x0000711c132a230a in ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context*, ggml_cgraph*, bool, bool, void const*) () from /home/jacek/git/llama.cpp/build_2026.02.10_qwen/bin/libggml-cuda.so.0
#10 0x0000711c132a44be in ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) () from /home/jacek/git/llama.cpp/build_2026.02.10_qwen/bin/libggml-cuda.so.0
#11 0x0000711c1525fdc7 in ggml_backend_sched_graph_compute_async () from /home/jacek/git/llama.cpp/build_2026.02.10_qwen/bin/libggml-base.so.0
#12 0x0000711c148bb6b1 in llama_context::graph_compute(ggml_cgraph*, bool) () from /home/jacek/git/llama.cpp/build_2026.02.10_qwen/bin/libllama.so.0
#13 0x0000711c148bd622 in llama_context::process_ubatch(llama_ubatch const&, llm_graph_type, llama_memory_context_i*, ggml_status&) () from /home/jacek/git/llama.cpp/build_2026.02.10_qwen/bin/libllama.so.0
#14 0x0000711c148c3b2f in llama_context::decode(llama_batch const&) () from /home/jacek/git/llama.cpp/build_2026.02.10_qwen/bin/libllama.so.0
#15 0x0000711c148c568f in llama_decode () from /home/jacek/git/llama.cpp/build_2026.02.10_qwen/bin/libllama.so.0
#16 0x0000601e043b57ac in server_context_impl::update_slots() ()
#17 0x0000601e04402716 in server_queue::start_loop(long) ()
#18 0x0000601e0430e3b1 in main ()
[Inferior 1 (process 13641) detached]
Aborted (core dumped)

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Feb 10, 2026
@jacekpoplawski
Copy link
Contributor

Without the asserts, I was able to run the code and even measure some speedup over master

Qwen_Qwen3-Next-80B-A3B-Instruct-Q5_K_M on 3x3090 g p

@github-actions github-actions bot added the Apple Metal https://en.wikipedia.org/wiki/Metal_(API) label Feb 11, 2026
@pwilkin
Copy link
Collaborator

pwilkin commented Feb 12, 2026

@ggerganov BTW, I think this is a good moment to ask: back when I first tried to implement the chunking, my idea was like this:

  • preallocate the entire attn_out tensor
  • create a view for the slice
  • ggml_cpy the relevant result into the slice

I remember @am17an mentioning he tried something similar. However, that always failed due to scheduler errors - seemed like the view didn't get a backend assigned to it properly. Any idea why? It still might be faster than doing all the CONCATs.

@ggerganov
Copy link
Member Author

I'll see if we can avoid the concats next.

@ggerganov
Copy link
Member Author

Using ggml_set() is the proper way to avoid concats in such scenarios. I think it should work now - would appreciate some tests. The GGML_OP_SET is missing in some backends (f.ex Metal), so I'll be implementing this now.

However, that always failed due to scheduler errors - seemed like the view didn't get a backend assigned to it properly. Any idea why?

Preallocating a tensor with ggml_new_tensor() is always an anti-pattern. The scheduler has no way to know on which backend to allocate the tensor. Also it can increase memory consumption significantly. As a rule of thumb, always try hard to avoid ggml_new_tensor() in the compute graphs, except for inputs.

@github-actions github-actions bot added the testing Everything test related label Feb 12, 2026
@ngxson
Copy link
Contributor

ngxson commented Feb 12, 2026

Using ggml_set() is the proper way to avoid concats in such scenarios. I think it should work now - would appreciate some tests. The GGML_OP_SET is missing in some backends (f.ex Metal), so I'll be implementing this now.

Interestingly, I have a draft PR #18759 that uses ggml_set_rows because it seems to be already supported by all backend. Seems to improve the perf significantly on Vulkan. Ref the issue I created recently: #19432

@ggerganov
Copy link
Member Author

ggml_set_rows() should also work, but it needs to prepare the destination indices in host and as you noted increases the code complexity. There might be some way to utilize ggml_arange() and simplify - not sure.

The current version also improves the perf significantly, especially for larger ubatches where we have more chunks. I think adding support to the missing backends should not be difficult and would be quite useful in the future in similar situations.

@ngxson
Copy link
Contributor

ngxson commented Feb 12, 2026

One more the reason why I considered using ggml_set_rows was because the core_attn_out need to be permuted in the end, so I was thinking if I can calculate the correct indexes in advanced to avoid the permute. But it seems to be quite too complicated.

// permute back to (S_v, H_v, n_tokens, n_seqs)
output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
output_tokens = ggml_cont(ctx0, output_tokens);

@pwilkin
Copy link
Collaborator

pwilkin commented Feb 12, 2026

Preallocating a tensor with ggml_new_tensor() is always an anti-pattern. The scheduler has no way to know on which backend to allocate the tensor. Also it can increase memory consumption significantly. As a rule of thumb, always try hard to avoid ggml_new_tensor() in the compute graphs, except for inputs.

This above seems like something worth adding to the model adding tutorial :)

@ggerganov ggerganov merged commit 1725e31 into master Feb 14, 2026
77 of 78 checks passed
@ggerganov ggerganov deleted the gg/qwen3-next-opt branch February 14, 2026 10:57
@jacekpoplawski
Copy link
Contributor

Excellent results g p

@Mainframework
Copy link

Mainframework commented Feb 14, 2026

Feedback for the code owners/authors

Since here: #18266
tps and memory also ram went downhill. This is my assesment, because the test results show a faster inference in qwen next up to x3 times, using previous builds from 71xx to 7552, but also overall in other models. I guess we all have to th@nk: JohannesGaessler for the change. Also can you guys take it easy with parasiting it with "bloatware", feels more and more invasive and slower. Standing to the bylaws of LLama.cpp, the purpose should be the use of AI with minimal clean code, meaning no bloatware, no telemetry.

With this said, I appreciate your hard work and hope you´ll get your funding round soon.

Edit: All good, with the 8068 working flawesly and damn fast. Qwen3.5, Minimax 2.5, Glm5, all flying, Thank you :)

@ymcki
Copy link
Contributor

ymcki commented Feb 14, 2026

Very impressive ggml magic here for Qwen3Next. I think most of the changes can also be replicated for Kimi Linear. Is @ggerganov also going to do it for Kimi Linear? Or should I do it in a new PR? Or should I do it within the unified delta_net PR?

@ggerganov
Copy link
Member Author

@ymcki I think you should try to integrate the Kimi Linear delta net into the new llm_build_delta_net_base class that we will introduce in #19597. Also keep track of the progress in #19504 - eventually the new dedicated delta net op should be compatible with the kda too. Not sure about the details yet, so feedback is appreciated.

@ymcki
Copy link
Contributor

ymcki commented Feb 14, 2026

Feedback for the code owners/authors

Since here: #18266 tps and memory also ram went downhill. This is my assesment, because the test results show a faster inference in qwen next up to x3 times, using previous builds from 71xx to 7552, but also overall in other models. I guess we all have to th@nk: JohannesGaessler for the change. Also can you guys take it easy with parasiting it with "bloatware", feels more and more invasive and slower. Standing to the bylaws of LLama.cpp, the purpose should be the use of AI with minimal clean code, meaning no bloatware, no telemetry.

With this said, I appreciate your hard work and hope you´ll get your funding round soon.

Are you talking about this issue?
#18258

@tarruda
Copy link

tarruda commented Feb 15, 2026

Here are some benchmarks on M1 Ultra:

before:

% llama-bench -m Qwen3-Coder-Next-GGUF/Q8_0/Qwen3-Coder-Next-Q8_0-00001-of-00003.gguf -fa 1 -t 1 -ngl 99 -b 2048 -ub 2048 -d 0,10000,20000,30000,40000,50000,60000,70000,80000,90000,100000                                                             
ggml_metal_device_init: tensor API disabled for pre-M5 and pre-A19 devices
ggml_metal_library_init: using embedded metal library
ggml_metal_library_init: loaded in 0.023 sec
ggml_metal_rsets_init: creating a residency set collection (keep_alive = 180 s)
ggml_metal_device_init: GPU name:   MTL0
ggml_metal_device_init: GPU family: MTLGPUFamilyApple7  (1007)
ggml_metal_device_init: GPU family: MTLGPUFamilyCommon3 (3003)
ggml_metal_device_init: GPU family: MTLGPUFamilyMetal3  (5001)
ggml_metal_device_init: simdgroup reduction   = true
ggml_metal_device_init: simdgroup matrix mul. = true
ggml_metal_device_init: has unified memory    = true
ggml_metal_device_init: has bfloat            = true
ggml_metal_device_init: has tensor            = false
ggml_metal_device_init: use residency sets    = true
ggml_metal_device_init: use shared buffers    = true
ggml_metal_device_init: recommendedMaxWorkingSetSize  = 134217.73 MB
| model                          |       size |     params | backend    | threads | n_ubatch | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | -------: | -: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |           pp512 |        599.06 ± 3.33 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |           tg128 |         26.00 ± 0.03 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d10000 |        521.51 ± 3.15 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d10000 |         24.90 ± 0.02 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d20000 |        455.07 ± 2.90 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d20000 |         24.60 ± 0.03 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d30000 |        400.96 ± 2.74 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d30000 |         23.88 ± 0.02 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d40000 |        360.77 ± 2.62 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d40000 |         23.18 ± 0.01 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d50000 |        326.46 ± 3.03 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d50000 |         22.56 ± 0.03 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d60000 |        298.34 ± 1.14 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d60000 |         22.15 ± 0.02 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d70000 |        269.18 ± 1.52 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d70000 |         21.55 ± 0.07 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d80000 |        249.38 ± 1.37 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d80000 |         20.85 ± 0.01 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d90000 |        231.09 ± 1.12 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d90000 |         20.32 ± 0.00 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 | pp512 @ d100000 |        214.36 ± 1.49 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 | tg128 @ d100000 |         19.86 ± 0.01 |

build: 8872ad212 (7966)

after

% llama-bench -m Qwen3-Coder-Next-Q8_0-00001-of-00003.gguf -fa 1 -t 1 -ngl 99 -b 2048 -ub 2048 -d 0,10000,20000,30000,40000,50000,60000,70000,80000,90000,100000,150000,200000,250000                                                                   
ggml_metal_device_init: tensor API disabled for pre-M5 and pre-A19 devices                                               
ggml_metal_library_init: using embedded metal library                                                                    
ggml_metal_library_init: loaded in 0.013 sec        
ggml_metal_rsets_init: creating a residency set collection (keep_alive = 180 s)
ggml_metal_device_init: GPU name:   MTL0                                                                                 
ggml_metal_device_init: GPU family: MTLGPUFamilyApple7  (1007)
ggml_metal_device_init: GPU family: MTLGPUFamilyCommon3 (3003)
ggml_metal_device_init: GPU family: MTLGPUFamilyMetal3  (5001)
ggml_metal_device_init: simdgroup reduction   = true
ggml_metal_device_init: simdgroup matrix mul. = true                                                                     
ggml_metal_device_init: has unified memory    = true                                                                                                                                                                                              
ggml_metal_device_init: has bfloat            = true                                                                                                                                                                                              
ggml_metal_device_init: has tensor            = false                                                                                                                                                                                             
ggml_metal_device_init: use residency sets    = true                                                                                                                                                                                              
ggml_metal_device_init: use shared buffers    = true                                                                                                                                                                                              
ggml_metal_device_init: recommendedMaxWorkingSetSize  = 134217.73 MB                                                                                                                                                                              
| model                          |       size |     params | backend    | threads | n_ubatch | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | -------: | -: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |           pp512 |        722.52 ± 1.36 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |           tg128 |         31.49 ± 0.01 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d10000 |        611.33 ± 4.90 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d10000 |         30.24 ± 0.02 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d20000 |        523.43 ± 3.74 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d20000 |         29.25 ± 0.01 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d30000 |        454.62 ± 3.84 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d30000 |         28.14 ± 0.01 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d40000 |        402.00 ± 2.54 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d40000 |         27.30 ± 0.01 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d50000 |        360.59 ± 2.14 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d50000 |         26.24 ± 0.03 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d60000 |        326.34 ± 0.79 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d60000 |         25.52 ± 0.02 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d70000 |        292.57 ± 2.95 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d70000 |         24.78 ± 0.01 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d80000 |        267.50 ± 2.12 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d80000 |         24.17 ± 0.01 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  pp512 @ d90000 |        249.85 ± 1.94 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 |  tg128 @ d90000 |         23.47 ± 0.01 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 | pp512 @ d100000 |        230.61 ± 1.89 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 | tg128 @ d100000 |         22.83 ± 0.01 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 | pp512 @ d150000 |        170.15 ± 1.96 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 | tg128 @ d150000 |         19.96 ± 0.00 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 | pp512 @ d200000 |        135.10 ± 1.65 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 | tg128 @ d200000 |         17.79 ± 0.01 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 | pp512 @ d250000 |        112.77 ± 0.64 |                                                                                                                         
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | MTL,BLAS   |       1 |     2048 |  1 | tg128 @ d250000 |         16.05 ± 0.00 |                                                                                                                         
                                                            
build: 01d8eaa28 (8054)

liparetejas pushed a commit to liparetejas/llama.cpp that referenced this pull request Feb 23, 2026
* models : optimizing qwen3next graph

* cont

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* cont : remove redundant q, g chunking

* minor

* minor

* avoid passing masks around

* avoid concats during chunking

* naming + shapes

* update names and use prefix to disable CUDA graphs
@rhjdvsgsgks
Copy link
Contributor

Notes:

@bartowski1182 i found all your qwen next quant have this problem. do you have any plan to update them?

@bartowski1182
Copy link
Contributor

hmm fair call out, yeah i should probably do that, i have some other improvements too that'll apply..

bartowski1182 pushed a commit to bartowski1182/llama.cpp that referenced this pull request Mar 2, 2026
* models : optimizing qwen3next graph

* cont

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* cont : remove redundant q, g chunking

* minor

* minor

* avoid passing masks around

* avoid concats during chunking

* naming + shapes

* update names and use prefix to disable CUDA graphs
ArberSephirotheca pushed a commit to ArberSephirotheca/llama.cpp that referenced this pull request Mar 3, 2026
* models : optimizing qwen3next graph

* cont

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* cont : remove redundant q, g chunking

* minor

* minor

* avoid passing masks around

* avoid concats during chunking

* naming + shapes

* update names and use prefix to disable CUDA graphs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Apple Metal https://en.wikipedia.org/wiki/Metal_(API) ggml changes relating to the ggml tensor library for machine learning model Model specific Nvidia GPU Issues specific to Nvidia GPUs testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants