Skip to content

Commit b29838c

Browse files
fix(guardrails): refresh startup pipelines after executor swap
1 parent e1e438b commit b29838c

6 files changed

Lines changed: 302 additions & 27 deletions

File tree

internal/app/app.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,13 @@ func New(ctx context.Context, cfg Config) (*App, error) {
415415
}
416416
return nil, fmt.Errorf("failed to wire internal guardrail executor: %w", err)
417417
}
418+
if err := executionPlanResult.Service.Refresh(ctx); err != nil {
419+
closeErr := errors.Join(app.executionPlans.Close(), app.guardrails.Close(), app.authKeys.Close(), app.aliases.Close(), app.batch.Close(), app.usage.Close(), app.audit.Close(), app.providers.Close())
420+
if closeErr != nil {
421+
return nil, fmt.Errorf("failed to refresh execution plans after wiring internal guardrail executor: %w (also: close error: %v)", err, closeErr)
422+
}
423+
return nil, fmt.Errorf("failed to refresh execution plans after wiring internal guardrail executor: %w", err)
424+
}
418425

419426
app.server = server.New(provider, serverCfg)
420427

internal/auditlog/entry_capture_test.go

Lines changed: 103 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ package auditlog
22

33
import (
44
"context"
5+
"fmt"
56
"net/http"
7+
"strings"
68
"testing"
79

810
"gomodel/internal/core"
@@ -55,32 +57,110 @@ func TestCaptureInternalJSONExchange_PreservesHeadersWithoutBodies(t *testing.T)
5557
}
5658

