1010// mlx-server --model mlx-community/Qwen2.5-3B-Instruct-4bit --port 5413
1111
1212import ArgumentParser
13+ import CoreImage
1314import Foundation
1415import HTTPTypes
1516import Hummingbird
1617import MLX
1718import MLXLLM
1819import MLXLMCommon
20+ import MLXVLM
1921
2022// ── CLI ──────────────────────────────────────────────────────────────────────
2123
@@ -56,6 +58,9 @@ struct MLXServer: AsyncParsableCommand {
5658 @Flag ( name: . long, help: " Enable thinking/reasoning mode (Qwen3.5 etc). Default: disabled " )
5759 var thinking : Bool = false
5860
61+ @Flag ( name: . long, help: " Enable VLM (vision-language model) mode for image inputs " )
62+ var vision : Bool = false
63+
5964 @Option ( name: . long, help: " Allowed CORS origin (* for all, or a specific origin URL) " )
6065 var cors : String ?
6166
@@ -79,11 +84,23 @@ struct MLXServer: AsyncParsableCommand {
7984 modelConfig = ModelConfiguration ( id: modelId)
8085 }
8186
82- let container = try await LLMModelFactory . shared. loadContainer (
83- configuration: modelConfig
84- ) { progress in
85- let pct = Int ( progress. fractionCompleted * 100 )
86- print ( " [mlx-server] Download: \( pct) % " )
87+ let isVision = self . vision
88+ let container : ModelContainer
89+ if isVision {
90+ print ( " [mlx-server] Loading VLM (vision-language model)... " )
91+ container = try await VLMModelFactory . shared. loadContainer (
92+ configuration: modelConfig
93+ ) { progress in
94+ let pct = Int ( progress. fractionCompleted * 100 )
95+ print ( " [mlx-server] Download: \( pct) % " )
96+ }
97+ } else {
98+ container = try await LLMModelFactory . shared. loadContainer (
99+ configuration: modelConfig
100+ ) { progress in
101+ let pct = Int ( progress. fractionCompleted * 100 )
102+ print ( " [mlx-server] Download: \( pct) % " )
103+ }
87104 }
88105
89106 print ( " [mlx-server] Model loaded. Starting HTTP server on \( host) : \( port) " )
@@ -96,7 +113,8 @@ struct MLXServer: AsyncParsableCommand {
96113 temp: self . temp,
97114 topP: self . topP,
98115 repeatPenalty: self . repeatPenalty,
99- thinking: self . thinking
116+ thinking: self . thinking,
117+ isVision: isVision
100118 )
101119
102120 let parallelSlots = self . parallel
@@ -170,7 +188,7 @@ struct MLXServer: AsyncParsableCommand {
170188 " port " : port,
171189 " model " : modelId,
172190 " engine " : " mlx " ,
173- " vision " : false
191+ " vision " : isVision
174192 ]
175193 if let data = try ? JSONSerialization . data ( withJSONObject: readyEvent) ,
176194 let json = String ( data: data, encoding: . utf8) {
@@ -192,6 +210,7 @@ struct ServerConfig: Sendable {
192210 let topP : Float
193211 let repeatPenalty : Float ?
194212 let thinking : Bool
213+ let isVision : Bool
195214}
196215
197216// ── Request Body Extraction ──────────────────────────────────────────────────
@@ -212,6 +231,7 @@ func handleChatCompletion(
212231) async throws -> Response {
213232 let chatReq = try JSONDecoder ( ) . decode ( ChatCompletionRequest . self, from: bodyData)
214233 let isStream = chatReq. stream ?? false
234+ let jsonMode = chatReq. responseFormat? . type == " json_object "
215235
216236 // ── Merge per-request overrides with CLI defaults ──
217237 let tokenLimit = chatReq. maxTokens ?? config. maxTokens
@@ -221,6 +241,11 @@ func handleChatCompletion(
221241 let stopSequences = chatReq. stop ?? [ ]
222242 let includeUsage = chatReq. streamOptions? . includeUsage ?? false
223243
244+ // Log extra sampling params if provided (accepted for API compat, not all are used)
245+ if chatReq. topK != nil || chatReq. frequencyPenalty != nil || chatReq. presencePenalty != nil {
246+ // These are accepted but may not affect generation if MLX doesn't support them
247+ }
248+
224249 let params = GenerateParameters (
225250 maxTokens: tokenLimit,
226251 maxKVSize: config. ctxSize,
@@ -234,15 +259,27 @@ func handleChatCompletion(
234259 MLXRandom . seed ( UInt64 ( seed) )
235260 }
236261
237- // Convert request messages → Chat.Message
238- let chatMessages : [ Chat . Message ] = chatReq. messages. compactMap { msg in
262+ // ── Parse messages with multipart content support (for VLM images) ──
263+ var chatMessages : [ Chat . Message ] = [ ]
264+ for msg in chatReq. messages {
265+ let textContent = msg. textContent
266+ let images = msg. extractImages ( )
239267 switch msg. role {
240- case " system " : return . system( msg. content)
241- case " assistant " : return . assistant( msg. content)
242- default : return . user( msg. content)
268+ case " system " :
269+ chatMessages. append ( . system( textContent, images: images) )
270+ case " assistant " :
271+ chatMessages. append ( . assistant( textContent, images: images) )
272+ default :
273+ chatMessages. append ( . user( textContent, images: images) )
243274 }
244275 }
245276
277+ // ── JSON mode: inject system prompt for JSON output ──
278+ if jsonMode {
279+ let jsonSystemMsg = Chat . Message. system ( " You must respond with valid JSON only. No markdown code fences, no explanation text, no preamble. Output raw JSON. " )
280+ chatMessages. insert ( jsonSystemMsg, at: 0 )
281+ }
282+
246283 // Convert OpenAI tools format → [String: any Sendable] for UserInput
247284 let toolSpecs : [ [ String : any Sendable ] ] ? = chatReq. tools? . map { tool in
248285 var spec : [ String : any Sendable ] = [ " type " : tool. type]
@@ -273,12 +310,13 @@ func handleChatCompletion(
273310 if isStream {
274311 return handleChatStreaming (
275312 stream: stream, modelId: modelId, stopSequences: stopSequences,
276- includeUsage: includeUsage, promptTokenCount: promptTokenCount, semaphore: semaphore
313+ includeUsage: includeUsage, promptTokenCount: promptTokenCount,
314+ jsonMode: jsonMode, semaphore: semaphore
277315 )
278316 } else {
279317 return try await handleChatNonStreaming (
280318 stream: stream, modelId: modelId, stopSequences: stopSequences,
281- promptTokenCount: promptTokenCount, semaphore: semaphore
319+ promptTokenCount: promptTokenCount, jsonMode : jsonMode , semaphore: semaphore
282320 )
283321 }
284322}
@@ -291,6 +329,7 @@ func handleChatStreaming(
291329 stopSequences: [ String ] ,
292330 includeUsage: Bool ,
293331 promptTokenCount: Int ,
332+ jsonMode: Bool = false ,
294333 semaphore: AsyncSemaphore
295334) -> Response {
296335 let ( sseStream, cont) = AsyncStream< String> . makeStream( )
@@ -357,6 +396,7 @@ func handleChatNonStreaming(
357396 modelId: String ,
358397 stopSequences: [ String ] ,
359398 promptTokenCount: Int ,
399+ jsonMode: Bool = false ,
360400 semaphore: AsyncSemaphore
361401) async throws -> Response {
362402 var fullText = " "
@@ -389,6 +429,18 @@ func handleChatNonStreaming(
389429 finishReason = " stop "
390430 }
391431
432+ // ── JSON mode validation ──
433+ if jsonMode {
434+ // Strip markdown code fences if model wrapped response
435+ let stripped = fullText
436+ . replacingOccurrences ( of: " ```json \n " , with: " " )
437+ . replacingOccurrences ( of: " ```json " , with: " " )
438+ . replacingOccurrences ( of: " ``` \n " , with: " " )
439+ . replacingOccurrences ( of: " ``` " , with: " " )
440+ . trimmingCharacters ( in: . whitespacesAndNewlines)
441+ fullText = stripped
442+ }
443+
392444 let totalTokens = promptTokenCount + completionTokenCount
393445 let hasToolCalls = !collectedToolCalls. isEmpty
394446
@@ -765,11 +817,89 @@ struct StreamOptions: Decodable {
765817 }
766818}
767819
820+ struct ResponseFormat : Decodable {
821+ let type : String
822+ }
823+
768824struct ChatCompletionRequest : Decodable {
825+ /// Message content can be a plain string or an array of content parts (text + image_url)
769826 struct Message : Decodable {
770827 let role : String
771- let content : String
828+ let content : MessageContent
829+
830+ /// Extract plain text from content (handles both string and multipart)
831+ var textContent : String {
832+ switch content {
833+ case . string( let s) : return s
834+ case . parts( let parts) :
835+ return parts. compactMap { part in
836+ if part. type == " text " { return part. text }
837+ return nil
838+ } . joined ( separator: " \n " )
839+ }
840+ }
841+
842+ /// Extract images from multipart content (base64 data URIs and HTTP URLs)
843+ func extractImages( ) -> [ UserInput . Image ] {
844+ guard case . parts( let parts) = content else { return [ ] }
845+ return parts. compactMap { part -> UserInput . Image ? in
846+ guard part. type == " image_url " , let imageUrl = part. imageUrl else { return nil }
847+ let urlStr = imageUrl. url
848+ // Handle base64 data URIs: data:image/png;base64,...
849+ if urlStr. hasPrefix ( " data: " ) {
850+ guard let commaIdx = urlStr. firstIndex ( of: " , " ) else { return nil }
851+ let base64Str = String ( urlStr [ urlStr. index ( after: commaIdx) ... ] )
852+ guard let data = Data ( base64Encoded: base64Str) ,
853+ let ciImage = CIImage ( data: data) else { return nil }
854+ return . ciImage( ciImage)
855+ }
856+ // Handle HTTP/HTTPS URLs
857+ if let url = URL ( string: urlStr) ,
858+ ( url. scheme == " http " || url. scheme == " https " ) {
859+ return . url( url)
860+ }
861+ // Handle file URLs
862+ if let url = URL ( string: urlStr) {
863+ return . url( url)
864+ }
865+ return nil
866+ }
867+ }
868+ }
869+
870+ /// Message content: either a plain string or structured multipart content
871+ enum MessageContent : Decodable {
872+ case string( String )
873+ case parts( [ ContentPart ] )
874+
875+ init ( from decoder: Decoder ) throws {
876+ let container = try decoder. singleValueContainer ( )
877+ if let str = try ? container. decode ( String . self) {
878+ self = . string( str)
879+ } else if let parts = try ? container. decode ( [ ContentPart ] . self) {
880+ self = . parts( parts)
881+ } else {
882+ self = . string( " " )
883+ }
884+ }
772885 }
886+
887+ struct ContentPart : Decodable {
888+ let type : String
889+ let text : String ?
890+ let imageUrl : ImageUrlContent ?
891+
892+ enum CodingKeys : String , CodingKey {
893+ case type, text
894+ case imageUrl = " image_url "
895+ }
896+ }
897+
898+ struct ImageUrlContent : Decodable {
899+ let url : String
900+ let detail : String ?
901+ }
902+
773903 struct ToolDef : Decodable {
774904 let type : String
775905 let function : ToolFuncDef
@@ -785,18 +915,26 @@ struct ChatCompletionRequest: Decodable {
785915 let maxTokens : Int ?
786916 let temperature : Double ?
787917 let topP : Double ?
918+ let topK : Int ?
788919 let repetitionPenalty : Double ?
920+ let frequencyPenalty : Double ?
921+ let presencePenalty : Double ?
789922 let tools : [ ToolDef ] ?
790923 let stop : [ String ] ?
791924 let seed : Int ?
792925 let streamOptions : StreamOptions ?
926+ let responseFormat : ResponseFormat ?
793927
794928 enum CodingKeys : String , CodingKey {
795929 case model, messages, stream, temperature, tools, stop, seed
796930 case maxTokens = " max_tokens "
797931 case topP = " top_p "
932+ case topK = " top_k "
798933 case repetitionPenalty = " repetition_penalty "
934+ case frequencyPenalty = " frequency_penalty "
935+ case presencePenalty = " presence_penalty "
799936 case streamOptions = " stream_options "
937+ case responseFormat = " response_format "
800938 }
801939}
802940
0 commit comments