Skip to content

Commit 4923427

Browse files
mbelinkyvincentkoc
authored andcommitted
fix(plugins): pass embedding input type labels
1 parent 60f1e08 commit 4923427

4 files changed

Lines changed: 100 additions & 0 deletions

File tree

extensions/memory-core/src/memory/generic-embedding-provider.integration.test.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ function createMemoryEmbeddingOptions(overrides?: {
127127
provider: overrides?.provider ?? "openai-compatible",
128128
fallback: "none",
129129
model: overrides?.model ?? "text-embedding-bge-m3",
130+
inputType: "default",
131+
queryInputType: "query",
132+
documentInputType: "document",
130133
remote: {
131134
baseUrl: overrides?.baseUrl,
132135
apiKey: "fixture-token",
@@ -193,6 +196,9 @@ describe("memory-core generic embedding provider bridge", () => {
193196
baseUrl: server.baseUrl,
194197
model: "text-embedding-bge-m3",
195198
dimensions: 3,
199+
inputType: "default",
200+
queryInputType: "query",
201+
documentInputType: "document",
196202
headers: {
197203
accept: "application/json",
198204
"content-type": "application/json",
@@ -230,6 +236,7 @@ describe("memory-core generic embedding provider bridge", () => {
230236
model: "text-embedding-bge-m3",
231237
input: ["hello"],
232238
dimensions: 3,
239+
input_type: "query",
233240
},
234241
});
235242
expect(server.requests[0]?.body).not.toHaveProperty("encoding_format");
@@ -240,11 +247,13 @@ describe("memory-core generic embedding provider bridge", () => {
240247
model: "text-embedding-bge-m3",
241248
input: ["a", "abcd"],
242249
dimensions: 3,
250+
input_type: "document",
243251
});
244252
expect(server.requests[2]?.body).toEqual({
245253
model: "text-embedding-bge-m3",
246254
input: ["xy"],
247255
dimensions: 3,
256+
input_type: "document",
248257
});
249258
});
250259

extensions/openai-compatible-embeddings/src/embedding-provider.test.ts

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ describe("openai-compatible generic embedding provider", () => {
242242
},
243243
});
244244
expect(server.requests[0]?.body).not.toHaveProperty("encoding_format");
245+
expect(server.requests[0]?.body).not.toHaveProperty("input_type");
245246
expect(server.requests[0]?.headers["content-type"]).toContain("application/json");
246247
expect(server.requests[0]?.headers.accept).toBe("application/json");
247248
expect(server.requests[0]?.headers["x-local-runtime"]).toBe("ollama");
@@ -252,6 +253,58 @@ describe("openai-compatible generic embedding provider", () => {
252253
});
253254
});
254255

