Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit 81505cd

Browse files
authored
feat: Add support and tests for DML returning clauses (#805)
This change adds support for DML returning clauses and includes a few prerequisite changes. I would suggest reviewing commit-by-commit. The commit messages provide additional context and are reproduced below, ### feat: Support custom endpoint when running tests By setting the `GOOGLE_CLOUD_TESTS_SPANNER_HOST` environment variable you can now run tests against an alternate Spanner API endpoint. This is particularly useful for running system tests against a pre-production deployment. ### refactor(dbapi): Remove most special handling of INSERTs For historical reasons it seems the INSERT codepath and that for UPDATE/DELETE were separated, but today there appears to be no practical differences in how these DML statements are handled. This change removes most of the special handling for INSERTs and uses existing methods for UPDATEs/DELETEs instead. The one remaining exception is the automatic addition of a WHERE clause to UPDATE and DELETE statements lacking one, which does not apply to INSERT statements. ### feat(dbapi): Add full support for rowcount Previously, rowcount was only available after executing an UPDATE or DELETE in autocommit mode. This change extends this support so that a rowcount is available for all DML statements, regardless of whether autocommit is enabled. ### feat: Add support for returning clause in DML This change adds support and tests for a returning clause in DML statements. This is done by moving executing of all DML to use `execute_sql`, which is already used when not in autocommit mode.
1 parent 1922a2e commit 81505cd

10 files changed

Lines changed: 304 additions & 178 deletions

File tree

google/cloud/spanner_dbapi/_helpers.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from google.cloud.spanner_dbapi.parse_utils import get_param_types
16-
from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner
1715
from google.cloud.spanner_v1 import param_types
1816

1917

@@ -47,24 +45,6 @@
4745
}
4846

4947

50-
def _execute_insert_heterogenous(
51-
transaction,
52-
sql_params_list,
53-
request_options=None,
54-
):
55-
for sql, params in sql_params_list:
56-
sql, params = sql_pyformat_args_to_spanner(sql, params)
57-
transaction.execute_update(
58-
sql, params, get_param_types(params), request_options=request_options
59-
)
60-
61-
62-
def handle_insert(connection, sql, params):
63-
return connection.database.run_in_transaction(
64-
_execute_insert_heterogenous, ((sql, params),), connection.request_options
65-
)
66-
67-
6848
class ColumnInfo:
6949
"""Row column description object."""
7050

google/cloud/spanner_dbapi/connection.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from google.cloud.spanner_v1.session import _get_retry_delay
2525
from google.cloud.spanner_v1.snapshot import Snapshot
2626

27-
from google.cloud.spanner_dbapi._helpers import _execute_insert_heterogenous
2827
from google.cloud.spanner_dbapi.checksum import _compare_checksums
2928
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
3029
from google.cloud.spanner_dbapi.cursor import Cursor
@@ -450,15 +449,6 @@ def run_statement(self, statement, retried=False):
450449
if not retried:
451450
self._statements.append(statement)
452451

