Skip to content

Commit 66d362a

Browse files
shakkernerdjalehman
authored andcommitted
fix: adapt OpenAI batch upload sizing
1 parent 3683222 commit 66d362a

4 files changed

Lines changed: 258 additions & 18 deletions

File tree

extensions/openai/embedding-batch.test.ts

Lines changed: 116 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ import { parseOpenAiBatchOutput, runOpenAiEmbeddingBatches } from "./embedding-b
44

55
const jsonlEncoder = new TextEncoder();
66

7-
function jsonResponse(body: unknown): Response {
7+
function jsonResponse(body: unknown, status = 200): Response {
88
return new Response(JSON.stringify(body), {
9-
status: 200,
9+
status,
1010
headers: { "Content-Type": "application/json" },
1111
});
1212
}
@@ -129,4 +129,118 @@ describe("OpenAI embedding batch output", () => {
129129
["2", [3]],
130130
]);
131131
});
132+
133+
it("adapts OpenAI-compatible upload groups after payload-size rejection", async () => {
134+
const requests: Parameters<typeof runOpenAiEmbeddingBatches>[0]["requests"] = Array.from(
135+
{ length: 4 },
136+
(_, index) => ({
137+
custom_id: String(index),
138+
method: "POST" as const,
139+
url: "/v1/embeddings",
140+
body: {
141+
model: "text-embedding-3-small",
142+
input: `payload-${index}`,
143+
},
144+
}),
145+
);
146+
const uploadedGroups: string[][] = [];
147+
const requestsByFileId = new Map<string, Array<{ custom_id?: string }>>();
148+
const outputByFileId = new Map<string, string>();
149+
const debug = vi.fn();
150+
let fileIndex = 0;
151+
let batchIndex = 0;
152+
const fetchImpl = vi.fn(async (input: RequestInfo | URL, init?: RequestInit) => {
153+
const url = fetchInputUrl(input);
154+
if (url.endsWith("/files") && init?.method === "POST") {
155+
const form = init.body as FormData;
156+
const file = form.get("file");
157+
if (!(file instanceof Blob)) {
158+
throw new Error("missing batch upload file");
159+
}
160+
const uploadedRequests = (await file.text())
161+
.split("\n")
162+
.map((line) => JSON.parse(line) as { custom_id?: string });
163+
const customIds = uploadedRequests.map((request) => request.custom_id ?? "");
164+
uploadedGroups.push(customIds);
165+
if (uploadedRequests.length > 2) {
166+
return jsonResponse(
167+
{
168+
error: {
169+
message: "Request body too large. Maximum allowed: 10 MB",
170+
type: "payload_too_large",
171+
code: "PAYLOAD_TOO_LARGE",
172+
},
173+
},
174+
413,
175+
);
176+
}
177+
const fileId = `file-${fileIndex}`;
178+
fileIndex += 1;
179+
requestsByFileId.set(fileId, uploadedRequests);
180+
return jsonResponse({ id: fileId });
181+
}
182+
if (url.endsWith("/batches") && init?.method === "POST") {
183+
const body = parseStringBody(init) as { input_file_id?: string };
184+
const batchId = `batch-${batchIndex}`;
185+
const outputFileId = `output-${batchIndex}`;
186+
batchIndex += 1;
187+
const uploadedRequests = requestsByFileId.get(body.input_file_id ?? "") ?? [];
188+
outputByFileId.set(
189+
outputFileId,
190+
uploadedRequests
191+
.map((request) =>
192+
JSON.stringify({
193+
custom_id: request.custom_id,
194+
response: {
195+
status_code: 200,
196+
body: { data: [{ embedding: [Number(request.custom_id) + 1] }] },
197+
},
198+
}),
199+
)
200+
.join("\n"),
201+
);
202+
return jsonResponse({ id: batchId, status: "completed", output_file_id: outputFileId });
203+
}
204+
const contentMatch = url.match(/\/files\/([^/]+)\/content$/);
205+
if (contentMatch) {
206+
return new Response(outputByFileId.get(contentMatch[1] ?? "") ?? "", { status: 200 });
207+
}
208+
return new Response("unexpected request", { status: 500 });
209+
});
210+
211+
const byCustomId = await runOpenAiEmbeddingBatches({
212+
openAi: {
213+
baseUrl: "https://openai-compatible.example/v1",
214+
headers: { Authorization: "Bearer test" },
215+
model: "text-embedding-3-small",
216+
fetchImpl,
217+
},
218+
agentId: "main",
219+
requests,
220+
wait: true,
221+
concurrency: 1,
222+
pollIntervalMs: 1000,
223+
timeoutMs: 60_000,
224+
debug,
225+
});
226+
227+
expect(uploadedGroups).toEqual([
228+
["0", "1", "2", "3"],
229+
["0", "1"],
230+
["2", "3"],
231+
]);
232+
expect(debug).toHaveBeenCalledWith(
233+
"memory embeddings: openai batch upload too large; splitting group",
234+
expect.objectContaining({
235+
requests: 4,
236+
parts: [2, 2],
237+
}),
238+
);
239+
expect([...byCustomId.entries()]).toEqual([
240+
["0", [1]],
241+
["1", [2]],
242+
["2", [3]],
243+
["3", [4]],
244+
]);
245+
});
132246
});

