Skip to content

Commit 05277f6

Browse files
oauthex: use internal JSON library for decoding. (#866)
1 parent 150bca7 commit 05277f6

3 files changed

Lines changed: 55 additions & 4 deletions

File tree

internal/json/json.go

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,25 @@ package json
88

99
import (
1010
"bytes"
11+
"io"
1112

1213
"github.com/segmentio/encoding/json"
1314
)
1415

15-
func Unmarshal(data []byte, v any) error {
16-
dec := json.NewDecoder(bytes.NewReader(data))
16+
type Decoder struct {
17+
dec *json.Decoder
18+
}
19+
20+
func NewDecoder(r io.Reader) *Decoder {
21+
dec := json.NewDecoder(r)
1722
dec.DontMatchCaseInsensitiveStructFields()
18-
return dec.Decode(v)
23+
return &Decoder{dec: dec}
24+
}
25+
26+
func (d *Decoder) Decode(v any) error {
27+
return d.dec.Decode(v)
28+
}
29+
30+
func Unmarshal(data []byte, v any) error {
31+
return NewDecoder(bytes.NewReader(data)).Decode(v)
1932
}

internal/json/json_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"bytes"
99
"encoding/json"
1010
"fmt"
11+
"strings"
1112
"testing"
1213

1314
"github.com/google/go-cmp/cmp"
@@ -65,6 +66,43 @@ func TestUnmarshalCaseSensitivity(t *testing.T) {
6566
}
6667
}
6768

69+
func TestNewDecoderCaseSensitivity(t *testing.T) {
70+
type Target struct {
71+
Field string `json:"field"`
72+
TaggedField string `json:"custom_tag"`
73+
}
74+
75+
tests := []struct {
76+
name string
77+
json string
78+
want Target
79+
}{
80+
{
81+
name: "exact match",
82+
json: `{"field": "value", "custom_tag": "tagged"}`,
83+
want: Target{Field: "value", TaggedField: "tagged"},
84+
},
85+
{
86+
name: "case mismatch",
87+
json: `{"Field": "value", "Custom_tag": "tagged"}`,
88+
want: Target{},
89+
},
90+
}
91+
92+
for _, tt := range tests {
93+
t.Run(tt.name, func(t *testing.T) {
94+
var got Target
95+
dec := NewDecoder(strings.NewReader(tt.json))
96+
if err := dec.Decode(&got); err != nil {
97+
t.Fatalf("Decode failed: %v", err)
98+
}
99+
if diff := cmp.Diff(tt.want, got); diff != "" {
100+
t.Errorf("Decode mismatch (-want +got):\n%s", diff)
101+
}
102+
})
103+
}
104+
}
105+
68106
func TestUnmarshalNullCharacter(t *testing.T) {
69107
type Target struct {
70108
Field string `json:"field"`

oauthex/oauth2.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ package oauthex
1010

1111
import (
1212
"context"
13-
"encoding/json"
1413
"fmt"
1514
"io"
1615
"mime"
1716
"net/http"
1817
"net/url"
1918
"strings"
2019

20+
"github.com/modelcontextprotocol/go-sdk/internal/json"
2121
"github.com/modelcontextprotocol/go-sdk/internal/util"
2222
)
2323

0 commit comments

Comments
 (0)