Skip to content

Commit 26a1394

Browse files
refactor: refactored code duplicates + magic numbers
1 parent b18b768 commit 26a1394

5 files changed

Lines changed: 107 additions & 111 deletions

File tree

internal/auditlog/constants.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package auditlog
2+
3+
// Buffer and capture limits for audit logging.
4+
const (
5+
// MaxBodyCapture is the maximum size of request/response bodies to capture (1MB).
6+
// Prevents memory exhaustion from large payloads.
7+
MaxBodyCapture int64 = 1024 * 1024
8+
9+
// MaxContentCapture is the maximum size of accumulated streaming content (1MB).
10+
// Used by StreamLogWrapper to limit reconstructed response body size.
11+
MaxContentCapture = 1024 * 1024
12+
13+
// SSEBufferSize is the rolling buffer size for extracting usage from SSE streams.
14+
// Must be large enough to capture the final usage event containing token counts.
15+
SSEBufferSize = 8192
16+
17+
// BatchFlushThreshold is the number of entries that triggers an immediate flush.
18+
// When the batch reaches this size, it's written to storage without waiting for the timer.
19+
BatchFlushThreshold = 100
20+
21+
// APIKeyHashPrefixLength is the number of hex characters from SHA256 hash.
22+
// 16 hex chars = 64 bits of entropy for identification without exposure.
23+
APIKeyHashPrefixLength = 16
24+
)
25+
26+
// Context keys for storing audit log data in request context.
27+
type contextKey string
28+
29+
const (
30+
// LogEntryKey is the context key for storing the log entry.
31+
LogEntryKey contextKey = "auditlog_entry"
32+
33+
// LogEntryStreamingKey is the context key for marking a request as streaming.
34+
// When true, the middleware skips logging (StreamLogWrapper handles it instead).
35+
LogEntryStreamingKey contextKey = "auditlog_entry_streaming"
36+
)

