Skip to content

Commit b29c82c

Browse files
cache: group models by provider and fix slash-in-model-ID lookups
Restructure ModelCache to group models by provider, eliminating per-model repetition of provider_type, owned_by, and object fields. Drop the always-"model" Object field entirely. Fix splitModelSelector fallthrough: when a qualified lookup fails (e.g. "meta-llama/Meta-Llama-3-70B" where "meta-llama" is not a provider name), fall through to unqualified lookup using the original model string instead of returning nil. Applied to GetProvider, GetModel, Supports, and GetProviderType. Harden cache tests with fatal guards before slice indexing. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent edf2753 commit b29c82c

4 files changed

Lines changed: 172 additions & 145 deletions

File tree

internal/cache/cache.go

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,27 @@ import (
99
)
1010

1111
// ModelCache represents the cached model data structure.
12-
// This is the data that gets stored and retrieved from the cache.
12+
// Models are grouped by provider to avoid repeating shared fields (provider_type, owned_by)
13+
// on every model entry.
1314
type ModelCache struct {
14-
UpdatedAt time.Time `json:"updated_at"`
15-
Models []CachedModel `json:"models"`
15+
UpdatedAt time.Time `json:"updated_at"`
16+
Providers map[string]CachedProvider `json:"providers"`
1617
// ModelListData holds the raw JSON model registry bytes for cache persistence,
1718
// allowing the registry to restore its full model list without re-fetching.
1819
ModelListData json.RawMessage `json:"model_list_data,omitempty"`
1920
}
2021

21-
// CachedModel represents a single cached model entry.
22+
// CachedProvider holds shared fields for all models from a single provider.
23+
type CachedProvider struct {
24+
ProviderType string `json:"provider_type"`
25+
OwnedBy string `json:"owned_by"`
26+
Models []CachedModel `json:"models"`
27+
}
28+
29+
// CachedModel represents a single cached model entry within a provider group.
2230
type CachedModel struct {
23-
ModelID string `json:"model_id"`
24-
Provider string `json:"provider"`
25-
ProviderType string `json:"provider_type"`
26-
Object string `json:"object"`
27-
OwnedBy string `json:"owned_by"`
28-
Created int64 `json:"created"`
31+
ID string `json:"id"`
32+
Created int64 `json:"created"`
2933
}
3034

3135
// Cache defines the interface for model cache storage.

