Skip to content

Commit d454c0c

Browse files
committed
feat(swiftlmchat): proactive iOS lifecycle — unload on background, reload on foreground
InferenceEngine: - willResignActiveNotification → stopGeneration() + unload() + save backgroundedModelId - didBecomeActiveNotification → reload backgroundedModelId (or lastLoadedModelId) - autoOffloadOnBackground: Bool (default true on iOS, false on macOS) - Observers consolidated into [NSObjectProtocol] for clean deinit - Reactive memory warning still kept as safety fallback - Thermal observer migrated to same consolidated array - Background unload sets .idle (not .error) — clean UX on return
1 parent dc4069f commit d454c0c

2 files changed

Lines changed: 127 additions & 40 deletions

File tree

Sources/MLXInferenceCore/InferenceEngine.swift

Lines changed: 89 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -55,54 +55,117 @@ public final class InferenceEngine: ObservableObject {
5555
@Published public private(set) var state: ModelState = .idle
5656
@Published public private(set) var thermalLevel: ThermalLevel = .nominal
5757

58+
/// Whether to automatically unload the model when the app backgrounds
59+
/// and reload it when returning to foreground.
60+
/// Defaults to true on iOS (prevents jetsam), false on macOS.
61+
public var autoOffloadOnBackground: Bool = {
62+
#if os(iOS)
63+
return true
64+
#else
65+
return false
66+
#endif
67+
}()
68+
5869
/// Shared download + storage manager.
5970
public let downloadManager = ModelDownloadManager()
6071

6172
private var container: ModelContainer?
6273
private var currentModelId: String?
6374
private var generationTask: Task<Void, Never>?
64-
private var pressureObserver: NSObjectProtocol?
65-
private var thermalObserver: NSObjectProtocol?
75+
76+
// All NotificationCenter observers collected for clean deregistration
77+
private var observers: [NSObjectProtocol] = []
78+
79+
// Track the model ID that was active before we backgrounded,
80+
// so we can restore it when returning to foreground.
81+
private var backgroundedModelId: String?
6682

6783
public init() {
6884
setupPressureHandlers()
6985
}
7086

7187
deinit {
72-
if let o = pressureObserver { NotificationCenter.default.removeObserver(o) }
73-
if let o = thermalObserver { NotificationCenter.default.removeObserver(o) }
88+
observers.forEach { NotificationCenter.default.removeObserver($0) }
7489
}
7590

7691
// MARK: — Pressure Handlers
7792

7893
private func setupPressureHandlers() {
79-
// iOS memory pressure → unload model weights immediately
8094
#if canImport(UIKit)
81-
pressureObserver = NotificationCenter.default.addObserver(
82-
forName: UIApplication.didReceiveMemoryWarningNotification,
83-
object: nil,
84-
queue: .main
85-
) { [weak self] _ in
86-
Task { @MainActor [weak self] in
87-
guard let self else { return }
88-
// Only unload if not actively generating
89-
if case .generating = self.state { return }
90-
self.unload()
91-
self.state = .error("Unloaded due to memory pressure. Tap to reload.")
95+
// ── REACTIVE: Memory warning (last resort) ────────────────────────────
96+
// OS sends this *after* pressure builds. We still handle it as a fallback
97+
// in case the proactive unload wasn't triggered (e.g. app was already
98+
// under pressure from another process).
99+
observers.append(
100+
NotificationCenter.default.addObserver(
101+
forName: UIApplication.didReceiveMemoryWarningNotification,
102+
object: nil, queue: .main
103+
) { [weak self] _ in
104+
Task { @MainActor [weak self] in
105+
guard let self else { return }
106+
if case .generating = self.state { return } // don't interrupt mid-stream
107+
self.unload()
108+
self.state = .error("Unloaded due to memory pressure. Tap to reload.")
109+
}
92110
}
93-
}
111+
)
112+
113+
// ── PROACTIVE: App will background ────────────────────────────────────
114+
// Fire BEFORE iOS hands control back to springboard.
115+
// At this moment the process is still fully foregrounded — Metal context
116+
// is valid, memory limit hasn't changed. We unload now so iOS never
117+
// accumulates memory pressure against us in the background.
118+
observers.append(
119+
NotificationCenter.default.addObserver(
120+
forName: UIApplication.willResignActiveNotification,
121+
object: nil, queue: .main
122+
) { [weak self] _ in
123+
Task { @MainActor [weak self] in
124+
guard let self, self.autoOffloadOnBackground else { return }
125+
// Remember what was loaded so we can restore it
126+
self.backgroundedModelId = self.currentModelId
127+
// Stop any in-flight generation cleanly
128+
self.stopGeneration()
129+
self.unload()
130+
self.state = .idle // clean slate — no error banner on return
131+
}
132+
}
133+
)
134+
135+
// ── PROACTIVE: App returned to foreground ─────────────────────────────
136+
// Silently reload the model the user was using before they left.
137+
// We show .loading state so the chat UI doesn't appear broken.
138+
observers.append(
139+
NotificationCenter.default.addObserver(
140+
forName: UIApplication.didBecomeActiveNotification,
141+
object: nil, queue: .main
142+
) { [weak self] _ in
143+
Task { @MainActor [weak self] in
144+
guard let self, self.autoOffloadOnBackground else { return }
145+
// Prefer the model that was active when we backgrounded;
146+
// fall back to the last persisted model the user chose.
147+
let modelToReload = self.backgroundedModelId
148+
?? self.downloadManager.lastLoadedModelId
149+
self.backgroundedModelId = nil
150+
if let modelId = modelToReload {
151+
await self.load(modelId: modelId)
152+
}
153+
}
154+
}
155+
)
94156
#endif
95157

96-
// Thermal state monitoring (all platforms)
97-
thermalObserver = NotificationCenter.default.addObserver(
98-
forName: ProcessInfo.thermalStateDidChangeNotification,
99-
object: nil,
100-
queue: .main
101-
) { [weak self] _ in
102-
Task { @MainActor [weak self] in
103-
self?.updateThermalLevel()
158+
// ── Thermal state monitoring (all platforms) ──────────────────────────
159+
observers.append(
160+
NotificationCenter.default.addObserver(
161+
forName: ProcessInfo.thermalStateDidChangeNotification,
162+
object: nil, queue: .main
163+
) { [weak self] _ in
164+
Task { @MainActor [weak self] in
165+
self?.updateThermalLevel()
166+
}
104167
}
105-
}
168+
)
106169
updateThermalLevel()
107170
}
108171

Sources/SwiftLM/Server.swift

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,10 +1065,14 @@ struct ThinkingStateTracker {
10651065

10661066
// ── Chat Streaming ───────────────────────────────────────────────────────────
10671067

1068-
/// A lightweight actor-based boolean flag used to coordinate the prefill heartbeat task.
1069-
private actor BoolFlag {
1070-
private(set) var value: Bool = false
1071-
func set() { value = true }
1068+
/// Tracks prefill progress: whether it is done, and how many tokens have been processed.
1069+
/// n_past is updated by activePrefillProgressHook (called from LLMModel.prepare after each chunk)
1070+
/// and read by the SSE heartbeat task every 2 s.
1071+
private actor PrefillState {
1072+
private(set) var done: Bool = false
1073+
private(set) var nPast: Int = 0
1074+
func finish() { done = true }
1075+
func update(nPast: Int) { self.nPast = nPast }
10721076
}
10731077

10741078
func handleChatStreaming(
@@ -1086,16 +1090,25 @@ func handleChatStreaming(
10861090
) -> Response {
10871091
let (sseStream, cont) = AsyncStream<String>.makeStream()
10881092

1089-
// ── Prefill heartbeat: emit progress events while prompt is being processed ──
1090-
// This prevents clients from seeing a silent/dead connection during long prefills.
1091-
let prefillDone = BoolFlag()
1093+
// ── Prefill heartbeat: emit llama-server-style slot_update progress every 2 s ──
1094+
// n_past is updated by activePrefillProgressHook in LLMModel.prepare() after each
1095+
// 512-token chunk; single-chunk prompts only show elapsed_seconds.
1096+
let prefillState = PrefillState()
1097+
activePrefillProgressHook = { nPast, _ in
1098+
Task { await prefillState.update(nPast: nPast) }
1099+
}
10921100
Task {
10931101
var elapsed = 0
1094-
while await !prefillDone.value {
1102+
while await !prefillState.done {
10951103
try? await Task.sleep(for: .seconds(2))
1096-
if await !prefillDone.value {
1104+
if await !prefillState.done {
10971105
elapsed += 2
1098-
_ = cont.yield(ssePrefillChunk(modelId: modelId, promptTokens: promptTokenCount, elapsedSeconds: elapsed))
1106+
let nPast = await prefillState.nPast
1107+
_ = cont.yield(ssePrefillChunk(
1108+
modelId: modelId,
1109+
nPast: nPast,
1110+
promptTokens: promptTokenCount,
1111+
elapsedSeconds: elapsed))
10991112
}
11001113
}
11011114
}
@@ -1121,7 +1134,9 @@ func handleChatStreaming(
11211134
}
11221135
// Signal first token — stops the prefill heartbeat task
11231136
if firstToken {
1124-
await prefillDone.set()
1137+
// First decode token: stop heartbeat and clear the prefill progress hook
1138+
activePrefillProgressHook = nil
1139+
await prefillState.finish()
11251140
let prefillDur = Date().timeIntervalSince(prefillStart)
11261141
let prefillTokPerSec = prefillDur > 0 ? Double(promptTokenCount) / prefillDur : 0
11271142
print("srv slot update: id 0 | prefill done | n_tokens=\(promptTokenCount), t=\(String(format: "%.2f", prefillDur))s, \(String(format: "%.1f", prefillTokPerSec))t/s")
@@ -1175,7 +1190,8 @@ func handleChatStreaming(
11751190
toolCallIndex += 1
11761191

11771192
case .info(let info):
1178-
await prefillDone.set()
1193+
activePrefillProgressHook = nil
1194+
await prefillState.finish()
11791195
if !stopped {
11801196
var reason: String
11811197
switch info.stopReason {
@@ -1711,15 +1727,23 @@ func sseChunk(modelId: String, reasoningContent: String?, content: String?, fini
17111727

17121728
/// Prefill-progress heartbeat chunk — emitted every 2s while the server is processing the prompt.
17131729
/// Uses object type "prefill_progress" so clients can filter it without confusing it with real tokens.
1714-
func ssePrefillChunk(modelId: String, promptTokens: Int, elapsedSeconds: Int) -> String {
1730+
/// Format mirrors llama-server's slot_update event:
1731+
/// n_past : tokens evaluated so far (real value from chunked prefill, or 0 for single-chunk)
1732+
/// n_prompt_tokens : total prompt token count
1733+
/// fraction : n_past / n_prompt_tokens (0.0–1.0), useful for progress bars
1734+
/// elapsed_seconds : wall-clock time since the request started
1735+
func ssePrefillChunk(modelId: String, nPast: Int = 0, promptTokens: Int, elapsedSeconds: Int) -> String {
1736+
let fraction = promptTokens > 0 ? Double(nPast) / Double(promptTokens) : 0.0
17151737
let chunk: [String: Any] = [
17161738
"id": "prefill-\(UUID().uuidString)",
17171739
"object": "prefill_progress",
17181740
"created": Int(Date().timeIntervalSince1970),
17191741
"model": modelId,
17201742
"prefill": [
17211743
"status": "processing",
1722-
"prompt_tokens": promptTokens,
1744+
"n_past": nPast,
1745+
"n_prompt_tokens": promptTokens,
1746+
"fraction": fraction,
17231747
"elapsed_seconds": elapsedSeconds
17241748
]
17251749
]

0 commit comments

Comments
 (0)