Skip to content

Commit

Permalink
test: unit test case fix (#1057)
Browse files Browse the repository at this point in the history
* test: unit test case fix

* feat(spanner): lint

---------

Co-authored-by: Sri Harsha CH <[email protected]>
Co-authored-by: Sri Harsha CH <[email protected]>
  • Loading branch information
3 people authored Jan 7, 2024
1 parent d3fe937 commit 07a0202
Showing 1 changed file with 35 additions and 16 deletions.
51 changes: 35 additions & 16 deletions tests/unit/test_spanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@
MODE = 2
RETRY = gapic_v1.method.DEFAULT
TIMEOUT = gapic_v1.method.DEFAULT
REQUEST_OPTIONS = RequestOptions()
insert_dml = "INSERT INTO table(pkey, desc) VALUES (%pkey, %desc)"
insert_params = {"pkey": 12345, "desc": "DESCRIPTION"}
insert_param_types = {"pkey": param_types.INT64, "desc": param_types.STRING}
Expand Down Expand Up @@ -142,7 +141,7 @@ def _execute_update_helper(
PARAM_TYPES,
query_mode=MODE,
query_options=query_options,
request_options=REQUEST_OPTIONS,
request_options=RequestOptions(),
retry=RETRY,
timeout=TIMEOUT,
)
Expand All @@ -167,7 +166,7 @@ def _execute_update_expected_request(
expected_query_options = _merge_query_options(
expected_query_options, query_options
)
expected_request_options = REQUEST_OPTIONS
expected_request_options = RequestOptions()
expected_request_options.transaction_tag = self.TRANSACTION_TAG

expected_request = ExecuteSqlRequest(
Expand Down Expand Up @@ -226,7 +225,7 @@ def _execute_sql_helper(
PARAM_TYPES,
query_mode=MODE,
query_options=query_options,
request_options=REQUEST_OPTIONS,
request_options=RequestOptions(),
partition=partition,
retry=RETRY,
timeout=TIMEOUT,
Expand All @@ -240,7 +239,13 @@ def _execute_sql_helper(
self.assertEqual(transaction._execute_sql_count, sql_count + 1)

def _execute_sql_expected_request(
self, database, partition=None, query_options=None, begin=True, sql_count=0
self,
database,
partition=None,
query_options=None,
begin=True,
sql_count=0,
transaction_tag=False,
):
if begin is True:
expected_transaction = TransactionSelector(
Expand All @@ -259,8 +264,12 @@ def _execute_sql_expected_request(
expected_query_options, query_options
)

expected_request_options = REQUEST_OPTIONS
expected_request_options.transaction_tag = self.TRANSACTION_TAG
expected_request_options = RequestOptions()

if transaction_tag is True:
expected_request_options.transaction_tag = self.TRANSACTION_TAG
else:
expected_request_options.transaction_tag = None

expected_request = ExecuteSqlRequest(
session=self.SESSION_NAME,
Expand Down Expand Up @@ -320,7 +329,7 @@ def _read_helper(
partition=partition,
retry=RETRY,
timeout=TIMEOUT,
request_options=REQUEST_OPTIONS,
request_options=RequestOptions(),
)
else:
result_set = transaction.read(
Expand All @@ -331,7 +340,7 @@ def _read_helper(
limit=LIMIT,
retry=RETRY,
timeout=TIMEOUT,
request_options=REQUEST_OPTIONS,
request_options=RequestOptions(),
)

self.assertEqual(transaction._read_request_count, count + 1)
Expand All @@ -342,7 +351,9 @@ def _read_helper(
self.assertEqual(result_set.metadata, metadata_pb)
self.assertEqual(result_set.stats, stats_pb)

def _read_helper_expected_request(self, partition=None, begin=True, count=0):
def _read_helper_expected_request(
self, partition=None, begin=True, count=0, transaction_tag=False
):
if begin is True:
expected_transaction = TransactionSelector(
begin=TransactionOptions(read_write=TransactionOptions.ReadWrite())
Expand All @@ -356,8 +367,12 @@ def _read_helper_expected_request(self, partition=None, begin=True, count=0):
expected_limit = LIMIT

# Transaction tag is ignored for read request.
expected_request_options = REQUEST_OPTIONS
expected_request_options.transaction_tag = self.TRANSACTION_TAG
expected_request_options = RequestOptions()

if transaction_tag is True:
expected_request_options.transaction_tag = self.TRANSACTION_TAG
else:
expected_request_options.transaction_tag = None

expected_request = ReadRequest(
session=self.SESSION_NAME,
Expand Down Expand Up @@ -410,7 +425,7 @@ def _batch_update_helper(
transaction._execute_sql_count = count

status, row_counts = transaction.batch_update(
dml_statements, request_options=REQUEST_OPTIONS
dml_statements, request_options=RequestOptions()
)

self.assertEqual(status, expected_status)
Expand Down Expand Up @@ -440,7 +455,7 @@ def _batch_update_expected_request(self, begin=True, count=0):
ExecuteBatchDmlRequest.Statement(sql=delete_dml),
]

expected_request_options = REQUEST_OPTIONS
expected_request_options = RequestOptions()
expected_request_options.transaction_tag = self.TRANSACTION_TAG

expected_request = ExecuteBatchDmlRequest(
Expand Down Expand Up @@ -595,7 +610,9 @@ def test_transaction_should_use_transaction_id_returned_by_first_update(self):

self._execute_sql_helper(transaction=transaction, api=api)
api.execute_streaming_sql.assert_called_once_with(
request=self._execute_sql_expected_request(database=database, begin=False),
request=self._execute_sql_expected_request(
database=database, begin=False, transaction_tag=True
),
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
metadata=[
Expand Down Expand Up @@ -644,7 +661,9 @@ def test_transaction_should_use_transaction_id_returned_by_first_batch_update(se
)
self._read_helper(transaction=transaction, api=api)
api.streaming_read.assert_called_once_with(
request=self._read_helper_expected_request(begin=False),
request=self._read_helper_expected_request(
begin=False, transaction_tag=True
),
metadata=[
("google-cloud-resource-prefix", database.name),
("x-goog-spanner-route-to-leader", "true"),
Expand Down

0 comments on commit 07a0202

Please sign in to comment.