@@ -3,25 +3,54 @@ package client
33import (
44 "context"
55 "database/sql"
6- "strconv"
76 "strings"
87
98 "github.com/cloudquery/plugin-sdk/schema"
109 "golang.org/x/sync/errgroup"
1110)
1211
1312func (c * Client ) merge (ctx context.Context , table * schema.Table , data [][]any ) error {
14- tx , err := c .db .BeginTx (ctx , & sql.TxOptions {Isolation : sql .LevelSerializable })
13+ // get conn
14+ conn , err := c .db .Conn (ctx )
1515 if err != nil {
1616 return err
1717 }
1818 defer func () {
19+ _ = conn .Close ()
20+ }()
21+
22+ tx , err := conn .BeginTx (ctx , & sql.TxOptions {Isolation : sql .LevelSerializable })
23+ if err != nil {
24+ return err
25+ }
26+ defer func () {
27+ if err == nil {
28+ err = tx .Commit ()
29+ }
30+ // also gets tx.Commit err
1931 if err != nil {
2032 _ = tx .Rollback ()
2133 }
2234 }()
2335
24- stmt , err := tx .PrepareContext (ctx , c .mergeQuery (table ))
36+ tmpTable , err := c .createTmpTable (ctx , tx , table )
37+ if err != nil {
38+ return err
39+ }
40+ defer func () {
41+ delErr := c .dropTmpTable (ctx , tx , tmpTable )
42+ if err == nil {
43+ err = delErr
44+ }
45+ }()
46+
47+ // tx bound
48+ if err := c .bulkInsertTx (ctx , tx , tmpTable , table , data ); err != nil {
49+ return err
50+ }
51+
52+ // now do merge on real table
53+ stmt , err := tx .PrepareContext (ctx , c .mergeQuery (table , tmpTable ))
2554 if err != nil {
2655 return err
2756 }
@@ -35,30 +64,20 @@ func (c *Client) merge(ctx context.Context, table *schema.Table, data [][]any) e
3564 })
3665 }
3766
38- if err := eg .Wait (); err != nil {
39- return err
40- }
41-
42- return tx .Commit ()
67+ return eg .Wait ()
4368}
4469
45- func (c * Client ) mergeQuery (table * schema.Table ) string {
70+ func (c * Client ) mergeQuery (table * schema.Table , tmpTable string ) string {
4671 var sb strings.Builder
4772
4873 sb .WriteString ("merge into " )
4974 sb .WriteString (sanitizeIdentifier (table .Name ))
5075 sb .WriteString (" as [tgt]\n " )
5176
52- sb .WriteString ("using ( select\n \t " )
53-
54- columns := tableColumnsSanitized (table )
55- selParts := make ([]string , len (columns ))
56- for i , col := range columns {
57- selParts [i ] = "@p" + strconv .Itoa (i + 1 ) + " as " + col
58- }
59- sb .WriteString (strings .Join (selParts , ",\n \t " ))
77+ sb .WriteString ("using " )
78+ sb .WriteString (tmpTable )
6079
61- sb .WriteString ("\n ) as [src]\n on (\n " )
80+ sb .WriteString ("as [src]\n on (\n " )
6281
6382 pk := c .getPKSanitized (table )
6483 matchParts := make ([]string , len (pk ))
@@ -84,6 +103,7 @@ func (c *Client) mergeQuery(table *schema.Table) string {
84103 sb .WriteString ("\n " )
85104
86105 // insert
106+ columns := tableColumnsSanitized (table )
87107 sb .WriteString ("when not matched then insert\n (\n \t " )
88108 sb .WriteString (strings .Join (columns , ",\n \t " ))
89109 sb .WriteString ("\n )\n values\n (\n \t " )
0 commit comments