Skip to content

Commit bfc980a

Browse files
simbasimba
authored andcommitted
feat: Phase 2 — JSON mode, VLM vision support, multipart content, extra sampling params
- Add response_format: { type: 'json_object' } with prompt injection + fence stripping - Add --vision CLI flag for VLM model loading via VLMModelFactory - Parse OpenAI multipart content (string or [{type:'text',...},{type:'image_url',...}]) - Decode base64 data URIs and HTTP URLs into UserInput.Image for VLM inference - Accept top_k, frequency_penalty, presence_penalty (API compat) - Add MLXVLM package dependency - Add 4 new regression tests (Tests 14-17), total: 38 assertions All 38 tests pass.
1 parent 519bfda commit bfc980a

3 files changed

Lines changed: 248 additions & 15 deletions

File tree

Package.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ let package = Package(
2323
dependencies: [
2424
.product(name: "MLX", package: "mlx-swift"),
2525
.product(name: "MLXLLM", package: "mlx-swift-lm"),
26+
.product(name: "MLXVLM", package: "mlx-swift-lm"),
2627
.product(name: "MLXLMCommon", package: "mlx-swift-lm"),
2728
.product(name: "Transformers", package: "swift-transformers"),
2829
.product(name: "Hummingbird", package: "hummingbird"),

Sources/mlx-server/main.swift

Lines changed: 153 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
// mlx-server --model mlx-community/Qwen2.5-3B-Instruct-4bit --port 5413
1111

1212
import ArgumentParser
13+
import CoreImage
1314
import Foundation
1415
import HTTPTypes
1516
import Hummingbird
1617
import MLX
1718
import MLXLLM
1819
import MLXLMCommon
20+
import MLXVLM
1921

2022
// ── CLI ──────────────────────────────────────────────────────────────────────
2123

@@ -56,6 +58,9 @@ struct MLXServer: AsyncParsableCommand {
5658
@Flag(name: .long, help: "Enable thinking/reasoning mode (Qwen3.5 etc). Default: disabled")
5759
var thinking: Bool = false
5860

61+
@Flag(name: .long, help: "Enable VLM (vision-language model) mode for image inputs")
62+
var vision: Bool = false
63+
5964
@Option(name: .long, help: "Allowed CORS origin (* for all, or a specific origin URL)")
6065
var cors: String?
6166

@@ -79,11 +84,23 @@ struct MLXServer: AsyncParsableCommand {
7984
modelConfig = ModelConfiguration(id: modelId)
8085
}
8186

82-
let container = try await LLMModelFactory.shared.loadContainer(
83-
configuration: modelConfig
84-
) { progress in
85-
let pct = Int(progress.fractionCompleted * 100)
86-
print("[mlx-server] Download: \(pct)%")
87+
let isVision = self.vision
88+
let container: ModelContainer
89+
if isVision {
90+
print("[mlx-server] Loading VLM (vision-language model)...")
91+
container = try await VLMModelFactory.shared.loadContainer(
92+
configuration: modelConfig
93+
) { progress in
94+
let pct = Int(progress.fractionCompleted * 100)
95+
print("[mlx-server] Download: \(pct)%")
96+
}
97+
} else {
98+
container = try await LLMModelFactory.shared.loadContainer(
99+
configuration: modelConfig
100+
) { progress in
101+
let pct = Int(progress.fractionCompleted * 100)
102+
print("[mlx-server] Download: \(pct)%")
103+
}
87104
}
88105

89106
print("[mlx-server] Model loaded. Starting HTTP server on \(host):\(port)")
@@ -96,7 +113,8 @@ struct MLXServer: AsyncParsableCommand {
96113
temp: self.temp,
97114
topP: self.topP,
98115
repeatPenalty: self.repeatPenalty,
99-
thinking: self.thinking
116+
thinking: self.thinking,
117+
isVision: isVision
100118
)
101119

102120
let parallelSlots = self.parallel
@@ -170,7 +188,7 @@ struct MLXServer: AsyncParsableCommand {
170188
"port": port,
171189
"model": modelId,
172190
"engine": "mlx",
173-
"vision": false
191+
"vision": isVision
174192
]
175193
if let data = try? JSONSerialization.data(withJSONObject: readyEvent),
176194
let json = String(data: data, encoding: .utf8) {
@@ -192,6 +210,7 @@ struct ServerConfig: Sendable {
192210
let topP: Float
193211
let repeatPenalty: Float?
194212
let thinking: Bool
213+
let isVision: Bool
195214
}
196215

197216
// ── Request Body Extraction ──────────────────────────────────────────────────
@@ -212,6 +231,7 @@ func handleChatCompletion(
212231
) async throws -> Response {
213232
let chatReq = try JSONDecoder().decode(ChatCompletionRequest.self, from: bodyData)
214233
let isStream = chatReq.stream ?? false
234+
let jsonMode = chatReq.responseFormat?.type == "json_object"
215235

216236
// ── Merge per-request overrides with CLI defaults ──
217237
let tokenLimit = chatReq.maxTokens ?? config.maxTokens
@@ -221,6 +241,11 @@ func handleChatCompletion(
221241
let stopSequences = chatReq.stop ?? []
222242
let includeUsage = chatReq.streamOptions?.includeUsage ?? false
223243

244+
// Log extra sampling params if provided (accepted for API compat, not all are used)
245+
if chatReq.topK != nil || chatReq.frequencyPenalty != nil || chatReq.presencePenalty != nil {
246+
// These are accepted but may not affect generation if MLX doesn't support them
247+
}
248+
224249
let params = GenerateParameters(
225250
maxTokens: tokenLimit,
226251
maxKVSize: config.ctxSize,
@@ -234,15 +259,27 @@ func handleChatCompletion(
234259
MLXRandom.seed(UInt64(seed))
235260
}
236261

237-
// Convert request messages → Chat.Message
238-
let chatMessages: [Chat.Message] = chatReq.messages.compactMap { msg in
262+
// ── Parse messages with multipart content support (for VLM images) ──
263+
var chatMessages: [Chat.Message] = []
264+
for msg in chatReq.messages {
265+
let textContent = msg.textContent
266+
let images = msg.extractImages()
239267
switch msg.role {
240-
case "system": return .system(msg.content)
241-
case "assistant": return .assistant(msg.content)
242-
default: return .user(msg.content)
268+
case "system":
269+
chatMessages.append(.system(textContent, images: images))
270+
case "assistant":
271+
chatMessages.append(.assistant(textContent, images: images))
272+
default:
273+
chatMessages.append(.user(textContent, images: images))
243274
}
244275
}
245276

277+
// ── JSON mode: inject system prompt for JSON output ──
278+
if jsonMode {
279+
let jsonSystemMsg = Chat.Message.system("You must respond with valid JSON only. No markdown code fences, no explanation text, no preamble. Output raw JSON.")
280+
chatMessages.insert(jsonSystemMsg, at: 0)
281+
}
282+
246283
// Convert OpenAI tools format → [String: any Sendable] for UserInput
247284
let toolSpecs: [[String: any Sendable]]? = chatReq.tools?.map { tool in
248285
var spec: [String: any Sendable] = ["type": tool.type]
@@ -273,12 +310,13 @@ func handleChatCompletion(
273310
if isStream {
274311
return handleChatStreaming(
275312
stream: stream, modelId: modelId, stopSequences: stopSequences,
276-
includeUsage: includeUsage, promptTokenCount: promptTokenCount, semaphore: semaphore
313+
includeUsage: includeUsage, promptTokenCount: promptTokenCount,
314+
jsonMode: jsonMode, semaphore: semaphore
277315
)
278316
} else {
279317
return try await handleChatNonStreaming(
280318
stream: stream, modelId: modelId, stopSequences: stopSequences,
281-
promptTokenCount: promptTokenCount, semaphore: semaphore
319+
promptTokenCount: promptTokenCount, jsonMode: jsonMode, semaphore: semaphore
282320
)
283321
}
284322
}
@@ -291,6 +329,7 @@ func handleChatStreaming(
291329
stopSequences: [String],
292330
includeUsage: Bool,
293331
promptTokenCount: Int,
332+
jsonMode: Bool = false,
294333
semaphore: AsyncSemaphore
295334
) -> Response {
296335
let (sseStream, cont) = AsyncStream<String>.makeStream()
@@ -357,6 +396,7 @@ func handleChatNonStreaming(
357396
modelId: String,
358397
stopSequences: [String],
359398
promptTokenCount: Int,
399+
jsonMode: Bool = false,
360400
semaphore: AsyncSemaphore
361401
) async throws -> Response {
362402
var fullText = ""
@@ -389,6 +429,18 @@ func handleChatNonStreaming(
389429
finishReason = "stop"
390430
}
391431

432+
// ── JSON mode validation ──
433+
if jsonMode {
434+
// Strip markdown code fences if model wrapped response
435+
let stripped = fullText
436+
.replacingOccurrences(of: "```json\n", with: "")
437+
.replacingOccurrences(of: "```json", with: "")
438+
.replacingOccurrences(of: "```\n", with: "")
439+
.replacingOccurrences(of: "```", with: "")
440+
.trimmingCharacters(in: .whitespacesAndNewlines)
441+
fullText = stripped
442+
}
443+
392444
let totalTokens = promptTokenCount + completionTokenCount
393445
let hasToolCalls = !collectedToolCalls.isEmpty
394446

@@ -765,11 +817,89 @@ struct StreamOptions: Decodable {
765817
}
766818
}
767819

820+
struct ResponseFormat: Decodable {
821+
let type: String
822+
}
823+
768824
struct ChatCompletionRequest: Decodable {
825+
/// Message content can be a plain string or an array of content parts (text + image_url)
769826
struct Message: Decodable {
770827
let role: String
771-
let content: String
828+
let content: MessageContent
829+
830+
/// Extract plain text from content (handles both string and multipart)
831+
var textContent: String {
832+
switch content {
833+
case .string(let s): return s
834+
case .parts(let parts):
835+
return parts.compactMap { part in
836+
if part.type == "text" { return part.text }
837+
return nil
838+
}.joined(separator: "\n")
839+
}
840+
}
841+
842+
/// Extract images from multipart content (base64 data URIs and HTTP URLs)
843+
func extractImages() -> [UserInput.Image] {
844+
guard case .parts(let parts) = content else { return [] }
845+
return parts.compactMap { part -> UserInput.Image? in
846+
guard part.type == "image_url", let imageUrl = part.imageUrl else { return nil }
847+
let urlStr = imageUrl.url
848+
// Handle base64 data URIs: data:image/png;base64,...
849+
if urlStr.hasPrefix("data:") {
850+
guard let commaIdx = urlStr.firstIndex(of: ",") else { return nil }
851+
let base64Str = String(urlStr[urlStr.index(after: commaIdx)...])
852+
guard let data = Data(base64Encoded: base64Str),
853+
let ciImage = CIImage(data: data) else { return nil }
854+
return .ciImage(ciImage)
855+
}
856+
// Handle HTTP/HTTPS URLs
857+
if let url = URL(string: urlStr),
858+
(url.scheme == "http" || url.scheme == "https") {
859+
return .url(url)
860+
}
861+
// Handle file URLs
862+
if let url = URL(string: urlStr) {
863+
return .url(url)
864+
}
865+
return nil
866+
}
867+
}
868+
}
869+
870+
/// Message content: either a plain string or structured multipart content
871+
enum MessageContent: Decodable {
872+
case string(String)
873+
case parts([ContentPart])
874+
875+
init(from decoder: Decoder) throws {
876+
let container = try decoder.singleValueContainer()
877+
if let str = try? container.decode(String.self) {
878+
self = .string(str)
879+
} else if let parts = try? container.decode([ContentPart].self) {
880+
self = .parts(parts)
881+
} else {
882+
self = .string("")
883+
}
884+
}
772885
}
886+
887+
struct ContentPart: Decodable {
888+
let type: String
889+
let text: String?
890+
let imageUrl: ImageUrlContent?
891+
892+
enum CodingKeys: String, CodingKey {
893+
case type, text
894+
case imageUrl = "image_url"
895+
}
896+
}
897+
898+
struct ImageUrlContent: Decodable {
899+
let url: String
900+
let detail: String?
901+
}
902+
773903
struct ToolDef: Decodable {
774904
let type: String
775905
let function: ToolFuncDef
@@ -785,18 +915,26 @@ struct ChatCompletionRequest: Decodable {
785915
let maxTokens: Int?
786916
let temperature: Double?
787917
let topP: Double?
918+
let topK: Int?
788919
let repetitionPenalty: Double?
920+
let frequencyPenalty: Double?
921+
let presencePenalty: Double?
789922
let tools: [ToolDef]?
790923
let stop: [String]?
791924
let seed: Int?
792925
let streamOptions: StreamOptions?
926+
let responseFormat: ResponseFormat?
793927

794928
enum CodingKeys: String, CodingKey {
795929
case model, messages, stream, temperature, tools, stop, seed
796930
case maxTokens = "max_tokens"
797931
case topP = "top_p"
932+
case topK = "top_k"
798933
case repetitionPenalty = "repetition_penalty"
934+
case frequencyPenalty = "frequency_penalty"
935+
case presencePenalty = "presence_penalty"
799936
case streamOptions = "stream_options"
937+
case responseFormat = "response_format"
800938
}
801939
}
802940

0 commit comments

Comments
 (0)