Skip to content

Commit 36cbafc

Browse files
Persist Anthropic batch result hints
1 parent 4356ee8 commit 36cbafc

9 files changed

Lines changed: 336 additions & 9 deletions

File tree

internal/batch/store_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,30 @@ func TestSerializeBatchValidatesID(t *testing.T) {
2727
})
2828
}
2929

30+
func TestSerializeBatchPreservesRequestEndpointHints(t *testing.T) {
31+
raw, err := serializeBatch(&core.BatchResponse{
32+
ID: "batch_123",
33+
RequestEndpointByCustomID: map[string]string{
34+
"resp-1": "/v1/responses",
35+
"chat-1": "/v1/chat/completions",
36+
},
37+
})
38+
if err != nil {
39+
t.Fatalf("serializeBatch() error = %v", err)
40+
}
41+
42+
decoded, err := deserializeBatch(raw)
43+
if err != nil {
44+
t.Fatalf("deserializeBatch() error = %v", err)
45+
}
46+
if got := decoded.RequestEndpointByCustomID["resp-1"]; got != "/v1/responses" {
47+
t.Fatalf("RequestEndpointByCustomID[resp-1] = %q, want /v1/responses", got)
48+
}
49+
if got := decoded.RequestEndpointByCustomID["chat-1"]; got != "/v1/chat/completions" {
50+
t.Fatalf("RequestEndpointByCustomID[chat-1] = %q, want /v1/chat/completions", got)
51+
}
52+
}
53+
3054
func TestNewRequiresConfig(t *testing.T) {
3155
_, err := New(context.Background(), nil)
3256
if err == nil {

internal/core/batch.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ type BatchResponse struct {
8686
RequestCounts BatchRequestCounts `json:"request_counts"`
8787
Metadata map[string]string `json:"metadata,omitempty"`
8888

89+
// Gateway-internal batch item endpoint hints persisted so providers can
90+
// re-shape native results after process restarts without leaking them in
91+
// API responses.
92+
RequestEndpointByCustomID map[string]string `json:"request_endpoint_by_custom_id,omitempty" swaggerignore:"true"`
93+
8994
// Gateway extension: optional usage/result snapshots persisted by the gateway.
9095
Usage BatchUsageSummary `json:"usage"`
9196
Results []BatchResultItem `json:"results,omitempty"`

internal/core/interfaces.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ type NativeBatchProvider interface {
3838
GetBatchResults(ctx context.Context, id string) (*BatchResultsResponse, error)
3939
}
4040

41+
// BatchResultHintAwareProvider is an optional native batch extension for
42+
// providers that need persisted per-item endpoint hints to shape results.
43+
type BatchResultHintAwareProvider interface {
44+
GetBatchResultsWithHints(ctx context.Context, id string, endpointByCustomID map[string]string) (*BatchResultsResponse, error)
45+
ClearBatchResultHints(batchID string)
46+
}
47+
4148
// NativeBatchRoutableProvider extends routing with native batch operations.
4249
type NativeBatchRoutableProvider interface {
4350
CreateBatch(ctx context.Context, providerType string, req *BatchRequest) (*BatchResponse, error)
@@ -47,6 +54,13 @@ type NativeBatchRoutableProvider interface {
4754
GetBatchResults(ctx context.Context, providerType, id string) (*BatchResultsResponse, error)
4855
}
4956

57+
// NativeBatchHintRoutableProvider is an optional routing extension for
58+
// providers that can consume persisted per-item endpoint hints.
59+
type NativeBatchHintRoutableProvider interface {
60+
GetBatchResultsWithHints(ctx context.Context, providerType, id string, endpointByCustomID map[string]string) (*BatchResultsResponse, error)
61+
ClearBatchResultHints(providerType, batchID string)
62+
}
63+
5064
// NativeFileProvider is implemented by providers that support OpenAI-compatible files APIs.
5165
type NativeFileProvider interface {
5266
CreateFile(ctx context.Context, req *FileCreateRequest) (*FileObject, error)

internal/providers/anthropic/anthropic.go

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,9 @@ func (p *Provider) SetBaseURL(url string) {
9090
p.client.SetBaseURL(url)
9191
}
9292

93-
func (p *Provider) setBatchResultEndpoints(batchID string, endpoints map[string]string) {
94-
batchID = strings.TrimSpace(batchID)
95-
if batchID == "" || len(endpoints) == 0 {
96-
return
93+
func cloneBatchResultEndpoints(endpoints map[string]string) map[string]string {
94+
if len(endpoints) == 0 {
95+
return nil
9796
}
9897
cloned := make(map[string]string, len(endpoints))
9998
for customID, endpoint := range endpoints {
@@ -104,6 +103,18 @@ func (p *Provider) setBatchResultEndpoints(batchID string, endpoints map[string]
104103
}
105104
cloned[customID] = endpoint
106105
}
106+
if len(cloned) == 0 {
107+
return nil
108+
}
109+
return cloned
110+
}
111+
112+
func (p *Provider) setBatchResultEndpoints(batchID string, endpoints map[string]string) {
113+
batchID = strings.TrimSpace(batchID)
114+
if batchID == "" || len(endpoints) == 0 {
115+
return
116+
}
117+
cloned := cloneBatchResultEndpoints(endpoints)
107118
if len(cloned) == 0 {
108119
return
109120
}
@@ -115,6 +126,18 @@ func (p *Provider) setBatchResultEndpoints(batchID string, endpoints map[string]
115126
p.batchEndpointsMu.Unlock()
116127
}
117128

129+
func (p *Provider) clearBatchResultEndpoints(batchID string) {
130+
batchID = strings.TrimSpace(batchID)
131+
if batchID == "" {
132+
return
133+
}
134+
p.batchEndpointsMu.Lock()
135+
if p.batchResultEndpoints != nil {
136+
delete(p.batchResultEndpoints, batchID)
137+
}
138+
p.batchEndpointsMu.Unlock()
139+
}
140+
118141
func (p *Provider) getBatchResultEndpoints(batchID string) map[string]string {
119142
batchID = strings.TrimSpace(batchID)
120143
if batchID == "" {
@@ -1130,6 +1153,7 @@ func (p *Provider) CreateBatch(ctx context.Context, req *core.BatchRequest) (*co
11301153
return nil, core.NewProviderError("anthropic", http.StatusBadGateway, "failed to map anthropic batch response", nil)
11311154
}
11321155
mapped.ProviderBatchID = mapped.ID
1156+
mapped.RequestEndpointByCustomID = cloneBatchResultEndpoints(endpointByCustomID)
11331157
p.setBatchResultEndpoints(mapped.ProviderBatchID, endpointByCustomID)
11341158
return mapped, nil
11351159
}
@@ -1214,8 +1238,7 @@ func (p *Provider) CancelBatch(ctx context.Context, id string) (*core.BatchRespo
12141238
return mapped, nil
12151239
}
12161240

1217-
// GetBatchResults retrieves Anthropic native message batch results.
1218-
func (p *Provider) GetBatchResults(ctx context.Context, id string) (*core.BatchResultsResponse, error) {
1241+
func (p *Provider) getBatchResults(ctx context.Context, id string, endpointByCustomID map[string]string) (*core.BatchResultsResponse, error) {
12191242
resp, err := p.client.DoPassthrough(ctx, llmclient.Request{
12201243
Method: http.MethodGet,
12211244
Endpoint: "/messages/batches/" + url.PathEscape(id) + "/results",
@@ -1236,7 +1259,11 @@ func (p *Provider) GetBatchResults(ctx context.Context, id string) (*core.BatchR
12361259
scanner := bufio.NewScanner(resp.Body)
12371260
// Allow larger result lines than Scanner's default 64K.
12381261
scanner.Buffer(make([]byte, 0, 64*1024), 4*1024*1024)
1239-
endpointByCustomID := p.getBatchResultEndpoints(id)
1262+
if len(endpointByCustomID) == 0 {
1263+
endpointByCustomID = p.getBatchResultEndpoints(id)
1264+
} else {
1265+
endpointByCustomID = cloneBatchResultEndpoints(endpointByCustomID)
1266+
}
12401267

12411268
results := make([]core.BatchResultItem, 0)
12421269
index := 0
@@ -1322,6 +1349,23 @@ func (p *Provider) GetBatchResults(ctx context.Context, id string) (*core.BatchR
13221349
}, nil
13231350
}
13241351

1352+
// GetBatchResults retrieves Anthropic native message batch results.
1353+
func (p *Provider) GetBatchResults(ctx context.Context, id string) (*core.BatchResultsResponse, error) {
1354+
return p.getBatchResults(ctx, id, nil)
1355+
}
1356+
1357+
// GetBatchResultsWithHints retrieves Anthropic native batch results using
1358+
// persisted per-item endpoint hints instead of transient in-memory state.
1359+
func (p *Provider) GetBatchResultsWithHints(ctx context.Context, id string, endpointByCustomID map[string]string) (*core.BatchResultsResponse, error) {
1360+
return p.getBatchResults(ctx, id, endpointByCustomID)
1361+
}
1362+
1363+
// ClearBatchResultHints clears transient per-batch endpoint hints once they
1364+
// have been persisted by the gateway.
1365+
func (p *Provider) ClearBatchResultHints(batchID string) {
1366+
p.clearBatchResultEndpoints(batchID)
1367+
}
1368+
13251369
// Embeddings returns an error because Anthropic does not natively support embeddings.
13261370
// Voyage AI (Anthropic's recommended embedding provider) may be added in the future.
13271371
func (p *Provider) Embeddings(_ context.Context, _ *core.EmbeddingRequest) (*core.EmbeddingResponse, error) {

internal/providers/anthropic/anthropic_test.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,52 @@ func TestGetBatchResults(t *testing.T) {
153153
}
154154
}
155155

156+
func TestGetBatchResultsWithHints(t *testing.T) {
157+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
158+
if r.URL.Path != "/messages/batches/batch_1/results" {
159+
http.NotFound(w, r)
160+
return
161+
}
162+
163+
w.WriteHeader(http.StatusOK)
164+
_, _ = w.Write([]byte(
165+
`{"custom_id":"ok-1","result":{"type":"succeeded","message":{"id":"msg_123","type":"message","role":"assistant","model":"claude-sonnet-4-5-20250929","content":[{"type":"text","text":"hi"}],"stop_reason":"end_turn","usage":{"input_tokens":1,"output_tokens":1}}}}`,
166+
))
167+
}))
168+
defer server.Close()
169+
170+
provider := NewWithHTTPClient("test-api-key", nil, llmclient.Hooks{})
171+
provider.SetBaseURL(server.URL)
172+
173+
resp, err := provider.GetBatchResultsWithHints(context.Background(), "batch_1", map[string]string{
174+
"ok-1": "/v1/responses",
175+
})
176+
if err != nil {
177+
t.Fatalf("unexpected error: %v", err)
178+
}
179+
if len(resp.Data) != 1 {
180+
t.Fatalf("len(Data) = %d, want 1", len(resp.Data))
181+
}
182+
if resp.Data[0].URL != "/v1/responses" {
183+
t.Fatalf("URL = %q, want /v1/responses", resp.Data[0].URL)
184+
}
185+
}
186+
187+
func TestClearBatchResultHints(t *testing.T) {
188+
provider := &Provider{
189+
batchResultEndpoints: map[string]map[string]string{
190+
"batch_1": {
191+
"resp-1": "/v1/responses",
192+
},
193+
},
194+
}
195+
196+
provider.ClearBatchResultHints("batch_1")
197+
if got := provider.getBatchResultEndpoints("batch_1"); got != nil {
198+
t.Fatalf("batch_1 hints should be cleared, got %#v", got)
199+
}
200+
}
201+
156202
func TestChatCompletion(t *testing.T) {
157203
tests := []struct {
158204
name string

internal/providers/router.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
"io"
99
"net/http"
10+
"strings"
1011

1112
"gomodel/internal/core"
1213
)
@@ -439,6 +440,34 @@ func (r *Router) GetBatchResults(ctx context.Context, providerType, id string) (
439440
})
440441
}
441442

443+
// GetBatchResultsWithHints routes native batch results lookup with persisted
444+
// per-item endpoint hints when the provider supports them.
445+
func (r *Router) GetBatchResultsWithHints(ctx context.Context, providerType, id string, endpointByCustomID map[string]string) (*core.BatchResultsResponse, error) {
446+
return routeNativeBatchCall(r, ctx, providerType, func(ctx context.Context, bp core.NativeBatchProvider) (*core.BatchResultsResponse, error) {
447+
if hinted, ok := bp.(core.BatchResultHintAwareProvider); ok && len(endpointByCustomID) > 0 {
448+
return hinted.GetBatchResultsWithHints(ctx, id, endpointByCustomID)
449+
}
450+
return bp.GetBatchResults(ctx, id)
451+
})
452+
}
453+
454+
// ClearBatchResultHints clears transient provider-side batch result hints once
455+
// they have been persisted by the gateway.
456+
func (r *Router) ClearBatchResultHints(providerType, batchID string) {
457+
if strings.TrimSpace(batchID) == "" {
458+
return
459+
}
460+
bp, err := r.resolveNativeBatchProvider(providerType)
461+
if err != nil {
462+
return
463+
}
464+
hinted, ok := bp.(core.BatchResultHintAwareProvider)
465+
if !ok {
466+
return
467+
}
468+
hinted.ClearBatchResultHints(batchID)
469+
}
470+
442471
// CreateFile routes file upload to a provider type.
443472
func (r *Router) CreateFile(ctx context.Context, providerType string, req *core.FileCreateRequest) (*core.FileObject, error) {
444473
resp, err := routeNativeFileCall(r, ctx, providerType, func(ctx context.Context, fp core.NativeFileProvider) (*core.FileObject, error) {

internal/providers/router_test.go

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,10 @@ func (m *mockProvider) Passthrough(_ context.Context, req *core.PassthroughReque
142142

143143
type mockBatchProvider struct {
144144
mockProvider
145-
listBatchesResp *core.BatchListResponse
145+
listBatchesResp *core.BatchListResponse
146+
hintedBatchResults *core.BatchResultsResponse
147+
capturedBatchHints map[string]string
148+
clearedBatchHintID string
146149
}
147150

148151
func (m *mockBatchProvider) CreateBatch(_ context.Context, _ *core.BatchRequest) (*core.BatchResponse, error) {
@@ -168,6 +171,23 @@ func (m *mockBatchProvider) GetBatchResults(_ context.Context, _ string) (*core.
168171
return &core.BatchResultsResponse{Object: "list", BatchID: "provider-batch-1"}, nil
169172
}
170173

174+
func (m *mockBatchProvider) GetBatchResultsWithHints(_ context.Context, _ string, endpointByCustomID map[string]string) (*core.BatchResultsResponse, error) {
175+
if len(endpointByCustomID) > 0 {
176+
m.capturedBatchHints = make(map[string]string, len(endpointByCustomID))
177+
for customID, endpoint := range endpointByCustomID {
178+
m.capturedBatchHints[customID] = endpoint
179+
}
180+
}
181+
if m.hintedBatchResults != nil {
182+
return m.hintedBatchResults, nil
183+
}
184+
return m.GetBatchResults(context.Background(), "")
185+
}
186+
187+
func (m *mockBatchProvider) ClearBatchResultHints(batchID string) {
188+
m.clearedBatchHintID = batchID
189+
}
190+
171191
func (m *mockBatchProvider) CreateFile(_ context.Context, req *core.FileCreateRequest) (*core.FileObject, error) {
172192
return &core.FileObject{
173193
ID: "file_1",
@@ -676,6 +696,39 @@ func TestRouterListBatchesSetsProviderOnItems(t *testing.T) {
676696
}
677697
}
678698

699+
func TestRouterGetBatchResultsWithHintsUsesHintAwareProvider(t *testing.T) {
700+
provider := &mockBatchProvider{
701+
hintedBatchResults: &core.BatchResultsResponse{
702+
Object: "list",
703+
BatchID: "provider-batch-1",
704+
Data: []core.BatchResultItem{
705+
{Index: 0, URL: "/v1/responses"},
706+
},
707+
},
708+
}
709+
lookup := newMockLookup()
710+
lookup.addModel("claude-sonnet", provider, "anthropic")
711+
712+
router, _ := NewRouter(lookup)
713+
resp, err := router.GetBatchResultsWithHints(context.Background(), "anthropic", "provider-batch-1", map[string]string{
714+
"resp-1": "/v1/responses",
715+
})
716+
if err != nil {
717+
t.Fatalf("unexpected error: %v", err)
718+
}
719+
if resp == nil || len(resp.Data) != 1 {
720+
t.Fatalf("unexpected response: %+v", resp)
721+
}
722+
if got := provider.capturedBatchHints["resp-1"]; got != "/v1/responses" {
723+
t.Fatalf("capturedBatchHints[resp-1] = %q, want /v1/responses", got)
724+
}
725+
726+
router.ClearBatchResultHints("anthropic", "provider-batch-1")
727+
if provider.clearedBatchHintID != "provider-batch-1" {
728+
t.Fatalf("clearedBatchHintID = %q, want provider-batch-1", provider.clearedBatchHintID)
729+
}
730+
}
731+
679732
func TestRouterEmbeddings(t *testing.T) {
680733
expectedResp := &core.EmbeddingResponse{
681734
Object: "list",

0 commit comments

Comments
 (0)