diff --git a/examples/simple_async.py b/examples/simple_async.py new file mode 100644 index 000000000..ee51f4389 --- /dev/null +++ b/examples/simple_async.py @@ -0,0 +1,124 @@ +from pymilvus import ( + DataType, + MilvusClient, + AsyncMilvusClient, + AnnSearchRequest, + RRFRanker, +) +import numpy as np +import asyncio +import time +import random + +fmt = "\n=== {:30} ===\n" +num_entities, dim = 100, 8 +default_limit = 3 +collection_name = "hello_milvus" +rng = np.random.default_rng(seed=19530) + +milvus_client = MilvusClient("example.db") +async_milvus_client = AsyncMilvusClient("example.db") + +loop = asyncio.get_event_loop() + +schema = milvus_client.create_schema(auto_id=False, description="hello_milvus is the simplest demo to introduce the APIs") +schema.add_field("pk", DataType.VARCHAR, is_primary=True, max_length=100) +schema.add_field("random", DataType.DOUBLE) +schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=dim) +schema.add_field("embeddings2", DataType.FLOAT_VECTOR, dim=dim) + +index_params = milvus_client.prepare_index_params() +index_params.add_index(field_name = "embeddings", index_type = "HNSW", metric_type="L2", nlist=128) +index_params.add_index(field_name = "embeddings2",index_type = "HNSW", metric_type="L2", nlist=128) + +async def recreate_collection(): + print(fmt.format("Start dropping collection")) + await async_milvus_client.drop_collection(collection_name) + print(fmt.format("Dropping collection done")) + print(fmt.format("Start creating collection")) + await async_milvus_client.create_collection(collection_name, schema=schema, index_params=index_params, consistency_level="Strong") + print(fmt.format("Creating collection done")) + +has_collection = milvus_client.has_collection(collection_name, timeout=5) +if has_collection: + loop.run_until_complete(recreate_collection()) +else: + print(fmt.format("Start creating collection")) + loop.run_until_complete(async_milvus_client.create_collection(collection_name, schema=schema, index_params=index_params, consistency_level="Strong")) + print(fmt.format("Creating collection done")) + +print(fmt.format(" all collections ")) +print(milvus_client.list_collections()) + +print(fmt.format(f"schema of collection {collection_name}")) +print(milvus_client.describe_collection(collection_name)) + +async def async_insert(collection_name): + entities = [ + # provide the pk field because `auto_id` is set to False + [str(i) for i in range(num_entities)], + rng.random(num_entities).tolist(), # field random, only supports list + rng.random((num_entities, dim)), # field embeddings, supports numpy.ndarray and list + rng.random((num_entities, dim)), # field embeddings2, supports numpy.ndarray and list + ] + rows = [ {"pk": entities[0][i], "random": entities[1][i], "embeddings": entities[2][i], "embeddings2": entities[3][i]} for i in range (num_entities)] + print(fmt.format("Start async inserting entities")) + + start_time = time.time() + tasks = [] + for row in rows: + task = async_milvus_client.insert(collection_name, [row]) + tasks.append(task) + await asyncio.gather(*tasks) + end_time = time.time() + print(fmt.format("Total time: {:.2f} seconds".format(end_time - start_time))) + print(fmt.format("Async inserting entities done")) + +loop.run_until_complete(async_insert(collection_name)) + +async def other_async_task(collection_name): + tasks = [] + # search + random_vector = rng.random((1, dim)) + random_vector2 = rng.random((1, dim)) + task = async_milvus_client.search(collection_name, random_vector, limit=default_limit, output_fields=["pk"], anns_field="embeddings") + tasks.append(task) + # hybrid search + search_param = { + "data": random_vector, + "anns_field": "embeddings", + "param": {"metric_type": "L2"}, + "limit": default_limit, + "expr": "random > 0.5"} + req = AnnSearchRequest(**search_param) + task = async_milvus_client.hybrid_search(collection_name, [req], RRFRanker(), default_limit, output_fields=["pk"]) + tasks.append(task) + # get + random_pk = random.randint(0, num_entities - 1) + task = async_milvus_client.get(collection_name=collection_name, ids=[random_pk]) + tasks.append(task) + # query + task = async_milvus_client.query(collection_name=collection_name, filter="", limit=default_limit) + tasks.append(task) + # delete + task = async_milvus_client.delete(collection_name=collection_name, ids=[random_pk]) + tasks.append(task) + # insert + task = async_milvus_client.insert( + collection_name=collection_name, + data=[{"pk": str(random_pk), "random": random_vector[0][0], "embeddings": random_vector[0], "embeddings2": random_vector[0]}], + ) + tasks.append(task) + # upsert + task = async_milvus_client.upsert( + collection_name=collection_name, + data=[{"pk": str(random_pk), "random": random_vector2[0][0], "embeddings": random_vector2[0], "embeddings2": random_vector2[0]}], + ) + tasks.append(task) + + results = await asyncio.gather(*tasks) + return results + +results = loop.run_until_complete(other_async_task(collection_name)) +for r in results: + print(r) \ No newline at end of file diff --git a/pymilvus/__init__.py b/pymilvus/__init__.py index 7c01c7d78..f73aee203 100644 --- a/pymilvus/__init__.py +++ b/pymilvus/__init__.py @@ -31,7 +31,7 @@ MilvusException, MilvusUnavailableException, ) -from .milvus_client import MilvusClient +from .milvus_client import AsyncMilvusClient, MilvusClient from .orm import db, utility from .orm.collection import Collection from .orm.connections import Connections, connections @@ -73,6 +73,7 @@ __all__ = [ "AnnSearchRequest", + "AsyncMilvusClient", "BulkInsertState", "Collection", "CollectionSchema", diff --git a/pymilvus/client/async_grpc_handler.py b/pymilvus/client/async_grpc_handler.py new file mode 100644 index 000000000..794858dd7 --- /dev/null +++ b/pymilvus/client/async_grpc_handler.py @@ -0,0 +1,699 @@ +import asyncio +import base64 +import copy +import socket +from pathlib import Path +from typing import Callable, Dict, List, Optional, Union +from urllib import parse + +import grpc +from grpc._cython import cygrpc + +from pymilvus.decorators import retry_on_rpc_failure, upgrade_reminder +from pymilvus.exceptions import ( + DescribeCollectionException, + MilvusException, + ParamError, +) +from pymilvus.grpc_gen import milvus_pb2 as milvus_types +from pymilvus.grpc_gen import milvus_pb2_grpc +from pymilvus.settings import Config + +from . import entity_helper, interceptor, ts_utils, utils +from .abstract import AnnSearchRequest, BaseRanker, CollectionSchema, MutationResult, SearchResult +from .async_interceptor import async_header_adder_interceptor +from .check import ( + check_pass_param, + is_legal_host, + is_legal_port, +) +from .constants import ITERATOR_SESSION_TS_FIELD +from .prepare import Prepare +from .types import ( + DataType, + ExtraList, + Status, + get_cost_extra, +) +from .utils import ( + check_invalid_binary_vector, + check_status, + get_server_type, + is_successful, + len_of, +) + + +class AsyncGrpcHandler: + def __init__( + self, + uri: str = Config.GRPC_URI, + host: str = "", + port: str = "", + channel: Optional[grpc.aio.Channel] = None, + **kwargs, + ) -> None: + self._async_stub = None + self._async_channel = channel + + addr = kwargs.get("address") + self._address = addr if addr is not None else self.__get_address(uri, host, port) + self._log_level = None + self._request_id = None + self._user = kwargs.get("user") + self._set_authorization(**kwargs) + self._setup_db_interceptor(kwargs.get("db_name")) + self._setup_grpc_channel() # init channel and stub + self.callbacks = [] + + def register_state_change_callback(self, callback: Callable): + self.callbacks.append(callback) + self._async_channel.subscribe(callback, try_to_connect=True) + + def deregister_state_change_callbacks(self): + for callback in self.callbacks: + self._async_channel.unsubscribe(callback) + self.callbacks = [] + + def __get_address(self, uri: str, host: str, port: str) -> str: + if host != "" and port != "" and is_legal_host(host) and is_legal_port(port): + return f"{host}:{port}" + + try: + parsed_uri = parse.urlparse(uri) + except Exception as e: + raise ParamError(message=f"Illegal uri: [{uri}], {e}") from e + return parsed_uri.netloc + + def _set_authorization(self, **kwargs): + secure = kwargs.get("secure", False) + if not isinstance(secure, bool): + raise ParamError(message="secure must be bool type") + self._secure = secure + self._client_pem_path = kwargs.get("client_pem_path", "") + self._client_key_path = kwargs.get("client_key_path", "") + self._ca_pem_path = kwargs.get("ca_pem_path", "") + self._server_pem_path = kwargs.get("server_pem_path", "") + self._server_name = kwargs.get("server_name", "") + + self._authorization_interceptor = None + self._setup_authorization_interceptor( + kwargs.get("user"), + kwargs.get("password"), + kwargs.get("token"), + ) + + def __enter__(self): + return self + + def __exit__(self: object, exc_type: object, exc_val: object, exc_tb: object): + pass + + def _wait_for_channel_ready(self, timeout: Union[float] = 10, retry_interval: float = 1): + try: + + async def wait_for_async_channel_ready(): + await self._async_channel.channel_ready() + + loop = asyncio.get_event_loop() + loop.run_until_complete(wait_for_async_channel_ready()) + + self._setup_identifier_interceptor(self._user, timeout=timeout) + except grpc.FutureTimeoutError as e: + raise MilvusException( + code=Status.CONNECT_FAILED, + message=f"Fail connecting to server on {self._address}, illegal connection params or server unavailable", + ) from e + except Exception as e: + raise e from e + + def close(self): + self.deregister_state_change_callbacks() + self._async_channel.close() + + def reset_db_name(self, db_name: str): + self._setup_db_interceptor(db_name) + self._setup_grpc_channel() + self._setup_identifier_interceptor(self._user) + + def _setup_authorization_interceptor(self, user: str, password: str, token: str): + keys = [] + values = [] + if token: + authorization = base64.b64encode(f"{token}".encode()) + keys.append("authorization") + values.append(authorization) + elif user and password: + authorization = base64.b64encode(f"{user}:{password}".encode()) + keys.append("authorization") + values.append(authorization) + if len(keys) > 0 and len(values) > 0: + self._authorization_interceptor = interceptor.header_adder_interceptor(keys, values) + + def _setup_db_interceptor(self, db_name: str): + if db_name is None: + self._db_interceptor = None + else: + check_pass_param(db_name=db_name) + self._db_interceptor = interceptor.header_adder_interceptor(["dbname"], [db_name]) + + def _setup_grpc_channel(self): + if self._async_channel is None: + opts = [ + (cygrpc.ChannelArgKey.max_send_message_length, -1), + (cygrpc.ChannelArgKey.max_receive_message_length, -1), + ("grpc.enable_retries", 1), + ("grpc.keepalive_time_ms", 55000), + ] + if not self._secure: + self._async_channel = grpc.aio.insecure_channel( + self._address, + options=opts, + ) + else: + if self._server_name != "": + opts.append(("grpc.ssl_target_name_override", self._server_name)) + + root_cert, private_k, cert_chain = None, None, None + if self._server_pem_path != "": + with Path(self._server_pem_path).open("rb") as f: + root_cert = f.read() + elif ( + self._client_pem_path != "" + and self._client_key_path != "" + and self._ca_pem_path != "" + ): + with Path(self._ca_pem_path).open("rb") as f: + root_cert = f.read() + with Path(self._client_key_path).open("rb") as f: + private_k = f.read() + with Path(self._client_pem_path).open("rb") as f: + cert_chain = f.read() + + creds = grpc.ssl_channel_credentials( + root_certificates=root_cert, + private_key=private_k, + certificate_chain=cert_chain, + ) + self._async_channel = grpc.aio.secure_channel( + self._address, + creds, + options=opts, + ) + + # avoid to add duplicate headers. + self._final_channel = self._async_channel + if self._log_level: + + async_log_level_interceptor = async_header_adder_interceptor( + ["log_level"], [self._log_level] + ) + self._final_channel._unary_unary_interceptors.append(async_log_level_interceptor) + + self._log_level = None + if self._request_id: + + async_request_id_interceptor = async_header_adder_interceptor( + ["client_request_id"], [self._request_id] + ) + self._final_channel._unary_unary_interceptors.append(async_request_id_interceptor) + + self._request_id = None + self._async_stub = milvus_pb2_grpc.MilvusServiceStub(self._final_channel) + + def _setup_identifier_interceptor(self, user: str, timeout: int = 10): + host = socket.gethostname() + self._identifier = self.__internal_register(user, host, timeout=timeout) + _async_identifier_interceptor = async_header_adder_interceptor( + ["identifier"], [str(self._identifier)] + ) + self._async_channel._unary_unary_interceptors.append(_async_identifier_interceptor) + self._async_stub = milvus_pb2_grpc.MilvusServiceStub(self._async_channel) + + @property + def server_address(self): + return self._address + + def get_server_type(self): + return get_server_type(self.server_address.split(":")[0]) + + @retry_on_rpc_failure() + async def create_collection( + self, collection_name: str, fields: List, timeout: Optional[float] = None, **kwargs + ): + check_pass_param(collection_name=collection_name, timeout=timeout) + request = Prepare.create_collection_request(collection_name, fields, **kwargs) + response = await self._async_stub.CreateCollection(request, timeout=timeout) + check_status(response) + + @retry_on_rpc_failure() + async def drop_collection(self, collection_name: str, timeout: Optional[float] = None): + check_pass_param(collection_name=collection_name, timeout=timeout) + request = Prepare.drop_collection_request(collection_name) + response = await self._async_stub.DropCollection(request, timeout=timeout) + check_status(response) + + @retry_on_rpc_failure() + async def load_collection( + self, + collection_name: str, + replica_number: int = 1, + timeout: Optional[float] = None, + **kwargs, + ): + check_pass_param( + collection_name=collection_name, replica_number=replica_number, timeout=timeout + ) + refresh = kwargs.get("refresh", kwargs.get("_refresh", False)) + resource_groups = kwargs.get("resource_groups", kwargs.get("_resource_groups")) + load_fields = kwargs.get("load_fields", kwargs.get("_load_fields")) + skip_load_dynamic_field = kwargs.get( + "skip_load_dynamic_field", kwargs.get("_skip_load_dynamic_field", False) + ) + + request = Prepare.load_collection( + "", + collection_name, + replica_number, + refresh, + resource_groups, + load_fields, + skip_load_dynamic_field, + ) + response = await self._async_stub.LoadCollection(request, timeout=timeout) + check_status(response) + + @retry_on_rpc_failure() + async def describe_collection( + self, collection_name: str, timeout: Optional[float] = None, **kwargs + ): + check_pass_param(collection_name=collection_name, timeout=timeout) + request = Prepare.describe_collection_request(collection_name) + response = await self._async_stub.DescribeCollection(request, timeout=timeout) + status = response.status + + if is_successful(status): + return CollectionSchema(raw=response).dict() + + raise DescribeCollectionException(status.code, status.reason, status.error_code) + + async def _get_info(self, collection_name: str, timeout: Optional[float] = None, **kwargs): + schema = kwargs.get("schema") + if not schema: + schema = await self.describe_collection(collection_name, timeout=timeout) + + fields_info = schema.get("fields") + enable_dynamic = schema.get("enable_dynamic_field", False) + + return fields_info, enable_dynamic + + @retry_on_rpc_failure() + async def insert_rows( + self, + collection_name: str, + entities: Union[Dict, List[Dict]], + partition_name: Optional[str] = None, + schema: Optional[dict] = None, + timeout: Optional[float] = None, + **kwargs, + ): + request = await self._prepare_row_insert_request( + collection_name, entities, partition_name, schema, timeout, **kwargs + ) + resp = await self._async_stub.Insert(request=request, timeout=timeout) + check_status(resp.status) + ts_utils.update_collection_ts(collection_name, resp.timestamp) + return MutationResult(resp) + + async def _prepare_row_insert_request( + self, + collection_name: str, + entity_rows: Union[List[Dict], Dict], + partition_name: Optional[str] = None, + schema: Optional[dict] = None, + timeout: Optional[float] = None, + **kwargs, + ): + if isinstance(entity_rows, dict): + entity_rows = [entity_rows] + + if not isinstance(schema, dict): + schema = await self.describe_collection(collection_name, timeout=timeout) + + fields_info = schema.get("fields") + enable_dynamic = schema.get("enable_dynamic_field", False) + + return Prepare.row_insert_param( + collection_name, + entity_rows, + partition_name, + fields_info, + enable_dynamic=enable_dynamic, + ) + + async def delete( + self, + collection_name: str, + expression: str, + partition_name: Optional[str] = None, + timeout: Optional[float] = None, + **kwargs, + ): + check_pass_param(collection_name=collection_name, timeout=timeout) + try: + req = Prepare.delete_request( + collection_name=collection_name, + filter=expression, + partition_name=partition_name, + consistency_level=kwargs.pop("consistency_level", 0), + **kwargs, + ) + + response = await self._async_stub.Delete(req, timeout=timeout) + + m = MutationResult(response) + ts_utils.update_collection_ts(collection_name, m.timestamp) + except Exception as err: + raise err from err + else: + return m + + async def _prepare_batch_upsert_request( + self, + collection_name: str, + entities: List, + partition_name: Optional[str] = None, + timeout: Optional[float] = None, + **kwargs, + ): + param = kwargs.get("upsert_param") + if param and not isinstance(param, milvus_types.UpsertRequest): + raise ParamError(message="The value of key 'upsert_param' is invalid") + if not isinstance(entities, list): + raise ParamError(message="'entities' must be a list, please provide valid entity data.") + + schema = kwargs.get("schema") + if not schema: + schema = await self.describe_collection(collection_name, timeout=timeout, **kwargs) + + fields_info = schema["fields"] + + return ( + param + if param + else Prepare.batch_upsert_param(collection_name, entities, partition_name, fields_info) + ) + + @retry_on_rpc_failure() + async def upsert( + self, + collection_name: str, + entities: List, + partition_name: Optional[str] = None, + timeout: Optional[float] = None, + **kwargs, + ): + if not check_invalid_binary_vector(entities): + raise ParamError(message="Invalid binary vector data exists") + + try: + request = await self._prepare_batch_upsert_request( + collection_name, entities, partition_name, timeout, **kwargs + ) + response = await self._async_stub.Upsert(request, timeout=timeout) + check_status(response.status) + m = MutationResult(response) + ts_utils.update_collection_ts(collection_name, m.timestamp) + except Exception as err: + raise err from err + else: + return m + + async def _prepare_row_upsert_request( + self, + collection_name: str, + rows: List, + partition_name: Optional[str] = None, + timeout: Optional[float] = None, + **kwargs, + ): + if not isinstance(rows, list): + raise ParamError(message="'rows' must be a list, please provide valid row data.") + + fields_info, enable_dynamic = await self._get_info(collection_name, timeout, **kwargs) + return Prepare.row_upsert_param( + collection_name, + rows, + partition_name, + fields_info, + enable_dynamic=enable_dynamic, + ) + + @retry_on_rpc_failure() + async def upsert_rows( + self, + collection_name: str, + entities: List, + partition_name: Optional[str] = None, + timeout: Optional[float] = None, + **kwargs, + ): + if isinstance(entities, dict): + entities = [entities] + request = await self._prepare_row_upsert_request( + collection_name, entities, partition_name, timeout, **kwargs + ) + response = await self._async_stub.Upsert(request, timeout=timeout) + check_status(response.status) + m = MutationResult(response) + ts_utils.update_collection_ts(collection_name, m.timestamp) + return m + + async def _execute_search( + self, request: milvus_types.SearchRequest, timeout: Optional[float] = None, **kwargs + ): + try: + response = await self._async_stub.Search(request, timeout=timeout) + check_status(response.status) + round_decimal = kwargs.get("round_decimal", -1) + return SearchResult( + response.results, + round_decimal, + status=response.status, + session_ts=response.session_ts, + ) + except Exception as e: + raise e from e + + async def _execute_hybrid_search( + self, request: milvus_types.HybridSearchRequest, timeout: Optional[float] = None, **kwargs + ): + try: + response = await self._async_stub.HybridSearch(request, timeout=timeout) + check_status(response.status) + round_decimal = kwargs.get("round_decimal", -1) + return SearchResult(response.results, round_decimal, status=response.status) + + except Exception as e: + raise e from e + + @retry_on_rpc_failure() + async def search( + self, + collection_name: str, + data: Union[List[List[float]], utils.SparseMatrixInputType], + anns_field: str, + param: Dict, + limit: int, + expression: Optional[str] = None, + partition_names: Optional[List[str]] = None, + output_fields: Optional[List[str]] = None, + round_decimal: int = -1, + timeout: Optional[float] = None, + **kwargs, + ): + check_pass_param( + limit=limit, + round_decimal=round_decimal, + anns_field=anns_field, + search_data=data, + partition_name_array=partition_names, + output_fields=output_fields, + guarantee_timestamp=kwargs.get("guarantee_timestamp"), + timeout=timeout, + ) + request = Prepare.search_requests_with_expr( + collection_name, + data, + anns_field, + param, + limit, + expression, + partition_names, + output_fields, + round_decimal, + **kwargs, + ) + return await self._execute_search(request, timeout, round_decimal=round_decimal, **kwargs) + + @retry_on_rpc_failure() + async def hybrid_search( + self, + collection_name: str, + reqs: List[AnnSearchRequest], + rerank: BaseRanker, + limit: int, + partition_names: Optional[List[str]] = None, + output_fields: Optional[List[str]] = None, + round_decimal: int = -1, + timeout: Optional[float] = None, + **kwargs, + ): + check_pass_param( + limit=limit, + round_decimal=round_decimal, + partition_name_array=partition_names, + output_fields=output_fields, + guarantee_timestamp=kwargs.get("guarantee_timestamp"), + timeout=timeout, + ) + + requests = [] + for req in reqs: + search_request = Prepare.search_requests_with_expr( + collection_name, + req.data, + req.anns_field, + req.param, + req.limit, + req.expr, + partition_names=partition_names, + round_decimal=round_decimal, + **kwargs, + ) + requests.append(search_request) + + hybrid_search_request = Prepare.hybrid_search_request_with_ranker( + collection_name, + requests, + rerank.dict(), + limit, + partition_names, + output_fields, + round_decimal, + **kwargs, + ) + return await self._execute_hybrid_search( + hybrid_search_request, timeout, round_decimal=round_decimal, **kwargs + ) + + @retry_on_rpc_failure() + async def create_index( + self, + collection_name: str, + field_name: str, + params: Dict, + timeout: Optional[float] = None, + **kwargs, + ): + index_name = kwargs.pop("index_name", Config.IndexName) + copy_kwargs = copy.deepcopy(kwargs) + + collection_desc = await self.describe_collection( + collection_name, timeout=timeout, **copy_kwargs + ) + + valid_field = False + for fields in collection_desc["fields"]: + if field_name != fields["name"]: + continue + valid_field = True + if fields["type"] not in { + DataType.FLOAT_VECTOR, + DataType.BINARY_VECTOR, + DataType.FLOAT16_VECTOR, + DataType.BFLOAT16_VECTOR, + DataType.SPARSE_FLOAT_VECTOR, + }: + break + + if not valid_field: + raise MilvusException(message=f"cannot create index on non-existed field: {field_name}") + + index_param = Prepare.create_index_request( + collection_name, field_name, params, index_name=index_name + ) + + status = await self._async_stub.CreateIndex(index_param, timeout=timeout) + check_status(status) + + return Status(status.code, status.reason) + + @retry_on_rpc_failure() + async def get( + self, + collection_name: str, + ids: List[int], + output_fields: Optional[List[str]] = None, + partition_names: Optional[List[str]] = None, + timeout: Optional[float] = None, + ): + # TODO: some check + request = Prepare.retrieve_request(collection_name, ids, output_fields, partition_names) + return await self._async_stub.Retrieve.get(request, timeout=timeout) + + @retry_on_rpc_failure() + async def query( + self, + collection_name: str, + expr: str, + output_fields: Optional[List[str]] = None, + partition_names: Optional[List[str]] = None, + timeout: Optional[float] = None, + **kwargs, + ): + if output_fields is not None and not isinstance(output_fields, (list,)): + raise ParamError(message="Invalid query format. 'output_fields' must be a list") + request = Prepare.query_request( + collection_name, expr, output_fields, partition_names, **kwargs + ) + response = await self._async_stub.Query(request, timeout=timeout) + check_status(response.status) + + num_fields = len(response.fields_data) + # check has fields + if num_fields == 0: + raise MilvusException(message="No fields returned") + + # check if all lists are of the same length + it = iter(response.fields_data) + num_entities = len_of(next(it)) + if not all(len_of(field_data) == num_entities for field_data in it): + raise MilvusException(message="The length of fields data is inconsistent") + + _, dynamic_fields = entity_helper.extract_dynamic_field_from_result(response) + + results = [] + for index in range(num_entities): + entity_row_data = entity_helper.extract_row_data_from_fields_data( + response.fields_data, index, dynamic_fields + ) + results.append(entity_row_data) + + extra_dict = get_cost_extra(response.status) + extra_dict[ITERATOR_SESSION_TS_FIELD] = response.session_ts + return ExtraList(results, extra=extra_dict) + + @retry_on_rpc_failure() + @upgrade_reminder + def __internal_register(self, user: str, host: str, **kwargs) -> int: + req = Prepare.register_request(user, host) + + async def wait_for_connect_response(): + return await self._async_stub.Connect(request=req) + + loop = asyncio.get_event_loop() + response = loop.run_until_complete(wait_for_connect_response()) + + check_status(response.status) + return response.identifier diff --git a/pymilvus/client/async_interceptor.py b/pymilvus/client/async_interceptor.py new file mode 100644 index 000000000..c456a44f2 --- /dev/null +++ b/pymilvus/client/async_interceptor.py @@ -0,0 +1,94 @@ +from typing import ( + Any, + Callable, + List, + Union, +) + +from grpc.aio import ( + ClientCallDetails, + StreamStreamClientInterceptor, + StreamUnaryClientInterceptor, + UnaryStreamClientInterceptor, + UnaryUnaryClientInterceptor, +) +from grpc.aio._call import ( + StreamStreamCall, + StreamUnaryCall, + UnaryStreamCall, + UnaryUnaryCall, +) +from grpc.aio._typing import ( + RequestIterableType, + RequestType, + ResponseIterableType, + ResponseType, +) + + +class _GenericAsyncClientInterceptor( + UnaryUnaryClientInterceptor, + UnaryStreamClientInterceptor, + StreamUnaryClientInterceptor, + StreamStreamClientInterceptor, +): + def __init__(self, interceptor_function: Callable): + self._fn = interceptor_function + + async def intercept_unary_unary( + self, + continuation: Callable[[ClientCallDetails, RequestType], UnaryUnaryCall], + client_call_details: ClientCallDetails, + request: RequestType, + ) -> Union[UnaryUnaryCall, ResponseType]: + new_details, new_request = self._fn(client_call_details, request) + return await continuation(new_details, new_request) + + async def intercept_unary_stream( + self, + continuation: Callable[[ClientCallDetails, RequestType], UnaryStreamCall], + client_call_details: ClientCallDetails, + request: RequestType, + ) -> Union[ResponseIterableType, UnaryStreamCall]: + new_details, new_request = self._fn(client_call_details, request) + return await continuation(new_details, new_request) + + async def intercept_stream_unary( + self, + continuation: Callable[[ClientCallDetails, RequestType], StreamUnaryCall], + client_call_details: ClientCallDetails, + request_iterator: RequestIterableType, + ) -> StreamUnaryCall: + new_details, new_request_iterator = self._fn(client_call_details, request_iterator) + return await continuation(new_details, new_request_iterator) + + async def intercept_stream_stream( + self, + continuation: Callable[[ClientCallDetails, RequestType], StreamStreamCall], + client_call_details: ClientCallDetails, + request_iterator: RequestIterableType, + ) -> Union[ResponseIterableType, StreamStreamCall]: + new_details, new_request_iterator = self._fn(client_call_details, request_iterator) + return await continuation(new_details, new_request_iterator) + + +def async_header_adder_interceptor(headers: List[str], values: List[str]): + def intercept_call(client_call_details: ClientCallDetails, request: Any): + metadata = [] + if client_call_details.metadata: + metadata = list(client_call_details.metadata) + + for header, value in zip(headers, values): + metadata.append((header, value)) + + new_details = ClientCallDetails( + method=client_call_details.method, + timeout=client_call_details.timeout, + metadata=metadata, + credentials=client_call_details.credentials, + wait_for_ready=client_call_details.wait_for_ready, + ) + + return new_details, request + + return _GenericAsyncClientInterceptor(intercept_call) diff --git a/pymilvus/milvus_client/__init__.py b/pymilvus/milvus_client/__init__.py index 3dcab3795..e8d26df40 100644 --- a/pymilvus/milvus_client/__init__.py +++ b/pymilvus/milvus_client/__init__.py @@ -1,4 +1,5 @@ +from .async_milvus_client import AsyncMilvusClient from .index import IndexParams from .milvus_client import MilvusClient -__all__ = ["IndexParams", "MilvusClient"] +__all__ = ["AsyncMilvusClient", "IndexParams", "MilvusClient"] diff --git a/pymilvus/milvus_client/async_milvus_client.py b/pymilvus/milvus_client/async_milvus_client.py new file mode 100644 index 000000000..2c5a000ab --- /dev/null +++ b/pymilvus/milvus_client/async_milvus_client.py @@ -0,0 +1,572 @@ +import logging +from typing import Dict, List, Optional, Union +from uuid import uuid4 + +from pymilvus.client.abstract import AnnSearchRequest, BaseRanker +from pymilvus.client.constants import DEFAULT_CONSISTENCY_LEVEL +from pymilvus.client.types import ( + ExceptionsMessage, + ExtraList, + OmitZeroDict, + construct_cost_extra, +) +from pymilvus.exceptions import ( + DataTypeNotMatchException, + MilvusException, + ParamError, + PrimaryKeyException, +) +from pymilvus.orm import utility +from pymilvus.orm.collection import CollectionSchema +from pymilvus.orm.connections import connections +from pymilvus.orm.types import DataType + +from .index import IndexParams + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +class AsyncMilvusClient: + """AsyncMilvusClient is an EXPERIMENTAL class + which only provides part of MilvusClient's methods""" + + def __init__( + self, + uri: str = "http://localhost:19530", + user: str = "", + password: str = "", + db_name: str = "", + token: str = "", + timeout: Optional[float] = None, + **kwargs, + ) -> None: + self._using = self._create_connection( + uri, user, password, db_name, token, timeout=timeout, **kwargs + ) + self.is_self_hosted = bool(utility.get_server_type(using=self._using) == "milvus") + + async def create_collection( + self, + collection_name: str, + dimension: Optional[int] = None, + primary_field_name: str = "id", # default is "id" + id_type: str = "int", # or "string", + vector_field_name: str = "vector", # default is "vector" + metric_type: str = "COSINE", + auto_id: bool = False, + timeout: Optional[float] = None, + schema: Optional[CollectionSchema] = None, + index_params: Optional[IndexParams] = None, + **kwargs, + ): + if schema is None: + return await self._fast_create_collection( + collection_name, + dimension, + primary_field_name=primary_field_name, + id_type=id_type, + vector_field_name=vector_field_name, + metric_type=metric_type, + auto_id=auto_id, + timeout=timeout, + **kwargs, + ) + + return await self._create_collection_with_schema( + collection_name, schema, index_params, timeout=timeout, **kwargs + ) + + async def _fast_create_collection( + self, + collection_name: str, + dimension: int, + primary_field_name: str = "id", # default is "id" + id_type: Union[DataType, str] = DataType.INT64, # or "string", + vector_field_name: str = "vector", # default is "vector" + metric_type: str = "COSINE", + auto_id: bool = False, + timeout: Optional[float] = None, + **kwargs, + ): + if dimension is None: + msg = "missing requried argument: 'dimension'" + raise TypeError(msg) + if "enable_dynamic_field" not in kwargs: + kwargs["enable_dynamic_field"] = True + + schema = self.create_schema(auto_id=auto_id, **kwargs) + + if id_type in ("int", DataType.INT64): + pk_data_type = DataType.INT64 + elif id_type in ("string", "str", DataType.VARCHAR): + pk_data_type = DataType.VARCHAR + else: + raise PrimaryKeyException(message=ExceptionsMessage.PrimaryFieldType) + + pk_args = {} + if "max_length" in kwargs and pk_data_type == DataType.VARCHAR: + pk_args["max_length"] = kwargs["max_length"] + + schema.add_field(primary_field_name, pk_data_type, is_primary=True, **pk_args) + vector_type = DataType.FLOAT_VECTOR + schema.add_field(vector_field_name, vector_type, dim=dimension) + schema.verify() + + conn = self._get_connection() + if "consistency_level" not in kwargs: + kwargs["consistency_level"] = DEFAULT_CONSISTENCY_LEVEL + try: + await conn.async_create_collection(collection_name, schema, timeout=timeout, **kwargs) + logger.debug("Successfully created collection: %s", collection_name) + except Exception as ex: + logger.error("Failed to create collection: %s", collection_name) + raise ex from ex + + index_params = IndexParams() + index_params.add_index(vector_field_name, "", "", metric_type=metric_type) + await self.create_index(collection_name, index_params, timeout=timeout) + await self.load_collection(collection_name, timeout=timeout) + + async def _create_collection_with_schema( + self, + collection_name: str, + schema: CollectionSchema, + index_params: IndexParams, + timeout: Optional[float] = None, + **kwargs, + ): + schema.verify() + + conn = self._get_connection() + if "consistency_level" not in kwargs: + kwargs["consistency_level"] = DEFAULT_CONSISTENCY_LEVEL + try: + await conn.create_collection(collection_name, schema, timeout=timeout, **kwargs) + logger.debug("Successfully created collection: %s", collection_name) + except Exception as ex: + logger.error("Failed to create collection: %s", collection_name) + raise ex from ex + + if index_params: + await self.create_index(collection_name, index_params, timeout=timeout) + await self.load_collection(collection_name, timeout=timeout) + + async def drop_collection( + self, collection_name: str, timeout: Optional[float] = None, **kwargs + ): + conn = self._get_connection() + await conn.drop_collection(collection_name, timeout=timeout, **kwargs) + logger.debug("Successfully dropped collection: %s", collection_name) + + async def load_collection( + self, collection_name: str, timeout: Optional[float] = None, **kwargs + ): + conn = self._get_connection() + try: + await conn.load_collection(collection_name, timeout=timeout, **kwargs) + except MilvusException as ex: + logger.error("Failed to load collection: %s", collection_name) + raise ex from ex + + async def create_index( + self, + collection_name: str, + index_params: IndexParams, + timeout: Optional[float] = None, + **kwargs, + ): + for index_param in index_params: + await self._create_index(collection_name, index_param, timeout=timeout, **kwargs) + + async def _create_index( + self, collection_name: str, index_param: Dict, timeout: Optional[float] = None, **kwargs + ): + conn = self._get_connection() + try: + params = index_param.pop("params", {}) + field_name = index_param.pop("field_name", "") + index_name = index_param.pop("index_name", "") + params.update(index_param) + await conn.create_index( + collection_name, + field_name, + params, + timeout=timeout, + index_name=index_name, + **kwargs, + ) + logger.debug("Successfully created an index on collection: %s", collection_name) + except Exception as ex: + logger.error("Failed to create an index on collection: %s", collection_name) + raise ex from ex + + async def insert( + self, + collection_name: str, + data: Union[Dict, List[Dict]], + timeout: Optional[float] = None, + partition_name: Optional[str] = "", + **kwargs, + ) -> Dict: + # If no data provided, we cannot input anything + if isinstance(data, Dict): + data = [data] + + msg = "wrong type of argument 'data'," + msg += f"expected 'Dict' or list of 'Dict', got '{type(data).__name__}'" + + if not isinstance(data, List): + raise TypeError(msg) + + if len(data) == 0: + return {"insert_count": 0, "ids": []} + + conn = self._get_connection() + # Insert into the collection. + try: + res = await conn.insert_rows( + collection_name, data, partition_name=partition_name, timeout=timeout + ) + except Exception as ex: + raise ex from ex + return OmitZeroDict( + { + "insert_count": res.insert_count, + "ids": res.primary_keys, + "cost": res.cost, + } + ) + + async def upsert( + self, + collection_name: str, + data: Union[Dict, List[Dict]], + timeout: Optional[float] = None, + partition_name: Optional[str] = "", + **kwargs, + ) -> Dict: + # If no data provided, we cannot input anything + if isinstance(data, Dict): + data = [data] + + msg = "wrong type of argument 'data'," + msg += f"expected 'Dict' or list of 'Dict', got '{type(data).__name__}'" + + if not isinstance(data, List): + raise TypeError(msg) + + if len(data) == 0: + return {"upsert_count": 0} + + conn = self._get_connection() + # Upsert into the collection. + try: + res = await conn.upsert_rows( + collection_name, data, partition_name=partition_name, timeout=timeout, **kwargs + ) + except Exception as ex: + raise ex from ex + + return OmitZeroDict( + { + "upsert_count": res.upsert_count, + "cost": res.cost, + } + ) + + async def hybrid_search( + self, + collection_name: str, + reqs: List[AnnSearchRequest], + ranker: BaseRanker, + limit: int = 10, + output_fields: Optional[List[str]] = None, + timeout: Optional[float] = None, + partition_names: Optional[List[str]] = None, + **kwargs, + ) -> List[List[dict]]: + + conn = self._get_connection() + try: + res = await conn.hybrid_search( + collection_name, + reqs, + ranker, + limit=limit, + partition_names=partition_names, + output_fields=output_fields, + timeout=timeout, + **kwargs, + ) + except Exception as ex: + logger.error("Failed to hybrid search collection: %s", collection_name) + raise ex from ex + + ret = [] + for hits in res: + ret.append([hit.to_dict() for hit in hits]) + + return ExtraList(ret, extra=construct_cost_extra(res.cost)) + + async def search( + self, + collection_name: str, + data: Union[List[list], list], + filter: str = "", + limit: int = 10, + output_fields: Optional[List[str]] = None, + search_params: Optional[dict] = None, + timeout: Optional[float] = None, + partition_names: Optional[List[str]] = None, + anns_field: Optional[str] = None, + **kwargs, + ) -> List[List[dict]]: + conn = self._get_connection() + try: + res = await conn.search( + collection_name, + data, + anns_field or "", + search_params or {}, + expression=filter, + limit=limit, + output_fields=output_fields, + partition_names=partition_names, + expr_params=kwargs.pop("filter_params", {}), + timeout=timeout, + **kwargs, + ) + except Exception as ex: + logger.error("Failed to search collection: %s", collection_name) + raise ex from ex + + ret = [] + for hits in res: + query_result = [] + for hit in hits: + query_result.append(hit.to_dict()) + ret.append(query_result) + + return ExtraList(ret, extra=construct_cost_extra(res.cost)) + + async def query( + self, + collection_name: str, + filter: str = "", + output_fields: Optional[List[str]] = None, + timeout: Optional[float] = None, + ids: Optional[Union[List, str, int]] = None, + partition_names: Optional[List[str]] = None, + **kwargs, + ) -> List[dict]: + if filter and not isinstance(filter, str): + raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(filter)) + + if filter and ids is not None: + raise ParamError(message=ExceptionsMessage.AmbiguousQueryFilterParam) + + if isinstance(ids, (int, str)): + ids = [ids] + + conn = self._get_connection() + + if ids: + try: + schema_dict = conn.describe_collection(collection_name, timeout=timeout, **kwargs) + except Exception as ex: + logger.error("Failed to describe collection: %s", collection_name) + raise ex from ex + filter = self._pack_pks_expr(schema_dict, ids) + + if not output_fields: + output_fields = ["*"] + + try: + res = await conn.query( + collection_name, + expr=filter, + output_fields=output_fields, + partition_names=partition_names, + timeout=timeout, + expr_params=kwargs.pop("filter_params", {}), + **kwargs, + ) + except Exception as ex: + logger.error("Failed to query collection: %s", collection_name) + raise ex from ex + + return res + + async def get( + self, + collection_name: str, + ids: Union[list, str, int], + output_fields: Optional[List[str]] = None, + timeout: Optional[float] = None, + partition_names: Optional[List[str]] = None, + **kwargs, + ) -> List[dict]: + if not isinstance(ids, list): + ids = [ids] + + if len(ids) == 0: + return [] + + conn = self._get_connection() + try: + schema_dict = await conn.describe_collection(collection_name, timeout=timeout, **kwargs) + except Exception as ex: + logger.error("Failed to describe collection: %s", collection_name) + raise ex from ex + + if not output_fields: + output_fields = ["*"] + + expr = self._pack_pks_expr(schema_dict, ids) + try: + res = await conn.query( + collection_name, + expr=expr, + output_fields=output_fields, + partition_names=partition_names, + timeout=timeout, + **kwargs, + ) + except Exception as ex: + logger.error("Failed to get collection: %s", collection_name) + raise ex from ex + + return res + + async def delete( + self, + collection_name: str, + ids: Optional[Union[list, str, int]] = None, + timeout: Optional[float] = None, + filter: Optional[str] = None, + partition_name: Optional[str] = None, + **kwargs, + ) -> Dict[str, int]: + pks = kwargs.get("pks", []) + if isinstance(pks, (int, str)): + pks = [pks] + + for pk in pks: + if not isinstance(pk, (int, str)): + msg = f"wrong type of argument pks, expect list, int or str, got '{type(pk).__name__}'" + raise TypeError(msg) + + if ids is not None: + if isinstance(ids, (int, str)): + pks.append(ids) + elif isinstance(ids, list): + for id in ids: + if not isinstance(id, (int, str)): + msg = f"wrong type of argument ids, expect list, int or str, got '{type(id).__name__}'" + raise TypeError(msg) + pks.extend(ids) + else: + msg = f"wrong type of argument ids, expect list, int or str, got '{type(ids).__name__}'" + raise TypeError(msg) + + # validate ambiguous delete filter param before describe collection rpc + if filter and len(pks) > 0: + raise ParamError(message=ExceptionsMessage.AmbiguousDeleteFilterParam) + + expr = "" + conn = self._get_connection() + if len(pks) > 0: + try: + schema_dict = await conn.describe_collection( + collection_name, timeout=timeout, **kwargs + ) + except Exception as ex: + logger.error("Failed to describe collection: %s", collection_name) + raise ex from ex + expr = self._pack_pks_expr(schema_dict, pks) + else: + if not isinstance(filter, str): + raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(filter)) + expr = filter + + ret_pks = [] + try: + res = await conn.delete( + collection_name=collection_name, + expression=expr, + partition_name=partition_name, + expr_params=kwargs.pop("filter_params", {}), + timeout=timeout, + **kwargs, + ) + if res.primary_keys: + ret_pks.extend(res.primary_keys) + except Exception as ex: + logger.error("Failed to delete primary keys in collection: %s", collection_name) + raise ex from ex + + # compatible with deletions that returns primary keys + if ret_pks: + return ret_pks + + return OmitZeroDict({"delete_count": res.delete_count, "cost": res.cost}) + + @classmethod + def create_schema(cls, **kwargs): + kwargs["check_fields"] = False # do not check fields for now + return CollectionSchema([], **kwargs) + + def close(self): + connections.disconnect(self._using) + + def _get_connection(self): + return connections._fetch_handler(self._using) + + def _create_connection( + self, + uri: str, + user: str = "", + password: str = "", + db_name: str = "", + token: str = "", + **kwargs, + ) -> str: + """Create the connection to the Milvus server.""" + # TODO: Implement reuse with new uri style + using = uuid4().hex + try: + connections.connect( + using, user, password, db_name, token, uri=uri, _async=True, **kwargs + ) + except Exception as ex: + logger.error("Failed to create new connection using: %s", using) + raise ex from ex + else: + logger.debug("Created new connection using: %s", using) + return using + + def _extract_primary_field(self, schema_dict: Dict) -> dict: + fields = schema_dict.get("fields", []) + if not fields: + return {} + + for field_dict in fields: + if field_dict.get("is_primary", None) is not None: + return field_dict + + return {} + + def _pack_pks_expr(self, schema_dict: Dict, pks: List) -> str: + primary_field = self._extract_primary_field(schema_dict) + pk_field_name = primary_field["name"] + data_type = primary_field["type"] + + # Varchar pks need double quotes around the values + if data_type == DataType.VARCHAR: + ids = ["'" + str(entry) + "'" for entry in pks] + expr = f"""{pk_field_name} in [{','.join(ids)}]""" + else: + ids = [str(entry) for entry in pks] + expr = f"{pk_field_name} in [{','.join(ids)}]" + return expr diff --git a/pymilvus/orm/connections.py b/pymilvus/orm/connections.py index 7855c72e4..5d6306661 100644 --- a/pymilvus/orm/connections.py +++ b/pymilvus/orm/connections.py @@ -18,6 +18,7 @@ from typing import Callable, Tuple, Union from urllib import parse +from pymilvus.client.async_grpc_handler import AsyncGrpcHandler from pymilvus.client.check import is_legal_address, is_legal_host, is_legal_port from pymilvus.client.grpc_handler import GrpcHandler from pymilvus.exceptions import ( @@ -303,6 +304,7 @@ def connect( password: str = "", db_name: str = "default", token: str = "", + _async: bool = False, **kwargs, ) -> None: """ @@ -357,7 +359,6 @@ def connect( >>> from pymilvus import connections >>> connections.connect("test", host="localhost", port="19530") """ - if kwargs.get("uri") and parse.urlparse(kwargs["uri"]).scheme.lower() not in [ "unix", "http", @@ -394,7 +395,7 @@ def connect( kwargs_copy["token"] = token def connect_milvus(**kwargs): - gh = GrpcHandler(**kwargs) + gh = GrpcHandler(**kwargs) if not _async else AsyncGrpcHandler(**kwargs) t = kwargs.get("timeout") timeout = t if isinstance(t, (int, float)) else Config.MILVUS_CONN_TIMEOUT @@ -532,7 +533,9 @@ def has_connection(self, alias: str) -> bool: raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) return alias in self._connected_alias - def _fetch_handler(self, alias: str = Config.MILVUS_CONN_ALIAS) -> GrpcHandler: + def _fetch_handler( + self, alias: str = Config.MILVUS_CONN_ALIAS + ) -> Union[GrpcHandler, AsyncGrpcHandler]: """Retrieves a GrpcHandler by alias.""" if not isinstance(alias, str): raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias))