Skip to content

Commit a5d171f

Browse files
fix(core): harden json field handling and cache guards
1 parent ae2c50f commit a5d171f

5 files changed

Lines changed: 193 additions & 96 deletions

File tree

docs/superpowers/plans/2026-03-22-chat-unknown-fields-refactor.md

Lines changed: 0 additions & 87 deletions
This file was deleted.

internal/core/batch_preparation_test.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,14 @@ func TestCloneBatchRequestDeepCopiesNestedFields(t *testing.T) {
3838
cloned.Requests[0].CustomID = "chat-2"
3939
cloned.Requests[0].Body[10] = 'X'
4040
itemExtra := cloned.Requests[0].ExtraFields.Lookup("x_item")
41+
if len(itemExtra) <= 9 {
42+
t.Fatalf("cloned item extra too short: %q", itemExtra)
43+
}
4144
itemExtra[9] = 'f'
4245
topExtra := cloned.ExtraFields.Lookup("x_top")
46+
if len(topExtra) <= 9 {
47+
t.Fatalf("cloned top extra too short: %q", topExtra)
48+
}
4349
topExtra[9] = 'f'
4450

4551
if got := original.Metadata["provider"]; got != "openai" {

internal/core/json_fields.go

Lines changed: 91 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"encoding/json"
66
"fmt"
7+
"math"
78
"sort"
89
"strconv"
910
)
@@ -53,7 +54,12 @@ func UnknownJSONFieldsFromMap(fields map[string]json.RawMessage) UnknownJSONFiel
5354
}
5455
buf.Write(keyBody)
5556
buf.WriteByte(':')
56-
buf.Write(CloneRawJSON(fields[key]))
57+
rawValue := CloneRawJSON(fields[key])
58+
if len(rawValue) == 0 {
59+
buf.WriteString("null")
60+
continue
61+
}
62+
buf.Write(rawValue)
5763
}
5864
buf.WriteByte('}')
5965
return UnknownJSONFields{raw: buf.Bytes()}
@@ -133,7 +139,11 @@ func mergeUnknownJSONObject(baseBody, extraBody []byte) ([]byte, error) {
133139
return CloneRawJSON(extraBody), nil
134140
}
135141

