@@ -247,31 +247,32 @@ func (r *ModelRegistry) LoadFromCache(ctx context.Context) (int, error) {
247247 }
248248 r .mu .RUnlock ()
249249
250- // Populate model maps from cache rows . Unqualified lookups keep "first provider wins".
251- newModels := make (map [string ]* ModelInfo , len ( modelCache . Models ) )
250+ // Populate model maps from grouped cache structure . Unqualified lookups keep "first provider wins".
251+ newModels := make (map [string ]* ModelInfo )
252252 newModelsByProvider := make (map [string ]map [string ]* ModelInfo )
253- for _ , cached := range modelCache .Models {
254- provider , ok := nameToProvider [cached . Provider ]
253+ for providerName , cachedProvider := range modelCache .Providers {
254+ provider , ok := nameToProvider [providerName ]
255255 if ! ok {
256- // Provider not configured, skip this model
256+ // Provider not configured, skip all its models
257257 continue
258258 }
259- info := & ModelInfo {
260- Model : core. Model {
261- ID : cached . ModelID ,
262- Object : cached . Object ,
263- OwnedBy : cached .OwnedBy ,
264- Created : cached . Created ,
265- } ,
266- Provider : provider ,
267- }
268- if _ , ok := newModelsByProvider [ cached . Provider ]; ! ok {
269- newModelsByProvider [ cached . Provider ] = make ( map [ string ] * ModelInfo )
270- }
271- newModelsByProvider [cached.Provider ][cached. ModelID ] = info
272- if _ , exists := newModels [cached .ModelID ]; ! exists {
273- newModels [ cached . ModelID ] = info
259+ providerModels := make ( map [ string ] * ModelInfo , len ( cachedProvider . Models ))
260+ for _ , cached := range cachedProvider . Models {
261+ info := & ModelInfo {
262+ Model : core. Model {
263+ ID : cached .ID ,
264+ Object : "model" ,
265+ OwnedBy : cachedProvider . OwnedBy ,
266+ Created : cached . Created ,
267+ },
268+ Provider : provider ,
269+ }
270+ providerModels [ cached . ID ] = info
271+ if _ , exists := newModels [cached .ID ]; ! exists {
272+ newModels [cached .ID ] = info
273+ }
274274 }
275+ newModelsByProvider [providerName ] = providerModels
275276 }
276277
277278 // Load model list data from cache if available
@@ -331,49 +332,58 @@ func (r *ModelRegistry) SaveToCache(ctx context.Context) error {
331332 return nil
332333 }
333334
334- // Build cache structure as a slice of provider/model rows .
335+ // Build grouped cache structure: one entry per provider with its models .
335336 modelCache := & cache.ModelCache {
336337 UpdatedAt : time .Now ().UTC (),
337- Models : make ([ ]cache.CachedModel , 0 ),
338+ Providers : make (map [ string ]cache.CachedProvider , len ( modelsByProvider ) ),
338339 ModelListData : modelListRaw ,
339340 }
340341
341- providerNames := make ([]string , 0 , len (modelsByProvider ))
342- for providerName := range modelsByProvider {
343- providerNames = append (providerNames , providerName )
344- }
345- sort .Strings (providerNames )
342+ var totalModels int
343+ for providerName , models := range modelsByProvider {
344+ // Determine provider type and owned_by from any model in this provider group.
345+ var pType , ownedBy string
346+ for _ , info := range models {
347+ t , ok := providerTypes [info .Provider ]
348+ if ! ok {
349+ continue
350+ }
351+ pType = t
352+ ownedBy = info .Model .OwnedBy
353+ break
354+ }
355+ if pType == "" {
356+ // No known provider type for this provider, skip entirely.
357+ continue
358+ }
346359
347- for _ , providerName := range providerNames {
348- modelIDs := make ([]string , 0 , len (modelsByProvider [providerName ]))
349- for modelID := range modelsByProvider [providerName ] {
360+ modelIDs := make ([]string , 0 , len (models ))
361+ for modelID := range models {
350362 modelIDs = append (modelIDs , modelID )
351363 }
352364 sort .Strings (modelIDs )
353365
366+ cachedModels := make ([]cache.CachedModel , 0 , len (modelIDs ))
354367 for _ , modelID := range modelIDs {
355- info := modelsByProvider [providerName ][modelID ]
356- pType , ok := providerTypes [info .Provider ]
357- if ! ok {
358- // Skip models without a known provider type.
359- continue
360- }
361- modelCache .Models = append (modelCache .Models , cache.CachedModel {
362- ModelID : modelID ,
363- Provider : providerName ,
364- ProviderType : pType ,
365- Object : info .Model .Object ,
366- OwnedBy : info .Model .OwnedBy ,
367- Created : info .Model .Created ,
368+ info := models [modelID ]
369+ cachedModels = append (cachedModels , cache.CachedModel {
370+ ID : modelID ,
371+ Created : info .Model .Created ,
368372 })
369373 }
374+ modelCache .Providers [providerName ] = cache.CachedProvider {
375+ ProviderType : pType ,
376+ OwnedBy : ownedBy ,
377+ Models : cachedModels ,
378+ }
379+ totalModels += len (cachedModels )
370380 }
371381
372382 if err := cacheBackend .Set (ctx , modelCache ); err != nil {
373383 return fmt .Errorf ("failed to save cache: %w" , err )
374384 }
375385
376- slog .Debug ("saved models to cache" , "models" , len ( modelCache . Models ) )
386+ slog .Debug ("saved models to cache" , "models" , totalModels )
377387 return nil
378388}
379389
@@ -427,10 +437,10 @@ func (r *ModelRegistry) GetProvider(model string) core.Provider {
427437 return info .Provider
428438 }
429439 }
430- return nil
440+ // Fall through: the slash may be part of the model ID (e.g. "meta-llama/Meta-Llama-3-70B")
431441 }
432442
433- if info , ok := r .models [modelID ]; ok {
443+ if info , ok := r .models [model ]; ok {
434444 return info .Provider
435445 }
436446 return nil
@@ -444,12 +454,14 @@ func (r *ModelRegistry) GetModel(model string) *ModelInfo {
444454 providerName , modelID := splitModelSelector (model )
445455 if providerName != "" {
446456 if providerModels , ok := r .modelsByProvider [providerName ]; ok {
447- return providerModels [modelID ]
457+ if info , exists := providerModels [modelID ]; exists {
458+ return info
459+ }
448460 }
449- return nil
461+ // Fall through: the slash may be part of the model ID
450462 }
451463
452- if info , ok := r .models [modelID ]; ok {
464+ if info , ok := r .models [model ]; ok {
453465 return info
454466 }
455467 return nil
@@ -462,15 +474,15 @@ func (r *ModelRegistry) Supports(model string) bool {
462474
463475 providerName , modelID := splitModelSelector (model )
464476 if providerName != "" {
465- providerModels , ok := r .modelsByProvider [providerName ]
466- if ! ok {
467- return false
477+ if providerModels , ok := r .modelsByProvider [providerName ]; ok {
478+ if _ , exists := providerModels [modelID ]; exists {
479+ return true
480+ }
468481 }
469- _ , ok = providerModels [modelID ]
470- return ok
482+ // Fall through: the slash may be part of the model ID
471483 }
472484
473- _ , ok := r .models [modelID ]
485+ _ , ok := r .models [model ]
474486 return ok
475487}
476488
@@ -522,14 +534,12 @@ func (r *ModelRegistry) GetProviderType(model string) string {
522534 return r .providerTypes [info .Provider ]
523535 }
524536 }
525- return ""
537+ // Fall through: the slash may be part of the model ID
526538 }
527539
528- info , ok := r .models [modelID ]
529- if ok {
540+ if info , ok := r .models [model ]; ok {
530541 return r .providerTypes [info .Provider ]
531542 }
532-
533543 return ""
534544}
535545
0 commit comments