Skip to content

Commit

Permalink
fix: remove async function in AsyncClient connection initialization
Browse files Browse the repository at this point in the history
- Move check for channel ready and identifier interceptor setup before each query call instead of during the AsyncClient connection initialization

Signed-off-by: Ruichen Bao <[email protected]>
  • Loading branch information
brcarry committed Dec 20, 2024
1 parent 7bccdcb commit 73405f6
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 49 deletions.
89 changes: 41 additions & 48 deletions pymilvus/client/async_grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import grpc
from grpc._cython import cygrpc

from pymilvus.decorators import ignore_unimplemented, retry_on_rpc_failure, upgrade_reminder
from pymilvus.decorators import ignore_unimplemented, retry_on_rpc_failure
from pymilvus.exceptions import (
AmbiguousIndexName,
DescribeCollectionException,
Expand Down Expand Up @@ -68,6 +68,7 @@ def __init__(
self._set_authorization(**kwargs)
self._setup_db_name(kwargs.get("db_name"))
self._setup_grpc_channel(**kwargs)
self._is_channel_ready = False
self.callbacks = []

def register_state_change_callback(self, callback: Callable):
Expand Down Expand Up @@ -108,33 +109,10 @@ def __enter__(self):
def __exit__(self: object, exc_type: object, exc_val: object, exc_tb: object):
pass

def _wait_for_channel_ready(self, timeout: Union[float] = 10, retry_interval: float = 1):
try:

async def wait_for_async_channel_ready():
await self._async_channel.channel_ready()

loop = asyncio.get_event_loop()
loop.run_until_complete(wait_for_async_channel_ready())

self._setup_identifier_interceptor(self._user, timeout=timeout)
except grpc.FutureTimeoutError as e:
raise MilvusException(
code=Status.CONNECT_FAILED,
message=f"Fail connecting to server on {self._address}, illegal connection params or server unavailable",
) from e
except Exception as e:
raise e from e

def close(self):
self.deregister_state_change_callbacks()
self._async_channel.close()

def reset_db_name(self, db_name: str):
self._setup_db_name(db_name)
self._setup_grpc_channel()
self._setup_identifier_interceptor(self._user)

def _setup_authorization_interceptor(self, user: str, password: str, token: str):
keys = []
values = []
Expand Down Expand Up @@ -233,33 +211,51 @@ def _setup_grpc_channel(self, **kwargs):
self._request_id = None
self._async_stub = milvus_pb2_grpc.MilvusServiceStub(self._final_channel)

def _setup_identifier_interceptor(self, user: str, timeout: int = 10):
host = socket.gethostname()
self._identifier = self.__internal_register(user, host, timeout=timeout)
_async_identifier_interceptor = async_header_adder_interceptor(
["identifier"], [str(self._identifier)]
)
self._async_channel._unary_unary_interceptors.append(_async_identifier_interceptor)
self._async_stub = milvus_pb2_grpc.MilvusServiceStub(self._async_channel)

@property
def server_address(self):
return self._address

def get_server_type(self):
return get_server_type(self.server_address.split(":")[0])

async def ensure_channel_ready(self):
try:
if not self._is_channel_ready:
# wait for channel ready
await self._async_channel.channel_ready()
# set identifier interceptor
host = socket.gethostname()
req = Prepare.register_request(self._user, host)
response = await self._async_stub.Connect(request=req)
check_status(response.status)
_async_identifier_interceptor = async_header_adder_interceptor(
["identifier"], [str(response.identifier)]
)
self._async_channel._unary_unary_interceptors.append(_async_identifier_interceptor)
self._async_stub = milvus_pb2_grpc.MilvusServiceStub(self._async_channel)

self._is_channel_ready = True
except grpc.FutureTimeoutError as e:
raise MilvusException(
code=Status.CONNECT_FAILED,
message=f"Fail connecting to server on {self._address}, illegal connection params or server unavailable",
) from e
except Exception as e:
raise e from e

@retry_on_rpc_failure()
async def create_collection(
self, collection_name: str, fields: List, timeout: Optional[float] = None, **kwargs
):
await self.ensure_channel_ready()
check_pass_param(collection_name=collection_name, timeout=timeout)
request = Prepare.create_collection_request(collection_name, fields, **kwargs)
response = await self._async_stub.CreateCollection(request, timeout=timeout)
check_status(response)

@retry_on_rpc_failure()
async def drop_collection(self, collection_name: str, timeout: Optional[float] = None):
await self.ensure_channel_ready()
check_pass_param(collection_name=collection_name, timeout=timeout)
request = Prepare.drop_collection_request(collection_name)
response = await self._async_stub.DropCollection(request, timeout=timeout)
Expand All @@ -273,6 +269,7 @@ async def load_collection(
timeout: Optional[float] = None,
**kwargs,
):
await self.ensure_channel_ready()
check_pass_param(
collection_name=collection_name, replica_number=replica_number, timeout=timeout
)
Expand Down Expand Up @@ -336,6 +333,7 @@ async def get_loading_progress(
async def describe_collection(
self, collection_name: str, timeout: Optional[float] = None, **kwargs
):
await self.ensure_channel_ready()
check_pass_param(collection_name=collection_name, timeout=timeout)
request = Prepare.describe_collection_request(collection_name)
response = await self._async_stub.DescribeCollection(request, timeout=timeout)
Expand Down Expand Up @@ -366,6 +364,7 @@ async def insert_rows(
timeout: Optional[float] = None,
**kwargs,
):
await self.ensure_channel_ready()
request = await self._prepare_row_insert_request(
collection_name, entities, partition_name, schema, timeout, **kwargs
)
Expand Down Expand Up @@ -400,6 +399,7 @@ async def _prepare_row_insert_request(
enable_dynamic=enable_dynamic,
)

@retry_on_rpc_failure()
async def delete(
self,
collection_name: str,
Expand All @@ -408,6 +408,7 @@ async def delete(
timeout: Optional[float] = None,
**kwargs,
):
await self.ensure_channel_ready()
check_pass_param(collection_name=collection_name, timeout=timeout)
try:
req = Prepare.delete_request(
Expand Down Expand Up @@ -462,6 +463,7 @@ async def upsert(
timeout: Optional[float] = None,
**kwargs,
):
await self.ensure_channel_ready()
if not check_invalid_binary_vector(entities):
raise ParamError(message="Invalid binary vector data exists")

Expand Down Expand Up @@ -507,6 +509,7 @@ async def upsert_rows(
timeout: Optional[float] = None,
**kwargs,
):
await self.ensure_channel_ready()
if isinstance(entities, dict):
entities = [entities]
request = await self._prepare_row_upsert_request(
Expand Down Expand Up @@ -561,6 +564,7 @@ async def search(
timeout: Optional[float] = None,
**kwargs,
):
await self.ensure_channel_ready()
check_pass_param(
limit=limit,
round_decimal=round_decimal,
Expand Down Expand Up @@ -598,6 +602,7 @@ async def hybrid_search(
timeout: Optional[float] = None,
**kwargs,
):
await self.ensure_channel_ready()
check_pass_param(
limit=limit,
round_decimal=round_decimal,
Expand Down Expand Up @@ -651,7 +656,7 @@ async def create_index(
collection_desc = await self.describe_collection(
collection_name, timeout=timeout, **copy_kwargs
)

await self.ensure_channel_ready()
valid_field = False
for fields in collection_desc["fields"]:
if field_name != fields["name"]:
Expand Down Expand Up @@ -747,6 +752,7 @@ async def get(
timeout: Optional[float] = None,
):
# TODO: some check
await self.ensure_channel_ready()
request = Prepare.retrieve_request(collection_name, ids, output_fields, partition_names)
return await self._async_stub.Retrieve.get(request, timeout=timeout)

Expand All @@ -760,6 +766,7 @@ async def query(
timeout: Optional[float] = None,
**kwargs,
):
await self.ensure_channel_ready()
if output_fields is not None and not isinstance(output_fields, (list,)):
raise ParamError(message="Invalid query format. 'output_fields' must be a list")
request = Prepare.query_request(
Expand Down Expand Up @@ -792,20 +799,6 @@ async def query(
extra_dict[ITERATOR_SESSION_TS_FIELD] = response.session_ts
return ExtraList(results, extra=extra_dict)

@retry_on_rpc_failure()
@upgrade_reminder
def __internal_register(self, user: str, host: str, **kwargs) -> int:
req = Prepare.register_request(user, host)

async def wait_for_connect_response():
return await self._async_stub.Connect(request=req)

loop = asyncio.get_event_loop()
response = loop.run_until_complete(wait_for_connect_response())

check_status(response.status)
return response.identifier

@retry_on_rpc_failure()
@ignore_unimplemented(0)
async def alloc_timestamp(self, timeout: Optional[float] = None) -> int:
Expand Down
4 changes: 3 additions & 1 deletion pymilvus/orm/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,9 @@ def connect_milvus(**kwargs):
t = kwargs.get("timeout")
timeout = t if isinstance(t, (int, float)) else Config.MILVUS_CONN_TIMEOUT

gh._wait_for_channel_ready(timeout=timeout)
if not _async:
gh._wait_for_channel_ready(timeout=timeout)

if kwargs.get("keep_alive", False):
gh.register_state_change_callback(
ReconnectHandler(self, alias, kwargs_copy).reconnect_on_idle
Expand Down

0 comments on commit 73405f6

Please sign in to comment.