256+
it("maps configured memory input_type labels onto query and document requests", async () => {
257+
const server = await startEmbeddingServer({
258+
respond: ({ body }) => {
259+
const input = body.input;
260+
const texts = Array.isArray(input) ? input : [input];
261+
return {
262+
object: "list",
263+
data: texts.map((text, index) => ({
264+
object: "embedding",
265+
embedding: [String(text).length, index + 0.25, 1],
266+
index,
267+
})),
268+
model: String(body.model),
269+
};
270+
},
271+
});
272+
273+
const result = await openAICompatibleEmbeddingProviderAdapter.create(
274+
createOptions({
275+
model: "text-embedding-bge-m3",
276+
inputType: " default ",
277+
queryInputType: " query ",
278+
documentInputType: " document ",
279+
remote: { baseUrl: server.baseUrl },
280+
}),
281+
);
282+
const provider = result.provider;
283+
if (!provider) {
284+
throw new Error("expected openai-compatible provider");
285+
}
286+
287+
expect(result.runtime?.cacheKeyData).toMatchObject({
288+
inputType: "default",
289+
queryInputType: "query",
290+
documentInputType: "document",
291+
});
292+
293+
await expect(provider.embed("hello", { inputType: "query" })).resolves.toEqual([5, 0.25, 1]);
294+
await expect(provider.embedBatch(["doc"], { inputType: "document" })).resolves.toEqual([
295+
[3, 0.25, 1],
296+
]);
297+
await expect(provider.embed("semantic", { inputType: "semantic" })).resolves.toEqual([
298+
8, 0.25, 1,
299+
]);
300+
301+
expect(server.requests.map((request) => request.body.input_type)).toEqual([
302+
"query",
303+
"document",
304+
"default",
305+
]);
306+
});
307+
255308
it("omits Authorization when no apiKey is configured", async () => {
256309
const server = await startEmbeddingServer();
257310
const { provider, client } = createOpenAICompatibleEmbeddingProvider(

extensions/openai-compatible-embeddings/src/embedding-provider.ts

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import type {
22
EmbeddingInput,
33
EmbeddingProvider,
44
EmbeddingProviderAdapter,
5+
EmbeddingProviderCallOptions,
56
EmbeddingProviderCreateOptions,
67
} from "openclaw/plugin-sdk/embedding-providers";
78
import { normalizeResolvedSecretInputString } from "openclaw/plugin-sdk/secret-input";
@@ -19,6 +20,9 @@ export type OpenAICompatibleEmbeddingClient = {
1920
ssrfPolicy?: SsrFPolicy;
2021
model: string;
2122
dimensions?: number;
23+
inputType?: string;
24+
queryInputType?: string;
25+
documentInputType?: string;
2226
};
2327

2428
type OpenAICompatibleEmbeddingResponse = {
@@ -55,6 +59,24 @@ function normalizeDimensions(value: number | undefined): number | undefined {
5559
return value;
5660
}
5761

62+
function normalizeOptionalInputType(value: string | undefined): string | undefined {
63+
const inputType = value?.trim();
64+
return inputType ? inputType : undefined;
65+
}
66+
67+
function resolveRequestInputType(
68+
client: OpenAICompatibleEmbeddingClient,
69+
kind: EmbeddingProviderCallOptions["inputType"] | undefined,
70+
): string | undefined {
71+
if (kind === "query") {
72+
return client.queryInputType ?? client.inputType;
73+
}
74+
if (kind === "document") {
75+
return client.documentInputType ?? client.inputType;
76+
}
77+
return client.inputType;
78+
}
79+
5880
function normalizeHeaderName(name: string): string {
5981
return name.trim().toLowerCase();
6082
}
@@ -164,12 +186,15 @@ async function postEmbeddingRequest(params: {
164186
client: OpenAICompatibleEmbeddingClient;
165187
input: string[];
166188
signal?: AbortSignal;
189+
inputType?: EmbeddingProviderCallOptions["inputType"];
167190
}): Promise<number[][]> {
168191
const { client, input } = params;
192+
const inputType = resolveRequestInputType(client, params.inputType);
169193
const body = {
170194
model: client.model,
171195
input,
172196
...(typeof client.dimensions === "number" ? { dimensions: client.dimensions } : {}),
197+
...(inputType ? { input_type: inputType } : {}),
173198
};
174199
const { response, release } = await fetchWithSsrFGuard({
175200
url: `${client.baseUrl}/embeddings`,
@@ -206,6 +231,9 @@ export function createOpenAICompatibleEmbeddingClient(
206231
value: options.remote?.apiKey,
207232
path: "embeddingProviders.openai-compatible.remote.apiKey",
208233
})?.trim();
234+
const inputType = normalizeOptionalInputType(options.inputType);
235+
const queryInputType = normalizeOptionalInputType(options.queryInputType);
236+
const documentInputType = normalizeOptionalInputType(options.documentInputType);
209237
return {
210238
baseUrl,
211239
headers: buildHeaders({ apiKey, extra: options.remote?.headers }),
@@ -214,6 +242,9 @@ export function createOpenAICompatibleEmbeddingClient(
214242
...(options.dimensions !== undefined
215243
? { dimensions: normalizeDimensions(options.dimensions) }
216244
: {}),
245+
...(inputType ? { inputType } : {}),
246+
...(queryInputType ? { queryInputType } : {}),
247+
...(documentInputType ? { documentInputType } : {}),
217248
};
218249
}
219250

@@ -230,6 +261,7 @@ export function createOpenAICompatibleEmbeddingProvider(options: EmbeddingProvid
230261
client,
231262
input: inputs.map(embeddingInputToText),
232263
signal: callOptions?.signal,
264+
inputType: callOptions?.inputType,
233265
});
234266
};
235267
return {
@@ -266,6 +298,9 @@ export const openAICompatibleEmbeddingProviderAdapter: EmbeddingProviderAdapter
266298
baseUrl: client.baseUrl,
267299
model: client.model,
268300
...(typeof client.dimensions === "number" ? { dimensions: client.dimensions } : {}),
301+
...(client.inputType ? { inputType: client.inputType } : {}),
302+
...(client.queryInputType ? { queryInputType: client.queryInputType } : {}),
303+
...(client.documentInputType ? { documentInputType: client.documentInputType } : {}),
269304
...(cacheHeaders ? { headers: cacheHeaders } : {}),
270305
},
271306
},

src/plugins/embedding-providers.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ export type EmbeddingProviderCreateOptions = {
4545
headers?: Record<string, string>;
4646
};
4747
model: string;
48+
inputType?: string;
49+
queryInputType?: string;
50+
documentInputType?: string;
4851
local?: {
4952
modelPath?: string;
5053
modelCacheDir?: string;

0 commit comments

Comments
 (0)