Skip to content

Commit c20a055

Browse files
committed
fix(provider): honor Codex response timeouts
1 parent da5fe99 commit c20a055

2 files changed

Lines changed: 224 additions & 14 deletions

File tree

src/llm/providers/openai-codex-responses.test.ts

Lines changed: 140 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,48 @@ function createJwt(payload: Record<string, unknown>): string {
1212
return `${header}.${body}.signature`;
1313
}
1414

15+
function stubTimeoutSignal(timeoutMs: number): void {
16+
vi.spyOn(AbortSignal, "timeout").mockImplementation((actualTimeoutMs) => {
17+
expect(actualTimeoutMs).toBe(timeoutMs);
18+
const controller = new AbortController();
19+
queueMicrotask(() => {
20+
controller.abort(new DOMException("timed out", "TimeoutError"));
21+
});
22+
return controller.signal;
23+
});
24+
}
25+
26+
function stubHangingFetch(timeoutMs: number): void {
27+
stubTimeoutSignal(timeoutMs);
28+
29+
vi.stubGlobal(
30+
"fetch",
31+
vi.fn(
32+
(_input: Parameters<typeof fetch>[0], init?: Parameters<typeof fetch>[1]) =>
33+
new Promise<Response>((_resolve, reject) => {
34+
const signal = init?.signal;
35+
if (!signal) {
36+
reject(new Error("missing abort signal"));
37+
return;
38+
}
39+
40+
const abort = () => {
41+
reject(
42+
signal.reason instanceof Error
43+
? signal.reason
44+
: new DOMException("aborted", "AbortError"),
45+
);
46+
};
47+
if (signal.aborted) {
48+
abort();
49+
return;
50+
}
51+
signal.addEventListener("abort", abort, { once: true });
52+
}),
53+
),
54+
);
55+
}
56+
1557
describe("extractOpenAICodexAccountId", () => {
1658
it("decodes URL-safe base64 JWT payloads", () => {
1759
const accessToken = createJwt({
@@ -33,6 +75,7 @@ describe("extractOpenAICodexAccountId", () => {
3375

3476
describe("streamOpenAICodexResponses transport", () => {
3577
afterEach(() => {
78+
vi.restoreAllMocks();
3679
vi.unstubAllGlobals();
3780
resetOpenAICodexWebSocketDebugStats();
3881
});
@@ -59,12 +102,16 @@ describe("streamOpenAICodexResponses transport", () => {
59102
throw new Error("fetch should not run");
60103
});
61104
vi.stubGlobal("fetch", fetchMock);
62-
vi.stubGlobal(
63-
"WebSocket",
64-
vi.fn(() => {
105+
class FailingWebSocket {
106+
constructor() {
65107
throw new Error("websocket connect failed");
66-
}),
67-
);
108+
}
109+
send(): void {}
110+
close(): void {}
111+
addEventListener(): void {}
112+
removeEventListener(): void {}
113+
}
114+
vi.stubGlobal("WebSocket", FailingWebSocket);
68115

69116
const stream = streamOpenAICodexResponses(model, context, {
70117
apiKey: createJwt({
@@ -82,4 +129,92 @@ describe("streamOpenAICodexResponses transport", () => {
82129
expect(result.stopReason).toBe("error");
83130
expect(result.errorMessage).toContain("websocket connect failed");
84131
});
132+
133+
it("honors timeoutMs for explicit SSE transport requests", async () => {
134+
stubHangingFetch(5);
135+
136+
const stream = streamOpenAICodexResponses(model, context, {
137+
apiKey: createJwt({
138+
"https://api.openai.com/auth": {
139+
chatgpt_account_id: "acct-1",
140+
},
141+
}),
142+
timeoutMs: 5,
143+
transport: "sse",
144+
});
145+
146+
const result = await stream.result();
147+
148+
expect(result.stopReason).toBe("error");
149+
expect(result.errorMessage).toContain("Request timed out after 5ms");
150+
});
151+
152+
it("honors timeoutMs for default websocket transport requests", async () => {
153+
stubTimeoutSignal(5);
154+
const fetchMock = vi.fn(async () => {
155+
throw new Error("fetch should not run before websocket timeout");
156+
});
157+
class HangingWebSocket {
158+
send = vi.fn();
159+
close = vi.fn();
160+
addEventListener(): void {}
161+
removeEventListener(): void {}
162+
}
163+
vi.stubGlobal("fetch", fetchMock);
164+
vi.stubGlobal("WebSocket", HangingWebSocket);
165+
166+
const stream = streamOpenAICodexResponses(model, context, {
167+
apiKey: createJwt({
168+
"https://api.openai.com/auth": {
169+
chatgpt_account_id: "acct-1",
170+
},
171+
}),
172+
timeoutMs: 5,
173+
});
174+
175+
const result = await stream.result();
176+
177+
expect(fetchMock).not.toHaveBeenCalled();
178+
expect(result.stopReason).toBe("error");
179+
expect(result.errorMessage).toContain("Request timed out after 5ms");
180+
});
181+
182+
it("does not send websocket payload after timeout fires during connect", async () => {
183+
let timeoutController: AbortController | undefined;
184+
vi.spyOn(AbortSignal, "timeout").mockImplementation((actualTimeoutMs) => {
185+
expect(actualTimeoutMs).toBe(5);
186+
timeoutController = new AbortController();
187+
return timeoutController.signal;
188+
});
189+
const sendMock = vi.fn();
190+
class OpeningThenTimedOutWebSocket {
191+
send = sendMock;
192+
close = vi.fn();
193+
addEventListener(type: string, listener: (event: unknown) => void): void {
194+
if (type === "open") {
195+
queueMicrotask(() => {
196+
listener({});
197+
timeoutController?.abort(new DOMException("timed out", "TimeoutError"));
198+
});
199+
}
200+
}
201+
removeEventListener(): void {}
202+
}
203+
vi.stubGlobal("WebSocket", OpeningThenTimedOutWebSocket);
204+
205+
const stream = streamOpenAICodexResponses(model, context, {
206+
apiKey: createJwt({
207+
"https://api.openai.com/auth": {
208+
chatgpt_account_id: "acct-1",
209+
},
210+
}),
211+
timeoutMs: 5,
212+
});
213+
214+
const result = await stream.result();
215+
216+
expect(sendMock).not.toHaveBeenCalled();
217+
expect(result.stopReason).toBe("error");
218+
expect(result.errorMessage).toContain("Request timed out after 5ms");
219+
});
85220
});

src/llm/providers/openai-codex-responses.ts

Lines changed: 84 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,52 @@ function sleep(ms: number, signal?: AbortSignal): Promise<void> {
133133
});
134134
}
135135

136+
function resolveRequestTimeoutMs(options?: OpenAICodexResponsesOptions): number | undefined {
137+
const timeoutMs = options?.timeoutMs;
138+
return typeof timeoutMs === "number" && Number.isFinite(timeoutMs) && timeoutMs > 0
139+
? Math.floor(timeoutMs)
140+
: undefined;
141+
}
142+
143+
function buildRequestSignal(
144+
baseSignal: AbortSignal | undefined,
145+
timeoutMs: number | undefined,
146+
): AbortSignal | undefined {
147+
if (timeoutMs === undefined) {
148+
return baseSignal;
149+
}
150+
const timeoutSignal = AbortSignal.timeout(timeoutMs);
151+
if (!baseSignal) {
152+
return timeoutSignal;
153+
}
154+
return AbortSignal.any([baseSignal, timeoutSignal]);
155+
}
156+
157+
function isRequestTimeoutError(
158+
error: unknown,
159+
callerSignal: AbortSignal | undefined,
160+
requestSignal: AbortSignal | undefined,
161+
timeoutMs: number | undefined,
162+
): boolean {
163+
if (timeoutMs === undefined || callerSignal?.aborted || !requestSignal?.aborted) {
164+
return false;
165+
}
166+
if (!(error instanceof Error)) {
167+
return false;
168+
}
169+
return (
170+
error.name === "AbortError" ||
171+
error.name === "TimeoutError" ||
172+
error.message === "Request was aborted"
173+
);
174+
}
175+
176+
function formatRequestTimeoutError(timeoutMs: number, cause: unknown): Error {
177+
return new Error(`Request timed out after ${timeoutMs}ms`, {
178+
cause: cause instanceof Error ? cause : undefined,
179+
});
180+
}
181+
136182
// ============================================================================
137183
// Main Stream Function
138184
// ============================================================================
@@ -148,6 +194,8 @@ export const streamOpenAICodexResponses: StreamFunction<
148194
const stream = new AssistantMessageEventStream();
149195

150196
void (async () => {
197+
let requestTimeoutMs: number | undefined;
198+
let activeSignal: AbortSignal | undefined;
151199
const output: AssistantMessage = {
152200
role: "assistant",
153201
content: [],
@@ -194,6 +242,10 @@ export const streamOpenAICodexResponses: StreamFunction<
194242
websocketRequestId,
195243
);
196244
const bodyJson = JSON.stringify(body);
245+
requestTimeoutMs = resolveRequestTimeoutMs(options);
246+
activeSignal = buildRequestSignal(options?.signal, requestTimeoutMs);
247+
const requestOptions =
248+
activeSignal === options?.signal ? options : { ...options, signal: activeSignal };
197249
const transport = options?.transport || "auto";
198250
const websocketDisabledForSession =
199251
transport === "auto" && isWebSocketSseFallbackActive(options?.sessionId);
@@ -214,10 +266,10 @@ export const streamOpenAICodexResponses: StreamFunction<
214266
() => {
215267
websocketStarted = true;
216268
},
217-
options,
269+
requestOptions,
218270
);
219271

220-
if (options?.signal?.aborted) {
272+
if (activeSignal?.aborted) {
221273
throw new Error("Request was aborted");
222274
}
223275
stream.push({
@@ -228,7 +280,7 @@ export const streamOpenAICodexResponses: StreamFunction<
228280
stream.end();
229281
return;
230282
} catch (error) {
231-
const aborted = options?.signal?.aborted;
283+
const aborted = activeSignal?.aborted;
232284
if (aborted || isCodexNonTransportError(error)) {
233285
throw error;
234286
}
@@ -259,7 +311,7 @@ export const streamOpenAICodexResponses: StreamFunction<
259311
let lastError: Error | undefined;
260312

261313
for (let attempt = 0; attempt <= MAX_RETRIES; attempt++) {
262-
if (options?.signal?.aborted) {
314+
if (activeSignal?.aborted) {
263315
throw new Error("Request was aborted");
264316
}
265317

@@ -268,7 +320,7 @@ export const streamOpenAICodexResponses: StreamFunction<
268320
method: "POST",
269321
headers: sseHeaders,
270322
body: bodyJson,
271-
signal: options?.signal,
323+
signal: activeSignal,
272324
});
273325
await options?.onResponse?.(
274326
{ status: response.status, headers: headersToRecord(response.headers) },
@@ -304,7 +356,7 @@ export const streamOpenAICodexResponses: StreamFunction<
304356
}
305357
}
306358

307-
await sleep(delayMs, options?.signal);
359+
await sleep(delayMs, activeSignal);
308360
continue;
309361
}
310362

@@ -317,15 +369,24 @@ export const streamOpenAICodexResponses: StreamFunction<
317369
throw new Error(info.friendlyMessage || info.message);
318370
} catch (error) {
319371
if (error instanceof Error) {
372+
if (
373+
isRequestTimeoutError(error, options?.signal, activeSignal, requestTimeoutMs) &&
374+
requestTimeoutMs !== undefined
375+
) {
376+
throw formatRequestTimeoutError(requestTimeoutMs, error);
377+
}
320378
if (error.name === "AbortError" || error.message === "Request was aborted") {
321379
throw new Error("Request was aborted", { cause: error });
322380
}
381+
if (error.name === "TimeoutError" && requestTimeoutMs !== undefined) {
382+
throw new Error(`Request timed out after ${requestTimeoutMs}ms`, { cause: error });
383+
}
323384
}
324385
lastError = error instanceof Error ? error : new Error(String(error));
325386
// Network errors are retryable
326387
if (attempt < MAX_RETRIES && !lastError.message.includes("usage limit")) {
327388
const delayMs = BASE_DELAY_MS * 2 ** attempt;
328-
await sleep(delayMs, options?.signal);
389+
await sleep(delayMs, activeSignal);
329390
continue;
330391
}
331392
throw lastError;
@@ -343,7 +404,7 @@ export const streamOpenAICodexResponses: StreamFunction<
343404
stream.push({ type: "start", partial: output });
344405
await processStream(response, output, stream, model, options);
345406

346-
if (options?.signal?.aborted) {
407+
if (activeSignal?.aborted) {
347408
throw new Error("Request was aborted");
348409
}
349410

@@ -354,12 +415,18 @@ export const streamOpenAICodexResponses: StreamFunction<
354415
});
355416
stream.end();
356417
} catch (error) {
418+
const normalizedError =
419+
isRequestTimeoutError(error, options?.signal, activeSignal, requestTimeoutMs) &&
420+
requestTimeoutMs !== undefined
421+
? formatRequestTimeoutError(requestTimeoutMs, error)
422+
: error;
357423
for (const block of output.content) {
358424
// partialJson is only a streaming scratch buffer; never persist it.
359425
delete (block as { partialJson?: string }).partialJson;
360426
}
361427
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
362-
output.errorMessage = error instanceof Error ? error.message : String(error);
428+
output.errorMessage =
429+
normalizedError instanceof Error ? normalizedError.message : String(normalizedError);
363430
stream.push({ type: "error", reason: output.stopReason, error: output });
364431
stream.end();
365432
}
@@ -975,6 +1042,11 @@ async function connectWebSocket(
9751042
signal?.removeEventListener("abort", onAbort);
9761043
};
9771044

1045+
if (signal?.aborted) {
1046+
onAbort();
1047+
return;
1048+
}
1049+
9781050
socket.addEventListener("open", onOpen);
9791051
socket.addEventListener("error", onError);
9801052
socket.addEventListener("close", onClose);
@@ -1374,6 +1446,9 @@ async function processWebSocketStream(
13741446
}
13751447
}
13761448
try {
1449+
if (options?.signal?.aborted) {
1450+
throw new Error("Request was aborted");
1451+
}
13771452
socket.send(JSON.stringify({ type: "response.create", ...requestBody }));
13781453
await processResponsesStream(
13791454
startWebSocketOutputOnFirstEvent(

0 commit comments

Comments
 (0)