Skip to content

Commit b57e57c

Browse files
fix(auditlog): restore capture and stabilize integration tests
1 parent a65071f commit b57e57c

9 files changed

Lines changed: 188 additions & 89 deletions

File tree

internal/auditlog/auditlog_test.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1211,6 +1211,18 @@ func (t *trackingReadCloser) Close() error {
12111211
return nil
12121212
}
12131213

1214+
type chainReadCloser struct {
1215+
io.Reader
1216+
closer io.Closer
1217+
}
1218+
1219+
func (c *chainReadCloser) Close() error {
1220+
if c == nil || c.closer == nil {
1221+
return nil
1222+
}
1223+
return c.closer.Close()
1224+
}
1225+
12141226
// discardWriter implements http.ResponseWriter but discards all output.
12151227
type discardWriter struct{}
12161228

@@ -1304,9 +1316,9 @@ func TestLimitedReaderRequestBodyCapture(t *testing.T) {
13041316
}
13051317

13061318
origBody := req.Body
1307-
req.Body = &combinedReadCloser{
1319+
req.Body = &chainReadCloser{
13081320
Reader: io.MultiReader(bytes.NewReader(bodyBytes), origBody),
1309-
rc: origBody,
1321+
closer: origBody,
13101322
}
13111323

13121324
// Read full body from reconstructed reader

internal/auditlog/entry_capture.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package auditlog
2+
3+
import (
4+
"net/http"
5+
6+
"gomodel/internal/core"
7+
)
8+
9+
// PopulateRequestData copies the configured request capture fields into the log entry.
10+
// Streaming handlers call this before creating the detached stream entry so request
11+
// metadata is preserved even though the middleware finishes later.
12+
func PopulateRequestData(entry *LogEntry, req *http.Request, cfg Config) {
13+
if entry == nil || req == nil {
14+
return
15+
}
16+
17+
data := ensureLogData(entry)
18+
19+
if cfg.LogHeaders {
20+
data.RequestHeaders = extractHeaders(req.Header)
21+
}
22+
23+
if !cfg.LogBodies {
24+
return
25+
}
26+
27+
snapshot := core.GetRequestSnapshot(req.Context())
28+
if snapshot == nil {
29+
return
30+
}
31+
32+
switch body := snapshot.CapturedBody(); {
33+
case snapshot.BodyNotCaptured:
34+
data.RequestBodyTooBigToHandle = true
35+
case body != nil:
36+
captureLoggedRequestBody(entry, body)
37+
}
38+
}
39+
40+
// PopulateResponseHeaders copies response headers into the log entry when header logging is enabled.
41+
func PopulateResponseHeaders(entry *LogEntry, headers http.Header) {
42+
if entry == nil || headers == nil {
43+
return
44+
}
45+
46+
data := ensureLogData(entry)
47+
data.ResponseHeaders = extractHeaders(headers)
48+
}
49+
50+
func ensureLogData(entry *LogEntry) *LogData {
51+
if entry.Data == nil {
52+
entry.Data = &LogData{}
53+
}
54+
return entry.Data
55+
}

internal/auditlog/middleware.go

Lines changed: 10 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -72,28 +72,6 @@ func Middleware(logger LoggerInterface) echo.MiddlewareFunc {
7272
entry.Data.APIKeyHash = hashAPIKey(authHeader)
7373
}
7474

75-
// Log request headers if enabled
76-
if cfg.LogHeaders {
77-
entry.Data.RequestHeaders = extractHeaders(req.Header)
78-
}
79-
80-
// Capture request body if enabled
81-
if cfg.LogBodies && req.Body != nil {
82-
if snapshot := core.GetRequestSnapshot(req.Context()); snapshot != nil {
83-
body := snapshot.CapturedBody()
84-
switch {
85-
case snapshot.BodyNotCaptured:
86-
entry.Data.RequestBodyTooBigToHandle = true
87-
case body != nil:
88-
captureLoggedRequestBody(entry, body)
89-
default:
90-
captureRequestBodyForLogging(entry, req)
91-
}
92-
} else {
93-
captureRequestBodyForLogging(entry, req)
94-
}
95-
}
96-
9775
// Store entry in context for potential enrichment by handlers
9876
c.Set(string(LogEntryKey), entry)
9977

@@ -104,7 +82,7 @@ func Middleware(logger LoggerInterface) echo.MiddlewareFunc {
10482
ResponseWriter: c.Response(),
10583
body: &bytes.Buffer{},
10684
shouldCapture: func() bool {
107-
return shouldCaptureResponseBody(c)
85+
return auditEnabledForContext(c.Request().Context()) && shouldCaptureResponseBody(c)
10886
},
10987
}
11088
c.SetResponse(responseCapture)
@@ -115,16 +93,24 @@ func Middleware(logger LoggerInterface) echo.MiddlewareFunc {
11593

11694
applyExecutionPlan(entry, c.Request().Context())
11795

96+
if !auditEnabledForContext(c.Request().Context()) {
97+
return err
98+
}
99+
118100
// Calculate duration
119101
entry.DurationNs = time.Since(start).Nanoseconds()
120102

121103
// ResolveResponseStatus applies Echo v5 precedence rules for committed responses,
122104
// suggested status codes, and errors implementing HTTPStatusCoder.
123105
_, entry.StatusCode = echo.ResolveResponseStatus(c.Response(), err)
124106

107+
// Request capture is deferred until after next so a later-resolved
108+
// Audit=false plan can skip it entirely.
109+
PopulateRequestData(entry, req, cfg)
110+
125111
// Log response headers if enabled
126112
if cfg.LogHeaders {
127-
entry.Data.ResponseHeaders = extractHeaders(c.Response().Header())
113+
PopulateResponseHeaders(entry, c.Response().Header())
128114
}
129115

130116
// Capture response body if enabled
@@ -205,35 +191,6 @@ func enrichEntryWithExecutionPlan(entry *LogEntry, plan *core.ExecutionPlan) {
205191
}
206192
}
207193

208-
func captureRequestBodyForLogging(entry *LogEntry, req *http.Request) {
209-
if req.ContentLength > MaxBodyCapture {
210-
entry.Data.RequestBodyTooBigToHandle = true
211-
return
212-
}
213-
214-
// Read up to MaxBodyCapture+1 to detect overflow safely.
215-
// Uses io.LimitReader to enforce the cap regardless of
216-
// Content-Length (handles chunked/unknown-length requests).
217-
limitedReader := io.LimitReader(req.Body, MaxBodyCapture+1)
218-
bodyBytes, err := io.ReadAll(limitedReader)
219-
if err != nil {
220-
return
221-
}
222-
if int64(len(bodyBytes)) > MaxBodyCapture {
223-
entry.Data.RequestBodyTooBigToHandle = true
224-
// Reconstruct full body for downstream: read bytes + unread remainder
225-
origBody := req.Body
226-
req.Body = &combinedReadCloser{
227-
Reader: io.MultiReader(bytes.NewReader(bodyBytes), origBody),
228-
rc: origBody,
229-
}
230-
return
231-
}
232-
233-
captureLoggedRequestBody(entry, bodyBytes)
234-
req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
235-
}
236-
237194
func captureLoggedRequestBody(entry *LogEntry, bodyBytes []byte) {
238195
if len(bodyBytes) == 0 {
239196
return
@@ -250,17 +207,6 @@ func captureLoggedRequestBody(entry *LogEntry, bodyBytes []byte) {
250207
entry.Data.RequestBody = toValidUTF8String(bodyBytes)
251208
}
252209

253-
// combinedReadCloser delegates Read to an io.Reader and Close to an io.ReadCloser.
254-
// Used to reconstruct a request body that preserves the original closer.
255-
type combinedReadCloser struct {
256-
io.Reader
257-
rc io.ReadCloser
258-
}
259-
260-
func (c *combinedReadCloser) Close() error {
261-
return c.rc.Close()
262-
}
263-
264210
// responseBodyCapture wraps http.ResponseWriter to capture the response body.
265211
// It implements http.Flusher and http.Hijacker by delegating to the underlying
266212
// ResponseWriter if it supports those interfaces.

internal/server/http.go

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -200,17 +200,19 @@ func New(provider core.RoutableProvider, cfg *Config) *Server {
200200
e.Use(PassthroughSemanticEnrichment(cfg.PassthroughSemanticEnrichers, passthroughV1PrefixNormalizationEnabled(cfg)))
201201
}
202202

203-
// Request planning must run before audit so per-request feature flags can
204-
// suppress audit capture entirely.
205-
e.Use(ExecutionPlanningWithResolverAndPolicy(provider, modelResolver, executionPolicyResolver))
206-
207-
// Audit logging middleware runs after planning so Audit=false can bypass
208-
// request/response capture, but still before authentication so rejected model
209-
// requests are logged when auditing is enabled for the matched plan.
203+
// Audit logging runs before request planning so early planning/validation
204+
// failures are still logged. The middleware defers request capture and
205+
// dynamically gates response capture on the final resolved execution plan, so
206+
// Audit=false still suppresses per-request capture work.
210207
if cfg != nil && cfg.AuditLogger != nil && cfg.AuditLogger.Config().Enabled {
211208
e.Use(auditlog.Middleware(cfg.AuditLogger))
212209
}
213210

211+
// Request planning resolves the request-scoped execution plan before auth and
212+
// handler execution. This keeps rejected model requests loggable and lets
213+
// downstream stages consume a shared policy decision.
214+
e.Use(ExecutionPlanningWithResolverAndPolicy(provider, modelResolver, executionPolicyResolver))
215+
214216
// Authentication (skips public paths)
215217
if cfg != nil && cfg.MasterKey != "" {
216218
e.Use(AuthMiddleware(cfg.MasterKey, authSkipPaths))

internal/server/model_validation.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,6 @@ func ExecutionPlanningWithResolverAndPolicy(
5454
}
5555
}
5656

57-
func deriveExecutionPlan(c *echo.Context, provider core.RoutableProvider, resolver RequestModelResolver) (*core.ExecutionPlan, error) {
58-
return deriveExecutionPlanWithPolicy(c, provider, resolver, nil)
59-
}
60-
6157
func deriveExecutionPlanWithPolicy(
6258
c *echo.Context,
6359
provider core.RoutableProvider,

internal/server/passthrough_support.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,12 +244,20 @@ func (s *passthroughService) proxyPassthroughResponse(c *echo.Context, providerT
244244
if isSSEContentType(resp.Headers) {
245245
auditlog.MarkEntryAsStreaming(c, true)
246246
auditlog.EnrichEntryWithStream(c, true)
247+
plan := core.GetExecutionPlan(c.Request().Context())
248+
auditEnabled := s.logger != nil && s.logger.Config().Enabled && (plan == nil || plan.AuditEnabled())
247249

248250
entry := auditlog.GetStreamEntryFromContext(c)
251+
if auditEnabled && entry != nil {
252+
auditlog.PopulateRequestData(entry, c.Request(), s.logger.Config())
253+
}
249254
streamEntry := auditlog.CreateStreamEntry(entry)
250255
if streamEntry != nil {
251256
streamEntry.StatusCode = resp.StatusCode
252257
}
258+
if auditEnabled && streamEntry != nil && s.logger.Config().LogHeaders {
259+
auditlog.PopulateResponseHeaders(streamEntry, c.Response().Header())
260+
}
253261

254262
requestID := requestIDFromContextOrHeader(c.Request())
255263
auditPath := passthroughAuditPath(c, providerType, endpoint, info)
@@ -261,11 +269,10 @@ func (s *passthroughService) proxyPassthroughResponse(c *echo.Context, providerT
261269
if info != nil {
262270
model = strings.TrimSpace(info.Model)
263271
}
264-
model = resolvedModelFromPlan(core.GetExecutionPlan(c.Request().Context()), model)
272+
model = resolvedModelFromPlan(plan, model)
265273

266274
observers := make([]streaming.Observer, 0, 2)
267-
plan := core.GetExecutionPlan(c.Request().Context())
268-
if (plan == nil || plan.AuditEnabled()) && streamEntry != nil {
275+
if auditEnabled && streamEntry != nil {
269276
if observer := auditlog.NewStreamLogObserver(s.logger, streamEntry, auditPath); observer != nil {
270277
observers = append(observers, observer)
271278
}

internal/server/translated_inference_service.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,10 @@ func (s *translatedInferenceService) handleStreamingResponse(
332332
auditlog.EnrichEntryWithStream(c, true)
333333

334334
entry := auditlog.GetStreamEntryFromContext(c)
335+
auditEnabled := s.logger != nil && s.logger.Config().Enabled && (plan == nil || plan.AuditEnabled())
336+
if auditEnabled && entry != nil {
337+
auditlog.PopulateRequestData(entry, c.Request(), s.logger.Config())
338+
}
335339
streamEntry := auditlog.CreateStreamEntry(entry)
336340
if streamEntry != nil {
337341
streamEntry.StatusCode = http.StatusOK
@@ -340,7 +344,7 @@ func (s *translatedInferenceService) handleStreamingResponse(
340344
requestID := requestIDFromContextOrHeader(c.Request())
341345
endpoint := c.Request().URL.Path
342346
observers := make([]streaming.Observer, 0, 2)
343-
if s.logger != nil && s.logger.Config().Enabled && streamEntry != nil && (plan == nil || plan.AuditEnabled()) {
347+
if auditEnabled && streamEntry != nil {
344348
observers = append(observers, auditlog.NewStreamLogObserver(s.logger, streamEntry, endpoint))
345349
}
346350
if s.usageLogger != nil && s.usageLogger.Config().Enabled && (plan == nil || plan.UsageEnabled()) {
@@ -356,12 +360,8 @@ func (s *translatedInferenceService) handleStreamingResponse(
356360
c.Response().Header().Set("Cache-Control", "no-cache")
357361
c.Response().Header().Set("Connection", "keep-alive")
358362

359-
if streamEntry != nil && streamEntry.Data != nil {
360-
streamEntry.Data.ResponseHeaders = map[string]string{
361-
"Content-Type": "text/event-stream",
362-
"Cache-Control": "no-cache",
363-
"Connection": "keep-alive",
364-
}
363+
if auditEnabled && streamEntry != nil && s.logger.Config().LogHeaders {
364+
auditlog.PopulateResponseHeaders(streamEntry, c.Response().Header())
365365
}
366366

367367
c.Response().WriteHeader(http.StatusOK)

tests/integration/main_test.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"context"
1010
"fmt"
1111
"log"
12+
"net/url"
1213
"os"
1314
"testing"
1415
"time"
@@ -123,7 +124,7 @@ func setupMongoDB(ctx context.Context) error {
123124
var err error
124125

125126
log.Println("Starting MongoDB container...")
126-
mongoContainer, err = mongodb.Run(ctx, "mongo:7")
127+
mongoContainer, err = mongodb.Run(ctx, "mongo:7", mongodb.WithReplicaSet("rs"))
127128
if err != nil {
128129
return fmt.Errorf("failed to start MongoDB container: %w", err)
129130
}
@@ -133,11 +134,15 @@ func setupMongoDB(ctx context.Context) error {
133134
if err != nil {
134135
return fmt.Errorf("failed to get MongoDB connection string: %w", err)
135136
}
137+
mongoURL, err = withDirectMongoConnection(mongoURL)
138+
if err != nil {
139+
return fmt.Errorf("failed to normalize MongoDB connection string: %w", err)
140+
}
136141

137142
log.Printf("MongoDB URL: %s", mongoURL)
138143

139144
// Create client
140-
mongoClient, err = mongo.Connect(options.Client().ApplyURI(mongoURL))
145+
mongoClient, err = mongo.Connect(options.Client().ApplyURI(mongoURL).SetDirect(true))
141146
if err != nil {
142147
return fmt.Errorf("failed to create MongoDB client: %w", err)
143148
}
@@ -213,3 +218,14 @@ func GetMongoURL() string {
213218
func GetTestContext() context.Context {
214219
return testCtx
215220
}
221+
222+
func withDirectMongoConnection(rawURL string) (string, error) {
223+
parsed, err := url.Parse(rawURL)
224+
if err != nil {
225+
return "", err
226+
}
227+
query := parsed.Query()
228+
query.Set("directConnection", "true")
229+
parsed.RawQuery = query.Encode()
230+
return parsed.String(), nil
231+
}

0 commit comments

Comments
 (0)