Skip to content

Commit e918c6a

Browse files
fix(hotpath): honor duplicate top-level selector keys
1 parent ca74fcc commit e918c6a

9 files changed

Lines changed: 110 additions & 101 deletions

File tree

go.mod

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ require (
1717
github.com/testcontainers/testcontainers-go v0.41.0
1818
github.com/testcontainers/testcontainers-go/modules/mongodb v0.41.0
1919
github.com/testcontainers/testcontainers-go/modules/postgres v0.41.0
20-
github.com/tidwall/gjson v1.18.0
2120
go.mongodb.org/mongo-driver/v2 v2.5.0
2221
golang.org/x/term v0.41.0
2322
gopkg.in/yaml.v3 v3.0.1
@@ -94,8 +93,6 @@ require (
9493
github.com/sv-tools/openapi v0.4.0 // indirect
9594
github.com/swaggo/files/v2 v2.0.2 // indirect
9695
github.com/swaggo/swag/v2 v2.0.0-rc5 // indirect
97-
github.com/tidwall/match v1.1.1 // indirect
98-
github.com/tidwall/pretty v1.2.0 // indirect
9996
github.com/tklauser/go-sysconf v0.3.16 // indirect
10097
github.com/tklauser/numcpus v0.11.0 // indirect
10198
github.com/urfave/cli/v2 v2.27.5 // indirect

go.sum

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -209,12 +209,6 @@ github.com/testcontainers/testcontainers-go/modules/mongodb v0.41.0 h1:z5sroe+jX
209209
github.com/testcontainers/testcontainers-go/modules/mongodb v0.41.0/go.mod h1:pb+JZ21ixA1TlyUkhJclS/hCuZCnAkgnh+iB+pTHnSk=
210210
github.com/testcontainers/testcontainers-go/modules/postgres v0.41.0 h1:AOtFXssrDlLm84A2sTTR/AhvJiYbrIuCO59d+Ro9Tb0=
211211
github.com/testcontainers/testcontainers-go/modules/postgres v0.41.0/go.mod h1:k2a09UKhgSp6vNpliIY0QSgm4Hi7GXVTzWvWgUemu/8=
212-
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
213-
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
214-
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
215-
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
216-
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
217-
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
218212
github.com/tklauser/go-sysconf v0.3.16 h1:frioLaCQSsF5Cy1jgRBrzr6t502KIIwQ0MArYICU0nA=
219213
github.com/tklauser/go-sysconf v0.3.16/go.mod h1:/qNL9xxDhc7tx3HSRsLWNnuzbVfh3e7gh/BmM179nYI=
220214
github.com/tklauser/numcpus v0.11.0 h1:nSTwhKH5e1dMNsCdVBukSZrURJRoHbSEQjdEbY+9RXw=

internal/core/json_fields.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,57 @@ func extractUnknownJSONFields(data []byte, knownFields ...string) (UnknownJSONFi
103103
return extractUnknownJSONFieldsObjectByScan(data, knownFields...)
104104
}
105105

106+
// VisitTopLevelJSONObjectFields walks members of a top-level JSON object in
107+
// source order. The callback receives raw value slices that alias data.
108+
// Returning false from the callback aborts the walk and reports failure.
109+
func VisitTopLevelJSONObjectFields(data []byte, visit func(key string, raw []byte) bool) bool {
110+
data = bytes.TrimSpace(data)
111+
if len(data) == 0 || data[0] != '{' {
112+
return false
113+
}
114+
115+
i := skipJSONWhitespace(data, 1)
116+
for i < len(data) {
117+
if data[i] == '}' {
118+
i = skipJSONWhitespace(data, i+1)
119+
return i == len(data)
120+
}
121+
122+
keyStart := i
123+
keyEnd, err := scanJSONString(data, keyStart)
124+
if err != nil {
125+
return false
126+
}
127+
key, err := decodeJSONString(data[keyStart:keyEnd])
128+
if err != nil {
129+
return false
130+
}
131+
132+
i = skipJSONWhitespace(data, keyEnd)
133+
if i >= len(data) || data[i] != ':' {
134+
return false
135+
}
136+
i = skipJSONWhitespace(data, i+1)
137+
138+
valueStart := i
139+
valueEnd, err := scanJSONValue(data, valueStart)
140+
if err != nil {
141+
return false
142+
}
143+
144+
if !visit(key, data[valueStart:valueEnd]) {
145+
return false
146+
}
147+
148+
i = skipJSONWhitespace(data, valueEnd)
149+
if i < len(data) && data[i] == ',' {
150+
i = skipJSONWhitespace(data, i+1)
151+
}
152+
}
153+
154+
return false
155+
}
156+
106157
func marshalWithUnknownJSONFields(base any, extraFields UnknownJSONFields) ([]byte, error) {
107158
baseBody, err := json.Marshal(base)
108159
if err != nil {

internal/core/semantic.go

Lines changed: 14 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@ package core
22

33
import (
44
"bytes"
5+
"encoding/json"
56
"net/http"
67
"strconv"
78
"strings"
8-
9-
"github.com/tidwall/gjson"
109
)
1110

1211
// RouteHints holds minimal routing-relevant request hints derived from the
@@ -303,54 +302,24 @@ func derivePassthroughRouteInfoFromTransport(snapshot *RequestSnapshot) *Passthr
303302
}
304303

305304
func deriveSnapshotSelectorHintsGJSON(body []byte) (model, provider string, stream, parsed bool) {
306-
if !gjson.ValidBytes(body) {
307-
return "", "", false, false
308-
}
309-
310-
root := gjson.ParseBytes(body)
311-
if !root.IsObject() {
312-
return "", "", false, false
313-
}
314-
315-
modelResult := root.Get("model")
316-
if !snapshotSelectorStringAllowed(modelResult) {
317-
return "", "", false, false
318-
}
319-
providerResult := root.Get("provider")
320-
if !snapshotSelectorStringAllowed(providerResult) {
321-
return "", "", false, false
322-
}
323-
streamResult := root.Get("stream")
324-
if !snapshotSelectorBoolAllowed(streamResult) {
305+
parsed = VisitTopLevelJSONObjectFields(body, func(key string, raw []byte) bool {
306+
switch key {
307+
case "model":
308+
return json.Unmarshal(raw, &model) == nil
309+
case "provider":
310+
return json.Unmarshal(raw, &provider) == nil
311+
case "stream":
312+
return json.Unmarshal(raw, &stream) == nil
313+
default:
314+
return true
315+
}
316+
})
317+
if !parsed {
325318
return "", "", false, false
326319
}
327-
328-
if modelResult.Type == gjson.String {
329-
model = modelResult.String()
330-
}
331-
if providerResult.Type == gjson.String {
332-
provider = providerResult.String()
333-
}
334-
if streamResult.Type == gjson.True || streamResult.Type == gjson.False {
335-
stream = streamResult.Bool()
336-
}
337320
return model, provider, stream, true
338321
}
339322

340-
func snapshotSelectorStringAllowed(result gjson.Result) bool {
341-
if !result.Exists() {
342-
return true
343-
}
344-
return result.Type == gjson.String || result.Type == gjson.Null
345-
}
346-
347-
func snapshotSelectorBoolAllowed(result gjson.Result) bool {
348-
if !result.Exists() {
349-
return true
350-
}
351-
return result.Type == gjson.True || result.Type == gjson.False || result.Type == gjson.Null
352-
}
353-
354323
// DeriveFileRouteInfoFromTransport derives sparse file route info from transport metadata.
355324
func DeriveFileRouteInfoFromTransport(method, path string, routeParams map[string]string, queryParams map[string][]string) *FileRouteInfo {
356325
req := &FileRouteInfo{

internal/core/semantic_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,10 @@ func TestDeriveSnapshotSelectorHintsGJSON_MatchesStdlibSemantics(t *testing.T) {
298298
body string
299299
}{
300300
{name: "valid selector fields", body: `{"provider":"openai","model":"gpt-5-mini","stream":true}`},
301+
{name: "duplicate selector fields use last occurrence", body: `{"model":"blocked","model":"gpt-5-mini","provider":"x","provider":"openai","stream":false,"stream":true}`},
302+
{name: "duplicate null string keeps prior value and null stream clears prior value", body: `{"model":"gpt-5-mini","model":null,"provider":"openai","provider":null,"stream":true,"stream":null}`},
303+
{name: "duplicate invalid selector field fails parse", body: `{"model":"gpt-5-mini","model":123}`},
304+
{name: "duplicate invalid stream field fails parse", body: `{"stream":true,"stream":"yes"}`},
301305
{name: "missing selector fields", body: `{"messages":[{"role":"user","content":"hi"}]}`},
302306
{name: "null selector fields", body: `{"provider":null,"model":null,"stream":null}`},
303307
{name: "invalid json", body: `not json`},

internal/responsecache/middleware_test.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package responsecache
33
import (
44
"bytes"
55
"context"
6+
"encoding/json"
67
"errors"
78
"io"
89
"net/http"
@@ -347,6 +348,10 @@ func TestIsStreamingRequest(t *testing.T) {
347348
}{
348349
{"stream true compact", "/v1/chat/completions", `{"stream":true}`, true},
349350
{"stream true with spaces", "/v1/chat/completions", `{"stream" : true}`, true},
351+
{"duplicate stream uses last occurrence", "/v1/chat/completions", `{"stream":false,"stream":true}`, true},
352+
{"duplicate stream last false", "/v1/chat/completions", `{"stream":true,"stream":false}`, false},
353+
{"duplicate null stream clears prior value", "/v1/chat/completions", `{"stream":true,"stream":null}`, false},
354+
{"duplicate invalid stream returns false", "/v1/chat/completions", `{"stream":true,"stream":"yes"}`, false},
350355
{"stream false", "/v1/chat/completions", `{"stream":false}`, false},
351356
{"stream absent", "/v1/chat/completions", `{"model":"gpt-4"}`, false},
352357
{"embeddings path always false", "/v1/embeddings", `{"stream":true}`, false},
@@ -424,6 +429,19 @@ func BenchmarkRequestBodyForCacheSnapshot(b *testing.B) {
424429
}
425430
}
426431

432+
func isStreamingRequestStdlib(path string, body []byte) bool {
433+
if path == "/v1/embeddings" {
434+
return false
435+
}
436+
var p struct {
437+
Stream *bool `json:"stream"`
438+
}
439+
if err := json.Unmarshal(body, &p); err != nil {
440+
return false
441+
}
442+
return p.Stream != nil && *p.Stream
443+
}
444+
427445
func TestSimpleCacheMiddleware_SkipsNoCache(t *testing.T) {
428446
store := cache.NewMapStore()
429447
defer store.Close()

internal/responsecache/simple.go

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ import (
1414
"time"
1515

1616
"github.com/labstack/echo/v5"
17-
"github.com/tidwall/gjson"
1817

1918
"gomodel/internal/cache"
2019
"gomodel/internal/core"
@@ -217,24 +216,16 @@ func isStreamingRequestGJSON(path string, body []byte) bool {
217216
if path == "/v1/embeddings" {
218217
return false
219218
}
220-
result := gjson.GetBytes(body, "stream")
221-
if !result.Exists() || result.Type != gjson.True && result.Type != gjson.False {
222-
return false
223-
}
224-
return result.Bool()
225-
}
226-
227-
func isStreamingRequestStdlib(path string, body []byte) bool {
228-
if path == "/v1/embeddings" {
229-
return false
230-
}
231-
var p struct {
232-
Stream *bool `json:"stream"`
233-
}
234-
if err := json.Unmarshal(body, &p); err != nil {
219+
var stream *bool
220+
if !core.VisitTopLevelJSONObjectFields(body, func(key string, raw []byte) bool {
221+
if key != "stream" {
222+
return true
223+
}
224+
return json.Unmarshal(raw, &stream) == nil
225+
}) {
235226
return false
236227
}
237-
return p.Stream != nil && *p.Stream
228+
return stream != nil && *stream
238229
}
239230

240231
func hashRequest(path string, body []byte, plan *core.ExecutionPlan) string {

internal/server/model_validation.go

Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
package server
22

33
import (
4+
"encoding/json"
45
"strings"
56

67
"github.com/labstack/echo/v5"
7-
"github.com/tidwall/gjson"
88

99
"gomodel/internal/auditlog"
1010
"gomodel/internal/core"
@@ -147,40 +147,22 @@ func cachedCanonicalSelectorHints(env *core.WhiteBoxPrompt) (model, provider str
147147
}
148148

149149
func selectorHintsFromJSONGJSON(body []byte) (model, provider string, parsed bool) {
150-
if !gjson.ValidBytes(body) {
150+
parsed = core.VisitTopLevelJSONObjectFields(body, func(key string, raw []byte) bool {
151+
switch key {
152+
case "model":
153+
return json.Unmarshal(raw, &model) == nil
154+
case "provider":
155+
return json.Unmarshal(raw, &provider) == nil
156+
default:
157+
return true
158+
}
159+
})
160+
if !parsed {
151161
return "", "", false
152162
}
153-
154-
root := gjson.ParseBytes(body)
155-
if !root.IsObject() {
156-
return "", "", false
157-
}
158-
159-
modelResult := root.Get("model")
160-
if !selectorHintValueAllowed(modelResult) {
161-
return "", "", false
162-
}
163-
providerResult := root.Get("provider")
164-
if !selectorHintValueAllowed(providerResult) {
165-
return "", "", false
166-
}
167-
168-
if modelResult.Type == gjson.String {
169-
model = modelResult.String()
170-
}
171-
if providerResult.Type == gjson.String {
172-
provider = providerResult.String()
173-
}
174163
return model, provider, true
175164
}
176165

177-
func selectorHintValueAllowed(result gjson.Result) bool {
178-
if !result.Exists() {
179-
return true
180-
}
181-
return result.Type == gjson.String || result.Type == gjson.Null
182-
}
183-
184166
func providerPassthroughType(c *echo.Context) (string, bool) {
185167
if info := passthroughRouteInfo(c); info != nil {
186168
providerType := strings.TrimSpace(info.Provider)

internal/server/model_validation_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,9 @@ func TestSelectorHintsFromJSONGJSON_MatchesStdlibSemantics(t *testing.T) {
778778
body string
779779
}{
780780
{name: "model and provider strings", body: `{"provider":"openai","model":"gpt-4o-mini"}`},
781+
{name: "duplicate selector fields use last occurrence", body: `{"provider":"openai","provider":"anthropic","model":"blocked","model":"gpt-4o-mini"}`},
782+
{name: "duplicate null selector keeps prior string value", body: `{"provider":"openai","provider":null,"model":"gpt-4o-mini","model":null}`},
783+
{name: "duplicate invalid selector field fails parse", body: `{"provider":"openai","provider":123}`},
781784
{name: "model only", body: `{"model":"gpt-4o-mini","messages":[{"role":"user","content":"hi"}]}`},
782785
{name: "null selector fields", body: `{"provider":null,"model":null}`},
783786
{name: "missing selector fields", body: `{"messages":[{"role":"user","content":"hi"}]}`},

0 commit comments

Comments
 (0)