@@ -426,4 +426,275 @@ describe('handleUIMessageStreamFinish', () => {
426426 expect ( callArgs . responseMessage . id ) . toBe ( 'msg-1' ) ;
427427 } ) ;
428428 } ) ;
429+
430+ describe ( 'onStepFinish callback' , ( ) => {
431+ it ( 'should call onStepFinish when finish-step chunk is encountered' , async ( ) => {
432+ const onStepFinishCallback = vi . fn ( ) ;
433+ const inputChunks : UIMessageChunk [ ] = [
434+ { type : 'start' , messageId : 'msg-step-1' } ,
435+ { type : 'text-start' , id : 'text-1' } ,
436+ { type : 'text-delta' , id : 'text-1' , delta : 'Step 1 text' } ,
437+ { type : 'text-end' , id : 'text-1' } ,
438+ { type : 'finish-step' } ,
439+ { type : 'finish' } ,
440+ ] ;
441+
442+ const originalMessages : UIMessage [ ] = [
443+ {
444+ id : 'user-msg-1' ,
445+ role : 'user' ,
446+ parts : [ { type : 'text' , text : 'Hello' } ] ,
447+ } ,
448+ ] ;
449+
450+ const stream = createUIMessageStream ( inputChunks ) ;
451+
452+ const resultStream = handleUIMessageStreamFinish < UIMessage > ( {
453+ stream,
454+ messageId : 'msg-step-1' ,
455+ originalMessages,
456+ onError : mockErrorHandler ,
457+ onStepFinish : onStepFinishCallback ,
458+ } ) ;
459+
460+ const result = await convertReadableStreamToArray ( resultStream ) ;
461+
462+ expect ( result ) . toEqual ( inputChunks ) ;
463+ expect ( onStepFinishCallback ) . toHaveBeenCalledTimes ( 1 ) ;
464+
465+ const callArgs = onStepFinishCallback . mock . calls [ 0 ] [ 0 ] ;
466+ expect ( callArgs . isContinuation ) . toBe ( false ) ;
467+ expect ( callArgs . responseMessage . id ) . toBe ( 'msg-step-1' ) ;
468+ expect ( callArgs . responseMessage . role ) . toBe ( 'assistant' ) ;
469+ expect ( callArgs . messages ) . toHaveLength ( 2 ) ;
470+ expect ( callArgs . messages [ 0 ] ) . toEqual ( originalMessages [ 0 ] ) ;
471+ expect ( callArgs . messages [ 1 ] . id ) . toBe ( 'msg-step-1' ) ;
472+ } ) ;
473+
474+ it ( 'should call onStepFinish multiple times for multiple steps' , async ( ) => {
475+ const onStepFinishCallback = vi . fn ( ) ;
476+ const inputChunks : UIMessageChunk [ ] = [
477+ { type : 'start' , messageId : 'msg-multi-step' } ,
478+ // Step 1
479+ { type : 'text-start' , id : 'text-1' } ,
480+ { type : 'text-delta' , id : 'text-1' , delta : 'Step 1' } ,
481+ { type : 'text-end' , id : 'text-1' } ,
482+ { type : 'finish-step' } ,
483+ // Step 2
484+ { type : 'start-step' } ,
485+ { type : 'text-start' , id : 'text-2' } ,
486+ { type : 'text-delta' , id : 'text-2' , delta : 'Step 2' } ,
487+ { type : 'text-end' , id : 'text-2' } ,
488+ { type : 'finish-step' } ,
489+ // Step 3
490+ { type : 'start-step' } ,
491+ { type : 'text-start' , id : 'text-3' } ,
492+ { type : 'text-delta' , id : 'text-3' , delta : 'Step 3' } ,
493+ { type : 'text-end' , id : 'text-3' } ,
494+ { type : 'finish-step' } ,
495+ { type : 'finish' } ,
496+ ] ;
497+
498+ const stream = createUIMessageStream ( inputChunks ) ;
499+
500+ const resultStream = handleUIMessageStreamFinish < UIMessage > ( {
501+ stream,
502+ messageId : 'msg-multi-step' ,
503+ originalMessages : [ ] ,
504+ onError : mockErrorHandler ,
505+ onStepFinish : onStepFinishCallback ,
506+ } ) ;
507+
508+ await convertReadableStreamToArray ( resultStream ) ;
509+
510+ expect ( onStepFinishCallback ) . toHaveBeenCalledTimes ( 3 ) ;
511+
512+ // Verify each step has the correct accumulated content
513+ const step1Args = onStepFinishCallback . mock . calls [ 0 ] [ 0 ] ;
514+ expect ( step1Args . responseMessage . parts ) . toHaveLength ( 1 ) ;
515+
516+ const step2Args = onStepFinishCallback . mock . calls [ 1 ] [ 0 ] ;
517+ expect ( step2Args . responseMessage . parts ) . toHaveLength ( 3 ) ; // step-start + 2 text parts
518+
519+ const step3Args = onStepFinishCallback . mock . calls [ 2 ] [ 0 ] ;
520+ expect ( step3Args . responseMessage . parts ) . toHaveLength ( 5 ) ; // 2 step-starts + 3 text parts
521+ } ) ;
522+
523+ it ( 'should call both onStepFinish and onFinish when both are provided' , async ( ) => {
524+ const onStepFinishCallback = vi . fn ( ) ;
525+ const onFinishCallback = vi . fn ( ) ;
526+ const inputChunks : UIMessageChunk [ ] = [
527+ { type : 'start' , messageId : 'msg-both' } ,
528+ { type : 'text-start' , id : 'text-1' } ,
529+ { type : 'text-delta' , id : 'text-1' , delta : 'Hello' } ,
530+ { type : 'text-end' , id : 'text-1' } ,
531+ { type : 'finish-step' } ,
532+ { type : 'finish' } ,
533+ ] ;
534+
535+ const stream = createUIMessageStream ( inputChunks ) ;
536+
537+ const resultStream = handleUIMessageStreamFinish < UIMessage > ( {
538+ stream,
539+ messageId : 'msg-both' ,
540+ originalMessages : [ ] ,
541+ onError : mockErrorHandler ,
542+ onStepFinish : onStepFinishCallback ,
543+ onFinish : onFinishCallback ,
544+ } ) ;
545+
546+ await convertReadableStreamToArray ( resultStream ) ;
547+
548+ expect ( onStepFinishCallback ) . toHaveBeenCalledTimes ( 1 ) ;
549+ expect ( onFinishCallback ) . toHaveBeenCalledTimes ( 1 ) ;
550+ } ) ;
551+
552+ it ( 'should handle onStepFinish errors by logging and continuing' , async ( ) => {
553+ const onStepFinishCallback = vi
554+ . fn ( )
555+ . mockRejectedValue ( new Error ( 'DB error' ) ) ;
556+ const inputChunks : UIMessageChunk [ ] = [
557+ { type : 'start' , messageId : 'msg-error' } ,
558+ { type : 'text-start' , id : 'text-1' } ,
559+ { type : 'text-delta' , id : 'text-1' , delta : 'Step 1' } ,
560+ { type : 'text-end' , id : 'text-1' } ,
561+ { type : 'finish-step' } ,
562+ { type : 'start-step' } ,
563+ { type : 'text-start' , id : 'text-2' } ,
564+ { type : 'text-delta' , id : 'text-2' , delta : 'Step 2' } ,
565+ { type : 'text-end' , id : 'text-2' } ,
566+ { type : 'finish-step' } ,
567+ { type : 'finish' } ,
568+ ] ;
569+
570+ const stream = createUIMessageStream ( inputChunks ) ;
571+
572+ const resultStream = handleUIMessageStreamFinish < UIMessage > ( {
573+ stream,
574+ messageId : 'msg-error' ,
575+ originalMessages : [ ] ,
576+ onError : mockErrorHandler ,
577+ onStepFinish : onStepFinishCallback ,
578+ } ) ;
579+
580+ // Stream should complete without throwing
581+ const result = await convertReadableStreamToArray ( resultStream ) ;
582+
583+ expect ( result ) . toEqual ( inputChunks ) ;
584+ // Both steps should have been attempted
585+ expect ( onStepFinishCallback ) . toHaveBeenCalledTimes ( 2 ) ;
586+ // Error should have been logged twice
587+ expect ( mockErrorHandler ) . toHaveBeenCalledTimes ( 2 ) ;
588+ expect ( mockErrorHandler ) . toHaveBeenCalledWith ( expect . any ( Error ) ) ;
589+ } ) ;
590+
591+ it ( 'should handle continuation scenario with onStepFinish' , async ( ) => {
592+ const onStepFinishCallback = vi . fn ( ) ;
593+ const inputChunks : UIMessageChunk [ ] = [
594+ { type : 'start' , messageId : 'assistant-msg-1' } ,
595+ { type : 'text-start' , id : 'text-1' } ,
596+ { type : 'text-delta' , id : 'text-1' , delta : ' continued' } ,
597+ { type : 'text-end' , id : 'text-1' } ,
598+ { type : 'finish-step' } ,
599+ { type : 'finish' } ,
600+ ] ;
601+
602+ const originalMessages : UIMessage [ ] = [
603+ {
604+ id : 'user-msg-1' ,
605+ role : 'user' ,
606+ parts : [ { type : 'text' , text : 'Continue this' } ] ,
607+ } ,
608+ {
609+ id : 'assistant-msg-1' ,
610+ role : 'assistant' ,
611+ parts : [ { type : 'text' , text : 'This is' } ] ,
612+ } ,
613+ ] ;
614+
615+ const stream = createUIMessageStream ( inputChunks ) ;
616+
617+ const resultStream = handleUIMessageStreamFinish < UIMessage > ( {
618+ stream,
619+ messageId : 'msg-999' ,
620+ originalMessages,
621+ onError : mockErrorHandler ,
622+ onStepFinish : onStepFinishCallback ,
623+ } ) ;
624+
625+ await convertReadableStreamToArray ( resultStream ) ;
626+
627+ expect ( onStepFinishCallback ) . toHaveBeenCalledTimes ( 1 ) ;
628+
629+ const callArgs = onStepFinishCallback . mock . calls [ 0 ] [ 0 ] ;
630+ expect ( callArgs . isContinuation ) . toBe ( true ) ;
631+ expect ( callArgs . responseMessage . id ) . toBe ( 'assistant-msg-1' ) ;
632+ expect ( callArgs . messages ) . toHaveLength ( 2 ) ;
633+ } ) ;
634+
635+ it ( 'should provide deep-cloned messages in onStepFinish to prevent mutation' , async ( ) => {
636+ const onStepFinishCallback = vi . fn ( ) ;
637+ const onFinishCallback = vi . fn ( ) ;
638+ const inputChunks : UIMessageChunk [ ] = [
639+ { type : 'start' , messageId : 'msg-clone' } ,
640+ { type : 'text-start' , id : 'text-1' } ,
641+ { type : 'text-delta' , id : 'text-1' , delta : 'Hello' } ,
642+ { type : 'text-end' , id : 'text-1' } ,
643+ { type : 'finish-step' } ,
644+ { type : 'finish' } ,
645+ ] ;
646+
647+ const stream = createUIMessageStream ( inputChunks ) ;
648+
649+ const resultStream = handleUIMessageStreamFinish < UIMessage > ( {
650+ stream,
651+ messageId : 'msg-clone' ,
652+ originalMessages : [ ] ,
653+ onError : mockErrorHandler ,
654+ onStepFinish : event => {
655+ // Mutate the message in the callback
656+ event . responseMessage . parts . push ( { type : 'text' , text : 'MUTATION!' } ) ;
657+ onStepFinishCallback ( event ) ;
658+ } ,
659+ onFinish : onFinishCallback ,
660+ } ) ;
661+
662+ await convertReadableStreamToArray ( resultStream ) ;
663+
664+ // Verify onStepFinish was called and received the mutated message
665+ expect ( onStepFinishCallback ) . toHaveBeenCalledTimes ( 1 ) ;
666+ const stepMessage = onStepFinishCallback . mock . calls [ 0 ] [ 0 ] . responseMessage ;
667+ expect ( stepMessage . parts ) . toHaveLength ( 2 ) ; // Original + mutation
668+
669+ // onFinish should NOT see the mutation from onStepFinish
670+ const finishMessage = onFinishCallback . mock . calls [ 0 ] [ 0 ] . responseMessage ;
671+ expect ( finishMessage . parts ) . toHaveLength ( 1 ) ;
672+ } ) ;
673+
674+ it ( 'should not process stream when neither onFinish nor onStepFinish is provided' , async ( ) => {
675+ const inputChunks : UIMessageChunk [ ] = [
676+ { type : 'start' , messageId : 'msg-passthrough' } ,
677+ { type : 'text-start' , id : 'text-1' } ,
678+ { type : 'text-delta' , id : 'text-1' , delta : 'Test' } ,
679+ { type : 'text-end' , id : 'text-1' } ,
680+ { type : 'finish-step' } ,
681+ { type : 'finish' } ,
682+ ] ;
683+
684+ const stream = createUIMessageStream ( inputChunks ) ;
685+
686+ const resultStream = handleUIMessageStreamFinish < UIMessage > ( {
687+ stream,
688+ messageId : 'msg-passthrough' ,
689+ originalMessages : [ ] ,
690+ onError : mockErrorHandler ,
691+ // Neither onFinish nor onStepFinish provided
692+ } ) ;
693+
694+ const result = await convertReadableStreamToArray ( resultStream ) ;
695+
696+ expect ( result ) . toEqual ( inputChunks ) ;
697+ expect ( mockErrorHandler ) . not . toHaveBeenCalled ( ) ;
698+ } ) ;
699+ } ) ;
429700} ) ;
0 commit comments