Skip to content

Commit ee4beee

Browse files
feat(ai): add onStepFinish callback to createUIMessageStream (#12448)
## Background we expose `onFinish` callback for createUIMessageStream() but not `onStepFInish`, which is inconsistent with the pattern. also reported in #12383 ## Summary added the callback to the function + new type added ## Manual Verification - verified by running `localhost:3000/use-chat-human-in-the-loop` ## Checklist - [x] Tests have been added / updated (for bug fixes / features) - [ ] Documentation has been added / updated (for bug fixes / features) - [ ] A _patch_ changeset for relevant packages has been added (for bug fixes / features - run `pnpm changeset` in the project root) - [x] I have reviewed this pull request (self-review) ## Future Work the callback might need to be added for `createAgentUIStreamResponse` ## Related Issues fixes #12383
1 parent bc009ee commit ee4beee

File tree

7 files changed

+354
-3
lines changed

7 files changed

+354
-3
lines changed

.changeset/great-goats-double.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'ai': patch
3+
---
4+
5+
feat(ai): add onStepFinish callback to createUIMessageStream

examples/next-openai/app/api/use-chat-human-in-the-loop/route.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,18 @@ export async function POST(req: Request) {
4343
model: openai('gpt-4o'),
4444
messages: await convertToModelMessages(processedMessages),
4545
tools,
46-
stopWhen: stepCountIs(5),
46+
stopWhen: stepCountIs(20),
4747
});
4848

