|
1 | 1 | package client |
2 | 2 |
|
3 | 3 | import ( |
4 | | - "fmt" |
5 | | - "sort" |
| 4 | + "strconv" |
6 | 5 | "strings" |
7 | 6 |
|
8 | 7 | "github.com/apache/arrow/go/v13/arrow" |
9 | 8 | "github.com/cloudquery/plugin-sdk/v3/types" |
10 | | - "golang.org/x/exp/maps" |
11 | 9 | ) |
12 | 10 |
|
13 | | -func mySQLTypeToArrowType(tableName string, columnName string, sqlType string) (arrow.DataType, error) { |
14 | | - if strings.HasPrefix(sqlType, "datetime") { |
15 | | - // MySQL permits up to microseconds (6 digits) precision |
16 | | - return arrow.FixedWidthTypes.Timestamp_us, nil |
| 11 | +const defaultPrecision = 10 |
| 12 | +const defaultScale = 0 |
| 13 | + |
| 14 | +func isUnsigned(sqlType string) bool { |
| 15 | + return strings.Contains(sqlType, "unsigned") |
| 16 | +} |
| 17 | + |
| 18 | +func parseStringValue(str string, defaultValue int32) int32 { |
| 19 | + val, err := strconv.ParseInt(str, 10, 32) |
| 20 | + if err != nil { |
| 21 | + return defaultValue |
| 22 | + } |
| 23 | + return int32(val) |
| 24 | +} |
| 25 | + |
| 26 | +func getPrecisionAndScale(dataType string) (precision, scale int32) { |
| 27 | + str := strings.TrimPrefix(dataType, "decimal") |
| 28 | + str = strings.TrimPrefix(str, "numeric") |
| 29 | + if str == "" { |
| 30 | + return defaultPrecision, defaultScale |
17 | 31 | } |
18 | | - sqlTypeToSchemaType := map[string]arrow.DataType{ |
19 | | - "tinyint(1)": arrow.FixedWidthTypes.Boolean, |
20 | | - "tinyint": arrow.PrimitiveTypes.Int8, |
21 | | - "tinyint(4)": arrow.PrimitiveTypes.Int8, |
22 | | - "smallint": arrow.PrimitiveTypes.Int16, |
23 | | - "smallint(5)": arrow.PrimitiveTypes.Int16, |
24 | | - "smallint(6)": arrow.PrimitiveTypes.Int16, |
25 | | - "smallint(5) unsigned": arrow.PrimitiveTypes.Uint16, |
26 | | - "int": arrow.PrimitiveTypes.Int32, |
27 | | - "int(10) unsigned": arrow.PrimitiveTypes.Uint32, |
28 | | - "int(11)": arrow.PrimitiveTypes.Int32, |
29 | | - "bigint": arrow.PrimitiveTypes.Int64, |
30 | | - "bigint(20)": arrow.PrimitiveTypes.Int64, |
31 | | - "tinyint unsigned": arrow.PrimitiveTypes.Uint8, |
32 | | - "tinyint(3) unsigned": arrow.PrimitiveTypes.Uint8, |
33 | | - "smallint unsigned": arrow.PrimitiveTypes.Uint16, |
34 | | - "int unsigned": arrow.PrimitiveTypes.Uint32, |
35 | | - "bigint unsigned": arrow.PrimitiveTypes.Uint64, |
36 | | - "bigint(20) unsigned": arrow.PrimitiveTypes.Uint64, |
37 | | - "float": arrow.PrimitiveTypes.Float32, |
38 | | - "double": arrow.PrimitiveTypes.Float64, |
39 | | - "binary(16)": types.ExtensionTypes.UUID, |
40 | | - "blob": arrow.BinaryTypes.LargeBinary, |
41 | | - "text": arrow.BinaryTypes.LargeString, |
42 | | - "json": types.ExtensionTypes.JSON, |
| 32 | + str = strings.TrimPrefix(str, "(") |
| 33 | + str = strings.TrimSuffix(str, ")") |
| 34 | + parts := strings.Split(str, ",") |
| 35 | + |
| 36 | + switch len(parts) { |
| 37 | + case 1: |
| 38 | + precision = parseStringValue(parts[0], defaultPrecision) |
| 39 | + scale = defaultScale |
| 40 | + case 2: |
| 41 | + precision = parseStringValue(parts[0], defaultPrecision) |
| 42 | + scale = parseStringValue(parts[1], defaultScale) |
| 43 | + default: |
| 44 | + precision = defaultPrecision |
| 45 | + scale = defaultScale |
43 | 46 | } |
| 47 | + return precision, scale |
| 48 | +} |
44 | 49 |
|
45 | | - if v, ok := sqlTypeToSchemaType[sqlType]; ok { |
46 | | - return v, nil |
| 50 | +func mySQLTypeToArrowType(sqlType string) arrow.DataType { |
| 51 | + if sqlType == "binary(16)" { |
| 52 | + return types.ExtensionTypes.UUID |
| 53 | + } |
| 54 | + if sqlType == "tinyint(1)" { |
| 55 | + return arrow.FixedWidthTypes.Boolean |
| 56 | + } |
| 57 | + if strings.HasPrefix(sqlType, "datetime") { |
| 58 | + return arrow.FixedWidthTypes.Timestamp_us |
| 59 | + } |
| 60 | + if strings.HasPrefix(sqlType, "decimal") || strings.HasPrefix(sqlType, "numeric") { |
| 61 | + precision, scale := getPrecisionAndScale(sqlType) |
| 62 | + return &arrow.Decimal128Type{Precision: precision, Scale: scale} |
| 63 | + } |
| 64 | + if strings.HasPrefix(sqlType, "tinyint") { |
| 65 | + if isUnsigned(sqlType) { |
| 66 | + return arrow.PrimitiveTypes.Uint8 |
| 67 | + } |
| 68 | + return arrow.PrimitiveTypes.Int8 |
| 69 | + } |
| 70 | + if strings.HasPrefix(sqlType, "smallint") { |
| 71 | + if isUnsigned(sqlType) { |
| 72 | + return arrow.PrimitiveTypes.Uint16 |
| 73 | + } |
| 74 | + return arrow.PrimitiveTypes.Int16 |
| 75 | + } |
| 76 | + if strings.HasPrefix(sqlType, "int") { |
| 77 | + if isUnsigned(sqlType) { |
| 78 | + return arrow.PrimitiveTypes.Uint32 |
| 79 | + } |
| 80 | + return arrow.PrimitiveTypes.Int32 |
| 81 | + } |
| 82 | + if strings.HasPrefix(sqlType, "bigint") { |
| 83 | + if isUnsigned(sqlType) { |
| 84 | + return arrow.PrimitiveTypes.Uint64 |
| 85 | + } |
| 86 | + return arrow.PrimitiveTypes.Int64 |
| 87 | + } |
| 88 | + switch sqlType { |
| 89 | + case "bool", "boolean": |
| 90 | + return arrow.FixedWidthTypes.Boolean |
| 91 | + case "float": |
| 92 | + return arrow.PrimitiveTypes.Float32 |
| 93 | + case "double": |
| 94 | + return arrow.PrimitiveTypes.Float64 |
| 95 | + case "timestamp": |
| 96 | + return arrow.FixedWidthTypes.Timestamp_us |
| 97 | + case "json": |
| 98 | + return types.ExtensionTypes.JSON |
| 99 | + case "binary", "blob": |
| 100 | + return arrow.BinaryTypes.Binary |
47 | 101 | } |
48 | 102 |
|
49 | | - supportedTypes := maps.Keys(sqlTypeToSchemaType) |
50 | | - supportedTypes = append(supportedTypes, "datetime") |
51 | | - sort.Strings(supportedTypes) |
52 | | - return nil, fmt.Errorf("got unknown MySQL type %q for column %q of table %q while trying to convert it to CloudQuery internal schema type. Supported MySQL types are %q", sqlType, columnName, tableName, supportedTypes) |
| 103 | + return arrow.BinaryTypes.String |
53 | 104 | } |
54 | 105 |
|
55 | 106 | func arrowTypeToMySqlStr(t arrow.DataType) string { |
|
0 commit comments