From b8bc752ca7c603ce5ff2504addf1870843d95b7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luis=20Coss=C3=ADo?= Date: Thu, 29 Aug 2024 11:10:29 -0400 Subject: [PATCH] congruence tests + local mode fixes --- qdrant_client/local/local_collection.py | 25 ++- .../local/payload_value_extractor.py | 13 ++ qdrant_client/qdrant_remote.py | 2 +- tests/congruence_tests/test_facet.py | 183 ++++++++++++++++++ 4 files changed, 217 insertions(+), 6 deletions(-) create mode 100644 tests/congruence_tests/test_facet.py diff --git a/qdrant_client/local/local_collection.py b/qdrant_client/local/local_collection.py index 1d20f25c2..c6e56d87f 100644 --- a/qdrant_client/local/local_collection.py +++ b/qdrant_client/local/local_collection.py @@ -20,6 +20,7 @@ from qdrant_client import grpc as grpc from qdrant_client._pydantic_compat import construct, to_jsonable_python as _to_jsonable_python from qdrant_client.conversions import common_types as types +from qdrant_client.conversions.common_types import get_args_subscribed from qdrant_client.conversions.conversion import GrpcToRest from qdrant_client.http import models from qdrant_client.http.models.models import Distance, ExtendedPointId, SparseVector, OrderValue @@ -51,7 +52,7 @@ from qdrant_client.local.json_path_parser import JsonPathItem, parse_json_path from qdrant_client.local.order_by import to_order_value from qdrant_client.local.payload_filters import calculate_payload_mask -from qdrant_client.local.payload_value_extractor import value_by_key +from qdrant_client.local.payload_value_extractor import value_by_key, parse_uuid from qdrant_client.local.payload_value_setter import set_value_by_key from qdrant_client.local.persistence import CollectionPersistence from qdrant_client.local.sparse import ( @@ -1102,13 +1103,27 @@ def facet( if not isinstance(payload, dict): continue - value = value_by_key(payload, key) + values = value_by_key(payload, key) - if value is None: + if values is None: continue - for v in value: - if isinstance(v, get_args(models.FacetValue)): + # Only count the same value for each point once + values_set: set[models.FacetValue] = set() + + for v in values: + if not isinstance(v, get_args_subscribed(models.FacetValue)): + continue + + # If values are UUIDs, format with hyphens + as_uuid = parse_uuid(v) + if as_uuid: + v = str(as_uuid) + + values_set.add(v) + + for v in values_set: + if isinstance(v, get_args_subscribed(models.FacetValue)): facet_hits[v] += 1 hits = [ diff --git a/qdrant_client/local/payload_value_extractor.py b/qdrant_client/local/payload_value_extractor.py index f7615c3e7..d7991ed94 100644 --- a/qdrant_client/local/payload_value_extractor.py +++ b/qdrant_client/local/payload_value_extractor.py @@ -1,3 +1,4 @@ +import uuid from typing import Any, List, Optional from qdrant_client.local.json_path_parser import ( @@ -77,3 +78,15 @@ def _get_value(data: Any, k_list: List[JsonPathItem]) -> None: _get_value(payload, keys) return result if result else None + + +def parse_uuid(value: Any) -> Optional[uuid.UUID]: + """ + Parse UUID from value. + Args: + value: arbitrary value + """ + try: + return uuid.UUID(str(value)) + except ValueError: + return None diff --git a/qdrant_client/qdrant_remote.py b/qdrant_client/qdrant_remote.py index 758fbd0d5..2c56a2788 100644 --- a/qdrant_client/qdrant_remote.py +++ b/qdrant_client/qdrant_remote.py @@ -1726,7 +1726,7 @@ def facet( shard_key_selector=shard_key_selector, ), timeout=timeout if timeout is not None else self._timeout, - ).result + ) return GrpcToRest.convert_facet_response(response) if isinstance(facet_filter, grpc.Filter): diff --git a/tests/congruence_tests/test_facet.py b/tests/congruence_tests/test_facet.py new file mode 100644 index 000000000..2353f833e --- /dev/null +++ b/tests/congruence_tests/test_facet.py @@ -0,0 +1,183 @@ +import random +import time +from typing import List + +import pytest + +from qdrant_client import QdrantClient, models +from qdrant_client.client_base import QdrantBase +from tests.congruence_tests.test_common import ( + COLLECTION_NAME, + compare_client_results, + generate_fixtures, + init_client, + init_local, + init_remote, +) +from tests.fixtures.filters import one_random_filter_please + +INT_KEY = "rand_digit" +INT_ID_KEY = "id" +UUID_KEY = "text_array" +STRING_ID_KEY = "id_str" +STRING_KEY = "city.name" + + +def all_facet_keys() -> List[str]: + return [INT_KEY, INT_ID_KEY, UUID_KEY, STRING_ID_KEY, STRING_KEY] + + +@pytest.fixture(scope="module") +def fixture_points() -> List[models.PointStruct]: + return generate_fixtures() + + +@pytest.fixture(scope="module", autouse=True) +def local_client(fixture_points) -> QdrantClient: + client = init_local() + init_client(client, fixture_points) + return client + + +@pytest.fixture(scope="module", autouse=True) +def http_client(fixture_points) -> QdrantClient: + client = init_remote() + init_client(client, fixture_points) + client.create_payload_index( + collection_name=COLLECTION_NAME, + field_name=INT_KEY, + field_schema=models.PayloadSchemaType.INTEGER, + ) + client.create_payload_index( + collection_name=COLLECTION_NAME, + field_name=INT_ID_KEY, + field_schema=models.PayloadSchemaType.INTEGER, + ) + client.create_payload_index( + collection_name=COLLECTION_NAME, + field_name=UUID_KEY, + field_schema=models.PayloadSchemaType.UUID, + ) + client.create_payload_index( + collection_name=COLLECTION_NAME, + field_name=STRING_KEY, + field_schema=models.PayloadSchemaType.KEYWORD, + ) + client.create_payload_index( + collection_name=COLLECTION_NAME, + field_name=STRING_ID_KEY, + field_schema=models.PayloadSchemaType.KEYWORD, + ) + return client + + +@pytest.fixture(scope="module", autouse=True) +def grpc_client(fixture_points) -> QdrantClient: + client = init_remote(prefer_grpc=True) + return client + + +def test_minimal( + local_client, + http_client, + grpc_client, +): + def f(client: QdrantBase, facet_key: str, **kwargs) -> models.FacetResponse: + return client.facet( + collection_name=COLLECTION_NAME, + key=facet_key, + ) + + for key in all_facet_keys(): + compare_client_results(grpc_client, http_client, f, facet_key=key) + compare_client_results(local_client, http_client, f, facet_key=key) + + +def test_limit( + local_client, + http_client, + grpc_client, +): + def f(client: QdrantBase, facet_key: str, limit: int, **kwargs) -> models.FacetResponse: + return client.facet( + collection_name=COLLECTION_NAME, + key=facet_key, + limit=limit, + ) + + for _ in range(10): + rand_num = random.randint(1, 100) + for key in all_facet_keys(): + compare_client_results(grpc_client, http_client, f, facet_key=key, limit=rand_num) + compare_client_results(local_client, http_client, f, facet_key=key, limit=rand_num) + + +def test_exact( + local_client, + http_client, + grpc_client, +): + def f(client: QdrantBase, facet_key: str, **kwargs) -> models.FacetResponse: + return client.facet( + collection_name=COLLECTION_NAME, + key=facet_key, + limit=5000, + exact=True, + ) + + for key in all_facet_keys(): + compare_client_results(grpc_client, http_client, f, facet_key=key) + compare_client_results(local_client, http_client, f, facet_key=key) + + +def test_filtered( + local_client, + http_client, + grpc_client, +): + def f( + client: QdrantBase, facet_key: str, facet_filter: models.Filter, **kwargs + ) -> models.FacetResponse: + return client.facet( + collection_name=COLLECTION_NAME, + key=facet_key, + facet_filter=facet_filter, + exact=False, + ) + + for key in all_facet_keys(): + filter_ = one_random_filter_please() + for _ in range(10): + compare_client_results( + grpc_client, http_client, f, facet_key=key, facet_filter=filter_ + ) + compare_client_results( + local_client, http_client, f, facet_key=key, facet_filter=filter_ + ) + + +def test_exact_filtered( + local_client, + http_client, + grpc_client, +): + def f( + client: QdrantBase, facet_key: str, facet_filter: models.Filter, **kwargs + ) -> models.FacetResponse: + return client.facet( + collection_name=COLLECTION_NAME, + key=facet_key, + limit=5000, + exact=True, + facet_filter=facet_filter, + ) + + for key in all_facet_keys(): + for _ in range(10): + filter_ = one_random_filter_please() + compare_client_results( + grpc_client, http_client, f, facet_key=key, facet_filter=filter_ + ) + compare_client_results( + local_client, http_client, f, facet_key=key, facet_filter=filter_ + )