Skip to content

CUDA: Enable cuda graphs for qwen3 next-style architectures#19521

Closed
ORippler wants to merge 3 commits intoggml-org:masterfrom
ORippler:osimons/cuda_graphs_qwen3_next
Closed

CUDA: Enable cuda graphs for qwen3 next-style architectures#19521
ORippler wants to merge 3 commits intoggml-org:masterfrom
ORippler:osimons/cuda_graphs_qwen3_next

Conversation

@ORippler
Copy link
Collaborator

@ORippler ORippler commented Feb 11, 2026

This is achieved via:

  1. Adding regex for pico-batched adds that happen even when ubatch is 1
  2. Bumping the consecutive updates threshold by 1.

While we already know 1. as tech-debt in cuda backend from other models, 2. was new to me. I did try to investigate what happens with the frequent graph updates and hybrid model, but could not figure it out fully:

  1. We seem to build different shapes of conv_states_updated-0 two times at ˜128 token intervals from 0 -> 24576 -> 0. This seems to come from llm_graph_input_rs?. My understanding is that this conv_states_updated-0 somewhat reflect a per-sequence mapping of our conv states, whereas we will only have a single sequence in llama-bench's token-generation phase (so it should always be 0 in that case? At least in the reused ggml_cgraph-case it ends-up being 0). This change in ggml_cgraph topology is however seemingly captured by llm_graph_result::can_reuse.
  2. node_143 (this is somewhere in the first build_layer_attn_linear of qwen3next) seems to have additional data added to it despite the graph being denoted as being reusable by llm_graph_result::can_reuse. Is this potentially a bug? I was under the impression that graph topology is not allowed to change when its reused in llama_context::process_ubatch.
Relevant sections from log output and diff patch for above two points
./build-x64-linux-gcc-reldbg/bin/llama-bench -m /mnt/share/gguf/Qwen/Qwen3-Coder-Next-GGUF/Qwen3-Coder-Next-Q4_K_M/Qwen3-Coder-Next-Q4_K_M-00001-of-00004.gguf -mmp 0 -dio 1 -fa 1 -p 0 2>&1 | tee run.txt

Middle of llama-bench
can_reuse: can reuse graph = 1
can_reuse: can reuse graph = 1
can_reuse: can reuse graph = 1
can_reuse: can reuse graph = 1
can_reuse: can reuse graph = 1
can_reuse: can reuse graph = 1
can_reuse: can reuse graph = 0
node->ne[0] mismatch: 24576 vs 0 for node->name: conv_states_updated-0 (reshaped) (view)
node 6 properties do not match, node name: conv_states_updated-0 (reshaped) (view)
can_reuse: can reuse graph = 0
node->ne[0] mismatch: 0 vs 24576 for node->name: conv_states_updated-0 (reshaped) (view)
node 6 properties do not match, node name: conv_states_updated-0 (reshaped) (view)
can_reuse: can reuse graph = 1
node->src[2]->data mismatch: 0x7f41f801d780 vs (nil) for node->name: node_144
node 143 properties do not match, node name: node_144
can_reuse: can reuse graph = 1
can_reuse: can reuse graph = 1
can_reuse: can reuse graph = 1
can_reuse: can reuse graph = 1
can_reuse: can reuse graph = 1
can_reuse: can reuse graph = 1

Applied diff to get that log

diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 2ac48ca99..66217ee23 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -2961,6 +2961,7 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
 
     for (int i = 0; i < GGML_MAX_DIMS; i++) {
         if (node->ne[i] != props->ne[i]) {
+            printf("node->ne[%d] mismatch: %ld vs %ld for node->name: %s\n", i, node->ne[i], props->ne[i], node->name);
             return false;
         }
         if (node->nb[i] != props->nb[i]) {
@@ -2978,6 +2979,7 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
             }
 
             if (node->src[i]->data != props->src_data[i]) {
+                printf("node->src[%d]->data mismatch: %p vs %p for node->name: %s\n", i, node->src[i]->data, props->src_data[i], node->name);
                 return false;
             }
         }
@@ -3027,6 +3029,7 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
             props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]);
         }
         if (!props_match) {
+            printf("node %d properties do not match, node name: %s\n", i, cgraph->nodes[i]->name);
             res = true;
         }
         ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]);
diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
index bba747d37..f542cdbed 100644
--- a/src/llama-graph.cpp
+++ b/src/llama-graph.cpp
@@ -795,6 +795,7 @@ bool llm_graph_result::can_reuse(const llm_graph_params & params) {
     if (debug > 0) {
         LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res);
     }
+    printf("%s: can reuse graph = %d\n", __func__, res);
 
     return res;
 }

Maybe folks from #16490 and #16095 (@ggerganov @pwilkin @gabe-l-hart ) could chime in here.


I'll try to collect numbers to see if we can remove that consecutive update counter in general. This should be feasible if Capture + Launch of a cudaGraph is always faster than naive launch on our workloads under both Linux and Windows.
In the mean-time, this closes #19345

I did verify perf

./build-x64-linux-gcc-reldbg/bin/llama-bench -m /mnt/share/gguf/Qwen/Qwen3-Coder-Next-GGUF/Qwen3-Coder-Next-Q8_0/Qwen3-Coder-Next-Q8_0-00001-of-00004.gguf -mmp 0 -dio 1 -fa 1 -p 0
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition, compute capability 12.0, VMM: yes
| model                          |       size |     params | backend    | ngl | fa | dio |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | CUDA       |  99 |  1 |   1 |           tg128 |        114.20 ± 0.81 |

build: 2a5964425 (8002)

and functional correctness.

./build-x64-linux-gcc-reldbg/bin/llama-completion -m /mnt/share/gguf/Qwen/Qwen3-Coder-Next-GGUF/Qwen3-Coder-Next-Q8_0/Qwen3-Coder-Next-Q8_0-00001-of-00004.gguf --no-mmap -dio -fa 1 -p "Hello there"
...
- If you want to submit another line, end your input with '\'.
- Not using system message. To change it, set a different value via -sys PROMPT

user
Hello there
assistant
Hello! How can I help you today? 😊

@ORippler ORippler requested a review from CISC as a code owner February 11, 2026 17:49
@github-actions github-actions bot added model Model specific Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Feb 11, 2026
@pwilkin
Copy link
Contributor

pwilkin commented Feb 11, 2026

@ORippler which one is node_143? Can you paste the operations line (op and source nodes)?

@am17an
Copy link
Contributor

am17an commented Feb 12, 2026

Also maybe try with --no-warmup for the logs

@ORippler
Copy link
Collaborator Author

@ORippler which one is node_143? Can you paste the operations line (op and source nodes)?

This turned out to be an issue in the CUDA backend, see 23b757a

@am17an lmk if you wish to isolate 23b757a into its own PR

@ORippler
Copy link
Collaborator Author

ORippler commented Feb 12, 2026

Though I'm still confused on why we rebuild the graph twice before being able to reuse it (basically conv_states_updated-0 -> 24576 on first build, and conv_states_updated-0 -> 0 on second build).

@ORippler
Copy link
Collaborator Author

