Skip to content

Commit fae6b7a

Browse files
committed
tmp table bulk
1 parent db7df16 commit fae6b7a

3 files changed

Lines changed: 77 additions & 24 deletions

File tree

plugins/destination/mssql/client/bulk_insert.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,15 @@ func (c *Client) bulkInsert(ctx context.Context, table *schema.Table, data [][]a
1919
}
2020
}()
2121

22-
stmt, err := tx.PrepareContext(ctx, mssql.CopyIn(sanitizeIdentifier(table.Name), mssql.BulkOptions{
22+
if err := c.bulkInsertTx(ctx, tx, sanitizeIdentifier(table.Name), table, data); err != nil {
23+
return err
24+
}
25+
26+
return tx.Commit()
27+
}
28+
29+
func (c *Client) bulkInsertTx(ctx context.Context, tx *sql.Tx, tableName string, table *schema.Table, data [][]any) (err error) {
30+
stmt, err := tx.PrepareContext(ctx, mssql.CopyIn(tableName, mssql.BulkOptions{
2331
KeepNulls: true,
2432
KilobytesPerBatch: c.spec.BatchSizeBytes >> 10,
2533
RowsPerBatch: c.spec.BatchSize,
@@ -36,9 +44,6 @@ func (c *Client) bulkInsert(ctx context.Context, table *schema.Table, data [][]a
3644
}
3745

3846
// send bulkInsert
39-
if _, err := stmt.ExecContext(ctx); err != nil {
40-
return err
41-
}
42-
43-
return tx.Commit()
47+
_, err = stmt.ExecContext(ctx)
48+
return err
4449
}

plugins/destination/mssql/client/merge.go

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,54 @@ package client
33
import (
44
"context"
55
"database/sql"
6-
"strconv"
76
"strings"
87

98
"github.com/cloudquery/plugin-sdk/schema"
109
"golang.org/x/sync/errgroup"
1110
)
1211

1312
func (c *Client) merge(ctx context.Context, table *schema.Table, data [][]any) error {
14-
tx, err := c.db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelSerializable})
13+
// get conn
14+
conn, err := c.db.Conn(ctx)
1515
if err != nil {
1616
return err
1717
}
1818
defer func() {
19+
_ = conn.Close()
20+
}()
21+
22+
tx, err := conn.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelSerializable})
23+
if err != nil {
24+
return err
25+
}
26+
defer func() {
27+
if err == nil {
28+
err = tx.Commit()
29+
}
30+
// also gets tx.Commit err
1931
if err != nil {
2032
_ = tx.Rollback()
2133
}
2234
}()
2335

24-
stmt, err := tx.PrepareContext(ctx, c.mergeQuery(table))
36+
tmpTable, err := c.createTmpTable(ctx, tx, table)
37+
if err != nil {
38+
return err
39+
}
40+
defer func() {
41+
delErr := c.dropTmpTable(ctx, tx, tmpTable)
42+
if err == nil {
43+
err = delErr
44+
}
45+
}()
46+
47+
// tx bound
48+
if err := c.bulkInsertTx(ctx, tx, tmpTable, table, data); err != nil {
49+
return err
50+
}
51+
52+
// now do merge on real table
53+
stmt, err := tx.PrepareContext(ctx, c.mergeQuery(table, tmpTable))
2554
if err != nil {
2655
return err
2756
}
@@ -35,30 +64,20 @@ func (c *Client) merge(ctx context.Context, table *schema.Table, data [][]any) e
3564
})
3665
}
3766

38-
if err := eg.Wait(); err != nil {
39-
return err
40-
}
41-
42-
return tx.Commit()
67+
return eg.Wait()
4368
}
4469

45-
func (c *Client) mergeQuery(table *schema.Table) string {
70+
func (c *Client) mergeQuery(table *schema.Table, tmpTable string) string {
4671
var sb strings.Builder
4772

4873
sb.WriteString("merge into ")
4974
sb.WriteString(sanitizeIdentifier(table.Name))
5075
sb.WriteString(" as [tgt]\n")
5176

52-
sb.WriteString("using ( select\n\t")
53-
54-
columns := tableColumnsSanitized(table)
55-
selParts := make([]string, len(columns))
56-
for i, col := range columns {
57-
selParts[i] = "@p" + strconv.Itoa(i+1) + " as " + col
58-
}
59-
sb.WriteString(strings.Join(selParts, ",\n\t"))
77+
sb.WriteString("using ")
78+
sb.WriteString(tmpTable)
6079

61-
sb.WriteString("\n) as [src]\non (\n")
80+
sb.WriteString("as [src]\non (\n")
6281

6382
pk := c.getPKSanitized(table)
6483
matchParts := make([]string, len(pk))
@@ -84,6 +103,7 @@ func (c *Client) mergeQuery(table *schema.Table) string {
84103
sb.WriteString("\n")
85104

86105
// insert
106+
columns := tableColumnsSanitized(table)
87107
sb.WriteString("when not matched then insert\n(\n\t")
88108
sb.WriteString(strings.Join(columns, ",\n\t"))
89109
sb.WriteString("\n)\nvalues\n(\n\t")
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package client
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"strings"
7+
8+
"github.com/cloudquery/plugin-sdk/schema"
9+
)
10+
11+
func (c *Client) createTmpTable(ctx context.Context, tx *sql.Tx, table *schema.Table) (string, error) {
12+
tmpTableName := sanitizeIdentifier("#tmp_" + table.Name)
13+
14+
var sb strings.Builder
15+
sb.WriteString("create table ")
16+
sb.WriteString(tmpTableName)
17+
sb.WriteString("\n(\n\t")
18+
sb.WriteString(strings.Join(getColumnDefinitions(table.Columns), ",\n\t"))
19+
sb.WriteString("\n);")
20+
21+
_, err := tx.ExecContext(ctx, sb.String())
22+
return tmpTableName, err
23+
}
24+
25+
func (c *Client) dropTmpTable(ctx context.Context, tx *sql.Tx, name string) error {
26+
_, err := tx.ExecContext(ctx, "drop table "+name+";")
27+
return err
28+
}

0 commit comments

Comments
 (0)