From de3aca0e78b68f66eb76bc679c6e95b0746ad590 Mon Sep 17 00:00:00 2001 From: HemangChothani <50404902+HemangChothani@users.noreply.github.com> Date: Fri, 21 Feb 2020 22:39:02 +0530 Subject: [PATCH] fix(firestore): fix get and getall method of transaction (#16) --- google/cloud/firestore_v1/transaction.py | 6 +++--- tests/unit/v1/test_transaction.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/google/cloud/firestore_v1/transaction.py b/google/cloud/firestore_v1/transaction.py index 9d4068c75a88..04485a84c2e3 100644 --- a/google/cloud/firestore_v1/transaction.py +++ b/google/cloud/firestore_v1/transaction.py @@ -213,7 +213,7 @@ def get_all(self, references): .DocumentSnapshot: The next document snapshot that fulfills the query, or :data:`None` if the document does not exist. """ - return self._client.get_all(references, transaction=self._id) + return self._client.get_all(references, transaction=self) def get(self, ref_or_query): """ @@ -225,9 +225,9 @@ def get(self, ref_or_query): query, or :data:`None` if the document does not exist. """ if isinstance(ref_or_query, DocumentReference): - return self._client.get_all([ref_or_query], transaction=self._id) + return self._client.get_all([ref_or_query], transaction=self) elif isinstance(ref_or_query, Query): - return ref_or_query.stream(transaction=self._id) + return ref_or_query.stream(transaction=self) else: raise ValueError( 'Value for argument "ref_or_query" must be a DocumentReference or a Query.' diff --git a/tests/unit/v1/test_transaction.py b/tests/unit/v1/test_transaction.py index 8cae24a23831..da3c2d0b027d 100644 --- a/tests/unit/v1/test_transaction.py +++ b/tests/unit/v1/test_transaction.py @@ -333,7 +333,7 @@ def test_get_all(self): transaction = self._make_one(client) ref1, ref2 = mock.Mock(), mock.Mock() result = transaction.get_all([ref1, ref2]) - client.get_all.assert_called_once_with([ref1, ref2], transaction=transaction.id) + client.get_all.assert_called_once_with([ref1, ref2], transaction=transaction) self.assertIs(result, client.get_all.return_value) def test_get_document_ref(self): @@ -343,7 +343,7 @@ def test_get_document_ref(self): transaction = self._make_one(client) ref = DocumentReference("documents", "doc-id") result = transaction.get(ref) - client.get_all.assert_called_once_with([ref], transaction=transaction.id) + client.get_all.assert_called_once_with([ref], transaction=transaction) self.assertIs(result, client.get_all.return_value) def test_get_w_query(self): @@ -354,7 +354,7 @@ def test_get_w_query(self): query = Query(parent=mock.Mock(spec=[])) query.stream = mock.MagicMock() result = transaction.get(query) - query.stream.assert_called_once_with(transaction=transaction.id) + query.stream.assert_called_once_with(transaction=transaction) self.assertIs(result, query.stream.return_value) def test_get_failure(self):