Skip to content

Commit b55e6f7

Browse files
authored
fix (ai/core): streamObject text stream in array mode must not include elements: prefix. (#2861)
1 parent faa45f8 commit b55e6f7

File tree

4 files changed

+209
-41
lines changed

4 files changed

+209
-41
lines changed

.changeset/neat-worms-fold.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'ai': patch
3+
---
4+
5+
fix (ai/core): streamObject text stream in array mode must not include elements: prefix.

packages/ai/core/generate-object/output-strategy.ts

Lines changed: 57 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,18 @@ export interface OutputStrategy<PARTIAL, RESULT, ELEMENT_STREAM> {
2323

2424
validatePartialResult({
2525
value,
26-
parseState,
26+
textDelta,
27+
isFinalDelta,
2728
}: {
2829
value: JSONValue;
29-
parseState:
30-
| 'undefined-input'
31-
| 'successful-parse'
32-
| 'repaired-parse'
33-
| 'failed-parse';
34-
}): ValidationResult<PARTIAL>;
30+
textDelta: string;
31+
isFirstDelta: boolean;
32+
isFinalDelta: boolean;
33+
latestObject: PARTIAL | undefined;
34+
}): ValidationResult<{
35+
partial: PARTIAL;
36+
textDelta: string;
37+
}>;
3538
validateFinalResult(value: JSONValue | undefined): ValidationResult<RESULT>;
3639

