Skip to content

Commit 03025a4

Browse files
committed
Fix Gemma 4 sliding window rotating KV cache regression and weight mapping
1 parent a2b70dc commit 03025a4

1 file changed

Lines changed: 18 additions & 12 deletions

File tree

Sources/SwiftLM/Server.swift

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -968,7 +968,7 @@ func handleChatCompletion(
968968
let prefillStart = Date()
969969

970970
// ── Cache-aware generation ──
971-
let stream: AsyncStream<Generation> = try await container.perform { context in
971+
let (stream, onPrefillDone) = try await container.perform { context -> (AsyncStream<Generation>, (() async -> Void)?) in
972972
let cache = context.model.newCache(parameters: params)
973973

974974
// ── TurboQuant: enable 3-bit KV compression on every KVCacheSimple layer ──
@@ -983,6 +983,7 @@ func handleChatCompletion(
983983
}
984984

985985
// Try to restore via token-by-token prefix match (llama-server style)
986+
var stream: AsyncStream<Generation>
986987
if let cachedCount = await promptCache.restore(newTokens: promptTokens, into: cache) {
987988
// Cache hit: KV state is pre-populated up to cachedCount tokens.
988989
// Only compute the remaining (new) tokens.
@@ -996,21 +997,22 @@ func handleChatCompletion(
996997
}
997998
let remainingTokens = lmInput.text.tokens[startIndex...]
998999
let trimmedInput = LMInput(tokens: remainingTokens)
999-
let stream = try MLXLMCommon.generate(
1000+
stream = try MLXLMCommon.generate(
10001001
input: trimmedInput, cache: cache, parameters: params, context: context
10011002
)
1002-
// Save prompt tokens + KV state synchronously after the partial prefill.
1003-
await promptCache.save(tokens: promptTokens, cache: cache)
1004-
return stream
10051003
} else {
10061004
// Cache miss: process the full prompt.
1007-
let stream = try MLXLMCommon.generate(
1005+
stream = try MLXLMCommon.generate(
10081006
input: lmInput, cache: cache, parameters: params, context: context
10091007
)
1010-
// Save prompt tokens + KV state synchronously after the full prefill.
1008+
}
1009+
1010+
// Return a closure that will save the cache state synchronously AFTER
1011+
// the generator stream has evaluated the prefill (on its very first token).
1012+
let onPrefillDone: (() async -> Void)? = {
10111013
await promptCache.save(tokens: promptTokens, cache: cache)
1012-
return stream
10131014
}
1015+
return (stream, onPrefillDone)
10141016
}
10151017

10161018
let modelId = config.modelId
@@ -1020,14 +1022,14 @@ func handleChatCompletion(
10201022
stream: stream, modelId: modelId, stopSequences: stopSequences,
10211023
includeUsage: includeUsage, promptTokenCount: promptTokenCount,
10221024
enableThinking: enableThinking, jsonMode: jsonMode, semaphore: semaphore,
1023-
stats: stats, genStart: genStart, prefillStart: prefillStart
1025+
stats: stats, genStart: genStart, prefillStart: prefillStart, onPrefillDone: onPrefillDone
10241026
)
10251027
} else {
10261028
return try await handleChatNonStreaming(
10271029
stream: stream, modelId: modelId, stopSequences: stopSequences,
10281030
promptTokenCount: promptTokenCount, enableThinking: enableThinking,
10291031
jsonMode: jsonMode, semaphore: semaphore,
1030-
stats: stats, genStart: genStart, prefillStart: prefillStart
1032+
stats: stats, genStart: genStart, prefillStart: prefillStart, onPrefillDone: onPrefillDone
10311033
)
10321034
}
10331035
}
@@ -1122,7 +1124,8 @@ func handleChatStreaming(
11221124
semaphore: AsyncSemaphore,
11231125
stats: ServerStats,
11241126
genStart: Date,
1125-
prefillStart: Date
1127+
prefillStart: Date,
1128+
onPrefillDone: (() async -> Void)? = nil
11261129
) -> Response {
11271130
let (sseStream, cont) = AsyncStream<String>.makeStream()
11281131

@@ -1177,6 +1180,7 @@ func handleChatStreaming(
11771180
let prefillTokPerSec = prefillDur > 0 ? Double(promptTokenCount) / prefillDur : 0
11781181
print("srv slot update: id 0 | prefill done | n_tokens=\(promptTokenCount), t=\(String(format: "%.2f", prefillDur))s, \(String(format: "%.1f", prefillTokPerSec))t/s")
11791182
print("srv generate: id 0 | ", terminator: "")
1183+
if let onPrefillDone { await onPrefillDone() }
11801184
firstToken = false
11811185
}
11821186
print(text, terminator: "")
@@ -1292,7 +1296,8 @@ func handleChatNonStreaming(
12921296
semaphore: AsyncSemaphore,
12931297
stats: ServerStats,
12941298
genStart: Date,
1295-
prefillStart: Date
1299+
prefillStart: Date,
1300+
onPrefillDone: (() async -> Void)? = nil
12961301
) async throws -> Response {
12971302
var fullText = ""
12981303
var completionTokenCount = 0
@@ -1315,6 +1320,7 @@ func handleChatNonStreaming(
13151320
let prefillTokPerSec = prefillDur > 0 ? Double(promptTokenCount) / prefillDur : 0
13161321
print("srv slot update: id 0 | prefill done | n_tokens=\(promptTokenCount), t=\(String(format: "%.2f", prefillDur))s, \(String(format: "%.1f", prefillTokPerSec))t/s")
13171322
print("srv generate: id 0 | ", terminator: "")
1323+
if let onPrefillDone { await onPrefillDone() }
13181324
firstToken = false
13191325
}
13201326
print(text, terminator: "")

0 commit comments

Comments
 (0)