Skip to content

server-context: fall back to full seq clear when partial KV eviction is refused#23280

Merged
ggerganov merged 3 commits into
ggml-org:masterfrom
ServeurpersoCom:fix/kv-partial-evict-fallback
May 19, 2026
Merged

server-context: fall back to full seq clear when partial KV eviction is refused#23280
ggerganov merged 3 commits into
ggml-org:masterfrom
ServeurpersoCom:fix/kv-partial-evict-fallback

Conversation

@ServeurpersoCom

@ServeurpersoCom ServeurpersoCom commented May 18, 2026

Copy link
Copy Markdown
Contributor

Overview

To reproduce on master, run llama-server with a recent hybrid attention model such as Qwen3.6-MoE, fill the KV cache with a few conversation turns, then click "Continue" at the end of an assistant reply and watch the server abort on a partial seq_rm refusal.

Additional information

Fix this

[57067] 0.52.065.769 W slot update_slots: id  0 | task 1900 | n_past was set to 1739
[57067] /root/llama.cpp.pascal/common/common.cpp:1489: failed to remove sequence 0 with p0=1739, p1=-1
[57067]
[57067] [New LWP 304479]
...
[57067] [New LWP 304336]
[57067] [Thread debugging using libthread_db enabled]
[57067] Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".
[57067] __syscall_cancel_arch () at ../sysdeps/unix/sysv/linux/x86_64/syscall_cancel.S:56
[57067] warning: 56     ../sysdeps/unix/sysv/linux/x86_64/syscall_cancel.S: Aucun fichier ou dossier de ce nom
[57067] #0  __syscall_cancel_arch () at ../sysdeps/unix/sysv/linux/x86_64/syscall_cancel.S:56
[57067] 56      in ../sysdeps/unix/sysv/linux/x86_64/syscall_cancel.S
[57067] #1  0x00007f9d601ab668 in __internal_syscall_cancel (a1=<optimized out>, a2=<optimized out>, a3=<optimized out>, a4=<optimized out>, a5=a5@entry=0, a6=a6@entry=0, nr=61) at ./nptl/cancellation.c:49
[57067] warning: 49     ./nptl/cancellation.c: Aucun fichier ou dossier de ce nom
[57067] #2  0x00007f9d601ab6ad in __syscall_cancel (a1=<optimized out>, a2=<optimized out>, a3=<optimized out>, a4=<optimized out>, a5=a5@entry=0, a6=a6@entry=0, nr=61) at ./nptl/cancellation.c:75
[57067] 75      in ./nptl/cancellation.c
[57067] #3  0x00007f9d602167c7 in __GI___wait4 (pid=<optimized out>, stat_loc=<optimized out>, options=<optimized out>, usage=<optimized out>) at ../sysdeps/unix/sysv/linux/wait4.c:30
[57067] warning: 30     ../sysdeps/unix/sysv/linux/wait4.c: Aucun fichier ou dossier de ce nom
[57067] #4  0x00007f9d6075f6bb in ggml_print_backtrace () from /root/llama.cpp.pascal/build/bin/libggml-base.so.0
[57067] #5  0x00007f9d6075f80e in ggml_abort () from /root/llama.cpp.pascal/build/bin/libggml-base.so.0
[57067] #6  0x00007f9d60dd3c1b in common_context_seq_rm(llama_context*, int, int, int) () from /root/llama.cpp.pascal/build/bin/libllama-common.so.0
[57067] #7  0x000055e24f870716 in server_context_impl::update_slots() ()
[57067] #8  0x000055e24f904e11 in server_queue::start_loop(long) ()
[57067] #9  0x000055e24f7ce42b in main ()
[57067] [Inferior 1 (process 304330) detached]
1.07.545.637 E srv    operator(): http client error: Failed to read connection
Continue

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure: YES. Opus 4.7 + MCP rootless containers with GPU access

…is refused

The startup probe in common_context_can_seq_rm only tests a 2 token tail
removal on seq 0, it cannot guarantee that every partial eviction will
succeed at any position on any live seq. The previous code aborted the
process via GGML_ABORT in common_context_seq_rm whenever the backend
refused the partial removal, taking down the server on a recoverable
condition.

On refusal we now clear the whole seq on both target and draft contexts,
reset the prompt cache counters, and let update_slots reprefill from
zero on the current iteration. The server stays alive, the slot loses
its prefix cache and pays a single reprefill, no crash.
@ServeurpersoCom ServeurpersoCom requested a review from a team as a code owner May 18, 2026 14:10
@ggerganov

