diff --git a/spanner/google/cloud/spanner_v1/pool.py b/spanner/google/cloud/spanner_v1/pool.py index 8af3e566cab5..ce7a196b6bb8 100644 --- a/spanner/google/cloud/spanner_v1/pool.py +++ b/spanner/google/cloud/spanner_v1/pool.py @@ -503,7 +503,7 @@ def put(self, session): raise queue.Full txn = session._transaction - if txn is None or txn.committed() or txn._rolled_back: + if txn is None or txn.committed or txn._rolled_back: session.transaction() self._pending_sessions.put(session) else: diff --git a/spanner/tests/unit/test_pool.py b/spanner/tests/unit/test_pool.py index c5e243e6373c..2d4a9d882291 100644 --- a/spanner/tests/unit/test_pool.py +++ b/spanner/tests/unit/test_pool.py @@ -656,7 +656,7 @@ def test_bind(self): for session in SESSIONS: session.create.assert_not_called() txn = session._transaction - self.assertTrue(txn._begun) + txn.begin.assert_called_once_with() self.assertTrue(pool._pending_sessions.empty()) @@ -685,7 +685,7 @@ def test_bind_w_timestamp_race(self): for session in SESSIONS: session.create.assert_not_called() txn = session._transaction - self.assertTrue(txn._begun) + txn.begin.assert_called_once_with() self.assertTrue(pool._pending_sessions.empty()) @@ -718,7 +718,7 @@ def test_put_non_full_w_active_txn(self): self.assertIs(queued, session) self.assertEqual(len(pending._items), 0) - self.assertFalse(txn._begun) + txn.begin.assert_not_called() def test_put_non_full_w_committed_txn(self): pool = self._make_one(size=1) @@ -727,7 +727,7 @@ def test_put_non_full_w_committed_txn(self): database = _Database("name") session = _Session(database) committed = session.transaction() - committed._committed = True + committed.committed = True pool.put(session) @@ -736,7 +736,7 @@ def test_put_non_full_w_committed_txn(self): self.assertEqual(len(pending._items), 1) self.assertIs(pending._items[0], session) self.assertIsNot(session._transaction, committed) - self.assertFalse(session._transaction._begun) + session._transaction.begin.assert_not_called() def test_put_non_full(self): pool = self._make_one(size=1) @@ -762,7 +762,7 @@ def test_begin_pending_transactions_non_empty(self): pool._sessions = _Queue() database = _Database("name") - TRANSACTIONS = [_Transaction()] + TRANSACTIONS = [_make_transaction(object())] PENDING_SESSIONS = [_Session(database, transaction=txn) for txn in TRANSACTIONS] pending = pool._pending_sessions = _Queue(*PENDING_SESSIONS) @@ -771,7 +771,7 @@ def test_begin_pending_transactions_non_empty(self): pool.begin_pending_transactions() # no raise for txn in TRANSACTIONS: - self.assertTrue(txn._begun) + txn.begin.assert_called_once_with() self.assertTrue(pending.empty()) @@ -832,17 +832,13 @@ def test_context_manager_w_kwargs(self): self.assertEqual(pool._got, {"foo": "bar"}) -class _Transaction(object): +def _make_transaction(*args, **kw): + from google.cloud.spanner_v1.transaction import Transaction - _begun = False - _committed = False - _rolled_back = False - - def begin(self): - self._begun = True - - def committed(self): - return self._committed + txn = mock.create_autospec(Transaction)(*args, **kw) + txn.committed = None + txn._rolled_back = False + return txn @total_ordering @@ -873,7 +869,7 @@ def delete(self): raise NotFound("unknown session") def transaction(self): - txn = self._transaction = _Transaction() + txn = self._transaction = _make_transaction(self) return txn