@@ -968,7 +968,7 @@ func handleChatCompletion(
968968 let prefillStart = Date ( )
969969
970970 // ── Cache-aware generation ──
971- let stream : AsyncStream < Generation > = try await container. perform { context in
971+ let ( stream, onPrefillDone ) = try await container. perform { context -> ( AsyncStream < Generation > , ( ( ) async -> Void ) ? ) in
972972 let cache = context. model. newCache ( parameters: params)
973973
974974 // ── TurboQuant: enable 3-bit KV compression on every KVCacheSimple layer ──
@@ -983,6 +983,7 @@ func handleChatCompletion(
983983 }
984984
985985 // Try to restore via token-by-token prefix match (llama-server style)
986+ var stream : AsyncStream < Generation >
986987 if let cachedCount = await promptCache. restore ( newTokens: promptTokens, into: cache) {
987988 // Cache hit: KV state is pre-populated up to cachedCount tokens.
988989 // Only compute the remaining (new) tokens.
@@ -996,21 +997,22 @@ func handleChatCompletion(
996997 }
997998 let remainingTokens = lmInput. text. tokens [ startIndex... ]
998999 let trimmedInput = LMInput ( tokens: remainingTokens)
999- let stream = try MLXLMCommon . generate (
1000+ stream = try MLXLMCommon . generate (
10001001 input: trimmedInput, cache: cache, parameters: params, context: context
10011002 )
1002- // Save prompt tokens + KV state synchronously after the partial prefill.
1003- await promptCache. save ( tokens: promptTokens, cache: cache)
1004- return stream
10051003 } else {
10061004 // Cache miss: process the full prompt.
1007- let stream = try MLXLMCommon . generate (
1005+ stream = try MLXLMCommon . generate (
10081006 input: lmInput, cache: cache, parameters: params, context: context
10091007 )
1010- // Save prompt tokens + KV state synchronously after the full prefill.
1008+ }
1009+
1010+ // Return a closure that will save the cache state synchronously AFTER
1011+ // the generator stream has evaluated the prefill (on its very first token).
1012+ let onPrefillDone : ( ( ) async -> Void ) ? = {
10111013 await promptCache. save ( tokens: promptTokens, cache: cache)
1012- return stream
10131014 }
1015+ return ( stream, onPrefillDone)
10141016 }
10151017
10161018 let modelId = config. modelId
@@ -1020,14 +1022,14 @@ func handleChatCompletion(
10201022 stream: stream, modelId: modelId, stopSequences: stopSequences,
10211023 includeUsage: includeUsage, promptTokenCount: promptTokenCount,
10221024 enableThinking: enableThinking, jsonMode: jsonMode, semaphore: semaphore,
1023- stats: stats, genStart: genStart, prefillStart: prefillStart
1025+ stats: stats, genStart: genStart, prefillStart: prefillStart, onPrefillDone : onPrefillDone
10241026 )
10251027 } else {
10261028 return try await handleChatNonStreaming (
10271029 stream: stream, modelId: modelId, stopSequences: stopSequences,
10281030 promptTokenCount: promptTokenCount, enableThinking: enableThinking,
10291031 jsonMode: jsonMode, semaphore: semaphore,
1030- stats: stats, genStart: genStart, prefillStart: prefillStart
1032+ stats: stats, genStart: genStart, prefillStart: prefillStart, onPrefillDone : onPrefillDone
10311033 )
10321034 }
10331035}
@@ -1122,7 +1124,8 @@ func handleChatStreaming(
11221124 semaphore: AsyncSemaphore,
11231125 stats: ServerStats,
11241126 genStart: Date,
1125- prefillStart: Date
1127+ prefillStart: Date,
1128+ onPrefillDone: ( ( ) async -> Void) ? = nil
11261129) -> Response {
11271130 let ( sseStream, cont) = AsyncStream< String> . makeStream( )
11281131
@@ -1177,6 +1180,7 @@ func handleChatStreaming(
11771180 let prefillTokPerSec = prefillDur > 0 ? Double ( promptTokenCount) / prefillDur : 0
11781181 print ( " srv slot update: id 0 | prefill done | n_tokens= \( promptTokenCount) , t= \( String ( format: " %.2f " , prefillDur) ) s, \( String ( format: " %.1f " , prefillTokPerSec) ) t/s " )
11791182 print ( " srv generate: id 0 | " , terminator: " " )
1183+ if let onPrefillDone { await onPrefillDone ( ) }
11801184 firstToken = false
11811185 }
11821186 print ( text, terminator: " " )
@@ -1292,7 +1296,8 @@ func handleChatNonStreaming(
12921296 semaphore: AsyncSemaphore,
12931297 stats: ServerStats,
12941298 genStart: Date,
1295- prefillStart: Date
1299+ prefillStart: Date,
1300+ onPrefillDone: ( ( ) async -> Void) ? = nil
12961301) async throws -> Response {
12971302 var fullText = " "
12981303 var completionTokenCount = 0
@@ -1315,6 +1320,7 @@ func handleChatNonStreaming(
13151320 let prefillTokPerSec = prefillDur > 0 ? Double ( promptTokenCount) / prefillDur : 0
13161321 print ( " srv slot update: id 0 | prefill done | n_tokens= \( promptTokenCount) , t= \( String ( format: " %.2f " , prefillDur) ) s, \( String ( format: " %.1f " , prefillTokPerSec) ) t/s " )
13171322 print ( " srv generate: id 0 | " , terminator: " " )
1323+ if let onPrefillDone { await onPrefillDone ( ) }
13181324 firstToken = false
13191325 }
13201326 print ( text, terminator: " " )
0 commit comments