Skip to content

Commit 35e6172

Browse files
committed
fix(swiftlmchat): wire MLXInferenceCore sources + SPM packages into Xcode project
- Regenerate project.pbxproj with: - MLXInferenceCore .swift files as direct compile sources - XCLocalSwiftPackageReference for mlx-swift and mlx-swift-lm - XCSwiftPackageProductDependency for MLX, MLXLLM, MLXLMCommon - Add ModelManagementView (download manager, disk usage, delete) - Add ModelDownloadManager to MLXInferenceCore - Wire progress callbacks in InferenceEngine.load()
1 parent 01df003 commit 35e6172

6 files changed

Lines changed: 866 additions & 140 deletions

File tree

Sources/MLXInferenceCore/InferenceEngine.swift

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ public struct GenerationToken: Sendable {
3333
public final class InferenceEngine: ObservableObject {
3434
@Published public private(set) var state: ModelState = .idle
3535

36+
/// Shared download manager — exposes download progress and local cache state.
37+
public let downloadManager = ModelDownloadManager()
38+
3639
private var container: ModelContainer?
3740
private var currentModelId: String?
3841
private var generationTask: Task<Void, Never>?
@@ -49,19 +52,27 @@ public final class InferenceEngine: ObservableObject {
4952
currentModelId = modelId
5053

5154
do {
52-
// Configure download progress reporting
5355
let config = ModelConfiguration(id: modelId)
5456
container = try await LLMModelFactory.shared.loadContainer(
5557
configuration: config
5658
) { [weak self] progress in
5759
Task { @MainActor in
60+
guard let self else { return }
5861
let pct = progress.fractionCompleted
59-
let speed = progress.throughput.map { String(format: "%.1f MB/s", $0 / 1_000_000) } ?? ""
60-
self?.state = .downloading(progress: pct, speed: speed)
62+
let speedMBps = progress.throughput.map { $0 / 1_000_000 }
63+
let speedStr = speedMBps.map { String(format: "%.1f MB/s", $0) } ?? ""
64+
self.state = .downloading(progress: pct, speed: speedStr)
65+
self.downloadManager.updateProgress(ModelDownloadProgress(
66+
modelId: modelId,
67+
fractionCompleted: pct,
68+
speedMBps: speedMBps
69+
))
6170
}
6271
}
72+
downloadManager.completeDownload(modelId: modelId)
6373
state = .ready(modelId: modelId)
6474
} catch {
75+
downloadManager.cancelDownload(modelId: modelId)
6576
state = .error("Failed to load \(modelId): \(error.localizedDescription)")
6677
container = nil
6778
}
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
// ModelDownloadManager.swift — HuggingFace cache inspection and model lifecycle
2+
// Manages local model storage: downloaded status, disk size, deletion, persistence.
3+
4+
import Foundation
5+
import Combine
6+
7+
/// Represents a locally downloaded model entry.
8+
public struct DownloadedModel: Identifiable, Sendable {
9+
public let id: String // HuggingFace model ID
10+
public let cacheDirectory: URL // Local cache path
11+
public let sizeBytes: Int64 // Total bytes on disk
12+
public let modifiedDate: Date? // Last access/modification date
13+
14+
public var displaySize: String {
15+
let gb = Double(sizeBytes) / 1_073_741_824
16+
let mb = Double(sizeBytes) / 1_048_576
17+
if gb >= 1.0 { return String(format: "%.1f GB", gb) }
18+
return String(format: "%.0f MB", mb)
19+
}
20+
}
21+
22+
/// Download progress for an in-flight download.
23+
public struct ModelDownloadProgress: Sendable {
24+
public let modelId: String
25+
public let fractionCompleted: Double // 0.0–1.0
26+
public let speedMBps: Double? // nil if unknown
27+
28+
public var speedString: String {
29+
guard let s = speedMBps else { return "" }
30+
return String(format: "%.1f MB/s", s)
31+
}
32+
public var percentString: String { "\(Int(fractionCompleted * 100))%" }
33+
}
34+
35+
/// Manages the HuggingFace model cache for SwiftLM Chat.
36+
/// Thread-safe: all mutations happen on MainActor.
37+
@MainActor
38+
public final class ModelDownloadManager: ObservableObject {
39+
40+
// MARK: — Published state
41+
42+
@Published public private(set) var downloadedModels: [DownloadedModel] = []
43+
@Published public private(set) var activeDownloads: [String: ModelDownloadProgress] = [:]
44+
@Published public private(set) var totalDiskUsageBytes: Int64 = 0
45+
46+
// MARK: — Persistence
47+
48+
private let lastModelKey = "swiftlm.lastLoadedModelId"
49+
public var lastLoadedModelId: String? {
50+
get { UserDefaults.standard.string(forKey: lastModelKey) }
51+
set { UserDefaults.standard.set(newValue, forKey: lastModelKey) }
52+
}
53+
54+
// MARK: — HuggingFace cache paths
55+
56+
/// Primary HF hub cache directory.
57+
public static var huggingFaceCacheRoot: URL {
58+
// Respect $HF_HUB_CACHE > $HF_HOME > default
59+
if let hfCache = ProcessInfo.processInfo.environment["HF_HUB_CACHE"] {
60+
return URL(fileURLWithPath: hfCache)
61+
}
62+
if let hfHome = ProcessInfo.processInfo.environment["HF_HOME"] {
63+
return URL(fileURLWithPath: hfHome).appendingPathComponent("hub")
64+
}
65+
return FileManager.default.homeDirectoryForCurrentUser
66+
.appendingPathComponent(".cache/huggingface/hub")
67+
}
68+
69+
/// Convert a HuggingFace model ID to its cache directory name.
70+
/// e.g. "mlx-community/Qwen2.5-7B-Instruct-4bit" → "models--mlx-community--Qwen2.5-7B-Instruct-4bit"
71+
public static func cacheDirName(for modelId: String) -> String {
72+
"models--" + modelId.replacingOccurrences(of: "/", with: "--")
73+
}
74+
75+
/// Returns the cache directory URL for a given model ID, or nil if not found.
76+
public static func cacheDirectory(for modelId: String) -> URL? {
77+
let dir = huggingFaceCacheRoot.appendingPathComponent(cacheDirName(for: modelId))
78+
return FileManager.default.fileExists(atPath: dir.path) ? dir : nil
79+
}
80+
81+
// MARK: — Public API
82+
83+
public init() {
84+
refresh()
85+
}
86+
87+
/// Re-scan the HuggingFace cache and update downloaded model list.
88+
public func refresh() {
89+
let root = Self.huggingFaceCacheRoot
90+
guard FileManager.default.fileExists(atPath: root.path) else {
91+
downloadedModels = []
92+
totalDiskUsageBytes = 0
93+
return
94+
}
95+
96+
var found: [DownloadedModel] = []
97+
let fm = FileManager.default
98+
99+
// Enumerate all "models--*" directories
100+
guard let contents = try? fm.contentsOfDirectory(
101+
at: root, includingPropertiesForKeys: [.contentModificationDateKey],
102+
options: [.skipsHiddenFiles]
103+
) else { return }
104+
105+
for dir in contents {
106+
guard dir.lastPathComponent.hasPrefix("models--") else { continue }
107+
108+
// Map directory name back to model ID
109+
let dirName = dir.lastPathComponent
110+
let modelId = dirName
111+
.replacingOccurrences(of: "^models--", with: "", options: .regularExpression)
112+
.replacingOccurrences(of: "--", with: "/")
113+
114+
// Only include models in our catalog
115+
guard ModelCatalog.all.contains(where: { $0.id == modelId }) else { continue }
116+
117+
let size = directorySize(at: dir)
118+
let modDate = (try? dir.resourceValues(forKeys: [.contentModificationDateKey]))?.contentModificationDate
119+
120+
found.append(DownloadedModel(
121+
id: modelId,
122+
cacheDirectory: dir,
123+
sizeBytes: size,
124+
modifiedDate: modDate
125+
))
126+
}
127+
128+
downloadedModels = found.sorted { ($0.modifiedDate ?? .distantPast) > ($1.modifiedDate ?? .distantPast) }
129+
totalDiskUsageBytes = found.reduce(0) { $0 + $1.sizeBytes }
130+
}
131+
132+
/// Returns true if the model has been fully downloaded to local cache.
133+
public func isDownloaded(_ modelId: String) -> Bool {
134+
downloadedModels.contains(where: { $0.id == modelId })
135+
}
136+
137+
/// Returns the downloaded model entry for a given ID, if available.
138+
public func downloadedModel(for modelId: String) -> DownloadedModel? {
139+
downloadedModels.first(where: { $0.id == modelId })
140+
}
141+
142+
/// Delete a model from local cache, freeing disk space.
143+
public func delete(_ modelId: String) throws {
144+
guard let dir = Self.cacheDirectory(for: modelId) else { return }
145+
try FileManager.default.removeItem(at: dir)
146+
refresh()
147+
}
148+
149+
/// Update active download progress (called by InferenceEngine during load).
150+
public func updateProgress(_ progress: ModelDownloadProgress) {
151+
activeDownloads[progress.modelId] = progress
152+
}
153+
154+
/// Mark a download as complete.
155+
public func completeDownload(modelId: String) {
156+
activeDownloads.removeValue(forKey: modelId)
157+
refresh()
158+
lastLoadedModelId = modelId
159+
}
160+
161+
/// Cancel an active download tracking entry.
162+
public func cancelDownload(modelId: String) {
163+
activeDownloads.removeValue(forKey: modelId)
164+
}
165+
166+
// MARK: — Helpers
167+
168+
private func directorySize(at url: URL) -> Int64 {
169+
let fm = FileManager.default
170+
guard let enumerator = fm.enumerator(
171+
at: url,
172+
includingPropertiesForKeys: [.fileSizeKey],
173+
options: [.skipsHiddenFiles]
174+
) else { return 0 }
175+
176+
var total: Int64 = 0
177+
for case let fileURL as URL in enumerator {
178+
let size = (try? fileURL.resourceValues(forKeys: [.fileSizeKey]))?.fileSize ?? 0
179+
total += Int64(size)
180+
}
181+
return total
182+
}
183+
}

0 commit comments

Comments
 (0)