|
1 | 1 | package providers |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "context" |
4 | 5 | "encoding/json" |
| 6 | + "io" |
5 | 7 | "math" |
6 | 8 | "strings" |
7 | 9 | "testing" |
8 | 10 |
|
9 | 11 | "gomodel/internal/core" |
10 | 12 | ) |
11 | 13 |
|
| 14 | +type capturingChatProvider struct { |
| 15 | + capturedReq *core.ChatRequest |
| 16 | + streamData string |
| 17 | + streamErr error |
| 18 | +} |
| 19 | + |
| 20 | +func (p *capturingChatProvider) ChatCompletion(_ context.Context, _ *core.ChatRequest) (*core.ChatResponse, error) { |
| 21 | + return nil, nil |
| 22 | +} |
| 23 | + |
| 24 | +func (p *capturingChatProvider) StreamChatCompletion(_ context.Context, req *core.ChatRequest) (io.ReadCloser, error) { |
| 25 | + p.capturedReq = req |
| 26 | + if p.streamErr != nil { |
| 27 | + return nil, p.streamErr |
| 28 | + } |
| 29 | + return io.NopCloser(strings.NewReader(p.streamData)), nil |
| 30 | +} |
| 31 | + |
12 | 32 | func TestResponsesFunctionCallIDs(t *testing.T) { |
13 | 33 | t.Run("preserve explicit call id", func(t *testing.T) { |
14 | 34 | const callID = "call_123" |
@@ -763,6 +783,92 @@ func TestExtractContentFromInput(t *testing.T) { |
763 | 783 | } |
764 | 784 | } |
765 | 785 |
|
| 786 | +func TestConvertResponsesRequestToChat_ClonesStreamOptions(t *testing.T) { |
| 787 | + req := &core.ResponsesRequest{ |
| 788 | + Model: "test-model", |
| 789 | + Input: "hello", |
| 790 | + Stream: true, |
| 791 | + StreamOptions: &core.StreamOptions{IncludeUsage: false}, |
| 792 | + } |
| 793 | + |
| 794 | + chatReq, err := ConvertResponsesRequestToChat(req) |
| 795 | + if err != nil { |
| 796 | + t.Fatalf("ConvertResponsesRequestToChat() error = %v", err) |
| 797 | + } |
| 798 | + if chatReq.StreamOptions == nil { |
| 799 | + t.Fatal("StreamOptions = nil, want cloned value") |
| 800 | + } |
| 801 | + if chatReq.StreamOptions == req.StreamOptions { |
| 802 | + t.Fatal("StreamOptions pointer was reused") |
| 803 | + } |
| 804 | + if chatReq.StreamOptions.IncludeUsage { |
| 805 | + t.Fatalf("IncludeUsage = %v, want false", chatReq.StreamOptions.IncludeUsage) |
| 806 | + } |
| 807 | +} |
| 808 | + |
| 809 | +func TestStreamResponsesViaChat_InjectsUsageWhenPolicyEnabled(t *testing.T) { |
| 810 | + provider := &capturingChatProvider{ |
| 811 | + streamData: "data: [DONE]\n\n", |
| 812 | + } |
| 813 | + req := &core.ResponsesRequest{ |
| 814 | + Model: "gemini-2.0-flash", |
| 815 | + Input: "hello", |
| 816 | + Stream: true, |
| 817 | + StreamOptions: &core.StreamOptions{IncludeUsage: false}, |
| 818 | + } |
| 819 | + ctx := core.WithEnforceReturningUsageData(context.Background(), true) |
| 820 | + |
| 821 | + stream, err := StreamResponsesViaChat(ctx, provider, req, "gemini") |
| 822 | + if err != nil { |
| 823 | + t.Fatalf("StreamResponsesViaChat() error = %v", err) |
| 824 | + } |
| 825 | + defer func() { |
| 826 | + _ = stream.Close() |
| 827 | + }() |
| 828 | + |
| 829 | + if provider.capturedReq == nil { |
| 830 | + t.Fatal("capturedReq = nil") |
| 831 | + } |
| 832 | + if provider.capturedReq.StreamOptions == nil { |
| 833 | + t.Fatal("captured StreamOptions = nil") |
| 834 | + } |
| 835 | + if !provider.capturedReq.StreamOptions.IncludeUsage { |
| 836 | + t.Fatal("captured IncludeUsage = false, want true") |
| 837 | + } |
| 838 | + if req.StreamOptions == nil { |
| 839 | + t.Fatal("original StreamOptions unexpectedly nil") |
| 840 | + } |
| 841 | + if req.StreamOptions.IncludeUsage { |
| 842 | + t.Fatal("original request was mutated") |
| 843 | + } |
| 844 | +} |
| 845 | + |
| 846 | +func TestStreamResponsesViaChat_DoesNotInjectUsageWhenPolicyDisabled(t *testing.T) { |
| 847 | + provider := &capturingChatProvider{ |
| 848 | + streamData: "data: [DONE]\n\n", |
| 849 | + } |
| 850 | + req := &core.ResponsesRequest{ |
| 851 | + Model: "gemini-2.0-flash", |
| 852 | + Input: "hello", |
| 853 | + Stream: true, |
| 854 | + } |
| 855 | + |
| 856 | + stream, err := StreamResponsesViaChat(context.Background(), provider, req, "gemini") |
| 857 | + if err != nil { |
| 858 | + t.Fatalf("StreamResponsesViaChat() error = %v", err) |
| 859 | + } |
| 860 | + defer func() { |
| 861 | + _ = stream.Close() |
| 862 | + }() |
| 863 | + |
| 864 | + if provider.capturedReq == nil { |
| 865 | + t.Fatal("capturedReq = nil") |
| 866 | + } |
| 867 | + if provider.capturedReq.StreamOptions != nil { |
| 868 | + t.Fatalf("captured StreamOptions = %+v, want nil", provider.capturedReq.StreamOptions) |
| 869 | + } |
| 870 | +} |
| 871 | + |
766 | 872 | func boolPtr(v bool) *bool { |
767 | 873 | return &v |
768 | 874 | } |
0 commit comments