Skip to content

Commit d897b6d

Browse files
authored
feat(spanner): support Scan from string to NullUUID (#14128)
Adds support for scanning a string value into a NullUUID. Also adds typed nil checks to other Scan functions.
1 parent 23cabc2 commit d897b6d

File tree

3 files changed

+259
-28
lines changed

3 files changed

+259
-28
lines changed

spanner/value.go

Lines changed: 145 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -261,12 +261,20 @@ func (n *NullInt64) Scan(value interface{}) error {
261261
default:
262262
return spannerErrorf(codes.InvalidArgument, "invalid type for NullInt64: %v", p)
263263
case *int64:
264-
n.Int64 = *p
264+
if p == nil {
265+
n.Int64, n.Valid = 0, false
266+
} else {
267+
n.Int64 = *p
268+
}
265269
case int64:
266270
n.Int64 = p
267271
case *NullInt64:
268-
n.Int64 = p.Int64
269-
n.Valid = p.Valid
272+
if p == nil {
273+
n.Int64, n.Valid = 0, false
274+
} else {
275+
n.Int64 = p.Int64
276+
n.Valid = p.Valid
277+
}
270278
case NullInt64:
271279
n.Int64 = p.Int64
272280
n.Valid = p.Valid
@@ -278,6 +286,10 @@ func (n *NullInt64) Scan(value interface{}) error {
278286
n.Int64 = i64
279287
n.Valid = true
280288
case *string:
289+
if p == nil {
290+
n.Int64, n.Valid = 0, false
291+
return nil
292+
}
281293
i64, err := strconv.ParseInt(*p, 10, 64)
282294
if err != nil {
283295
return err
@@ -360,12 +372,20 @@ func (n *NullString) Scan(value interface{}) error {
360372
default:
361373
return spannerErrorf(codes.InvalidArgument, "invalid type for NullString: %v", p)
362374
case *string:
363-
n.StringVal = *p
375+
if p == nil {
376+
n.StringVal, n.Valid = "", false
377+
} else {
378+
n.StringVal = *p
379+
}
364380
case string:
365381
n.StringVal = p
366382
case *NullString:
367-
n.StringVal = p.StringVal
368-
n.Valid = p.Valid
383+
if p == nil {
384+
n.StringVal, n.Valid = "", false
385+
} else {
386+
n.StringVal = p.StringVal
387+
n.Valid = p.Valid
388+
}
369389
case NullString:
370390
n.StringVal = p.StringVal
371391
n.Valid = p.Valid
@@ -440,12 +460,20 @@ func (n *NullFloat64) Scan(value interface{}) error {
440460
default:
441461
return spannerErrorf(codes.InvalidArgument, "invalid type for NullFloat64: %v", p)
442462
case *float64:
443-
n.Float64 = *p
463+
if p == nil {
464+
n.Float64, n.Valid = 0, false
465+
} else {
466+
n.Float64 = *p
467+
}
444468
case float64:
445469
n.Float64 = p
446470
case *NullFloat64:
447-
n.Float64 = p.Float64
448-
n.Valid = p.Valid
471+
if p == nil {
472+
n.Float64, n.Valid = 0, false
473+
} else {
474+
n.Float64 = p.Float64
475+
n.Valid = p.Valid
476+
}
449477
case NullFloat64:
450478
n.Float64 = p.Float64
451479
n.Valid = p.Valid
@@ -457,6 +485,10 @@ func (n *NullFloat64) Scan(value interface{}) error {
457485
n.Float64 = f
458486
n.Valid = true
459487
case *string:
488+
if p == nil {
489+
n.Float64, n.Valid = 0, false
490+
return nil
491+
}
460492
f, err := strconv.ParseFloat(*p, 64)
461493
if err != nil {
462494
return err
@@ -534,12 +566,20 @@ func (n *NullFloat32) Scan(value interface{}) error {
534566
default:
535567
return spannerErrorf(codes.InvalidArgument, "invalid type for NullFloat32: %v", p)
536568
case *float32:
537-
n.Float32 = *p
569+
if p == nil {
570+
n.Float32, n.Valid = 0, false
571+
} else {
572+
n.Float32 = *p
573+
}
538574
case float32:
539575
n.Float32 = p
540576
case *NullFloat32:
541-
n.Float32 = p.Float32
542-
n.Valid = p.Valid
577+
if p == nil {
578+
n.Float32, n.Valid = 0, false
579+
} else {
580+
n.Float32 = p.Float32
581+
n.Valid = p.Valid
582+
}
543583
case NullFloat32:
544584
n.Float32 = p.Float32
545585
n.Valid = p.Valid
@@ -551,6 +591,10 @@ func (n *NullFloat32) Scan(value interface{}) error {
551591
n.Float32 = float32(f)
552592
n.Valid = true
553593
case *string:
594+
if p == nil {
595+
n.Float32, n.Valid = 0, false
596+
return nil
597+
}
554598
f, err := strconv.ParseFloat(*p, 32)
555599
if err != nil {
556600
return err
@@ -628,12 +672,20 @@ func (n *NullBool) Scan(value interface{}) error {
628672
default:
629673
return spannerErrorf(codes.InvalidArgument, "invalid type for NullBool: %v", p)
630674
case *bool:
631-
n.Bool = *p
675+
if p == nil {
676+
n.Bool, n.Valid = false, false
677+
} else {
678+
n.Bool = *p
679+
}
632680
case bool:
633681
n.Bool = p
634682
case *NullBool:
635-
n.Bool = p.Bool
636-
n.Valid = p.Valid
683+
if p == nil {
684+
n.Bool, n.Valid = false, false
685+
} else {
686+
n.Bool = p.Bool
687+
n.Valid = p.Valid
688+
}
637689
case NullBool:
638690
n.Bool = p.Bool
639691
n.Valid = p.Valid
@@ -645,6 +697,10 @@ func (n *NullBool) Scan(value interface{}) error {
645697
n.Bool = f
646698
n.Valid = true
647699
case *string:
700+
if p == nil {
701+
n.Bool, n.Valid = false, false
702+
return nil
703+
}
648704
f, err := strconv.ParseBool(*p)
649705
if err != nil {
650706
return err
@@ -727,12 +783,20 @@ func (n *NullTime) Scan(value interface{}) error {
727783
default:
728784
return spannerErrorf(codes.InvalidArgument, "invalid type for NullTime: %v", p)
729785
case *time.Time:
730-
n.Time = *p
786+
if p == nil {
787+
n.Time, n.Valid = time.Time{}, false
788+
} else {
789+
n.Time = *p
790+
}
731791
case time.Time:
732792
n.Time = p
733793
case *NullTime:
734-
n.Time = p.Time
735-
n.Valid = p.Valid
794+
if p == nil {
795+
n.Time, n.Valid = time.Time{}, false
796+
} else {
797+
n.Time = p.Time
798+
n.Valid = p.Valid
799+
}
736800
case NullTime:
737801
n.Time = p.Time
738802
n.Valid = p.Valid
@@ -744,6 +808,10 @@ func (n *NullTime) Scan(value interface{}) error {
744808
n.Time = f
745809
n.Valid = true
746810
case *string:
811+
if p == nil {
812+
n.Time, n.Valid = time.Time{}, false
813+
return nil
814+
}
747815
f, err := time.Parse(time.RFC3339Nano, *p)
748816
if err != nil {
749817
return err
@@ -831,12 +899,20 @@ func (n *NullDate) Scan(value interface{}) error {
831899
n.Date = d
832900
n.Valid = true
833901
case *civil.Date:
834-
n.Date = *p
902+
if p == nil {
903+
n.Date, n.Valid = civil.Date{}, false
904+
} else {
905+
n.Date = *p
906+
}
835907
case civil.Date:
836908
n.Date = p
837909
case *NullDate:
838-
n.Date = p.Date
839-
n.Valid = p.Valid
910+
if p == nil {
911+
n.Date, n.Valid = civil.Date{}, false
912+
} else {
913+
n.Date = p.Date
914+
n.Valid = p.Valid
915+
}
840916
case NullDate:
841917
n.Date = p.Date
842918
n.Valid = p.Valid
@@ -916,12 +992,20 @@ func (n *NullNumeric) Scan(value interface{}) error {
916992
default:
917993
return spannerErrorf(codes.InvalidArgument, "invalid type for NullNumeric: %v", p)
918994
case *big.Rat:
919-
n.Numeric = *p
995+
if p == nil {
996+
n.Numeric, n.Valid = big.Rat{}, false
997+
} else {
998+
n.Numeric = *p
999+
}
9201000
case big.Rat:
9211001
n.Numeric = p
9221002
case *NullNumeric:
923-
n.Numeric = p.Numeric
924-
n.Valid = p.Valid
1003+
if p == nil {
1004+
n.Numeric, n.Valid = big.Rat{}, false
1005+
} else {
1006+
n.Numeric = p.Numeric
1007+
n.Valid = p.Valid
1008+
}
9251009
case NullNumeric:
9261010
n.Numeric = p.Numeric
9271011
n.Valid = p.Valid
@@ -933,6 +1017,10 @@ func (n *NullNumeric) Scan(value interface{}) error {
9331017
n.Numeric = *y
9341018
n.Valid = true
9351019
case *string:
1020+
if p == nil {
1021+
n.Numeric, n.Valid = big.Rat{}, false
1022+
return nil
1023+
}
9361024
y, ok := (&big.Rat{}).SetString(*p)
9371025
if !ok {
9381026
return errUnexpectedNumericStr(*p)
@@ -1285,21 +1373,50 @@ func (n *NullUUID) Scan(value interface{}) error {
12851373
n.Valid = true
12861374
switch p := value.(type) {
12871375
default:
1288-
return spannerErrorf(codes.InvalidArgument, "invalid type for NullUUID: %v", p)
1376+
return spannerErrorf(codes.InvalidArgument, "invalid type for NullUUID: %v (%t)", p, p)
12891377
case *uuid.UUID:
1290-
n.UUID = *p
1378+
if p == nil {
1379+
n.UUID = uuid.Nil
1380+
n.Valid = false
1381+
} else {
1382+
n.UUID = *p
1383+
}
12911384
case uuid.UUID:
12921385
n.UUID = p
12931386
case *NullUUID:
1294-
n.UUID = p.UUID
1295-
n.Valid = p.Valid
1387+
if p == nil {
1388+
n.UUID = uuid.Nil
1389+
n.Valid = false
1390+
} else {
1391+
n.UUID = p.UUID
1392+
n.Valid = p.Valid
1393+
}
12961394
case NullUUID:
12971395
n.UUID = p.UUID
12981396
n.Valid = p.Valid
1397+
case string:
1398+
return n.scanStringValue(p)
1399+
case *string:
1400+
if p == nil {
1401+
n.UUID = uuid.Nil
1402+
n.Valid = false
1403+
return nil
1404+
}
1405+
return n.scanStringValue(*p)
12991406
}
13001407
return nil
13011408
}
13021409

1410+
func (n *NullUUID) scanStringValue(s string) error {
1411+
u, err := uuid.Parse(s)
1412+
if err != nil {
1413+
return err
1414+
}
1415+
n.UUID = u
1416+
n.Valid = true
1417+
return nil
1418+
}
1419+
13031420
// GormDataType is used by gorm to determine the default data type for fields with this type.
13041421
func (n NullUUID) GormDataType() string {
13051422
return "UUID"

0 commit comments

Comments
 (0)