Skip to content

Commit 85bc9db

Browse files
authored
feat(spanner): PGNumeric implements Scanner and Valuer (#13722)
Implement the Scanner and Valuer interfaces for PGNumeric so this type is easier to use with the database/sql driver.
1 parent e94bd6f commit 85bc9db

File tree

3 files changed

+121
-0
lines changed

3 files changed

+121
-0
lines changed

spanner/row_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2837,3 +2837,11 @@ func TestSelectAll(t *testing.T) {
28372837
func stringPointer(s string) *string {
28382838
return &s
28392839
}
2840+
2841+
func float32Pointer(f float32) *float32 {
2842+
return &f
2843+
}
2844+
2845+
func float64Pointer(f float64) *float64 {
2846+
return &f
2847+
}

spanner/value.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,6 +1051,80 @@ func (n *PGNumeric) UnmarshalJSON(payload []byte) error {
10511051
return nil
10521052
}
10531053

1054+
// Value implements the driver.Valuer interface.
1055+
func (n PGNumeric) Value() (driver.Value, error) {
1056+
if n.IsNull() {
1057+
return nil, nil
1058+
}
1059+
return n.Numeric, nil
1060+
}
1061+
1062+
// Scan implements the sql.Scanner interface.
1063+
func (n *PGNumeric) Scan(value interface{}) error {
1064+
if value == nil {
1065+
n.Numeric, n.Valid = "", false
1066+
return nil
1067+
}
1068+
n.Valid = true
1069+
switch p := value.(type) {
1070+
default:
1071+
return spannerErrorf(codes.InvalidArgument, "invalid type for PGNumeric: %v", p)
1072+
case *big.Rat:
1073+
if p == nil {
1074+
n.Numeric, n.Valid = "", false
1075+
} else {
1076+
n.Numeric = NumericString(p)
1077+
}
1078+
case big.Rat:
1079+
n.Numeric = NumericString(&p)
1080+
case *NullNumeric:
1081+
if p == nil {
1082+
n.Numeric, n.Valid = "", false
1083+
} else {
1084+
if p.Valid {
1085+
n.Numeric = p.String()
1086+
} else {
1087+
n.Numeric = ""
1088+
}
1089+
n.Valid = p.Valid
1090+
}
1091+
case NullNumeric:
1092+
if p.Valid {
1093+
n.Numeric = p.String()
1094+
} else {
1095+
n.Numeric = ""
1096+
}
1097+
n.Valid = p.Valid
1098+
case string:
1099+
n.Numeric = p
1100+
n.Valid = true
1101+
case *string:
1102+
if p == nil {
1103+
n.Numeric, n.Valid = "", false
1104+
} else {
1105+
n.Numeric = *p
1106+
n.Valid = true
1107+
}
1108+
case float32:
1109+
n.Numeric = strconv.FormatFloat(float64(p), 'f', 9, 32)
1110+
case *float32:
1111+
if p == nil {
1112+
n.Numeric, n.Valid = "", false
1113+
} else {
1114+
n.Numeric = strconv.FormatFloat(float64(*p), 'f', 9, 32)
1115+
}
1116+
case float64:
1117+
n.Numeric = strconv.FormatFloat(p, 'f', 9, 64)
1118+
case *float64:
1119+
if p == nil {
1120+
n.Numeric, n.Valid = "", false
1121+
} else {
1122+
n.Numeric = strconv.FormatFloat(*p, 'f', 9, 64)
1123+
}
1124+
}
1125+
return nil
1126+
}
1127+
10541128
// GormDataType is used by gorm to determine the default data type for fields with this type.
10551129
func (n PGNumeric) GormDataType() string {
10561130
return "numeric"

spanner/value_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3562,6 +3562,45 @@ func TestScanNullNumeric(t *testing.T) {
35623562
}
35633563
}
35643564

3565+
func TestScanPGNumeric(t *testing.T) {
3566+
for _, test := range []struct {
3567+
name string
3568+
input any
3569+
want PGNumeric
3570+
}{
3571+
{name: "string", input: "3.14", want: PGNumeric{Numeric: "3.14", Valid: true}},
3572+
{name: "stringptr", input: stringPointer("3.14"), want: PGNumeric{Numeric: "3.14", Valid: true}},
3573+
{name: "nil", input: nil, want: PGNumeric{}},
3574+
{name: "nilstringptr", input: (*string)(nil), want: PGNumeric{}},
3575+
{name: "float32", input: float32(3.14), want: PGNumeric{Numeric: "3.140000105", Valid: true}},
3576+
{name: "float32ptr", input: float32Pointer(float32(3.14)), want: PGNumeric{Numeric: "3.140000105", Valid: true}},
3577+
{name: "float64", input: 3.14, want: PGNumeric{Numeric: "3.140000000", Valid: true}},
3578+
{name: "float64ptr", input: float64Pointer(3.14), want: PGNumeric{Numeric: "3.140000000", Valid: true}},
3579+
{name: "NullNumeric", input: NullNumeric{Numeric: *bigRatFromString("3.14"), Valid: true}, want: PGNumeric{Numeric: "3.140000000", Valid: true}},
3580+
{name: "NullNumericPtr", input: &NullNumeric{Numeric: *bigRatFromString("3.14"), Valid: true}, want: PGNumeric{Numeric: "3.140000000", Valid: true}},
3581+
{name: "NullNumericWithNullValue", input: NullNumeric{}, want: PGNumeric{}},
3582+
{name: "NullNumericPtrWithNullValue", input: &NullNumeric{}, want: PGNumeric{}},
3583+
{name: "bigrat", input: *bigRatFromString("6.626"), want: PGNumeric{Numeric: "6.626000000", Valid: true}},
3584+
{name: "bigratptr", input: bigRatFromString("9.99"), want: PGNumeric{Numeric: "9.990000000", Valid: true}},
3585+
{name: "nilbigratptr", input: (*big.Rat)(nil), want: PGNumeric{}},
3586+
} {
3587+
t.Run(test.name, func(t *testing.T) {
3588+
n := PGNumeric{Numeric: "should be overwritten", Valid: true}
3589+
if err := n.Scan(test.input); err != nil {
3590+
t.Fatal(err)
3591+
}
3592+
if g, w := n, test.want; !reflect.DeepEqual(g, w) {
3593+
t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w)
3594+
}
3595+
})
3596+
}
3597+
}
3598+
3599+
func bigRatFromString(s string) *big.Rat {
3600+
r, _ := (&big.Rat{}).SetString(s)
3601+
return r
3602+
}
3603+
35653604
func TestInterval(t *testing.T) {
35663605
tests := []struct {
35673606
name string

0 commit comments

Comments
 (0)