Skip to content

Commit f0ec33e

Browse files
authored
feat(dest-mysql): Handle all MySQL types (#11360)
#### Summary Fixes #11325 Same as #9143 for MySQL destination. Type handling taken from https://github.com/cloudquery/cloudquery/pull/11214/files#diff-c4ee14fff2e1c7cbea7dfcf122879522b5d817aaa041f54564696293b1c6bc30R90. Once we have the MySQL source migrated to v3 we can add support for more granular types like `char(x)` or other fixed size types (doing that will be a breaking change). <!--
1 parent 5e30247 commit f0ec33e

4 files changed

Lines changed: 146 additions & 66 deletions

File tree

plugins/destination/mysql/client/migrate.go

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,45 +10,35 @@ import (
1010
"github.com/cloudquery/plugin-sdk/v3/schema"
1111
)
1212

13-
func (c *Client) normalizeTables(tables schema.Tables) (schema.Tables, error) {
13+
func (c *Client) normalizeTables(tables schema.Tables) schema.Tables {
1414
flattened := tables.FlattenTables()
1515
normalized := make(schema.Tables, len(flattened))
16-
var err error
1716
for i, table := range flattened {
18-
normalized[i], err = c.normalizeTable(table)
19-
if err != nil {
20-
return nil, err
21-
}
17+
normalized[i] = c.normalizeTable(table)
2218
}
23-
return normalized, nil
19+
return normalized
2420
}
2521

26-
func (c *Client) normalizeTable(table *schema.Table) (*schema.Table, error) {
22+
func (c *Client) normalizeTable(table *schema.Table) *schema.Table {
2723
columns := make([]schema.Column, len(table.Columns))
2824
for i, col := range table.Columns {
2925
if !c.pkEnabled() {
3026
col.PrimaryKey = false
3127
}
32-
normalized, err := c.normalizeField(col.ToArrowField())
33-
if err != nil {
34-
return nil, err
35-
}
28+
normalized := c.normalizeField(col.ToArrowField())
3629
columns[i] = schema.NewColumnFromArrowField(*normalized)
3730
}
38-
return &schema.Table{Name: table.Name, Columns: columns}, nil
31+
return &schema.Table{Name: table.Name, Columns: columns}
3932
}
4033

41-
func (*Client) normalizeField(field arrow.Field) (*arrow.Field, error) {
42-
normalizedType, err := mySQLTypeToArrowType("", "", arrowTypeToMySqlStr(field.Type))
43-
if err != nil {
44-
return nil, err
45-
}
34+
func (*Client) normalizeField(field arrow.Field) *arrow.Field {
35+
normalizedType := mySQLTypeToArrowType(arrowTypeToMySqlStr(field.Type))
4636
return &arrow.Field{
4737
Name: field.Name,
4838
Type: normalizedType,
4939
Nullable: field.Nullable,
5040
Metadata: field.Metadata,
51-
}, nil
41+
}
5242
}
5343

5444
func (c *Client) nonAutoMigratableTables(tables schema.Tables, mysqlTables schema.Tables) ([]string, [][]schema.TableColumnChange) {
@@ -105,11 +95,7 @@ func (c *Client) Migrate(ctx context.Context, tables schema.Tables) error {
10595
return err
10696
}
10797

108-
normalizedTables, err := c.normalizeTables(tables)
109-
if err != nil {
110-
return err
111-
}
112-
98+
normalizedTables := c.normalizeTables(tables)
11399
if c.spec.MigrateMode != specs.MigrateModeForced {
114100
nonAutoMigrtableTables, changes := c.nonAutoMigratableTables(normalizedTables, mysqlTables)
115101
if len(nonAutoMigrtableTables) > 0 {

plugins/destination/mysql/client/schema.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,7 @@ func (c *Client) getTableColumns(ctx context.Context, tableName string) ([]schem
5959
return nil, err
6060
}
6161

62-
schemaType, err := mySQLTypeToArrowType(tableName, name, typ)
63-
if err != nil {
64-
return nil, err
65-
}
62+
schemaType := mySQLTypeToArrowType(typ)
6663
var primaryKey bool
6764
if constraintType != nil && c.pkEnabled() {
6865
primaryKey = strings.Contains(*constraintType, "PRIMARY KEY")

plugins/destination/mysql/client/types.go

Lines changed: 89 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,106 @@
11
package client
22

33
import (
4-
"fmt"
5-
"sort"
4+
"strconv"
65
"strings"
76

87
"github.com/apache/arrow/go/v13/arrow"
98
"github.com/cloudquery/plugin-sdk/v3/types"
10-
"golang.org/x/exp/maps"
119
)
1210

13-
func mySQLTypeToArrowType(tableName string, columnName string, sqlType string) (arrow.DataType, error) {
14-
if strings.HasPrefix(sqlType, "datetime") {
15-
// MySQL permits up to microseconds (6 digits) precision
16-
return arrow.FixedWidthTypes.Timestamp_us, nil
11+
const defaultPrecision = 10
12+
const defaultScale = 0
13+
14+
func isUnsigned(sqlType string) bool {
15+
return strings.Contains(sqlType, "unsigned")
16+
}
17+
18+
func parseStringValue(str string, defaultValue int32) int32 {
19+
val, err := strconv.ParseInt(str, 10, 32)
20+
if err != nil {
21+
return defaultValue
22+
}
23+
return int32(val)
24+
}
25+
26+
func getPrecisionAndScale(dataType string) (precision, scale int32) {
27+
str := strings.TrimPrefix(dataType, "decimal")
28+
str = strings.TrimPrefix(str, "numeric")
29+
if str == "" {
30+
return defaultPrecision, defaultScale
1731
}
18-
sqlTypeToSchemaType := map[string]arrow.DataType{
19-
"tinyint(1)": arrow.FixedWidthTypes.Boolean,
20-
"tinyint": arrow.PrimitiveTypes.Int8,
21-
"tinyint(4)": arrow.PrimitiveTypes.Int8,
22-
"smallint": arrow.PrimitiveTypes.Int16,
23-
"smallint(5)": arrow.PrimitiveTypes.Int16,
24-
"smallint(6)": arrow.PrimitiveTypes.Int16,
25-
"smallint(5) unsigned": arrow.PrimitiveTypes.Uint16,
26-
"int": arrow.PrimitiveTypes.Int32,
27-
"int(10) unsigned": arrow.PrimitiveTypes.Uint32,
28-
"int(11)": arrow.PrimitiveTypes.Int32,
29-
"bigint": arrow.PrimitiveTypes.Int64,
30-
"bigint(20)": arrow.PrimitiveTypes.Int64,
31-
"tinyint unsigned": arrow.PrimitiveTypes.Uint8,
32-
"tinyint(3) unsigned": arrow.PrimitiveTypes.Uint8,
33-
"smallint unsigned": arrow.PrimitiveTypes.Uint16,
34-
"int unsigned": arrow.PrimitiveTypes.Uint32,
35-
"bigint unsigned": arrow.PrimitiveTypes.Uint64,
36-
"bigint(20) unsigned": arrow.PrimitiveTypes.Uint64,
37-
"float": arrow.PrimitiveTypes.Float32,
38-
"double": arrow.PrimitiveTypes.Float64,
39-
"binary(16)": types.ExtensionTypes.UUID,
40-
"blob": arrow.BinaryTypes.LargeBinary,
41-
"text": arrow.BinaryTypes.LargeString,
42-
"json": types.ExtensionTypes.JSON,
32+
str = strings.TrimPrefix(str, "(")
33+
str = strings.TrimSuffix(str, ")")
34+
parts := strings.Split(str, ",")
35+
36+
switch len(parts) {
37+
case 1:
38+
precision = parseStringValue(parts[0], defaultPrecision)
39+
scale = defaultScale
40+
case 2:
41+
precision = parseStringValue(parts[0], defaultPrecision)
42+
scale = parseStringValue(parts[1], defaultScale)
43+
default:
44+
precision = defaultPrecision
45+
scale = defaultScale
4346
}
47+
return precision, scale
48+
}
4449

45-
if v, ok := sqlTypeToSchemaType[sqlType]; ok {
46-
return v, nil
50+
func mySQLTypeToArrowType(sqlType string) arrow.DataType {
51+
if sqlType == "binary(16)" {
52+
return types.ExtensionTypes.UUID
53+
}
54+
if sqlType == "tinyint(1)" {
55+
return arrow.FixedWidthTypes.Boolean
56+
}
57+
if strings.HasPrefix(sqlType, "datetime") {
58+
return arrow.FixedWidthTypes.Timestamp_us
59+
}
60+
if strings.HasPrefix(sqlType, "decimal") || strings.HasPrefix(sqlType, "numeric") {
61+
precision, scale := getPrecisionAndScale(sqlType)
62+
return &arrow.Decimal128Type{Precision: precision, Scale: scale}
63+
}
64+
if strings.HasPrefix(sqlType, "tinyint") {
65+
if isUnsigned(sqlType) {
66+
return arrow.PrimitiveTypes.Uint8
67+
}
68+
return arrow.PrimitiveTypes.Int8
69+
}
70+
if strings.HasPrefix(sqlType, "smallint") {
71+
if isUnsigned(sqlType) {
72+
return arrow.PrimitiveTypes.Uint16
73+
}
74+
return arrow.PrimitiveTypes.Int16
75+
}
76+
if strings.HasPrefix(sqlType, "int") {
77+
if isUnsigned(sqlType) {
78+
return arrow.PrimitiveTypes.Uint32
79+
}
80+
return arrow.PrimitiveTypes.Int32
81+
}
82+
if strings.HasPrefix(sqlType, "bigint") {
83+
if isUnsigned(sqlType) {
84+
return arrow.PrimitiveTypes.Uint64
85+
}
86+
return arrow.PrimitiveTypes.Int64
87+
}
88+
switch sqlType {
89+
case "bool", "boolean":
90+
return arrow.FixedWidthTypes.Boolean
91+
case "float":
92+
return arrow.PrimitiveTypes.Float32
93+
case "double":
94+
return arrow.PrimitiveTypes.Float64
95+
case "timestamp":
96+
return arrow.FixedWidthTypes.Timestamp_us
97+
case "json":
98+
return types.ExtensionTypes.JSON
99+
case "binary", "blob":
100+
return arrow.BinaryTypes.Binary
47101
}
48102

49-
supportedTypes := maps.Keys(sqlTypeToSchemaType)
50-
supportedTypes = append(supportedTypes, "datetime")
51-
sort.Strings(supportedTypes)
52-
return nil, fmt.Errorf("got unknown MySQL type %q for column %q of table %q while trying to convert it to CloudQuery internal schema type. Supported MySQL types are %q", sqlType, columnName, tableName, supportedTypes)
103+
return arrow.BinaryTypes.String
53104
}
54105

55106
func arrowTypeToMySqlStr(t arrow.DataType) string {
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package client
2+
3+
import (
4+
"testing"
5+
)
6+
7+
func Test_getPrecisionAndScale(t *testing.T) {
8+
tests := []struct {
9+
name string
10+
dataTypes []string
11+
wantPrecision int32
12+
wantScale int32
13+
}{
14+
{
15+
name: "should return default precision and scale when not provided",
16+
dataTypes: []string{"decimal", "numeric"},
17+
wantPrecision: 10,
18+
wantScale: 0,
19+
},
20+
{
21+
name: "should return default scale when only precision is provided",
22+
dataTypes: []string{"decimal(15)", "numeric(15)"},
23+
wantPrecision: 15,
24+
wantScale: 0,
25+
},
26+
{
27+
name: "should return precision and scale when precision and scale are provided",
28+
dataTypes: []string{"decimal(15,2)", "numeric(15,2)"},
29+
wantPrecision: 15,
30+
wantScale: 2,
31+
},
32+
}
33+
for _, tt := range tests {
34+
t.Run(tt.name, func(t *testing.T) {
35+
for _, dataType := range tt.dataTypes {
36+
gotPrecision, gotScale := getPrecisionAndScale(dataType)
37+
if gotPrecision != tt.wantPrecision {
38+
t.Errorf("getPrecisionAndScale() gotPrecision = %v, want %v", gotPrecision, tt.wantPrecision)
39+
}
40+
if gotScale != tt.wantScale {
41+
t.Errorf("getPrecisionAndScale() gotScale = %v, want %v", gotScale, tt.wantScale)
42+
}
43+
}
44+
})
45+
}
46+
}

0 commit comments

Comments
 (0)