Skip to content

Commit f5a9e75

Browse files
fix: guarantee usage rows for streamed responses (#141)
1 parent 8f4d0d5 commit f5a9e75

10 files changed

Lines changed: 373 additions & 24 deletions

File tree

config/config.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,8 @@ type UsageConfig struct {
196196
// Default: true
197197
Enabled bool `yaml:"enabled" env:"USAGE_ENABLED"`
198198

199-
// EnforceReturningUsageData controls whether to enforce returning usage data in streaming responses.
200-
// When true, stream_options: {"include_usage": true} is automatically added to streaming requests.
199+
// EnforceReturningUsageData controls whether to ask streaming providers to return usage data when possible.
200+
// When true, stream_options: {"include_usage": true} is added for provider paths that support it.
201201
// Default: true
202202
EnforceReturningUsageData bool `yaml:"enforce_returning_usage_data" env:"ENFORCE_RETURNING_USAGE_DATA"`
203203

internal/core/context.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ const (
1212
ingressFrameKey contextKey = "ingress-frame"
1313
// semanticEnvelopeKey stores the best-effort semantic extraction for the request.
1414
semanticEnvelopeKey contextKey = "semantic-envelope"
15+
// enforceReturningUsageDataKey stores whether streaming requests should ask providers
16+
// to include usage when the provider supports it.
17+
enforceReturningUsageDataKey contextKey = "enforce-returning-usage-data"
1518
)
1619

1720
// WithRequestID returns a new context with the request ID attached.
@@ -59,3 +62,19 @@ func GetSemanticEnvelope(ctx context.Context) *SemanticEnvelope {
5962
}
6063
return nil
6164
}
65+
66+
// WithEnforceReturningUsageData returns a new context with the streaming usage policy attached.
67+
func WithEnforceReturningUsageData(ctx context.Context, enforce bool) context.Context {
68+
return context.WithValue(ctx, enforceReturningUsageDataKey, enforce)
69+
}
70+
71+
// GetEnforceReturningUsageData reports whether the request should ask providers
72+
// to include usage in streaming responses when possible.
73+
func GetEnforceReturningUsageData(ctx context.Context) bool {
74+
if v := ctx.Value(enforceReturningUsageDataKey); v != nil {
75+
if enforce, ok := v.(bool); ok {
76+
return enforce
77+
}
78+
}
79+
return false
80+
}

internal/providers/responses_adapter.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func ConvertResponsesRequestToChat(req *core.ResponsesRequest) (*core.ChatReques
3737
ParallelToolCalls: req.ParallelToolCalls,
3838
Temperature: req.Temperature,
3939
Stream: req.Stream,
40-
StreamOptions: req.StreamOptions,
40+
StreamOptions: cloneStreamOptions(req.StreamOptions),
4141
Reasoning: req.Reasoning,
4242
ExtraFields: core.CloneRawJSONMap(req.ExtraFields),
4343
}
@@ -62,6 +62,14 @@ func ConvertResponsesRequestToChat(req *core.ResponsesRequest) (*core.ChatReques
6262
return chatReq, nil
6363
}
6464

65+
func cloneStreamOptions(src *core.StreamOptions) *core.StreamOptions {
66+
if src == nil {
67+
return nil
68+
}
69+
cloned := *src
70+
return &cloned
71+
}
72+
6573
func normalizeResponsesToolsForChat(tools []map[string]any) []map[string]any {
6674
if len(tools) == 0 {
6775
return nil
@@ -982,6 +990,12 @@ func StreamResponsesViaChat(ctx context.Context, p ChatProvider, req *core.Respo
982990
if err != nil {
983991
return nil, err
984992
}
993+
if core.GetEnforceReturningUsageData(ctx) {
994+
if chatReq.StreamOptions == nil {
995+
chatReq.StreamOptions = &core.StreamOptions{}
996+
}
997+
chatReq.StreamOptions.IncludeUsage = true
998+
}
985999

9861000
stream, err := p.StreamChatCompletion(ctx, chatReq)
9871001
if err != nil {

internal/providers/responses_adapter_test.go

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,34 @@
11
package providers
22

33
import (
4+
"context"
45
"encoding/json"
6+
"io"
57
"math"
68
"strings"
79
"testing"
810

911
"gomodel/internal/core"
1012
)
1113

14+
type capturingChatProvider struct {
15+
capturedReq *core.ChatRequest
16+
streamData string
17+
streamErr error
18+
}
19+
20+
func (p *capturingChatProvider) ChatCompletion(_ context.Context, _ *core.ChatRequest) (*core.ChatResponse, error) {
21+
return nil, nil
22+
}
23+
24+
func (p *capturingChatProvider) StreamChatCompletion(_ context.Context, req *core.ChatRequest) (io.ReadCloser, error) {
25+
p.capturedReq = req
26+
if p.streamErr != nil {
27+
return nil, p.streamErr
28+
}
29+
return io.NopCloser(strings.NewReader(p.streamData)), nil
30+
}
31+
1232
func TestResponsesFunctionCallIDs(t *testing.T) {
1333
t.Run("preserve explicit call id", func(t *testing.T) {
1434
const callID = "call_123"
@@ -763,6 +783,92 @@ func TestExtractContentFromInput(t *testing.T) {
763783
}
764784
}
765785

786+
func TestConvertResponsesRequestToChat_ClonesStreamOptions(t *testing.T) {
787+
req := &core.ResponsesRequest{
788+
Model: "test-model",
789+
Input: "hello",
790+
Stream: true,
791+
StreamOptions: &core.StreamOptions{IncludeUsage: false},
792+
}
793+
794+
chatReq, err := ConvertResponsesRequestToChat(req)
795+
if err != nil {
796+
t.Fatalf("ConvertResponsesRequestToChat() error = %v", err)
797+
}
798+
if chatReq.StreamOptions == nil {
799+
t.Fatal("StreamOptions = nil, want cloned value")
800+
}
801+
if chatReq.StreamOptions == req.StreamOptions {
802+
t.Fatal("StreamOptions pointer was reused")
803+
}
804+
if chatReq.StreamOptions.IncludeUsage {
805+
t.Fatalf("IncludeUsage = %v, want false", chatReq.StreamOptions.IncludeUsage)
806+
}
807+
}
808+
809+
func TestStreamResponsesViaChat_InjectsUsageWhenPolicyEnabled(t *testing.T) {
810+
provider := &capturingChatProvider{
811+
streamData: "data: [DONE]\n\n",
812+
}
813+
req := &core.ResponsesRequest{
814+
Model: "gemini-2.0-flash",
815+
Input: "hello",
816+
Stream: true,
817+
StreamOptions: &core.StreamOptions{IncludeUsage: false},
818+
}
819+
ctx := core.WithEnforceReturningUsageData(context.Background(), true)
820+
821+
stream, err := StreamResponsesViaChat(ctx, provider, req, "gemini")
822+
if err != nil {
823+
t.Fatalf("StreamResponsesViaChat() error = %v", err)
824+
}
825+
defer func() {
826+
_ = stream.Close()
827+
}()
828+
829+
if provider.capturedReq == nil {
830+
t.Fatal("capturedReq = nil")
831+
}
832+
if provider.capturedReq.StreamOptions == nil {
833+
t.Fatal("captured StreamOptions = nil")
834+
}
835+
if !provider.capturedReq.StreamOptions.IncludeUsage {
836+
t.Fatal("captured IncludeUsage = false, want true")
837+
}
838+
if req.StreamOptions == nil {
839+
t.Fatal("original StreamOptions unexpectedly nil")
840+
}
841+
if req.StreamOptions.IncludeUsage {
842+
t.Fatal("original request was mutated")
843+
}
844+
}
845+
846+
func TestStreamResponsesViaChat_DoesNotInjectUsageWhenPolicyDisabled(t *testing.T) {
847+
provider := &capturingChatProvider{
848+
streamData: "data: [DONE]\n\n",
849+
}
850+
req := &core.ResponsesRequest{
851+
Model: "gemini-2.0-flash",
852+
Input: "hello",
853+
Stream: true,
854+
}
855+
856+
stream, err := StreamResponsesViaChat(context.Background(), provider, req, "gemini")
857+
if err != nil {
858+
t.Fatalf("StreamResponsesViaChat() error = %v", err)
859+
}
860+
defer func() {
861+
_ = stream.Close()
862+
}()
863+
864+
if provider.capturedReq == nil {
865+
t.Fatal("capturedReq = nil")
866+
}
867+
if provider.capturedReq.StreamOptions != nil {
868+
t.Fatalf("captured StreamOptions = %+v, want nil", provider.capturedReq.StreamOptions)
869+
}
870+
}
871+
766872
func boolPtr(v bool) *bool {
767873
return &v
768874
}

internal/server/handlers.go

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,22 @@ func (h *Handler) logUsage(model, providerType string, extractFn func(*core.Mode
228228
}
229229
}
230230

231+
func (h *Handler) shouldEnforceReturningUsageData() bool {
232+
return h.usageLogger != nil && h.usageLogger.Config().EnforceReturningUsageData
233+
}
234+
235+
func cloneChatRequestForStreamUsage(req *core.ChatRequest) *core.ChatRequest {
236+
if req == nil {
237+
return nil
238+
}
239+
cloned := *req
240+
if req.StreamOptions != nil {
241+
streamOptions := *req.StreamOptions
242+
cloned.StreamOptions = &streamOptions
243+
}
244+
return &cloned
245+
}
246+
231247
func resolveModelSelector(ctx context.Context, model, provider *string) error {
232248
return core.NormalizeModelSelector(core.GetSemanticEnvelope(ctx), model, provider)
233249
}
@@ -560,14 +576,16 @@ func (h *Handler) ChatCompletion(c *echo.Context) error {
560576
ctx = core.WithRequestID(ctx, requestID)
561577

562578
if req.Stream {
563-
if h.usageLogger != nil && h.usageLogger.Config().EnforceReturningUsageData {
564-
if req.StreamOptions == nil {
565-
req.StreamOptions = &core.StreamOptions{}
579+
streamReq := req
580+
if h.shouldEnforceReturningUsageData() {
581+
streamReq = cloneChatRequestForStreamUsage(req)
582+
if streamReq.StreamOptions == nil {
583+
streamReq.StreamOptions = &core.StreamOptions{}
566584
}
567-
req.StreamOptions.IncludeUsage = true
585+
streamReq.StreamOptions.IncludeUsage = true
568586
}
569-
return h.handleStreamingResponse(c, req.Model, providerType, func() (io.ReadCloser, error) {
570-
return h.provider.StreamChatCompletion(ctx, req)
587+
return h.handleStreamingResponse(c, streamReq.Model, providerType, func() (io.ReadCloser, error) {
588+
return h.provider.StreamChatCompletion(ctx, streamReq)
571589
})
572590
}
573591

@@ -1040,6 +1058,9 @@ func (h *Handler) Responses(c *echo.Context) error {
10401058
requestID := c.Request().Header.Get("X-Request-ID")
10411059

10421060
if req.Stream {
1061+
if h.shouldEnforceReturningUsageData() {
1062+
ctx = core.WithEnforceReturningUsageData(ctx, true)
1063+
}
10431064
return h.handleStreamingResponse(c, req.Model, providerType, func() (io.ReadCloser, error) {
10441065
return h.provider.StreamResponses(ctx, req)
10451066
})

0 commit comments

Comments
 (0)