Skip to content

Commit

Permalink
feature: Facets (#759)
Browse files Browse the repository at this point in the history
* facet in local_collection.py

* replace usages of calculate_payload_mask

* - qdrant remote
- qdrant base
- qdrant client
- conversions

* congruence tests + local mode fixes

* generate async client

* add type stubs and misc fixes

* fix mypy in Python 3.8

* generate async client

* review remarks

* gen async client

* update for bool facets
  • Loading branch information
coszio authored Oct 2, 2024
1 parent 1c600f5 commit 9f736ac
Show file tree
Hide file tree
Showing 22 changed files with 934 additions and 301 deletions.
11 changes: 11 additions & 0 deletions qdrant_client/async_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,17 @@ async def count(
) -> types.CountResult:
raise NotImplementedError()

async def facet(
self,
collection_name: str,
key: str,
facet_filter: Optional[types.Filter] = None,
limit: int = 10,
exact: bool = False,
**kwargs: Any,
) -> types.FacetResponse:
raise NotImplementedError()

async def upsert(
self, collection_name: str, points: types.Points, **kwargs: Any
) -> types.UpdateResult:
Expand Down
50 changes: 50 additions & 0 deletions qdrant_client/async_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,56 @@ async def count(
**kwargs,
)

async def facet(
self,
collection_name: str,
key: str,
facet_filter: Optional[types.Filter] = None,
limit: int = 10,
exact: bool = False,
consistency: Optional[types.ReadConsistency] = None,
timeout: Optional[int] = None,
shard_key_selector: Optional[types.ShardKeySelector] = None,
**kwargs: Any,
) -> types.FacetResponse:
"""Facet counts for the collection. For a specific payload key, returns unique values along with their counts.
Higher counts come first in the results.
Args:
collection_name: Name of the collection
key: Payload field to facet
facet_filter: Filter to apply
limit: Maximum number of hits to return
exact: If `True` - provide the exact count of points matching the filter. If `False` - provide the approximate count of points matching the filter. Works faster.
consistency:
Read consistency of the search. Defines how many replicas should be queried before returning the result. Values:
- int - number of replicas to query, values should present in all queried replicas
- 'majority' - query all replicas, but return values present in the majority of replicas
- 'quorum' - query the majority of replicas, return values present in all of them
- 'all' - query all replicas, and return values present in all replicas
timeout: Overrides global timeout for this search. Unit is seconds.
shard_key_selector:
This parameter allows to specify which shards should be queried.
If `None` - query all shards. Only works for collections with `custom` sharding method.
Returns:
Unique values in the facet and the amount of points that they cover.
"""
assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}"
return await self._client.facet(
collection_name=collection_name,
key=key,
facet_filter=facet_filter,
limit=limit,
exact=exact,
consistency=consistency,
timeout=timeout,
shard_key_selector=shard_key_selector,
**kwargs,
)

async def upsert(
self,
collection_name: str,
Expand Down
58 changes: 57 additions & 1 deletion qdrant_client/async_qdrant_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,7 @@ async def recommend_groups(
if isinstance(with_lookup, models.WithLookup):
with_lookup = RestToGrpc.convert_with_lookup(with_lookup)
if isinstance(with_lookup, str):
with_lookup = grpc.WithLookup(lookup_index=with_lookup)
with_lookup = grpc.WithLookup(collection=with_lookup)
positive_ids = RestToGrpc.convert_recommend_examples_to_ids(positive)
positive_vectors = RestToGrpc.convert_recommend_examples_to_vectors(positive)
negative_ids = RestToGrpc.convert_recommend_examples_to_ids(negative)
Expand Down Expand Up @@ -1389,6 +1389,60 @@ async def count(
assert count_result is not None, "Count points returned None result"
return count_result

async def facet(
self,
collection_name: str,
key: str,
facet_filter: Optional[types.Filter] = None,
limit: int = 10,
exact: bool = False,
timeout: Optional[int] = None,
consistency: Optional[types.ReadConsistency] = None,
shard_key_selector: Optional[types.ShardKeySelector] = None,
**kwargs: Any,
) -> types.FacetResponse:
if self._prefer_grpc:
if isinstance(facet_filter, models.Filter):
facet_filter = RestToGrpc.convert_filter(model=facet_filter)
if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
consistency = RestToGrpc.convert_read_consistency(consistency)
response = await self.grpc_points.Facet(
grpc.FacetCounts(
collection_name=collection_name,
key=key,
filter=facet_filter,
limit=limit,
exact=exact,
timeout=timeout,
read_consistency=consistency,
shard_key_selector=shard_key_selector,
),
timeout=timeout if timeout is not None else self._timeout,
)
return types.FacetResponse(
hits=[GrpcToRest.convert_facet_value_hit(hit) for hit in response.hits]
)
if isinstance(facet_filter, grpc.Filter):
facet_filter = GrpcToRest.convert_filter(model=facet_filter)
facet_result = (
await self.openapi_client.points_api.facet(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
facet_request=models.FacetRequest(
shard_key=shard_key_selector,
key=key,
limit=limit,
filter=facet_filter,
exact=exact,
),
)
).result
assert facet_result is not None, "Facet points returned None result"
return facet_result

async def upsert(
self,
collection_name: str,
Expand Down Expand Up @@ -1681,6 +1735,8 @@ def _points_selector_to_points_list(
cls, points_selector: grpc.PointsSelector
) -> List[grpc.PointId]:
name = points_selector.WhichOneof("points_selector_one_of")
if name is None:
return []
val = getattr(points_selector, name)
if name == "points":
return list(val.ids)
Expand Down
11 changes: 11 additions & 0 deletions qdrant_client/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,17 @@ def count(
) -> types.CountResult:
raise NotImplementedError()

def facet(
self,
collection_name: str,
key: str,
facet_filter: Optional[types.Filter] = None,
limit: int = 10,
exact: bool = False,
**kwargs: Any,
) -> types.FacetResponse:
raise NotImplementedError()

def upsert(
self,
collection_name: str,
Expand Down
16 changes: 13 additions & 3 deletions qdrant_client/conversions/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from typing import List, Union, get_args, Sequence

from qdrant_client import grpc as grpc
from qdrant_client import grpc
from qdrant_client.http import models as rest

typing_remap = {
Expand Down Expand Up @@ -58,12 +58,19 @@ def get_args_subscribed(tp): # type: ignore
QuantizationConfig = Union[rest.QuantizationConfig, grpc.QuantizationConfig]
PointId = Union[int, str, grpc.PointId]
PayloadSchemaType = Union[
rest.PayloadSchemaType, rest.PayloadSchemaParams, int, grpc.PayloadIndexParams
rest.PayloadSchemaType,
rest.PayloadSchemaParams,
int,
grpc.PayloadIndexParams,
] # type(grpc.PayloadSchemaType) == int
PointStruct: TypeAlias = rest.PointStruct
Points = Union[rest.Batch, Sequence[Union[rest.PointStruct, grpc.PointStruct]]]
PointsSelector = Union[
List[PointId], rest.Filter, grpc.Filter, rest.PointsSelector, grpc.PointsSelector
List[PointId],
rest.Filter,
grpc.Filter,
rest.PointsSelector,
grpc.PointsSelector,
]
LookupLocation = Union[rest.LookupLocation, grpc.LookupLocation]
RecommendStrategy: TypeAlias = rest.RecommendStrategy
Expand Down Expand Up @@ -120,6 +127,9 @@ def get_args_subscribed(tp): # type: ignore
GroupsResult: TypeAlias = rest.GroupsResult
QueryResponse: TypeAlias = rest.QueryResponse

FacetValue: TypeAlias = rest.FacetValue
FacetResponse: TypeAlias = rest.FacetResponse

VersionInfo: TypeAlias = rest.VersionInfo

# we can't use `nptyping` package due to numpy/python-version incompatibilities
Expand Down
Loading

0 comments on commit 9f736ac

Please sign in to comment.