Skip to content

Commit 3784193

Browse files
Strengthen contract replay validation and provider guards
1 parent be99a6d commit 3784193

51 files changed

Lines changed: 9887 additions & 200 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

cmd/recordapi/main.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,36 @@ var endpointConfigs = map[string]struct {
111111
},
112112
}
113113

114+
var providerCapabilities = map[string]map[string]bool{
115+
"openai": {
116+
"responses": true,
117+
},
118+
"anthropic": {
119+
"responses": false,
120+
},
121+
"gemini": {
122+
"responses": false,
123+
},
124+
"groq": {
125+
"responses": false,
126+
},
127+
"xai": {
128+
"responses": true,
129+
},
130+
}
131+
132+
func endpointRequiresResponsesCapability(endpoint string) bool {
133+
return endpoint == "responses" || endpoint == "responses_stream"
134+
}
135+
136+
func providerSupportsResponses(provider string) bool {
137+
capabilities, ok := providerCapabilities[provider]
138+
if !ok {
139+
return false
140+
}
141+
return capabilities["responses"]
142+
}
143+
114144
func main() {
115145
provider := flag.String("provider", "openai", "Provider to test (openai, anthropic, gemini, groq, xai)")
116146
endpoint := flag.String("endpoint", "chat", "Endpoint to test (chat, chat_stream, models, responses, responses_stream)")
@@ -135,6 +165,10 @@ func main() {
135165
fmt.Fprintf(os.Stderr, "Error: unknown endpoint %q\n", *endpoint)
136166
os.Exit(1)
137167
}
168+
if endpointRequiresResponsesCapability(*endpoint) && !providerSupportsResponses(*provider) {
169+
fmt.Fprintf(os.Stderr, "Error: provider %q is missing responses capability (/v1/responses)\n", *provider)
170+
os.Exit(1)
171+
}
138172

139173
apiKey := os.Getenv(pConfig.envKey)
140174
if apiKey == "" {

internal/providers/gemini/gemini.go

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,12 @@ const (
3030

3131
// Provider implements the core.Provider interface for Google Gemini
3232
type Provider struct {
33-
client *llmclient.Client
34-
httpClient *http.Client
35-
hooks llmclient.Hooks
36-
apiKey string
37-
modelsURL string
33+
client *llmclient.Client
34+
httpClient *http.Client
35+
hooks llmclient.Hooks
36+
apiKey string
37+
modelsURL string
38+
modelsClientConf llmclient.Config
3839
}
3940

4041
// New creates a new Gemini provider.
@@ -44,6 +45,13 @@ func New(apiKey string, opts providers.ProviderOptions) core.Provider {
4445
apiKey: apiKey,
4546
hooks: opts.Hooks,
4647
modelsURL: defaultModelsBaseURL,
48+
modelsClientConf: llmclient.Config{
49+
ProviderName: "gemini",
50+
BaseURL: defaultModelsBaseURL,
51+
Retry: opts.Resilience.Retry,
52+
Hooks: opts.Hooks,
53+
CircuitBreaker: opts.Resilience.CircuitBreaker,
54+
},
4755
}
4856
cfg := llmclient.Config{
4957
ProviderName: "gemini",
@@ -68,6 +76,9 @@ func NewWithHTTPClient(apiKey string, httpClient *http.Client, hooks llmclient.H
6876
hooks: hooks,
6977
modelsURL: defaultModelsBaseURL,
7078
}
79+
modelsCfg := llmclient.DefaultConfig("gemini", defaultModelsBaseURL)
80+
modelsCfg.Hooks = hooks
81+
p.modelsClientConf = modelsCfg
7182
cfg := llmclient.DefaultConfig("gemini", defaultOpenAICompatibleBaseURL)
7283
cfg.Hooks = hooks
7384
p.client = llmclient.NewWithHTTPClient(httpClient, cfg, p.setHeaders)
@@ -83,6 +94,7 @@ func (p *Provider) SetBaseURL(url string) {
8394
// This is primarily useful for tests and local emulators.
8495
func (p *Provider) SetModelsURL(url string) {
8596
p.modelsURL = url
97+
p.modelsClientConf.BaseURL = url
8698
}
8799

88100
// setHeaders sets the required headers for Gemini API requests
@@ -149,16 +161,26 @@ type geminiModelsResponse struct {
149161
func (p *Provider) ListModels(ctx context.Context) (*core.ModelsResponse, error) {
150162
// Use the native Gemini API to list models
151163
// We need to create a separate client for the models endpoint since it uses a different URL
152-
modelsCfg := llmclient.DefaultConfig("gemini", p.modelsURL)
164+
modelsCfg := p.modelsClientConf
165+
modelsCfg.BaseURL = p.modelsURL
153166
modelsCfg.Hooks = p.hooks
154167
headers := func(req *http.Request) {
155168
// Add API key as query parameter.
156169
// NOTE: Passing the API key in the URL query parameter is required by Google's native Gemini API for the models endpoint.
157170
// This may be a security concern, as the API key can be logged in server access logs, proxy logs, and browser history.
158171
// See: https://cloud.google.com/vertex-ai/docs/generative-ai/model-parameters#api-key
159172
q := req.URL.Query()
160-
q.Add("key", p.apiKey)
173+
q.Set("key", p.apiKey)
161174
req.URL.RawQuery = q.Encode()
175+
176+
// Preserve request tracing across list-models requests.
177+
requestID := req.Header.Get("X-Request-Id")
178+
if requestID == "" {
179+
requestID = core.GetRequestID(req.Context())
180+
}
181+
if requestID != "" {
182+
req.Header.Set("X-Request-Id", requestID)
183+
}
162184
}
163185

164186
var modelsClient *llmclient.Client

tests/contract/README.md

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,26 @@ Each folder contains recorded JSON and SSE payloads used by replay tests.
3535

3636
```bash
3737
# Run contract replay tests
38-
go test -v -tags=contract ./tests/contract/...
38+
go test -v -tags=contract -timeout=5m ./tests/contract/...
3939

4040
# Make target
4141
make test-contract
4242
```
4343

4444
## Updating fixtures
4545

46-
Use `cmd/recordapi` (or provider curl calls) to refresh payloads when provider contracts change, then re-run the contract suite.
46+
Contract tests under `tests/contract/**/*_test.go` must validate full normalized output against committed golden files.
47+
48+
Use the canonical recorder target to refresh provider payload fixtures:
49+
50+
```bash
51+
make record-api
52+
```
53+
54+
Then refresh normalized contract-output goldens from replay tests:
55+
56+
```bash
57+
RECORD=1 go test -v -tags=contract -timeout=5m ./tests/contract/...
58+
```
59+
60+
Re-run the suite without `RECORD=1` before committing.

tests/contract/anthropic_test.go

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
//go:build contract
22

3+
// Contract tests in this file are intended to run with: -tags=contract -timeout=5m.
34
package contract
45

56
import (
67
"context"
78
"net/http"
89
"testing"
910

10-
"github.com/stretchr/testify/assert"
1111
"github.com/stretchr/testify/require"
1212

1313
"gomodel/internal/core"
@@ -53,16 +53,12 @@ func TestAnthropicReplayChatCompletion(t *testing.T) {
5353
})
5454
require.NoError(t, err)
5555
require.NotNil(t, resp)
56-
57-
assert.NotEmpty(t, resp.ID)
58-
assert.Equal(t, "chat.completion", resp.Object)
5956
require.NotEmpty(t, resp.Choices)
60-
assert.Equal(t, "assistant", resp.Choices[0].Message.Role)
61-
assert.NotEmpty(t, resp.Choices[0].FinishReason)
57+
6258
if tc.finishReason != "" {
63-
assert.Equal(t, tc.finishReason, resp.Choices[0].FinishReason)
59+
require.Equal(t, tc.finishReason, resp.Choices[0].FinishReason)
6460
}
65-
assert.NotEmpty(t, resp.Choices[0].Message.Content)
61+
compareGoldenJSON(t, goldenPathForFixture(tc.fixturePath), resp)
6662
})
6763
}
6864
}
@@ -84,9 +80,11 @@ func TestAnthropicReplayStreamChatCompletion(t *testing.T) {
8480
raw := readAllStream(t, stream)
8581
chunks, done := parseChatStream(t, raw)
8682

87-
require.True(t, done, "stream should terminate with [DONE]")
88-
require.NotEmpty(t, chunks)
89-
assert.NotEmpty(t, extractChatStreamText(chunks))
83+
compareGoldenJSON(t, goldenPathForFixture("anthropic/messages_stream.txt"), map[string]any{
84+
"done": done,
85+
"chunks": chunks,
86+
"text": extractChatStreamText(chunks),
87+
})
9088
}
9189

9290
func TestAnthropicReplayResponses(t *testing.T) {
@@ -101,13 +99,7 @@ func TestAnthropicReplayResponses(t *testing.T) {
10199
require.NoError(t, err)
102100
require.NotNil(t, resp)
103101

104-
assert.Equal(t, "response", resp.Object)
105-
assert.Equal(t, "completed", resp.Status)
106-
require.NotEmpty(t, resp.Output)
107-
require.NotEmpty(t, resp.Output[0].Content)
108-
assert.NotEmpty(t, resp.Output[0].Content[0].Text)
109-
require.NotNil(t, resp.Usage)
110-
assert.GreaterOrEqual(t, resp.Usage.TotalTokens, 0)
102+
compareGoldenJSON(t, "anthropic/responses.golden.json", resp)
111103
}
112104

113105
func TestAnthropicReplayStreamResponses(t *testing.T) {
@@ -123,12 +115,9 @@ func TestAnthropicReplayStreamResponses(t *testing.T) {
123115

124116
raw := readAllStream(t, stream)
125117
events := parseResponsesStream(t, raw)
126-
require.NotEmpty(t, events)
127-
128-
assert.True(t, hasResponsesEvent(events, "response.created"))
129-
assert.True(t, hasResponsesEvent(events, "response.output_text.delta"))
130-
assert.True(t, hasResponsesEvent(events, "response.completed"))
131-
assert.NotEmpty(t, extractResponsesStreamText(events))
118+
require.True(t, hasResponsesEvent(events, "response.created"))
119+
require.True(t, hasResponsesEvent(events, "response.output_text.delta"))
120+
require.True(t, hasResponsesEvent(events, "response.completed"))
132121

133122
hasDone := false
134123
for _, event := range events {
@@ -137,5 +126,10 @@ func TestAnthropicReplayStreamResponses(t *testing.T) {
137126
break
138127
}
139128
}
140-
assert.True(t, hasDone, "responses stream should terminate with [DONE]")
129+
require.True(t, hasDone, "responses stream should terminate with [DONE]")
130+
131+
compareGoldenJSON(t, "anthropic/responses_stream.golden.json", map[string]any{
132+
"events": events,
133+
"text": extractResponsesStreamText(events),
134+
})
141135
}

tests/contract/gemini_test.go

Lines changed: 23 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,26 @@
11
//go:build contract
22

3+
// Contract tests in this file are intended to run with: -tags=contract -timeout=5m.
34
package contract
45

56
import (
67
"context"
78
"net/http"
89
"testing"
910

10-
"github.com/stretchr/testify/assert"
1111
"github.com/stretchr/testify/require"
1212

1313
"gomodel/internal/core"
14-
"gomodel/internal/llmclient"
15-
"gomodel/internal/providers/gemini"
1614
)
1715

18-
func newGeminiReplayProvider(t *testing.T, routes map[string]replayRoute) core.Provider {
19-
t.Helper()
20-
21-
client := newReplayHTTPClient(t, routes)
22-
provider := gemini.NewWithHTTPClient("test-api-key", client, llmclient.Hooks{})
23-
provider.SetBaseURL("https://replay.local")
24-
provider.SetModelsURL("https://replay.local")
25-
return provider
26-
}
27-
2816
func TestGeminiReplayChatCompletion(t *testing.T) {
2917
testCases := []struct {
30-
name string
31-
fixturePath string
32-
expectContent bool
18+
name string
19+
fixturePath string
3320
}{
34-
{name: "basic", fixturePath: "gemini/chat_completion.json", expectContent: true},
35-
{name: "params", fixturePath: "gemini/chat_with_params.json", expectContent: true},
36-
{name: "tools", fixturePath: "gemini/chat_with_tools.json", expectContent: false},
21+
{name: "basic", fixturePath: "gemini/chat_completion.json"},
22+
{name: "params", fixturePath: "gemini/chat_with_params.json"},
23+
{name: "tools", fixturePath: "gemini/chat_with_tools.json"},
3724
}
3825

3926
for _, tc := range testCases {
@@ -52,14 +39,7 @@ func TestGeminiReplayChatCompletion(t *testing.T) {
5239
require.NoError(t, err)
5340
require.NotNil(t, resp)
5441

55-
assert.NotEmpty(t, resp.ID)
56-
assert.Equal(t, "chat.completion", resp.Object)
57-
require.NotEmpty(t, resp.Choices)
58-
assert.Equal(t, "assistant", resp.Choices[0].Message.Role)
59-
assert.NotEmpty(t, resp.Choices[0].FinishReason)
60-
if tc.expectContent {
61-
assert.NotEmpty(t, resp.Choices[0].Message.Content)
62-
}
42+
compareGoldenJSON(t, goldenPathForFixture(tc.fixturePath), resp)
6343
})
6444
}
6545
}
@@ -81,9 +61,11 @@ func TestGeminiReplayStreamChatCompletion(t *testing.T) {
8161
raw := readAllStream(t, stream)
8262
chunks, done := parseChatStream(t, raw)
8363

84-
require.True(t, done, "stream should terminate with [DONE]")
85-
require.NotEmpty(t, chunks)
86-
assert.NotEmpty(t, extractChatStreamText(chunks))
64+
compareGoldenJSON(t, goldenPathForFixture("gemini/chat_completion_stream.txt"), map[string]any{
65+
"done": done,
66+
"chunks": chunks,
67+
"text": extractChatStreamText(chunks),
68+
})
8769
}
8870

8971
func TestGeminiReplayListModels(t *testing.T) {
@@ -95,12 +77,7 @@ func TestGeminiReplayListModels(t *testing.T) {
9577
require.NoError(t, err)
9678
require.NotNil(t, resp)
9779

98-
assert.Equal(t, "list", resp.Object)
99-
require.NotEmpty(t, resp.Data)
100-
for _, model := range resp.Data {
101-
assert.NotEmpty(t, model.ID)
102-
assert.Equal(t, "model", model.Object)
103-
}
80+
compareGoldenJSON(t, goldenPathForFixture("gemini/models.json"), resp)
10481
}
10582

10683
func TestGeminiReplayResponses(t *testing.T) {
@@ -115,13 +92,7 @@ func TestGeminiReplayResponses(t *testing.T) {
11592
require.NoError(t, err)
11693
require.NotNil(t, resp)
11794

118-
assert.Equal(t, "response", resp.Object)
119-
assert.Equal(t, "completed", resp.Status)
120-
require.NotEmpty(t, resp.Output)
121-
require.NotEmpty(t, resp.Output[0].Content)
122-
assert.NotEmpty(t, resp.Output[0].Content[0].Text)
123-
require.NotNil(t, resp.Usage)
124-
assert.GreaterOrEqual(t, resp.Usage.TotalTokens, 0)
95+
compareGoldenJSON(t, "gemini/responses.golden.json", resp)
12596
}
12697

12798
func TestGeminiReplayStreamResponses(t *testing.T) {
@@ -137,12 +108,9 @@ func TestGeminiReplayStreamResponses(t *testing.T) {
137108

138109
raw := readAllStream(t, stream)
139110
events := parseResponsesStream(t, raw)
140-
require.NotEmpty(t, events)
141-
142-
assert.True(t, hasResponsesEvent(events, "response.created"))
143-
assert.True(t, hasResponsesEvent(events, "response.output_text.delta"))
144-
assert.True(t, hasResponsesEvent(events, "response.completed"))
145-
assert.NotEmpty(t, extractResponsesStreamText(events))
111+
require.True(t, hasResponsesEvent(events, "response.created"))
112+
require.True(t, hasResponsesEvent(events, "response.output_text.delta"))
113+
require.True(t, hasResponsesEvent(events, "response.completed"))
146114

147115
hasDone := false
148116
for _, event := range events {
@@ -151,5 +119,10 @@ func TestGeminiReplayStreamResponses(t *testing.T) {
151119
break
152120
}
153121
}
154-
assert.True(t, hasDone, "responses stream should terminate with [DONE]")
122+
require.True(t, hasDone, "responses stream should terminate with [DONE]")
123+
124+
compareGoldenJSON(t, "gemini/responses_stream.golden.json", map[string]any{
125+
"events": events,
126+
"text": extractResponsesStreamText(events),
127+
})
155128
}

0 commit comments

Comments
 (0)