# Copyright 2021 Google LLC All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import base64 import collections import datetime import decimal import math import struct import threading import time import uuid from google.api_core import datetime_helpers, exceptions from google.rpc import code_pb2 import grpc import pytest from google.cloud import spanner_v1 from google.cloud._helpers import UTC from google.cloud.spanner_admin_database_v1 import DatabaseDialect from google.cloud.spanner_v1 import _opentelemetry_tracing from google.cloud.spanner_v1._helpers import AtomicCounter, _get_cloud_region from google.cloud.spanner_v1.data_types import JsonObject from google.cloud.spanner_v1.database_sessions_manager import TransactionType from google.cloud.spanner_v1.request_id_header import ( REQ_RAND_PROCESS_ID, build_request_id, parse_request_id, ) from tests import _helpers as ot_helpers from tests._helpers import is_multiplexed_enabled from . import _helpers, _sample_data from .testdata import singer_pb2 SOME_DATE = datetime.date(2011, 1, 17) SOME_TIME = datetime.datetime(1989, 1, 17, 17, 59, 12, 345612) NANO_TIME = datetime_helpers.DatetimeWithNanoseconds(1995, 8, 31, nanosecond=987654321) POS_INF = float("+inf") NEG_INF = float("-inf") (OTHER_NAN,) = struct.unpack("--Session.run_in_transaction----------| # |---------DMLTransaction-------| # # |>----Transaction.commit---| # CreateSession should have a trace of its own, with no children # nor being a child of any other span. session_span = span_list[0] test_span = span_list[2] # assert session_span.context.trace_id != test_span.context.trace_id for span in span_list[1:]: if span.parent: assert span.parent.span_id != session_span.context.span_id def assert_parent_and_children(parent_span, children): for span in children: assert span.context.trace_id == parent_span.context.trace_id assert span.parent.span_id == parent_span.context.span_id # [CreateSession --> Batch] should have their own trace. session_run_in_txn_span = span_list[3] children_of_test_span = [session_run_in_txn_span] assert_parent_and_children(test_span, children_of_test_span) dml_txn_span = span_list[4] batch_commit_txn_span = span_list[5] children_of_session_run_in_txn_span = [dml_txn_span, batch_commit_txn_span] assert_parent_and_children( session_run_in_txn_span, children_of_session_run_in_txn_span ) def test_execute_partitioned_dml( not_postgres_emulator, sessions_database, database_dialect ): # [START spanner_test_dml_partioned_dml_update] sd = _sample_data param_types = spanner_v1.param_types delete_statement = f"DELETE FROM {sd.TABLE} WHERE true" def _setup_table(txn): txn.execute_update(delete_statement) for insert_statement in _generate_insert_statements(): txn.execute_update(insert_statement) committed = sessions_database.run_in_transaction(_setup_table) with sessions_database.snapshot(read_timestamp=committed) as snapshot: before_pdml = list(snapshot.read(sd.TABLE, sd.COLUMNS, sd.ALL)) sd._check_rows_data(before_pdml) keys = ( ["p1", "p2"] if database_dialect == DatabaseDialect.POSTGRESQL else ["email", "target"] ) placeholders = ( ["$1", "$2"] if database_dialect == DatabaseDialect.POSTGRESQL else [f"@{key}" for key in keys] ) nonesuch = "nonesuch@example.com" target = "phred@example.com" update_statement = ( f"UPDATE contacts SET email = {placeholders[0]} WHERE email = {placeholders[1]}" ) row_count = sessions_database.execute_partitioned_dml( update_statement, params={keys[0]: nonesuch, keys[1]: target}, param_types={keys[0]: param_types.STRING, keys[1]: param_types.STRING}, request_options=spanner_v1.RequestOptions( priority=spanner_v1.RequestOptions.Priority.PRIORITY_MEDIUM ), ) assert row_count == 1 row = sd.ROW_DATA[0] updated = [row[:3] + (nonesuch,)] + list(sd.ROW_DATA[1:]) with sessions_database.snapshot(read_timestamp=committed) as snapshot: after_update = list(snapshot.read(sd.TABLE, sd.COLUMNS, sd.ALL)) sd._check_rows_data(after_update, updated) row_count = sessions_database.execute_partitioned_dml(delete_statement) assert row_count == len(sd.ROW_DATA) with sessions_database.snapshot(read_timestamp=committed) as snapshot: after_delete = list(snapshot.read(sd.TABLE, sd.COLUMNS, sd.ALL)) sd._check_rows_data(after_delete, []) # [END spanner_test_dml_partioned_dml_update] def _transaction_concurrency_helper( sessions_database, unit_of_work, pkey, database_dialect=None ): initial_value = 123 num_threads = 3 # conforms to equivalent Java systest. with sessions_database.batch() as batch: batch.insert_or_update( COUNTERS_TABLE, COUNTERS_COLUMNS, [[pkey, initial_value]] ) # We don't want to run the threads' transactions in the current # session, which would fail. txn_sessions = [] for _ in range(num_threads): txn_sessions.append(sessions_database) args = ( (unit_of_work, pkey, database_dialect) if database_dialect else (unit_of_work, pkey) ) threads = [ threading.Thread(target=txn_session.run_in_transaction, args=args) for txn_session in txn_sessions ] for thread in threads: thread.start() for thread in threads: thread.join() with sessions_database.snapshot() as snapshot: keyset = spanner_v1.KeySet(keys=[(pkey,)]) rows = list(snapshot.read(COUNTERS_TABLE, COUNTERS_COLUMNS, keyset)) assert len(rows) == 1 _, value = rows[0] multiplexed_enabled = is_multiplexed_enabled(TransactionType.READ_WRITE) if multiplexed_enabled: # Allow for partial success due to transaction aborts assert initial_value < value <= initial_value + num_threads else: assert value == initial_value + num_threads def _read_w_concurrent_update(transaction, pkey): keyset = spanner_v1.KeySet(keys=[(pkey,)]) rows = list(transaction.read(COUNTERS_TABLE, COUNTERS_COLUMNS, keyset)) assert len(rows) == 1 pkey, value = rows[0] transaction.update(COUNTERS_TABLE, COUNTERS_COLUMNS, [[pkey, value + 1]]) def test_transaction_read_w_concurrent_updates( sessions_database, # TODO: Re-enable when the Emulator returns pre-commit tokens for streaming reads. not_emulator, ): pkey = "read_w_concurrent_updates" _transaction_concurrency_helper(sessions_database, _read_w_concurrent_update, pkey) def _query_w_concurrent_update(transaction, pkey, database_dialect): param_types = spanner_v1.param_types key = "p1" if database_dialect == DatabaseDialect.POSTGRESQL else "name" placeholder = "$1" if database_dialect == DatabaseDialect.POSTGRESQL else f"@{key}" sql = f"SELECT * FROM {COUNTERS_TABLE} WHERE name = {placeholder}" rows = list( transaction.execute_sql( sql, params={key: pkey}, param_types={key: param_types.STRING} ) ) assert len(rows) == 1 pkey, value = rows[0] transaction.update(COUNTERS_TABLE, COUNTERS_COLUMNS, [[pkey, value + 1]]) def test_transaction_query_w_concurrent_updates(sessions_database, database_dialect): pkey = "query_w_concurrent_updates" _transaction_concurrency_helper( sessions_database, _query_w_concurrent_update, pkey, database_dialect ) def test_transaction_read_w_abort(not_emulator, sessions_database): sd = _sample_data trigger = _ReadAbortTrigger() with sessions_database.batch() as batch: batch.delete(COUNTERS_TABLE, sd.ALL) batch.insert( COUNTERS_TABLE, COUNTERS_COLUMNS, [[trigger.KEY1, 0], [trigger.KEY2, 0]] ) provoker = threading.Thread(target=trigger.provoke_abort, args=(sessions_database,)) handler = threading.Thread(target=trigger.handle_abort, args=(sessions_database,)) provoker.start() trigger.provoker_started.wait() handler.start() trigger.handler_done.wait() provoker.join() handler.join() with sessions_database.snapshot() as snapshot: rows = list(snapshot.read(COUNTERS_TABLE, COUNTERS_COLUMNS, sd.ALL)) sd._check_row_data(rows, expected=[[trigger.KEY1, 1], [trigger.KEY2, 1]]) def _row_data(max_index): for index in range(max_index): yield ( index, f"First{index:09}", f"Last{max_index - index:09}", f"test-{index:09}@example.com", ) def _set_up_table(database, row_count): sd = _sample_data def _unit_of_work(transaction): transaction.delete(sd.TABLE, sd.ALL) transaction.insert(sd.TABLE, sd.COLUMNS, _row_data(row_count)) committed = database.run_in_transaction(_unit_of_work) return committed def _set_up_proto_table(database): sd = _sample_data def _unit_of_work(transaction): transaction.delete(sd.SINGERS_PROTO_TABLE, sd.ALL) transaction.insert( sd.SINGERS_PROTO_TABLE, sd.SINGERS_PROTO_COLUMNS, sd.SINGERS_PROTO_ROW_DATA ) committed = database.run_in_transaction(_unit_of_work) return committed def test_read_with_single_keys_index(sessions_database): # [START spanner_test_single_key_index_read] sd = _sample_data row_count = 10 columns = sd.COLUMNS[1], sd.COLUMNS[2] _set_up_table(sessions_database, row_count) expected = [[row[1], row[2]] for row in _row_data(row_count)] row = 5 keyset = [[expected[row][0], expected[row][1]]] with sessions_database.snapshot() as snapshot: results_iter = snapshot.read( sd.TABLE, columns, spanner_v1.KeySet(keys=keyset), index="name" ) rows = list(results_iter) assert rows == [expected[row]] # [END spanner_test_single_key_index_read] def test_empty_read_with_single_keys_index(sessions_database): sd = _sample_data row_count = 10 columns = sd.COLUMNS[1], sd.COLUMNS[2] _set_up_table(sessions_database, row_count) keyset = [["Non", "Existent"]] with sessions_database.snapshot() as snapshot: results_iter = snapshot.read( sd.TABLE, columns, spanner_v1.KeySet(keys=keyset), index="name" ) rows = list(results_iter) assert rows == [] def test_read_with_multiple_keys_index(sessions_database): sd = _sample_data row_count = 10 columns = sd.COLUMNS[1], sd.COLUMNS[2] _set_up_table(sessions_database, row_count) expected = [[row[1], row[2]] for row in _row_data(row_count)] with sessions_database.snapshot() as snapshot: rows = list( snapshot.read( sd.TABLE, columns, spanner_v1.KeySet(keys=expected), index="name", ) ) assert rows == expected def test_snapshot_read_w_various_staleness(sessions_database): sd = _sample_data row_count = 400 committed = _set_up_table(sessions_database, row_count) all_data_rows = list(_row_data(row_count)) before_reads = datetime.datetime.utcnow().replace(tzinfo=UTC) # Test w/ read timestamp with sessions_database.snapshot(read_timestamp=committed) as read_tx: rows = list(read_tx.read(sd.TABLE, sd.COLUMNS, sd.ALL)) sd._check_row_data(rows, all_data_rows) # Test w/ min read timestamp with sessions_database.snapshot(min_read_timestamp=committed) as min_read_ts: rows = list(min_read_ts.read(sd.TABLE, sd.COLUMNS, sd.ALL)) sd._check_row_data(rows, all_data_rows) staleness = datetime.datetime.utcnow().replace(tzinfo=UTC) - before_reads # Test w/ max staleness with sessions_database.snapshot(max_staleness=staleness) as max_staleness: rows = list(max_staleness.read(sd.TABLE, sd.COLUMNS, sd.ALL)) sd._check_row_data(rows, all_data_rows) # Test w/ exact staleness with sessions_database.snapshot(exact_staleness=staleness) as exact_staleness: rows = list(exact_staleness.read(sd.TABLE, sd.COLUMNS, sd.ALL)) sd._check_row_data(rows, all_data_rows) # Test w/ strong with sessions_database.snapshot() as strong: rows = list(strong.read(sd.TABLE, sd.COLUMNS, sd.ALL)) sd._check_row_data(rows, all_data_rows) def test_multiuse_snapshot_read_isolation_strong(sessions_database): sd = _sample_data row_count = 40 _set_up_table(sessions_database, row_count) all_data_rows = list(_row_data(row_count)) with sessions_database.snapshot(multi_use=True) as strong: before = list(strong.read(sd.TABLE, sd.COLUMNS, sd.ALL)) sd._check_row_data(before, all_data_rows) with sessions_database.batch() as batch: batch.delete(sd.TABLE, sd.ALL) after = list(strong.read(sd.TABLE, sd.COLUMNS, sd.ALL)) sd._check_row_data(after, all_data_rows) def test_multiuse_snapshot_read_isolation_read_timestamp(sessions_database): sd = _sample_data row_count = 40 committed = _set_up_table(sessions_database, row_count) all_data_rows = list(_row_data(row_count)) with sessions_database.snapshot( read_timestamp=committed, multi_use=True ) as read_ts: before = list(read_ts.read(sd.TABLE, sd.COLUMNS, sd.ALL)) sd._check_row_data(before, all_data_rows) with sessions_database.batch() as batch: batch.delete(sd.TABLE, sd.ALL) after = list(read_ts.read(sd.TABLE, sd.COLUMNS, sd.ALL)) sd._check_row_data(after, all_data_rows) def test_multiuse_snapshot_read_isolation_exact_staleness(sessions_database): sd = _sample_data row_count = 40 _set_up_table(sessions_database, row_count) all_data_rows = list(_row_data(row_count)) time.sleep(1) delta = datetime.timedelta(microseconds=1000) with sessions_database.snapshot(exact_staleness=delta, multi_use=True) as exact: before = list(exact.read(sd.TABLE, sd.COLUMNS, sd.ALL)) sd._check_row_data(before, all_data_rows) with sessions_database.batch() as batch: batch.delete(sd.TABLE, sd.ALL) after = list(exact.read(sd.TABLE, sd.COLUMNS, sd.ALL)) sd._check_row_data(after, all_data_rows) def test_read_w_index( shared_instance, database_operation_timeout, databases_to_delete, database_dialect, proto_descriptor_file, ): # Indexed reads cannot return non-indexed columns sd = _sample_data row_count = 2000 my_columns = sd.COLUMNS[0], sd.COLUMNS[2] # Create an alternate dataase w/ index. extra_ddl = ["CREATE INDEX contacts_by_last_name ON contacts(last_name)"] pool = spanner_v1.BurstyPool(labels={"testcase": "read_w_index"}) if database_dialect == DatabaseDialect.POSTGRESQL: temp_db = shared_instance.database( _helpers.unique_id("test_read", separator="_"), pool=pool, database_dialect=database_dialect, ) operation = temp_db.create() operation.result(database_operation_timeout) operation = temp_db.update_ddl( ddl_statements=_helpers.DDL_STATEMENTS + extra_ddl, ) operation.result(database_operation_timeout) else: temp_db = shared_instance.database( _helpers.unique_id("test_read", separator="_"), ddl_statements=_helpers.DDL_STATEMENTS + extra_ddl + _helpers.PROTO_COLUMNS_DDL_STATEMENTS, pool=pool, database_dialect=database_dialect, proto_descriptors=proto_descriptor_file, ) operation = temp_db.create() operation.result(database_operation_timeout) # raises on failure / timeout. databases_to_delete.append(temp_db) committed = _set_up_table(temp_db, row_count) with temp_db.snapshot(read_timestamp=committed) as snapshot: rows = list( snapshot.read(sd.TABLE, my_columns, sd.ALL, index="contacts_by_last_name") ) expected = list(reversed([(row[0], row[2]) for row in _row_data(row_count)])) sd._check_rows_data(rows, expected) # Test indexes on proto column types if database_dialect == DatabaseDialect.GOOGLE_STANDARD_SQL: # Indexed reads cannot return non-indexed columns my_columns = ( sd.SINGERS_PROTO_COLUMNS[0], sd.SINGERS_PROTO_COLUMNS[1], sd.SINGERS_PROTO_COLUMNS[4], ) committed = _set_up_proto_table(temp_db) with temp_db.snapshot(read_timestamp=committed) as snapshot: rows = list( snapshot.read( sd.SINGERS_PROTO_TABLE, my_columns, spanner_v1.KeySet(keys=[[singer_pb2.Genre.ROCK]]), index="SingerByGenre", ) ) row = sd.SINGERS_PROTO_ROW_DATA[0] expected = list([(row[0], row[1], row[4])]) sd._check_rows_data(rows, expected) def test_read_w_single_key(sessions_database): # [START spanner_test_single_key_read] sd = _sample_data row_count = 40 committed = _set_up_table(sessions_database, row_count) with sessions_database.snapshot(read_timestamp=committed) as snapshot: rows = list(snapshot.read(sd.TABLE, sd.COLUMNS, spanner_v1.KeySet(keys=[(0,)]))) all_data_rows = list(_row_data(row_count)) expected = [all_data_rows[0]] sd._check_row_data(rows, expected) # [END spanner_test_single_key_read] def test_empty_read(sessions_database): # [START spanner_test_empty_read] sd = _sample_data row_count = 40 _set_up_table(sessions_database, row_count) with sessions_database.snapshot() as snapshot: rows = list( snapshot.read(sd.TABLE, sd.COLUMNS, spanner_v1.KeySet(keys=[(40,)])) ) sd._check_row_data(rows, []) # [END spanner_test_empty_read] def test_read_w_multiple_keys(sessions_database): sd = _sample_data row_count = 40 indices = [0, 5, 17] committed = _set_up_table(sessions_database, row_count) with sessions_database.snapshot(read_timestamp=committed) as snapshot: rows = list( snapshot.read( sd.TABLE, sd.COLUMNS, spanner_v1.KeySet(keys=[(index,) for index in indices]), ) ) all_data_rows = list(_row_data(row_count)) expected = [row for row in all_data_rows if row[0] in indices] sd._check_row_data(rows, expected) def test_read_w_limit(sessions_database): sd = _sample_data row_count = 3000 limit = 100 committed = _set_up_table(sessions_database, row_count) with sessions_database.snapshot(read_timestamp=committed) as snapshot: rows = list(snapshot.read(sd.TABLE, sd.COLUMNS, sd.ALL, limit=limit)) all_data_rows = list(_row_data(row_count)) expected = all_data_rows[:limit] sd._check_row_data(rows, expected) def test_read_w_ranges(sessions_database): sd = _sample_data row_count = 3000 start = 1000 end = 2000 committed = _set_up_table(sessions_database, row_count) with sessions_database.snapshot( read_timestamp=committed, multi_use=True, ) as snapshot: all_data_rows = list(_row_data(row_count)) single_key = spanner_v1.KeyRange(start_closed=[start], end_open=[start + 1]) keyset = spanner_v1.KeySet(ranges=(single_key,)) rows = list(snapshot.read(sd.TABLE, sd.COLUMNS, keyset)) expected = all_data_rows[start : start + 1] sd._check_rows_data(rows, expected) closed_closed = spanner_v1.KeyRange(start_closed=[start], end_closed=[end]) keyset = spanner_v1.KeySet(ranges=(closed_closed,)) rows = list(snapshot.read(sd.TABLE, sd.COLUMNS, keyset)) expected = all_data_rows[start : end + 1] sd._check_row_data(rows, expected) closed_open = spanner_v1.KeyRange(start_closed=[start], end_open=[end]) keyset = spanner_v1.KeySet(ranges=(closed_open,)) rows = list(snapshot.read(sd.TABLE, sd.COLUMNS, keyset)) expected = all_data_rows[start:end] sd._check_row_data(rows, expected) open_open = spanner_v1.KeyRange(start_open=[start], end_open=[end]) keyset = spanner_v1.KeySet(ranges=(open_open,)) rows = list(snapshot.read(sd.TABLE, sd.COLUMNS, keyset)) expected = all_data_rows[start + 1 : end] sd._check_row_data(rows, expected) open_closed = spanner_v1.KeyRange(start_open=[start], end_closed=[end]) keyset = spanner_v1.KeySet(ranges=(open_closed,)) rows = list(snapshot.read(sd.TABLE, sd.COLUMNS, keyset)) expected = all_data_rows[start + 1 : end + 1] sd._check_row_data(rows, expected) def test_read_partial_range_until_end(sessions_database): sd = _sample_data row_count = 3000 start = 1000 committed = _set_up_table(sessions_database, row_count) with sessions_database.snapshot( read_timestamp=committed, multi_use=True, ) as snapshot: all_data_rows = list(_row_data(row_count)) expected_map = { ("start_closed", "end_closed"): all_data_rows[start:], ("start_closed", "end_open"): [], ("start_open", "end_closed"): all_data_rows[start + 1 :], ("start_open", "end_open"): [], } for start_arg in ("start_closed", "start_open"): for end_arg in ("end_closed", "end_open"): range_kwargs = {start_arg: [start], end_arg: []} keyset = spanner_v1.KeySet( ranges=(spanner_v1.KeyRange(**range_kwargs),) ) rows = list(snapshot.read(sd.TABLE, sd.COLUMNS, keyset)) expected = expected_map[(start_arg, end_arg)] sd._check_row_data(rows, expected) def test_read_partial_range_from_beginning(sessions_database): sd = _sample_data row_count = 3000 end = 2000 committed = _set_up_table(sessions_database, row_count) all_data_rows = list(_row_data(row_count)) expected_map = { ("start_closed", "end_closed"): all_data_rows[: end + 1], ("start_closed", "end_open"): all_data_rows[:end], ("start_open", "end_closed"): [], ("start_open", "end_open"): [], } for start_arg in ("start_closed", "start_open"): for end_arg in ("end_closed", "end_open"): range_kwargs = {start_arg: [], end_arg: [end]} keyset = spanner_v1.KeySet(ranges=(spanner_v1.KeyRange(**range_kwargs),)) with sessions_database.snapshot( read_timestamp=committed, multi_use=True, ) as snapshot: rows = list(snapshot.read(sd.TABLE, sd.COLUMNS, keyset)) expected = expected_map[(start_arg, end_arg)] sd._check_row_data(rows, expected) def test_read_with_range_keys_index_single_key(sessions_database): sd = _sample_data row_count = 10 columns = sd.COLUMNS[1], sd.COLUMNS[2] data = [[row[1], row[2]] for row in _row_data(row_count)] _set_up_table(sessions_database, row_count) start = 3 krange = spanner_v1.KeyRange(start_closed=data[start], end_open=data[start + 1]) keyset = spanner_v1.KeySet(ranges=(krange,)) with sessions_database.snapshot() as snapshot: rows = list(snapshot.read(sd.TABLE, columns, keyset, index="name")) assert rows == data[start : start + 1] def test_read_with_range_keys_index_closed_closed(sessions_database): sd = _sample_data row_count = 10 columns = sd.COLUMNS[1], sd.COLUMNS[2] data = [[row[1], row[2]] for row in _row_data(row_count)] _set_up_table(sessions_database, row_count) start, end = 3, 7 krange = spanner_v1.KeyRange(start_closed=data[start], end_closed=data[end]) keyset = spanner_v1.KeySet(ranges=(krange,)) with sessions_database.snapshot() as snapshot: rows = list(snapshot.read(sd.TABLE, columns, keyset, index="name")) assert rows == data[start : end + 1] def test_read_with_range_keys_index_closed_open(sessions_database): sd = _sample_data row_count = 10 columns = sd.COLUMNS[1], sd.COLUMNS[2] data = [[row[1], row[2]] for row in _row_data(row_count)] _set_up_table(sessions_database, row_count) start, end = 3, 7 krange = spanner_v1.KeyRange(start_closed=data[start], end_open=data[end]) keyset = spanner_v1.KeySet(ranges=(krange,)) with sessions_database.snapshot() as snapshot: rows = list(snapshot.read(sd.TABLE, columns, keyset, index="name")) assert rows == data[start:end] def test_read_with_range_keys_index_open_closed(sessions_database): sd = _sample_data row_count = 10 columns = sd.COLUMNS[1], sd.COLUMNS[2] data = [[row[1], row[2]] for row in _row_data(row_count)] _set_up_table(sessions_database, row_count) start, end = 3, 7 krange = spanner_v1.KeyRange(start_open=data[start], end_closed=data[end]) keyset = spanner_v1.KeySet(ranges=(krange,)) with sessions_database.snapshot() as snapshot: rows = list(snapshot.read(sd.TABLE, columns, keyset, index="name")) assert rows == data[start + 1 : end + 1] def test_read_with_range_keys_index_open_open(sessions_database): sd = _sample_data row_count = 10 columns = sd.COLUMNS[1], sd.COLUMNS[2] data = [[row[1], row[2]] for row in _row_data(row_count)] _set_up_table(sessions_database, row_count) start, end = 3, 7 krange = spanner_v1.KeyRange(start_open=data[start], end_open=data[end]) keyset = spanner_v1.KeySet(ranges=(krange,)) with sessions_database.snapshot() as snapshot: rows = list(snapshot.read(sd.TABLE, columns, keyset, index="name")) assert rows == data[start + 1 : end] def test_read_with_range_keys_index_limit_closed_closed(sessions_database): sd = _sample_data row_count = 10 columns = sd.COLUMNS[1], sd.COLUMNS[2] data = [[row[1], row[2]] for row in _row_data(row_count)] _set_up_table(sessions_database, row_count) start, end, limit = 3, 7, 2 krange = spanner_v1.KeyRange(start_closed=data[start], end_closed=data[end]) keyset = spanner_v1.KeySet(ranges=(krange,)) with sessions_database.snapshot() as snapshot: rows = list(snapshot.read(sd.TABLE, columns, keyset, index="name", limit=limit)) expected = data[start : end + 1] assert rows == expected[:limit] def test_read_with_range_keys_index_limit_closed_open(sessions_database): sd = _sample_data row_count = 10 columns = sd.COLUMNS[1], sd.COLUMNS[2] data = [[row[1], row[2]] for row in _row_data(row_count)] _set_up_table(sessions_database, row_count) start, end, limit = 3, 7, 2 krange = spanner_v1.KeyRange(start_closed=data[start], end_open=data[end]) keyset = spanner_v1.KeySet(ranges=(krange,)) with sessions_database.snapshot() as snapshot: rows = list(snapshot.read(sd.TABLE, columns, keyset, index="name", limit=limit)) expected = data[start:end] assert rows == expected[:limit] def test_read_with_range_keys_index_limit_open_closed(sessions_database): sd = _sample_data row_count = 10 columns = sd.COLUMNS[1], sd.COLUMNS[2] data = [[row[1], row[2]] for row in _row_data(row_count)] _set_up_table(sessions_database, row_count) start, end, limit = 3, 7, 2 krange = spanner_v1.KeyRange(start_open=data[start], end_closed=data[end]) keyset = spanner_v1.KeySet(ranges=(krange,)) with sessions_database.snapshot() as snapshot: rows = list(snapshot.read(sd.TABLE, columns, keyset, index="name", limit=limit)) expected = data[start + 1 : end + 1] assert rows == expected[:limit] def test_read_with_range_keys_index_limit_open_open(sessions_database): sd = _sample_data row_count = 10 columns = sd.COLUMNS[1], sd.COLUMNS[2] data = [[row[1], row[2]] for row in _row_data(row_count)] _set_up_table(sessions_database, row_count) start, end, limit = 3, 7, 2 krange = spanner_v1.KeyRange(start_open=data[start], end_open=data[end]) keyset = spanner_v1.KeySet(ranges=(krange,)) with sessions_database.snapshot() as snapshot: rows = list(snapshot.read(sd.TABLE, columns, keyset, index="name", limit=limit)) expected = data[start + 1 : end] assert rows == expected[:limit] def test_read_with_range_keys_and_index_closed_closed(sessions_database): sd = _sample_data row_count = 10 columns = sd.COLUMNS[1], sd.COLUMNS[2] _set_up_table(sessions_database, row_count) data = [[row[1], row[2]] for row in _row_data(row_count)] keyrow, start, end = 1, 3, 7 closed_closed = spanner_v1.KeyRange(start_closed=data[start], end_closed=data[end]) keys = [data[keyrow]] keyset = spanner_v1.KeySet(keys=keys, ranges=(closed_closed,)) with sessions_database.snapshot() as snapshot: rows = list(snapshot.read(sd.TABLE, columns, keyset, index="name")) expected = [data[keyrow]] + data[start : end + 1] assert rows == expected def test_read_with_range_keys_and_index_closed_open(sessions_database): sd = _sample_data row_count = 10 columns = sd.COLUMNS[1], sd.COLUMNS[2] _set_up_table(sessions_database, row_count) data = [[row[1], row[2]] for row in _row_data(row_count)] keyrow, start, end = 1, 3, 7 closed_open = spanner_v1.KeyRange(start_closed=data[start], end_open=data[end]) keys = [data[keyrow]] keyset = spanner_v1.KeySet(keys=keys, ranges=(closed_open,)) with sessions_database.snapshot() as snapshot: rows = list(snapshot.read(sd.TABLE, columns, keyset, index="name")) expected = [data[keyrow]] + data[start:end] assert rows == expected def test_read_with_range_keys_and_index_open_closed(sessions_database): sd = _sample_data row_count = 10 columns = sd.COLUMNS[1], sd.COLUMNS[2] _set_up_table(sessions_database, row_count) data = [[row[1], row[2]] for row in _row_data(row_count)] keyrow, start, end = 1, 3, 7 open_closed = spanner_v1.KeyRange(start_open=data[start], end_closed=data[end]) keys = [data[keyrow]] keyset = spanner_v1.KeySet(keys=keys, ranges=(open_closed,)) with sessions_database.snapshot() as snapshot: rows = list(snapshot.read(sd.TABLE, columns, keyset, index="name")) expected = [data[keyrow]] + data[start + 1 : end + 1] assert rows == expected def test_read_with_range_keys_and_index_open_open(sessions_database): sd = _sample_data row_count = 10 columns = sd.COLUMNS[1], sd.COLUMNS[2] _set_up_table(sessions_database, row_count) data = [[row[1], row[2]] for row in _row_data(row_count)] keyrow, start, end = 1, 3, 7 open_open = spanner_v1.KeyRange(start_open=data[start], end_open=data[end]) keys = [data[keyrow]] keyset = spanner_v1.KeySet(keys=keys, ranges=(open_open,)) with sessions_database.snapshot() as snapshot: rows = list(snapshot.read(sd.TABLE, columns, keyset, index="name")) expected = [data[keyrow]] + data[start + 1 : end] assert rows == expected def test_partition_read_w_index(sessions_database, not_emulator, not_experimental_host): sd = _sample_data row_count = 10 columns = sd.COLUMNS[1], sd.COLUMNS[2] committed = _set_up_table(sessions_database, row_count) expected = [[row[1], row[2]] for row in _row_data(row_count)] union = [] batch_txn = sessions_database.batch_snapshot(read_timestamp=committed) batches = batch_txn.generate_read_batches( sd.TABLE, columns, spanner_v1.KeySet(all_=True), index="name", data_boost_enabled=True, ) for batch in batches: p_results_iter = batch_txn.process(batch) union.extend(list(p_results_iter)) assert union == expected batch_txn.close() def test_execute_sql_w_manual_consume(sessions_database): sd = _sample_data row_count = 3000 committed = _set_up_table(sessions_database, row_count) for lazy_decode in [False, True]: with sessions_database.snapshot(read_timestamp=committed) as snapshot: streamed = snapshot.execute_sql(sd.SQL, lazy_decode=lazy_decode) keyset = spanner_v1.KeySet(all_=True) with sessions_database.snapshot(read_timestamp=committed) as snapshot: rows = list( snapshot.read(sd.TABLE, sd.COLUMNS, keyset, lazy_decode=lazy_decode) ) assert list(streamed) == rows assert streamed._current_row == [] assert streamed._pending_chunk is None def test_execute_sql_w_to_dict_list(sessions_database): sd = _sample_data row_count = 40 _set_up_table(sessions_database, row_count) with sessions_database.snapshot() as snapshot: rows = snapshot.execute_sql(sd.SQL).to_dict_list() all_data_rows = list(_row_data(row_count)) row_data = [list(row.values()) for row in rows] sd._check_row_data(row_data, all_data_rows) assert all(set(row.keys()) == set(sd.COLUMNS) for row in rows) def _check_sql_results( database, sql, params, param_types=None, expected=None, order=True, recurse_into_lists=True, column_info=None, ): if order and "ORDER" not in sql: sql += " ORDER BY pkey" for lazy_decode in [False, True]: with database.snapshot() as snapshot: iterator = snapshot.execute_sql( sql, params=params, param_types=param_types, column_info=column_info, lazy_decode=lazy_decode, ) rows = list(iterator) if lazy_decode: for index, row in enumerate(rows): rows[index] = iterator.decode_row(row) _sample_data._check_rows_data( rows, expected=expected, recurse_into_lists=recurse_into_lists ) def test_multiuse_snapshot_execute_sql_isolation_strong(sessions_database): sd = _sample_data row_count = 40 _set_up_table(sessions_database, row_count) all_data_rows = list(_row_data(row_count)) with sessions_database.snapshot(multi_use=True) as strong: before = list(strong.execute_sql(sd.SQL)) sd._check_row_data(before, all_data_rows) with sessions_database.batch() as batch: batch.delete(sd.TABLE, sd.ALL) after = list(strong.execute_sql(sd.SQL)) sd._check_row_data(after, all_data_rows) def test_execute_sql_returning_array_of_struct(sessions_database, not_postgres): sql = ( "SELECT ARRAY(SELECT AS STRUCT C1, C2 " "FROM (SELECT 'a' AS C1, 1 AS C2 " "UNION ALL SELECT 'b' AS C1, 2 AS C2) " "ORDER BY C1 ASC)" ) _check_sql_results( sessions_database, sql=sql, params=None, param_types=None, expected=[[[["a", 1], ["b", 2]]]], ) def test_execute_sql_returning_empty_array_of_struct(sessions_database, not_postgres): sql = ( "SELECT ARRAY(SELECT AS STRUCT C1, C2 " "FROM (SELECT 2 AS C1) X " "JOIN (SELECT 1 AS C2) Y " "ON X.C1 = Y.C2 " "ORDER BY C1 ASC)" ) sessions_database.snapshot(multi_use=True) _check_sql_results( sessions_database, sql=sql, params=None, param_types=None, expected=[[[]]] ) def test_invalid_type(sessions_database): sd = _sample_data table = "counters" columns = ("name", "value") valid_input = (("", 0),) with sessions_database.batch() as batch: batch.delete(table, sd.ALL) batch.insert(table, columns, valid_input) invalid_input = ((0, ""),) with pytest.raises(exceptions.FailedPrecondition): with sessions_database.batch() as batch: batch.delete(table, sd.ALL) batch.insert(table, columns, invalid_input) def test_execute_sql_select_1(sessions_database): sessions_database.snapshot(multi_use=True) # Hello, world query _check_sql_results( sessions_database, sql="SELECT 1", params=None, param_types=None, expected=[(1,)], order=False, ) def _bind_test_helper( database, database_dialect, param_type, single_value, array_value, expected_array_value=None, recurse_into_lists=True, column_info=None, expected_single_value=None, ): database.snapshot(multi_use=True) key = "p1" if database_dialect == DatabaseDialect.POSTGRESQL else "v" placeholder = "$1" if database_dialect == DatabaseDialect.POSTGRESQL else f"@{key}" if expected_single_value is None: expected_single_value = single_value # Bind a non-null _check_sql_results( database, sql=f"SELECT {placeholder} as column", params={key: single_value}, param_types={key: param_type}, expected=[(expected_single_value,)], order=False, recurse_into_lists=recurse_into_lists, column_info=column_info, ) # Bind a null _check_sql_results( database, sql=f"SELECT {placeholder} as column", params={key: None}, param_types={key: param_type}, expected=[(None,)], order=False, recurse_into_lists=recurse_into_lists, column_info=column_info, ) # Bind an array of array_element_type = param_type array_type = spanner_v1.Type( code=spanner_v1.TypeCode.ARRAY, array_element_type=array_element_type ) if expected_array_value is None: expected_array_value = array_value _check_sql_results( database, sql=f"SELECT {placeholder} as column", params={key: array_value}, param_types={key: array_type}, expected=[(expected_array_value,)], order=False, recurse_into_lists=recurse_into_lists, column_info=column_info, ) # Bind an empty array of _check_sql_results( database, sql=f"SELECT {placeholder} as column", params={key: []}, param_types={key: array_type}, expected=[([],)], order=False, recurse_into_lists=recurse_into_lists, column_info=column_info, ) # Bind a null array of _check_sql_results( database, sql=f"SELECT {placeholder} as column", params={key: None}, param_types={key: array_type}, expected=[(None,)], order=False, recurse_into_lists=recurse_into_lists, column_info=column_info, ) def test_execute_sql_w_string_bindings(sessions_database, database_dialect): _bind_test_helper( sessions_database, database_dialect, spanner_v1.param_types.STRING, "Phred", ["Phred", "Bharney"], ) def test_execute_sql_w_bool_bindings(sessions_database, database_dialect): _bind_test_helper( sessions_database, database_dialect, spanner_v1.param_types.BOOL, True, [True, False, True], ) def test_execute_sql_w_int64_bindings(sessions_database, database_dialect): _bind_test_helper( sessions_database, database_dialect, spanner_v1.param_types.INT64, 42, [123, 456, 789], ) def test_execute_sql_w_float64_bindings(sessions_database, database_dialect): _bind_test_helper( sessions_database, database_dialect, spanner_v1.param_types.FLOAT64, 42.3, [12.3, 456.0, 7.89], ) def test_execute_sql_w_float_bindings_transfinite(sessions_database, database_dialect): key = "p1" if database_dialect == DatabaseDialect.POSTGRESQL else "neg_inf" placeholder = "$1" if database_dialect == DatabaseDialect.POSTGRESQL else f"@{key}" # Find -inf _check_sql_results( sessions_database, sql=f"SELECT {placeholder}", params={key: NEG_INF}, param_types={key: spanner_v1.param_types.FLOAT64}, expected=[(NEG_INF,)], order=False, ) key = "p1" if database_dialect == DatabaseDialect.POSTGRESQL else "pos_inf" placeholder = "$1" if database_dialect == DatabaseDialect.POSTGRESQL else f"@{key}" # Find +inf _check_sql_results( sessions_database, sql=f"SELECT {placeholder}", params={key: POS_INF}, param_types={key: spanner_v1.param_types.FLOAT64}, expected=[(POS_INF,)], order=False, ) def test_execute_sql_w_float32_bindings(sessions_database, database_dialect): pytest.skip("float32 is not yet supported in production.") _bind_test_helper( sessions_database, database_dialect, spanner_v1.param_types.FLOAT32, 42.3, [12.3, 456.0, 7.89], ) def test_execute_sql_w_float32_bindings_transfinite( sessions_database, database_dialect ): pytest.skip("float32 is not yet supported in production.") key = "p1" if database_dialect == DatabaseDialect.POSTGRESQL else "neg_inf" placeholder = "$1" if database_dialect == DatabaseDialect.POSTGRESQL else f"@{key}" # Find -inf _check_sql_results( sessions_database, sql=f"SELECT {placeholder}", params={key: NEG_INF}, param_types={key: spanner_v1.param_types.FLOAT32}, expected=[(NEG_INF,)], order=False, ) key = "p1" if database_dialect == DatabaseDialect.POSTGRESQL else "pos_inf" placeholder = "$1" if database_dialect == DatabaseDialect.POSTGRESQL else f"@{key}" # Find +inf _check_sql_results( sessions_database, sql=f"SELECT {placeholder}", params={key: POS_INF}, param_types={key: spanner_v1.param_types.FLOAT32}, expected=[(POS_INF,)], order=False, ) def test_execute_sql_w_bytes_bindings(sessions_database, database_dialect): _bind_test_helper( sessions_database, database_dialect, spanner_v1.param_types.BYTES, b"DEADBEEF", [b"FACEDACE", b"DEADBEEF"], ) def test_execute_sql_w_timestamp_bindings(sessions_database, database_dialect): timestamp_1 = datetime_helpers.DatetimeWithNanoseconds( 1989, 1, 17, 17, 59, 12, nanosecond=345612789 ) timestamp_2 = datetime_helpers.DatetimeWithNanoseconds( 1989, 1, 17, 17, 59, 13, nanosecond=456127893 ) timestamps = [timestamp_1, timestamp_2] # In round-trip, timestamps acquire a timezone value. expected_timestamps = [timestamp.replace(tzinfo=UTC) for timestamp in timestamps] _bind_test_helper( sessions_database, database_dialect, spanner_v1.param_types.TIMESTAMP, timestamp_1, timestamps, expected_timestamps, recurse_into_lists=False, ) def test_execute_sql_w_date_bindings(sessions_database, not_postgres, database_dialect): dates = [SOME_DATE, SOME_DATE + datetime.timedelta(days=1)] _bind_test_helper( sessions_database, database_dialect, spanner_v1.param_types.DATE, SOME_DATE, dates, ) def test_execute_sql_w_numeric_bindings( not_emulator, sessions_database, database_dialect ): if database_dialect == DatabaseDialect.POSTGRESQL: _bind_test_helper( sessions_database, database_dialect, spanner_v1.param_types.PG_NUMERIC, NUMERIC_1, [NUMERIC_1, NUMERIC_2], ) else: _bind_test_helper( sessions_database, database_dialect, spanner_v1.param_types.NUMERIC, NUMERIC_1, [NUMERIC_1, NUMERIC_2], ) def test_execute_sql_w_json_bindings( not_emulator, not_postgres, sessions_database, database_dialect ): _bind_test_helper( sessions_database, database_dialect, spanner_v1.param_types.JSON, JSON_1, [JSON_1, JSON_2], ) def test_execute_sql_w_jsonb_bindings( not_google_standard_sql, sessions_database, database_dialect ): _bind_test_helper( sessions_database, database_dialect, spanner_v1.param_types.PG_JSONB, JSON_1, [JSON_1, JSON_2], ) def test_execute_sql_w_oid_bindings( not_google_standard_sql, sessions_database, database_dialect ): _bind_test_helper( sessions_database, database_dialect, spanner_v1.param_types.PG_OID, 123, [123, 456], ) def test_execute_sql_w_query_param_struct(sessions_database, not_postgres): name = "Phred" count = 123 size = 23.456 height = 188.0 weight = 97.6 param_types = spanner_v1.param_types record_type = param_types.Struct( [ param_types.StructField("name", param_types.STRING), param_types.StructField("count", param_types.INT64), param_types.StructField("size", param_types.FLOAT64), param_types.StructField( "nested", param_types.Struct( [ param_types.StructField("height", param_types.FLOAT64), param_types.StructField("weight", param_types.FLOAT64), ] ), ), ] ) # Query with null struct, explicit type _check_sql_results( sessions_database, sql="SELECT @r.name, @r.count, @r.size, @r.nested.weight", params={"r": None}, param_types={"r": record_type}, expected=[(None, None, None, None)], order=False, ) # Query with non-null struct, explicit type, NULL values _check_sql_results( sessions_database, sql="SELECT @r.name, @r.count, @r.size, @r.nested.weight", params={"r": (None, None, None, None)}, param_types={"r": record_type}, expected=[(None, None, None, None)], order=False, ) # Query with non-null struct, explicit type, nested NULL values _check_sql_results( sessions_database, sql="SELECT @r.nested.weight", params={"r": (None, None, None, (None, None))}, param_types={"r": record_type}, expected=[(None,)], order=False, ) # Query with non-null struct, explicit type _check_sql_results( sessions_database, sql="SELECT @r.name, @r.count, @r.size, @r.nested.weight", params={"r": (name, count, size, (height, weight))}, param_types={"r": record_type}, expected=[(name, count, size, weight)], order=False, ) # Query with empty struct, explicitly empty type empty_type = param_types.Struct([]) _check_sql_results( sessions_database, sql="SELECT @r IS NULL", params={"r": ()}, param_types={"r": empty_type}, expected=[(False,)], order=False, ) # Query with null struct, explicitly empty type _check_sql_results( sessions_database, sql="SELECT @r IS NULL", params={"r": None}, param_types={"r": empty_type}, expected=[(True,)], order=False, ) # Query with equality check for struct value struct_equality_query = ( "SELECT " '@struct_param=STRUCT(1,"bob")' ) struct_type = param_types.Struct( [ param_types.StructField("threadf", param_types.INT64), param_types.StructField("userf", param_types.STRING), ] ) _check_sql_results( sessions_database, sql=struct_equality_query, params={"struct_param": (1, "bob")}, param_types={"struct_param": struct_type}, expected=[(True,)], order=False, ) # Query with nullness test for struct _check_sql_results( sessions_database, sql="SELECT @struct_param IS NULL", params={"struct_param": None}, param_types={"struct_param": struct_type}, expected=[(True,)], order=False, ) # Query with null array-of-struct array_elem_type = param_types.Struct( [param_types.StructField("threadid", param_types.INT64)] ) array_type = param_types.Array(array_elem_type) _check_sql_results( sessions_database, sql="SELECT a.threadid FROM UNNEST(@struct_arr_param) a", params={"struct_arr_param": None}, param_types={"struct_arr_param": array_type}, expected=[], order=False, ) # Query with non-null array-of-struct _check_sql_results( sessions_database, sql="SELECT a.threadid FROM UNNEST(@struct_arr_param) a", params={"struct_arr_param": [(123,), (456,)]}, param_types={"struct_arr_param": array_type}, expected=[(123,), (456,)], order=False, ) # Query with null array-of-struct field struct_type_with_array_field = param_types.Struct( [ param_types.StructField("intf", param_types.INT64), param_types.StructField("arraysf", array_type), ] ) _check_sql_results( sessions_database, sql="SELECT a.threadid FROM UNNEST(@struct_param.arraysf) a", params={"struct_param": (123, None)}, param_types={"struct_param": struct_type_with_array_field}, expected=[], order=False, ) # Query with non-null array-of-struct field _check_sql_results( sessions_database, sql="SELECT a.threadid FROM UNNEST(@struct_param.arraysf) a", params={"struct_param": (123, ((456,), (789,)))}, param_types={"struct_param": struct_type_with_array_field}, expected=[(456,), (789,)], order=False, ) # Query with anonymous / repeated-name fields anon_repeated_array_elem_type = param_types.Struct( [ param_types.StructField("", param_types.INT64), param_types.StructField("", param_types.STRING), ] ) anon_repeated_array_type = param_types.Array(anon_repeated_array_elem_type) _check_sql_results( sessions_database, sql="SELECT CAST(t as STRUCT).* " "FROM UNNEST(@struct_param) t", params={"struct_param": [(123, "abcdef")]}, param_types={"struct_param": anon_repeated_array_type}, expected=[(123, "abcdef")], order=False, ) # Query and return a struct parameter value_type = param_types.Struct( [ param_types.StructField("message", param_types.STRING), param_types.StructField("repeat", param_types.INT64), ] ) value_query = ( "SELECT ARRAY(SELECT AS STRUCT message, repeat " "FROM (SELECT @value.message AS message, " "@value.repeat AS repeat)) AS value" ) _check_sql_results( sessions_database, sql=value_query, params={"value": ("hello", 1)}, param_types={"value": value_type}, expected=[([["hello", 1]],)], order=False, ) def test_execute_sql_w_proto_message_bindings( not_postgres, sessions_database, database_dialect ): singer_info = _sample_data.SINGER_INFO_1 singer_info_bytes = base64.b64encode(singer_info.SerializeToString()) _bind_test_helper( sessions_database, database_dialect, spanner_v1.param_types.ProtoMessage(singer_info), singer_info, [singer_info, None], column_info={"column": singer_pb2.SingerInfo()}, ) # Tests compatibility between proto message and bytes column types _bind_test_helper( sessions_database, database_dialect, spanner_v1.param_types.ProtoMessage(singer_info), singer_info_bytes, [singer_info_bytes, None], expected_single_value=singer_info, expected_array_value=[singer_info, None], column_info={"column": singer_pb2.SingerInfo()}, ) # Tests compatibility between proto message and bytes column types _bind_test_helper( sessions_database, database_dialect, spanner_v1.param_types.BYTES, singer_info, [singer_info, None], expected_single_value=singer_info_bytes, expected_array_value=[singer_info_bytes, None], ) def test_execute_sql_w_proto_enum_bindings( not_emulator, not_postgres, sessions_database, database_dialect ): singer_genre = _sample_data.SINGER_GENRE_1 _bind_test_helper( sessions_database, database_dialect, spanner_v1.param_types.ProtoEnum(singer_pb2.Genre), singer_genre, [singer_genre, None], ) # Tests compatibility between proto enum and int64 column types _bind_test_helper( sessions_database, database_dialect, spanner_v1.param_types.ProtoEnum(singer_pb2.Genre), 3, [3, None], expected_single_value="ROCK", expected_array_value=["ROCK", None], column_info={"column": singer_pb2.Genre}, ) # Tests compatibility between proto enum and int64 column types _bind_test_helper( sessions_database, database_dialect, spanner_v1.param_types.INT64, singer_genre, [singer_genre, None], ) def test_execute_sql_returning_transfinite_floats(sessions_database, not_postgres): with sessions_database.snapshot(multi_use=True) as snapshot: # Query returning -inf, +inf, NaN as column values rows = list( snapshot.execute_sql( "SELECT " 'CAST("-inf" AS FLOAT64), ' 'CAST("+inf" AS FLOAT64), ' 'CAST("NaN" AS FLOAT64)' ) ) assert len(rows) == 1 assert rows[0][0] == float("-inf") assert rows[0][1] == float("+inf") # NaNs cannot be compared by equality. assert math.isnan(rows[0][2]) # Query returning array of -inf, +inf, NaN as one column rows = list( snapshot.execute_sql( "SELECT" ' [CAST("-inf" AS FLOAT64),' ' CAST("+inf" AS FLOAT64),' ' CAST("NaN" AS FLOAT64)]' ) ) assert len(rows) == 1 float_array = rows[0][0] assert float_array[0] == float("-inf") assert float_array[1] == float("+inf") # NaNs cannot be searched for by equality. assert math.isnan(float_array[2]) def test_execute_sql_w_uuid_bindings(sessions_database, database_dialect): if database_dialect == DatabaseDialect.POSTGRESQL: pytest.skip("UUID parameter type is not yet supported in PostgreSQL dialect.") _bind_test_helper( sessions_database, database_dialect, spanner_v1.param_types.UUID, uuid.uuid4(), [uuid.uuid4(), uuid.uuid4()], ) def test_partition_query(sessions_database, not_emulator, not_experimental_host): row_count = 40 sql = f"SELECT * FROM {_sample_data.TABLE}" committed = _set_up_table(sessions_database, row_count) # Paritioned query does not support ORDER BY all_data_rows = set(_row_data(row_count)) union = set() batch_txn = sessions_database.batch_snapshot(read_timestamp=committed) for batch in batch_txn.generate_query_batches(sql, data_boost_enabled=True): p_results_iter = batch_txn.process(batch) # Lists aren't hashable so the results need to be converted rows = [tuple(result) for result in p_results_iter] union.update(set(rows)) assert union == all_data_rows batch_txn.close() def test_run_partition_query(sessions_database, not_emulator, not_experimental_host): row_count = 40 sql = f"SELECT * FROM {_sample_data.TABLE}" committed = _set_up_table(sessions_database, row_count) # Paritioned query does not support ORDER BY all_data_rows = set(_row_data(row_count)) union = set() batch_txn = sessions_database.batch_snapshot(read_timestamp=committed) p_results_iter = batch_txn.run_partitioned_query(sql, data_boost_enabled=True) # Lists aren't hashable so the results need to be converted rows = [tuple(result) for result in p_results_iter] union.update(set(rows)) assert union == all_data_rows batch_txn.close() def test_mutation_groups_insert_or_update_then_query(not_emulator, sessions_database): sd = _sample_data num_groups = 3 num_mutations_per_group = len(sd.BATCH_WRITE_ROW_DATA) // num_groups with sessions_database.batch() as batch: batch.delete(sd.TABLE, sd.ALL) with sessions_database.mutation_groups() as groups: for i in range(num_groups): group = groups.group() for j in range(num_mutations_per_group): group.insert_or_update( sd.TABLE, sd.COLUMNS, [sd.BATCH_WRITE_ROW_DATA[i * num_mutations_per_group + j]], ) # Response indexes received seen = collections.Counter() for response in groups.batch_write(): _check_batch_status(response.status.code) assert response.commit_timestamp is not None assert len(response.indexes) > 0 seen.update(response.indexes) # All indexes must be in the range [0, num_groups-1] and seen exactly once assert len(seen) == num_groups assert all((0 <= idx < num_groups and ct == 1) for (idx, ct) in seen.items()) # Verify the writes by reading from the database with sessions_database.snapshot() as snapshot: rows = list(snapshot.execute_sql(sd.SQL)) sd._check_rows_data(rows, sd.BATCH_WRITE_ROW_DATA) def _check_batch_status(status_code, expected=code_pb2.OK): if status_code != expected: _status_code_to_grpc_status_code = { member.value[0]: member for member in grpc.StatusCode } grpc_status_code = _status_code_to_grpc_status_code[status_code] call = _helpers.FauxCall(status_code) raise exceptions.from_grpc_status( grpc_status_code, "batch_update failed", errors=[call] ) def get_param_info(param_names, database_dialect): keys = [f"p{i + 1}" for i in range(len(param_names))] if database_dialect == DatabaseDialect.POSTGRESQL: placeholders = [f"${i + 1}" for i in range(len(param_names))] else: placeholders = [f"@p{i + 1}" for i in range(len(param_names))] return keys, placeholders def test_interval(sessions_database, database_dialect, not_emulator): from google.cloud.spanner_v1 import Interval def setup_table(): if database_dialect == DatabaseDialect.POSTGRESQL: sessions_database.update_ddl( [ """ CREATE TABLE IntervalTable ( key text primary key, create_time timestamptz, expiry_time timestamptz, expiry_within_month bool GENERATED ALWAYS AS (expiry_time - create_time < INTERVAL '30' DAY) STORED, interval_array_len bigint GENERATED ALWAYS AS (ARRAY_LENGTH(ARRAY[INTERVAL '1-2 3 4:5:6'], 1)) STORED ) """ ] ).result() else: sessions_database.update_ddl( [ """ CREATE TABLE IntervalTable ( key STRING(MAX), create_time TIMESTAMP, expiry_time TIMESTAMP, expiry_within_month bool AS (expiry_time - create_time < INTERVAL 30 DAY), interval_array_len INT64 AS (ARRAY_LENGTH(ARRAY[INTERVAL '1-2 3 4:5:6' YEAR TO SECOND])) ) PRIMARY KEY (key) """ ] ).result() def insert_test1(transaction): keys, placeholders = get_param_info( ["key", "create_time", "expiry_time"], database_dialect ) transaction.execute_update( f""" INSERT INTO IntervalTable (key, create_time, expiry_time) VALUES ({placeholders[0]}, {placeholders[1]}, {placeholders[2]}) """, params={ keys[0]: "test1", keys[1]: datetime.datetime(2004, 11, 30, 4, 53, 54, tzinfo=UTC), keys[2]: datetime.datetime(2004, 12, 15, 4, 53, 54, tzinfo=UTC), }, param_types={ keys[0]: spanner_v1.param_types.STRING, keys[1]: spanner_v1.param_types.TIMESTAMP, keys[2]: spanner_v1.param_types.TIMESTAMP, }, ) def insert_test2(transaction): keys, placeholders = get_param_info( ["key", "create_time", "expiry_time"], database_dialect ) transaction.execute_update( f""" INSERT INTO IntervalTable (key, create_time, expiry_time) VALUES ({placeholders[0]}, {placeholders[1]}, {placeholders[2]}) """, params={ keys[0]: "test2", keys[1]: datetime.datetime(2004, 8, 30, 4, 53, 54, tzinfo=UTC), keys[2]: datetime.datetime(2004, 12, 15, 4, 53, 54, tzinfo=UTC), }, param_types={ keys[0]: spanner_v1.param_types.STRING, keys[1]: spanner_v1.param_types.TIMESTAMP, keys[2]: spanner_v1.param_types.TIMESTAMP, }, ) def test_computed_columns(transaction): keys, placeholders = get_param_info(["key"], database_dialect) results = list( transaction.execute_sql( f""" SELECT expiry_within_month, interval_array_len FROM IntervalTable WHERE key = {placeholders[0]}""", params={keys[0]: "test1"}, param_types={keys[0]: spanner_v1.param_types.STRING}, ) ) assert len(results) == 1 row = results[0] assert row[0] is True # expiry_within_month assert row[1] == 1 # interval_array_len def test_interval_arithmetic(transaction): results = list( transaction.execute_sql( "SELECT INTERVAL '1' DAY + INTERVAL '1' MONTH AS Col1" ) ) assert len(results) == 1 row = results[0] interval = row[0] assert interval.months == 1 assert interval.days == 1 assert interval.nanos == 0 def test_interval_timestamp_comparison(transaction): timestamp = "2004-11-30T10:23:54+0530" keys, placeholders = get_param_info(["interval"], database_dialect) if database_dialect == DatabaseDialect.POSTGRESQL: query = f"SELECT COUNT(*) FROM IntervalTable WHERE create_time < TIMESTAMPTZ '%s' - {placeholders[0]}" else: query = f"SELECT COUNT(*) FROM IntervalTable WHERE create_time < TIMESTAMP('%s') - {placeholders[0]}" results = list( transaction.execute_sql( query % timestamp, params={keys[0]: Interval(days=30)}, param_types={keys[0]: spanner_v1.param_types.INTERVAL}, ) ) assert len(results) == 1 assert results[0][0] == 1 def test_interval_array_param(transaction): intervals = [ Interval(months=14, days=3, nanos=14706000000000), Interval(), Interval(months=-14, days=-3, nanos=-14706000000000), None, ] keys, placeholders = get_param_info(["intervals"], database_dialect) array_type = spanner_v1.Type( code=spanner_v1.TypeCode.ARRAY, array_element_type=spanner_v1.param_types.INTERVAL, ) results = list( transaction.execute_sql( f"SELECT {placeholders[0]}", params={keys[0]: intervals}, param_types={keys[0]: array_type}, ) ) assert len(results) == 1 row = results[0] intervals = row[0] assert len(intervals) == 4 assert intervals[0].months == 14 assert intervals[0].days == 3 assert intervals[0].nanos == 14706000000000 assert intervals[1].months == 0 assert intervals[1].days == 0 assert intervals[1].nanos == 0 assert intervals[2].months == -14 assert intervals[2].days == -3 assert intervals[2].nanos == -14706000000000 assert intervals[3] is None def test_interval_array_cast(transaction): results = list( transaction.execute_sql( """ SELECT ARRAY[ CAST('P1Y2M3DT4H5M6.789123S' AS INTERVAL), null, CAST('P-1Y-2M-3DT-4H-5M-6.789123S' AS INTERVAL) ] AS Col1 """ ) ) assert len(results) == 1 row = results[0] intervals = row[0] assert len(intervals) == 3 assert intervals[0].months == 14 # 1 year + 2 months assert intervals[0].days == 3 assert intervals[0].nanos == 14706789123000 # 4h5m6.789123s in nanos assert intervals[1] is None assert intervals[2].months == -14 assert intervals[2].days == -3 assert intervals[2].nanos == -14706789123000 setup_table() sessions_database.run_in_transaction(insert_test1) sessions_database.run_in_transaction(test_computed_columns) sessions_database.run_in_transaction(test_interval_arithmetic) sessions_database.run_in_transaction(insert_test2) sessions_database.run_in_transaction(test_interval_timestamp_comparison) sessions_database.run_in_transaction(test_interval_array_param) sessions_database.run_in_transaction(test_interval_array_cast) def test_session_id_and_multiplexed_flag_behavior(sessions_database, ot_exporter): sd = _sample_data with sessions_database.batch() as batch: batch.delete(sd.TABLE, sd.ALL) batch.insert(sd.TABLE, sd.COLUMNS, sd.ROW_DATA) multiplexed_enabled = is_multiplexed_enabled(TransactionType.READ_ONLY) snapshot1_session_id = None snapshot2_session_id = None snapshot1_is_multiplexed = None snapshot2_is_multiplexed = None snapshot1 = sessions_database.snapshot() snapshot2 = sessions_database.snapshot() try: with snapshot1 as snap1, snapshot2 as snap2: rows1 = list(snap1.read(sd.TABLE, sd.COLUMNS, sd.ALL)) rows2 = list(snap2.read(sd.TABLE, sd.COLUMNS, sd.ALL)) snapshot1_session_id = snap1._session.name snapshot1_is_multiplexed = snap1._session.is_multiplexed snapshot2_session_id = snap2._session.name snapshot2_is_multiplexed = snap2._session.is_multiplexed except Exception: with sessions_database.snapshot() as snap1: rows1 = list(snap1.read(sd.TABLE, sd.COLUMNS, sd.ALL)) snapshot1_session_id = snap1._session.name snapshot1_is_multiplexed = snap1._session.is_multiplexed with sessions_database.snapshot() as snap2: rows2 = list(snap2.read(sd.TABLE, sd.COLUMNS, sd.ALL)) snapshot2_session_id = snap2._session.name snapshot2_is_multiplexed = snap2._session.is_multiplexed sd._check_rows_data(rows1) sd._check_rows_data(rows2) assert rows1 == rows2 assert snapshot1_session_id is not None assert snapshot2_session_id is not None assert snapshot1_is_multiplexed is not None assert snapshot2_is_multiplexed is not None if multiplexed_enabled: assert snapshot1_session_id == snapshot2_session_id assert snapshot1_is_multiplexed is True assert snapshot2_is_multiplexed is True else: assert snapshot1_is_multiplexed is False assert snapshot2_is_multiplexed is False if ot_exporter is not None: span_list = ot_exporter.get_finished_spans() session_spans = [] read_spans = [] for span in span_list: if ( "CreateSession" in span.name or "CreateMultiplexedSession" in span.name or "GetSession" in span.name ): session_spans.append(span) elif "Snapshot.read" in span.name: read_spans.append(span) assert len(read_spans) == 2 if multiplexed_enabled: multiplexed_session_spans = [ s for s in session_spans if "CreateMultiplexedSession" in s.name ] read_only_multiplexed_sessions = [ s for s in multiplexed_session_spans if s.start_time > span_list[1].end_time ] # Allow for session reuse - if no new multiplexed sessions were created, # it means an existing one was reused (which is valid behavior) if len(read_only_multiplexed_sessions) == 0: # Verify that multiplexed sessions are actually being used by checking # that the snapshots themselves are multiplexed assert snapshot1_is_multiplexed is True assert snapshot2_is_multiplexed is True assert snapshot1_session_id == snapshot2_session_id else: # New multiplexed session was created assert len(read_only_multiplexed_sessions) >= 1 # Note: We don't need to assert specific counts for regular/get sessions # as the key validation is that multiplexed sessions are being used properly else: read_only_session_spans = [ s for s in session_spans if s.start_time > span_list[1].end_time ] assert len(read_only_session_spans) >= 1 multiplexed_session_spans = [ s for s in session_spans if "CreateMultiplexedSession" in s.name ] assert len(multiplexed_session_spans) == 0