diff --git a/examples/simple_async.py b/examples/simple_async.py index ee51f4389..91d7d681c 100644 --- a/examples/simple_async.py +++ b/examples/simple_async.py @@ -31,6 +31,7 @@ index_params.add_index(field_name = "embeddings", index_type = "HNSW", metric_type="L2", nlist=128) index_params.add_index(field_name = "embeddings2",index_type = "HNSW", metric_type="L2", nlist=128) +# Always use `await` when you want to guarantee the execution order of tasks. async def recreate_collection(): print(fmt.format("Start dropping collection")) await async_milvus_client.drop_collection(collection_name) diff --git a/pymilvus/client/async_grpc_handler.py b/pymilvus/client/async_grpc_handler.py index 794858dd7..1d8825be6 100644 --- a/pymilvus/client/async_grpc_handler.py +++ b/pymilvus/client/async_grpc_handler.py @@ -19,7 +19,7 @@ from pymilvus.grpc_gen import milvus_pb2_grpc from pymilvus.settings import Config -from . import entity_helper, interceptor, ts_utils, utils +from . import entity_helper, ts_utils, utils from .abstract import AnnSearchRequest, BaseRanker, CollectionSchema, MutationResult, SearchResult from .async_interceptor import async_header_adder_interceptor from .check import ( @@ -62,8 +62,8 @@ def __init__( self._request_id = None self._user = kwargs.get("user") self._set_authorization(**kwargs) - self._setup_db_interceptor(kwargs.get("db_name")) - self._setup_grpc_channel() # init channel and stub + self._setup_db_name(kwargs.get("db_name")) + self._setup_grpc_channel(**kwargs) self.callbacks = [] def register_state_change_callback(self, callback: Callable): @@ -96,12 +96,7 @@ def _set_authorization(self, **kwargs): self._server_pem_path = kwargs.get("server_pem_path", "") self._server_name = kwargs.get("server_name", "") - self._authorization_interceptor = None - self._setup_authorization_interceptor( - kwargs.get("user"), - kwargs.get("password"), - kwargs.get("token"), - ) + self._async_authorization_interceptor = None def __enter__(self): return self @@ -132,7 +127,7 @@ def close(self): self._async_channel.close() def reset_db_name(self, db_name: str): - self._setup_db_interceptor(db_name) + self._setup_db_name(db_name) self._setup_grpc_channel() self._setup_identifier_interceptor(self._user) @@ -148,16 +143,19 @@ def _setup_authorization_interceptor(self, user: str, password: str, token: str) keys.append("authorization") values.append(authorization) if len(keys) > 0 and len(values) > 0: - self._authorization_interceptor = interceptor.header_adder_interceptor(keys, values) + self._async_authorization_interceptor = async_header_adder_interceptor(keys, values) + self._final_channel._unary_unary_interceptors.append( + self._async_authorization_interceptor + ) - def _setup_db_interceptor(self, db_name: str): + def _setup_db_name(self, db_name: str): if db_name is None: - self._db_interceptor = None + self._db_name = None else: check_pass_param(db_name=db_name) - self._db_interceptor = interceptor.header_adder_interceptor(["dbname"], [db_name]) + self._db_name = db_name - def _setup_grpc_channel(self): + def _setup_grpc_channel(self, **kwargs): if self._async_channel is None: opts = [ (cygrpc.ChannelArgKey.max_send_message_length, -1), @@ -203,21 +201,31 @@ def _setup_grpc_channel(self): # avoid to add duplicate headers. self._final_channel = self._async_channel - if self._log_level: + if self._async_authorization_interceptor: + self._final_channel._unary_unary_interceptors.append( + self._async_authorization_interceptor + ) + else: + self._setup_authorization_interceptor( + kwargs.get("user"), + kwargs.get("password"), + kwargs.get("token"), + ) + if self._db_name: + async_db_interceptor = async_header_adder_interceptor(["dbname"], [self._db_name]) + self._final_channel._unary_unary_interceptors.append(async_db_interceptor) + if self._log_level: async_log_level_interceptor = async_header_adder_interceptor( ["log_level"], [self._log_level] ) self._final_channel._unary_unary_interceptors.append(async_log_level_interceptor) - self._log_level = None if self._request_id: - async_request_id_interceptor = async_header_adder_interceptor( ["client_request_id"], [self._request_id] ) self._final_channel._unary_unary_interceptors.append(async_request_id_interceptor) - self._request_id = None self._async_stub = milvus_pb2_grpc.MilvusServiceStub(self._final_channel) diff --git a/pymilvus/client/async_interceptor.py b/pymilvus/client/async_interceptor.py index c456a44f2..db96b416f 100644 --- a/pymilvus/client/async_interceptor.py +++ b/pymilvus/client/async_interceptor.py @@ -72,7 +72,7 @@ async def intercept_stream_stream( return await continuation(new_details, new_request_iterator) -def async_header_adder_interceptor(headers: List[str], values: List[str]): +def async_header_adder_interceptor(headers: List[str], values: Union[List[str], List[bytes]]): def intercept_call(client_call_details: ClientCallDetails, request: Any): metadata = [] if client_call_details.metadata: