Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/develop-google-firestore-instrum…
Browse files Browse the repository at this point in the history
…entation' into feature-firstore-async-instrumentation
  • Loading branch information
TimPansino committed Jul 31, 2023
2 parents 7bf6f49 + dcc92a9 commit 4a8a3fe
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 134 deletions.
22 changes: 3 additions & 19 deletions newrelic/hooks/datastore_firestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,9 @@
from newrelic.api.datastore_trace import DatastoreTrace


def _get_object_id(obj, *args, **kwargs):
try:
return obj.id
except Exception:
return None


def _get_parent_id(obj, *args, **kwargs):
try:
return obj._parent.id
except Exception:
return None


def _get_collection_ref_id(obj, *args, **kwargs):
try:
return obj._collection_ref.id
except Exception:
return None
_get_object_id = lambda obj, *args, **kwargs: getattr(obj, "id", None)
_get_parent_id = lambda obj, *args, **kwargs: getattr(getattr(obj, "_parent", None), "id", None)
_get_collection_ref_id = lambda obj, *args, **kwargs: getattr(getattr(obj, "_collection_ref", None), "id", None)


def wrap_generator_method(module, class_name, method_name, target, is_async=False):
Expand Down
3 changes: 3 additions & 0 deletions tests/datastore_firestore/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
import uuid

import pytest

from google.cloud.firestore import Client
from google.cloud.firestore import Client, AsyncClient

from testing_support.db_settings import firestore_settings
from testing_support.fixture.event_loop import event_loop as loop # noqa: F401; pylint: disable=W0611
from testing_support.fixtures import ( # noqa: F401; pylint: disable=W0611
Expand Down
42 changes: 25 additions & 17 deletions tests/datastore_firestore/test_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from testing_support.validators.validate_database_duration import (
validate_database_duration,
)
Expand All @@ -24,16 +26,19 @@
# ===== WriteBatch =====


def _exercise_write_batch(client, collection):
docs = [collection.document(str(x)) for x in range(1, 4)]
batch = client.batch()
for doc in docs:
batch.set(doc, {})
@pytest.fixture()
def exercise_write_batch(client, collection):
def _exercise_write_batch():
docs = [collection.document(str(x)) for x in range(1, 4)]
batch = client.batch()
for doc in docs:
batch.set(doc, {})

batch.commit()
batch.commit()
return _exercise_write_batch


def test_firestore_write_batch(client, collection):
def test_firestore_write_batch(exercise_write_batch):
_test_scoped_metrics = [
("Datastore/operation/Firestore/commit", 1),
]
Expand All @@ -52,26 +57,29 @@ def test_firestore_write_batch(client, collection):
)
@background_task(name="test_firestore_write_batch")
def _test():
_exercise_write_batch(client, collection)
exercise_write_batch()

_test()


# ===== BulkWriteBatch =====


def _exercise_bulk_write_batch(client, collection):
from google.cloud.firestore_v1.bulk_batch import BulkWriteBatch
@pytest.fixture()
def exercise_bulk_write_batch(client, collection):
def _exercise_bulk_write_batch():
from google.cloud.firestore_v1.bulk_batch import BulkWriteBatch

docs = [collection.document(str(x)) for x in range(1, 4)]
batch = BulkWriteBatch(client)
for doc in docs:
batch.set(doc, {})
docs = [collection.document(str(x)) for x in range(1, 4)]
batch = BulkWriteBatch(client)
for doc in docs:
batch.set(doc, {})

batch.commit()
batch.commit()
return _exercise_bulk_write_batch


def test_firestore_bulk_write_batch(client, collection):
def test_firestore_bulk_write_batch(exercise_bulk_write_batch):
_test_scoped_metrics = [
("Datastore/operation/Firestore/commit", 1),
]
Expand All @@ -90,6 +98,6 @@ def test_firestore_bulk_write_batch(client, collection):
)
@background_task(name="test_firestore_bulk_write_batch")
def _test():
_exercise_bulk_write_batch(client, collection)
exercise_bulk_write_batch()

_test()
17 changes: 10 additions & 7 deletions tests/datastore_firestore/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,16 @@ def sample_data(collection):
return doc


def _exercise_client(client, collection, sample_data):
assert len([_ for _ in client.collections()])
doc = [_ for _ in client.get_all([sample_data])][0]
assert doc.to_dict()["x"] == 1
@pytest.fixture()
def exercise_client(client, sample_data):
def _exercise_client():
assert len([_ for _ in client.collections()])
doc = [_ for _ in client.get_all([sample_data])][0]
assert doc.to_dict()["x"] == 1
return _exercise_client