extensions/openai/embedding-batch.ts

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,25 @@ async function fetchOpenAiBatchResource<T>(params: {
134134
});
135135
}
136136

137+
function formatOpenAiBatchError(error: unknown): string {
138+
return error instanceof Error ? error.message : String(error);
139+
}
140+
141+
function isOpenAiBatchUploadTooLargeError(error: unknown): boolean {
142+
const message = formatOpenAiBatchError(error);
143+
if (!/openai batch file upload failed/i.test(message)) {
144+
return false;
145+
}
146+
return (
147+
/\b413\b/.test(message) ||
148+
/payload too large/i.test(message) ||
149+
/request body too large/i.test(message) ||
150+
/file too large/i.test(message) ||
151+
/maximum allowed/i.test(message) ||
152+
/max(?:imum)? (?:body|payload|file) (?:size )?(?:exceeded|limit)/i.test(message)
153+
);
154+
}
155+
137156
export function parseOpenAiBatchOutput(text: string): OpenAiBatchOutputLine[] {
138157
if (!text.trim()) {
139158
return [];
@@ -294,6 +313,15 @@ export async function runOpenAiEmbeddingBatches(
294313
maxJsonlBytes: params.maxJsonlBytes ?? OPENAI_BATCH_MAX_JSONL_BYTES,
295314
debugLabel: "memory embeddings: openai batch submit",
296315
}),
316+
shouldSplitGroupOnError: isOpenAiBatchUploadTooLargeError,
317+
onSplitGroup: ({ error, group, parts, depth }) => {
318+
params.debug?.("memory embeddings: openai batch upload too large; splitting group", {
319+
requests: group.length,
320+
parts: parts.map((part) => part.length),
321+
depth,
322+
error: formatOpenAiBatchError(error),
323+
});
324+
},
297325
runGroup: async ({ group, groupIndex, groups, byCustomId, pollIntervalMs, timeoutMs }) => {
298326
const batchInfo = await submitOpenAiBatch({
299327
openAi: params.openAi,

packages/memory-host-sdk/src/host/batch-runner.test.ts

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,4 +93,65 @@ describe("buildEmbeddingBatchGroupOptions", () => {
9393

9494
expect(groups).toEqual([["one", "two"], ["three"]]);
9595
});
96+
97+
it("splits provider-rejected batch groups when the error is splittable", async () => {
98+
const uploadTooLarge = new Error("batch upload failed: 413 payload too large");
99+
const calls: string[][] = [];
100+
const onSplitGroup = vi.fn();
101+
102+
await runEmbeddingBatchGroups({
103+
requests: ["one", "two", "three", "four"],
104+
maxRequests: 100,
105+
wait: true,
106+
pollIntervalMs: 1000,
107+
timeoutMs: 60_000,
108+
concurrency: 1,
109+
debugLabel: "embedding batch submit",
110+
shouldSplitGroupOnError: (error) => error === uploadTooLarge,
111+
onSplitGroup,
112+
runGroup: async ({ group }) => {
113+
calls.push([...group]);
114+
if (group.length === 4) {
115+
throw uploadTooLarge;
116+
}
117+
},
118+
});
119+
120+
expect(calls).toEqual([
121+
["one", "two", "three", "four"],
122+
["one", "two"],
123+
["three", "four"],
124+
]);
125+
expect(onSplitGroup).toHaveBeenCalledWith(
126+
expect.objectContaining({
127+
error: uploadTooLarge,
128+
group: ["one", "two", "three", "four"],
129+
parts: [
130+
["one", "two"],
131+
["three", "four"],
132+
],
133+
depth: 0,
134+
}),
135+
);
136+
});
137+
138+
it("does not split a single rejected batch request", async () => {
139+
const uploadTooLarge = new Error("batch upload failed: 413 payload too large");
140+
141+
await expect(
142+
runEmbeddingBatchGroups({
143+
requests: ["one"],
144+
maxRequests: 100,
145+
wait: true,
146+
pollIntervalMs: 1000,
147+
timeoutMs: 60_000,
148+
concurrency: 1,
149+
debugLabel: "embedding batch submit",
150+
shouldSplitGroupOnError: () => true,
151+
runGroup: async () => {
152+
throw uploadTooLarge;
153+
},
154+
}),
155+
).rejects.toThrow(uploadTooLarge);
156+
});
96157
});

packages/memory-host-sdk/src/host/batch-runner.ts

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,24 @@ export type EmbeddingBatchExecutionParams = {
1414
debug?: (message: string, data?: Record<string, unknown>) => void;
1515
};
1616

17+
type EmbeddingBatchGroupRunArgs<TRequest> = {
18+
group: TRequest[];
19+
groupIndex: number;
20+
groups: number;
21+
byCustomId: Map<string, number[]>;
22+
pollIntervalMs: number;
23+
timeoutMs: number;
24+
};
25+
26+
type EmbeddingBatchSplitArgs<TRequest> = {
27+
error: unknown;
28+
group: TRequest[];
29+
parts: TRequest[][];
30+
groupIndex: number;
31+
groups: number;
32+
depth: number;
33+
};
34+
1735
/** Clamp polling to both configured poll interval and total timeout budget. */
1836
function resolveEmbeddingBatchPollIntervalMs(params: {
1937
pollIntervalMs: number;
@@ -40,14 +58,9 @@ export async function runEmbeddingBatchGroups<TRequest>(params: {
4058
concurrency: EmbeddingBatchExecutionParams["concurrency"];
4159
debugLabel: string;
4260
debug?: EmbeddingBatchExecutionParams["debug"];
43-
runGroup: (args: {
44-
group: TRequest[];
45-
groupIndex: number;
46-
groups: number;
47-
byCustomId: Map<string, number[]>;
48-
pollIntervalMs: number;
49-
timeoutMs: number;
50-
}) => Promise<void>;
61+
shouldSplitGroupOnError?: (error: unknown, group: TRequest[]) => boolean;
62+
onSplitGroup?: (args: EmbeddingBatchSplitArgs<TRequest>) => void;
63+
runGroup: (args: EmbeddingBatchGroupRunArgs<TRequest>) => Promise<void>;
5164
}): Promise<Map<string, number[]>> {
5265
if (params.requests.length === 0) {
5366
return new Map();
@@ -58,15 +71,39 @@ export async function runEmbeddingBatchGroups<TRequest>(params: {
5871
});
5972
const byCustomId = new Map<string, number[]>();
6073
const pollIntervalMs = resolveEmbeddingBatchPollIntervalMs(params);
74+
const runGroup = async (group: TRequest[], groupIndex: number, depth = 0): Promise<void> => {
75+
try {
76+
await params.runGroup({
77+
group,
78+
groupIndex,
79+
groups: groups.length,
80+
byCustomId,
81+
pollIntervalMs,
82+
timeoutMs: params.timeoutMs,
83+
});
84+
} catch (error) {
85+
if (group.length <= 1 || !params.shouldSplitGroupOnError?.(error, group)) {
86+
throw error;
87+
}
88+
const splitAt = Math.ceil(group.length / 2);
89+
const parts = [group.slice(0, splitAt), group.slice(splitAt)].filter(
90+
(part) => part.length > 0,
91+
);
92+
params.onSplitGroup?.({
93+
error,
94+
group,
95+
parts,
96+
groupIndex,
97+
groups: groups.length,
98+
depth,
99+
});
100+
for (const part of parts) {
101+
await runGroup(part, groupIndex, depth + 1);
102+
}
103+
}
104+
};
61105
const tasks = groups.map((group, groupIndex) => async () => {
62-
await params.runGroup({
63-
group,
64-
groupIndex,
65-
groups: groups.length,
66-
byCustomId,
67-
pollIntervalMs,
68-
timeoutMs: params.timeoutMs,
69-
});
106+
await runGroup(group, groupIndex);
70107
});
71108

72109
params.debug?.(params.debugLabel, {

0 commit comments

Comments
 (0)