Skip to content

Commit ec18852

Browse files
Backport: feat(gateway): add reranking model support (#14204)
1 parent 5b155e6 commit ec18852

File tree

8 files changed

+487
-0
lines changed

8 files changed

+487
-0
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@ai-sdk/gateway": patch
3+
---
4+
5+
feat (provider/gateway): add reranking model support with `rerankingModel()` and `reranking()` methods

content/providers/01-ai-sdk-providers/00-ai-gateway.mdx

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,34 @@ const { text } = await generateText({
176176

177177
AI Gateway language models can also be used in the `streamText` function and support structured data generation with [`Output`](/docs/reference/ai-sdk-core/output) (see [AI SDK Core](/docs/ai-sdk-core)).
178178

179+
## Reranking Models
180+
181+
You can create reranking models using the `rerankingModel` method on the provider instance:
182+
183+
```ts
184+
import { rerank } from 'ai';
185+
import { gateway } from '@ai-sdk/gateway';
186+
187+
const { ranking } = await rerank({
188+
model: gateway.rerankingModel('cohere/rerank-v3.5'),
189+
query: 'What is the capital of France?',
190+
documents: [
191+
'Paris is the capital of France.',
192+
'Berlin is the capital of Germany.',
193+
'Madrid is the capital of Spain.',
194+
],
195+
topN: 2,
196+
});
197+
198+
console.log(ranking);
199+
// [
200+
// { originalIndex: 0, score: 0.89, document: 'Paris is the capital of France.' },
201+
// { originalIndex: 2, score: 0.15, document: 'Madrid is the capital of Spain.' },
202+
// ]
203+
```
204+
205+
Reranking models are useful for improving search results in retrieval-augmented generation (RAG) pipelines by re-scoring candidate documents after an initial retrieval step.
206+
179207
## Available Models
180208

181209
The AI Gateway supports models from OpenAI, Anthropic, Google, Meta, xAI, Mistral, DeepSeek, Amazon Bedrock, Cohere, Perplexity, Alibaba, and other providers.

packages/gateway/src/gateway-provider.test.ts

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import { NoSuchModelError } from '@ai-sdk/provider';
1010
import { GatewayEmbeddingModel } from './gateway-embedding-model';
1111
import { GatewayImageModel } from './gateway-image-model';
1212
import { GatewayVideoModel } from './gateway-video-model';
13+
import { GatewayRerankingModel } from './gateway-reranking-model';
1314
import { getVercelOidcToken, getVercelRequestId } from './vercel-environment';
1415
import { resolve } from '@ai-sdk/provider-utils';
1516
import { GatewayLanguageModel } from './gateway-language-model';
@@ -133,6 +134,37 @@ function getGatewayVideoModelInternalConfig(
133134
return config;
134135
}
135136

137+
type GatewayRerankingModelInternalConfig = {
138+
provider: string;
139+
baseURL: string;
140+
headers: () => Promise<Record<string, string>>;
141+
fetch?: typeof fetch;
142+
o11yHeaders: () => Promise<Record<string, string>>;
143+
};
144+
145+
function assertIsGatewayRerankingModelInternalConfig(
146+
value: unknown,
147+
): asserts value is GatewayRerankingModelInternalConfig {
148+
if (
149+
!value ||
150+
typeof value !== 'object' ||
151+
typeof (value as { provider?: unknown }).provider !== 'string' ||
152+
typeof (value as { baseURL?: unknown }).baseURL !== 'string' ||
153+
typeof (value as { headers?: unknown }).headers !== 'function' ||
154+
typeof (value as { o11yHeaders?: unknown }).o11yHeaders !== 'function'
155+
) {
156+
throw new Error('Invalid GatewayRerankingModel configuration');
157+
}
158+
}
159+
160+
function getGatewayRerankingModelInternalConfig(
161+
model: GatewayRerankingModel,
162+
): GatewayRerankingModelInternalConfig {
163+
const config = Reflect.get(model as object, 'config');
164+
assertIsGatewayRerankingModelInternalConfig(config);
165+
return config;
166+
}
167+
136168
describe('GatewayProvider', () => {
137169
beforeEach(() => {
138170
vi.clearAllMocks();
@@ -324,6 +356,36 @@ describe('GatewayProvider', () => {
324356
expect(o11yHeaders).toEqual({ 'ai-o11y-request-id': 'mock-request-id' });
325357
});
326358

359+
it('should create GatewayRerankingModel for rerankingModel', () => {
360+
const provider = createGatewayProvider({
361+
baseURL: 'https://api.example.com',
362+
apiKey: 'test-api-key',
363+
});
364+
365+
const model = provider.rerankingModel('cohere/rerank-v3.5');
366+
367+
if (!(model instanceof GatewayRerankingModel)) {
368+
fail('Expected GatewayRerankingModel to be created');
369+
}
370+
371+
const config = getGatewayRerankingModelInternalConfig(model);
372+
expect(config.provider).toBe('gateway');
373+
expect(config.baseURL).toBe('https://api.example.com');
374+
});
375+
376+
it('should create GatewayRerankingModel for reranking alias', () => {
377+
const provider = createGatewayProvider({
378+
baseURL: 'https://api.example.com',
379+
apiKey: 'test-api-key',
380+
});
381+
382+
const model = provider.reranking('cohere/rerank-v3.5');
383+
384+
if (!(model instanceof GatewayRerankingModel)) {
385+
fail('Expected GatewayRerankingModel to be created');
386+
}
387+
});
388+
327389
it('should fetch available models', async () => {
328390
mockGetAvailableModels.mockReturnValue({ models: [] });
329391

packages/gateway/src/gateway-provider.ts

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@ import { GatewayLanguageModel } from './gateway-language-model';
2727
import { GatewayEmbeddingModel } from './gateway-embedding-model';
2828
import { GatewayImageModel } from './gateway-image-model';
2929
import { GatewayVideoModel } from './gateway-video-model';
30+
import { GatewayRerankingModel } from './gateway-reranking-model';
3031
import type { GatewayEmbeddingModelId } from './gateway-embedding-model-settings';
3132
import type { GatewayImageModelId } from './gateway-image-model-settings';
33+
import type { GatewayRerankingModelId } from './gateway-reranking-model-settings';
3234
import type { GatewayVideoModelId } from './gateway-video-model-settings';
3335
import { gatewayTools } from './gateway-tools';
3436
import { getVercelOidcToken, getVercelRequestId } from './vercel-environment';
@@ -37,6 +39,7 @@ import type {
3739
LanguageModelV3,
3840
EmbeddingModelV3,
3941
ImageModelV3,
42+
RerankingModelV3,
4043
Experimental_VideoModelV3,
4144
ProviderV3,
4245
} from '@ai-sdk/provider';
@@ -117,6 +120,16 @@ export interface GatewayProvider extends ProviderV3 {
117120
*/
118121
videoModel(modelId: GatewayVideoModelId): Experimental_VideoModelV3;
119122

123+
/**
124+
* Creates a model for reranking documents.
125+
*/
126+
reranking(modelId: GatewayRerankingModelId): RerankingModelV3;
127+
128+
/**
129+
* Creates a model for reranking documents.
130+
*/
131+
rerankingModel(modelId: GatewayRerankingModelId): RerankingModelV3;
132+
120133
/**
121134
* Gateway-specific tools executed server-side.
122135
*/
@@ -354,6 +367,17 @@ export function createGatewayProvider(
354367
o11yHeaders: createO11yHeaders(),
355368
});
356369
};
370+
const createRerankingModel = (modelId: GatewayRerankingModelId) => {
371+
return new GatewayRerankingModel(modelId, {
372+
provider: 'gateway',
373+
baseURL,
374+
headers: getHeaders,
375+
fetch: options.fetch,
376+
o11yHeaders: createO11yHeaders(),
377+
});
378+
};
379+
provider.rerankingModel = createRerankingModel;
380+
provider.reranking = createRerankingModel;
357381
provider.chat = provider.languageModel;
358382
provider.embedding = provider.embeddingModel;
359383
provider.image = provider.imageModel;
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
export type GatewayRerankingModelId = 'cohere/rerank-v3.5' | (string & {});

0 commit comments

Comments
 (0)