def test_firestore_client(client, collection, sample_data):
def test_firestore_client(exercise_client):
_test_scoped_metrics = [
("Datastore/operation/Firestore/collections", 1),
("Datastore/operation/Firestore/get_all", 1),
Expand All @@ -55,12 +58,12 @@ def test_firestore_client(client, collection, sample_data):
)
@background_task(name="test_firestore_client")
def _test():
_exercise_client(client, collection, sample_data)
exercise_client()

_test()


@background_task()
def test_firestore_client_generators(client, collection, sample_data, assert_trace_for_generator):
def test_firestore_client_generators(client, sample_data, assert_trace_for_generator):
assert_trace_for_generator(client.collections)
assert_trace_for_generator(client.get_all, [sample_data])
29 changes: 17 additions & 12 deletions tests/datastore_firestore/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from testing_support.validators.validate_database_duration import (
validate_database_duration,
)
Expand All @@ -22,20 +24,23 @@
from newrelic.api.background_task import background_task


def _exercise_collections(collection):
collection.document("DoesNotExist")
collection.add({"capital": "Rome", "currency": "Euro", "language": "Italian"}, "Italy")
collection.add({"capital": "Mexico City", "currency": "Peso", "language": "Spanish"}, "Mexico")
@pytest.fixture()
def exercise_collections(collection):
def _exercise_collections():
collection.document("DoesNotExist")
collection.add({"capital": "Rome", "currency": "Euro", "language": "Italian"}, "Italy")
collection.add({"capital": "Mexico City", "currency": "Peso", "language": "Spanish"}, "Mexico")

documents_get = collection.get()
assert len(documents_get) == 2
documents_stream = [_ for _ in collection.stream()]
assert len(documents_stream) == 2
documents_list = [_ for _ in collection.list_documents()]
assert len(documents_list) == 2
documents_get = collection.get()
assert len(documents_get) == 2
documents_stream = [_ for _ in collection.stream()]
assert len(documents_stream) == 2
documents_list = [_ for _ in collection.list_documents()]
assert len(documents_list) == 2
return _exercise_collections


def test_firestore_collections(collection):
def test_firestore_collections(exercise_collections, collection):
_test_scoped_metrics = [
("Datastore/statement/Firestore/%s/stream" % collection.id, 1),
("Datastore/statement/Firestore/%s/get" % collection.id, 1),
Expand All @@ -60,7 +65,7 @@ def test_firestore_collections(collection):
)
@background_task(name="test_firestore_collections")
def _test():
_exercise_collections(collection)
exercise_collections()

_test()

Expand Down
33 changes: 19 additions & 14 deletions tests/datastore_firestore/test_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from testing_support.validators.validate_database_duration import (
validate_database_duration,
)
Expand All @@ -22,23 +24,26 @@
from newrelic.api.background_task import background_task


def _exercise_documents(collection):
italy_doc = collection.document("Italy")
italy_doc.set({"capital": "Rome", "currency": "Euro", "language": "Italian"})
italy_doc.get()
italian_cities = italy_doc.collection("cities")
italian_cities.add({"capital": "Rome"})
retrieved_coll = [_ for _ in italy_doc.collections()]
assert len(retrieved_coll) == 1
@pytest.fixture()
def exercise_documents(collection):
def _exercise_documents():
italy_doc = collection.document("Italy")
italy_doc.set({"capital": "Rome", "currency": "Euro", "language": "Italian"})
italy_doc.get()
italian_cities = italy_doc.collection("cities")
italian_cities.add({"capital": "Rome"})
retrieved_coll = [_ for _ in italy_doc.collections()]
assert len(retrieved_coll) == 1

usa_doc = collection.document("USA")
usa_doc.create({"capital": "Washington D.C.", "currency": "Dollar", "language": "English"})
usa_doc.update({"president": "Joe Biden"})
usa_doc = collection.document("USA")
usa_doc.create({"capital": "Washington D.C.", "currency": "Dollar", "language": "English"})
usa_doc.update({"president": "Joe Biden"})

collection.document("USA").delete()
collection.document("USA").delete()
return _exercise_documents


def test_firestore_documents(collection):
def test_firestore_documents(exercise_documents):
_test_scoped_metrics = [
("Datastore/statement/Firestore/Italy/set", 1),
("Datastore/statement/Firestore/Italy/get", 1),
Expand Down Expand Up @@ -69,7 +74,7 @@ def test_firestore_documents(collection):
)
@background_task(name="test_firestore_documents")
def _test():
_exercise_documents(collection)
exercise_documents()