Copy link
Copy Markdown
Member

Could you try the following patch applied to master instead:

diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp
index 0f3fb9efa..7b801eac0 100644
--- a/tools/server/server-context.cpp
+++ b/tools/server/server-context.cpp
@@ -2583,9 +2583,9 @@ private:
                             llama_pos pos_next = slot.prompt.tokens.pos_next(n_past);
 
                             // the largest pos_min required for a checkpoint to be useful
-                            const auto pos_min_thold = std::max(0, pos_next - n_swa);
+                            const auto pos_min_thold = std::max(0, pos_next - n_swa - 1);
 
-                            if (n_past > 0 && n_past < slot.prompt.n_tokens()) {
+                            if (n_past > 0 && n_past <= slot.prompt.n_tokens()) {
                                 const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id);
                                 if (pos_min == -1) {
                                     SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);

I think this will fix both this problem and also #23223.

@ServeurpersoCom

Copy link
Copy Markdown
Contributor Author

Could you try the following patch applied to master instead:

diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp
index 0f3fb9efa..7b801eac0 100644
--- a/tools/server/server-context.cpp
+++ b/tools/server/server-context.cpp
@@ -2583,9 +2583,9 @@ private:
                             llama_pos pos_next = slot.prompt.tokens.pos_next(n_past);
 
                             // the largest pos_min required for a checkpoint to be useful
-                            const auto pos_min_thold = std::max(0, pos_next - n_swa);
+                            const auto pos_min_thold = std::max(0, pos_next - n_swa - 1);
 
-                            if (n_past > 0 && n_past < slot.prompt.n_tokens()) {
+                            if (n_past > 0 && n_past <= slot.prompt.n_tokens()) {
                                 const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id);
                                 if (pos_min == -1) {
                                     SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);

I think this will fix both this problem and also #23223.

Look better, I try now !

@ServeurpersoCom

Copy link
Copy Markdown
Contributor Author

Successfully tested in Master+ on my PR sse-replay-buffer, It fixes the problem at its source, much better than my fallback. thanks !

Reproduces in master on hybrid models by asking the assistant to
continue its last reply on a multi turn conversation: the LCP match is
perfect, the deep partial seq_rm is refused by the recurrent backend,
common_context_seq_rm aborts the process via GGML_ABORT.

Patch by @ggerganov routes the n_past == slot.prompt.n_tokens() case
through the existing do_reset path.
@ggerganov

Copy link
Copy Markdown
Member

Ok, let's do some testing also with non-recurrent models to make sure I am not overlooking something and we can merge.

@ServeurpersoCom

Copy link
Copy Markdown
Contributor Author

Tested with Qwen3 30B A3B, GPT-OSS, and Llama 3.3 (all pure transformers), multi-turn continuation works as expected, no regression.

@ggerganov ggerganov merged commit ccee426 into ggml-org:master May 19, 2026
46 of 49 checks passed
kgrama pushed a commit to kgrama/llama.cpp that referenced this pull request May 19, 2026
xxmustafacooTR pushed a commit to xxPlayground/llama-cpp-turboquant that referenced this pull request May 19, 2026
rsenthilkumar6 pushed a commit to rsenthilkumar6/llama.cpp that referenced this pull request May 19, 2026
ArberSephirotheca pushed a commit to ArberSephirotheca/llama.cpp that referenced this pull request May 19, 2026
fhnmor21 pushed a commit to fhnmor21/llama-cpp-turboquant that referenced this pull request May 19, 2026
dbrain pushed a commit to dbrain/hbd-llama-cpp-turboquant that referenced this pull request May 21, 2026
baramofme pushed a commit to baramofme/llama-cpp-turboquant that referenced this pull request May 23, 2026
Jcfunk added a commit to Jcfunk/llama.cpp that referenced this pull request May 23, 2026
* upstream/HEAD:
  ci : install server kleidiai runner dependencies (ggml-org#23259)
  server-context: guarantee there is at least 1 token to decode (ggml-org#23280)
  server : print graphs reused in slot timings (ggml-org#23279)
  save-load-state : refactor tests and improve readability (ggml-org#23196)
  llama-eval : add per-task summary stats (ggml-org#23151)
  ggml-webgpu : extend GDN for K>1 (ggml-org#23299)
  [SCYL] add chapter for performance reference in SYCL.md (ggml-org#23315)
  convert : filter lora tensor names (ggml-org#23077)
  sycl: add GGML_SYCL_USE_ASYNC_MEM_OP env toggle (ggml-org#22153)
  rpc : keep last_graph_uid in the device context (ggml-org#23273)
srossitto79 pushed a commit to srossitto79/llama.cpp that referenced this pull request May 23, 2026
jimbothigpen added a commit to jimbothigpen/llama.cpp that referenced this pull request May 29, 2026
Reverts mainline commit ccee426 (PR ggml-org#23280) in tools/server/server-context.cpp
which we picked up via 2026-05-25 forward-sync. The change introduced a KV cache
reuse regression on Qwen3.6-35B-A3B (and likely Qwen3.5-35B-A3B-MTP) where a
full batch of cached tokens is dropped per turn on multi-turn requests.

Mainline issue: ggml-org#23589
RC + reproducer: orangeswim 2026-05-24

§-RISK: This is a naked revert per the issue author's mainline test; it may
reintroduce the hybrid-attention crash that ggml-org#23280 was fixing. Build + smoke
verify gated on GPU-lockout-clear; follow-up worker required before FF-merge.

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
fewtarius pushed a commit to fewtarius/llama.cpp that referenced this pull request May 30, 2026
turbo-tan pushed a commit to turbo-tan/llama.cpp-tq3 that referenced this pull request Jun 2, 2026
Jcfunk added a commit to Jcfunk/llama.cpp that referenced this pull request Jun 11, 2026
* upstream/HEAD: (25 commits)
  metal : optimize pad + cpy (ggml-org#23354)
  snapdragon: update toolchain to v0.6 (ggml-org#23369)
  ggml-cuda: tune RDNA3 Q6_K MMVQ nwarps (ggml-org#23349)
  opencl: add MoE support for q4_k, q5_k, q6_k on Adreno (ggml-org#23303)
  hexagon: add MROPE and IMROPE support in HTP rope op (ggml-org#23317)
  refactor: Chat Screen UI rendering (ggml-org#23333)
  github: mention --log-file in issue templates (ggml-org#23277)
  common: fix --help for --verbosity (ggml-org#23278)
  common: fix --fit verbosity with --verbosity 4 (ggml-org#23282)
  convert : update mtp related help (ggml-org#23334)
  hexagon: enable support for NORM op (ggml-org#23319)
  model : clarify MTP layer comment in qwen35.cpp [no ci] (ggml-org#23338)
  llama : MTP clean-up (ggml-org#23269)
  ui: Bump packages + address build warnings (ggml-org#23300)
  ci : install libssl-dev (ggml-org#23325)
  ci : install server kleidiai runner dependencies (ggml-org#23259)
  server-context: guarantee there is at least 1 token to decode (ggml-org#23280)
  server : print graphs reused in slot timings (ggml-org#23279)
  save-load-state : refactor tests and improve readability (ggml-org#23196)
  llama-eval : add per-task summary stats (ggml-org#23151)
  ...
TheTom pushed a commit to TheTom/llama-cpp-turboquant that referenced this pull request Jun 12, 2026
…odels

When an incoming prompt exactly matches the slot's cached tokens, the
server backs off one token (n_past--) to guarantee at least one token is
decoded for logits [TAG_PROMPT_LOGITS]. The subsequent truncation then
calls seq_rm with p0 = n_past > 0, which recurrent memory cannot satisfy
(the state cannot be rewound to a mid-sequence position when the rollback
exceeds n_rs_seq), so common_context_seq_rm hits GGML_ABORT and the whole
server dies:

  common/common.cpp:1472: failed to remove sequence 0 with p0=490, p1=-1

Observed in production with Qwen3.6-35B-A3B (GatedDeltaNet layers): any
client re-sending an identical prompt with cache_prompt enabled
(regenerate / retry) crashed the server. Reproducible with any recurrent
model, e.g. mamba-130m, by sending the same /completion prompt twice.

Fix: include the exact-match case (n_past == slot.prompt.n_tokens()) in
the existing checkpoint-restore/full-reprocess branch by relaxing its
condition to n_past <= slot.prompt.n_tokens() and extending pos_min_thold
by 1 when there are no new tokens, so the upcoming back-off is accounted
for. Recurrent/hybrid models now restore a context checkpoint (or fall
back to full re-processing) instead of aborting; attention models are
unaffected.

This backports upstream ggml-org/llama.cpp PRs ggml-org#23280 and ggml-org#24110
(commits ccee426, 6f3a9f3).

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants