Skip to content

Commit

Permalink
congruence tests + local mode fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
coszio committed Sep 16, 2024
1 parent da8ce9f commit b8bc752
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 6 deletions.
25 changes: 20 additions & 5 deletions qdrant_client/local/local_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from qdrant_client import grpc as grpc
from qdrant_client._pydantic_compat import construct, to_jsonable_python as _to_jsonable_python
from qdrant_client.conversions import common_types as types
from qdrant_client.conversions.common_types import get_args_subscribed
from qdrant_client.conversions.conversion import GrpcToRest
from qdrant_client.http import models
from qdrant_client.http.models.models import Distance, ExtendedPointId, SparseVector, OrderValue
Expand Down Expand Up @@ -51,7 +52,7 @@
from qdrant_client.local.json_path_parser import JsonPathItem, parse_json_path
from qdrant_client.local.order_by import to_order_value
from qdrant_client.local.payload_filters import calculate_payload_mask
from qdrant_client.local.payload_value_extractor import value_by_key
from qdrant_client.local.payload_value_extractor import value_by_key, parse_uuid
from qdrant_client.local.payload_value_setter import set_value_by_key
from qdrant_client.local.persistence import CollectionPersistence
from qdrant_client.local.sparse import (
Expand Down Expand Up @@ -1102,13 +1103,27 @@ def facet(
if not isinstance(payload, dict):
continue

value = value_by_key(payload, key)
values = value_by_key(payload, key)

if value is None:
if values is None:
continue

for v in value:
if isinstance(v, get_args(models.FacetValue)):
# Only count the same value for each point once
values_set: set[models.FacetValue] = set()

for v in values:
if not isinstance(v, get_args_subscribed(models.FacetValue)):
continue

# If values are UUIDs, format with hyphens
as_uuid = parse_uuid(v)
if as_uuid:
v = str(as_uuid)

values_set.add(v)

for v in values_set:
if isinstance(v, get_args_subscribed(models.FacetValue)):
facet_hits[v] += 1

hits = [
Expand Down
13 changes: 13 additions & 0 deletions qdrant_client/local/payload_value_extractor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import uuid
from typing import Any, List, Optional

from qdrant_client.local.json_path_parser import (
Expand Down Expand Up @@ -77,3 +78,15 @@ def _get_value(data: Any, k_list: List[JsonPathItem]) -> None:

_get_value(payload, keys)
return result if result else None


def parse_uuid(value: Any) -> Optional[uuid.UUID]:
"""
Parse UUID from value.
Args:
value: arbitrary value
"""
try:
return uuid.UUID(str(value))
except ValueError:
return None
2 changes: 1 addition & 1 deletion qdrant_client/qdrant_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -1726,7 +1726,7 @@ def facet(
shard_key_selector=shard_key_selector,
),
timeout=timeout if timeout is not None else self._timeout,
).result
)
return GrpcToRest.convert_facet_response(response)

if isinstance(facet_filter, grpc.Filter):
Expand Down
183 changes: 183 additions & 0 deletions tests/congruence_tests/test_facet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import random
import time
from typing import List

import pytest

from qdrant_client import QdrantClient, models
from qdrant_client.client_base import QdrantBase
from tests.congruence_tests.test_common import (
COLLECTION_NAME,
compare_client_results,
generate_fixtures,
init_client,
init_local,
init_remote,
)
from tests.fixtures.filters import one_random_filter_please

INT_KEY = "rand_digit"
INT_ID_KEY = "id"
UUID_KEY = "text_array"
STRING_ID_KEY = "id_str"
STRING_KEY = "city.name"


def all_facet_keys() -> List[str]:
return [INT_KEY, INT_ID_KEY, UUID_KEY, STRING_ID_KEY, STRING_KEY]


@pytest.fixture(scope="module")
def fixture_points() -> List[models.PointStruct]:
return generate_fixtures()


@pytest.fixture(scope="module", autouse=True)
def local_client(fixture_points) -> QdrantClient:
client = init_local()
init_client(client, fixture_points)
return client


@pytest.fixture(scope="module", autouse=True)
def http_client(fixture_points) -> QdrantClient:
client = init_remote()
init_client(client, fixture_points)
client.create_payload_index(
collection_name=COLLECTION_NAME,
field_name=INT_KEY,
field_schema=models.PayloadSchemaType.INTEGER,
)
client.create_payload_index(
collection_name=COLLECTION_NAME,
field_name=INT_ID_KEY,
field_schema=models.PayloadSchemaType.INTEGER,
)
client.create_payload_index(
collection_name=COLLECTION_NAME,
field_name=UUID_KEY,
field_schema=models.PayloadSchemaType.UUID,
)
client.create_payload_index(
collection_name=COLLECTION_NAME,
field_name=STRING_KEY,
field_schema=models.PayloadSchemaType.KEYWORD,
)
client.create_payload_index(
collection_name=COLLECTION_NAME,
field_name=STRING_ID_KEY,
field_schema=models.PayloadSchemaType.KEYWORD,
)
return client


@pytest.fixture(scope="module", autouse=True)
def grpc_client(fixture_points) -> QdrantClient:
client = init_remote(prefer_grpc=True)
return client


def test_minimal(
local_client,
http_client,
grpc_client,
):
def f(client: QdrantBase, facet_key: str, **kwargs) -> models.FacetResponse:
return client.facet(
collection_name=COLLECTION_NAME,
key=facet_key,
)

for key in all_facet_keys():
compare_client_results(grpc_client, http_client, f, facet_key=key)
compare_client_results(local_client, http_client, f, facet_key=key)


def test_limit(
local_client,
http_client,
grpc_client,
):
def f(client: QdrantBase, facet_key: str, limit: int, **kwargs) -> models.FacetResponse:
return client.facet(
collection_name=COLLECTION_NAME,
key=facet_key,
limit=limit,
)

for _ in range(10):
rand_num = random.randint(1, 100)
for key in all_facet_keys():
compare_client_results(grpc_client, http_client, f, facet_key=key, limit=rand_num)
compare_client_results(local_client, http_client, f, facet_key=key, limit=rand_num)


def test_exact(
local_client,
http_client,
grpc_client,
):
def f(client: QdrantBase, facet_key: str, **kwargs) -> models.FacetResponse:
return client.facet(
collection_name=COLLECTION_NAME,
key=facet_key,
limit=5000,
exact=True,
)

for key in all_facet_keys():
compare_client_results(grpc_client, http_client, f, facet_key=key)
compare_client_results(local_client, http_client, f, facet_key=key)


def test_filtered(
local_client,
http_client,
grpc_client,
):
def f(
client: QdrantBase, facet_key: str, facet_filter: models.Filter, **kwargs
) -> models.FacetResponse:
return client.facet(
collection_name=COLLECTION_NAME,
key=facet_key,
facet_filter=facet_filter,
exact=False,
)

for key in all_facet_keys():
filter_ = one_random_filter_please()
for _ in range(10):
compare_client_results(
grpc_client, http_client, f, facet_key=key, facet_filter=filter_
)
compare_client_results(
local_client, http_client, f, facet_key=key, facet_filter=filter_
)


def test_exact_filtered(
local_client,
http_client,
grpc_client,
):
def f(
client: QdrantBase, facet_key: str, facet_filter: models.Filter, **kwargs
) -> models.FacetResponse:
return client.facet(
collection_name=COLLECTION_NAME,
key=facet_key,
limit=5000,
exact=True,
facet_filter=facet_filter,
)

for key in all_facet_keys():
for _ in range(10):
filter_ = one_random_filter_please()
compare_client_results(
grpc_client, http_client, f, facet_key=key, facet_filter=filter_
)
compare_client_results(
local_client, http_client, f, facet_key=key, facet_filter=filter_
)

0 comments on commit b8bc752

Please sign in to comment.