1515"""Model a set of read-only queries to a database as a snapshot."""
1616
1717import functools
18-
18+ import threading
1919from google .protobuf .struct_pb2 import Struct
2020from google .cloud .spanner_v1 import ExecuteSqlRequest
2121from google .cloud .spanner_v1 import ReadRequest
2727
2828from google .api_core .exceptions import InternalServerError
2929from google .api_core .exceptions import ServiceUnavailable
30+ from google .api_core .exceptions import InvalidArgument
3031from google .api_core import gapic_v1
3132from google .cloud .spanner_v1 ._helpers import _make_value_pb
3233from google .cloud .spanner_v1 ._helpers import _merge_query_options
4344
4445
4546def _restart_on_unavailable (
46- method , request , trace_name = None , session = None , attributes = None
47+ method ,
48+ request ,
49+ trace_name = None ,
50+ session = None ,
51+ attributes = None ,
52+ transaction = None ,
53+ transaction_selector = None ,
4754):
4855 """Restart iteration after :exc:`.ServiceUnavailable`.
4956
@@ -52,22 +59,51 @@ def _restart_on_unavailable(
5259
5360 :type request: proto
5461 :param request: request proto to call the method with
62+
63+ :type transaction: :class:`google.cloud.spanner_v1.snapshot._SnapshotBase`
64+ :param transaction: Snapshot or Transaction class object based on the type of transaction
65+
66+ :type transaction_selector: :class:`transaction_pb2.TransactionSelector`
67+ :param transaction_selector: Transaction selector object to be used in request if transaction is not passed,
68+ if both transaction_selector and transaction are passed, then transaction is given priority.
5569 """
70+
5671 resume_token = b""
5772 item_buffer = []
73+
74+ if transaction is not None :
75+ transaction_selector = transaction ._make_txn_selector ()
76+ elif transaction_selector is None :
77+ raise InvalidArgument (
78+ "Either transaction or transaction_selector should be set"
79+ )
80+
81+ request .transaction = transaction_selector
5882 with trace_call (trace_name , session , attributes ):
5983 iterator = method (request = request )
6084 while True :
6185 try :
6286 for item in iterator :
6387 item_buffer .append (item )
88+ # Setting the transaction id because the transaction begin was inlined for first rpc.
89+ if (
90+ transaction is not None
91+ and transaction ._transaction_id is None
92+ and item .metadata is not None
93+ and item .metadata .transaction is not None
94+ and item .metadata .transaction .id is not None
95+ ):
96+ transaction ._transaction_id = item .metadata .transaction .id
6497 if item .resume_token :
6598 resume_token = item .resume_token
6699 break
67100 except ServiceUnavailable :
68101 del item_buffer [:]
69102 with trace_call (trace_name , session , attributes ):
70103 request .resume_token = resume_token
104+ if transaction is not None :
105+ transaction_selector = transaction ._make_txn_selector ()
106+ request .transaction = transaction_selector
71107 iterator = method (request = request )
72108 continue
73109 except InternalServerError as exc :
@@ -80,6 +116,9 @@ def _restart_on_unavailable(
80116 del item_buffer [:]
81117 with trace_call (trace_name , session , attributes ):
82118 request .resume_token = resume_token
119+ if transaction is not None :
120+ transaction_selector = transaction ._make_txn_selector ()
121+ request .transaction = transaction_selector
83122 iterator = method (request = request )
84123 continue
85124
@@ -106,6 +145,7 @@ class _SnapshotBase(_SessionWrapper):
106145 _transaction_id = None
107146 _read_request_count = 0
108147 _execute_sql_count = 0
148+ _lock = threading .Lock ()
109149
110150 def _make_txn_selector (self ):
111151 """Helper for :meth:`read` / :meth:`execute_sql`.
@@ -180,13 +220,12 @@ def read(
180220 if self ._read_request_count > 0 :
181221 if not self ._multi_use :
182222 raise ValueError ("Cannot re-use single-use snapshot." )
183- if self ._transaction_id is None :
223+ if self ._transaction_id is None and self . _read_only :
184224 raise ValueError ("Transaction ID pending." )
185225
186226 database = self ._session ._database
187227 api = database .spanner_api
188228 metadata = _metadata_with_prefix (database .name )
189- transaction = self ._make_txn_selector ()
190229
191230 if request_options is None :
192231 request_options = RequestOptions ()
@@ -204,7 +243,6 @@ def read(
204243 table = table ,
205244 columns = columns ,
206245 key_set = keyset ._to_pb (),
207- transaction = transaction ,
208246 index = index ,
209247 limit = limit ,
210248 partition_token = partition ,
@@ -219,13 +257,32 @@ def read(
219257 )
220258
221259 trace_attributes = {"table_id" : table , "columns" : columns }
222- iterator = _restart_on_unavailable (
223- restart ,
224- request ,
225- "CloudSpanner.ReadOnlyTransaction" ,
226- self ._session ,
227- trace_attributes ,
228- )
260+
261+ if self ._transaction_id is None :
262+ # lock is added to handle the inline begin for first rpc
263+ with self ._lock :
264+ iterator = _restart_on_unavailable (
265+ restart ,
266+ request ,
267+ "CloudSpanner.ReadOnlyTransaction" ,
268+ self ._session ,
269+ trace_attributes ,
270+ transaction = self ,
271+ )
272+ self ._read_request_count += 1
273+ if self ._multi_use :
274+ return StreamedResultSet (iterator , source = self )
275+ else :
276+ return StreamedResultSet (iterator )
277+ else :
278+ iterator = _restart_on_unavailable (
279+ restart ,
280+ request ,
281+ "CloudSpanner.ReadOnlyTransaction" ,
282+ self ._session ,
283+ trace_attributes ,
284+ transaction = self ,
285+ )
229286
230287 self ._read_request_count += 1
231288
@@ -301,7 +358,7 @@ def execute_sql(
301358 if self ._read_request_count > 0 :
302359 if not self ._multi_use :
303360 raise ValueError ("Cannot re-use single-use snapshot." )
304- if self ._transaction_id is None :
361+ if self ._transaction_id is None and self . _read_only :
305362 raise ValueError ("Transaction ID pending." )
306363
307364 if params is not None :
@@ -315,7 +372,7 @@ def execute_sql(
315372
316373 database = self ._session ._database
317374 metadata = _metadata_with_prefix (database .name )
318- transaction = self . _make_txn_selector ()
375+
319376 api = database .spanner_api
320377
321378 # Query-level options have higher precedence than client-level and
@@ -336,7 +393,6 @@ def execute_sql(
336393 request = ExecuteSqlRequest (
337394 session = self ._session .name ,
338395 sql = sql ,
339- transaction = transaction ,
340396 params = params_pb ,
341397 param_types = param_types ,
342398 query_mode = query_mode ,
@@ -354,13 +410,34 @@ def execute_sql(
354410 )
355411
356412 trace_attributes = {"db.statement" : sql }
357- iterator = _restart_on_unavailable (
358- restart ,
359- request ,
360- "CloudSpanner.ReadWriteTransaction" ,
361- self ._session ,
362- trace_attributes ,
363- )
413+
414+ if self ._transaction_id is None :
415+ # lock is added to handle the inline begin for first rpc
416+ with self ._lock :
417+ iterator = _restart_on_unavailable (
418+ restart ,
419+ request ,
420+ "CloudSpanner.ReadWriteTransaction" ,
421+ self ._session ,
422+ trace_attributes ,
423+ transaction = self ,
424+ )
425+ self ._read_request_count += 1
426+ self ._execute_sql_count += 1
427+
428+ if self ._multi_use :
429+ return StreamedResultSet (iterator , source = self )
430+ else :
431+ return StreamedResultSet (iterator )
432+ else :
433+ iterator = _restart_on_unavailable (
434+ restart ,
435+ request ,
436+ "CloudSpanner.ReadWriteTransaction" ,
437+ self ._session ,
438+ trace_attributes ,
439+ transaction = self ,
440+ )
364441
365442 self ._read_request_count += 1
366443 self ._execute_sql_count += 1
0 commit comments