136-
merged := make([]byte, 0, len(baseBody)+len(extraBody)-1)
142+
totalCap, err := mergedJSONObjectCap(len(baseBody), len(extraBody))
143+
if err != nil {
144+
return nil, err
145+
}
146+
merged := make([]byte, 0, totalCap)
137147
merged = append(merged, baseBody[:len(baseBody)-1]...)
138148
if !bytes.Equal(extraBody, []byte("{}")) {
139149
merged = append(merged, ',')
@@ -142,6 +152,16 @@ func mergeUnknownJSONObject(baseBody, extraBody []byte) ([]byte, error) {
142152
return merged, nil
143153
}
144154

155+
func mergedJSONObjectCap(baseLen, extraLen int) (int, error) {
156+
if extraLen <= 0 {
157+
return 0, fmt.Errorf("unknown JSON fields are empty")
158+
}
159+
if baseLen > math.MaxInt-(extraLen-1) {
160+
return 0, fmt.Errorf("combined JSON object too large")
161+
}
162+
return baseLen + extraLen - 1, nil
163+
}
164+
145165
func readJSONObjectKey(dec *json.Decoder) (string, bool) {
146166
keyToken, err := dec.Token()
147167
if err != nil {
@@ -205,8 +225,22 @@ func extractUnknownJSONFieldsObjectByScan(data []byte, knownFields ...string) (U
205225
}
206226

207227
i = skipJSONWhitespace(data, valueEnd)
208-
if i < len(data) && data[i] == ',' {
228+
if i >= len(data) {
229+
return UnknownJSONFields{}, fmt.Errorf("unterminated JSON object")
230+
}
231+
switch data[i] {
232+
case ',':
209233
i = skipJSONWhitespace(data, i+1)
234+
if i >= len(data) {
235+
return UnknownJSONFields{}, fmt.Errorf("unterminated JSON object")
236+
}
237+
if data[i] == '}' {
238+
return UnknownJSONFields{}, fmt.Errorf("unexpected trailing comma in JSON object")
239+
}
240+
case '}':
241+
// The next loop iteration will terminate cleanly on the closing brace.
242+
default:
243+
return UnknownJSONFields{}, fmt.Errorf("expected ',' or '}' after object value")
210244
}
211245
}
212246

@@ -233,12 +267,19 @@ func scanJSONValue(data []byte, start int) (int, error) {
233267
for i < len(data) {
234268
switch data[i] {
235269
case ',', '}', ']':
236-
return i, nil
270+
goto validateLiteral
237271
case ' ', '\n', '\r', '\t':
238-
return i, nil
272+
goto validateLiteral
239273
}
240274
i++
241275
}
276+
validateLiteral:
277+
if i == start {
278+
return 0, fmt.Errorf("expected JSON literal")
279+
}
280+
if err := validateJSONLiteral(data[start:i]); err != nil {
281+
return 0, err
282+
}
242283
return i, nil
243284
}
244285
}
@@ -267,8 +308,22 @@ func scanJSONObject(data []byte, start int) (int, error) {
267308
return 0, err
268309
}
269310
i = skipJSONWhitespace(data, valueEnd)
270-
if i < len(data) && data[i] == ',' {
271-
i++
311+
if i >= len(data) {
312+
return 0, fmt.Errorf("unterminated JSON object")
313+
}
314+
switch data[i] {
315+
case ',':
316+
i = skipJSONWhitespace(data, i+1)
317+
if i >= len(data) {
318+
return 0, fmt.Errorf("unterminated JSON object")
319+
}
320+
if data[i] == '}' {
321+
return 0, fmt.Errorf("unexpected trailing comma in JSON object")
322+
}
323+
case '}':
324+
return i + 1, nil
325+
default:
326+
return 0, fmt.Errorf("expected ',' or '}' after object value")
272327
}
273328
}
274329
return 0, fmt.Errorf("unterminated JSON object")
@@ -289,13 +344,40 @@ func scanJSONArray(data []byte, start int) (int, error) {
289344
return 0, err
290345
}
291346
i = skipJSONWhitespace(data, valueEnd)
292-
if i < len(data) && data[i] == ',' {
293-
i++
347+
if i >= len(data) {
348+
return 0, fmt.Errorf("unterminated JSON array")
349+
}
350+
switch data[i] {
351+
case ',':
352+
i = skipJSONWhitespace(data, i+1)
353+
if i >= len(data) {
354+
return 0, fmt.Errorf("unterminated JSON array")
355+
}
356+
if data[i] == ']' {
357+
return 0, fmt.Errorf("unexpected trailing comma in JSON array")
358+
}
359+
case ']':
360+
return i + 1, nil
361+
default:
362+
return 0, fmt.Errorf("expected ',' or ']' after array element")
294363
}
295364
}
296365
return 0, fmt.Errorf("unterminated JSON array")
297366
}
298367

368+
func validateJSONLiteral(raw []byte) error {
369+
var value any
370+
if err := json.Unmarshal(raw, &value); err != nil {
371+
return fmt.Errorf("invalid JSON literal: %w", err)
372+
}
373+
switch value.(type) {
374+
case nil, bool, float64:
375+
return nil
376+
default:
377+
return fmt.Errorf("invalid JSON literal")
378+
}
379+
}
380+
299381
func scanJSONString(data []byte, start int) (int, error) {
300382
if start >= len(data) || data[start] != '"' {
301383
return 0, fmt.Errorf("expected JSON string")

internal/core/json_fields_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package core
33
import (
44
"bytes"
55
"encoding/json"
6+
"math"
67
"testing"
78
)
89

@@ -54,3 +55,43 @@ func TestExtractUnknownJSONFieldsObjectByScan_HandlesEscapedStrings(t *testing.T
5455
t.Fatalf("x_json = %s", got)
5556
}
5657
}
58+
59+
func TestUnknownJSONFieldsFromMap_EmptyRawValueEncodesAsNull(t *testing.T) {
60+
fields := UnknownJSONFieldsFromMap(map[string]json.RawMessage{
61+
"x_nil": nil,
62+
"x_set": json.RawMessage(`true`),
63+
})
64+
65+
if got := fields.Lookup("x_nil"); !bytes.Equal(got, []byte("null")) {
66+
t.Fatalf("x_nil = %q, want null", got)
67+
}
68+
if got := fields.Lookup("x_set"); !bytes.Equal(got, []byte("true")) {
69+
t.Fatalf("x_set = %q, want true", got)
70+
}
71+
}
72+
73+
func TestExtractUnknownJSONFieldsObjectByScan_RejectsInvalidJSONSyntax(t *testing.T) {
74+
tests := []struct {
75+
name string
76+
body string
77+
}{
78+
{name: "invalid bare literal", body: `{"known":"value","x":wat}`},
79+
{name: "missing object comma", body: `{"known":"value" "x":1}`},
80+
{name: "trailing object comma", body: `{"known":"value","x":1,}`},
81+
{name: "trailing array comma", body: `{"known":"value","x":[1,]}`},
82+
}
83+
84+
for _, tt := range tests {
85+
t.Run(tt.name, func(t *testing.T) {
86+
if _, err := extractUnknownJSONFieldsObjectByScan([]byte(tt.body), "known"); err == nil {
87+
t.Fatalf("extractUnknownJSONFieldsObjectByScan(%q) error = nil, want syntax error", tt.body)
88+
}
89+
})
90+
}
91+
}
92+
93+
func TestMergedJSONObjectCap_Overflow(t *testing.T) {
94+
if _, err := mergedJSONObjectCap(math.MaxInt, 2); err == nil {
95+
t.Fatal("mergedJSONObjectCap() error = nil, want overflow error")
96+
}
97+
}

internal/responsecache/middleware_test.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,61 @@ func TestSimpleCacheMiddleware_BypassesCacheWhenBodyWasNotCaptured(t *testing.T)
339339
}
340340
}
341341

