@@ -2,13 +2,15 @@ package executionplans
22
33import (
44 "context"
5+ "encoding/json"
56 "errors"
67 "sync"
78 "sync/atomic"
89 "testing"
910 "time"
1011
1112 "gomodel/internal/core"
13+ "gomodel/internal/guardrails"
1214)
1315
1416type staticStore struct {
@@ -614,6 +616,171 @@ func TestServiceEnsureDefaultGlobal_ValidatesBeforeStoreMutation(t *testing.T) {
614616 }
615617}
616618
619+ func TestServiceRefresh_RebuildsCompiledGuardrailPipelinesAfterExecutorSwap (t * testing.T ) {
620+ guardrailStore := & guardrailTestStore {
621+ definitions : map [string ]guardrails.Definition {
622+ "privacy" : {
623+ Name : "privacy" ,
624+ Type : "llm_based_altering" ,
625+ Config : mustMarshalJSON (t , map [string ]any {
626+ "model" : "gpt-4o-mini" ,
627+ "roles" : []string {"user" },
628+ }),
629+ },
630+ },
631+ }
632+ guardrailService , err := guardrails .NewService (guardrailStore , guardrailExecutorFunc (func (_ context.Context , _ * core.ChatRequest ) (* core.ChatResponse , error ) {
633+ return & core.ChatResponse {
634+ Choices : []core.Choice {
635+ {Message : core.ResponseMessage {Role : "assistant" , Content : "[|---|](PERSON_1)" }},
636+ },
637+ }, nil
638+ }))
639+ if err != nil {
640+ t .Fatalf ("guardrails.NewService() error = %v" , err )
641+ }
642+ if err := guardrailService .Refresh (context .Background ()); err != nil {
643+ t .Fatalf ("guardrailService.Refresh() error = %v" , err )
644+ }
645+
646+ store := & staticStore {
647+ versions : []Version {
648+ {
649+ ID : "global-v1" ,
650+ Scope : Scope {},
651+ ScopeKey : "global" ,
652+ Version : 1 ,
653+ Active : true ,
654+ Name : "global" ,
655+ Payload : Payload {
656+ SchemaVersion : 1 ,
657+ Features : FeatureFlags {
658+ Cache : false ,
659+ Audit : true ,
660+ Usage : true ,
661+ Guardrails : true ,
662+ },
663+ Guardrails : []GuardrailStep {
664+ {Ref : "privacy" , Step : 10 },
665+ },
666+ },
667+ },
668+ },
669+ }
670+ service , err := NewService (store , NewCompilerWithFeatureCaps (guardrailService , core .DefaultExecutionFeatures ()))
671+ if err != nil {
672+ t .Fatalf ("NewService() error = %v" , err )
673+ }
674+ if err := service .Refresh (context .Background ()); err != nil {
675+ t .Fatalf ("service.Refresh() error = %v" , err )
676+ }
677+
678+ selector := core .NewExecutionPlanSelector ("" , "" , "/" )
679+ policy , err := service .Match (selector )
680+ if err != nil {
681+ t .Fatalf ("service.Match() error = %v" , err )
682+ }
683+ plan := & core.ExecutionPlan {Policy : policy }
684+
685+ assertPipelineRewrite (t , service .PipelineForExecutionPlan (plan ), "[|---|](PERSON_1)" )
686+
687+ if err := guardrailService .SetExecutor (context .Background (), guardrailExecutorFunc (func (_ context.Context , _ * core.ChatRequest ) (* core.ChatResponse , error ) {
688+ return & core.ChatResponse {
689+ Choices : []core.Choice {
690+ {Message : core.ResponseMessage {Role : "assistant" , Content : "[|---|](PERSON_2)" }},
691+ },
692+ }, nil
693+ })); err != nil {
694+ t .Fatalf ("guardrailService.SetExecutor() error = %v" , err )
695+ }
696+
697+ assertPipelineRewrite (t , service .PipelineForExecutionPlan (plan ), "[|---|](PERSON_1)" )
698+
699+ if err := service .Refresh (context .Background ()); err != nil {
700+ t .Fatalf ("service.Refresh() after SetExecutor error = %v" , err )
701+ }
702+ assertPipelineRewrite (t , service .PipelineForExecutionPlan (plan ), "[|---|](PERSON_2)" )
703+ }
704+
705+ type guardrailTestStore struct {
706+ definitions map [string ]guardrails.Definition
707+ }
708+
709+ func (s * guardrailTestStore ) List (context.Context ) ([]guardrails.Definition , error ) {
710+ result := make ([]guardrails.Definition , 0 , len (s .definitions ))
711+ for _ , definition := range s .definitions {
712+ result = append (result , definition )
713+ }
714+ return result , nil
715+ }
716+
717+ func (s * guardrailTestStore ) Get (_ context.Context , name string ) (* guardrails.Definition , error ) {
718+ definition , ok := s .definitions [name ]
719+ if ! ok {
720+ return nil , guardrails .ErrNotFound
721+ }
722+ copy := definition
723+ return & copy , nil
724+ }
725+
726+ func (s * guardrailTestStore ) Upsert (_ context.Context , definition guardrails.Definition ) error {
727+ if s .definitions == nil {
728+ s .definitions = make (map [string ]guardrails.Definition )
729+ }
730+ s .definitions [definition .Name ] = definition
731+ return nil
732+ }
733+
734+ func (s * guardrailTestStore ) UpsertMany (_ context.Context , definitions []guardrails.Definition ) error {
735+ if s .definitions == nil {
736+ s .definitions = make (map [string ]guardrails.Definition )
737+ }
738+ for _ , definition := range definitions {
739+ s .definitions [definition .Name ] = definition
740+ }
741+ return nil
742+ }
743+
744+ func (s * guardrailTestStore ) Delete (_ context.Context , name string ) error {
745+ delete (s .definitions , name )
746+ return nil
747+ }
748+
749+ func (s * guardrailTestStore ) Close () error { return nil }
750+
751+ type guardrailExecutorFunc func (context.Context , * core.ChatRequest ) (* core.ChatResponse , error )
752+
753+ func (f guardrailExecutorFunc ) ChatCompletion (ctx context.Context , req * core.ChatRequest ) (* core.ChatResponse , error ) {
754+ return f (ctx , req )
755+ }
756+
757+ func mustMarshalJSON (t * testing.T , value any ) []byte {
758+ t .Helper ()
759+ raw , err := json .Marshal (value )
760+ if err != nil {
761+ t .Fatalf ("json.Marshal() error = %v" , err )
762+ }
763+ return raw
764+ }
765+
766+ func assertPipelineRewrite (t * testing.T , pipeline * guardrails.Pipeline , want string ) {
767+ t .Helper ()
768+ if pipeline == nil {
769+ t .Fatal ("pipeline = nil, want non-nil" )
770+ }
771+
772+ msgs , err := pipeline .Process (context .Background (), []guardrails.Message {{Role : "user" , Content : "John Smith" }})
773+ if err != nil {
774+ t .Fatalf ("pipeline.Process() error = %v" , err )
775+ }
776+ if len (msgs ) != 1 {
777+ t .Fatalf ("len(msgs) = %d, want 1" , len (msgs ))
778+ }
779+ if msgs [0 ].Content != want {
780+ t .Fatalf ("msgs[0].Content = %q, want %q" , msgs [0 ].Content , want )
781+ }
782+ }
783+
617784func TestServiceCreate_RefreshesSnapshot (t * testing.T ) {
618785 store := & staticStore {
619786 versions : []Version {
0 commit comments