Skip to content

Commit 78de20d

Browse files
authored
fix(datastore): fix context leak in Iterator and Transaction spans (#14478)
Fixes #9528 This PR fixes a context leak in the `datastore` library where several methods of `Iterator` and `Transaction` were overwriting their internal context field with a new span context on every call (e.g., `t.ctx = trace.StartSpan(t.ctx, ...)`). When these methods are called repeatedly (e.g., in a loop over millions of objects calling `Cursor()` or `Get()`), the context chain grows linearly, eventually leading to a stack overflow in `context.Done()` or `context.Value()` due to deep recursion. Fix is to use a local context variable for the span instead of overwriting.
1 parent fbb543e commit 78de20d

4 files changed

Lines changed: 216 additions & 24 deletions

File tree

datastore/go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ require (
77
cloud.google.com/go/longrunning v0.9.0
88
github.com/google/go-cmp v0.7.0
99
github.com/googleapis/gax-go/v2 v2.21.0
10+
go.opentelemetry.io/otel/sdk v1.43.0
1011
google.golang.org/api v0.274.0
1112
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7
1213
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9
@@ -30,7 +31,6 @@ require (
3031
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect
3132
go.opentelemetry.io/otel v1.43.0 // indirect
3233
go.opentelemetry.io/otel/metric v1.43.0 // indirect
33-
go.opentelemetry.io/otel/sdk v1.43.0 // indirect
3434
go.opentelemetry.io/otel/trace v1.43.0 // indirect
3535
golang.org/x/crypto v0.49.0 // indirect
3636
golang.org/x/net v0.52.0 // indirect

datastore/query.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ func (c *Client) Count(ctx context.Context, q *Query) (n int, err error) {
619619
// directly.
620620
it := c.Run(ctx, newQ)
621621
for {
622-
err := it.nextBatch()
622+
err := it.nextBatch(it.ctx)
623623
if err == iterator.Done {
624624
return n, nil
625625
}
@@ -1085,7 +1085,7 @@ func (t *Iterator) Next(dst interface{}) (k *Key, err error) {
10851085
func (t *Iterator) next() (*Key, *pb.Entity, error) {
10861086
// Fetch additional batches while there are no more results.
10871087
for t.err == nil && len(t.results) == 0 {
1088-
t.err = t.nextBatch()
1088+
t.err = t.nextBatch(t.ctx)
10891089
}
10901090
if t.err != nil {
10911091
return nil, nil, t.err
@@ -1110,7 +1110,7 @@ func (t *Iterator) next() (*Key, *pb.Entity, error) {
11101110
}
11111111

11121112
// nextBatch makes a single call to the server for a batch of results.
1113-
func (t *Iterator) nextBatch() error {
1113+
func (t *Iterator) nextBatch(ctx context.Context) error {
11141114
if t.err != nil {
11151115
return t.err
11161116
}
@@ -1138,7 +1138,7 @@ func (t *Iterator) nextBatch() error {
11381138
}
11391139

11401140
// Run the query.
1141-
resp, err := t.client.client.RunQuery(t.ctx, t.req)
1141+
resp, err := t.client.client.RunQuery(ctx, t.req)
11421142
if err != nil {
11431143
return err
11441144
}
@@ -1250,12 +1250,12 @@ func fromPbExecutionStats(pbstats *pb.ExecutionStats) *ExecutionStats {
12501250

12511251
// Cursor returns a cursor for the iterator's current location.
12521252
func (t *Iterator) Cursor() (c Cursor, err error) {
1253-
t.ctx = trace.StartSpan(t.ctx, "cloud.google.com/go/datastore.Query.Cursor")
1254-
defer func() { trace.EndSpan(t.ctx, err) }()
1253+
ctx := trace.StartSpan(t.ctx, "cloud.google.com/go/datastore.Query.Cursor")
1254+
defer func() { trace.EndSpan(ctx, err) }()
12551255

12561256
// If there is still an offset, we need to the skip those results first.
12571257
for t.err == nil && t.offset > 0 {
1258-
t.err = t.nextBatch()
1258+
t.err = t.nextBatch(ctx)
12591259
}
12601260

12611261
if t.err != nil && t.err != iterator.Done {

datastore/trace_test.go

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
// Copyright 2026 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package datastore
16+
17+
import (
18+
"context"
19+
"testing"
20+
21+
pb "cloud.google.com/go/datastore/apiv1/datastorepb"
22+
"cloud.google.com/go/internal/testutil"
23+
"go.opentelemetry.io/otel/sdk/trace/tracetest"
24+
)
25+
26+
func TestTransactionTracingContextPropagation(t *testing.T) {
27+
ctx := context.Background()
28+
te := testutil.NewOpenTelemetryTestExporter()
29+
t.Cleanup(func() {
30+
te.Unregister(ctx)
31+
})
32+
33+
client, srv, cleanup := newMock(t)
34+
defer cleanup()
35+
36+
mockTxnID := []byte("tid")
37+
mockKey := NameKey("Kind", "name", nil)
38+
mockEntity := &pb.Entity{
39+
Key: keyToProto(mockKey),
40+
}
41+
42+
srv.addRPC(&pb.BeginTransactionRequest{ProjectId: mockProjectID}, &pb.BeginTransactionResponse{Transaction: mockTxnID})
43+
srv.addRPC(&pb.LookupRequest{
44+
ProjectId: mockProjectID,
45+
Keys: []*pb.Key{keyToProto(mockKey)},
46+
ReadOptions: &pb.ReadOptions{
47+
ConsistencyType: &pb.ReadOptions_Transaction{Transaction: mockTxnID},
48+
},
49+
}, &pb.LookupResponse{Found: []*pb.EntityResult{{Entity: mockEntity}}})
50+
srv.addRPC(&pb.LookupRequest{
51+
ProjectId: mockProjectID,
52+
Keys: []*pb.Key{keyToProto(mockKey)},
53+
ReadOptions: &pb.ReadOptions{
54+
ConsistencyType: &pb.ReadOptions_Transaction{Transaction: mockTxnID},
55+
},
56+
}, &pb.LookupResponse{Found: []*pb.EntityResult{{Entity: mockEntity}}})
57+
srv.addRPC(&pb.CommitRequest{
58+
ProjectId: mockProjectID,
59+
TransactionSelector: &pb.CommitRequest_Transaction{Transaction: mockTxnID},
60+
Mode: pb.CommitRequest_TRANSACTIONAL,
61+
}, &pb.CommitResponse{})
62+
63+
tx, err := client.NewTransaction(ctx)
64+
if err != nil {
65+
t.Fatal(err)
66+
}
67+
68+
var dst struct{}
69+
if err := tx.Get(mockKey, &dst); err != nil {
70+
t.Fatal(err)
71+
}
72+
if err := tx.Get(mockKey, &dst); err != nil {
73+
t.Fatal(err)
74+
}
75+
if _, err := tx.Commit(); err != nil {
76+
t.Fatal(err)
77+
}
78+
79+
spans := te.Spans()
80+
81+
var newTxnSpan, beginTxnSpan tracetest.SpanStub
82+
var getSpans []tracetest.SpanStub
83+
var commitSpan tracetest.SpanStub
84+
var hasNewTxn, hasBeginTxn, hasCommit bool
85+
86+
for _, s := range spans {
87+
switch s.Name {
88+
case "cloud.google.com/go/datastore.NewTransaction":
89+
newTxnSpan = s
90+
hasNewTxn = true
91+
case "cloud.google.com/go/datastore.Transaction.BeginTransaction":
92+
beginTxnSpan = s
93+
hasBeginTxn = true
94+
case "cloud.google.com/go/datastore.Transaction.Get":
95+
getSpans = append(getSpans, s)
96+
case "cloud.google.com/go/datastore.Transaction.Commit":
97+
commitSpan = s
98+
hasCommit = true
99+
}
100+
}
101+
102+
if !hasNewTxn {
103+
t.Fatal("missing NewTransaction span")
104+
}
105+
newTxnSpanID := newTxnSpan.SpanContext.SpanID()
106+
107+
if !hasBeginTxn {
108+
t.Fatal("missing BeginTransaction span")
109+
}
110+
if beginTxnSpan.Parent.SpanID() != newTxnSpanID {
111+
t.Errorf("BeginTransaction span parent = %v, want %v", beginTxnSpan.Parent.SpanID(), newTxnSpanID)
112+
}
113+
114+
if len(getSpans) != 2 {
115+
t.Fatalf("got %d Get spans, want 2", len(getSpans))
116+
}
117+
118+
for i, s := range getSpans {
119+
if s.Parent.SpanID() != newTxnSpanID {
120+
t.Errorf("Get span %d parent = %v, want %v (parent span: %v)", i, s.Parent.SpanID(), newTxnSpanID, s.Parent)
121+
}
122+
}
123+
124+
if !hasCommit {
125+
t.Fatal("missing Commit span")
126+
}
127+
if commitSpan.Parent.SpanID() != newTxnSpanID {
128+
t.Errorf("Commit span parent = %v, want %v", commitSpan.Parent.SpanID(), newTxnSpanID)
129+
}
130+
}
131+
132+
func TestIteratorTracingContextPropagation(t *testing.T) {
133+
ctx := context.Background()
134+
te := testutil.NewOpenTelemetryTestExporter()
135+
t.Cleanup(func() {
136+
te.Unregister(ctx)
137+
})
138+
139+
client, srv, cleanup := newMock(t)
140+
defer cleanup()
141+
142+
mockKind := "Kind"
143+
srv.addRPC(&pb.RunQueryRequest{
144+
ProjectId: mockProjectID,
145+
QueryType: &pb.RunQueryRequest_Query{Query: &pb.Query{Kind: []*pb.KindExpression{{Name: mockKind}}}},
146+
}, &pb.RunQueryResponse{})
147+
148+
q := NewQuery(mockKind)
149+
it := client.Run(ctx, q)
150+
151+
_, _ = it.Cursor()
152+
_, _ = it.Cursor()
153+
154+
spans := te.Spans()
155+
156+
var runSpan tracetest.SpanStub
157+
var cursorSpans []tracetest.SpanStub
158+
var hasRun bool
159+
160+
for _, s := range spans {
161+
switch s.Name {
162+
case "cloud.google.com/go/datastore.Query.Run":
163+
runSpan = s
164+
hasRun = true
165+
case "cloud.google.com/go/datastore.Query.Cursor":
166+
cursorSpans = append(cursorSpans, s)
167+
}
168+
}
169+
170+
if !hasRun {
171+
t.Fatal("missing Query.Run span")
172+
}
173+
runSpanID := runSpan.SpanContext.SpanID()
174+
175+
if len(cursorSpans) != 2 {
176+
t.Fatalf("got %d Cursor spans, want 2", len(cursorSpans))
177+
}
178+
179+
for i, s := range cursorSpans {
180+
if s.Parent.SpanID() != runSpanID {
181+
t.Errorf("Cursor span %d parent = %v, want %v", i, s.Parent.SpanID(), runSpanID)
182+
}
183+
}
184+
}

datastore/transaction.go

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,16 @@ func (c *Client) NewTransaction(ctx context.Context, opts ...TransactionOption)
179179
return c.newTransaction(ctx, newTransactionSettings(opts))
180180
}
181181

182+
// parseTransactionOptions returns the protobuf TransactionOptions and the span name
183+
// to be used for tracing the transaction.
182184
func (t *Transaction) parseTransactionOptions() (*pb.TransactionOptions, string) {
185+
const (
186+
defaultSpanName = "cloud.google.com/go/datastore.Transaction.BeginTransaction"
187+
readOnlySpanName = "cloud.google.com/go/datastore.Transaction.ReadOnlyTransaction"
188+
readWriteSpanName = "cloud.google.com/go/datastore.Transaction.ReadWriteTransaction"
189+
)
183190
if t.settings == nil {
184-
return nil, ""
191+
return nil, defaultSpanName
185192
}
186193

187194
if t.settings.readOnly {
@@ -192,17 +199,17 @@ func (t *Transaction) parseTransactionOptions() (*pb.TransactionOptions, string)
192199

193200
return &pb.TransactionOptions{
194201
Mode: &pb.TransactionOptions_ReadOnly_{ReadOnly: ro},
195-
}, "cloud.google.com/go/datastore.Transaction.ReadOnlyTransaction"
202+
}, readOnlySpanName
196203
}
197204

198205
if t.settings.prevID != nil {
199206
return &pb.TransactionOptions{
200207
Mode: &pb.TransactionOptions_ReadWrite_{ReadWrite: &pb.TransactionOptions_ReadWrite{
201208
PreviousTransaction: t.settings.prevID,
202209
}},
203-
}, "cloud.google.com/go/datastore.Transaction.ReadWriteTransaction"
210+
}, readWriteSpanName
204211
}
205-
return nil, ""
212+
return nil, defaultSpanName
206213
}
207214

208215
// beginTransaction makes BeginTransaction rpc
@@ -214,13 +221,14 @@ func (t *Transaction) beginTransaction() (txnID []byte, err error) {
214221
}
215222

216223
txOptionsPb, spanName := t.parseTransactionOptions()
224+
ctx := trace.StartSpan(t.ctx, spanName)
225+
defer func() { trace.EndSpan(ctx, err) }()
226+
217227
if txOptionsPb != nil {
218-
t.ctx = trace.StartSpan(t.ctx, spanName)
219-
defer func() { trace.EndSpan(t.ctx, err) }()
220228
req.TransactionOptions = txOptionsPb
221229
}
222230

223-
resp, err := t.client.client.BeginTransaction(t.ctx, req)
231+
resp, err := t.client.client.BeginTransaction(ctx, req)
224232
if err != nil {
225233
return nil, err
226234
}
@@ -412,8 +420,8 @@ func grpcStatusCode(err error) (codes.Code, error) {
412420

413421
// Commit applies the enqueued operations atomically.
414422
func (t *Transaction) Commit() (c *Commit, err error) {
415-
t.ctx = trace.StartSpan(t.ctx, "cloud.google.com/go/datastore.Transaction.Commit")
416-
defer func() { trace.EndSpan(t.ctx, err) }()
423+
ctx := trace.StartSpan(t.ctx, "cloud.google.com/go/datastore.Transaction.Commit")
424+
defer func() { trace.EndSpan(ctx, err) }()
417425

418426
t.stateLock.Lock()
419427
if t.state == transactionStateExpired {
@@ -434,7 +442,7 @@ func (t *Transaction) Commit() (c *Commit, err error) {
434442
Mutations: t.mutations,
435443
Mode: pb.CommitRequest_TRANSACTIONAL,
436444
}
437-
resp, err := t.client.client.Commit(t.ctx, req)
445+
resp, err := t.client.client.Commit(ctx, req)
438446
if status.Code(err) == codes.Aborted {
439447
return nil, ErrConcurrentTransaction
440448
}
@@ -484,8 +492,8 @@ func (t *Transaction) rollbackWithRetry() error {
484492

485493
// Rollback abandons a pending transaction.
486494
func (t *Transaction) Rollback() (err error) {
487-
t.ctx = trace.StartSpan(t.ctx, "cloud.google.com/go/datastore.Transaction.Rollback")
488-
defer func() { trace.EndSpan(t.ctx, err) }()
495+
ctx := trace.StartSpan(t.ctx, "cloud.google.com/go/datastore.Transaction.Rollback")
496+
defer func() { trace.EndSpan(ctx, err) }()
489497

490498
if t.state == transactionStateExpired {
491499
return errExpiredTransaction
@@ -501,7 +509,7 @@ func (t *Transaction) Rollback() (err error) {
501509
return err
502510
}
503511

504-
_, err = t.client.client.Rollback(t.ctx, &pb.RollbackRequest{
512+
_, err = t.client.client.Rollback(ctx, &pb.RollbackRequest{
505513
ProjectId: t.client.dataset,
506514
DatabaseId: t.client.databaseID,
507515
Transaction: t.id,
@@ -542,15 +550,15 @@ func (t *Transaction) parseReadOptions() (*pb.ReadOptions, error) {
542550
}
543551

544552
func (t *Transaction) get(spanName string, keys []*Key, dst interface{}) (err error) {
545-
t.ctx = trace.StartSpan(t.ctx, spanName)
546-
defer func() { trace.EndSpan(t.ctx, err) }()
553+
ctx := trace.StartSpan(t.ctx, spanName)
554+
defer func() { trace.EndSpan(ctx, err) }()
547555

548556
opts, err := t.parseReadOptions()
549557
if err != nil {
550558
return err
551559
}
552560

553-
txnID, err := t.client.get(t.ctx, keys, dst, opts)
561+
txnID, err := t.client.get(ctx, keys, dst, opts)
554562

555563
if txnID != nil {
556564
t.stateLock.Lock()

0 commit comments

Comments
 (0)