342+
func TestSimpleCacheMiddleware_BodyReadErrorReturnsGatewayError(t *testing.T) {
343+
store := cache.NewMapStore()
344+
defer store.Close()
345+
mw := NewResponseCacheMiddlewareWithStore(store, time.Hour)
346+
e := echo.New()
347+
348+
handler := mw.Middleware()(func(c *echo.Context) error {
349+
t.Fatal("handler should not be called when request body read fails")
350+
return nil
351+
})
352+
353+
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
354+
req.Header.Set("Content-Type", "application/json")
355+
req.Body = explodingCacheReadCloser{}
356+
rec := httptest.NewRecorder()
357+
c := e.NewContext(req, rec)
358+
359+
err := handler(c)
360+
var gatewayErr *core.GatewayError
361+
if !errors.As(err, &gatewayErr) {
362+
t.Fatalf("handler error = %T, want *core.GatewayError", err)
363+
}
364+
if gatewayErr.Type != core.ErrorTypeInvalidRequest {
365+
t.Fatalf("gateway error type = %q, want %q", gatewayErr.Type, core.ErrorTypeInvalidRequest)
366+
}
367+
}
368+
369+
func TestRequestBodyForCache_BodyNotCapturedTakesPrecedenceOverEmptySnapshotBody(t *testing.T) {
370+
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
371+
frame := core.NewRequestSnapshot(
372+
http.MethodPost,
373+
"/v1/chat/completions",
374+
nil,
375+
nil,
376+
nil,
377+
"application/json",
378+
[]byte{},
379+
true,
380+
"",
381+
nil,
382+
)
383+
req = req.WithContext(core.WithRequestSnapshot(req.Context(), frame))
384+
385+
body, cacheable, err := requestBodyForCache(req)
386+
if err != nil {
387+
t.Fatalf("requestBodyForCache() error = %v", err)
388+
}
389+
if cacheable {
390+
t.Fatalf("requestBodyForCache() cacheable = true, want false (body=%q)", body)
391+
}
392+
if body != nil {
393+
t.Fatalf("requestBodyForCache() body = %q, want nil", body)
394+
}
395+
}
396+
342397
func TestIsStreamingRequest(t *testing.T) {
343398
tests := []struct {
344399
name string

0 commit comments

Comments
 (0)