diff --git a/pymilvus/client/async_grpc_handler.py b/pymilvus/client/async_grpc_handler.py index 1d8825be6..5532b4880 100644 --- a/pymilvus/client/async_grpc_handler.py +++ b/pymilvus/client/async_grpc_handler.py @@ -2,6 +2,7 @@ import base64 import copy import socket +import time from pathlib import Path from typing import Callable, Dict, List, Optional, Union from urllib import parse @@ -9,9 +10,11 @@ import grpc from grpc._cython import cygrpc -from pymilvus.decorators import retry_on_rpc_failure, upgrade_reminder +from pymilvus.decorators import ignore_unimplemented, retry_on_rpc_failure, upgrade_reminder from pymilvus.exceptions import ( + AmbiguousIndexName, DescribeCollectionException, + ExceptionsMessage, MilvusException, ParamError, ) @@ -32,6 +35,7 @@ from .types import ( DataType, ExtraList, + IndexState, Status, get_cost_extra, ) @@ -291,6 +295,43 @@ async def load_collection( response = await self._async_stub.LoadCollection(request, timeout=timeout) check_status(response) + await self.wait_for_loading_collection(collection_name, timeout, is_refresh=refresh) + + @retry_on_rpc_failure() + async def wait_for_loading_collection( + self, collection_name: str, timeout: Optional[float] = None, is_refresh: bool = False + ): + start = time.time() + + def can_loop(t: int) -> bool: + return True if timeout is None else t <= (start + timeout) + + while can_loop(time.time()): + progress = await self.get_loading_progress( + collection_name, timeout=timeout, is_refresh=is_refresh + ) + if progress >= 100: + return + await asyncio.sleep(Config.WaitTimeDurationWhenLoad) + raise MilvusException( + message=f"wait for loading collection timeout, collection: {collection_name}" + ) + + @retry_on_rpc_failure() + async def get_loading_progress( + self, + collection_name: str, + partition_names: Optional[List[str]] = None, + timeout: Optional[float] = None, + is_refresh: bool = False, + ): + request = Prepare.get_loading_progress(collection_name, partition_names) + response = await self._async_stub.GetLoadingProgress(request, timeout=timeout) + check_status(response.status) + if is_refresh: + return response.refresh_progress + return response.progress + @retry_on_rpc_failure() async def describe_collection( self, collection_name: str, timeout: Optional[float] = None, **kwargs @@ -635,8 +676,67 @@ async def create_index( status = await self._async_stub.CreateIndex(index_param, timeout=timeout) check_status(status) + index_success, fail_reason = await self.wait_for_creating_index( + collection_name=collection_name, + index_name=index_name, + timeout=timeout, + field_name=field_name, + ) + + if not index_success: + raise MilvusException(message=fail_reason) + return Status(status.code, status.reason) + @retry_on_rpc_failure() + async def wait_for_creating_index( + self, collection_name: str, index_name: str, timeout: Optional[float] = None, **kwargs + ): + timestamp = await self.alloc_timestamp() + start = time.time() + while True: + await asyncio.sleep(0.5) + state, fail_reason = await self.get_index_state( + collection_name, index_name, timeout=timeout, timestamp=timestamp, **kwargs + ) + if state == IndexState.Finished: + return True, fail_reason + if state == IndexState.Failed: + return False, fail_reason + end = time.time() + if isinstance(timeout, int) and end - start > timeout: + msg = ( + f"collection {collection_name} create index {index_name} " + f"timeout in {timeout}s" + ) + raise MilvusException(message=msg) + + @retry_on_rpc_failure() + async def get_index_state( + self, + collection_name: str, + index_name: str, + timeout: Optional[float] = None, + timestamp: Optional[int] = None, + **kwargs, + ): + request = Prepare.describe_index_request(collection_name, index_name, timestamp) + response = await self._async_stub.DescribeIndex(request, timeout=timeout) + status = response.status + check_status(status) + + if len(response.index_descriptions) == 1: + index_desc = response.index_descriptions[0] + return index_desc.state, index_desc.index_state_fail_reason + # just for create_index. + field_name = kwargs.pop("field_name", "") + if field_name != "": + for index_desc in response.index_descriptions: + if index_desc.field_name == field_name: + return index_desc.state, index_desc.index_state_fail_reason + + raise AmbiguousIndexName(message=ExceptionsMessage.AmbiguousIndexName) + @retry_on_rpc_failure() async def get( self, @@ -705,3 +805,11 @@ async def wait_for_connect_response(): check_status(response.status) return response.identifier + + @retry_on_rpc_failure() + @ignore_unimplemented(0) + async def alloc_timestamp(self, timeout: Optional[float] = None) -> int: + request = milvus_types.AllocTimestampRequest() + response = await self._async_stub.AllocTimestamp(request, timeout=timeout) + check_status(response.status) + return response.timestamp