@@ -4,13 +4,15 @@ import { beforeEach, afterEach, describe, expect, it, vi } from 'vitest';
44
55import { MockEmbeddingModelV3 } from '../test/mock-embedding-model-v3' ;
66import { MockLanguageModelV3 } from '../test/mock-language-model-v3' ;
7+ import { MockRerankingModelV3 } from '../test/mock-reranking-model-v3' ;
78import { MockVideoModelV3 } from '../test/mock-video-model-v3' ;
89import { customProvider } from '../registry/custom-provider' ;
910import { MockImageModelV2 } from '../test/mock-image-model-v2' ;
1011import {
1112 resolveEmbeddingModel ,
1213 resolveImageModel ,
1314 resolveLanguageModel ,
15+ resolveRerankingModel ,
1416 resolveVideoModel ,
1517} from './resolve-model' ;
1618
@@ -356,3 +358,129 @@ describe('resolveVideoModel', () => {
356358 } ) ;
357359 } ) ;
358360} ) ;
361+
362+ describe ( 'resolveRerankingModel' , ( ) => {
363+ describe ( 'when a reranking model v3 is provided' , ( ) => {
364+ it ( 'should return it as-is' , ( ) => {
365+ const originalModel = new MockRerankingModelV3 ( {
366+ provider : 'test-provider' ,
367+ modelId : 'test-model-id' ,
368+ } ) ;
369+
370+ const resolvedModel = resolveRerankingModel ( originalModel ) ;
371+
372+ expect ( resolvedModel ) . toBe ( originalModel ) ;
373+ expect ( resolvedModel . specificationVersion ) . toBe ( 'v3' ) ;
374+ } ) ;
375+ } ) ;
376+
377+ describe ( 'when a reranking model v3 is provided' , ( ) => {
378+ it ( 'should return it as v3' , ( ) => {
379+ const resolvedModel = resolveRerankingModel (
380+ new MockRerankingModelV3 ( {
381+ provider : 'test-provider' ,
382+ modelId : 'test-model-id' ,
383+ } ) ,
384+ ) ;
385+
386+ expect ( resolvedModel . provider ) . toBe ( 'test-provider' ) ;
387+ expect ( resolvedModel . modelId ) . toBe ( 'test-model-id' ) ;
388+ expect ( resolvedModel . specificationVersion ) . toBe ( 'v3' ) ;
389+ } ) ;
390+ } ) ;
391+
392+ describe ( 'when a string is provided and the global default provider is not set' , ( ) => {
393+ it ( 'should return a gateway reranking model' , ( ) => {
394+ const mockModel = new MockRerankingModelV3 ( {
395+ provider : 'gateway' ,
396+ modelId : 'test-model-id' ,
397+ } ) ;
398+
399+ const rerankingModelSpy = vi
400+ . spyOn ( gateway , 'rerankingModel' )
401+ . mockReturnValue ( mockModel as any ) ;
402+
403+ try {
404+ const resolvedModel = resolveRerankingModel ( 'test-model-id' ) ;
405+
406+ expect ( resolvedModel . provider ) . toBe ( 'gateway' ) ;
407+ expect ( resolvedModel . modelId ) . toBe ( 'test-model-id' ) ;
408+ } finally {
409+ rerankingModelSpy . mockRestore ( ) ;
410+ }
411+ } ) ;
412+ } ) ;
413+
414+ describe ( 'when a string is provided and the global default provider is set' , ( ) => {
415+ beforeEach ( ( ) => {
416+ globalThis . AI_SDK_DEFAULT_PROVIDER = customProvider ( {
417+ rerankingModels : {
418+ 'test-model-id' : new MockRerankingModelV3 ( {
419+ provider : 'global-test-provider' ,
420+ modelId : 'actual-test-model-id' ,
421+ } ) ,
422+ } ,
423+ } ) ;
424+ } ) ;
425+
426+ afterEach ( ( ) => {
427+ delete globalThis . AI_SDK_DEFAULT_PROVIDER ;
428+ } ) ;
429+
430+ it ( 'should return a reranking model from the global default provider' , ( ) => {
431+ const resolvedModel = resolveRerankingModel ( 'test-model-id' ) ;
432+
433+ expect ( resolvedModel . provider ) . toBe ( 'global-test-provider' ) ;
434+ expect ( resolvedModel . modelId ) . toBe ( 'actual-test-model-id' ) ;
435+ } ) ;
436+ } ) ;
437+
438+ describe ( 'when a string is provided and the provider does not support reranking models' , ( ) => {
439+ beforeEach ( ( ) => {
440+ globalThis . AI_SDK_DEFAULT_PROVIDER = {
441+ specificationVersion : 'v3' as const ,
442+ languageModel : ( ) => {
443+ throw new Error ( 'not implemented' ) ;
444+ } ,
445+ embeddingModel : ( ) => {
446+ throw new Error ( 'not implemented' ) ;
447+ } ,
448+ imageModel : ( ) => {
449+ throw new Error ( 'not implemented' ) ;
450+ } ,
451+ } ;
452+ } ) ;
453+
454+ afterEach ( ( ) => {
455+ delete globalThis . AI_SDK_DEFAULT_PROVIDER ;
456+ } ) ;
457+
458+ it ( 'should throw an error' , ( ) => {
459+ expect ( ( ) => resolveRerankingModel ( 'test-model-id' ) ) . toThrow (
460+ 'The default provider does not support reranking models.' ,
461+ ) ;
462+ } ) ;
463+ } ) ;
464+
465+ describe ( 'when a model with unsupported specification version is provided' , ( ) => {
466+ it ( 'should throw UnsupportedModelVersionError' , ( ) => {
467+ const unsupportedModel = {
468+ specificationVersion : 'v1' ,
469+ provider : 'test-provider' ,
470+ modelId : 'test-model-id' ,
471+ } as any ;
472+
473+ expect ( ( ) => resolveRerankingModel ( unsupportedModel ) ) . toThrow ( ) ;
474+ } ) ;
475+
476+ it ( 'should throw UnsupportedModelVersionError for v2 models' , ( ) => {
477+ const v2Model = {
478+ specificationVersion : 'v2' ,
479+ provider : 'test-provider' ,
480+ modelId : 'test-model-id' ,
481+ } as any ;
482+
483+ expect ( ( ) => resolveRerankingModel ( v2Model ) ) . toThrow ( ) ;
484+ } ) ;
485+ } ) ;
486+ } ) ;
0 commit comments