453-
if statement.is_insert:
454-
_execute_insert_heterogenous(
455-
transaction, ((statement.sql, statement.params),), self.request_options
456-
)
457-
return (
458-
iter(()),
459-
ResultsChecksum() if retried else statement.checksum,
460-
)
461-
462452
return (
463453
transaction.execute_sql(
464454
statement.sql,

google/cloud/spanner_dbapi/cursor.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
_UNSET_COUNT = -1
4848

4949
ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"])
50-
Statement = namedtuple("Statement", "sql, params, param_types, checksum, is_insert")
50+
Statement = namedtuple("Statement", "sql, params, param_types, checksum")
5151

5252

5353
def check_not_closed(function):
@@ -137,14 +137,21 @@ def description(self):
137137

138138
@property
139139
def rowcount(self):
140-
"""The number of rows updated by the last UPDATE, DELETE request's `execute()` call.
140+
"""The number of rows updated by the last INSERT, UPDATE, DELETE request's `execute()` call.
141141
For SELECT requests the rowcount returns -1.
142142
143143
:rtype: int
144-
:returns: The number of rows updated by the last UPDATE, DELETE request's .execute*() call.
144+
:returns: The number of rows updated by the last INSERT, UPDATE, DELETE request's .execute*() call.
145145
"""
146146

147-
return self._row_count
147+
if self._row_count != _UNSET_COUNT or self._result_set is None:
148+
return self._row_count
149+
150+
stats = getattr(self._result_set, "stats", None)
151+
if stats is not None and "row_count_exact" in stats:
152+
return stats.row_count_exact
153+
154+
return _UNSET_COUNT
148155

149156
@check_not_closed
150157
def callproc(self, procname, args=None):
@@ -171,17 +178,11 @@ def close(self):
171178
self._is_closed = True
172179

173180
def _do_execute_update(self, transaction, sql, params):
174-
result = transaction.execute_update(
175-
sql,
176-
params=params,
177-
param_types=get_param_types(params),
178-
request_options=self.connection.request_options,
181+
self._result_set = transaction.execute_sql(
182+
sql, params=params, param_types=get_param_types(params)
179183
)
180-
self._itr = None
181-
if type(result) == int:
182-
self._row_count = result
183-
184-
return result
184+
self._itr = PeekIterator(self._result_set)
185+
self._row_count = _UNSET_COUNT
185186

186187
def _do_batch_update(self, transaction, statements, many_result_set):
187188
status, res = transaction.batch_update(statements)
@@ -227,7 +228,9 @@ def execute(self, sql, args=None):
227228
:type args: list
228229
:param args: Additional parameters to supplement the SQL query.
229230
"""
231+
self._itr = None
230232
self._result_set = None
233+
self._row_count = _UNSET_COUNT
231234

232235
try:
233236
if self.connection.read_only:
@@ -249,18 +252,14 @@ def execute(self, sql, args=None):
249252
if class_ == parse_utils.STMT_UPDATING:
250253
sql = parse_utils.ensure_where_clause(sql)
251254

252-
if class_ != parse_utils.STMT_INSERT:
253-
sql, args = sql_pyformat_args_to_spanner(sql, args or None)
255+
sql, args = sql_pyformat_args_to_spanner(sql, args or None)
254256

255257
if not self.connection.autocommit:
256258
statement = Statement(
257259
sql,
258260
args,
259-
get_param_types(args or None)
260-
if class_ != parse_utils.STMT_INSERT
261-
else {},
261+
get_param_types(args or None),
262262
ResultsChecksum(),
263-
class_ == parse_utils.STMT_INSERT,
264263
)
265264

266265
(
@@ -277,8 +276,6 @@ def execute(self, sql, args=None):
277276

278277
if class_ == parse_utils.STMT_NON_UPDATING:
279278
self._handle_DQL(sql, args or None)
280-
elif class_ == parse_utils.STMT_INSERT:
281-
_helpers.handle_insert(self.connection, sql, args or None)
282279
else:
283280
self.connection.database.run_in_transaction(
284281
self._do_execute_update,
@@ -304,6 +301,10 @@ def executemany(self, operation, seq_of_params):
304301
:param seq_of_params: Sequence of additional parameters to run
305302
the query with.
306303
"""
304+
self._itr = None
305+
self._result_set = None
306+
self._row_count = _UNSET_COUNT
307+
307308
class_ = parse_utils.classify_stmt(operation)
308309
if class_ == parse_utils.STMT_DDL:
309310
raise ProgrammingError(
@@ -327,6 +328,7 @@ def executemany(self, operation, seq_of_params):
327328
)
328329
else:
329330
retried = False
331+
total_row_count = 0
330332
while True:
331333
try:
332334
transaction = self.connection.transaction_checkout()
@@ -341,12 +343,14 @@ def executemany(self, operation, seq_of_params):
341343
many_result_set.add_iter(res)
342344
res_checksum.consume_result(res)
343345
res_checksum.consume_result(status.code)
346+
total_row_count += sum([max(val, 0) for val in res])
344347

345348
if status.code == ABORTED:
346349
self.connection._transaction = None
347350
raise Aborted(status.message)
348351
elif status.code != OK:
349352
raise OperationalError(status.message)
353+
self._row_count = total_row_count
350354
break
351355
except Aborted:
352356
self.connection.retry_transaction()

tests/system/_helpers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
INSTANCE_ID_DEFAULT = "google-cloud-python-systest"
3131
INSTANCE_ID = os.environ.get(INSTANCE_ID_ENVVAR, INSTANCE_ID_DEFAULT)
3232

33+
API_ENDPOINT_ENVVAR = "GOOGLE_CLOUD_TESTS_SPANNER_HOST"
34+
API_ENDPOINT = os.getenv(API_ENDPOINT_ENVVAR)
35+
3336
SKIP_BACKUP_TESTS_ENVVAR = "SKIP_BACKUP_TESTS"
3437
SKIP_BACKUP_TESTS = os.getenv(SKIP_BACKUP_TESTS_ENVVAR) is not None
3538

tests/system/conftest.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,10 @@ def spanner_client():
8585
credentials=credentials,
8686
)
8787
else:
88-
return spanner_v1.Client() # use google.auth.default credentials
88+
client_options = {"api_endpoint": _helpers.API_ENDPOINT}
89+
return spanner_v1.Client(
90+
client_options=client_options
91+
) # use google.auth.default credentials
8992

9093

9194
@pytest.fixture(scope="session")

tests/system/test_dbapi.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,3 +501,126 @@ def test_staleness(shared_instance, dbapi_database):
501501
assert len(cursor.fetchall()) == 1
502502

503503
conn.close()
504+
505+
506+
@pytest.mark.parametrize("autocommit", [False, True])
507+
def test_rowcount(shared_instance, dbapi_database, autocommit):
508+
conn = Connection(shared_instance, dbapi_database)
509+
conn.autocommit = autocommit
510+
cur = conn.cursor()
511+
512+
cur.execute(
513+
"""
514+
CREATE TABLE Singers (
515+
SingerId INT64 NOT NULL,
516+
Name STRING(1024),
517+
) PRIMARY KEY (SingerId)
518+
"""
519+
)
520+
conn.commit()
521+
522+
# executemany sets rowcount to the total modified rows
523+
rows = [(i, f"Singer {i}") for i in range(100)]
524+
cur.executemany("INSERT INTO Singers (SingerId, Name) VALUES (%s, %s)", rows[:98])
525+
assert cur.rowcount == 98
526+
527+
# execute with INSERT
528+
cur.execute(
529+
"INSERT INTO Singers (SingerId, Name) VALUES (%s, %s), (%s, %s)",
530+
[x for row in rows[98:] for x in row],
531+
)
532+
assert cur.rowcount == 2
533+
534+
# execute with UPDATE
535+
cur.execute("UPDATE Singers SET Name = 'Cher' WHERE SingerId < 25")
536+
assert cur.rowcount == 25
537+
538+
# execute with SELECT
539+
cur.execute("SELECT Name FROM Singers WHERE SingerId < 75")
540+
assert len(cur.fetchall()) == 75
541+
# rowcount is not available for SELECT
542+
assert cur.rowcount == -1
543+
544+
# execute with DELETE
545+
cur.execute("DELETE FROM Singers")
546+
assert cur.rowcount == 100
547+
548+
# execute with UPDATE matching 0 rows
549+
cur.execute("UPDATE Singers SET Name = 'Cher' WHERE SingerId < 25")
550+
assert cur.rowcount == 0
551+
552+
conn.commit()
553+
cur.execute("DROP TABLE Singers")
554+
conn.commit()
555+
556+
557+
@pytest.mark.parametrize("autocommit", [False, True])
558+
@pytest.mark.skipif(
559+
_helpers.USE_EMULATOR, reason="Emulator does not support DML Returning."
560+
)
561+
def test_dml_returning_insert(shared_instance, dbapi_database, autocommit):
562+
conn = Connection(shared_instance, dbapi_database)
563+
conn.autocommit = autocommit
564+
cur = conn.cursor()
565+
cur.execute(
566+
"""
567+
INSERT INTO contacts (contact_id, first_name, last_name, email)
568+
VALUES (1, 'first-name', 'last-name', 'test.email@example.com')
569+
THEN RETURN contact_id, first_name
570+
"""
571+
)
572+
assert cur.fetchone() == (1, "first-name")
573+
assert cur.rowcount == 1
574+
conn.commit()
575+
576+
577+
@pytest.mark.parametrize("autocommit", [False, True])
578+
@pytest.mark.skipif(
579+
_helpers.USE_EMULATOR, reason="Emulator does not support DML Returning."
580+
)
581+
def test_dml_returning_update(shared_instance, dbapi_database, autocommit):
582+
conn = Connection(shared_instance, dbapi_database)
583+
conn.autocommit = autocommit
584+
cur = conn.cursor()
585+
cur.execute(
586+
"""
587+
INSERT INTO contacts (contact_id, first_name, last_name, email)
588+
VALUES (1, 'first-name', 'last-name', 'test.email@example.com')
589+
"""
590+
)
591+
assert cur.rowcount == 1
592+
cur.execute(
593+
"""
594+
UPDATE contacts SET first_name = 'new-name' WHERE contact_id = 1
595+
THEN RETURN contact_id, first_name
596+
"""
597+
)
598+
assert cur.fetchone() == (1, "new-name")
599+
assert cur.rowcount == 1
600+
conn.commit()
601+
602+
603+
@pytest.mark.parametrize("autocommit", [False, True])
604+
@pytest.mark.skipif(
605+
_helpers.USE_EMULATOR, reason="Emulator does not support DML Returning."
606+
)
607+
def test_dml_returning_delete(shared_instance, dbapi_database, autocommit):
608+
conn = Connection(shared_instance, dbapi_database)
609+
conn.autocommit = autocommit
610+
cur = conn.cursor()
611+
cur.execute(
612+
"""
613+
INSERT INTO contacts (contact_id, first_name, last_name, email)
614+
VALUES (1, 'first-name', 'last-name', 'test.email@example.com')
615+
"""
616+
)
617+
assert cur.rowcount == 1
618+
cur.execute(
619+
"""
620+
DELETE FROM contacts WHERE contact_id = 1
621+
THEN RETURN contact_id, first_name
622+
"""
623+
)
624+
assert cur.fetchone() == (1, "first-name")
625+
assert cur.rowcount == 1
626+
conn.commit()

0 commit comments

Comments
 (0)