internal/auditlog/logger.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,23 +99,23 @@ func (l *Logger) flushLoop() {
9999
ticker := time.NewTicker(l.flushInterval)
100100
defer ticker.Stop()
101101

102-
batch := make([]*LogEntry, 0, 100)
102+
batch := make([]*LogEntry, 0, BatchFlushThreshold)
103103

104104
for {
105105
select {
106106
case entry := <-l.buffer:
107107
batch = append(batch, entry)
108-
// Flush when batch reaches 100 entries
109-
if len(batch) >= 100 {
108+
// Flush when batch reaches threshold
109+
if len(batch) >= BatchFlushThreshold {
110110
l.flushBatch(batch)
111-
batch = make([]*LogEntry, 0, 100)
111+
batch = make([]*LogEntry, 0, BatchFlushThreshold)
112112
}
113113

114114
case <-ticker.C:
115115
// Periodic flush
116116
if len(batch) > 0 {
117117
l.flushBatch(batch)
118-
batch = make([]*LogEntry, 0, 100)
118+
batch = make([]*LogEntry, 0, BatchFlushThreshold)
119119
}
120120

121121
case <-l.done:

internal/auditlog/middleware.go

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,8 @@ import (
2020
"github.com/labstack/echo/v4"
2121
)
2222

23-
// contextKey is a type for context keys to avoid collisions
24-
type contextKey string
25-
26-
const (
27-
// LogEntryKey is the context key for storing the log entry
28-
LogEntryKey contextKey = "auditlog_entry"
29-
30-
// maxBodyCapture is the maximum size of request/response bodies to capture (1MB)
31-
maxBodyCapture int64 = 1024 * 1024
32-
33-
// apiKeyHashPrefixLength is the number of hex characters to use from the SHA256 hash
34-
// of API keys. 16 hex chars = 64 bits of entropy, reducing collision risk compared to 8.
35-
apiKeyHashPrefixLength = 16
36-
)
23+
// Note: contextKey type and constants (LogEntryKey, LogEntryStreamingKey,
24+
// MaxBodyCapture, APIKeyHashPrefixLength) are defined in constants.go
3725

3826
// Middleware creates an Echo middleware for audit logging.
3927
// It captures request metadata at the start and response metadata at the end,
@@ -88,7 +76,7 @@ func Middleware(logger LoggerInterface) echo.MiddlewareFunc {
8876
// Capture request body if enabled
8977
if cfg.LogBodies && req.Body != nil && req.ContentLength > 0 {
9078
// Skip body capture if too large to prevent memory exhaustion
91-
if req.ContentLength > maxBodyCapture {
79+
if req.ContentLength > MaxBodyCapture {
9280
entry.Data.RequestBodyTooBigToHandle = true
9381
} else {
9482
bodyBytes, err := io.ReadAll(req.Body)
@@ -180,11 +168,11 @@ type responseBodyCapture struct {
180168
}
181169

182170
func (r *responseBodyCapture) Write(b []byte) (int, error) {
183-
// Write to the capture buffer (limit to maxBodyCapture to avoid memory issues)
184-
if r.body.Len() < int(maxBodyCapture) {
171+
// Write to the capture buffer (limit to MaxBodyCapture to avoid memory issues)
172+
if r.body.Len() < int(MaxBodyCapture) {
185173
r.body.Write(b)
186174
// Check if we just hit the limit
187-
if r.body.Len() >= int(maxBodyCapture) {
175+
if r.body.Len() >= int(MaxBodyCapture) {
188176
r.truncated = true
189177
}
190178
}
@@ -224,7 +212,7 @@ func extractHeaders(headers map[string][]string) map[string]string {
224212
}
225213

226214
// hashAPIKey creates a short hash of the API key for identification.
227-
// Returns first apiKeyHashPrefixLength hex characters of SHA256 hash.
215+
// Returns first APIKeyHashPrefixLength hex characters of SHA256 hash.
228216
func hashAPIKey(authHeader string) string {
229217
// Extract token from "Bearer <token>"
230218
token := strings.TrimPrefix(authHeader, "Bearer ")
@@ -234,7 +222,7 @@ func hashAPIKey(authHeader string) string {
234222
}
235223

236224
hash := sha256.Sum256([]byte(token))
237-
return hex.EncodeToString(hash[:])[:apiKeyHashPrefixLength]
225+
return hex.EncodeToString(hash[:])[:APIKeyHashPrefixLength]
238226
}
239227

240228
// EnrichEntry retrieves the log entry from context for enrichment by handlers.

internal/auditlog/stream_wrapper.go

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ import (
88
"time"
99
)
1010

11-
// maxContentCapture is the maximum size of accumulated content (1MB)
12-
const maxContentCapture = 1024 * 1024
11+
// Note: MaxContentCapture, SSEBufferSize, and LogEntryStreamingKey
12+
// constants are defined in constants.go
1313

1414
// streamResponseBuilder accumulates data from SSE events to reconstruct a response
1515
type streamResponseBuilder struct {
@@ -97,12 +97,12 @@ func (w *StreamLogWrapper) Read(p []byte) (n int, err error) {
9797
if _, errBuf := w.buffer.Write(p[:n]); errBuf != nil {
9898
return n, errBuf
9999
}
100-
// Keep only last 8KB to find "data: [DONE]" and usage
101-
if w.buffer.Len() > 8192 {
100+
// Keep only last SSEBufferSize bytes to find "data: [DONE]" and usage
101+
if w.buffer.Len() > SSEBufferSize {
102102
// Discard old data, keep recent
103103
data := w.buffer.Bytes()
104104
w.buffer.Reset()
105-
if _, errBuf := w.buffer.Write(data[len(data)-8192:]); errBuf != nil {
105+
if _, errBuf := w.buffer.Write(data[len(data)-SSEBufferSize:]); errBuf != nil {
106106
return n, errBuf
107107
}
108108
}
@@ -201,8 +201,8 @@ func (w *StreamLogWrapper) parseChatCompletionEvent(event map[string]interface{}
201201
}
202202
// Extract and accumulate content
203203
if content, ok := delta["content"].(string); ok && content != "" {
204-
if !w.builder.truncated && w.builder.contentLen < maxContentCapture {
205-
remaining := maxContentCapture - w.builder.contentLen
204+
if !w.builder.truncated && w.builder.contentLen < MaxContentCapture {
205+
remaining := MaxContentCapture - w.builder.contentLen
206206
if len(content) > remaining {
207207
content = content[:remaining]
208208
w.builder.truncated = true
@@ -241,8 +241,8 @@ func (w *StreamLogWrapper) parseResponsesAPIEvent(event map[string]interface{})
241241
case "response.output_text.delta":
242242
// Accumulate text delta
243243
if delta, ok := event["delta"].(string); ok && delta != "" {
244-
if !w.builder.truncated && w.builder.contentLen < maxContentCapture {
245-
remaining := maxContentCapture - w.builder.contentLen
244+
if !w.builder.truncated && w.builder.contentLen < MaxContentCapture {
245+
remaining := MaxContentCapture - w.builder.contentLen
246246
if len(delta) > remaining {
247247
delta = delta[:remaining]
248248
w.builder.truncated = true
@@ -504,13 +504,12 @@ func GetStreamEntryFromContext(c interface{ Get(string) interface{} }) *LogEntry
504504
// MarkEntryAsStreaming marks the entry as a streaming request so the middleware
505505
// knows not to log it (the stream wrapper will handle logging).
506506
func MarkEntryAsStreaming(c interface{ Set(string, interface{}) }, isStreaming bool) {
507-
// We use a simple marker in the context
508-
c.Set(string(LogEntryKey)+"_streaming", isStreaming)
507+
c.Set(string(LogEntryStreamingKey), isStreaming)
509508
}
510509

511510
// IsEntryMarkedAsStreaming checks if the entry is marked as streaming.
512511
func IsEntryMarkedAsStreaming(c interface{ Get(string) interface{} }) bool {
513-
val := c.Get(string(LogEntryKey) + "_streaming")
512+
val := c.Get(string(LogEntryStreamingKey))
514513
if val == nil {
515514
return false
516515
}

internal/server/handlers.go

Lines changed: 47 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,47 @@ func NewHandler(provider core.RoutableProvider, logger auditlog.LoggerInterface)
2626
}
2727
}
2828

29+
// handleStreamingResponse handles SSE streaming responses for both ChatCompletion and Responses endpoints.
30+
// It wraps the stream with audit logging and sets appropriate SSE headers.
31+
func (h *Handler) handleStreamingResponse(c echo.Context, streamFn func() (io.ReadCloser, error)) error {
32+
// Mark as streaming so middleware doesn't log (StreamLogWrapper handles it)
33+
auditlog.MarkEntryAsStreaming(c, true)
34+
auditlog.EnrichEntryWithStream(c, true)
35+
36+
stream, err := streamFn()
37+
if err != nil {
38+
return handleError(c, err)
39+
}
40+
41+
// Get entry from context and wrap stream for logging
42+
entry := auditlog.GetStreamEntryFromContext(c)
43+
streamEntry := auditlog.CreateStreamEntry(entry)
44+
if streamEntry != nil {
45+
streamEntry.StatusCode = http.StatusOK // Streaming always starts with 200 OK
46+
}
47+
wrappedStream := auditlog.WrapStreamForLogging(stream, h.logger, streamEntry, c.Request().URL.Path)
48+
defer func() {
49+
_ = wrappedStream.Close() //nolint:errcheck
50+
}()
51+
52+
c.Response().Header().Set("Content-Type", "text/event-stream")
53+
c.Response().Header().Set("Cache-Control", "no-cache")
54+
c.Response().Header().Set("Connection", "keep-alive")
55+
56+
// Capture response headers on stream entry AFTER setting them
57+
if streamEntry != nil && streamEntry.Data != nil {
58+
streamEntry.Data.ResponseHeaders = map[string]string{
59+
"Content-Type": "text/event-stream",
60+
"Cache-Control": "no-cache",
61+
"Connection": "keep-alive",
62+
}
63+
}
64+
65+
c.Response().WriteHeader(http.StatusOK)
66+
_, _ = io.Copy(c.Response().Writer, wrappedStream)
67+
return nil
68+
}
69+
2970
// ChatCompletion handles POST /v1/chat/completions
3071
func (h *Handler) ChatCompletion(c echo.Context) error {
3172
var req core.ChatRequest
@@ -42,43 +83,9 @@ func (h *Handler) ChatCompletion(c echo.Context) error {
4283

4384
// Handle streaming: proxy the raw SSE stream
4485
if req.Stream {
45-
// Mark as streaming so middleware doesn't log (StreamLogWrapper handles it)
46-
auditlog.MarkEntryAsStreaming(c, true)
47-
auditlog.EnrichEntryWithStream(c, true)
48-
49-
stream, err := h.provider.StreamChatCompletion(c.Request().Context(), &req)
50-
if err != nil {
51-
return handleError(c, err)
52-
}
53-
54-
// Get entry from context and wrap stream for logging
55-
entry := auditlog.GetStreamEntryFromContext(c)
56-
streamEntry := auditlog.CreateStreamEntry(entry)
57-
if streamEntry != nil {
58-
streamEntry.StatusCode = http.StatusOK // Streaming always starts with 200 OK
59-
}
60-
wrappedStream := auditlog.WrapStreamForLogging(stream, h.logger, streamEntry, c.Request().URL.Path)
61-
defer func() {
62-
_ = wrappedStream.Close() //nolint:errcheck
63-
}()
64-
65-
c.Response().Header().Set("Content-Type", "text/event-stream")
66-
c.Response().Header().Set("Cache-Control", "no-cache")
67-
c.Response().Header().Set("Connection", "keep-alive")
68-
69-
// Capture response headers on stream entry AFTER setting them
70-
if streamEntry != nil && streamEntry.Data != nil {
71-
streamEntry.Data.ResponseHeaders = map[string]string{
72-
"Content-Type": "text/event-stream",
73-
"Cache-Control": "no-cache",
74-
"Connection": "keep-alive",
75-
}
76-
}
77-
78-
c.Response().WriteHeader(http.StatusOK)
79-
80-
_, _ = io.Copy(c.Response().Writer, wrappedStream)
81-
return nil
86+
return h.handleStreamingResponse(c, func() (io.ReadCloser, error) {
87+
return h.provider.StreamChatCompletion(c.Request().Context(), &req)
88+
})
8289
}
8390

8491
// Non-streaming
@@ -125,43 +132,9 @@ func (h *Handler) Responses(c echo.Context) error {
125132

126133
// Handle streaming: proxy the raw SSE stream
127134
if req.Stream {
128-
// Mark as streaming so middleware doesn't log (StreamLogWrapper handles it)
129-
auditlog.MarkEntryAsStreaming(c, true)
130-
auditlog.EnrichEntryWithStream(c, true)
131-
132-
stream, err := h.provider.StreamResponses(c.Request().Context(), &req)
133-
if err != nil {
134-
return handleError(c, err)
135-
}
136-
137-
// Get entry from context and wrap stream for logging
138-
entry := auditlog.GetStreamEntryFromContext(c)
139-
streamEntry := auditlog.CreateStreamEntry(entry)
140-
if streamEntry != nil {
141-
streamEntry.StatusCode = http.StatusOK // Streaming always starts with 200 OK
142-
}
143-
wrappedStream := auditlog.WrapStreamForLogging(stream, h.logger, streamEntry, c.Request().URL.Path)
144-
defer func() {
145-
_ = wrappedStream.Close() //nolint:errcheck
146-
}()
147-
148-
c.Response().Header().Set("Content-Type", "text/event-stream")
149-
c.Response().Header().Set("Cache-Control", "no-cache")
150-
c.Response().Header().Set("Connection", "keep-alive")
151-
152-
// Capture response headers on stream entry AFTER setting them
153-
if streamEntry != nil && streamEntry.Data != nil {
154-
streamEntry.Data.ResponseHeaders = map[string]string{
155-
"Content-Type": "text/event-stream",
156-
"Cache-Control": "no-cache",
157-
"Connection": "keep-alive",
158-
}
159-
}
160-
161-
c.Response().WriteHeader(http.StatusOK)
162-
163-
_, _ = io.Copy(c.Response().Writer, wrappedStream)
164-
return nil
135+
return h.handleStreamingResponse(c, func() (io.ReadCloser, error) {
136+
return h.provider.StreamResponses(c.Request().Context(), &req)
137+
})
165138
}
166139

167140
// Non-streaming

0 commit comments

Comments
 (0)