diff --git a/pymilvus/client/async_grpc_handler.py b/pymilvus/client/async_grpc_handler.py index 5532b4880..c9fe1320e 100644 --- a/pymilvus/client/async_grpc_handler.py +++ b/pymilvus/client/async_grpc_handler.py @@ -10,7 +10,7 @@ import grpc from grpc._cython import cygrpc -from pymilvus.decorators import ignore_unimplemented, retry_on_rpc_failure, upgrade_reminder +from pymilvus.decorators import ignore_unimplemented, retry_on_rpc_failure from pymilvus.exceptions import ( AmbiguousIndexName, DescribeCollectionException, @@ -68,6 +68,7 @@ def __init__( self._set_authorization(**kwargs) self._setup_db_name(kwargs.get("db_name")) self._setup_grpc_channel(**kwargs) + self._is_channel_ready = False self.callbacks = [] def register_state_change_callback(self, callback: Callable): @@ -108,33 +109,10 @@ def __enter__(self): def __exit__(self: object, exc_type: object, exc_val: object, exc_tb: object): pass - def _wait_for_channel_ready(self, timeout: Union[float] = 10, retry_interval: float = 1): - try: - - async def wait_for_async_channel_ready(): - await self._async_channel.channel_ready() - - loop = asyncio.get_event_loop() - loop.run_until_complete(wait_for_async_channel_ready()) - - self._setup_identifier_interceptor(self._user, timeout=timeout) - except grpc.FutureTimeoutError as e: - raise MilvusException( - code=Status.CONNECT_FAILED, - message=f"Fail connecting to server on {self._address}, illegal connection params or server unavailable", - ) from e - except Exception as e: - raise e from e - def close(self): self.deregister_state_change_callbacks() self._async_channel.close() - def reset_db_name(self, db_name: str): - self._setup_db_name(db_name) - self._setup_grpc_channel() - self._setup_identifier_interceptor(self._user) - def _setup_authorization_interceptor(self, user: str, password: str, token: str): keys = [] values = [] @@ -233,15 +211,6 @@ def _setup_grpc_channel(self, **kwargs): self._request_id = None self._async_stub = milvus_pb2_grpc.MilvusServiceStub(self._final_channel) - def _setup_identifier_interceptor(self, user: str, timeout: int = 10): - host = socket.gethostname() - self._identifier = self.__internal_register(user, host, timeout=timeout) - _async_identifier_interceptor = async_header_adder_interceptor( - ["identifier"], [str(self._identifier)] - ) - self._async_channel._unary_unary_interceptors.append(_async_identifier_interceptor) - self._async_stub = milvus_pb2_grpc.MilvusServiceStub(self._async_channel) - @property def server_address(self): return self._address @@ -249,10 +218,36 @@ def server_address(self): def get_server_type(self): return get_server_type(self.server_address.split(":")[0]) + async def ensure_channel_ready(self): + try: + if not self._is_channel_ready: + # wait for channel ready + await self._async_channel.channel_ready() + # set identifier interceptor + host = socket.gethostname() + req = Prepare.register_request(self._user, host) + response = await self._async_stub.Connect(request=req) + check_status(response.status) + _async_identifier_interceptor = async_header_adder_interceptor( + ["identifier"], [str(response.identifier)] + ) + self._async_channel._unary_unary_interceptors.append(_async_identifier_interceptor) + self._async_stub = milvus_pb2_grpc.MilvusServiceStub(self._async_channel) + + self._is_channel_ready = True + except grpc.FutureTimeoutError as e: + raise MilvusException( + code=Status.CONNECT_FAILED, + message=f"Fail connecting to server on {self._address}, illegal connection params or server unavailable", + ) from e + except Exception as e: + raise e from e + @retry_on_rpc_failure() async def create_collection( self, collection_name: str, fields: List, timeout: Optional[float] = None, **kwargs ): + await self.ensure_channel_ready() check_pass_param(collection_name=collection_name, timeout=timeout) request = Prepare.create_collection_request(collection_name, fields, **kwargs) response = await self._async_stub.CreateCollection(request, timeout=timeout) @@ -260,6 +255,7 @@ async def create_collection( @retry_on_rpc_failure() async def drop_collection(self, collection_name: str, timeout: Optional[float] = None): + await self.ensure_channel_ready() check_pass_param(collection_name=collection_name, timeout=timeout) request = Prepare.drop_collection_request(collection_name) response = await self._async_stub.DropCollection(request, timeout=timeout) @@ -273,6 +269,7 @@ async def load_collection( timeout: Optional[float] = None, **kwargs, ): + await self.ensure_channel_ready() check_pass_param( collection_name=collection_name, replica_number=replica_number, timeout=timeout ) @@ -336,6 +333,7 @@ async def get_loading_progress( async def describe_collection( self, collection_name: str, timeout: Optional[float] = None, **kwargs ): + await self.ensure_channel_ready() check_pass_param(collection_name=collection_name, timeout=timeout) request = Prepare.describe_collection_request(collection_name) response = await self._async_stub.DescribeCollection(request, timeout=timeout) @@ -366,6 +364,7 @@ async def insert_rows( timeout: Optional[float] = None, **kwargs, ): + await self.ensure_channel_ready() request = await self._prepare_row_insert_request( collection_name, entities, partition_name, schema, timeout, **kwargs ) @@ -400,6 +399,7 @@ async def _prepare_row_insert_request( enable_dynamic=enable_dynamic, ) + @retry_on_rpc_failure() async def delete( self, collection_name: str, @@ -408,6 +408,7 @@ async def delete( timeout: Optional[float] = None, **kwargs, ): + await self.ensure_channel_ready() check_pass_param(collection_name=collection_name, timeout=timeout) try: req = Prepare.delete_request( @@ -462,6 +463,7 @@ async def upsert( timeout: Optional[float] = None, **kwargs, ): + await self.ensure_channel_ready() if not check_invalid_binary_vector(entities): raise ParamError(message="Invalid binary vector data exists") @@ -507,6 +509,7 @@ async def upsert_rows( timeout: Optional[float] = None, **kwargs, ): + await self.ensure_channel_ready() if isinstance(entities, dict): entities = [entities] request = await self._prepare_row_upsert_request( @@ -561,6 +564,7 @@ async def search( timeout: Optional[float] = None, **kwargs, ): + await self.ensure_channel_ready() check_pass_param( limit=limit, round_decimal=round_decimal, @@ -598,6 +602,7 @@ async def hybrid_search( timeout: Optional[float] = None, **kwargs, ): + await self.ensure_channel_ready() check_pass_param( limit=limit, round_decimal=round_decimal, @@ -651,7 +656,7 @@ async def create_index( collection_desc = await self.describe_collection( collection_name, timeout=timeout, **copy_kwargs ) - + await self.ensure_channel_ready() valid_field = False for fields in collection_desc["fields"]: if field_name != fields["name"]: @@ -747,6 +752,7 @@ async def get( timeout: Optional[float] = None, ): # TODO: some check + await self.ensure_channel_ready() request = Prepare.retrieve_request(collection_name, ids, output_fields, partition_names) return await self._async_stub.Retrieve.get(request, timeout=timeout) @@ -760,6 +766,7 @@ async def query( timeout: Optional[float] = None, **kwargs, ): + await self.ensure_channel_ready() if output_fields is not None and not isinstance(output_fields, (list,)): raise ParamError(message="Invalid query format. 'output_fields' must be a list") request = Prepare.query_request( @@ -792,20 +799,6 @@ async def query( extra_dict[ITERATOR_SESSION_TS_FIELD] = response.session_ts return ExtraList(results, extra=extra_dict) - @retry_on_rpc_failure() - @upgrade_reminder - def __internal_register(self, user: str, host: str, **kwargs) -> int: - req = Prepare.register_request(user, host) - - async def wait_for_connect_response(): - return await self._async_stub.Connect(request=req) - - loop = asyncio.get_event_loop() - response = loop.run_until_complete(wait_for_connect_response()) - - check_status(response.status) - return response.identifier - @retry_on_rpc_failure() @ignore_unimplemented(0) async def alloc_timestamp(self, timeout: Optional[float] = None) -> int: diff --git a/pymilvus/orm/connections.py b/pymilvus/orm/connections.py index 5d6306661..151b31e7d 100644 --- a/pymilvus/orm/connections.py +++ b/pymilvus/orm/connections.py @@ -400,7 +400,9 @@ def connect_milvus(**kwargs): t = kwargs.get("timeout") timeout = t if isinstance(t, (int, float)) else Config.MILVUS_CONN_TIMEOUT - gh._wait_for_channel_ready(timeout=timeout) + if not _async: + gh._wait_for_channel_ready(timeout=timeout) + if kwargs.get("keep_alive", False): gh.register_state_change_callback( ReconnectHandler(self, alias, kwargs_copy).reconnect_on_idle