Skip to content

Commit

Permalink
add type stubs and misc fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
coszio committed Sep 16, 2024
1 parent 2f52d0c commit 743fea2
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 25 deletions.
6 changes: 3 additions & 3 deletions qdrant_client/async_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,7 +1265,7 @@ async def facet(
facet_filter: Optional[types.Filter] = None,
limit: int = 10,
exact: bool = False,
read_consistency: Optional[types.ReadConsistency] = None,
consistency: Optional[types.ReadConsistency] = None,
timeout: Optional[int] = None,
shard_key_selector: Optional[types.ShardKeySelector] = None,
**kwargs: Any,
Expand All @@ -1280,7 +1280,7 @@ async def facet(
limit: Maximum number of hits to return
exact: If `True` - provide the exact count of points matching the filter. If `False` - provide the approximate count of points matching the filter. Works faster.
read_consistency:
consistency:
Read consistency of the search. Defines how many replicas should be queried before returning the result. Values:
- int - number of replicas to query, values should present in all queried replicas
Expand All @@ -1302,7 +1302,7 @@ async def facet(
facet_filter=facet_filter,
limit=limit,
exact=exact,
read_consistency=read_consistency,
consistency=consistency,
timeout=timeout,
shard_key_selector=shard_key_selector,
**kwargs,
Expand Down
10 changes: 5 additions & 5 deletions qdrant_client/async_qdrant_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -1397,7 +1397,7 @@ async def facet(
limit: int = 10,
exact: bool = False,
timeout: Optional[int] = None,
read_consistency: Optional[types.ReadConsistency] = None,
consistency: Optional[types.ReadConsistency] = None,
shard_key_selector: Optional[types.ShardKeySelector] = None,
**kwargs: Any,
) -> types.FacetResponse:
Expand All @@ -1406,8 +1406,8 @@ async def facet(
facet_filter = RestToGrpc.convert_filter(model=facet_filter)
if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
if isinstance(read_consistency, get_args_subscribed(models.ReadConsistency)):
read_consistency = RestToGrpc.convert_read_consistency(read_consistency)
if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
consistency = RestToGrpc.convert_read_consistency(consistency)
response = await self.grpc_points.Facet(
grpc.FacetCounts(
collection_name=collection_name,
Expand All @@ -1416,7 +1416,7 @@ async def facet(
limit=limit,
exact=exact,
timeout=timeout,
read_consistency=read_consistency,
read_consistency=consistency,
shard_key_selector=shard_key_selector,
),
timeout=timeout if timeout is not None else self._timeout,
Expand All @@ -1427,7 +1427,7 @@ async def facet(
facet_result = (
await self.openapi_client.points_api.facet(
collection_name=collection_name,
consistency=read_consistency,
consistency=consistency,
timeout=timeout,
facet_request=models.FacetRequest(
shard_key=shard_key_selector,
Expand Down
1 change: 1 addition & 0 deletions qdrant_client/conversions/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def get_args_subscribed(tp): # type: ignore
GroupsResult: TypeAlias = rest.GroupsResult
QueryResponse: TypeAlias = rest.QueryResponse

FacetValue: TypeAlias = rest.FacetValue
FacetResponse: TypeAlias = rest.FacetResponse

VersionInfo: TypeAlias = rest.VersionInfo
Expand Down
2 changes: 1 addition & 1 deletion qdrant_client/local/async_qdrant_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ async def facet(
limit: int = 10,
exact: bool = False,
**kwargs: Any,
):
) -> types.FacetResponse:
collection = self._get_collection(collection_name)
return collection.facet(key=key, facet_filter=facet_filter, limit=limit)

Expand Down
14 changes: 7 additions & 7 deletions qdrant_client/local/local_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,8 +1091,8 @@ def facet(
key: str,
facet_filter: Optional[types.Filter] = None,
limit: int = 10,
) -> models.FacetResponse:
facet_hits = defaultdict(int)
) -> types.FacetResponse:
facet_hits: Dict[types.FacetValue, int] = defaultdict(int)

mask = self._payload_and_non_deleted_mask(facet_filter)

Expand All @@ -1109,10 +1109,11 @@ def facet(
continue

# Only count the same value for each point once
values_set: set[models.FacetValue] = set()
values_set: set[types.FacetValue] = set()

# Sanitize to use only valid values
for v in values:
if not isinstance(v, get_args_subscribed(models.FacetValue)):
if not isinstance(v, get_args_subscribed(types.FacetValue)):
continue

# If values are UUIDs, format with hyphens
Expand All @@ -1123,8 +1124,7 @@ def facet(
values_set.add(v)

for v in values_set:
if isinstance(v, get_args_subscribed(models.FacetValue)):
facet_hits[v] += 1
facet_hits[v] += 1

hits = [
models.FacetValueHit(value=value, count=count)
Expand All @@ -1135,7 +1135,7 @@ def facet(
)[:limit]
]

return models.FacetResponse(hits=hits)
return types.FacetResponse(hits=hits)

def retrieve(
self,
Expand Down
2 changes: 1 addition & 1 deletion qdrant_client/local/qdrant_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ def facet(
limit: int = 10,
exact: bool = False,
**kwargs: Any,
):
) -> types.FacetResponse:
collection = self._get_collection(collection_name)
return collection.facet(key=key, facet_filter=facet_filter, limit=limit)

Expand Down
6 changes: 3 additions & 3 deletions qdrant_client/qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,7 +1331,7 @@ def facet(
facet_filter: Optional[types.Filter] = None,
limit: int = 10,
exact: bool = False,
read_consistency: Optional[types.ReadConsistency] = None,
consistency: Optional[types.ReadConsistency] = None,
timeout: Optional[int] = None,
shard_key_selector: Optional[types.ShardKeySelector] = None,
**kwargs: Any,
Expand All @@ -1346,7 +1346,7 @@ def facet(
limit: Maximum number of hits to return
exact: If `True` - provide the exact count of points matching the filter. If `False` - provide the approximate count of points matching the filter. Works faster.
read_consistency:
consistency:
Read consistency of the search. Defines how many replicas should be queried before returning the result. Values:
- int - number of replicas to query, values should present in all queried replicas
Expand All @@ -1369,7 +1369,7 @@ def facet(
facet_filter=facet_filter,
limit=limit,
exact=exact,
read_consistency=read_consistency,
consistency=consistency,
timeout=timeout,
shard_key_selector=shard_key_selector,
**kwargs,
Expand Down
10 changes: 5 additions & 5 deletions qdrant_client/qdrant_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -1700,7 +1700,7 @@ def facet(
limit: int = 10,
exact: bool = False,
timeout: Optional[int] = None,
read_consistency: Optional[types.ReadConsistency] = None,
consistency: Optional[types.ReadConsistency] = None,
shard_key_selector: Optional[types.ShardKeySelector] = None,
**kwargs: Any,
) -> types.FacetResponse:
Expand All @@ -1711,8 +1711,8 @@ def facet(
if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)

if isinstance(read_consistency, get_args_subscribed(models.ReadConsistency)):
read_consistency = RestToGrpc.convert_read_consistency(read_consistency)
if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
consistency = RestToGrpc.convert_read_consistency(consistency)

response = self.grpc_points.Facet(
grpc.FacetCounts(
Expand All @@ -1722,7 +1722,7 @@ def facet(
limit=limit,
exact=exact,
timeout=timeout,
read_consistency=read_consistency,
read_consistency=consistency,
shard_key_selector=shard_key_selector,
),
timeout=timeout if timeout is not None else self._timeout,
Expand All @@ -1734,7 +1734,7 @@ def facet(

facet_result = self.openapi_client.points_api.facet(
collection_name=collection_name,
consistency=read_consistency,
consistency=consistency,
timeout=timeout,
facet_request=models.FacetRequest(
shard_key=shard_key_selector,
Expand Down
10 changes: 10 additions & 0 deletions tests/type_stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,13 @@
shard_key_selector=None,
timeout=1,
)
qdrant_client.facet(
collection_name="collection",
key="field",
facet_filter=rest_models.Filter(),
exact=True,
limit=10,
consistency=None,
shard_key_selector=None,
timeout=1,
)

0 comments on commit 743fea2

Please sign in to comment.