Skip to content

Commit

Permalink
fix(firestore): simplify 'Collection.add', avoid spurious API call (#…
Browse files Browse the repository at this point in the history
…9634)

Closes #9629
  • Loading branch information
tseaver authored Nov 7, 2019
1 parent 8c83b52 commit 5621e9c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 43 deletions.
25 changes: 4 additions & 21 deletions firestore/google/cloud/firestore_v1/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from google.cloud.firestore_v1 import _helpers
from google.cloud.firestore_v1 import query as query_mod
from google.cloud.firestore_v1.proto import document_pb2
from google.cloud.firestore_v1.watch import Watch
from google.cloud.firestore_v1 import document

Expand Down Expand Up @@ -157,27 +156,11 @@ def add(self, document_data, document_id=None):
and the document already exists.
"""
if document_id is None:
parent_path, expected_prefix = self._parent_info()

document_pb = document_pb2.Document()

created_document_pb = self._client._firestore_api.create_document(
parent_path,
collection_id=self.id,
document_id=None,
document=document_pb,
mask=None,
metadata=self._client._rpc_metadata,
)
document_id = _auto_id()

new_document_id = _helpers.get_doc_id(created_document_pb, expected_prefix)
document_ref = self.document(new_document_id)
set_result = document_ref.set(document_data)
return set_result.update_time, document_ref
else:
document_ref = self.document(document_id)
write_result = document_ref.create(document_data)
return write_result.update_time, document_ref
document_ref = self.document(document_id)
write_result = document_ref.create(document_data)
return write_result.update_time, document_ref

def list_documents(self, page_size=None):
"""List all subdocuments of the current collection.
Expand Down
33 changes: 11 additions & 22 deletions firestore/tests/unit/v1/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import types
import unittest

Expand Down Expand Up @@ -193,7 +192,7 @@ def test_add_auto_assigned(self):
from google.cloud.firestore_v1.proto import document_pb2
from google.cloud.firestore_v1.document import DocumentReference
from google.cloud.firestore_v1 import SERVER_TIMESTAMP
from google.cloud.firestore_v1._helpers import pbs_for_set_no_merge
from google.cloud.firestore_v1._helpers import pbs_for_create

# Create a minimal fake GAPIC add attach it to a real client.
firestore_api = mock.Mock(spec=["create_document", "commit"])
Expand All @@ -214,42 +213,32 @@ def test_add_auto_assigned(self):
# Actually make a collection.
collection = self._make_one("grand-parent", "parent", "child", client=client)

# Add a dummy response for the fake GAPIC.
parent_path = collection.parent._document_path
auto_assigned_id = "cheezburger"
name = "{}/{}/{}".format(parent_path, collection.id, auto_assigned_id)
create_doc_response = document_pb2.Document(name=name)
create_doc_response.update_time.FromDatetime(datetime.datetime.utcnow())
firestore_api.create_document.return_value = create_doc_response

# Actually call add() on our collection; include a transform to make
# sure transforms during adds work.
document_data = {"been": "here", "now": SERVER_TIMESTAMP}
update_time, document_ref = collection.add(document_data)

patch = mock.patch("google.cloud.firestore_v1.collection._auto_id")
random_doc_id = "DEADBEEF"
with patch as patched:
patched.return_value = random_doc_id
update_time, document_ref = collection.add(document_data)

# Verify the response and the mocks.
self.assertIs(update_time, mock.sentinel.update_time)
self.assertIsInstance(document_ref, DocumentReference)
self.assertIs(document_ref._client, client)
expected_path = collection._path + (auto_assigned_id,)
expected_path = collection._path + (random_doc_id,)
self.assertEqual(document_ref._path, expected_path)

expected_document_pb = document_pb2.Document()
firestore_api.create_document.assert_called_once_with(
parent_path,
collection_id=collection.id,
document_id=None,
document=expected_document_pb,
mask=None,
metadata=client._rpc_metadata,
)
write_pbs = pbs_for_set_no_merge(document_ref._document_path, document_data)
write_pbs = pbs_for_create(document_ref._document_path, document_data)
firestore_api.commit.assert_called_once_with(
client._database_string,
write_pbs,
transaction=None,
metadata=client._rpc_metadata,
)
# Since we generate the ID locally, we don't call 'create_document'.
firestore_api.create_document.assert_not_called()

@staticmethod
def _write_pb_for_create(document_path, document_data):
Expand Down

0 comments on commit 5621e9c

Please sign in to comment.