Skip to content

Bug: Hybrid models revaluate the prompt invalidating kv #1576

@nicman23

Description

@nicman23

What happened?

kv gets invalidated

Name and Version

master

What operating system are you seeing the problem on?

Linux

Relevant log output

i do not have it but i got a fix :P

https://github.com/evq/llama.cpp/commit/c2507187078d1b9e706410cdaac26e4d59c7036d is the llama.cpp fix

asked qwen to edit example server on this project and it worked.

diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp
index ccaa7bd3..7602ad26 100644
--- a/examples/server/server-context.cpp
+++ b/examples/server/server-context.cpp
@@ -2801,7 +2801,15 @@ void  server_context::create_checkpoint_at_interval(server_slot & slot, const gp
 
 void server_context::apply_checkpoint(server_slot & slot) {
     llama_pos pos_next = slot.cache_tokens.pos_next(slot.n_past);
-    const auto pos_min_thold = std::max(0, pos_next - 1);
+
+    const bool is_recurrent = llama_model_is_recurrent(model);
+
+    // note: when n_swa == 0, the model does not use SWA
+    const auto n_swa = std::max(0, llama_model_n_swa(model));
+
+    // For hybrid/recurrent: SWA threshold not meaningful, set to 0.
+    // For pure transformer/SWA: preserve existing behavior.
+    const auto pos_min_thold = is_recurrent ? (llama_pos) 0 : (llama_pos) std::max(0, pos_next - n_swa);
     if (slot.n_past > 0 && slot.n_past < slot.cache_tokens.n_tokens()) {
         int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
 
@@ -2812,8 +2820,14 @@ void server_context::apply_checkpoint(server_slot & slot) {
             const auto it = std::find_if(
                 slot.server_cached_prompt.checkpoints.rbegin(),
                 slot.server_cached_prompt.checkpoints.rend(),
-                [&](const auto & cur) {
+                [&, func_name = __func__](const auto & cur) {
+                    if (is_recurrent) {
+                        // For hybrid/recurrent: use position-matching semantics
+                        return cur.pos_max <= n_past && cur.pos_max < pos_next;
+                    }
                     // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
+                    LOG_INF("slot %12.*s: id %2d | task %d | Checking checkpoint with [%d, %d] against %d...\n", 12,
+                        func_name, (slot).id, ((slot).task ? (slot).task->id : -1), cur.pos_min, cur.pos_max, pos_min_thold);
                     return cur.pos_min < pos_min_thold;
                 }
             );
@@ -2840,10 +2854,16 @@ void server_context::apply_checkpoint(server_slot & slot) {
             }
 
             if (do_reset) {
-                SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
-                    "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
-                slot.n_past = 0;
-                slot.n_past_prompt = 0;
+                if (is_recurrent) {
+                    // For hybrid/recurrent: preserve current state, do not zero n_past
+                    SLT_WRN(slot, "no matching recurrent checkpoint; preserving prompt state (n_past = %d)\n", n_past);
+                } else {
+                    SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
+                        "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
+                    pos_next = 0;
+                    slot.n_past = 0;
+                    slot.n_past_prompt = 0;
+                }
             }
         }
     }
@@ -2852,8 +2872,8 @@ void server_context::apply_checkpoint(server_slot & slot) {
         // erase any checkpoints with pos_min > pos_min_thold
         for (auto it = slot.server_cached_prompt.checkpoints.begin(); it != slot.server_cached_prompt.checkpoints.end();) {
             const auto & cur = *it;
-            if (cur.pos_min > pos_min_thold) {
-                SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024);
+            if (cur.pos_min > pos_next) {
+                SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, pos_next, (float)cur.data.size() / 1024 / 1024);
                 it = slot.server_cached_prompt.checkpoints.erase(it);
             } else {
                 ++it;

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions