Skip to content

Commit 148db14

Browse files
fix(google): omit request config with cached content
Fix Gemini cached-content GenerateContent payloads so cached requests no longer resend request-level systemInstruction, tools, or toolConfig. Covers explicit cachedContent and managed cacheRetention prompt caching; fixes #84919. Proof: Real behavior proof passed on PR head 198a42b after live Gemini repro/fix evidence was added to the PR body. Focused tests and check:changed were already green. Thanks @neeravmakwana.
1 parent 5a9673e commit 148db14

4 files changed

Lines changed: 151 additions & 17 deletions

File tree

extensions/google/transport-stream.test.ts

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -396,19 +396,14 @@ describe("google transport stream", () => {
396396
});
397397

398398
const payload = parseRequestJsonBody(init);
399-
expect(payload.systemInstruction).toEqual({
400-
parts: [{ text: "Follow policy." }],
401-
});
402399
expect(payload.cachedContent).toBe("cachedContents/request-cache");
400+
expect(payload.systemInstruction).toBeUndefined();
401+
expect(payload.tools).toBeUndefined();
402+
expect(payload.toolConfig).toBeUndefined();
403403
expect((payload.generationConfig as { thinkingConfig?: unknown }).thinkingConfig).toEqual({
404404
includeThoughts: true,
405405
thinkingLevel: "HIGH",
406406
});
407-
expect(
408-
(payload.toolConfig as { functionCallingConfig?: unknown }).functionCallingConfig,
409-
).toEqual({
410-
mode: "AUTO",
411-
});
412407
expect(result.api).toBe("google-generative-ai");
413408
expect(result.provider).toBe("google");
414409
expect(result.responseId).toBe("resp_1");
@@ -1640,6 +1635,36 @@ describe("google transport stream", () => {
16401635
expect(params.cachedContent).toBe("cachedContents/prebuilt-context");
16411636
});
16421637