_test()

Expand Down
71 changes: 38 additions & 33 deletions tests/datastore_firestore/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,16 @@ def sample_data(collection):
# ===== Query =====


def _exercise_query(collection):
query = collection.select("x").limit(10).order_by("x").where(field_path="x", op_string="<=", value=3)
assert len(query.get()) == 3
assert len([_ for _ in query.stream()]) == 3
@pytest.fixture()
def exercise_query(collection):
def _exercise_query():
query = collection.select("x").limit(10).order_by("x").where(field_path="x", op_string="<=", value=3)
assert len(query.get()) == 3
assert len([_ for _ in query.stream()]) == 3
return _exercise_query


def test_firestore_query(collection):
def test_firestore_query(exercise_query, collection):
_test_scoped_metrics = [
("Datastore/statement/Firestore/%s/stream" % collection.id, 1),
("Datastore/statement/Firestore/%s/get" % collection.id, 1),
Expand All @@ -64,7 +67,7 @@ def test_firestore_query(collection):
)
@background_task(name="test_firestore_query")
def _test():
_exercise_query(collection)
exercise_query()

_test()

Expand All @@ -78,13 +81,16 @@ def test_firestore_query_generators(collection, assert_trace_for_generator):
# ===== AggregationQuery =====


def _exercise_aggregation_query(collection):
aggregation_query = collection.select("x").where(field_path="x", op_string="<=", value=3).count()
assert aggregation_query.get()[0][0].value == 3
assert [_ for _ in aggregation_query.stream()][0][0].value == 3
@pytest.fixture()
def exercise_aggregation_query(collection):
def _exercise_aggregation_query():
aggregation_query = collection.select("x").where(field_path="x", op_string="<=", value=3).count()
assert aggregation_query.get()[0][0].value == 3
assert [_ for _ in aggregation_query.stream()][0][0].value == 3
return _exercise_aggregation_query


def test_firestore_aggregation_query(collection):
def test_firestore_aggregation_query(exercise_aggregation_query, collection):
_test_scoped_metrics = [
("Datastore/statement/Firestore/%s/stream" % collection.id, 1),
("Datastore/statement/Firestore/%s/get" % collection.id, 1),
Expand All @@ -106,7 +112,7 @@ def test_firestore_aggregation_query(collection):
)
@background_task(name="test_firestore_aggregation_query")
def _test():
_exercise_aggregation_query(collection)
exercise_aggregation_query()

_test()

Expand Down Expand Up @@ -143,22 +149,23 @@ def mock_partition_query(*args, **kwargs):
yield


def _exercise_collection_group(collection):
from google.cloud.firestore import CollectionGroup

collection_group = CollectionGroup(collection)
assert len(collection_group.get())
assert len([d for d in collection_group.stream()])

partitions = [p for p in collection_group.get_partitions(1)]
assert len(partitions) == 2
documents = []
while partitions:
documents.extend(partitions.pop().query().get())
assert len(documents) == 6


def test_firestore_collection_group(collection, patch_partition_queries):
@pytest.fixture()
def exercise_collection_group(client, collection):
def _exercise_collection_group():
collection_group = client.collection_group(collection.id)
assert len(collection_group.get())
assert len([d for d in collection_group.stream()])

partitions = [p for p in collection_group.get_partitions(1)]
assert len(partitions) == 2
documents = []
while partitions:
documents.extend(partitions.pop().query().get())
assert len(documents) == 6
return _exercise_collection_group


def test_firestore_collection_group(exercise_collection_group, client, collection, patch_partition_queries):
_test_scoped_metrics = [
("Datastore/statement/Firestore/%s/get" % collection.id, 3),
("Datastore/statement/Firestore/%s/stream" % collection.id, 1),
Expand All @@ -182,14 +189,12 @@ def test_firestore_collection_group(collection, patch_partition_queries):
)
@background_task(name="test_firestore_collection_group")
def _test():
_exercise_collection_group(collection)
exercise_collection_group()

_test()


@background_task()
def test_firestore_collection_group_generators(collection, assert_trace_for_generator, patch_partition_queries):
from google.cloud.firestore import CollectionGroup

collection_group = CollectionGroup(collection)
def test_firestore_collection_group_generators(client, collection, assert_trace_for_generator, patch_partition_queries):
collection_group = client.collection_group(collection.id)
assert_trace_for_generator(collection_group.get_partitions, 1)
Loading

0 comments on commit 4a8a3fe

Please sign in to comment.