Skip to content

Commit 072a120

Browse files
fix(guardrails): resolve audit and cache review issues
1 parent 25ebca3 commit 072a120

8 files changed

Lines changed: 259 additions & 57 deletions

File tree

internal/auditlog/entry_capture.go

Lines changed: 38 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,11 @@ func CaptureInternalJSONExchange(
9393
return
9494
}
9595

96-
if req := internalJSONAuditRequest(ctx, method, path, requestIDForEntry(entry), requestBody); req != nil {
96+
if req := internalJSONAuditRequest(ctx, method, path, requestIDForEntry(entry), requestBody, cfg.LogBodies); req != nil {
9797
PopulateRequestData(entry, req, cfg)
9898
}
99-
headers, body, truncated, ok := internalJSONAuditResponse(responseBody, responseErr, requestIDForEntry(entry))
100-
if ok {
101-
PopulateResponseData(entry, headers, body, truncated, cfg)
102-
}
99+
headers, body, truncated := internalJSONAuditResponse(responseBody, responseErr, requestIDForEntry(entry), cfg.LogBodies)
100+
PopulateResponseData(entry, headers, body, truncated, cfg)
103101
}
104102

105103
func ensureLogData(entry *LogEntry) *LogData {
@@ -116,40 +114,47 @@ func requestIDForEntry(entry *LogEntry) string {
116114
return strings.TrimSpace(entry.RequestID)
117115
}
118116

119-
func internalJSONAuditRequest(ctx context.Context, method, path, requestID string, bodyValue any) *http.Request {
120-
if bodyValue == nil {
121-
return nil
122-
}
123-
body, err := json.Marshal(bodyValue)
124-
if err != nil {
125-
return nil
126-
}
127-
117+
func internalJSONAuditRequest(ctx context.Context, method, path, requestID string, bodyValue any, logBodies bool) *http.Request {
128118
headers := internalJSONAuditHeaders(ctx, requestID)
129-
capturedBody, bodyTooBig := boundedAuditBody(body, false)
130-
snapshot := core.NewRequestSnapshot(
131-
method,
132-
path,
133-
nil,
134-
nil,
135-
headers,
136-
headers.Get("Content-Type"),
137-
capturedBody,
138-
bodyTooBig,
139-
requestID,
140-
nil,
141-
core.UserPathFromContext(ctx),
142-
)
143-
reqCtx := core.WithRequestSnapshot(ctx, snapshot)
144119
req := &http.Request{
145120
Method: method,
146121
URL: &url.URL{Path: path},
147122
Header: headers,
148123
}
124+
reqCtx := ctx
125+
if logBodies && bodyValue != nil {
126+
if body, err := json.Marshal(bodyValue); err == nil {
127+
capturedBody, bodyTooBig := boundedAuditBody(body, false)
128+
snapshot := core.NewRequestSnapshot(
129+
method,
130+
path,
131+
nil,
132+
nil,
133+
headers,
134+
headers.Get("Content-Type"),
135+
capturedBody,
136+
bodyTooBig,
137+
requestID,
138+
nil,
139+
core.UserPathFromContext(ctx),
140+
)
141+
reqCtx = core.WithRequestSnapshot(ctx, snapshot)
142+
}
143+
}
149144
return req.WithContext(reqCtx)
150145
}
151146

152-
func internalJSONAuditResponse(bodyValue any, responseErr error, requestID string) (http.Header, []byte, bool, bool) {
147+
func internalJSONAuditResponse(bodyValue any, responseErr error, requestID string, logBodies bool) (http.Header, []byte, bool) {
148+
headers := make(http.Header)
149+
headers.Set("Content-Type", "application/json")
150+
if requestID != "" {
151+
headers.Set("X-Request-ID", requestID)
152+
}
153+
154+
if !logBodies {
155+
return headers, nil, false
156+
}
157+
153158
var (
154159
body []byte
155160
err error
@@ -165,19 +170,13 @@ func internalJSONAuditResponse(bodyValue any, responseErr error, requestID strin
165170
body, err = json.Marshal(core.NewProviderError("", http.StatusInternalServerError, responseErr.Error(), responseErr).ToJSON())
166171
}
167172
default:
168-
return nil, nil, false, false
173+
return headers, nil, false
169174
}
170175
if err != nil {
171-
return nil, nil, false, false
172-
}
173-
174-
headers := make(http.Header)
175-
headers.Set("Content-Type", "application/json")
176-
if requestID != "" {
177-
headers.Set("X-Request-ID", requestID)
176+
return headers, nil, false
178177
}
179178
capturedBody, truncated := boundedAuditBody(body, true)
180-
return headers, capturedBody, truncated, true
179+
return headers, capturedBody, truncated
181180
}
182181

183182
func internalJSONAuditHeaders(ctx context.Context, requestID string) http.Header {
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
package auditlog
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"testing"
7+
8+
"gomodel/internal/core"
9+
)
10+
11+
func TestCaptureInternalJSONExchange_PreservesHeadersWithoutBodies(t *testing.T) {
12+
entry := &LogEntry{
13+
RequestID: "req_123",
14+
Data: &LogData{},
15+
}
16+
ctx := core.WithRequestSnapshot(context.Background(), core.NewRequestSnapshot(
17+
"POST",
18+
"/v1/chat/completions",
19+
nil,
20+
nil,
21+
map[string][]string{
22+
"Traceparent": {`00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-00`},
23+
},
24+
"application/json",
25+
nil,
26+
false,
27+
"req_123",
28+
nil,
29+
"/team/alpha",
30+
))
31+
32+
CaptureInternalJSONExchange(entry, ctx, "POST", "/v1/chat/completions", nil, nil, nil, Config{
33+
LogHeaders: true,
34+
LogBodies: false,
35+
})
36+
37+
if entry.Data == nil {
38+
t.Fatal("Data = nil, want populated log data")
39+
}
40+
if got := entry.Data.RequestHeaders[http.CanonicalHeaderKey("X-Request-ID")]; got != "req_123" {
41+
t.Fatalf("RequestHeaders[X-Request-ID] = %q, want req_123", got)
42+
}
43+
if got := entry.Data.RequestHeaders[http.CanonicalHeaderKey(core.UserPathHeader)]; got != "/team/alpha" {
44+
t.Fatalf("RequestHeaders[%s] = %q, want /team/alpha", core.UserPathHeader, got)
45+
}
46+
if got := entry.Data.RequestHeaders["Traceparent"]; got == "" {
47+
t.Fatal("RequestHeaders[Traceparent] = empty, want propagated trace header")
48+
}
49+
if got := entry.Data.ResponseHeaders[http.CanonicalHeaderKey("X-Request-ID")]; got != "req_123" {
50+
t.Fatalf("ResponseHeaders[X-Request-ID] = %q, want req_123", got)
51+
}
52+
if entry.Data.RequestBody != nil || entry.Data.ResponseBody != nil {
53+
t.Fatal("expected no bodies when body logging is disabled")
54+
}
55+
}
56+
57+
func TestCaptureInternalJSONExchange_PreservesHeadersWhenBodyMarshalFails(t *testing.T) {
58+
entry := &LogEntry{
59+
RequestID: "req_456",
60+
Data: &LogData{},
61+
}
62+
ctx := core.WithEffectiveUserPath(context.Background(), "/team/beta")
63+
64+
CaptureInternalJSONExchange(entry, ctx, "POST", "/v1/chat/completions", func() {}, func() {}, nil, Config{
65+
LogHeaders: true,
66+
LogBodies: true,
67+
})
68+
69+
if entry.Data == nil {
70+
t.Fatal("Data = nil, want populated log data")
71+
}
72+
if got := entry.Data.RequestHeaders[http.CanonicalHeaderKey("X-Request-ID")]; got != "req_456" {
73+
t.Fatalf("RequestHeaders[X-Request-ID] = %q, want req_456", got)
74+
}
75+
if got := entry.Data.RequestHeaders[http.CanonicalHeaderKey(core.UserPathHeader)]; got != "/team/beta" {
76+
t.Fatalf("RequestHeaders[%s] = %q, want /team/beta", core.UserPathHeader, got)
77+
}
78+
if got := entry.Data.ResponseHeaders[http.CanonicalHeaderKey("X-Request-ID")]; got != "req_456" {
79+
t.Fatalf("ResponseHeaders[X-Request-ID] = %q, want req_456", got)
80+
}
81+
if entry.Data.RequestBody != nil || entry.Data.ResponseBody != nil {
82+
t.Fatal("expected marshal failures to skip bodies while preserving headers")
83+
}
84+
}

internal/guardrails/llm_based_altering_test.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,13 @@ func TestLLMBasedAltering_Process_UsesInternalGuardrailOriginAndUserPath(t *test
152152
}
153153

154154
func TestLLMBasedAltering_Process_SkipsPrefix(t *testing.T) {
155+
called := false
155156
g, err := NewLLMBasedAlteringGuardrail("privacy", LLMBasedAlteringConfig{
156157
Model: "gpt-4o-mini",
157158
SkipContentPrefix: "### safe",
158159
}, mockChatCompletionExecutor{
159160
chatFn: func(_ context.Context, _ *core.ChatRequest) (*core.ChatResponse, error) {
161+
called = true
160162
return nil, fmt.Errorf("should not be called")
161163
},
162164
})
@@ -172,6 +174,9 @@ func TestLLMBasedAltering_Process_SkipsPrefix(t *testing.T) {
172174
if got[0].Content != msgs[0].Content {
173175
t.Fatalf("Content = %q, want unchanged", got[0].Content)
174176
}
177+
if called {
178+
t.Fatal("expected skip prefix to bypass auxiliary executor")
179+
}
175180
}
176181

177182
func TestLLMBasedAltering_Process_FailsOpenOnProviderError(t *testing.T) {

internal/guardrails/provider.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,9 +1149,33 @@ func patchResponsesInputMap(original map[string]any, patched core.ResponsesInput
11491149
for key, value := range updated {
11501150
cloned[key] = value
11511151
}
1152+
if patched.Type == "function_call_output" {
1153+
cloned["output"] = restoreResponsesInputOutputValue(original["output"], patched.Output)
1154+
}
11521155
return cloned, nil
11531156
}
11541157

1158+
func restoreResponsesInputOutputValue(original any, rewritten string) any {
1159+
if _, ok := original.(string); ok {
1160+
return rewritten
1161+
}
1162+
if strings.TrimSpace(rewritten) == "" {
1163+
if original == nil {
1164+
return nil
1165+
}
1166+
return original
1167+
}
1168+
1169+
var decoded any
1170+
if err := json.Unmarshal([]byte(rewritten), &decoded); err == nil {
1171+
return decoded
1172+
}
1173+
if original == nil {
1174+
return nil
1175+
}
1176+
return original
1177+
}
1178+
11551179
func responsesInputElementAsMap(element core.ResponsesInputElement) (map[string]any, error) {
11561180
value, err := responsesInputElementAsAny(element)
11571181
if err != nil {

internal/guardrails/provider_test.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2253,6 +2253,82 @@ func TestApplyMessagesToResponses_PreservesSystemRoleInputItems(t *testing.T) {
22532253
}
22542254
}
22552255

2256+
func TestApplyMessagesToResponses_PreservesTypedFunctionCallOutputMapEnvelope(t *testing.T) {
2257+
req := &core.ResponsesRequest{
2258+
Model: "gpt-4",
2259+
Input: []map[string]any{
2260+
{
2261+
"type": "function_call_output",
2262+
"call_id": "call_123",
2263+
"output": map[string]any{
2264+
"patient": "John Smith",
2265+
"score": float64(1),
2266+
},
2267+
},
2268+
},
2269+
}
2270+
msgs := []Message{{
2271+
Role: "tool",
2272+
ToolCallID: "call_123",
2273+
Content: `{"patient":"[|---|](PERSON_1)","score":1}`,
2274+
}}
2275+
2276+
result, err := applyMessagesToResponses(req, msgs)
2277+
if err != nil {
2278+
t.Fatalf("applyMessagesToResponses() error = %v", err)
2279+
}
2280+
2281+
input, ok := result.Input.([]map[string]any)
2282+
if !ok || len(input) != 1 {
2283+
t.Fatalf("Input = %#v, want []map[string]any len=1", result.Input)
2284+
}
2285+
output, ok := input[0]["output"].(map[string]any)
2286+
if !ok {
2287+
t.Fatalf("output = %#v, want map[string]any", input[0]["output"])
2288+
}
2289+
if output["patient"] != "[|---|](PERSON_1)" {
2290+
t.Fatalf("patient = %#v, want rewritten redaction", output["patient"])
2291+
}
2292+
if output["score"] != float64(1) {
2293+
t.Fatalf("score = %#v, want 1", output["score"])
2294+
}
2295+
}
2296+
2297+
func TestApplyMessagesToResponses_PreservesStringFunctionCallOutputMapEnvelope(t *testing.T) {
2298+
req := &core.ResponsesRequest{
2299+
Model: "gpt-4",
2300+
Input: []map[string]any{
2301+
{
2302+
"type": "function_call_output",
2303+
"call_id": "call_123",
2304+
"output": `{"patient":"John Smith"}`,
2305+
},
2306+
},
2307+
}
2308+
msgs := []Message{{
2309+
Role: "tool",
2310+
ToolCallID: "call_123",
2311+
Content: `{"patient":"[|---|](PERSON_1)"}`,
2312+
}}
2313+
2314+
result, err := applyMessagesToResponses(req, msgs)
2315+
if err != nil {
2316+
t.Fatalf("applyMessagesToResponses() error = %v", err)
2317+
}
2318+
2319+
input, ok := result.Input.([]map[string]any)
2320+
if !ok || len(input) != 1 {
2321+
t.Fatalf("Input = %#v, want []map[string]any len=1", result.Input)
2322+
}
2323+
output, ok := input[0]["output"].(string)
2324+
if !ok {
2325+
t.Fatalf("output = %#v, want string", input[0]["output"])
2326+
}
2327+
if output != `{"patient":"[|---|](PERSON_1)"}` {
2328+
t.Fatalf("Output = %q, want rewritten JSON string preserved as string", output)
2329+
}
2330+
}
2331+
22562332
func TestApplyMessagesToResponses_PreservesArrayEnvelopeForAnyInput(t *testing.T) {
22572333
req := &core.ResponsesRequest{
22582334
Model: "gpt-4",

internal/responsecache/responsecache.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,9 +242,9 @@ func internalRequestHeaders(ctx context.Context) http.Header {
242242
func internalCacheType(headerValue string) string {
243243
headerValue = strings.TrimSpace(headerValue)
244244
switch headerValue {
245-
case "HIT (exact)":
245+
case CacheHeaderExact:
246246
return CacheTypeExact
247-
case "HIT (semantic)":
247+
case CacheHeaderSemantic:
248248
return CacheTypeSemantic
249249
default:
250250
return ""

internal/responsecache/semantic.go

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -387,17 +387,17 @@ func extractTextFromContent(content any) string {
387387
// (e.g. "/v1/chat/completions") and isolates entries across distinct endpoints.
388388
func computeParamsHash(body []byte, endpointPath string, plan *core.ExecutionPlan, guardrailsHash, embedderIdentity string) string {
389389
var req struct {
390-
Model string `json:"model"`
391-
Temperature *float64 `json:"temperature"`
392-
TopP *float64 `json:"top_p"`
393-
MaxTokens *int `json:"max_tokens"`
394-
MaxOutputTokens *int `json:"max_output_tokens"`
395-
Tools []map[string]any `json:"tools"`
396-
ResponseFormat any `json:"response_format"`
397-
Stream bool `json:"stream,omitempty"`
398-
StreamOptions *core.StreamOptions `json:"stream_options"`
399-
Reasoning json.RawMessage `json:"reasoning"`
400-
Instructions string `json:"instructions"`
390+
Model string `json:"model"`
391+
Temperature *float64 `json:"temperature"`
392+
TopP *float64 `json:"top_p"`
393+
MaxTokens *int `json:"max_tokens"`
394+
MaxOutputTokens *int `json:"max_output_tokens"`
395+
Tools []map[string]any `json:"tools"`
396+
ResponseFormat any `json:"response_format"`
397+
Stream bool `json:"stream,omitempty"`
398+
StreamOptions *core.StreamOptions `json:"stream_options"`
399+
Reasoning json.RawMessage `json:"reasoning"`
400+
Instructions string `json:"instructions"`
401401
}
402402
_ = json.Unmarshal(body, &req)
403403

@@ -576,9 +576,11 @@ func WithGuardrailsHash(ctx context.Context, hash string) context.Context {
576576

577577
// CacheTypeHeader values for X-Cache-Type.
578578
const (
579-
CacheTypeExact = "exact"
580-
CacheTypeSemantic = "semantic"
581-
CacheTypeBoth = "both"
579+
CacheTypeExact = "exact"
580+
CacheTypeSemantic = "semantic"
581+
CacheTypeBoth = "both"
582+
CacheHeaderExact = "HIT (exact)"
583+
CacheHeaderSemantic = "HIT (semantic)"
582584
)
583585

584586
// ShouldSkipExactCache reports whether the X-Cache-Type header requests semantic-only mode.

0 commit comments

Comments
 (0)