diff --git a/firestore/google/cloud/firestore_v1/query.py b/firestore/google/cloud/firestore_v1/query.py index 6f4c498c0725..d4e1f7f07324 100644 --- a/firestore/google/cloud/firestore_v1/query.py +++ b/firestore/google/cloud/firestore_v1/query.py @@ -43,6 +43,8 @@ ">=": _operator_enum.GREATER_THAN_OR_EQUAL, ">": _operator_enum.GREATER_THAN, "array_contains": _operator_enum.ARRAY_CONTAINS, + "in": _operator_enum.IN, + "array_contains_any": _operator_enum.ARRAY_CONTAINS_ANY, } _BAD_OP_STRING = "Operator string {!r} is invalid. Valid choices are: {}." _BAD_OP_NAN_NULL = 'Only an equality filter ("==") can be used with None or NaN values' diff --git a/firestore/tests/system/test_system.py b/firestore/tests/system/test_system.py index f2d30c94a171..71ac07fcee74 100644 --- a/firestore/tests/system/test_system.py +++ b/firestore/tests/system/test_system.py @@ -492,11 +492,13 @@ def test_collection_add(client, cleanup): assert set(collection3.list_documents()) == {document_ref5} -def test_query_stream(client, cleanup): +@pytest.fixture +def query_docs(client): collection_id = "qs" + UNIQUE_RESOURCE_ID sub_collection = "child" + UNIQUE_RESOURCE_ID collection = client.collection(collection_id, "doc", sub_collection) + cleanup = [] stored = {} num_vals = 5 allowed_vals = six.moves.xrange(num_vals) @@ -505,38 +507,82 @@ def test_query_stream(client, cleanup): document_data = { "a": a_val, "b": b_val, + "c": [a_val, num_vals * 100], "stats": {"sum": a_val + b_val, "product": a_val * b_val}, } _, doc_ref = collection.add(document_data) # Add to clean-up. - cleanup(doc_ref.delete) + cleanup.append(doc_ref.delete) stored[doc_ref.id] = document_data - # 0. Limit to snapshots where ``a==1``. - query0 = collection.where("a", "==", 1) - values0 = {snapshot.id: snapshot.to_dict() for snapshot in query0.stream()} - assert len(values0) == num_vals - for key, value in six.iteritems(values0): + yield collection, stored, allowed_vals + + for operation in cleanup: + operation() + + +def test_query_stream_w_simple_field_eq_op(query_docs): + collection, stored, allowed_vals = query_docs + query = collection.where("a", "==", 1) + values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} + assert len(values) == len(allowed_vals) + for key, value in six.iteritems(values): + assert stored[key] == value + assert value["a"] == 1 + + +def test_query_stream_w_simple_field_array_contains_op(query_docs): + collection, stored, allowed_vals = query_docs + query = collection.where("c", "array_contains", 1) + values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} + assert len(values) == len(allowed_vals) + for key, value in six.iteritems(values): + assert stored[key] == value + assert value["a"] == 1 + + +def test_query_stream_w_simple_field_in_op(query_docs): + collection, stored, allowed_vals = query_docs + num_vals = len(allowed_vals) + query = collection.where("a", "in", [1, num_vals + 100]) + values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} + assert len(values) == len(allowed_vals) + for key, value in six.iteritems(values): assert stored[key] == value assert value["a"] == 1 - # 1. Order by ``b``. - query1 = collection.order_by("b", direction=query0.DESCENDING) - values1 = [(snapshot.id, snapshot.to_dict()) for snapshot in query1.stream()] - assert len(values1) == len(stored) - b_vals1 = [] - for key, value in values1: + +def test_query_stream_w_simple_field_array_contains_any_op(query_docs): + collection, stored, allowed_vals = query_docs + num_vals = len(allowed_vals) + query = collection.where("c", "array_contains_any", [1, num_vals * 200]) + values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} + assert len(values) == len(allowed_vals) + for key, value in six.iteritems(values): assert stored[key] == value - b_vals1.append(value["b"]) + assert value["a"] == 1 + + +def test_query_stream_w_order_by(query_docs): + collection, stored, allowed_vals = query_docs + query = collection.order_by("b", direction=firestore.Query.DESCENDING) + values = [(snapshot.id, snapshot.to_dict()) for snapshot in query.stream()] + assert len(values) == len(stored) + b_vals = [] + for key, value in values: + assert stored[key] == value + b_vals.append(value["b"]) # Make sure the ``b``-values are in DESCENDING order. - assert sorted(b_vals1, reverse=True) == b_vals1 + assert sorted(b_vals, reverse=True) == b_vals + - # 2. Limit to snapshots where ``stats.sum > 1`` (a field path). - query2 = collection.where("stats.sum", ">", 4) - values2 = {snapshot.id: snapshot.to_dict() for snapshot in query2.stream()} - assert len(values2) == 10 +def test_query_stream_w_field_path(query_docs): + collection, stored, allowed_vals = query_docs + query = collection.where("stats.sum", ">", 4) + values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} + assert len(values) == 10 ab_pairs2 = set() - for key, value in six.iteritems(values2): + for key, value in six.iteritems(values): assert stored[key] == value ab_pairs2.add((value["a"], value["b"])) @@ -550,63 +596,72 @@ def test_query_stream(client, cleanup): ) assert expected_ab_pairs == ab_pairs2 - # 3. Use a start and end cursor. - query3 = ( + +def test_query_stream_w_start_end_cursor(query_docs): + collection, stored, allowed_vals = query_docs + num_vals = len(allowed_vals) + query = ( collection.order_by("a") .start_at({"a": num_vals - 2}) .end_before({"a": num_vals - 1}) ) - values3 = [(snapshot.id, snapshot.to_dict()) for snapshot in query3.stream()] - assert len(values3) == num_vals - for key, value in values3: + values = [(snapshot.id, snapshot.to_dict()) for snapshot in query.stream()] + assert len(values) == num_vals + for key, value in values: assert stored[key] == value assert value["a"] == num_vals - 2 - b_vals1.append(value["b"]) - - # 4. Send a query with no results. - query4 = collection.where("b", "==", num_vals + 100) - values4 = list(query4.stream()) - assert len(values4) == 0 - - # 5. Select a subset of fields. - query5 = collection.where("b", "<=", 1) - query5 = query5.select(["a", "stats.product"]) - values5 = {snapshot.id: snapshot.to_dict() for snapshot in query5.stream()} - assert len(values5) == num_vals * 2 # a ANY, b in (0, 1) - for key, value in six.iteritems(values5): + + +def test_query_stream_wo_results(query_docs): + collection, stored, allowed_vals = query_docs + num_vals = len(allowed_vals) + query = collection.where("b", "==", num_vals + 100) + values = list(query.stream()) + assert len(values) == 0 + + +def test_query_stream_w_projection(query_docs): + collection, stored, allowed_vals = query_docs + num_vals = len(allowed_vals) + query = collection.where("b", "<=", 1).select(["a", "stats.product"]) + values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} + assert len(values) == num_vals * 2 # a ANY, b in (0, 1) + for key, value in six.iteritems(values): expected = { "a": stored[key]["a"], "stats": {"product": stored[key]["stats"]["product"]}, } assert expected == value - # 6. Add multiple filters via ``where()``. - query6 = collection.where("stats.product", ">", 5) - query6 = query6.where("stats.product", "<", 10) - values6 = {snapshot.id: snapshot.to_dict() for snapshot in query6.stream()} +def test_query_stream_w_multiple_filters(query_docs): + collection, stored, allowed_vals = query_docs + query = collection.where("stats.product", ">", 5).where("stats.product", "<", 10) + values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} matching_pairs = [ (a_val, b_val) for a_val in allowed_vals for b_val in allowed_vals if 5 < a_val * b_val < 10 ] - assert len(values6) == len(matching_pairs) - for key, value in six.iteritems(values6): + assert len(values) == len(matching_pairs) + for key, value in six.iteritems(values): assert stored[key] == value pair = (value["a"], value["b"]) assert pair in matching_pairs - # 7. Skip the first three results, when ``b==2`` - query7 = collection.where("b", "==", 2) + +def test_query_stream_w_offset(query_docs): + collection, stored, allowed_vals = query_docs + num_vals = len(allowed_vals) offset = 3 - query7 = query7.offset(offset) - values7 = {snapshot.id: snapshot.to_dict() for snapshot in query7.stream()} + query = collection.where("b", "==", 2).offset(offset) + values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} # NOTE: We don't check the ``a``-values, since that would require # an ``order_by('a')``, which combined with the ``b == 2`` # filter would necessitate an index. - assert len(values7) == num_vals - offset - for key, value in six.iteritems(values7): + assert len(values) == num_vals - offset + for key, value in six.iteritems(values): assert stored[key] == value assert value["b"] == 2 diff --git a/firestore/tests/unit/v1/test_query.py b/firestore/tests/unit/v1/test_query.py index a4911fecb44f..bdb0e922d00b 100644 --- a/firestore/tests/unit/v1/test_query.py +++ b/firestore/tests/unit/v1/test_query.py @@ -1464,18 +1464,47 @@ def _call_fut(op_string): return _enum_from_op_string(op_string) - def test_success(self): + @staticmethod + def _get_op_class(): from google.cloud.firestore_v1.gapic import enums - op_class = enums.StructuredQuery.FieldFilter.Operator + return enums.StructuredQuery.FieldFilter.Operator + + def test_lt(self): + op_class = self._get_op_class() self.assertEqual(self._call_fut("<"), op_class.LESS_THAN) + + def test_le(self): + op_class = self._get_op_class() self.assertEqual(self._call_fut("<="), op_class.LESS_THAN_OR_EQUAL) + + def test_eq(self): + op_class = self._get_op_class() self.assertEqual(self._call_fut("=="), op_class.EQUAL) + + def test_ge(self): + op_class = self._get_op_class() self.assertEqual(self._call_fut(">="), op_class.GREATER_THAN_OR_EQUAL) + + def test_gt(self): + op_class = self._get_op_class() self.assertEqual(self._call_fut(">"), op_class.GREATER_THAN) + + def test_array_contains(self): + op_class = self._get_op_class() self.assertEqual(self._call_fut("array_contains"), op_class.ARRAY_CONTAINS) - def test_failure(self): + def test_in(self): + op_class = self._get_op_class() + self.assertEqual(self._call_fut("in"), op_class.IN) + + def test_array_contains_any(self): + op_class = self._get_op_class() + self.assertEqual( + self._call_fut("array_contains_any"), op_class.ARRAY_CONTAINS_ANY + ) + + def test_invalid(self): with self.assertRaises(ValueError): self._call_fut("?")