Skip to content

Commit 2957beb

Browse files
Rewrite file-backed batch sources before submission
1 parent c1e1a47 commit 2957beb

10 files changed

Lines changed: 847 additions & 117 deletions

File tree

internal/aliases/provider.go

Lines changed: 77 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package aliases
33
import (
44
"context"
55
"encoding/json"
6-
"fmt"
76
"io"
87
"sort"
98

@@ -148,11 +147,12 @@ func (p *Provider) CreateBatch(ctx context.Context, providerType string, req *co
148147
if err != nil {
149148
return nil, err
150149
}
151-
forward, err := p.rewriteBatchRequest(req)
150+
result, err := p.rewriteBatchSource(ctx, providerType, req)
152151
if err != nil {
153152
return nil, err
154153
}
155-
return native.CreateBatch(ctx, providerType, forward)
154+
p.recordBatchPreparation(ctx, req, result.Request)
155+
return native.CreateBatch(ctx, providerType, result.Request)
156156
}
157157

158158
func (p *Provider) GetBatch(ctx context.Context, providerType, id string) (*core.BatchResponse, error) {
@@ -192,11 +192,16 @@ func (p *Provider) CreateBatchWithHints(ctx context.Context, providerType string
192192
if err != nil {
193193
return nil, nil, err
194194
}
195-
forward, err := p.rewriteBatchRequest(req)
195+
result, err := p.rewriteBatchSource(ctx, providerType, req)
196196
if err != nil {
197197
return nil, nil, err
198198
}
199-
return hinted.CreateBatchWithHints(ctx, providerType, forward)
199+
p.recordBatchPreparation(ctx, req, result.Request)
200+
resp, hints, err := hinted.CreateBatchWithHints(ctx, providerType, result.Request)
201+
if err != nil {
202+
return nil, nil, err
203+
}
204+
return resp, mergeBatchHints(result.RequestEndpointHints, hints), nil
200205
}
201206

202207
func (p *Provider) GetBatchResultsWithHints(ctx context.Context, providerType, id string, endpointByCustomID map[string]string) (*core.BatchResultsResponse, error) {
@@ -305,59 +310,81 @@ func (p *Provider) rewriteEmbeddingRequest(req *core.EmbeddingRequest, mode requ
305310
return &forward, nil
306311
}
307312

308-
func (p *Provider) rewriteBatchRequest(req *core.BatchRequest) (*core.BatchRequest, error) {
309-
if req == nil || len(req.Requests) == 0 {
310-
return req, nil
313+
func (p *Provider) rewriteBatchSource(ctx context.Context, providerType string, req *core.BatchRequest) (*core.BatchRewriteResult, error) {
314+
files, err := p.nativeFileRouter()
315+
if err != nil {
316+
files = nil
311317
}
318+
return core.RewriteBatchSource(ctx, providerType, req, files, []string{"chat_completions", "responses", "embeddings"}, p.rewriteBatchDecodedItem)
319+
}
312320

313-
forward := *req
314-
forward.Requests = make([]core.BatchRequestItem, len(req.Requests))
315-
copy(forward.Requests, req.Requests)
316-
317-
for i, item := range forward.Requests {
318-
decoded, handled, err := core.MaybeDecodeKnownBatchItemRequest(req.Endpoint, item, "chat_completions", "responses", "embeddings")
321+
func (p *Provider) rewriteBatchDecodedItem(_ context.Context, _ core.BatchRequestItem, decoded *core.DecodedBatchItemRequest) (json.RawMessage, error) {
322+
switch typed := decoded.Request.(type) {
323+
case *core.ChatRequest:
324+
modified, err := p.rewriteChatRequest(typed, rewriteForUpstream)
319325
if err != nil {
320-
return nil, core.NewInvalidRequestError(fmt.Sprintf("batch item %d: %s", i, err.Error()), err)
326+
return nil, err
321327
}
322-
if !handled {
323-
continue
328+
body, err := json.Marshal(modified)
329+
if err != nil {
330+
return nil, core.NewInvalidRequestError("failed to encode batch item", err)
331+
}
332+
return body, nil
333+
case *core.ResponsesRequest:
334+
modified, err := p.rewriteResponsesRequest(typed, rewriteForUpstream)
335+
if err != nil {
336+
return nil, err
337+
}
338+
body, err := json.Marshal(modified)
339+
if err != nil {
340+
return nil, core.NewInvalidRequestError("failed to encode batch item", err)
341+
}
342+
return body, nil
343+
case *core.EmbeddingRequest:
344+
modified, err := p.rewriteEmbeddingRequest(typed, rewriteForUpstream)
345+
if err != nil {
346+
return nil, err
324347
}
348+
body, err := json.Marshal(modified)
349+
if err != nil {
350+
return nil, core.NewInvalidRequestError("failed to encode batch item", err)
351+
}
352+
return body, nil
353+
default:
354+
return nil, core.NewInvalidRequestError("unsupported batch item url: "+decoded.Endpoint, nil)
355+
}
356+
}
357+
358+
func (p *Provider) recordBatchPreparation(ctx context.Context, original, rewritten *core.BatchRequest) {
359+
if ctx == nil || original == nil || rewritten == nil {
360+
return
361+
}
362+
metadata := core.GetBatchPreparationMetadata(ctx)
363+
if metadata == nil {
364+
return
365+
}
366+
metadata.RecordInputFileRewrite(original.InputFileID, rewritten.InputFileID)
367+
}
325368

326-
var body []byte
327-
switch typed := decoded.Request.(type) {
328-
case *core.ChatRequest:
329-
modified, err := p.rewriteChatRequest(typed, rewriteForUpstream)
330-
if err != nil {
331-
return nil, err
332-
}
333-
body, err = json.Marshal(modified)
334-
if err != nil {
335-
return nil, core.NewInvalidRequestError("failed to encode batch item", err)
336-
}
337-
case *core.ResponsesRequest:
338-
modified, err := p.rewriteResponsesRequest(typed, rewriteForUpstream)
339-
if err != nil {
340-
return nil, err
341-
}
342-
body, err = json.Marshal(modified)
343-
if err != nil {
344-
return nil, core.NewInvalidRequestError("failed to encode batch item", err)
345-
}
346-
case *core.EmbeddingRequest:
347-
modified, err := p.rewriteEmbeddingRequest(typed, rewriteForUpstream)
348-
if err != nil {
349-
return nil, err
350-
}
351-
body, err = json.Marshal(modified)
352-
if err != nil {
353-
return nil, core.NewInvalidRequestError("failed to encode batch item", err)
354-
}
355-
default:
356-
continue
369+
func mergeBatchHints(left, right map[string]string) map[string]string {
370+
if len(left) == 0 {
371+
if len(right) == 0 {
372+
return nil
373+
}
374+
merged := make(map[string]string, len(right))
375+
for key, value := range right {
376+
merged[key] = value
357377
}
358-
forward.Requests[i].Body = body
378+
return merged
359379
}
360-
return &forward, nil
380+
merged := make(map[string]string, len(left)+len(right))
381+
for key, value := range left {
382+
merged[key] = value
383+
}
384+
for key, value := range right {
385+
merged[key] = value
386+
}
387+
return merged
361388
}
362389

363390
func (p *Provider) resolveRequestSelector(model, provider string) (core.ModelSelector, error) {

internal/aliases/provider_test.go

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"encoding/json"
66
"io"
7+
"strings"
78
"testing"
89

910
"gomodel/internal/core"
@@ -14,6 +15,9 @@ type providerMock struct {
1415
responsesReq *core.ResponsesRequest
1516
embeddingReq *core.EmbeddingRequest
1617
batchReq *core.BatchRequest
18+
fileContent *core.FileContentResponse
19+
fileCreates []*core.FileCreateRequest
20+
fileObject *core.FileObject
1721
modelsResp *core.ModelsResponse
1822
supported map[string]bool
1923
providerType map[string]string
@@ -85,6 +89,35 @@ func (m *providerMock) GetBatchResults(_ context.Context, _ string, _ string) (*
8589
return &core.BatchResultsResponse{Object: "list", BatchID: "batch_1"}, nil
8690
}
8791

92+
func (m *providerMock) CreateFile(_ context.Context, _ string, req *core.FileCreateRequest) (*core.FileObject, error) {
93+
copy := *req
94+
copy.Content = append([]byte(nil), req.Content...)
95+
m.fileCreates = append(m.fileCreates, &copy)
96+
if m.fileObject != nil {
97+
return m.fileObject, nil
98+
}
99+
return &core.FileObject{ID: "file_rewritten", Object: "file", Filename: req.Filename, Purpose: req.Purpose}, nil
100+
}
101+
102+
func (m *providerMock) ListFiles(_ context.Context, _ string, _ string, _ int, _ string) (*core.FileListResponse, error) {
103+
return &core.FileListResponse{Object: "list"}, nil
104+
}
105+
106+
func (m *providerMock) GetFile(_ context.Context, _ string, id string) (*core.FileObject, error) {
107+
return &core.FileObject{ID: id, Object: "file"}, nil
108+
}
109+
110+
func (m *providerMock) DeleteFile(_ context.Context, _ string, id string) (*core.FileDeleteResponse, error) {
111+
return &core.FileDeleteResponse{ID: id, Object: "file", Deleted: true}, nil
112+
}
113+
114+
func (m *providerMock) GetFileContent(_ context.Context, _ string, id string) (*core.FileContentResponse, error) {
115+
if m.fileContent != nil {
116+
return m.fileContent, nil
117+
}
118+
return &core.FileContentResponse{ID: id, Filename: "batch.jsonl", Data: []byte("{}\n")}, nil
119+
}
120+
88121
func TestProviderResolvesRequestsAndExposesAliasModels(t *testing.T) {
89122
catalog := newTestCatalog()
90123
catalog.add("gpt-4o", "openai", core.Model{ID: "gpt-4o", Object: "model", OwnedBy: "openai"})
@@ -221,3 +254,83 @@ func TestProviderRewritesBatchItemBodies(t *testing.T) {
221254
t.Fatalf("rewritten batch item unexpectedly preserved provider hint: %s", inner.batchReq.Requests[0].Body)
222255
}
223256
}
257+
258+
func TestProviderRewritesBatchInputFiles(t *testing.T) {
259+
catalog := newTestCatalog()
260+
catalog.add("openai/gpt-4o", "openai", core.Model{ID: "gpt-4o", Object: "model"})
261+
262+
service, err := NewService(newMemoryStore(Alias{Name: "smart", TargetModel: "gpt-4o", TargetProvider: "openai", Enabled: true}), catalog)
263+
if err != nil {
264+
t.Fatalf("NewService() error = %v", err)
265+
}
266+
if err := service.Refresh(context.Background()); err != nil {
267+
t.Fatalf("Refresh() error = %v", err)
268+
}
269+
270+
inner := newProviderMock()
271+
inner.fileContent = &core.FileContentResponse{
272+
ID: "file_source",
273+
Filename: "batch.jsonl",
274+
Data: []byte("{\"custom_id\":\"1\",\"method\":\"POST\",\"url\":\"/v1/chat/completions\",\"body\":{\"model\":\"smart\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}]}}\n"),
275+
}
276+
inner.fileObject = &core.FileObject{ID: "file_rewritten", Object: "file", Filename: "batch.jsonl", Purpose: "batch"}
277+
provider := NewProvider(inner, service)
278+
279+
_, err = provider.CreateBatch(context.Background(), "openai", &core.BatchRequest{
280+
InputFileID: "file_source",
281+
Endpoint: "/v1/chat/completions",
282+
})
283+
if err != nil {
284+
t.Fatalf("CreateBatch() error = %v", err)
285+
}
286+
if inner.batchReq == nil {
287+
t.Fatal("captured batch request = nil")
288+
}
289+
if inner.batchReq.InputFileID != "file_rewritten" {
290+
t.Fatalf("rewritten input_file_id = %q, want file_rewritten", inner.batchReq.InputFileID)
291+
}
292+
if len(inner.fileCreates) != 1 {
293+
t.Fatalf("len(fileCreates) = %d, want 1", len(inner.fileCreates))
294+
}
295+
if got := string(inner.fileCreates[0].Content); !strings.Contains(got, "\"model\":\"gpt-4o\"") {
296+
t.Fatalf("rewritten file content = %s, want concrete model", got)
297+
}
298+
}
299+
300+
func TestProviderBatchInputFileSkipsUploadWhenUnchanged(t *testing.T) {
301+
catalog := newTestCatalog()
302+
catalog.add("openai/gpt-4o", "openai", core.Model{ID: "gpt-4o", Object: "model"})
303+
304+
service, err := NewService(newMemoryStore(Alias{Name: "smart", TargetModel: "gpt-4o", TargetProvider: "openai", Enabled: true}), catalog)
305+
if err != nil {
306+
t.Fatalf("NewService() error = %v", err)
307+
}
308+
if err := service.Refresh(context.Background()); err != nil {
309+
t.Fatalf("Refresh() error = %v", err)
310+
}
311+
312+
inner := newProviderMock()
313+
inner.fileContent = &core.FileContentResponse{
314+
ID: "file_source",
315+
Filename: "batch.jsonl",
316+
Data: []byte("{\"custom_id\":\"1\",\"method\":\"POST\",\"url\":\"/v1/chat/completions\",\"body\":{\"model\":\"gpt-4o\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}]}}\n"),
317+
}
318+
provider := NewProvider(inner, service)
319+
320+
_, err = provider.CreateBatch(context.Background(), "openai", &core.BatchRequest{
321+
InputFileID: "file_source",
322+
Endpoint: "/v1/chat/completions",
323+
})
324+
if err != nil {
325+
t.Fatalf("CreateBatch() error = %v", err)
326+
}
327+
if inner.batchReq == nil {
328+
t.Fatal("captured batch request = nil")
329+
}
330+
if inner.batchReq.InputFileID != "file_source" {
331+
t.Fatalf("input_file_id = %q, want file_source", inner.batchReq.InputFileID)
332+
}
333+
if len(inner.fileCreates) != 0 {
334+
t.Fatalf("len(fileCreates) = %d, want 0", len(inner.fileCreates))
335+
}
336+
}

internal/batch/store.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ const (
2828
type StoredBatch struct {
2929
Batch *core.BatchResponse `json:"batch"`
3030
RequestEndpointByCustomID map[string]string `json:"request_endpoint_by_custom_id,omitempty"`
31+
OriginalInputFileID string `json:"original_input_file_id,omitempty"`
32+
RewrittenInputFileID string `json:"rewritten_input_file_id,omitempty"`
3133
RequestID string `json:"request_id,omitempty"`
3234
UsageLoggedAt *time.Time `json:"usage_logged_at,omitempty"`
3335
}

internal/batch/store_test.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ func TestSerializeBatchPreservesRequestEndpointHints(t *testing.T) {
3939
"resp-1": "/v1/responses",
4040
"chat-1": "/v1/chat/completions",
4141
},
42+
OriginalInputFileID: "file_original",
43+
RewrittenInputFileID: "file_rewritten",
4244
})
4345
if err != nil {
4446
t.Fatalf("serializeBatch() error = %v", err)
@@ -57,6 +59,12 @@ func TestSerializeBatchPreservesRequestEndpointHints(t *testing.T) {
5759
if got := decoded.RequestEndpointByCustomID["chat-1"]; got != "/v1/chat/completions" {
5860
t.Fatalf("RequestEndpointByCustomID[chat-1] = %q, want /v1/chat/completions", got)
5961
}
62+
if decoded.OriginalInputFileID != "file_original" {
63+
t.Fatalf("OriginalInputFileID = %q, want file_original", decoded.OriginalInputFileID)
64+
}
65+
if decoded.RewrittenInputFileID != "file_rewritten" {
66+
t.Fatalf("RewrittenInputFileID = %q, want file_rewritten", decoded.RewrittenInputFileID)
67+
}
6068
}
6169

6270
func TestSerializeBatchStripsGatewayOnlyMetadata(t *testing.T) {
@@ -65,9 +73,9 @@ func TestSerializeBatchStripsGatewayOnlyMetadata(t *testing.T) {
6573
Batch: &core.BatchResponse{
6674
ID: "batch_123",
6775
Metadata: map[string]string{
68-
"visible": "keep",
69-
RequestIDMetadataKey: "req_123",
70-
UsageLoggedAtMetadataKey: strconv.FormatInt(loggedAt.Unix(), 10),
76+
"visible": "keep",
77+
RequestIDMetadataKey: "req_123",
78+
UsageLoggedAtMetadataKey: strconv.FormatInt(loggedAt.Unix(), 10),
7179
},
7280
},
7381
})

0 commit comments

Comments
 (0)