diff --git a/pymilvus/client/abstract.py b/pymilvus/client/abstract.py index a7711749b..aeb4f9f63 100644 --- a/pymilvus/client/abstract.py +++ b/pymilvus/client/abstract.py @@ -497,11 +497,15 @@ def __init__( ) nq_thres += topk self._session_ts = session_ts + self._search_iterator_v2_results = res.search_iterator_v2_results super().__init__(data) def get_session_ts(self): return self._session_ts + def get_search_iterator_v2_results_info(self): + return self._search_iterator_v2_results + def get_fields_by_range( self, start: int, end: int, all_fields_data: List[schema_pb2.FieldData] ) -> Dict[str, Tuple[List[Any], schema_pb2.FieldData]]: diff --git a/pymilvus/client/constants.py b/pymilvus/client/constants.py index efb38aa7a..a7265f0b7 100644 --- a/pymilvus/client/constants.py +++ b/pymilvus/client/constants.py @@ -16,6 +16,10 @@ STRICT_GROUP_SIZE = "strict_group_size" ITERATOR_FIELD = "iterator" ITERATOR_SESSION_TS_FIELD = "iterator_session_ts" +ITER_SEARCH_V2_KEY = "search_iter_v2" +ITER_SEARCH_BATCH_SIZE_KEY = "search_iter_batch_size" +ITER_SEARCH_LAST_BOUND_KEY = "search_iter_last_bound" +ITER_SEARCH_TOKEN_KEY = "search_iter_token" PAGE_RETAIN_ORDER_FIELD = "page_retain_order" RANKER_TYPE_RRF = "rrf" diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py index 4bbc95b9a..1ec051a51 100644 --- a/pymilvus/client/prepare.py +++ b/pymilvus/client/prepare.py @@ -20,6 +20,10 @@ GROUP_BY_FIELD, GROUP_SIZE, ITERATOR_FIELD, + ITER_SEARCH_V2_KEY, + ITER_SEARCH_BATCH_SIZE_KEY, + ITER_SEARCH_LAST_BOUND_KEY, + ITER_SEARCH_TOKEN_KEY, PAGE_RETAIN_ORDER_FIELD, RANK_GROUP_SCORER, REDUCE_STOP_FOR_BEST, @@ -940,6 +944,22 @@ def search_requests_with_expr( is_iterator = kwargs.get(ITERATOR_FIELD) if is_iterator is not None: search_params[ITERATOR_FIELD] = is_iterator + + is_search_iter_v2 = kwargs.get(ITER_SEARCH_V2_KEY) + if is_search_iter_v2 is not None: + search_params[ITER_SEARCH_V2_KEY] = is_search_iter_v2 + + search_iter_batch_size = kwargs.get(ITER_SEARCH_BATCH_SIZE_KEY) + if search_iter_batch_size is not None: + search_params[ITER_SEARCH_BATCH_SIZE_KEY] = search_iter_batch_size + + search_iter_last_bound = kwargs.get(ITER_SEARCH_LAST_BOUND_KEY) + if search_iter_last_bound is not None: + search_params[ITER_SEARCH_LAST_BOUND_KEY] = search_iter_last_bound + + search_iter_token = kwargs.get(ITER_SEARCH_TOKEN_KEY) + if search_iter_token is not None: + search_params[ITER_SEARCH_TOKEN_KEY] = search_iter_token group_by_field = kwargs.get(GROUP_BY_FIELD) if group_by_field is not None: diff --git a/pymilvus/orm/collection.py b/pymilvus/orm/collection.py index a31d374cf..491c721a1 100644 --- a/pymilvus/orm/collection.py +++ b/pymilvus/orm/collection.py @@ -42,7 +42,7 @@ from .constants import UNLIMITED from .future import MutationFuture, SearchFuture from .index import Index -from .iterator import QueryIterator, SearchIterator +from .iterator import QueryIterator, SearchIterator, SearchIteratorV2 from .mutation import MutationResult from .partition import Partition from .prepare import Prepare @@ -969,26 +969,33 @@ def search_iterator( output_fields: Optional[List[str]] = None, timeout: Optional[float] = None, round_decimal: int = -1, + use_v1: Optional[bool] = False, **kwargs, ): if expr is not None and not isinstance(expr, str): raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr)) - return SearchIterator( - connection=self._get_connection(), - collection_name=self._name, - data=data, - ann_field=anns_field, - param=param, - batch_size=batch_size, - limit=limit, - expr=expr, - partition_names=partition_names, - output_fields=output_fields, - timeout=timeout, - round_decimal=round_decimal, - schema=self._schema_dict, + + iterator_params = { + 'connection': self._get_connection(), + 'collection_name': self._name, + 'data': data, + 'anns_field': anns_field, + 'param': param, + 'batch_size': batch_size, + 'limit': limit, + 'expr': expr, + 'partition_names': partition_names, + 'output_fields': output_fields, + 'timeout': timeout, + 'round_decimal': round_decimal, + 'schema': self._schema_dict, **kwargs, - ) + } + + if use_v1: + return SearchIterator(**iterator_params) + else: + return SearchIteratorV2(**iterator_params) def query( self, diff --git a/pymilvus/orm/constants.py b/pymilvus/orm/constants.py index 6862ab75f..766db3146 100644 --- a/pymilvus/orm/constants.py +++ b/pymilvus/orm/constants.py @@ -48,6 +48,11 @@ REDUCE_STOP_FOR_BEST = "reduce_stop_for_best" ITERATOR_FIELD = "iterator" ITERATOR_SESSION_TS_FIELD = "iterator_session_ts" +ITER_SEARCH_V2_KEY = "search_iter_v2" +ITER_SEARCH_BATCH_SIZE_KEY = "search_iter_batch_size" +ITER_SEARCH_LAST_BOUND_KEY = "search_iter_last_bound" +ITER_SEARCH_TOKEN_KEY = "search_iter_token" +ITER_SEARCH_TTL_KEY = "search_iter_ttl" PRINT_ITERATOR_CURSOR = "print_iterator_cursor" DEFAULT_MAX_L2_DISTANCE = 99999999.0 DEFAULT_MIN_IP_DISTANCE = -99999999.0 diff --git a/pymilvus/orm/iterator.py b/pymilvus/orm/iterator.py index 15118523f..c86b016ee 100644 --- a/pymilvus/orm/iterator.py +++ b/pymilvus/orm/iterator.py @@ -30,6 +30,11 @@ ITERATOR_FIELD, ITERATOR_SESSION_CP_FILE, ITERATOR_SESSION_TS_FIELD, + ITER_SEARCH_V2_KEY, + ITER_SEARCH_BATCH_SIZE_KEY, + ITER_SEARCH_LAST_BOUND_KEY, + ITER_SEARCH_TOKEN_KEY, + ITER_SEARCH_TTL_KEY, MAX_BATCH_SIZE, MAX_FILTERED_IDS_COUNT_ITERATION, MAX_TRY_TIME, @@ -51,7 +56,7 @@ LOGGER.setLevel(logging.INFO) QueryIterator = TypeVar("QueryIterator") SearchIterator = TypeVar("SearchIterator") - +SearchIteratorV2 = TypeVar("SearchIteratorV2") log = logging.getLogger(__name__) @@ -86,6 +91,11 @@ def extend_batch_size(batch_size: int, next_param: dict, to_extend_batch_size: b def check_set_flag(obj: Any, flag_name: str, kwargs: Dict[str, Any], key: str): setattr(obj, flag_name, kwargs.get(key, False)) +def check_batch_size(batch_size: int): + if batch_size < 0: + raise ParamError(message="batch size cannot be less than zero") + if batch_size > MAX_BATCH_SIZE: + raise ParamError(message=f"batch size cannot be larger than {MAX_BATCH_SIZE}") class QueryIterator: def __init__( @@ -192,10 +202,7 @@ def __check_set_reduce_stop_for_best(self): self._kwargs[REDUCE_STOP_FOR_BEST] = "False" def __check_set_batch_size(self, batch_size: int): - if batch_size < 0: - raise ParamError(message="batch size cannot be less than zero") - if batch_size > MAX_BATCH_SIZE: - raise ParamError(message=f"batch size cannot be larger than {MAX_BATCH_SIZE}") + check_batch_size(batch_size) self._kwargs[BATCH_SIZE] = batch_size self._kwargs[MILVUS_LIMIT] = batch_size @@ -432,13 +439,30 @@ def distances(self): return distances +def check_num_queries(data: Union[List, utils.SparseMatrixInputType]): + rows = entity_helper.get_input_num_rows(data) + if rows > 1: + raise ParamError( + message="Not support search iteration over multiple vectors at present" + ) + if rows == 0: + raise ParamError(message="vector_data for search cannot be empty") + +def check_metrics(param: Dict): + if param[METRIC_TYPE] is None or param[METRIC_TYPE] == "": + raise ParamError(message="must specify metrics type for search iterator") + +def check_offset(kwargs: Dict): + if kwargs.get(OFFSET, 0) != 0: + raise ParamError(message="Not support offset when searching iteration") + class SearchIterator: def __init__( self, connection: Connections, collection_name: str, data: Union[List, utils.SparseMatrixInputType], - ann_field: str, + anns_field: str, param: Dict, batch_size: Optional[int] = 1000, limit: Optional[int] = UNLIMITED, @@ -450,18 +474,14 @@ def __init__( schema: Optional[CollectionSchema] = None, **kwargs, ) -> SearchIterator: - rows = entity_helper.get_input_num_rows(data) - if rows > 1: - raise ParamError( - message="Not support search iteration over multiple vectors at present" - ) - if rows == 0: - raise ParamError(message="vector_data for search cannot be empty") + check_num_queries(data) + check_metrics(param) + check_offset(kwargs) self._conn = connection self._iterator_params = { "collection_name": collection_name, "data": data, - "ann_field": ann_field, + "anns_field": anns_field, BATCH_SIZE: batch_size, "output_fields": output_fields, "partition_names": partition_names, @@ -478,8 +498,6 @@ def __init__( self._schema = schema self._limit = limit self._returned_count = 0 - self.__check_metrics() - self.__check_offset() self.__check_rm_range_search_parameters() self.__setup__pk_prop() check_set_flag(self, "_print_iterator_cursor", self._kwargs, PRINT_ITERATOR_CURSOR) @@ -561,10 +579,6 @@ def __setup__pk_prop(self): if self._pk_field_name is None or self._pk_field_name == "": raise ParamError(message="schema must contain pk field, broke") - def __check_metrics(self): - if self._param[METRIC_TYPE] is None or self._param[METRIC_TYPE] == "": - raise ParamError(message="must specify metrics type for search iterator") - """we use search && range search to implement search iterator, so range search parameters are disabled to clients""" @@ -587,10 +601,6 @@ def __check_rm_range_search_parameters(self): f"smalled than range_filter, please adjust your parameter" ) - def __check_offset(self): - if self._kwargs.get(OFFSET, 0) != 0: - raise ParamError(message="Not support offset when searching iteration") - def __update_filtered_ids(self, res: SearchPage): if len(res) == 0: return @@ -698,7 +708,7 @@ def __execute_next_search( res = self._conn.search( self._iterator_params["collection_name"], self._iterator_params["data"], - self._iterator_params["ann_field"], + self._iterator_params["anns_field"], next_params, extend_batch_size(self._iterator_params[BATCH_SIZE], next_params, to_extend_batch), next_expr, @@ -784,3 +794,75 @@ def release_cache(self, cache_id: int): NO_CACHE_ID = -1 # Singleton Mode in Python iterator_cache = IteratorCache() + +class SearchIteratorV2: + def __init__( + self, + connection: Connections, + collection_name: str, + data: Union[List, utils.SparseMatrixInputType], + anns_field: str, + param: Dict, + batch_size: int = 1000, + expr: Optional[str] = None, + partition_names: Optional[List[str]] = None, + output_fields: Optional[List[str]] = None, + timeout: Optional[float] = None, + ttl: Optional[int] = None, + round_decimal: int = -1, + **kwargs, + ) -> SearchIteratorV2: + check_num_queries(data) + check_metrics(param) + check_offset(kwargs) + check_batch_size(batch_size) + # check_set_flag(self, "_print_iterator_cursor", kwargs, PRINT_ITERATOR_CURSOR) + + # delete limit from incoming for compatibility + if MILVUS_LIMIT in kwargs: + del kwargs[MILVUS_LIMIT] + + self._conn = connection + self._params = { + 'collection_name': collection_name, + 'data': data, + 'anns_field': anns_field, + 'param': deepcopy(param), + 'limit': batch_size, + 'expression': expr, + 'partition_names': partition_names, + 'output_fields': output_fields, + 'round_decimal': round_decimal, + 'timeout': timeout, + ITERATOR_FIELD: True, + ITER_SEARCH_V2_KEY: True, + ITER_SEARCH_BATCH_SIZE_KEY: batch_size, + ITER_SEARCH_TTL_KEY: ttl, + GUARANTEE_TIMESTAMP: 0, + **kwargs, + } + + def next(self): + res = self._conn.search( + **self._params + ) + iter_info = res.get_search_iterator_v2_results_info() + self._params[ITER_SEARCH_LAST_BOUND_KEY] = iter_info.last_bound + + # patch token and guarantee timestamp for the first next() call + if ITER_SEARCH_TOKEN_KEY not in self._params: + self._params[ITER_SEARCH_TOKEN_KEY] = iter_info.token + if self._params[GUARANTEE_TIMESTAMP] <= 0: + if res.get_session_ts() > 0: + self._params[GUARANTEE_TIMESTAMP] = res.get_session_ts() + else: + log.warning("failed to set up mvccTs from milvus server, use client-side ts instead") + self._params[GUARANTEE_TIMESTAMP] = fall_back_to_latest_session_ts() + + # return SearchPage for compability + if len(res) > 0: + return SearchPage(res[0]) + return SearchPage(None) + + def close(self): + pass