I just saw there is on-going work for a better model-graph for qwen3-next style-architectures in llama.cpp (#19504 and #19375).

I still feel we could merge this PR regardless so people get good perf on CUDA backend in the mean-time (happy to update heuristics to the new node-names should they change in the new model-graph)

@am17an
Copy link
Contributor

am17an commented Feb 12, 2026

@ORippler let's do a separate PR for that

@ORippler
Copy link
Collaborator Author

ORippler commented Feb 12, 2026

@ORippler let's do a separate PR for that

Done. As I don't know if it's even possible to maintain stacked PR across forks on Github, I removed the commit here (and this PR now requires #19566 to be merged first & depends on it to actually work)

EDIT: the change is still in this branch, I'll do an interactive rebase once #19566 has been merged)

@am17an
Copy link
Contributor

am17an commented Feb 12, 2026

BTW just for completeness, can you run a larger correctness test for TG. Perhaps llama-perplexity with -ub 1 on a small corpus

@jacekpoplawski
Copy link
Contributor

very nice speedup! g p

@ORippler ORippler force-pushed the osimons/cuda_graphs_qwen3_next branch from 28177e3 to 89decb7 Compare February 13, 2026 09:41
@ORippler
Copy link
Collaborator Author

ORippler commented Feb 13, 2026

I rebased the branch now that #19566 was merged.

BTW just for completeness, can you run a larger correctness test for TG. Perhaps llama-perplexity with -ub 1 on a small corpus

Here are numbers for wiki.test.raw that show equivalence PPL:

CUDA Graphs on (confirmed manually by debugging, also see faster throughput/ETA)

./build-x64-linux-gcc-reldbg/bin/llama-perplexity -m /mnt/share/gguf/Qwen/Qwen3-Coder-Next-GGUF/Qwen3-Coder-Next-Q4_K_M/Qwen3-Coder-Next-Q4_K_M-00001-of-00004.gguf -dio -fa 1 -f wikitext-2-raw/wiki.test.raw -ub 1 --no-warmup -b 1 2>&1 | tee perplexity_CG_ON.log

perplexity: calculating perplexity over 584 chunks, n_ctx=512, batch_size=1, n_seq=1
perplexity: 4.09 seconds per pass - ETA 39.82 minutes
[1]4.5973,[2]6.6801,[3]5.6123,[4]4.9002,[5]4.8127,[6]5.0074,[7]5.0961,[8]5.2216,[9]5.1102,[10]5.1793,[11]5.1086,
.....
579]8.3659,[580]8.3814,[581]8.3829,[582]8.3947,[583]8.3787,[584]8.3729,
Final estimate: PPL = 8.3729 +/- 0.06545

CUDA Graphs off

GGML_CUDA_DISABLE_GRAPHS=1 ./build-x64-linux-gcc-reldbg/bin/llama-perplexity -m /mnt/share/gguf/Qwen/Qwen3-Coder-Next-GGUF/Qwen3-Coder-Next-Q4_K_M/Qwen3-Coder-Next-Q4_K_M-00001-of-00004.gguf -dio -fa 1 -f wikitext-2-raw/wiki.test.raw -ub 1 --no-warmup -b 1 2>&1 | tee perplexity_CG_OFF.log

perplexity: 5.61 seconds per pass - ETA 54.62 minutes
[1]4.5973,[2]6.6801,[3]5.6123,[4]4.9002,[5]4.8127,[6]5.0074,[7]5.0961,[8]5.2216,[9]5.1102,[10]5.1793,[11]5.1086,
...
[579]8.3659,[580]8.3814,[581]8.3829,[582]8.3947,[583]8.3787,[584]8.3729,
Final estimate: PPL = 8.3729 +/- 0.06545

This is less strict, but should still do the job and be faster.
Also matches how we compare the other prefixes
Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change will not work after we merge #19375, but I will update this PR to support the new graph. Will merge it after that.

Comment on lines +2899 to 2918
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 &&
(node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
(node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) {
strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0 &&
strncmp(node->name, qwen3next_diag_mask.c_str(), qwen3next_diag_mask.size()) != 0 &&
strncmp(node->name, delta_net_linear_attn_prefix.c_str(), delta_net_linear_attn_prefix.size()) != 0) {
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
// by means of matching node names. See
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
// https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
use_cuda_graph = false;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
#endif
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's project_per_layer_input operation

Can we confirm that the only reason we have this entire logic is because of Gemma3n - as explained in the comment?

// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.

Haven't looked at the source of this comment, but changes in the batch size / context size should already be detected by the node properties comparison, correct?

@ggerganov
Copy link
Member

I merged this change directly into #19375 - see 4521751. We can close this PR.

Looking at this logic, I think we can simply remove the GGML_OP_ADD checks - they should not be needed anymore because of the node properties matching that we perform. Unless I am missing something?

@ORippler
Copy link
Collaborator Author

Looking at this logic, I think we can simply remove the GGML_OP_ADD checks - they should not be needed anymore because of the node properties matching that we perform. Unless I am missing something?

The reason we added this was historical: Users raised segfaults in multi-GPU setting in the PR that added CUDA Graphs support, likely caused by us failing to detect the need to recapture the CUDA Graph. Instead of root-causing and fixing these issues, band-aids have been put in place and updated ever since. I've been meaning to remove the update-counter + batch-size heuristics once we validate:

  1. Multi-GPU + CUDA Graphs works stably in PP phase (we know it to work for TG in multi-GPU setting already)
  2. Frequent recapture of CUDA Graphs within PP phase due to ubatch-size-changes do not yield perf regression compared to just using CUDA's launch API

@ORippler ORippler deleted the osimons/cuda_graphs_qwen3_next branch February 16, 2026 08:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning model Model specific Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Eval bug: Llama.cpp 40% slower than VLLM + high CPU usage when running Qwen Coder Next

6 participants