Skip to content

Commit 189a8c3

Browse files
authored
fix: Concurrent map access when reporting remote progress (#21108)
#### Summary Our race detection tests didn't catch this since we only report remote progress for sync runs API keys https://github.com/cloudquery/cloudquery/blob/0e0deec32b95fac6af9f27e53e8fd1bb75921da2/cli/cmd/sync_v3.go#L89 I'll follow up with a test later
1 parent 9f5b730 commit 189a8c3

2 files changed

Lines changed: 48 additions & 8 deletions

File tree

cli/cmd/sync_v3.go

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import (
3434
"github.com/cloudquery/cloudquery/cli/v6/internal/tablenamechanger"
3535
"github.com/cloudquery/cloudquery/cli/v6/internal/transformer"
3636
"github.com/cloudquery/cloudquery/cli/v6/internal/transformerpipeline"
37+
"github.com/cloudquery/cloudquery/cli/v6/internal/utils"
3738
)
3839

3940
type v3source struct {
@@ -149,7 +150,7 @@ func syncConnectionV3(ctx context.Context, syncOptions syncV3Options) (syncErr e
149150
syncTimeTook time.Duration
150151
totalResources = int64(0)
151152
totals = sourceClient.Metrics()
152-
statsPerTable = cloudquery_api.SyncRunTableProgress{}
153+
statsPerTable = utils.NewConcurrentMap[string, cloudquery_api.SyncRunTableProgressValue]()
153154
)
154155
defer func() {
155156
analytics.TrackSyncCompleted(ctx, invocationUUID.UUID, analytics.SyncFinishedEvent{
@@ -363,10 +364,11 @@ func syncConnectionV3(ctx context.Context, syncOptions syncV3Options) (syncErr e
363364
}
364365
// Pre init stats per table
365366
for _, table := range sourceTables {
366-
statsPerTable[table.Name] = cloudquery_api.SyncRunTableProgressValue{
367+
initialStats := cloudquery_api.SyncRunTableProgressValue{
367368
Rows: 0,
368369
Errors: 0,
369370
}
371+
statsPerTable.Add(table.Name, initialStats)
370372
}
371373

372374
var remoteProgressReporter *godebouncer.Debouncer
@@ -390,12 +392,13 @@ func syncConnectionV3(ctx context.Context, syncOptions syncV3Options) (syncErr e
390392
if atomic.LoadInt64(&isComplete) == 1 {
391393
status = cloudquery_api.SyncRunStatusCompleted
392394
}
395+
tableProgress := cloudquery_api.SyncRunTableProgress(statsPerTable.GetAll())
393396
obj := cloudquery_api.CreateSyncRunProgressJSONRequestBody{
394397
Rows: atomic.LoadInt64(&totalResources),
395398
Errors: int64(totals.Errors),
396399
Warnings: int64(totals.Warnings),
397400
Status: &status,
398-
TableProgress: &statsPerTable,
401+
TableProgress: &tableProgress,
399402
}
400403
if shard != nil {
401404
obj.ShardNum = &shard.num
@@ -507,9 +510,9 @@ func syncConnectionV3(ctx context.Context, syncOptions syncV3Options) (syncErr e
507510
atomic.AddInt64(&newResources, record.NumRows())
508511
atomic.AddInt64(&totalResources, record.NumRows())
509512
tableName, _ := record.Schema().Metadata().GetValue(schema.MetadataTableName)
510-
stats := statsPerTable[tableName]
513+
stats, _ := statsPerTable.Get(tableName)
511514
stats.Rows += record.NumRows()
512-
statsPerTable[tableName] = stats
515+
statsPerTable.Add(tableName, stats)
513516
if remoteProgressReporter != nil {
514517
remoteProgressReporter.SendSignal()
515518
}
@@ -623,9 +626,9 @@ func syncConnectionV3(ctx context.Context, syncOptions syncV3Options) (syncErr e
623626
}
624627
case *plugin.Sync_Response_Error:
625628
log.Error().Str("table", m.Error.TableName).Msg(m.Error.Error)
626-
stats := statsPerTable[m.Error.TableName]
629+
stats, _ := statsPerTable.Get(m.Error.TableName)
627630
stats.Errors++
628-
statsPerTable[m.Error.TableName] = stats
631+
statsPerTable.Add(m.Error.TableName, stats)
629632
default:
630633
return fmt.Errorf("unknown message type: %T", m)
631634
}
@@ -643,7 +646,7 @@ func syncConnectionV3(ctx context.Context, syncOptions syncV3Options) (syncErr e
643646
totals = sourceClient.Metrics()
644647
sourceWarnings := totals.Warnings
645648
var sourceErrors uint64
646-
for _, val := range statsPerTable {
649+
for _, val := range statsPerTable.GetAll() {
647650
sourceErrors += uint64(val.Errors)
648651
}
649652
if totals.Errors > sourceErrors {
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package utils
2+
3+
import (
4+
"maps"
5+
"sync"
6+
)
7+
8+
// ConcurrentMap is a thread-safe map of values.
9+
type ConcurrentMap[T comparable, U any] struct {
10+
mu sync.RWMutex
11+
m map[T]U
12+
}
13+
14+
func NewConcurrentMap[T comparable, U any]() *ConcurrentMap[T, U] {
15+
return &ConcurrentMap[T, U]{
16+
m: make(map[T]U),
17+
}
18+
}
19+
20+
func (s *ConcurrentMap[T, U]) Add(key T, value U) {
21+
s.mu.Lock()
22+
defer s.mu.Unlock()
23+
s.m[key] = value
24+
}
25+
26+
func (s *ConcurrentMap[T, U]) Get(key T) (U, bool) {
27+
s.mu.RLock()
28+
defer s.mu.RUnlock()
29+
v, ok := s.m[key]
30+
return v, ok
31+
}
32+
33+
func (s *ConcurrentMap[T, U]) GetAll() map[T]U {
34+
s.mu.RLock()
35+
defer s.mu.RUnlock()
36+
return maps.Clone(s.m)
37+
}

0 commit comments

Comments
 (0)