4949
writer.merge(
5050
result.toUIMessageStream({ originalMessages: processedMessages }),
5151
);
5252
},
53+
onStepFinish: ({ messages, responseMessage }) => {
54+
console.log('--- Step finished ---');
55+
console.log('Parts count:', responseMessage.parts.length);
56+
console.log('Messages:', JSON.stringify(messages, null, 2));
57+
},
5358
onFinish: ({}) => {
5459
// save messages here
5560
console.log('Finished!');

packages/ai/src/ui-message-stream/create-ui-message-stream.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import { UIMessage } from '../ui/ui-messages';
77
import { handleUIMessageStreamFinish } from './handle-ui-message-stream-finish';
88
import { InferUIMessageChunk } from './ui-message-chunks';
99
import { UIMessageStreamOnFinishCallback } from './ui-message-stream-on-finish-callback';
10+
import { UIMessageStreamOnStepFinishCallback } from './ui-message-stream-on-step-finish-callback';
1011
import { UIMessageStreamWriter } from './ui-message-stream-writer';
1112

1213
/**
@@ -16,6 +17,7 @@ import { UIMessageStreamWriter } from './ui-message-stream-writer';
1617
* @param options.onError - A function that extracts an error message from an error. Defaults to `getErrorMessage`.
1718
* @param options.originalMessages - The original messages. If provided, persistence mode is assumed
1819
* and a message ID is provided for the response message.
20+
* @param options.onStepFinish - A callback that is called when each step finishes. Useful for persisting intermediate messages.
1921
* @param options.onFinish - A callback that is called when the stream finishes.
2022
* @param options.generateId - A function that generates a unique ID. Defaults to the built-in ID generator.
2123
*
@@ -25,6 +27,7 @@ export function createUIMessageStream<UI_MESSAGE extends UIMessage>({
2527
execute,
2628
onError = getErrorMessage,
2729
originalMessages,
30+
onStepFinish,
2831
onFinish,
2932
generateId = generateIdFunc,
3033
}: {
@@ -39,6 +42,11 @@ export function createUIMessageStream<UI_MESSAGE extends UIMessage>({
3942
*/
4043
originalMessages?: UI_MESSAGE[];
4144

45+
/**
46+
* Callback that is called when each step finishes during multi-step agent runs.
47+
*/
48+
onStepFinish?: UIMessageStreamOnStepFinishCallback<UI_MESSAGE>;
49+
4250
onFinish?: UIMessageStreamOnFinishCallback<UI_MESSAGE>;
4351

4452
generateId?: IdGenerator;
@@ -130,6 +138,7 @@ export function createUIMessageStream<UI_MESSAGE extends UIMessage>({
130138
stream,
131139
messageId: generateId(),
132140
originalMessages,
141+
onStepFinish,
133142
onFinish,
134143
onError,
135144
});

packages/ai/src/ui-message-stream/handle-ui-message-stream-finish.test.ts

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,4 +426,275 @@ describe('handleUIMessageStreamFinish', () => {
426426
expect(callArgs.responseMessage.id).toBe('msg-1');
427427
});
428428
});
429+
430+
describe('onStepFinish callback', () => {
431+
it('should call onStepFinish when finish-step chunk is encountered', async () => {
432+
const onStepFinishCallback = vi.fn();
433+
const inputChunks: UIMessageChunk[] = [
434+
{ type: 'start', messageId: 'msg-step-1' },
435+
{ type: 'text-start', id: 'text-1' },
436+
{ type: 'text-delta', id: 'text-1', delta: 'Step 1 text' },
437+
{ type: 'text-end', id: 'text-1' },
438+
{ type: 'finish-step' },
439+
{ type: 'finish' },
440+
];
441+
442+
const originalMessages: UIMessage[] = [
443+
{
444+
id: 'user-msg-1',
445+
role: 'user',
446+
parts: [{ type: 'text', text: 'Hello' }],
447+
},
448+
];
449+
450+
const stream = createUIMessageStream(inputChunks);
451+
452+
const resultStream = handleUIMessageStreamFinish<UIMessage>({
453+
stream,
454+
messageId: 'msg-step-1',
455+
originalMessages,
456+
onError: mockErrorHandler,
457+
onStepFinish: onStepFinishCallback,
458+
});
459+
460+
const result = await convertReadableStreamToArray(resultStream);
461+
462+
expect(result).toEqual(inputChunks);
463+
expect(onStepFinishCallback).toHaveBeenCalledTimes(1);
464+
465+
const callArgs = onStepFinishCallback.mock.calls[0][0];
466+
expect(callArgs.isContinuation).toBe(false);
467+
expect(callArgs.responseMessage.id).toBe('msg-step-1');
468+
expect(callArgs.responseMessage.role).toBe('assistant');
469+
expect(callArgs.messages).toHaveLength(2);
470+
expect(callArgs.messages[0]).toEqual(originalMessages[0]);
471+
expect(callArgs.messages[1].id).toBe('msg-step-1');
472+
});
473+
474+
it('should call onStepFinish multiple times for multiple steps', async () => {
475+
const onStepFinishCallback = vi.fn();
476+
const inputChunks: UIMessageChunk[] = [
477+
{ type: 'start', messageId: 'msg-multi-step' },
478+
// Step 1
479+
{ type: 'text-start', id: 'text-1' },
480+
{ type: 'text-delta', id: 'text-1', delta: 'Step 1' },
481+
{ type: 'text-end', id: 'text-1' },
482+
{ type: 'finish-step' },
483+
// Step 2
484+
{ type: 'start-step' },
485+
{ type: 'text-start', id: 'text-2' },
486+
{ type: 'text-delta', id: 'text-2', delta: 'Step 2' },
487+
{ type: 'text-end', id: 'text-2' },
488+
{ type: 'finish-step' },
489+
// Step 3
490+
{ type: 'start-step' },
491+
{ type: 'text-start', id: 'text-3' },
492+
{ type: 'text-delta', id: 'text-3', delta: 'Step 3' },
493+
{ type: 'text-end', id: 'text-3' },
494+
{ type: 'finish-step' },
495+
{ type: 'finish' },
496+
];
497+
498+
const stream = createUIMessageStream(inputChunks);
499+
500+
const resultStream = handleUIMessageStreamFinish<UIMessage>({
501+
stream,
502+
messageId: 'msg-multi-step',
503+
originalMessages: [],
504+
onError: mockErrorHandler,
505+
onStepFinish: onStepFinishCallback,
506+
});
507+
508+
await convertReadableStreamToArray(resultStream);
509+
510+
expect(onStepFinishCallback).toHaveBeenCalledTimes(3);
511+
512+
// Verify each step has the correct accumulated content
513+
const step1Args = onStepFinishCallback.mock.calls[0][0];
514+
expect(step1Args.responseMessage.parts).toHaveLength(1);
515+
516+
const step2Args = onStepFinishCallback.mock.calls[1][0];
517+
expect(step2Args.responseMessage.parts).toHaveLength(3); // step-start + 2 text parts
518+
519+
const step3Args = onStepFinishCallback.mock.calls[2][0];
520+
expect(step3Args.responseMessage.parts).toHaveLength(5); // 2 step-starts + 3 text parts
521+
});
522+
523+
it('should call both onStepFinish and onFinish when both are provided', async () => {
524+
const onStepFinishCallback = vi.fn();
525+
const onFinishCallback = vi.fn();
526+
const inputChunks: UIMessageChunk[] = [
527+
{ type: 'start', messageId: 'msg-both' },
528+
{ type: 'text-start', id: 'text-1' },
529+
{ type: 'text-delta', id: 'text-1', delta: 'Hello' },
530+
{ type: 'text-end', id: 'text-1' },
531+
{ type: 'finish-step' },
532+
{ type: 'finish' },
533+
];
534+
535+
const stream = createUIMessageStream(inputChunks);
536+
537+
const resultStream = handleUIMessageStreamFinish<UIMessage>({
538+
stream,
539+
messageId: 'msg-both',
540+
originalMessages: [],
541+
onError: mockErrorHandler,
542+
onStepFinish: onStepFinishCallback,
543+
onFinish: onFinishCallback,
544+
});
545+
546+
await convertReadableStreamToArray(resultStream);
547+
548+
expect(onStepFinishCallback).toHaveBeenCalledTimes(1);
549+
expect(onFinishCallback).toHaveBeenCalledTimes(1);
550+
});
551+
552+
it('should handle onStepFinish errors by logging and continuing', async () => {
553+
const onStepFinishCallback = vi
554+
.fn()
555+
.mockRejectedValue(new Error('DB error'));
556+
const inputChunks: UIMessageChunk[] = [
557+
{ type: 'start', messageId: 'msg-error' },
558+
{ type: 'text-start', id: 'text-1' },
559+
{ type: 'text-delta', id: 'text-1', delta: 'Step 1' },
560+
{ type: 'text-end', id: 'text-1' },
561+
{ type: 'finish-step' },
562+
{ type: 'start-step' },
563+
{ type: 'text-start', id: 'text-2' },
564+
{ type: 'text-delta', id: 'text-2', delta: 'Step 2' },
565+
{ type: 'text-end', id: 'text-2' },
566+
{ type: 'finish-step' },
567+
{ type: 'finish' },
568+
];
569+
570+
const stream = createUIMessageStream(inputChunks);
571+
572+
const resultStream = handleUIMessageStreamFinish<UIMessage>({
573+
stream,
574+
messageId: 'msg-error',
575+
originalMessages: [],
576+
onError: mockErrorHandler,
577+
onStepFinish: onStepFinishCallback,
578+
});
579+
580+
// Stream should complete without throwing
581+
const result = await convertReadableStreamToArray(resultStream);
582+
583+
expect(result).toEqual(inputChunks);
584+
// Both steps should have been attempted
585+
expect(onStepFinishCallback).toHaveBeenCalledTimes(2);
586+
// Error should have been logged twice
587+
expect(mockErrorHandler).toHaveBeenCalledTimes(2);
588+
expect(mockErrorHandler).toHaveBeenCalledWith(expect.any(Error));
589+
});
590+
591+
it('should handle continuation scenario with onStepFinish', async () => {
592+
const onStepFinishCallback = vi.fn();
593+
const inputChunks: UIMessageChunk[] = [
594+
{ type: 'start', messageId: 'assistant-msg-1' },
595+
{ type: 'text-start', id: 'text-1' },
596+
{ type: 'text-delta', id: 'text-1', delta: ' continued' },
597+
{ type: 'text-end', id: 'text-1' },
598+
{ type: 'finish-step' },
599+
{ type: 'finish' },
600+
];
601+
602+
const originalMessages: UIMessage[] = [
603+
{
604+
id: 'user-msg-1',
605+
role: 'user',
606+
parts: [{ type: 'text', text: 'Continue this' }],
607+
},
608+
{
609+
id: 'assistant-msg-1',
610+
role: 'assistant',
611+
parts: [{ type: 'text', text: 'This is' }],
612+
},
613+
];
614+
615+
const stream = createUIMessageStream(inputChunks);
616+
617+
const resultStream = handleUIMessageStreamFinish<UIMessage>({
618+
stream,
619+
messageId: 'msg-999',
620+
originalMessages,
621+
onError: mockErrorHandler,
622+
onStepFinish: onStepFinishCallback,
623+
});
624+
625+
await convertReadableStreamToArray(resultStream);
626+
627+
expect(onStepFinishCallback).toHaveBeenCalledTimes(1);
628+
629+
const callArgs = onStepFinishCallback.mock.calls[0][0];
630+
expect(callArgs.isContinuation).toBe(true);
631+
expect(callArgs.responseMessage.id).toBe('assistant-msg-1');
632+
expect(callArgs.messages).toHaveLength(2);
633+
});
634+
635+
it('should provide deep-cloned messages in onStepFinish to prevent mutation', async () => {
636+
const onStepFinishCallback = vi.fn();
637+
const onFinishCallback = vi.fn();
638+
const inputChunks: UIMessageChunk[] = [
639+
{ type: 'start', messageId: 'msg-clone' },
640+
{ type: 'text-start', id: 'text-1' },
641+
{ type: 'text-delta', id: 'text-1', delta: 'Hello' },
642+
{ type: 'text-end', id: 'text-1' },
643+
{ type: 'finish-step' },
644+
{ type: 'finish' },
645+
];
646+
647+
const stream = createUIMessageStream(inputChunks);
648+
649+
const resultStream = handleUIMessageStreamFinish<UIMessage>({
650+
stream,
651+
messageId: 'msg-clone',
652+
originalMessages: [],
653+
onError: mockErrorHandler,
654+
onStepFinish: event => {
655+
// Mutate the message in the callback
656+
event.responseMessage.parts.push({ type: 'text', text: 'MUTATION!' });
657+
onStepFinishCallback(event);
658+
},
659+
onFinish: onFinishCallback,
660+
});
661+
662+
await convertReadableStreamToArray(resultStream);
663+
664+
// Verify onStepFinish was called and received the mutated message
665+
expect(onStepFinishCallback).toHaveBeenCalledTimes(1);
666+
const stepMessage = onStepFinishCallback.mock.calls[0][0].responseMessage;
667+
expect(stepMessage.parts).toHaveLength(2); // Original + mutation
668+
669+
// onFinish should NOT see the mutation from onStepFinish
670+
const finishMessage = onFinishCallback.mock.calls[0][0].responseMessage;
671+
expect(finishMessage.parts).toHaveLength(1);
672+
});
673+
674+
it('should not process stream when neither onFinish nor onStepFinish is provided', async () => {
675+
const inputChunks: UIMessageChunk[] = [
676+
{ type: 'start', messageId: 'msg-passthrough' },
677+
{ type: 'text-start', id: 'text-1' },
678+
{ type: 'text-delta', id: 'text-1', delta: 'Test' },
679+
{ type: 'text-end', id: 'text-1' },
680+
{ type: 'finish-step' },
681+
{ type: 'finish' },
682+
];
683+
684+
const stream = createUIMessageStream(inputChunks);
685+
686+
const resultStream = handleUIMessageStreamFinish<UIMessage>({
687+
stream,
688+
messageId: 'msg-passthrough',
689+
originalMessages: [],
690+
onError: mockErrorHandler,
691+
// Neither onFinish nor onStepFinish provided
692+
});
693+
694+
const result = await convertReadableStreamToArray(resultStream);
695+
696+
expect(result).toEqual(inputChunks);
697+
expect(mockErrorHandler).not.toHaveBeenCalled();
698+
});
699+
});
429700
});

0 commit comments

Comments
 (0)