Skip to content

Commit

Permalink
generate async client
Browse files Browse the repository at this point in the history
  • Loading branch information
coszio committed Sep 16, 2024
1 parent b8bc752 commit 2f52d0c
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 36 deletions.
11 changes: 11 additions & 0 deletions qdrant_client/async_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,17 @@ async def count(
) -> types.CountResult:
raise NotImplementedError()

async def facet(
self,
collection_name: str,
key: str,
facet_filter: Optional[types.Filter] = None,
limit: int = 10,
exact: bool = False,
**kwargs: Any,
) -> types.FacetResponse:
raise NotImplementedError()

async def upsert(
self, collection_name: str, points: types.Points, **kwargs: Any
) -> types.UpdateResult:
Expand Down
60 changes: 53 additions & 7 deletions qdrant_client/async_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(
super().__init__(**kwargs)
self._init_options = {
key: value
for (key, value) in locals().items()
for key, value in locals().items()
if key not in ("self", "__class__", "kwargs")
}
self._init_options.update(deepcopy(kwargs))
Expand Down Expand Up @@ -504,9 +504,7 @@ async def query_points(
QueryResponse structure containing list of found close points with similarity scores.
"""
assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}"
(query, prefetch) = self._resolve_query_to_embedding_embeddings_and_prefetch(
query, prefetch
)
query, prefetch = self._resolve_query_to_embedding_embeddings_and_prefetch(query, prefetch)
return await self._client.query_points(
collection_name=collection_name,
query=query,
Expand Down Expand Up @@ -639,9 +637,7 @@ async def query_points_groups(
Each group also contains an id of the group, which is the value of the payload field.
"""
assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}"
(query, prefetch) = self._resolve_query_to_embedding_embeddings_and_prefetch(
query, prefetch
)
query, prefetch = self._resolve_query_to_embedding_embeddings_and_prefetch(query, prefetch)
return await self._client.query_points_groups(
collection_name=collection_name,
query=query,
Expand Down Expand Up @@ -1262,6 +1258,56 @@ async def count(
**kwargs,
)

async def facet(
self,
collection_name: str,
key: str,
facet_filter: Optional[types.Filter] = None,
limit: int = 10,
exact: bool = False,
read_consistency: Optional[types.ReadConsistency] = None,
timeout: Optional[int] = None,
shard_key_selector: Optional[types.ShardKeySelector] = None,
**kwargs: Any,
) -> types.FacetResponse:
"""Facet counts for the collection. For a specific payload key, returns unique values along with their counts.
Higher counts come first in the results.
Args:
collection_name: Name of the collection
key: Payload field to facet
facet_filter: Filter to apply
limit: Maximum number of hits to return
exact: If `True` - provide the exact count of points matching the filter. If `False` - provide the approximate count of points matching the filter. Works faster.
read_consistency:
Read consistency of the search. Defines how many replicas should be queried before returning the result. Values:
- int - number of replicas to query, values should present in all queried replicas
- 'majority' - query all replicas, but return values present in the majority of replicas
- 'quorum' - query the majority of replicas, return values present in all of them
- 'all' - query all replicas, and return values present in all replicas
timeout: Overrides global timeout for this search. Unit is seconds.
shard_key_selector:
This parameter allows to specify which shards should be queried.
If `None` - query all shards. Only works for collections with `custom` sharding method.
Returns:
Unique values in the facet and the amount of points that they cover.
"""
assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}"
return await self._client.facet(
collection_name=collection_name,
key=key,
facet_filter=facet_filter,
limit=limit,
exact=exact,
read_consistency=read_consistency,
timeout=timeout,
shard_key_selector=shard_key_selector,
**kwargs,
)

async def upsert(
self,
collection_name: str,
Expand Down
10 changes: 5 additions & 5 deletions qdrant_client/async_qdrant_fastembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def _embed_documents(
parallel: Optional[int] = None,
) -> Iterable[Tuple[str, List[float]]]:
embedding_model = self._get_or_init_model(model_name=embedding_model_name)
(documents_a, documents_b) = tee(documents, 2)
documents_a, documents_b = tee(documents, 2)
if embed_type == "passage":
vectors_iter = embedding_model.passage_embed(
documents_a, batch_size=batch_size, parallel=parallel
Expand Down Expand Up @@ -349,7 +349,7 @@ def _points_iterator(
yield models.PointStruct(id=idx, payload=payload, vector=point_vector)

def _validate_collection_info(self, collection_info: models.CollectionInfo) -> None:
(embeddings_size, distance) = self._get_model_params(model_name=self.embedding_model_name)
embeddings_size, distance = self._get_model_params(model_name=self.embedding_model_name)
vector_field_name = self.get_vector_field_name()
assert isinstance(
collection_info.config.params.vectors, dict
Expand Down Expand Up @@ -395,7 +395,7 @@ def get_fastembed_vector_params(
Configuration for `vectors_config` argument in `create_collection` method.
"""
vector_field_name = self.get_vector_field_name()
(embeddings_size, distance) = self._get_model_params(model_name=self.embedding_model_name)
embeddings_size, distance = self._get_model_params(model_name=self.embedding_model_name)
return {
vector_field_name: models.VectorParams(
size=embeddings_size,
Expand Down Expand Up @@ -651,7 +651,7 @@ async def query(
with_payload=True,
**kwargs,
)
(dense_request_response, sparse_request_response) = await self.search_batch(
dense_request_response, sparse_request_response = await self.search_batch(
collection_name=collection_name, requests=[dense_request, sparse_request]
)
return self._scored_points_to_query_responses(
Expand Down Expand Up @@ -728,6 +728,6 @@ async def query_batch(
sparse_responses = responses[len(query_texts) :]
responses = [
reciprocal_rank_fusion([dense_response, sparse_response], limit=limit)
for (dense_response, sparse_response) in zip(dense_responses, sparse_responses)
for dense_response, sparse_response in zip(dense_responses, sparse_responses)
]
return [self._scored_points_to_query_responses(response) for response in responses]
80 changes: 67 additions & 13 deletions qdrant_client/async_qdrant_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
if url.startswith("localhost"):
url = f"//{url}"
parsed_url: Url = parse_url(url)
(self._host, self._port) = (parsed_url.host, parsed_url.port)
self._host, self._port = (parsed_url.host, parsed_url.port)
if parsed_url.scheme:
self._https = parsed_url.scheme == "https"
self._scheme = parsed_url.scheme
Expand Down Expand Up @@ -174,7 +174,7 @@ async def close(self, grpc_grace: Optional[float] = None, **kwargs: Any) -> None
@staticmethod
def _parse_url(url: str) -> Tuple[Optional[str], str, Optional[int], Optional[str]]:
parse_result: Url = parse_url(url)
(scheme, host, port, prefix) = (
scheme, host, port, prefix = (
parse_result.scheme,
parse_result.host,
parse_result.port,
Expand Down Expand Up @@ -1017,7 +1017,7 @@ async def recommend_groups(
if isinstance(with_lookup, models.WithLookup):
with_lookup = RestToGrpc.convert_with_lookup(with_lookup)
if isinstance(with_lookup, str):
with_lookup = grpc.WithLookup(lookup_index=with_lookup)
with_lookup = grpc.WithLookup(collection=with_lookup)
positive_ids = RestToGrpc.convert_recommend_examples_to_ids(positive)
positive_vectors = RestToGrpc.convert_recommend_examples_to_vectors(positive)
negative_ids = RestToGrpc.convert_recommend_examples_to_ids(negative)
Expand Down Expand Up @@ -1389,6 +1389,58 @@ async def count(
assert count_result is not None, "Count points returned None result"
return count_result

async def facet(
self,
collection_name: str,
key: str,
facet_filter: Optional[types.Filter] = None,
limit: int = 10,
exact: bool = False,
timeout: Optional[int] = None,
read_consistency: Optional[types.ReadConsistency] = None,
shard_key_selector: Optional[types.ShardKeySelector] = None,
**kwargs: Any,
) -> types.FacetResponse:
if self._prefer_grpc:
if isinstance(facet_filter, models.Filter):
facet_filter = RestToGrpc.convert_filter(model=facet_filter)
if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
if isinstance(read_consistency, get_args_subscribed(models.ReadConsistency)):
read_consistency = RestToGrpc.convert_read_consistency(read_consistency)
response = await self.grpc_points.Facet(
grpc.FacetCounts(
collection_name=collection_name,
key=key,
filter=facet_filter,
limit=limit,
exact=exact,
timeout=timeout,
read_consistency=read_consistency,
shard_key_selector=shard_key_selector,
),
timeout=timeout if timeout is not None else self._timeout,
)
return GrpcToRest.convert_facet_response(response)
if isinstance(facet_filter, grpc.Filter):
facet_filter = GrpcToRest.convert_filter(model=facet_filter)
facet_result = (
await self.openapi_client.points_api.facet(
collection_name=collection_name,
consistency=read_consistency,
timeout=timeout,
facet_request=models.FacetRequest(
shard_key=shard_key_selector,
key=key,
limit=limit,
filter=facet_filter,
exact=exact,
),
)
).result
assert facet_result is not None, "Facet points returned None result"
return facet_result

async def upsert(
self,
collection_name: str,
Expand Down Expand Up @@ -1511,7 +1563,7 @@ async def delete_vectors(
**kwargs: Any,
) -> types.UpdateResult:
if self._prefer_grpc:
(points_selector, opt_shard_key_selector) = self._try_argument_to_grpc_selector(points)
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(points)
shard_key_selector = shard_key_selector or opt_shard_key_selector
if isinstance(ordering, models.WriteOrdering):
ordering = RestToGrpc.convert_write_ordering(ordering)
Expand All @@ -1532,7 +1584,7 @@ async def delete_vectors(
assert grpc_result is not None, "Delete vectors returned None result"
return GrpcToRest.convert_update_result(grpc_result)
else:
(_points, _filter) = self._try_argument_to_rest_points_and_filter(points)
_points, _filter = self._try_argument_to_rest_points_and_filter(points)
return (
await self.openapi_client.points_api.delete_vectors(
collection_name=collection_name,
Expand Down Expand Up @@ -1681,6 +1733,8 @@ def _points_selector_to_points_list(
cls, points_selector: grpc.PointsSelector
) -> List[grpc.PointId]:
name = points_selector.WhichOneof("points_selector_one_of")
if name is None:
return []
val = getattr(points_selector, name)
if name == "points":
return list(val.ids)
Expand Down Expand Up @@ -1725,7 +1779,7 @@ async def delete(
**kwargs: Any,
) -> types.UpdateResult:
if self._prefer_grpc:
(points_selector, opt_shard_key_selector) = self._try_argument_to_grpc_selector(
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(
points_selector
)
shard_key_selector = shard_key_selector or opt_shard_key_selector
Expand Down Expand Up @@ -1774,7 +1828,7 @@ async def set_payload(
**kwargs: Any,
) -> types.UpdateResult:
if self._prefer_grpc:
(points_selector, opt_shard_key_selector) = self._try_argument_to_grpc_selector(points)
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(points)
shard_key_selector = shard_key_selector or opt_shard_key_selector
if isinstance(ordering, models.WriteOrdering):
ordering = RestToGrpc.convert_write_ordering(ordering)
Expand All @@ -1797,7 +1851,7 @@ async def set_payload(
).result
)
else:
(_points, _filter) = self._try_argument_to_rest_points_and_filter(points)
_points, _filter = self._try_argument_to_rest_points_and_filter(points)
result: Optional[types.UpdateResult] = (
await self.openapi_client.points_api.set_payload(
collection_name=collection_name,
Expand Down Expand Up @@ -1826,7 +1880,7 @@ async def overwrite_payload(
**kwargs: Any,
) -> types.UpdateResult:
if self._prefer_grpc:
(points_selector, opt_shard_key_selector) = self._try_argument_to_grpc_selector(points)
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(points)
shard_key_selector = shard_key_selector or opt_shard_key_selector
if isinstance(ordering, models.WriteOrdering):
ordering = RestToGrpc.convert_write_ordering(ordering)
Expand All @@ -1848,7 +1902,7 @@ async def overwrite_payload(
).result
)
else:
(_points, _filter) = self._try_argument_to_rest_points_and_filter(points)
_points, _filter = self._try_argument_to_rest_points_and_filter(points)
result: Optional[types.UpdateResult] = (
await self.openapi_client.points_api.overwrite_payload(
collection_name=collection_name,
Expand Down Expand Up @@ -1876,7 +1930,7 @@ async def delete_payload(
**kwargs: Any,
) -> types.UpdateResult:
if self._prefer_grpc:
(points_selector, opt_shard_key_selector) = self._try_argument_to_grpc_selector(points)
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(points)
shard_key_selector = shard_key_selector or opt_shard_key_selector
if isinstance(ordering, models.WriteOrdering):
ordering = RestToGrpc.convert_write_ordering(ordering)
Expand All @@ -1898,7 +1952,7 @@ async def delete_payload(
).result
)
else:
(_points, _filter) = self._try_argument_to_rest_points_and_filter(points)
_points, _filter = self._try_argument_to_rest_points_and_filter(points)
result: Optional[types.UpdateResult] = (
await self.openapi_client.points_api.delete_payload(
collection_name=collection_name,
Expand All @@ -1922,7 +1976,7 @@ async def clear_payload(
**kwargs: Any,
) -> types.UpdateResult:
if self._prefer_grpc:
(points_selector, opt_shard_key_selector) = self._try_argument_to_grpc_selector(
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(
points_selector
)
shard_key_selector = shard_key_selector or opt_shard_key_selector
Expand Down
Loading

0 comments on commit 2f52d0c

Please sign in to comment.