Extend llama_kv_cache_seq_rm to allow matching any sequence#3843
Merged
KerfuffleV2 merged 2 commits intoggml-org:masterfrom Oct 29, 2023
Merged
Extend llama_kv_cache_seq_rm to allow matching any sequence#3843KerfuffleV2 merged 2 commits intoggml-org:masterfrom
KerfuffleV2 merged 2 commits intoggml-org:masterfrom
Conversation
ggerganov
approved these changes
Oct 29, 2023
Member
ggerganov
left a comment
There was a problem hiding this comment.
Let's remove llama_kv_cache_tokens_rm and add llama_kv_cache_clear and extend the llama_kv_cache_seq_rm as proposed
Use llama_kv_cache_clear for cache clearing Change calls to llama_kv_cache_tokens_rm that want to delete by position to use llama_kv_cache_seq_rm functionality
Contributor
Author
|
How about this? Tested, seems to work. Pretty hard change for even me to screw up. |
KerfuffleV2
commented
Oct 29, 2023
| for (uint32_t i = 0; i < cache.size; ++i) { | ||
| if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { | ||
| cache.cells[i].seq_id.erase(seq_id); | ||
| if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { |
Contributor
Author
There was a problem hiding this comment.
Probably not worth the complexity, but we could optimize to:
static void llama_kv_cache_seq_rm(
struct llama_kv_cache & cache,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
uint32_t new_head = cache.size;
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
if (seq_id < 0) {
if (p0 == 0 && p1 >= cache.size) {
llama_kv_cache_clear(cache);
return;
}
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.cells[i].seq_id.clear();
cache.cells[i].pos = -1;
if (new_head == cache.size) new_head = i;
}
}
} else {
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.cells[i].seq_id.erase(seq_id);
if (cache.cells[i].seq_id.empty()) {
cache.cells[i].pos = -1;
if (new_head == cache.size) new_head = i;
}
}
}
}
// If we freed up a slot, set head to it so searching can start there.
if (new_head != cache.size) cache.head = new_head;
}This avoids checking if seq_id < 0 each iteration, but a single int test probably wouldn't be noticeable even for huge KV caches.
Member
There was a problem hiding this comment.
Yup, either way would be fine. If you change it directly merge or you just merge as it is
Contributor
Author
There was a problem hiding this comment.
Since you don't have a preference, I'll just leave it as is. I don't think it's worth the added complexity.
ggerganov
approved these changes
Oct 29, 2023
Nexesenex
pushed a commit
to Nexesenex/croco.cpp
that referenced
this pull request
Oct 30, 2023
…#3843) * Extend llama_kv_cache_seq_rm to allow matichng any sequence * Replace llama_kv_cache_tokens_rm with llama_kv_cache_clear Use llama_kv_cache_clear for cache clearing Change calls to llama_kv_cache_tokens_rm that want to delete by position to use llama_kv_cache_seq_rm functionality
brittlewis12
added a commit
to brittlewis12/llmfarm_core.swift
that referenced
this pull request
Nov 17, 2023
olexiyb
pushed a commit
to Sanctum-AI/llama.cpp
that referenced
this pull request
Nov 23, 2023
…#3843) * Extend llama_kv_cache_seq_rm to allow matichng any sequence * Replace llama_kv_cache_tokens_rm with llama_kv_cache_clear Use llama_kv_cache_clear for cache clearing Change calls to llama_kv_cache_tokens_rm that want to delete by position to use llama_kv_cache_seq_rm functionality
brittlewis12
added a commit
to brittlewis12/llmfarm_core.swift
that referenced
this pull request
Nov 30, 2023
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Since
llama_seq_idallows for negative values, this extendsllama_kv_cache_seq_rmto use sequence ids < 0 to match any sequence. Based on #3840 it doesn't seem likellama_kv_cache_token_rmreally can be used for anything other than just clearing the kv cache. Some of the existing uses seem incorrect as well.This pull is mergeable but incomplete. If it's on the right track, I'll just replace the existing calls to
llama_kv_cache_token_rmthat just clear the cache (llama_kv_cache_tokens_rm(ctx, -1, -1)) withllama_kv_cache_seq_rm(ctx, -1, -1, -1)and ones likellama_kv_cache_tokens_rm(ctx, n_matching_session_tokens, -1)which seem to actually want a position rather than cell index withllama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1).llama_kv_cache_tokens_rmcould possibly be removed (I can't see how one would actually use it, where would knowing a kv cell index be meaningful? Maybe something like kv cache defragging if it ever gets added?). Could possibly also add allama_kv_cache_clearfunction to clear the KV cache a bit more efficiently than using the newllama_kv_cache_seq_rmmethod.Closes #3840