Skip to content

Commit f152133

Browse files
Backport: feat(ai/core): support plain string model IDs in rerank() (#14214)
1 parent 99327b1 commit f152133

File tree

5 files changed

+169
-2
lines changed

5 files changed

+169
-2
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
"ai": patch
3+
---
4+
5+
feat (ai/core): support plain string model IDs in `rerank()` function
6+
7+
The `rerank()` function now accepts plain model strings (e.g., `'cohere/rerank-v3.5'`) in addition to `RerankingModel` objects, matching the behavior of `generateText`, `embed`, and other core functions.

packages/ai/src/model/resolve-model.test.ts

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@ import { beforeEach, afterEach, describe, expect, it, vi } from 'vitest';
44

55
import { MockEmbeddingModelV3 } from '../test/mock-embedding-model-v3';
66
import { MockLanguageModelV3 } from '../test/mock-language-model-v3';
7+
import { MockRerankingModelV3 } from '../test/mock-reranking-model-v3';
78
import { MockVideoModelV3 } from '../test/mock-video-model-v3';
89
import { customProvider } from '../registry/custom-provider';
910
import { MockImageModelV2 } from '../test/mock-image-model-v2';
1011
import {
1112
resolveEmbeddingModel,
1213
resolveImageModel,
1314
resolveLanguageModel,
15+
resolveRerankingModel,
1416
resolveVideoModel,
1517
} from './resolve-model';
1618

@@ -356,3 +358,129 @@ describe('resolveVideoModel', () => {
356358
});
357359
});
358360
});
361+
362+
describe('resolveRerankingModel', () => {
363+
describe('when a reranking model v3 is provided', () => {
364+
it('should return it as-is', () => {
365+
const originalModel = new MockRerankingModelV3({
366+
provider: 'test-provider',
367+
modelId: 'test-model-id',
368+
});
369+
370+
const resolvedModel = resolveRerankingModel(originalModel);
371+
372+
expect(resolvedModel).toBe(originalModel);
373+
expect(resolvedModel.specificationVersion).toBe('v3');
374+
});
375+
});
376+
377+
describe('when a reranking model v3 is provided', () => {
378+
it('should return it as v3', () => {
379+
const resolvedModel = resolveRerankingModel(
380+
new MockRerankingModelV3({
381+
provider: 'test-provider',
382+
modelId: 'test-model-id',
383+
}),
384+
);
385+
386+
expect(resolvedModel.provider).toBe('test-provider');
387+
expect(resolvedModel.modelId).toBe('test-model-id');
388+
expect(resolvedModel.specificationVersion).toBe('v3');
389+
});
390+
});
391+
392+
describe('when a string is provided and the global default provider is not set', () => {
393+
it('should return a gateway reranking model', () => {
394+
const mockModel = new MockRerankingModelV3({
395+
provider: 'gateway',
396+
modelId: 'test-model-id',
397+
});
398+
399+
const rerankingModelSpy = vi
400+
.spyOn(gateway, 'rerankingModel')
401+
.mockReturnValue(mockModel as any);
402+
403+
try {
404+
const resolvedModel = resolveRerankingModel('test-model-id');
405+
406+
expect(resolvedModel.provider).toBe('gateway');
407+
expect(resolvedModel.modelId).toBe('test-model-id');
408+
} finally {
409+
rerankingModelSpy.mockRestore();
410+
}
411+
});
412+
});
413+
414+
describe('when a string is provided and the global default provider is set', () => {
415+
beforeEach(() => {
416+
globalThis.AI_SDK_DEFAULT_PROVIDER = customProvider({
417+
rerankingModels: {
418+
'test-model-id': new MockRerankingModelV3({
419+
provider: 'global-test-provider',
420+
modelId: 'actual-test-model-id',
421+
}),
422+
},
423+
});
424+
});
425+
426+
afterEach(() => {
427+
delete globalThis.AI_SDK_DEFAULT_PROVIDER;
428+
});
429+
430+
it('should return a reranking model from the global default provider', () => {
431+
const resolvedModel = resolveRerankingModel('test-model-id');
432+
433+
expect(resolvedModel.provider).toBe('global-test-provider');
434+
expect(resolvedModel.modelId).toBe('actual-test-model-id');
435+
});
436+
});
437+
438+
describe('when a string is provided and the provider does not support reranking models', () => {
439+
beforeEach(() => {
440+
globalThis.AI_SDK_DEFAULT_PROVIDER = {
441+
specificationVersion: 'v3' as const,
442+
languageModel: () => {
443+
throw new Error('not implemented');
444+
},
445+
embeddingModel: () => {
446+
throw new Error('not implemented');
447+
},
448+
imageModel: () => {
449+
throw new Error('not implemented');
450+
},
451+
};
452+
});
453+
454+
afterEach(() => {
455+
delete globalThis.AI_SDK_DEFAULT_PROVIDER;
456+
});
457+
458+
it('should throw an error', () => {
459+
expect(() => resolveRerankingModel('test-model-id')).toThrow(
460+
'The default provider does not support reranking models.',
461+
);
462+
});
463+
});
464+
465+
describe('when a model with unsupported specification version is provided', () => {
466+
it('should throw UnsupportedModelVersionError', () => {
467+
const unsupportedModel = {
468+
specificationVersion: 'v1',
469+
provider: 'test-provider',
470+
modelId: 'test-model-id',
471+
} as any;
472+
473+
expect(() => resolveRerankingModel(unsupportedModel)).toThrow();
474+
});
475+
476+
it('should throw UnsupportedModelVersionError for v2 models', () => {
477+
const v2Model = {
478+
specificationVersion: 'v2',
479+
provider: 'test-provider',
480+
modelId: 'test-model-id',
481+
} as any;
482+
483+
expect(() => resolveRerankingModel(v2Model)).toThrow();
484+
});
485+
});
486+
});