internal/cache/cache_test.go

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,13 @@ func TestLocalCache(t *testing.T) {
2929
// Set data
3030
data := &ModelCache{
3131
UpdatedAt: time.Now().UTC(),
32-
Models: []CachedModel{
33-
{
34-
ModelID: "test-model",
35-
Provider: "openai",
32+
Providers: map[string]CachedProvider{
33+
"openai": {
3634
ProviderType: "openai",
37-
Object: "model",
3835
OwnedBy: "openai",
39-
Created: 1234567890,
36+
Models: []CachedModel{
37+
{ID: "test-model", Created: 1234567890},
38+
},
4039
},
4140
},
4241
}
@@ -54,11 +53,12 @@ func TestLocalCache(t *testing.T) {
5453
if result == nil {
5554
t.Fatal("expected result, got nil")
5655
}
57-
if len(result.Models) != 1 {
58-
t.Errorf("expected 1 model, got %d", len(result.Models))
56+
p, ok := result.Providers["openai"]
57+
if !ok || len(p.Models) != 1 {
58+
t.Fatalf("expected 1 model in openai provider, got %v", result.Providers)
5959
}
60-
if result.Models[0].ModelID != "test-model" {
61-
t.Errorf("expected test-model in cache, got %q", result.Models[0].ModelID)
60+
if p.Models[0].ID != "test-model" {
61+
t.Errorf("expected test-model in cache, got %q", p.Models[0].ID)
6262
}
6363
})
6464

@@ -70,7 +70,7 @@ func TestLocalCache(t *testing.T) {
7070
ctx := context.Background()
7171

7272
data := &ModelCache{
73-
Models: []CachedModel{},
73+
Providers: map[string]CachedProvider{},
7474
}
7575

7676
err := cache.Set(ctx, data)
@@ -136,22 +136,20 @@ func TestModelCacheSerialization(t *testing.T) {
136136
t.Run("JSONRoundTrip", func(t *testing.T) {
137137
original := &ModelCache{
138138
UpdatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
139-
Models: []CachedModel{
140-
{
141-
ModelID: "gpt-4",
142-
Provider: "openai-main",
139+
Providers: map[string]CachedProvider{
140+
"openai-main": {
143141
ProviderType: "openai",
144-
Object: "model",
145142
OwnedBy: "openai",
146-
Created: 1234567890,
143+
Models: []CachedModel{
144+
{ID: "gpt-4", Created: 1234567890},
145+
},
147146
},
148-
{
149-
ModelID: "claude-3",
150-
Provider: "anthropic-main",
147+
"anthropic-main": {
151148
ProviderType: "anthropic",
152-
Object: "model",
153149
OwnedBy: "anthropic",
154-
Created: 1234567891,
150+
Models: []CachedModel{
151+
{ID: "claude-3", Created: 1234567891},
152+
},
155153
},
156154
},
157155
}
@@ -166,17 +164,22 @@ func TestModelCacheSerialization(t *testing.T) {
166164
t.Fatalf("failed to unmarshal: %v", err)
167165
}
168166

169-
if len(restored.Models) != len(original.Models) {
170-
t.Errorf("model count mismatch: got %d, want %d", len(restored.Models), len(original.Models))
167+
if len(restored.Providers) != len(original.Providers) {
168+
t.Fatalf("provider count mismatch: got %d, want %d", len(restored.Providers), len(original.Providers))
171169
}
172-
if restored.Models[0].ModelID != original.Models[0].ModelID {
173-
t.Errorf("first model ID mismatch: got %q, want %q", restored.Models[0].ModelID, original.Models[0].ModelID)
170+
openai, ok := restored.Providers["openai-main"]
171+
if !ok || len(openai.Models) == 0 {
172+
t.Fatalf("expected openai-main provider with models, got %v", restored.Providers)
174173
}
175-
if restored.Models[0].Provider != original.Models[0].Provider {
176-
t.Errorf("first provider mismatch: got %q, want %q", restored.Models[0].Provider, original.Models[0].Provider)
174+
if openai.Models[0].ID != "gpt-4" {
175+
t.Errorf("openai model ID mismatch: got %q, want %q", openai.Models[0].ID, "gpt-4")
177176
}
178-
if restored.Models[1].ProviderType != original.Models[1].ProviderType {
179-
t.Errorf("second provider type mismatch: got %q, want %q", restored.Models[1].ProviderType, original.Models[1].ProviderType)
177+
if openai.ProviderType != "openai" {
178+
t.Errorf("openai provider type mismatch: got %q, want %q", openai.ProviderType, "openai")
179+
}
180+
anthropic := restored.Providers["anthropic-main"]
181+
if anthropic.ProviderType != "anthropic" {
182+
t.Errorf("anthropic provider type mismatch: got %q, want %q", anthropic.ProviderType, "anthropic")
180183
}
181184
})
182185
}

internal/providers/registry.go

Lines changed: 69 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -247,31 +247,32 @@ func (r *ModelRegistry) LoadFromCache(ctx context.Context) (int, error) {
247247
}
248248
r.mu.RUnlock()
249249

250-
// Populate model maps from cache rows. Unqualified lookups keep "first provider wins".
251-
newModels := make(map[string]*ModelInfo, len(modelCache.Models))
250+
// Populate model maps from grouped cache structure. Unqualified lookups keep "first provider wins".
251+
newModels := make(map[string]*ModelInfo)
252252
newModelsByProvider := make(map[string]map[string]*ModelInfo)
253-
for _, cached := range modelCache.Models {
254-
provider, ok := nameToProvider[cached.Provider]
253+
for providerName, cachedProvider := range modelCache.Providers {
254+
provider, ok := nameToProvider[providerName]
255255
if !ok {
256-
// Provider not configured, skip this model
256+
// Provider not configured, skip all its models
257257
continue
258258
}
259-
info := &ModelInfo{
260-
Model: core.Model{
261-
ID: cached.ModelID,
262-
Object: cached.Object,
263-
OwnedBy: cached.OwnedBy,
264-
Created: cached.Created,
265-
},
266-
Provider: provider,
267-
}
268-
if _, ok := newModelsByProvider[cached.Provider]; !ok {
269-
newModelsByProvider[cached.Provider] = make(map[string]*ModelInfo)
270-
}
271-
newModelsByProvider[cached.Provider][cached.ModelID] = info
272-
if _, exists := newModels[cached.ModelID]; !exists {
273-
newModels[cached.ModelID] = info
259+
providerModels := make(map[string]*ModelInfo, len(cachedProvider.Models))
260+
for _, cached := range cachedProvider.Models {
261+
info := &ModelInfo{
262+
Model: core.Model{
263+
ID: cached.ID,
264+
Object: "model",
265+
OwnedBy: cachedProvider.OwnedBy,
266+
Created: cached.Created,
267+
},
268+
Provider: provider,
269+
}
270+
providerModels[cached.ID] = info
271+
if _, exists := newModels[cached.ID]; !exists {
272+
newModels[cached.ID] = info
273+
}
274274
}
275+
newModelsByProvider[providerName] = providerModels
275276
}
276277

277278
// Load model list data from cache if available
@@ -331,49 +332,58 @@ func (r *ModelRegistry) SaveToCache(ctx context.Context) error {
331332
return nil
332333
}
333334

334-
// Build cache structure as a slice of provider/model rows.
335+
// Build grouped cache structure: one entry per provider with its models.
335336
modelCache := &cache.ModelCache{
336337
UpdatedAt: time.Now().UTC(),
337-
Models: make([]cache.CachedModel, 0),
338+
Providers: make(map[string]cache.CachedProvider, len(modelsByProvider)),
338339
ModelListData: modelListRaw,
339340
}
340341

341-
providerNames := make([]string, 0, len(modelsByProvider))
342-
for providerName := range modelsByProvider {
343-
providerNames = append(providerNames, providerName)
344-
}
345-
sort.Strings(providerNames)
342+
var totalModels int
343+
for providerName, models := range modelsByProvider {
344+
// Determine provider type and owned_by from any model in this provider group.
345+
var pType, ownedBy string
346+
for _, info := range models {
347+
t, ok := providerTypes[info.Provider]
348+
if !ok {
349+
continue
350+
}
351+
pType = t
352+
ownedBy = info.Model.OwnedBy
353+
break
354+
}
355+
if pType == "" {
356+
// No known provider type for this provider, skip entirely.
357+
continue
358+
}
346359

347-
for _, providerName := range providerNames {
348-
modelIDs := make([]string, 0, len(modelsByProvider[providerName]))
349-
for modelID := range modelsByProvider[providerName] {
360+
modelIDs := make([]string, 0, len(models))
361+
for modelID := range models {
350362
modelIDs = append(modelIDs, modelID)
351363
}
352364
sort.Strings(modelIDs)
353365

366+
cachedModels := make([]cache.CachedModel, 0, len(modelIDs))
354367
for _, modelID := range modelIDs {
355-
info := modelsByProvider[providerName][modelID]
356-
pType, ok := providerTypes[info.Provider]
357-
if !ok {
358-
// Skip models without a known provider type.
359-
continue
360-
}
361-
modelCache.Models = append(modelCache.Models, cache.CachedModel{
362-
ModelID: modelID,
363-
Provider: providerName,
364-
ProviderType: pType,
365-
Object: info.Model.Object,
366-
OwnedBy: info.Model.OwnedBy,
367-
Created: info.Model.Created,
368+
info := models[modelID]
369+
cachedModels = append(cachedModels, cache.CachedModel{
370+
ID: modelID,
371+
Created: info.Model.Created,
368372
})
369373
}
374+
modelCache.Providers[providerName] = cache.CachedProvider{
375+
ProviderType: pType,
376+
OwnedBy: ownedBy,
377+
Models: cachedModels,
378+
}
379+
totalModels += len(cachedModels)
370380
}
371381

372382
if err := cacheBackend.Set(ctx, modelCache); err != nil {
373383
return fmt.Errorf("failed to save cache: %w", err)
374384
}
375385

376-
slog.Debug("saved models to cache", "models", len(modelCache.Models))
386+
slog.Debug("saved models to cache", "models", totalModels)
377387
return nil
378388
}
379389

@@ -427,10 +437,10 @@ func (r *ModelRegistry) GetProvider(model string) core.Provider {
427437
return info.Provider
428438
}
429439
}
430-
return nil
440+
// Fall through: the slash may be part of the model ID (e.g. "meta-llama/Meta-Llama-3-70B")
431441
}
432442

433-
if info, ok := r.models[modelID]; ok {
443+
if info, ok := r.models[model]; ok {
434444
return info.Provider
435445
}
436446
return nil
@@ -444,12 +454,14 @@ func (r *ModelRegistry) GetModel(model string) *ModelInfo {
444454
providerName, modelID := splitModelSelector(model)
445455
if providerName != "" {
446456
if providerModels, ok := r.modelsByProvider[providerName]; ok {
447-
return providerModels[modelID]
457+
if info, exists := providerModels[modelID]; exists {
458+
return info
459+
}
448460
}
449-
return nil
461+
// Fall through: the slash may be part of the model ID
450462
}
451463

452-
if info, ok := r.models[modelID]; ok {
464+
if info, ok := r.models[model]; ok {
453465
return info
454466
}
455467
return nil
@@ -462,15 +474,15 @@ func (r *ModelRegistry) Supports(model string) bool {
462474

463475
providerName, modelID := splitModelSelector(model)
464476
if providerName != "" {
465-
providerModels, ok := r.modelsByProvider[providerName]
466-
if !ok {
467-
return false
477+
if providerModels, ok := r.modelsByProvider[providerName]; ok {
478+
if _, exists := providerModels[modelID]; exists {
479+
return true
480+
}
468481
}
469-
_, ok = providerModels[modelID]
470-
return ok
482+
// Fall through: the slash may be part of the model ID
471483
}
472484

473-
_, ok := r.models[modelID]
485+
_, ok := r.models[model]
474486
return ok
475487
}
476488

@@ -522,14 +534,12 @@ func (r *ModelRegistry) GetProviderType(model string) string {
522534
return r.providerTypes[info.Provider]
523535
}
524536
}
525-
return ""
537+
// Fall through: the slash may be part of the model ID
526538
}
527539

528-
info, ok := r.models[modelID]
529-
if ok {
540+
if info, ok := r.models[model]; ok {
530541
return r.providerTypes[info.Provider]
531542
}
532-
533543
return ""
534544
}
535545

0 commit comments

Comments
 (0)