Skip to content

Commit

Permalink
generate async client
Browse files Browse the repository at this point in the history
  • Loading branch information
coszio committed Sep 16, 2024
1 parent 909d55a commit d8ccef6
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 31 deletions.
10 changes: 7 additions & 3 deletions qdrant_client/async_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(
super().__init__(**kwargs)
self._init_options = {
key: value
for key, value in locals().items()
for (key, value) in locals().items()
if key not in ("self", "__class__", "kwargs")
}
self._init_options.update(deepcopy(kwargs))
Expand Down Expand Up @@ -504,7 +504,9 @@ async def query_points(
QueryResponse structure containing list of found close points with similarity scores.
"""
assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}"
query, prefetch = self._resolve_query_to_embedding_embeddings_and_prefetch(query, prefetch)
(query, prefetch) = self._resolve_query_to_embedding_embeddings_and_prefetch(
query, prefetch
)
return await self._client.query_points(
collection_name=collection_name,
query=query,
Expand Down Expand Up @@ -637,7 +639,9 @@ async def query_points_groups(
Each group also contains an id of the group, which is the value of the payload field.
"""
assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}"
query, prefetch = self._resolve_query_to_embedding_embeddings_and_prefetch(query, prefetch)
(query, prefetch) = self._resolve_query_to_embedding_embeddings_and_prefetch(
query, prefetch
)
return await self._client.query_points_groups(
collection_name=collection_name,
query=query,
Expand Down
10 changes: 5 additions & 5 deletions qdrant_client/async_qdrant_fastembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def _embed_documents(
parallel: Optional[int] = None,
) -> Iterable[Tuple[str, List[float]]]:
embedding_model = self._get_or_init_model(model_name=embedding_model_name)
documents_a, documents_b = tee(documents, 2)
(documents_a, documents_b) = tee(documents, 2)
if embed_type == "passage":
vectors_iter = embedding_model.passage_embed(
documents_a, batch_size=batch_size, parallel=parallel
Expand Down Expand Up @@ -349,7 +349,7 @@ def _points_iterator(
yield models.PointStruct(id=idx, payload=payload, vector=point_vector)

def _validate_collection_info(self, collection_info: models.CollectionInfo) -> None:
embeddings_size, distance = self._get_model_params(model_name=self.embedding_model_name)
(embeddings_size, distance) = self._get_model_params(model_name=self.embedding_model_name)
vector_field_name = self.get_vector_field_name()
assert isinstance(
collection_info.config.params.vectors, dict
Expand Down Expand Up @@ -395,7 +395,7 @@ def get_fastembed_vector_params(
Configuration for `vectors_config` argument in `create_collection` method.
"""
vector_field_name = self.get_vector_field_name()
embeddings_size, distance = self._get_model_params(model_name=self.embedding_model_name)
(embeddings_size, distance) = self._get_model_params(model_name=self.embedding_model_name)
return {
vector_field_name: models.VectorParams(
size=embeddings_size,
Expand Down Expand Up @@ -651,7 +651,7 @@ async def query(
with_payload=True,
**kwargs,
)
dense_request_response, sparse_request_response = await self.search_batch(
(dense_request_response, sparse_request_response) = await self.search_batch(
collection_name=collection_name, requests=[dense_request, sparse_request]
)
return self._scored_points_to_query_responses(
Expand Down Expand Up @@ -728,6 +728,6 @@ async def query_batch(
sparse_responses = responses[len(query_texts) :]
responses = [
reciprocal_rank_fusion([dense_response, sparse_response], limit=limit)
for dense_response, sparse_response in zip(dense_responses, sparse_responses)
for (dense_response, sparse_response) in zip(dense_responses, sparse_responses)
]
return [self._scored_points_to_query_responses(response) for response in responses]
24 changes: 12 additions & 12 deletions qdrant_client/async_qdrant_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
if url.startswith("localhost"):
url = f"//{url}"
parsed_url: Url = parse_url(url)
self._host, self._port = (parsed_url.host, parsed_url.port)
(self._host, self._port) = (parsed_url.host, parsed_url.port)
if parsed_url.scheme:
self._https = parsed_url.scheme == "https"
self._scheme = parsed_url.scheme
Expand Down Expand Up @@ -174,7 +174,7 @@ async def close(self, grpc_grace: Optional[float] = None, **kwargs: Any) -> None
@staticmethod
def _parse_url(url: str) -> Tuple[Optional[str], str, Optional[int], Optional[str]]:
parse_result: Url = parse_url(url)
scheme, host, port, prefix = (
(scheme, host, port, prefix) = (
parse_result.scheme,
parse_result.host,
parse_result.port,
Expand Down Expand Up @@ -1563,7 +1563,7 @@ async def delete_vectors(
**kwargs: Any,
) -> types.UpdateResult:
if self._prefer_grpc:
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(points)
(points_selector, opt_shard_key_selector) = self._try_argument_to_grpc_selector(points)
shard_key_selector = shard_key_selector or opt_shard_key_selector
if isinstance(ordering, models.WriteOrdering):
ordering = RestToGrpc.convert_write_ordering(ordering)
Expand All @@ -1584,7 +1584,7 @@ async def delete_vectors(
assert grpc_result is not None, "Delete vectors returned None result"
return GrpcToRest.convert_update_result(grpc_result)
else:
_points, _filter = self._try_argument_to_rest_points_and_filter(points)
(_points, _filter) = self._try_argument_to_rest_points_and_filter(points)
return (
await self.openapi_client.points_api.delete_vectors(
collection_name=collection_name,
Expand Down Expand Up @@ -1779,7 +1779,7 @@ async def delete(
**kwargs: Any,
) -> types.UpdateResult:
if self._prefer_grpc:
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(
(points_selector, opt_shard_key_selector) = self._try_argument_to_grpc_selector(
points_selector
)
shard_key_selector = shard_key_selector or opt_shard_key_selector
Expand Down Expand Up @@ -1828,7 +1828,7 @@ async def set_payload(
**kwargs: Any,
) -> types.UpdateResult:
if self._prefer_grpc:
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(points)
(points_selector, opt_shard_key_selector) = self._try_argument_to_grpc_selector(points)
shard_key_selector = shard_key_selector or opt_shard_key_selector
if isinstance(ordering, models.WriteOrdering):
ordering = RestToGrpc.convert_write_ordering(ordering)
Expand All @@ -1851,7 +1851,7 @@ async def set_payload(
).result
)
else:
_points, _filter = self._try_argument_to_rest_points_and_filter(points)
(_points, _filter) = self._try_argument_to_rest_points_and_filter(points)
result: Optional[types.UpdateResult] = (
await self.openapi_client.points_api.set_payload(
collection_name=collection_name,
Expand Down Expand Up @@ -1880,7 +1880,7 @@ async def overwrite_payload(
**kwargs: Any,
) -> types.UpdateResult:
if self._prefer_grpc:
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(points)
(points_selector, opt_shard_key_selector) = self._try_argument_to_grpc_selector(points)
shard_key_selector = shard_key_selector or opt_shard_key_selector
if isinstance(ordering, models.WriteOrdering):
ordering = RestToGrpc.convert_write_ordering(ordering)
Expand All @@ -1902,7 +1902,7 @@ async def overwrite_payload(
).result
)
else:
_points, _filter = self._try_argument_to_rest_points_and_filter(points)
(_points, _filter) = self._try_argument_to_rest_points_and_filter(points)
result: Optional[types.UpdateResult] = (
await self.openapi_client.points_api.overwrite_payload(
collection_name=collection_name,
Expand Down Expand Up @@ -1930,7 +1930,7 @@ async def delete_payload(
**kwargs: Any,
) -> types.UpdateResult:
if self._prefer_grpc:
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(points)
(points_selector, opt_shard_key_selector) = self._try_argument_to_grpc_selector(points)
shard_key_selector = shard_key_selector or opt_shard_key_selector
if isinstance(ordering, models.WriteOrdering):
ordering = RestToGrpc.convert_write_ordering(ordering)
Expand All @@ -1952,7 +1952,7 @@ async def delete_payload(
).result
)
else:
_points, _filter = self._try_argument_to_rest_points_and_filter(points)
(_points, _filter) = self._try_argument_to_rest_points_and_filter(points)
result: Optional[types.UpdateResult] = (
await self.openapi_client.points_api.delete_payload(
collection_name=collection_name,
Expand All @@ -1976,7 +1976,7 @@ async def clear_payload(
**kwargs: Any,
) -> types.UpdateResult:
if self._prefer_grpc:
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(
(points_selector, opt_shard_key_selector) = self._try_argument_to_grpc_selector(
points_selector
)
shard_key_selector = shard_key_selector or opt_shard_key_selector
Expand Down
22 changes: 11 additions & 11 deletions qdrant_client/local/async_qdrant_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _save(self) -> None:
{
"collections": {
collection_name: to_dict(collection.config)
for collection_name, collection in self.collections.items()
for (collection_name, collection) in self.collections.items()
},
"aliases": self.aliases,
}
Expand Down Expand Up @@ -357,7 +357,7 @@ def _resolve_prefetch_input(
if prefetch.query is None:
return prefetch
prefetch = deepcopy(prefetch)
query, mentioned_ids = self._resolve_query_input(
(query, mentioned_ids) = self._resolve_query_input(
collection_name, prefetch.query, prefetch.using, prefetch.lookup_from
)
prefetch.query = query
Expand All @@ -383,7 +383,7 @@ async def query_points(
) -> types.QueryResponse:
collection = self._get_collection(collection_name)
if query is not None:
query, mentioned_ids = self._resolve_query_input(
(query, mentioned_ids) = self._resolve_query_input(
collection_name, query, using, lookup_from
)
query_filter = ignore_mentioned_ids_filter(query_filter, list(mentioned_ids))
Expand Down Expand Up @@ -454,7 +454,7 @@ async def query_points_groups(
) -> types.GroupsResult:
collection = self._get_collection(collection_name)
if query is not None:
query, mentioned_ids = self._resolve_query_input(
(query, mentioned_ids) = self._resolve_query_input(
collection_name, query, using, lookup_from
)
query_filter = ignore_mentioned_ids_filter(query_filter, list(mentioned_ids))
Expand Down Expand Up @@ -814,7 +814,7 @@ async def get_collection_aliases(
return types.CollectionsAliasesResponse(
aliases=[
rest_models.AliasDescription(alias_name=alias_name, collection_name=name)
for alias_name, name in self.aliases.items()
for (alias_name, name) in self.aliases.items()
if name == collection_name
]
)
Expand All @@ -825,7 +825,7 @@ async def get_aliases(self, **kwargs: Any) -> types.CollectionsAliasesResponse:
return types.CollectionsAliasesResponse(
aliases=[
rest_models.AliasDescription(alias_name=alias_name, collection_name=name)
for alias_name, name in self.aliases.items()
for (alias_name, name) in self.aliases.items()
]
)

Expand All @@ -835,7 +835,7 @@ async def get_collections(self, **kwargs: Any) -> types.CollectionsResponse:
return types.CollectionsResponse(
collections=[
rest_models.CollectionDescription(name=name)
for name, _ in self.collections.items()
for (name, _) in self.collections.items()
]
)

Expand Down Expand Up @@ -876,7 +876,7 @@ async def delete_collection(self, collection_name: str, **kwargs: Any) -> bool:
del _collection
self.aliases = {
alias_name: name
for alias_name, name in self.aliases.items()
for (alias_name, name) in self.aliases.items()
if name != collection_name
}
collection_path = self._collection_path(collection_name)
Expand Down Expand Up @@ -917,12 +917,12 @@ async def create_collection(
self.collections[collection_name] = collection
if src_collection and from_collection_name:
batch_size = 100
records, next_offset = await self.scroll(
(records, next_offset) = await self.scroll(
from_collection_name, limit=2, with_vectors=True
)
self.upload_records(collection_name, records)
while next_offset is not None:
records, next_offset = await self.scroll(
(records, next_offset) = await self.scroll(
from_collection_name, offset=next_offset, limit=batch_size, with_vectors=True
)
self.upload_records(collection_name, records)
Expand Down Expand Up @@ -998,7 +998,7 @@ def uuid_generator() -> Generator[str, None, None]:
vector=(vector.tolist() if isinstance(vector, np.ndarray) else vector) or {},
payload=payload or {},
)
for point_id, vector, payload in zip(
for (point_id, vector, payload) in zip(
ids or uuid_generator(), iter(vectors), payload or itertools.cycle([{}])
)
]
Expand Down

0 comments on commit d8ccef6

Please sign in to comment.