Skip to content

Commit 697b12d

Browse files
fix(providers): correct azure resource routing
1 parent 0606096 commit 697b12d

8 files changed

Lines changed: 431 additions & 32 deletions

File tree

internal/providers/azure/azure.go

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"net/http"
66
"net/url"
77
"strconv"
8+
"strings"
89

910
"gomodel/internal/core"
1011
"gomodel/internal/llmclient"
@@ -22,34 +23,46 @@ var Registration = providers.Registration{
2223

2324
type Provider struct {
2425
*openai.CompatibleProvider
25-
apiVersion string
26+
resourceProvider *openai.CompatibleProvider
27+
apiVersion string
2628
}
2729

2830
func New(apiKey string, opts providers.ProviderOptions) core.Provider {
2931
p := &Provider{apiVersion: defaultAPIVersion}
30-
p.CompatibleProvider = openai.NewCompatibleProvider(apiKey, opts, openai.CompatibleProviderConfig{
32+
cfg := openai.CompatibleProviderConfig{
3133
ProviderName: "azure",
3234
DefaultBaseURL: "https://example.invalid",
3335
SetHeaders: setHeaders,
34-
})
36+
}
37+
p.CompatibleProvider = openai.NewCompatibleProvider(apiKey, opts, cfg)
38+
p.resourceProvider = openai.NewCompatibleProvider(apiKey, opts, cfg)
3539
p.SetRequestMutator(p.mutateRequest)
40+
p.resourceProvider.SetRequestMutator(p.mutateRequest)
3641
return p
3742
}
3843

3944
func NewWithHTTPClient(apiKey string, httpClient *http.Client, hooks llmclient.Hooks) *Provider {
4045
p := &Provider{apiVersion: defaultAPIVersion}
41-
p.CompatibleProvider = openai.NewCompatibleProviderWithHTTPClient(apiKey, httpClient, hooks, openai.CompatibleProviderConfig{
46+
cfg := openai.CompatibleProviderConfig{
4247
ProviderName: "azure",
4348
DefaultBaseURL: "https://example.invalid",
4449
SetHeaders: setHeaders,
45-
})
50+
}
51+
p.CompatibleProvider = openai.NewCompatibleProviderWithHTTPClient(apiKey, httpClient, hooks, cfg)
52+
p.resourceProvider = openai.NewCompatibleProviderWithHTTPClient(apiKey, httpClient, hooks, cfg)
4653
p.SetRequestMutator(p.mutateRequest)
54+
p.resourceProvider.SetRequestMutator(p.mutateRequest)
4755
return p
4856
}
4957

58+
func (p *Provider) SetBaseURL(baseURL string) {
59+
p.CompatibleProvider.SetBaseURL(baseURL)
60+
p.resourceProvider.SetBaseURL(resourceRootBaseURL(baseURL))
61+
}
62+
5063
func (p *Provider) ListModels(ctx context.Context) (*core.ModelsResponse, error) {
5164
var resp core.ModelsResponse
52-
if err := p.Do(ctx, llmclient.Request{
65+
if err := p.resourceProvider.Do(ctx, llmclient.Request{
5366
Method: http.MethodGet,
5467
Endpoint: "/openai/models",
5568
}, &resp); err != nil {
@@ -63,7 +76,7 @@ func (p *Provider) CreateBatch(ctx context.Context, req *core.BatchRequest) (*co
6376
return nil, core.NewInvalidRequestError("batch request is required", nil)
6477
}
6578
var resp core.BatchResponse
66-
if err := p.Do(ctx, llmclient.Request{
79+
if err := p.resourceProvider.Do(ctx, llmclient.Request{
6780
Method: http.MethodPost,
6881
Endpoint: "/openai/batches",
6982
Body: req,
@@ -78,7 +91,7 @@ func (p *Provider) CreateBatch(ctx context.Context, req *core.BatchRequest) (*co
7891

7992
func (p *Provider) GetBatch(ctx context.Context, id string) (*core.BatchResponse, error) {
8093
var resp core.BatchResponse
81-
if err := p.Do(ctx, llmclient.Request{
94+
if err := p.resourceProvider.Do(ctx, llmclient.Request{
8295
Method: http.MethodGet,
8396
Endpoint: "/openai/batches/" + url.PathEscape(id),
8497
}, &resp); err != nil {
@@ -105,7 +118,7 @@ func (p *Provider) ListBatches(ctx context.Context, limit int, after string) (*c
105118
}
106119

107120
var resp core.BatchListResponse
108-
if err := p.Do(ctx, llmclient.Request{
121+
if err := p.resourceProvider.Do(ctx, llmclient.Request{
109122
Method: http.MethodGet,
110123
Endpoint: endpoint,
111124
}, &resp); err != nil {
@@ -121,7 +134,7 @@ func (p *Provider) ListBatches(ctx context.Context, limit int, after string) (*c
121134

122135
func (p *Provider) CancelBatch(ctx context.Context, id string) (*core.BatchResponse, error) {
123136
var resp core.BatchResponse
124-
if err := p.Do(ctx, llmclient.Request{
137+
if err := p.resourceProvider.Do(ctx, llmclient.Request{
125138
Method: http.MethodPost,
126139
Endpoint: "/openai/batches/" + url.PathEscape(id) + "/cancel",
127140
}, &resp); err != nil {
@@ -169,3 +182,25 @@ func isValidClientRequestID(id string) bool {
169182
}
170183
return true
171184
}
185+
186+
func resourceRootBaseURL(baseURL string) string {
187+
parsed, err := url.Parse(strings.TrimSpace(baseURL))
188+
if err != nil {
189+
return strings.TrimRight(strings.TrimSpace(baseURL), "/")
190+
}
191+
192+
path := strings.TrimRight(parsed.Path, "/")
193+
for _, marker := range []string{"/openai/deployments/", "/deployments/"} {
194+
if idx := strings.Index(path, marker); idx >= 0 {
195+
path = path[:idx]
196+
break
197+
}
198+
}
199+
200+
parsed.Path = path
201+
parsed.RawPath = ""
202+
parsed.RawQuery = ""
203+
parsed.Fragment = ""
204+
205+
return strings.TrimRight(parsed.String(), "/")
206+
}

internal/providers/azure/azure_test.go

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,3 +239,133 @@ func TestBatchEndpoints_UseAzureOpenAIPaths(t *testing.T) {
239239
})
240240
}
241241
}
242+
243+
func TestListModels_UsesAzureResourceRootForDeploymentScopedBaseURL(t *testing.T) {
244+
var gotPath string
245+
246+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
247+
gotPath = r.URL.Path
248+
w.Header().Set("Content-Type", "application/json")
249+
_, _ = w.Write([]byte(`{"object":"list","data":[]}`))
250+
}))
251+
defer server.Close()
252+
253+
provider := NewWithHTTPClient("test-api-key", server.Client(), llmclient.Hooks{})
254+
provider.SetBaseURL(server.URL + "/openai/deployments/gpt-4o")
255+
256+
_, err := provider.ListModels(context.Background())
257+
if err != nil {
258+
t.Fatalf("unexpected error: %v", err)
259+
}
260+
if gotPath != "/openai/models" {
261+
t.Fatalf("path = %q, want /openai/models", gotPath)
262+
}
263+
}
264+
265+
func TestBatchEndpoints_UseAzureResourceRootForDeploymentScopedBaseURL(t *testing.T) {
266+
tests := []struct {
267+
name string
268+
call func(*Provider) error
269+
wantPath string
270+
wantMethod string
271+
responseBody string
272+
}{
273+
{
274+
name: "create",
275+
call: func(p *Provider) error {
276+
_, err := p.CreateBatch(context.Background(), &core.BatchRequest{
277+
InputFileID: "file-123",
278+
Endpoint: "/v1/chat/completions",
279+
CompletionWindow: "24h",
280+
})
281+
return err
282+
},
283+
wantPath: "/openai/batches",
284+
wantMethod: http.MethodPost,
285+
responseBody: `{
286+
"id":"batch_123",
287+
"object":"batch",
288+
"endpoint":"/v1/chat/completions",
289+
"status":"validating",
290+
"created_at":1677652288,
291+
"request_counts":{"total":1,"completed":0,"failed":0}
292+
}`,
293+
},
294+
{
295+
name: "get",
296+
call: func(p *Provider) error {
297+
_, err := p.GetBatch(context.Background(), "batch_123")
298+
return err
299+
},
300+
wantPath: "/openai/batches/batch_123",
301+
wantMethod: http.MethodGet,
302+
responseBody: `{
303+
"id":"batch_123",
304+
"object":"batch",
305+
"endpoint":"/v1/chat/completions",
306+
"status":"validating",
307+
"created_at":1677652288,
308+
"request_counts":{"total":1,"completed":0,"failed":0}
309+
}`,
310+
},
311+
{
312+
name: "list",
313+
call: func(p *Provider) error {
314+
_, err := p.ListBatches(context.Background(), 10, "batch_122")
315+
return err
316+
},
317+
wantPath: "/openai/batches",
318+
wantMethod: http.MethodGet,
319+
responseBody: `{
320+
"object":"list",
321+
"data":[],
322+
"has_more":false
323+
}`,
324+
},
325+
{
326+
name: "cancel",
327+
call: func(p *Provider) error {
328+
_, err := p.CancelBatch(context.Background(), "batch_123")
329+
return err
330+
},
331+
wantPath: "/openai/batches/batch_123/cancel",
332+
wantMethod: http.MethodPost,
333+
responseBody: `{
334+
"id":"batch_123",
335+
"object":"batch",
336+
"endpoint":"/v1/chat/completions",
337+
"status":"cancelling",
338+
"created_at":1677652288,
339+
"request_counts":{"total":1,"completed":0,"failed":0}
340+
}`,
341+
},
342+
}
343+
344+
for _, tt := range tests {
345+
t.Run(tt.name, func(t *testing.T) {
346+
var gotPath string
347+
var gotMethod string
348+
349+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
350+
gotPath = r.URL.Path
351+
gotMethod = r.Method
352+
w.Header().Set("Content-Type", "application/json")
353+
_, _ = w.Write([]byte(tt.responseBody))
354+
}))
355+
defer server.Close()
356+
357+
provider := NewWithHTTPClient("test-api-key", server.Client(), llmclient.Hooks{})
358+
provider.SetBaseURL(server.URL + "/openai/deployments/gpt-4o")
359+
360+
if err := tt.call(provider); err != nil {
361+
t.Fatalf("unexpected error: %v", err)
362+
}
363+
if gotPath != tt.wantPath {
364+
t.Fatalf("path = %q, want %q", gotPath, tt.wantPath)
365+
}
366+
if gotMethod != tt.wantMethod {
367+
t.Fatalf("method = %q, want %q", gotMethod, tt.wantMethod)
368+
}
369+
})
370+
}
371+
}

internal/providers/batch_results_file_adapter.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,21 @@ type openAICompatibleBatchLine struct {
3131

3232
// FetchBatchResultsFromOutputFile adapts OpenAI-compatible batch output files to gateway batch results.
3333
func FetchBatchResultsFromOutputFile(ctx context.Context, client *llmclient.Client, providerName, batchID string) (*core.BatchResultsResponse, error) {
34+
return FetchBatchResultsFromOutputFileWithPreparer(ctx, client, providerName, batchID, nil)
35+
}
36+
37+
func FetchBatchResultsFromOutputFileWithPreparer(ctx context.Context, client *llmclient.Client, providerName, batchID string, prepare openAICompatibleRequestPreparer) (*core.BatchResultsResponse, error) {
3438
if strings.TrimSpace(batchID) == "" {
3539
return nil, core.NewInvalidRequestError("batch id is required", nil)
3640
}
3741
if client == nil {
3842
return nil, core.NewInvalidRequestError("provider client is not configured", nil)
3943
}
4044

41-
batchRaw, err := client.DoRaw(ctx, llmclient.Request{
45+
batchRaw, err := client.DoRaw(ctx, prepareOpenAICompatibleRequest(prepare, llmclient.Request{
4246
Method: http.MethodGet,
4347
Endpoint: "/batches/" + url.PathEscape(batchID),
44-
})
48+
}))
4549
if err != nil {
4650
return nil, err
4751
}
@@ -58,10 +62,10 @@ func FetchBatchResultsFromOutputFile(ctx context.Context, client *llmclient.Clie
5862
return nil, core.NewProviderError(providerName, http.StatusBadGateway, "provider batch response missing output file id", nil)
5963
}
6064

61-
fileResp, err := client.DoPassthrough(ctx, llmclient.Request{
65+
fileResp, err := client.DoPassthrough(ctx, prepareOpenAICompatibleRequest(prepare, llmclient.Request{
6266
Method: http.MethodGet,
6367
Endpoint: "/files/" + url.PathEscape(outputFileID) + "/content",
64-
})
68+
}))
6569
if err != nil {
6670
return nil, err
6771
}

internal/providers/config.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func applyProviderEnvVars(raw map[string]config.RawProviderConfig) map[string]co
6767
baseURL = kp.defaultBase
6868
}
6969

70-
if apiKey == "" && baseURL == "" {
70+
if apiKey == "" && baseURL == "" && apiVersion == "" {
7171
continue
7272
}
7373

@@ -110,7 +110,7 @@ func filterEmptyProviders(raw map[string]config.RawProviderConfig) map[string]co
110110
result[name] = p
111111
continue
112112
}
113-
if name == "azure" && strings.TrimSpace(p.BaseURL) == "" {
113+
if p.Type == "azure" && strings.TrimSpace(p.BaseURL) == "" {
114114
continue
115115
}
116116
if p.APIKey != "" && !strings.Contains(p.APIKey, "${") {

internal/providers/config_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,18 @@ func TestFilterEmptyProviders_EmptyMap(t *testing.T) {
253253
}
254254
}
255255

256+
func TestFilterEmptyProviders_RemovesAzureByTypeWithoutBaseURL(t *testing.T) {
257+
raw := map[string]config.RawProviderConfig{
258+
"my-azure": {Type: "azure", APIKey: "sk-azure"},
259+
}
260+
261+
got := filterEmptyProviders(raw)
262+
263+
if _, exists := got["my-azure"]; exists {
264+
t.Fatal("expected azure provider without base URL to be removed regardless of map key")
265+
}
266+
}
267+
256268
// --- applyProviderEnvVars ---
257269

258270
func TestApplyProviderEnvVars_DiscoversFromAPIKey(t *testing.T) {
@@ -343,6 +355,25 @@ func TestApplyProviderEnvVars_AzureAPIVersionEnvWins(t *testing.T) {
343355
}
344356
}
345357

358+
func TestApplyProviderEnvVars_AzureAPIVersionEnvWinsWithoutOtherAzureEnvVars(t *testing.T) {
359+
t.Setenv("AZURE_API_VERSION", "2025-04-01-preview")
360+
361+
raw := map[string]config.RawProviderConfig{
362+
"azure": {
363+
Type: "azure",
364+
APIKey: "sk-yaml-azure",
365+
BaseURL: "https://example-resource.openai.azure.com/openai/deployments/gpt-4o",
366+
APIVersion: "2024-10-21",
367+
},
368+
}
369+
370+
got := applyProviderEnvVars(raw)
371+
372+
if got["azure"].APIVersion != "2025-04-01-preview" {
373+
t.Fatalf("APIVersion = %q, want 2025-04-01-preview", got["azure"].APIVersion)
374+
}
375+
}
376+
346377
func TestApplyProviderEnvVars_DoesNotDiscoverAzureWithoutBaseURL(t *testing.T) {
347378
t.Setenv("AZURE_API_KEY", "sk-azure")
348379

0 commit comments

Comments
 (0)