packages/ai/src/model/resolve-model.ts

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import {
55
ImageModelV3,
66
LanguageModelV3,
77
ProviderV3,
8+
RerankingModelV3,
89
SpeechModelV3,
910
TranscriptionModelV3,
1011
} from '@ai-sdk/provider';
@@ -19,6 +20,7 @@ import { asLanguageModelV3 } from './as-language-model-v3';
1920
import { asSpeechModelV3 } from './as-speech-model-v3';
2021
import { asTranscriptionModelV3 } from './as-transcription-model-v3';
2122
import { ImageModel } from '../types/image-model';
23+
import { RerankingModel } from '../types/reranking-model';
2224
import { VideoModel } from '../types/video-model';
2325

2426
export function resolveLanguageModel(model: LanguageModel): LanguageModelV3 {
@@ -154,6 +156,33 @@ export function resolveVideoModel(
154156
return model;
155157
}
156158

159+
export function resolveRerankingModel(model: RerankingModel): RerankingModelV3 {
160+
if (typeof model === 'string') {
161+
const provider = getGlobalProvider();
162+
const rerankingModel = provider.rerankingModel;
163+
164+
if (!rerankingModel) {
165+
throw new Error(
166+
'The default provider does not support reranking models. ' +
167+
'Please use a RerankingModel object from a provider (e.g., gateway.rerankingModel("model-id")).',
168+
);
169+
}
170+
171+
return rerankingModel(model);
172+
}
173+
174+
if (model.specificationVersion !== 'v3') {
175+
const unsupportedModel: any = model;
176+
throw new UnsupportedModelVersionError({
177+
version: unsupportedModel.specificationVersion,
178+
provider: unsupportedModel.provider,
179+
modelId: unsupportedModel.modelId,
180+
});
181+
}
182+
183+
return model;
184+
}
185+
157186
function getGlobalProvider(): ProviderV3 {
158187
return globalThis.AI_SDK_DEFAULT_PROVIDER ?? gateway;
159188
}

packages/ai/src/rerank/rerank.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import { getTracer } from '../telemetry/get-tracer';
77
import { recordSpan } from '../telemetry/record-span';
88
import { selectTelemetryAttributes } from '../telemetry/select-telemetry-attributes';
99
import { TelemetrySettings } from '../telemetry/telemetry-settings';
10+
import { resolveRerankingModel } from '../model/resolve-model';
1011
import { RerankingModel } from '../types';
1112
import { RerankResult } from './rerank-result';
1213
import { logWarnings } from '../logger/log-warnings';
@@ -28,7 +29,7 @@ import { logWarnings } from '../logger/log-warnings';
2829
* @returns A result object that contains the reranked documents, the reranked indices, and additional information.
2930
*/
3031
export async function rerank<VALUE extends JSONObject | string>({
31-
model,
32+
model: modelArg,
3233
documents,
3334
query,
3435
topN,
@@ -88,6 +89,8 @@ export async function rerank<VALUE extends JSONObject | string>({
8889
*/
8990
providerOptions?: ProviderOptions;
9091
}): Promise<RerankResult<VALUE>> {
92+
const model = resolveRerankingModel(modelArg);
93+
9194
if (documents.length === 0) {
9295
return new DefaultRerankResult({
9396
originalDocuments: [],

packages/ai/src/types/reranking-model.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ import { RerankingModelV3 } from '@ai-sdk/provider';
33
/**
44
* Reranking model that is used by the AI SDK.
55
*/
6-
export type RerankingModel = RerankingModelV3;
6+
export type RerankingModel = string | RerankingModelV3;

0 commit comments

Comments
 (0)