1638+
it("omits per-request system and tool settings when using cachedContent", () => {
1639+
const params = buildGoogleGenerativeAiParams(
1640+
buildGeminiModel(),
1641+
{
1642+
systemPrompt: "Follow policy.",
1643+
messages: [{ role: "user", content: "hello", timestamp: 0 }],
1644+
tools: [
1645+
{
1646+
name: "lookup",
1647+
description: "Look up a value",
1648+
parameters: {
1649+
type: "object",
1650+
properties: { q: { type: "string" } },
1651+
required: ["q"],
1652+
},
1653+
},
1654+
],
1655+
} as never,
1656+
{
1657+
cachedContent: " cachedContents/prebuilt-context ",
1658+
toolChoice: "auto",
1659+
},
1660+
);
1661+
1662+
expect(params.cachedContent).toBe("cachedContents/prebuilt-context");
1663+
expect(params.systemInstruction).toBeUndefined();
1664+
expect(params.tools).toBeUndefined();
1665+
expect(params.toolConfig).toBeUndefined();
1666+
});
1667+
16431668
it("uses a non-empty text placeholder for empty user text", () => {
16441669
const params = buildGoogleGenerativeAiParams(buildGeminiModel(), {
16451670
messages: [

extensions/google/transport-stream.ts

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -705,13 +705,15 @@ export function buildGoogleGenerativeAiParams(
705705
const params: GoogleGenerateContentRequest = {
706706
contents: convertGoogleMessages(model, context),
707707
};
708-
if (typeof options?.cachedContent === "string" && options.cachedContent.trim()) {
709-
params.cachedContent = options.cachedContent.trim();
708+
const cachedContent =
709+
typeof options?.cachedContent === "string" ? options.cachedContent.trim() : "";
710+
if (cachedContent) {
711+
params.cachedContent = cachedContent;
710712
}
711713
if (Object.keys(generationConfig).length > 0) {
712714
params.generationConfig = generationConfig;
713715
}
714-
if (context.systemPrompt) {
716+
if (!cachedContent && context.systemPrompt) {
715717
params.systemInstruction = {
716718
parts: [
717719
{
@@ -720,7 +722,7 @@ export function buildGoogleGenerativeAiParams(
720722
],
721723
};
722724
}
723-
if (context.tools?.length) {
725+
if (!cachedContent && context.tools?.length) {
724726
params.tools = convertGoogleTools(context.tools);
725727
const toolChoice = mapToolChoice(options?.toolChoice);
726728
if (toolChoice) {

src/agents/pi-embedded-runner/google-prompt-cache.test.ts

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ describe("google prompt cache", () => {
179179
},
180180
],
181181
} as never,
182-
{ temperature: 0.2 } as never,
182+
{ temperature: 0.2, toolChoice: "auto" } as never,
183183
),
184184
);
185185

@@ -200,11 +200,28 @@ describe("google prompt cache", () => {
200200
systemInstruction: {
201201
parts: [{ text: "Follow policy." }],
202202
},
203+
tools: [
204+
{
205+
functionDeclarations: [
206+
{
207+
name: "lookup",
208+
description: "Look up a value",
209+
parametersJsonSchema: { type: "object" },
210+
},
211+
],
212+
},
213+
],
214+
toolConfig: {
215+
functionCallingConfig: {
216+
mode: "AUTO",
217+
},
218+
},
203219
});
204220
expect(innerStreamFn).toHaveBeenCalledTimes(1);
205221
expect(streamContext(innerStreamFn).systemPrompt).toBeUndefined();
206-
expect(Array.isArray(streamContext(innerStreamFn).tools)).toBe(true);
222+
expect(streamContext(innerStreamFn).tools).toBeUndefined();
207223
expect(streamOptions(innerStreamFn).temperature).toBe(0.2);
224+
expect(streamOptions(innerStreamFn).toolChoice).toBe("auto");
208225
expect(getCapturedPayload()?.cachedContent).toBe("cachedContents/system-cache-1");
209226
expect(entries).toEqual([
210227
{
@@ -221,6 +238,7 @@ describe("google prompt cache", () => {
221238
modelApi: "google-generative-ai",
222239
baseUrl: "https://generativelanguage.googleapis.com/v1beta",
223240
systemPromptDigest,
241+
cacheConfigDigest: expect.any(String),
224242
cacheRetention: "long",
225243
cachedContent: "cachedContents/system-cache-1",
226244
expireTime,

src/agents/pi-embedded-runner/google-prompt-cache.ts

Lines changed: 92 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ type GooglePromptCacheModel = Model<Api> & {
3131
headers?: Record<string, string>;
3232
provider: string;
3333
};
34+
type GooglePromptCacheContext = Parameters<StreamFn>[1];
35+
type GooglePromptCacheOptions = Parameters<StreamFn>[2];
3436

3537
type GooglePromptCacheEntry = {
3638
timestamp: number;
@@ -39,6 +41,7 @@ type GooglePromptCacheEntry = {
3941
modelApi?: string | null;
4042
baseUrl: string;
4143
systemPromptDigest: string;
44+
cacheConfigDigest?: string;
4245
cacheRetention: CacheRetention;
4346
} & (
4447
| {
@@ -111,6 +114,7 @@ function buildGooglePromptCacheMatchKey(params: {
111114
modelApi?: string | null;
112115
baseUrl: string;
113116
systemPromptDigest: string;
117+
cacheConfigDigest?: string;
114118
}) {
115119
return stableStringify(params);
116120
}
@@ -150,6 +154,8 @@ function readLatestGooglePromptCacheEntry(
150154
: null,
151155
baseUrl: stringifyGooglePromptCacheKeyPart(cacheData.baseUrl),
152156
systemPromptDigest: stringifyGooglePromptCacheKeyPart(cacheData.systemPromptDigest),
157+
cacheConfigDigest:
158+
typeof cacheData.cacheConfigDigest === "string" ? cacheData.cacheConfigDigest : undefined,
153159
});
154160
if (candidateKey === matchKey) {
155161
return data as GooglePromptCacheEntry;
@@ -183,13 +189,79 @@ function parseExpireTimeMs(expireTime: string | undefined): number | null {
183189
return Number.isFinite(timestamp) ? timestamp : null;
184190
}
185191

186-
function buildManagedContextWithoutSystemPrompt(context: Parameters<StreamFn>[1]) {
187-
if (!context.systemPrompt) {
192+
function convertManagedGoogleTools(tools: NonNullable<GooglePromptCacheContext["tools"]>) {
193+
if (tools.length === 0) {
194+
return undefined;
195+
}
196+
return [
197+
{
198+
functionDeclarations: tools.map((tool) => ({
199+
name: tool.name,
200+
description: tool.description,
201+
parametersJsonSchema: tool.parameters,
202+
})),
203+
},
204+
];
205+
}
206+
207+
function mapManagedGoogleToolChoice(
208+
choice: unknown,
209+
): { mode: "AUTO" | "NONE" | "ANY"; allowedFunctionNames?: string[] } | undefined {
210+
if (!choice) {
211+
return undefined;
212+
}
213+
if (
214+
typeof choice === "object" &&
215+
choice !== null &&
216+
(choice as { type?: unknown }).type === "function"
217+
) {
218+
const functionName = (choice as { function?: { name?: unknown } }).function?.name;
219+
return typeof functionName === "string"
220+
? { mode: "ANY", allowedFunctionNames: [functionName] }
221+
: { mode: "ANY" };
222+
}
223+
switch (choice) {
224+
case "none":
225+
return { mode: "NONE" };
226+
case "any":
227+
case "required":
228+
return { mode: "ANY" };
229+
default:
230+
return { mode: "AUTO" };
231+
}
232+
}
233+
234+
function buildManagedGooglePromptCacheConfig(
235+
context: GooglePromptCacheContext,
236+
options: GooglePromptCacheOptions,
237+
) {
238+
const tools = context.tools?.length ? convertManagedGoogleTools(context.tools) : undefined;
239+
const toolChoice = tools
240+
? mapManagedGoogleToolChoice((options as { toolChoice?: unknown } | undefined)?.toolChoice)
241+
: undefined;
242+
const toolConfig = toolChoice ? { functionCallingConfig: toolChoice } : undefined;
243+
const cacheConfigDigest =
244+
tools || toolConfig
245+
? stableStringify({
246+
tools,
247+
toolConfig,
248+
})
249+
: undefined;
250+
return {
251+
cacheConfigDigest,
252+
tools,
253+
toolConfig,
254+
};
255+
}
256+
257+
function buildManagedContextForCachedContent(context: GooglePromptCacheContext) {
258+
if (!context.systemPrompt && !context.tools?.length) {
188259
return context;
189260
}
190261
return {
191262
...context,
192263
systemPrompt: undefined,
264+
tools: undefined,
193265
};
194266
}
195267

@@ -229,6 +301,8 @@ async function createGooglePromptCache(params: {
229301
modelId: string;
230302
signal?: AbortSignal;
231303
systemPrompt: string;
304+
tools?: unknown;
305+
toolConfig?: unknown;
232306
}): Promise<{ cachedContent: string; expireTime?: string } | null> {
233307
const response = await params.fetchImpl(`${params.baseUrl}/cachedContents`, {
234308
method: "POST",
@@ -239,6 +313,8 @@ async function createGooglePromptCache(params: {
239313
systemInstruction: {
240314
parts: [{ text: params.systemPrompt }],
241315
},
316+
...(params.tools ? { tools: params.tools } : {}),
317+
...(params.toolConfig ? { toolConfig: params.toolConfig } : {}),
242318
}),
243319
signal: params.signal,
244320
});
@@ -256,9 +332,12 @@ async function ensureGooglePromptCache(
256332
cacheRetention: CacheRetention;
257333
model: GooglePromptCacheModel;
258334
provider: string;
335+
cacheConfigDigest?: string;
259336
sessionManager: GooglePromptCacheSessionManager;
260337
signal?: AbortSignal;
261338
systemPrompt: string;
339+
tools?: unknown;
340+
toolConfig?: unknown;
262341
},
263342
deps: GooglePromptCacheDeps,
264343
): Promise<string | null> {
@@ -271,6 +350,7 @@ async function ensureGooglePromptCache(
271350
modelApi: params.model.api,
272351
baseUrl,
273352
systemPromptDigest,
353+
cacheConfigDigest: params.cacheConfigDigest,
274354
});
275355
const latestEntry = readLatestGooglePromptCacheEntry(params.sessionManager, matchKey);
276356

@@ -306,6 +386,7 @@ async function ensureGooglePromptCache(
306386
modelApi: params.model.api,
307387
baseUrl,
308388
systemPromptDigest,
389+
cacheConfigDigest: params.cacheConfigDigest,
309390
cacheRetention: params.cacheRetention,
310391
cachedContent: latestEntry.cachedContent,
311392
expireTime: refreshed.expireTime ?? latestEntry.expireTime,
@@ -325,6 +406,8 @@ async function ensureGooglePromptCache(
325406
modelId: params.model.id,
326407
signal: params.signal,
327408
systemPrompt: params.systemPrompt,
409+
tools: params.tools,
410+
toolConfig: params.toolConfig,
328411
});
329412
if (!created) {
330413
await appendGooglePromptCacheEntry(params.sessionManager, {
@@ -335,6 +418,7 @@ async function ensureGooglePromptCache(
335418
modelApi: params.model.api,
336419
baseUrl,
337420
systemPromptDigest,
421+
cacheConfigDigest: params.cacheConfigDigest,
338422
cacheRetention: params.cacheRetention,
339423
retryAfter: now + GOOGLE_PROMPT_CACHE_RETRY_BACKOFF_MS,
340424
});
@@ -349,6 +433,7 @@ async function ensureGooglePromptCache(
349433
modelApi: params.model.api,
350434
baseUrl,
351435
systemPromptDigest,
436+
cacheConfigDigest: params.cacheConfigDigest,
352437
cacheRetention: params.cacheRetention,
353438
cachedContent: created.cachedContent,
354439
expireTime: created.expireTime,
@@ -386,15 +471,19 @@ export async function prepareGooglePromptCacheStreamFn(
386471

387472
const inner = params.streamFn;
388473
return async (model, context, options) => {
474+
const cacheConfig = buildManagedGooglePromptCacheConfig(context, options);
389475
const cachedContent = await ensureGooglePromptCache(
390476
{
391477
apiKey,
478+
cacheConfigDigest: cacheConfig.cacheConfigDigest,
392479
cacheRetention: resolvedRetention,
393480
model: params.model,
394481
provider: params.provider,
395482
sessionManager: params.sessionManager,
396483
signal: params.signal,
397484
systemPrompt,
485+
tools: cacheConfig.tools,
486+
toolConfig: cacheConfig.toolConfig,
398487
},
399488
deps,
400489
);
@@ -408,7 +497,7 @@ export async function prepareGooglePromptCacheStreamFn(
408497
return streamWithPayloadPatch(
409498
inner,
410499
model,
411-
buildManagedContextWithoutSystemPrompt(context),
500+
buildManagedContextForCachedContent(context),
412501
options,
413502
(payload) => {
414503
payload.cachedContent = cachedContent;

0 commit comments

Comments
 (0)