Skip to content

Commit 35bbb2f

Browse files
fix: return defensive copies from cached model lists and bound category cache
Two issues in the sorted-list caching introduced in the previous commit: 1. Cached slices were returned by reference, allowing callers to mutate shared internal state (core.Model contains a *ModelMetadata pointer). Now all return paths copy the cached slice before returning. 2. categoryCache accepted any ModelCategory string value, allowing unbounded map growth from arbitrary input. Now only the 6 known categories (text_generation, embedding, image, audio, video, utility) are cached; unrecognized categories still produce correct results but skip the cache. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent be49a8e commit 35bbb2f

1 file changed

Lines changed: 38 additions & 18 deletions

File tree

internal/providers/registry.go

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -370,19 +370,20 @@ func (r *ModelRegistry) Supports(model string) bool {
370370

371371
// ListModels returns all models in the registry, sorted by model ID for consistent ordering.
372372
// The sorted slice is cached and rebuilt only when the underlying models change.
373+
// Returns a defensive copy so callers cannot mutate the internal cache.
373374
func (r *ModelRegistry) ListModels() []core.Model {
374375
r.mu.RLock()
375376
if cached := r.sortedModels; cached != nil {
376377
r.mu.RUnlock()
377-
return cached
378+
return append([]core.Model(nil), cached...)
378379
}
379380
r.mu.RUnlock()
380381

381382
r.mu.Lock()
382383
defer r.mu.Unlock()
383384
// Double-check: another goroutine may have built it while we waited for the lock.
384385
if r.sortedModels != nil {
385-
return r.sortedModels
386+
return append([]core.Model(nil), r.sortedModels...)
386387
}
387388

388389
models := make([]core.Model, 0, len(r.models))
@@ -392,7 +393,7 @@ func (r *ModelRegistry) ListModels() []core.Model {
392393
sort.Slice(models, func(i, j int) bool { return models[i].ID < models[j].ID })
393394

394395
r.sortedModels = models
395-
return models
396+
return append([]core.Model(nil), models...)
396397
}
397398

398399
// ModelCount returns the number of registered models
@@ -424,18 +425,19 @@ type ModelWithProvider struct {
424425

425426
// ListModelsWithProvider returns all models with their provider types, sorted by model ID.
426427
// The sorted slice is cached and rebuilt only when the underlying models change.
428+
// Returns a defensive copy so callers cannot mutate the internal cache.
427429
func (r *ModelRegistry) ListModelsWithProvider() []ModelWithProvider {
428430
r.mu.RLock()
429431
if cached := r.sortedModelsWithProvider; cached != nil {
430432
r.mu.RUnlock()
431-
return cached
433+
return append([]ModelWithProvider(nil), cached...)
432434
}
433435
r.mu.RUnlock()
434436

435437
r.mu.Lock()
436438
defer r.mu.Unlock()
437439
if r.sortedModelsWithProvider != nil {
438-
return r.sortedModelsWithProvider
440+
return append([]ModelWithProvider(nil), r.sortedModelsWithProvider...)
439441
}
440442

441443
result := make([]ModelWithProvider, 0, len(r.models))
@@ -448,31 +450,47 @@ func (r *ModelRegistry) ListModelsWithProvider() []ModelWithProvider {
448450
sort.Slice(result, func(i, j int) bool { return result[i].Model.ID < result[j].Model.ID })
449451

450452
r.sortedModelsWithProvider = result
451-
return result
453+
return append([]ModelWithProvider(nil), result...)
454+
}
455+
456+
// cacheableCategory reports whether category is a known value that should be cached.
457+
// CategoryAll is handled separately (delegates to ListModelsWithProvider).
458+
var cacheableCategories = map[core.ModelCategory]struct{}{
459+
core.CategoryTextGeneration: {},
460+
core.CategoryEmbedding: {},
461+
core.CategoryImage: {},
462+
core.CategoryAudio: {},
463+
core.CategoryVideo: {},
464+
core.CategoryUtility: {},
452465
}
453466

454467
// ListModelsWithProviderByCategory returns models filtered by category, sorted by model ID.
455468
// If category is CategoryAll, returns all models (same as ListModelsWithProvider).
456-
// Results are cached per category and rebuilt only when the underlying models change.
469+
// Results for known categories are cached and rebuilt only when the underlying models change.
470+
// Returns a defensive copy so callers cannot mutate the internal cache.
457471
func (r *ModelRegistry) ListModelsWithProviderByCategory(category core.ModelCategory) []ModelWithProvider {
458472
if category == core.CategoryAll {
459473
return r.ListModelsWithProvider()
460474
}
461475

462-
r.mu.RLock()
463-
if r.categoryCache != nil {
464-
if cached, ok := r.categoryCache[category]; ok {
465-
r.mu.RUnlock()
466-
return cached
476+
_, cacheable := cacheableCategories[category]
477+
478+
if cacheable {
479+
r.mu.RLock()
480+
if r.categoryCache != nil {
481+
if cached, ok := r.categoryCache[category]; ok {
482+
r.mu.RUnlock()
483+
return append([]ModelWithProvider(nil), cached...)
484+
}
467485
}
486+
r.mu.RUnlock()
468487
}
469-
r.mu.RUnlock()
470488

471489
r.mu.Lock()
472490
defer r.mu.Unlock()
473-
if r.categoryCache != nil {
491+
if cacheable && r.categoryCache != nil {
474492
if cached, ok := r.categoryCache[category]; ok {
475-
return cached
493+
return append([]ModelWithProvider(nil), cached...)
476494
}
477495
}
478496

@@ -488,10 +506,12 @@ func (r *ModelRegistry) ListModelsWithProviderByCategory(category core.ModelCate
488506
}
489507
sort.Slice(result, func(i, j int) bool { return result[i].Model.ID < result[j].Model.ID })
490508

491-
if r.categoryCache == nil {
492-
r.categoryCache = make(map[core.ModelCategory][]ModelWithProvider)
509+
if cacheable {
510+
if r.categoryCache == nil {
511+
r.categoryCache = make(map[core.ModelCategory][]ModelWithProvider)
512+
}
513+
r.categoryCache[category] = result
493514
}
494-
r.categoryCache[category] = result
495515
return result
496516
}
497517

0 commit comments

Comments
 (0)