5759
func TestCaptureInternalJSONExchange_PreservesHeadersWhenBodyMarshalFails(t *testing.T) {
58-
entry := &LogEntry{
59-
RequestID: "req_456",
60-
Data: &LogData{},
61-
}
62-
ctx := core.WithEffectiveUserPath(context.Background(), "/team/beta")
60+
t.Run("marshal failure preserves headers", func(t *testing.T) {
61+
entry := &LogEntry{
62+
RequestID: "req_456",
63+
Data: &LogData{},
64+
}
65+
ctx := core.WithEffectiveUserPath(context.Background(), "/team/beta")
6366

64-
CaptureInternalJSONExchange(entry, ctx, "POST", "/v1/chat/completions", func() {}, func() {}, nil, Config{
65-
LogHeaders: true,
66-
LogBodies: true,
67+
CaptureInternalJSONExchange(entry, ctx, "POST", "/v1/chat/completions", func() {}, func() {}, nil, Config{
68+
LogHeaders: true,
69+
LogBodies: true,
70+
})
71+
72+
if entry.Data == nil {
73+
t.Fatal("Data = nil, want populated log data")
74+
}
75+
if got := entry.Data.RequestHeaders[http.CanonicalHeaderKey("X-Request-ID")]; got != "req_456" {
76+
t.Fatalf("RequestHeaders[X-Request-ID] = %q, want req_456", got)
77+
}
78+
if got := entry.Data.RequestHeaders[http.CanonicalHeaderKey(core.UserPathHeader)]; got != "/team/beta" {
79+
t.Fatalf("RequestHeaders[%s] = %q, want /team/beta", core.UserPathHeader, got)
80+
}
81+
if got := entry.Data.ResponseHeaders[http.CanonicalHeaderKey("X-Request-ID")]; got != "req_456" {
82+
t.Fatalf("ResponseHeaders[X-Request-ID] = %q, want req_456", got)
83+
}
84+
if entry.Data.RequestBody != nil || entry.Data.ResponseBody != nil {
85+
t.Fatal("expected marshal failures to skip bodies while preserving headers")
86+
}
6787
})
6888

69-
if entry.Data == nil {
70-
t.Fatal("Data = nil, want populated log data")
71-
}
72-
if got := entry.Data.RequestHeaders[http.CanonicalHeaderKey("X-Request-ID")]; got != "req_456" {
73-
t.Fatalf("RequestHeaders[X-Request-ID] = %q, want req_456", got)
74-
}
75-
if got := entry.Data.RequestHeaders[http.CanonicalHeaderKey(core.UserPathHeader)]; got != "/team/beta" {
76-
t.Fatalf("RequestHeaders[%s] = %q, want /team/beta", core.UserPathHeader, got)
77-
}
78-
if got := entry.Data.ResponseHeaders[http.CanonicalHeaderKey("X-Request-ID")]; got != "req_456" {
79-
t.Fatalf("ResponseHeaders[X-Request-ID] = %q, want req_456", got)
80-
}
81-
if entry.Data.RequestBody != nil || entry.Data.ResponseBody != nil {
82-
t.Fatal("expected marshal failures to skip bodies while preserving headers")
83-
}
89+
t.Run("response error preserves headers and captures error body", func(t *testing.T) {
90+
entry := &LogEntry{
91+
RequestID: "req_456_err",
92+
Data: &LogData{},
93+
}
94+
ctx := core.WithEffectiveUserPath(context.Background(), "/team/beta")
95+
responseErr := core.NewProviderError("openai", http.StatusBadGateway, "upstream failed", fmt.Errorf("boom"))
96+
97+
CaptureInternalJSONExchange(entry, ctx, "POST", "/v1/chat/completions", map[string]any{"ok": true}, nil, responseErr, Config{
98+
LogHeaders: true,
99+
LogBodies: true,
100+
})
101+
102+
if entry.Data == nil {
103+
t.Fatal("Data = nil, want populated log data")
104+
}
105+
if got := entry.Data.RequestHeaders[http.CanonicalHeaderKey("X-Request-ID")]; got != "req_456_err" {
106+
t.Fatalf("RequestHeaders[X-Request-ID] = %q, want req_456_err", got)
107+
}
108+
if got := entry.Data.ResponseHeaders[http.CanonicalHeaderKey("X-Request-ID")]; got != "req_456_err" {
109+
t.Fatalf("ResponseHeaders[X-Request-ID] = %q, want req_456_err", got)
110+
}
111+
body, ok := entry.Data.ResponseBody.(map[string]any)
112+
if !ok {
113+
t.Fatalf("ResponseBody = %T, want synthesized error envelope", entry.Data.ResponseBody)
114+
}
115+
errorBody, ok := body["error"].(map[string]any)
116+
if !ok {
117+
t.Fatalf("ResponseBody[error] = %#v, want object", body["error"])
118+
}
119+
if got := errorBody["message"]; got != "upstream failed" {
120+
t.Fatalf("ResponseBody.error.message = %#v, want upstream failed", got)
121+
}
122+
})
123+
124+
t.Run("oversized payload preserves headers and sets truncation flags", func(t *testing.T) {
125+
entry := &LogEntry{
126+
RequestID: "req_456_big",
127+
Data: &LogData{},
128+
}
129+
ctx := core.WithEffectiveUserPath(context.Background(), "/team/beta")
130+
large := strings.Repeat("x", int(MaxBodyCapture)+1024)
131+
132+
CaptureInternalJSONExchange(entry, ctx, "POST", "/v1/chat/completions",
133+
map[string]any{"payload": large},
134+
map[string]any{"payload": large},
135+
nil,
136+
Config{
137+
LogHeaders: true,
138+
LogBodies: true,
139+
},
140+
)
141+
142+
if entry.Data == nil {
143+
t.Fatal("Data = nil, want populated log data")
144+
}
145+
if got := entry.Data.RequestHeaders[http.CanonicalHeaderKey("X-Request-ID")]; got != "req_456_big" {
146+
t.Fatalf("RequestHeaders[X-Request-ID] = %q, want req_456_big", got)
147+
}
148+
if got := entry.Data.ResponseHeaders[http.CanonicalHeaderKey("X-Request-ID")]; got != "req_456_big" {
149+
t.Fatalf("ResponseHeaders[X-Request-ID] = %q, want req_456_big", got)
150+
}
151+
if !entry.Data.RequestBodyTooBigToHandle {
152+
t.Fatal("RequestBodyTooBigToHandle = false, want true")
153+
}
154+
if entry.Data.RequestBody != nil {
155+
t.Fatalf("RequestBody = %#v, want omitted oversized request body", entry.Data.RequestBody)
156+
}
157+
if !entry.Data.ResponseBodyTooBigToHandle {
158+
t.Fatal("ResponseBodyTooBigToHandle = false, want true")
159+
}
160+
if entry.Data.ResponseBody == nil {
161+
t.Fatal("ResponseBody = nil, want truncated captured payload")
162+
}
163+
})
84164
}
85165

86166
func TestCaptureInternalJSONExchange_DoesNotReuseIngressSnapshotOnMarshalFailure(t *testing.T) {

internal/executionplans/service_test.go

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@ package executionplans
22

33
import (
44
"context"
5+
"encoding/json"
56
"errors"
67
"sync"
78
"sync/atomic"
89
"testing"
910
"time"
1011

1112
"gomodel/internal/core"
13+
"gomodel/internal/guardrails"
1214
)
1315

1416
type staticStore struct {
@@ -614,6 +616,171 @@ func TestServiceEnsureDefaultGlobal_ValidatesBeforeStoreMutation(t *testing.T) {
614616
}
615617
}
616618

619+
func TestServiceRefresh_RebuildsCompiledGuardrailPipelinesAfterExecutorSwap(t *testing.T) {
620+
guardrailStore := &guardrailTestStore{
621+
definitions: map[string]guardrails.Definition{
622+
"privacy": {
623+
Name: "privacy",
624+
Type: "llm_based_altering",
625+
Config: mustMarshalJSON(t, map[string]any{
626+
"model": "gpt-4o-mini",
627+
"roles": []string{"user"},
628+
}),
629+
},
630+
},
631+
}
632+
guardrailService, err := guardrails.NewService(guardrailStore, guardrailExecutorFunc(func(_ context.Context, _ *core.ChatRequest) (*core.ChatResponse, error) {
633+
return &core.ChatResponse{
634+
Choices: []core.Choice{
635+
{Message: core.ResponseMessage{Role: "assistant", Content: "[|---|](PERSON_1)"}},
636+
},
637+
}, nil
638+
}))
639+
if err != nil {
640+
t.Fatalf("guardrails.NewService() error = %v", err)
641+
}
642+
if err := guardrailService.Refresh(context.Background()); err != nil {
643+
t.Fatalf("guardrailService.Refresh() error = %v", err)
644+
}
645+
646+
store := &staticStore{
647+
versions: []Version{
648+
{
649+
ID: "global-v1",
650+
Scope: Scope{},
651+
ScopeKey: "global",
652+
Version: 1,
653+
Active: true,
654+
Name: "global",
655+
Payload: Payload{
656+
SchemaVersion: 1,
657+
Features: FeatureFlags{
658+
Cache: false,
659+
Audit: true,
660+
Usage: true,
661+
Guardrails: true,
662+
},
663+
Guardrails: []GuardrailStep{
664+
{Ref: "privacy", Step: 10},
665+
},
666+
},
667+
},
668+
},
669+
}
670+
service, err := NewService(store, NewCompilerWithFeatureCaps(guardrailService, core.DefaultExecutionFeatures()))
671+
if err != nil {
672+
t.Fatalf("NewService() error = %v", err)
673+
}
674+
if err := service.Refresh(context.Background()); err != nil {
675+
t.Fatalf("service.Refresh() error = %v", err)
676+
}
677+
678+
selector := core.NewExecutionPlanSelector("", "", "/")
679+
policy, err := service.Match(selector)
680+
if err != nil {
681+
t.Fatalf("service.Match() error = %v", err)
682+
}
683+
plan := &core.ExecutionPlan{Policy: policy}
684+
685+
assertPipelineRewrite(t, service.PipelineForExecutionPlan(plan), "[|---|](PERSON_1)")
686+
687+
if err := guardrailService.SetExecutor(context.Background(), guardrailExecutorFunc(func(_ context.Context, _ *core.ChatRequest) (*core.ChatResponse, error) {
688+
return &core.ChatResponse{
689+
Choices: []core.Choice{
690+
{Message: core.ResponseMessage{Role: "assistant", Content: "[|---|](PERSON_2)"}},
691+
},
692+
}, nil
693+
})); err != nil {
694+
t.Fatalf("guardrailService.SetExecutor() error = %v", err)
695+
}
696+
697+
assertPipelineRewrite(t, service.PipelineForExecutionPlan(plan), "[|---|](PERSON_1)")
698+
699+
if err := service.Refresh(context.Background()); err != nil {
700+
t.Fatalf("service.Refresh() after SetExecutor error = %v", err)
701+
}
702+
assertPipelineRewrite(t, service.PipelineForExecutionPlan(plan), "[|---|](PERSON_2)")
703+
}
704+
705+
type guardrailTestStore struct {
706+
definitions map[string]guardrails.Definition
707+
}
708+
709+
func (s *guardrailTestStore) List(context.Context) ([]guardrails.Definition, error) {
710+
result := make([]guardrails.Definition, 0, len(s.definitions))
711+
for _, definition := range s.definitions {
712+
result = append(result, definition)
713+
}
714+
return result, nil
715+
}
716+
717+
func (s *guardrailTestStore) Get(_ context.Context, name string) (*guardrails.Definition, error) {
718+
definition, ok := s.definitions[name]
719+
if !ok {
720+
return nil, guardrails.ErrNotFound
721+
}
722+
copy := definition
723+
return &copy, nil
724+
}
725+
726+
func (s *guardrailTestStore) Upsert(_ context.Context, definition guardrails.Definition) error {
727+
if s.definitions == nil {
728+
s.definitions = make(map[string]guardrails.Definition)
729+
}
730+
s.definitions[definition.Name] = definition
731+
return nil
732+
}
733+
734+
func (s *guardrailTestStore) UpsertMany(_ context.Context, definitions []guardrails.Definition) error {
735+
if s.definitions == nil {
736+
s.definitions = make(map[string]guardrails.Definition)
737+
}
738+
for _, definition := range definitions {
739+
s.definitions[definition.Name] = definition
740+
}
741+
return nil
742+
}
743+
744+
func (s *guardrailTestStore) Delete(_ context.Context, name string) error {
745+
delete(s.definitions, name)
746+
return nil
747+
}
748+
749+
func (s *guardrailTestStore) Close() error { return nil }
750+
751+
type guardrailExecutorFunc func(context.Context, *core.ChatRequest) (*core.ChatResponse, error)
752+
753+
func (f guardrailExecutorFunc) ChatCompletion(ctx context.Context, req *core.ChatRequest) (*core.ChatResponse, error) {
754+
return f(ctx, req)
755+
}
756+
757+
func mustMarshalJSON(t *testing.T, value any) []byte {
758+
t.Helper()
759+
raw, err := json.Marshal(value)
760+
if err != nil {
761+
t.Fatalf("json.Marshal() error = %v", err)
762+
}
763+
return raw
764+
}
765+
766+
func assertPipelineRewrite(t *testing.T, pipeline *guardrails.Pipeline, want string) {
767+
t.Helper()
768+
if pipeline == nil {
769+
t.Fatal("pipeline = nil, want non-nil")
770+
}
771+
772+
msgs, err := pipeline.Process(context.Background(), []guardrails.Message{{Role: "user", Content: "John Smith"}})
773+
if err != nil {
774+
t.Fatalf("pipeline.Process() error = %v", err)
775+
}
776+
if len(msgs) != 1 {
777+
t.Fatalf("len(msgs) = %d, want 1", len(msgs))
778+
}
779+
if msgs[0].Content != want {
780+
t.Fatalf("msgs[0].Content = %q, want %q", msgs[0].Content, want)
781+
}
782+
}
783+
617784
func TestServiceCreate_RefreshesSnapshot(t *testing.T) {
618785
store := &staticStore{
619786
versions: []Version{

internal/guardrails/provider.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -997,7 +997,7 @@ func applyMessagesToResponses(req *core.ResponsesRequest, msgs []Message) (*core
997997
}
998998

999999
func applyMessagesToResponsesInput(original any, msgs []Message) (any, error) {
1000-
switch typed := original.(type) {
1000+
switch original.(type) {
10011001
case nil:
10021002
if len(msgs) != 0 {
10031003
return nil, core.NewInvalidRequestError("guardrails cannot add or remove responses input items", nil)
@@ -1014,8 +1014,6 @@ func applyMessagesToResponsesInput(original any, msgs []Message) (any, error) {
10141014
return "", nil
10151015
}
10161016
return msgs[0].Content, nil
1017-
default:
1018-
_ = typed
10191017
}
10201018

10211019
elements, err := coerceResponsesInputElements(original)

internal/responsecache/handle_request_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,28 @@ func TestHandleInternalRequest_RejectsNilContext(t *testing.T) {
112112
}
113113
}
114114

115+
func TestHandleInternalRequest_RejectsNilMiddleware(t *testing.T) {
116+
var m *ResponseCacheMiddleware
117+
118+
_, err := m.HandleInternalRequest(context.Background(), http.MethodPost, "/v1/chat/completions", []byte(`{}`), func(c *echo.Context) error {
119+
return c.JSON(http.StatusOK, map[string]string{"ok": "1"})
120+
})
121+
if err == nil {
122+
t.Fatal("HandleInternalRequest() error = nil, want provider error")
123+
}
124+
125+
gatewayErr, ok := err.(*core.GatewayError)
126+
if !ok {
127+
t.Fatalf("HandleInternalRequest() error = %T, want *core.GatewayError", err)
128+
}
129+
if gatewayErr.Type != core.ErrorTypeProvider {
130+
t.Fatalf("error type = %q, want %q", gatewayErr.Type, core.ErrorTypeProvider)
131+
}
132+
if gatewayErr.HTTPStatusCode() != http.StatusInternalServerError {
133+
t.Fatalf("status code = %d, want %d", gatewayErr.HTTPStatusCode(), http.StatusInternalServerError)
134+
}
135+
}
136+
115137
func TestHandleInternalRequest_RejectsUninitializedEcho(t *testing.T) {
116138
m := &ResponseCacheMiddleware{}
117139

0 commit comments

Comments
 (0)