3740
createElementStream(
@@ -43,8 +46,8 @@ const noSchemaOutputStrategy: OutputStrategy<JSONValue, JSONValue, never> = {
4346
type: 'no-schema',
4447
jsonSchema: undefined,
4548

46-
validatePartialResult({ value }): ValidationResult<JSONValue> {
47-
return { success: true, value };
49+
validatePartialResult({ value, textDelta }) {
50+
return { success: true, value: { partial: value, textDelta } };
4851
},
4952

5053
validateFinalResult(
@@ -68,9 +71,15 @@ const objectOutputStrategy = <OBJECT>(
6871
type: 'object',
6972
jsonSchema: schema.jsonSchema,
7073

71-
validatePartialResult({ value }): ValidationResult<DeepPartial<OBJECT>> {
72-
// Note: currently no validation of partial results:
73-
return { success: true, value: value as DeepPartial<OBJECT> };
74+
validatePartialResult({ value, textDelta }) {
75+
return {
76+
success: true,
77+
value: {
78+
// Note: currently no validation of partial results:
79+
partial: value as DeepPartial<OBJECT>,
80+
textDelta,
81+
},
82+
};
7483
},
7584

7685
validateFinalResult(value: JSONValue | undefined): ValidationResult<OBJECT> {
@@ -91,7 +100,7 @@ const arrayOutputStrategy = <ELEMENT>(
91100
const { $schema, ...itemSchema } = schema.jsonSchema;
92101

93102
return {
94-
type: 'object',
103+
type: 'array',
95104

96105
// wrap in object that contains array of elements, since most LLMs will not
97106
// be able to generate an array directly:
@@ -106,10 +115,7 @@ const arrayOutputStrategy = <ELEMENT>(
106115
additionalProperties: false,
107116
},
108117

109-
validatePartialResult({
110-
value,
111-
parseState,
112-
}): ValidationResult<Array<ELEMENT>> {
118+
validatePartialResult({ value, latestObject, isFirstDelta, isFinalDelta }) {
113119
// check that the value is an object that contains an array of elements:
114120
if (!isJSONObject(value) || !isJSONArray(value.elements)) {
115121
return {
@@ -128,13 +134,11 @@ const arrayOutputStrategy = <ELEMENT>(
128134
const element = inputArray[i];
129135
const result = safeValidateTypes({ value: element, schema });
130136

131-
// special treatment for last element:
132-
// ignore parse failures or validation failures, since they indicate that the
133-
// last element is incomplete and should not be included in the result
134-
if (
135-
i === inputArray.length - 1 &&
136-
(!result.success || parseState !== 'successful-parse')
137-
) {
137+
// special treatment for last processed element:
138+
// ignore parse or validation failures, since they indicate that the
139+
// last element is incomplete and should not be included in the result,
140+
// unless it is the final delta
141+
if (i === inputArray.length - 1 && !isFinalDelta) {
138142
continue;
139143
}
140144

@@ -145,7 +149,35 @@ const arrayOutputStrategy = <ELEMENT>(
145149
resultArray.push(result.value);
146150
}
147151

148-
return { success: true, value: resultArray };
152+
// calculate delta:
153+
const publishedElementCount = latestObject?.length ?? 0;
154+
155+
let textDelta = '';
156+
157+
if (isFirstDelta) {
158+
textDelta += '[';
159+
}
160+
161+
if (publishedElementCount > 0) {
162+
textDelta += ',';
163+
}
164+
165+
textDelta += resultArray
166+
.slice(publishedElementCount) // only new elements
167+
.map(element => JSON.stringify(element))
168+
.join(',');
169+
170+
if (isFinalDelta) {
171+
textDelta += ']';
172+
}
173+
174+
return {
175+
success: true,
176+
value: {
177+
partial: resultArray,
178+
textDelta,
179+
},
180+
};
149181
},
150182

151183
validateFinalResult(

packages/ai/core/generate-object/stream-object.test.ts

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1097,7 +1097,7 @@ describe('output = "array"', () => {
10971097
});
10981098
});
10991099

1100-
it('should stream only complete objects', async () => {
1100+
it('should stream only complete objects in partialObjectStream', async () => {
11011101
assert.deepStrictEqual(
11021102
await convertAsyncIterableToArray(result.partialObjectStream),
11031103
[
@@ -1113,6 +1113,18 @@ describe('output = "array"', () => {
11131113
);
11141114
});
11151115

1116+
it('should stream only complete objects in textStream', async () => {
1117+
assert.deepStrictEqual(
1118+
await convertAsyncIterableToArray(result.textStream),
1119+
[
1120+
'[',
1121+
'{"content":"element 1"}',
1122+
',{"content":"element 2"}',
1123+
',{"content":"element 3"}]',
1124+
],
1125+
);
1126+
});
1127+
11161128
it('should have the correct object result', async () => {
11171129
// consume stream
11181130
await convertAsyncIterableToArray(result.partialObjectStream);
@@ -1143,6 +1155,123 @@ describe('output = "array"', () => {
11431155
);
11441156
});
11451157
});
1158+
1159+
describe('array with 2 elements streamed in 1 chunk', () => {
1160+
let result: StreamObjectResult<
1161+
{ content: string }[],
1162+
{ content: string }[],
1163+
AsyncIterableStream<{ content: string }>
1164+
>;
1165+
1166+
let onFinishResult: Parameters<
1167+
Required<Parameters<typeof streamObject>[0]>['onFinish']
1168+
>[0];
1169+
1170+
beforeEach(async () => {
1171+
result = await streamObject({
1172+
model: new MockLanguageModelV1({
1173+
doStream: async ({ prompt, mode }) => {
1174+
assert.deepStrictEqual(mode, {
1175+
type: 'object-json',
1176+
name: undefined,
1177+
description: undefined,
1178+
schema: {
1179+
$schema: 'http://json-schema.org/draft-07/schema#',
1180+
additionalProperties: false,
1181+
properties: {
1182+
elements: {
1183+
type: 'array',
1184+
items: {
1185+
type: 'object',
1186+
properties: { content: { type: 'string' } },
1187+
required: ['content'],
1188+
additionalProperties: false,
1189+
},
1190+
},
1191+
},
1192+
required: ['elements'],
1193+
type: 'object',
1194+
},
1195+
});
1196+
1197+
assert.deepStrictEqual(prompt, [
1198+
{
1199+
role: 'system',
1200+
content:
1201+
'JSON schema:\n' +
1202+
`{\"$schema\":\"http://json-schema.org/draft-07/schema#\",\"type\":\"object\",\"properties\":{\"elements\":{\"type\":\"array\",\"items\":{\"type\":\"object\",\"properties\":{\"content\":{\"type\":\"string\"}},\"required\":[\"content\"],\"additionalProperties\":false}}},\"required\":[\"elements\"],\"additionalProperties\":false}` +
1203+
`\n` +
1204+
'You MUST answer with a JSON object that matches the JSON schema above.',
1205+
},
1206+
{ role: 'user', content: [{ type: 'text', text: 'prompt' }] },
1207+
]);
1208+
1209+
return {
1210+
stream: convertArrayToReadableStream([
1211+
{
1212+
type: 'text-delta',
1213+
textDelta:
1214+
'{"elements":[{"content":"element 1"},{"content":"element 2"}]}',
1215+
},
1216+
// finish
1217+
{
1218+
type: 'finish',
1219+
finishReason: 'stop',
1220+
usage: { completionTokens: 10, promptTokens: 3 },
1221+
},
1222+
]),
1223+
rawCall: { rawPrompt: 'prompt', rawSettings: {} },
1224+
};
1225+
},
1226+
}),
1227+
schema: z.object({ content: z.string() }),
1228+
output: 'array',
1229+
mode: 'json',
1230+
prompt: 'prompt',
1231+
onFinish: async event => {
1232+
onFinishResult = event as unknown as typeof onFinishResult;
1233+
},
1234+
});
1235+
});
1236+
1237+
it('should stream only complete objects in partialObjectStream', async () => {
1238+
assert.deepStrictEqual(
1239+
await convertAsyncIterableToArray(result.partialObjectStream),
1240+
[[{ content: 'element 1' }, { content: 'element 2' }]],
1241+
);
1242+
});
1243+
1244+
it('should stream only complete objects in textStream', async () => {
1245+
assert.deepStrictEqual(
1246+
await convertAsyncIterableToArray(result.textStream),
1247+
['[{"content":"element 1"},{"content":"element 2"}]'],
1248+
);
1249+
});
1250+
1251+
it('should have the correct object result', async () => {
1252+
// consume stream
1253+
await convertAsyncIterableToArray(result.partialObjectStream);
1254+
1255+
expect(await result.object).toStrictEqual([
1256+
{ content: 'element 1' },
1257+
{ content: 'element 2' },
1258+
]);
1259+
});
1260+
1261+
it('should call onFinish callback with full array', async () => {
1262+
expect(onFinishResult.object).toStrictEqual([
1263+
{ content: 'element 1' },
1264+
{ content: 'element 2' },
1265+
]);
1266+
});
1267+
1268+
it('should stream elements individually in elementStream', async () => {
1269+
assert.deepStrictEqual(
1270+
await convertAsyncIterableToArray(result.elementStream),
1271+
[{ content: 'element 1' }, { content: 'element 2' }],
1272+
);
1273+
});
1274+
});
11461275
});
11471276

11481277
describe('output = "no-schema"', () => {

packages/ai/core/generate-object/stream-object.ts

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -603,13 +603,14 @@ class DefaultStreamObjectResult<PARTIAL, RESULT, ELEMENT_STREAM>
603603

604604
// pipe chunks through a transformation stream that extracts metadata:
605605
let accumulatedText = '';
606-
let delta = '';
606+
let textDelta = '';
607607

608608
// Keep track of raw parse result before type validation, since e.g. Zod might
609609
// change the object by mapping properties.
610610
let latestObjectJson: JSONValue | undefined = undefined;
611611
let latestObject: PARTIAL | undefined = undefined;
612-
let firstChunk = true;
612+
let isFirstChunk = true;
613+
let isFirstDelta = true;
613614

614615
const self = this;
615616
this.originalStream = stream.pipeThrough(
@@ -619,10 +620,10 @@ class DefaultStreamObjectResult<PARTIAL, RESULT, ELEMENT_STREAM>
619620
>({
620621
async transform(chunk, controller): Promise<void> {
621622
// Telemetry event for first chunk:
622-
if (firstChunk) {
623+
if (isFirstChunk) {
623624
const msToFirstChunk = performance.now() - startTimestamp;
624625

625-
firstChunk = false;
626+
isFirstChunk = false;
626627

627628
doStreamSpan.addEvent('ai.stream.firstChunk', {
628629
'ai.stream.msToFirstChunk': msToFirstChunk,
@@ -636,7 +637,7 @@ class DefaultStreamObjectResult<PARTIAL, RESULT, ELEMENT_STREAM>
636637
// process partial text chunks
637638
if (typeof chunk === 'string') {
638639
accumulatedText += chunk;
639-
delta += chunk;
640+
textDelta += chunk;
640641

641642
const { value: currentObjectJson, state: parseState } =
642643
parsePartialJson(accumulatedText);
@@ -647,16 +648,19 @@ class DefaultStreamObjectResult<PARTIAL, RESULT, ELEMENT_STREAM>
647648
) {
648649
const validationResult = outputStrategy.validatePartialResult({
649650
value: currentObjectJson,
650-
parseState,
651+
textDelta,
652+
latestObject,
653+
isFirstDelta,
654+
isFinalDelta: parseState === 'successful-parse',
651655
});
652656

653657
if (
654658
validationResult.success &&
655-
!isDeepEqualData(latestObject, validationResult.value)
659+
!isDeepEqualData(latestObject, validationResult.value.partial)
656660
) {
657661
// inside inner check to correctly parse the final element in array mode:
658662
latestObjectJson = currentObjectJson;
659-
latestObject = validationResult.value;
663+
latestObject = validationResult.value.partial;
660664

661665
controller.enqueue({
662666
type: 'object',
@@ -665,10 +669,11 @@ class DefaultStreamObjectResult<PARTIAL, RESULT, ELEMENT_STREAM>
665669

666670
controller.enqueue({
667671
type: 'text-delta',
668-
textDelta: delta,
672+
textDelta: validationResult.value.textDelta,
669673
});
670674

671-
delta = '';
675+
textDelta = '';
676+
isFirstDelta = false;
672677
}
673678
}
674679

@@ -678,11 +683,8 @@ class DefaultStreamObjectResult<PARTIAL, RESULT, ELEMENT_STREAM>
678683
switch (chunk.type) {
679684
case 'finish': {
680685
// send final text delta:
681-
if (delta !== '') {
682-
controller.enqueue({
683-
type: 'text-delta',
684-
textDelta: delta,
685-
});
686+
if (textDelta !== '') {
687+
controller.enqueue({ type: 'text-delta', textDelta });
686688
}
687689

688690
// store finish reason for telemetry:

0 commit comments

Comments
 (0)