Skip to content

Commit

Permalink
enhance: add search iterator v2
Browse files Browse the repository at this point in the history
Signed-off-by: Patrick Weizhi Xu <[email protected]>
  • Loading branch information
PwzXxm committed Nov 29, 2024
1 parent 053c568 commit 6398b3a
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 41 deletions.
4 changes: 4 additions & 0 deletions pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,11 +497,15 @@ def __init__(
)
nq_thres += topk
self._session_ts = session_ts
self._search_iterator_v2_results = res.search_iterator_v2_results
super().__init__(data)

def get_session_ts(self):
return self._session_ts

def get_search_iterator_v2_results_info(self):
return self._search_iterator_v2_results

def get_fields_by_range(
self, start: int, end: int, all_fields_data: List[schema_pb2.FieldData]
) -> Dict[str, Tuple[List[Any], schema_pb2.FieldData]]:
Expand Down
4 changes: 4 additions & 0 deletions pymilvus/client/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
STRICT_GROUP_SIZE = "strict_group_size"
ITERATOR_FIELD = "iterator"
ITERATOR_SESSION_TS_FIELD = "iterator_session_ts"
ITER_SEARCH_V2_KEY = "search_iter_v2"
ITER_SEARCH_BATCH_SIZE_KEY = "search_iter_batch_size"
ITER_SEARCH_LAST_BOUND_KEY = "search_iter_last_bound"
ITER_SEARCH_ID_KEY = "search_iter_id"
PAGE_RETAIN_ORDER_FIELD = "page_retain_order"

RANKER_TYPE_RRF = "rrf"
Expand Down
20 changes: 20 additions & 0 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
DYNAMIC_FIELD_NAME,
GROUP_BY_FIELD,
GROUP_SIZE,
ITER_SEARCH_BATCH_SIZE_KEY,
ITER_SEARCH_ID_KEY,
ITER_SEARCH_LAST_BOUND_KEY,
ITER_SEARCH_V2_KEY,
ITERATOR_FIELD,
PAGE_RETAIN_ORDER_FIELD,
RANK_GROUP_SCORER,
Expand Down Expand Up @@ -941,6 +945,22 @@ def search_requests_with_expr(
if is_iterator is not None:
search_params[ITERATOR_FIELD] = is_iterator

is_search_iter_v2 = kwargs.get(ITER_SEARCH_V2_KEY)
if is_search_iter_v2 is not None:
search_params[ITER_SEARCH_V2_KEY] = is_search_iter_v2

search_iter_batch_size = kwargs.get(ITER_SEARCH_BATCH_SIZE_KEY)
if search_iter_batch_size is not None:
search_params[ITER_SEARCH_BATCH_SIZE_KEY] = search_iter_batch_size

search_iter_last_bound = kwargs.get(ITER_SEARCH_LAST_BOUND_KEY)
if search_iter_last_bound is not None:
search_params[ITER_SEARCH_LAST_BOUND_KEY] = search_iter_last_bound

search_iter_id = kwargs.get(ITER_SEARCH_ID_KEY)
if search_iter_id is not None:
search_params[ITER_SEARCH_ID_KEY] = search_iter_id

group_by_field = kwargs.get(GROUP_BY_FIELD)
if group_by_field is not None:
search_params[GROUP_BY_FIELD] = group_by_field
Expand Down
38 changes: 22 additions & 16 deletions pymilvus/orm/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from .constants import UNLIMITED
from .future import MutationFuture, SearchFuture
from .index import Index
from .iterator import QueryIterator, SearchIterator
from .iterator import QueryIterator, SearchIterator, SearchIteratorV2
from .mutation import MutationResult
from .partition import Partition
from .prepare import Prepare
Expand Down Expand Up @@ -969,26 +969,32 @@ def search_iterator(
output_fields: Optional[List[str]] = None,
timeout: Optional[float] = None,
round_decimal: int = -1,
use_v1: Optional[bool] = False,
**kwargs,
):
if expr is not None and not isinstance(expr, str):
raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr))
return SearchIterator(
connection=self._get_connection(),
collection_name=self._name,
data=data,
ann_field=anns_field,
param=param,
batch_size=batch_size,
limit=limit,
expr=expr,
partition_names=partition_names,
output_fields=output_fields,
timeout=timeout,
round_decimal=round_decimal,
schema=self._schema_dict,

iterator_params = {
"connection": self._get_connection(),
"collection_name": self._name,
"data": data,
"anns_field": anns_field,
"param": param,
"batch_size": batch_size,
"limit": limit,
"expr": expr,
"partition_names": partition_names,
"output_fields": output_fields,
"timeout": timeout,
"round_decimal": round_decimal,
"schema": self._schema_dict,
**kwargs,
)
}

if use_v1:
return SearchIterator(**iterator_params)
return SearchIteratorV2(**iterator_params)

def query(
self,
Expand Down
5 changes: 5 additions & 0 deletions pymilvus/orm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@
REDUCE_STOP_FOR_BEST = "reduce_stop_for_best"
ITERATOR_FIELD = "iterator"
ITERATOR_SESSION_TS_FIELD = "iterator_session_ts"
ITER_SEARCH_V2_KEY = "search_iter_v2"
ITER_SEARCH_BATCH_SIZE_KEY = "search_iter_batch_size"
ITER_SEARCH_LAST_BOUND_KEY = "search_iter_last_bound"
ITER_SEARCH_ID_KEY = "search_iter_id"
ITER_SEARCH_TTL_KEY = "search_iter_ttl"
PRINT_ITERATOR_CURSOR = "print_iterator_cursor"
DEFAULT_MAX_L2_DISTANCE = 99999999.0
DEFAULT_MIN_IP_DISTANCE = -99999999.0
Expand Down
135 changes: 110 additions & 25 deletions pymilvus/orm/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
GUARANTEE_TIMESTAMP,
INT64_MAX,
IS_PRIMARY,
ITER_SEARCH_BATCH_SIZE_KEY,
ITER_SEARCH_ID_KEY,
ITER_SEARCH_LAST_BOUND_KEY,
ITER_SEARCH_TTL_KEY,
ITER_SEARCH_V2_KEY,
ITERATOR_FIELD,
ITERATOR_SESSION_CP_FILE,
ITERATOR_SESSION_TS_FIELD,
Expand All @@ -51,7 +56,7 @@
LOGGER.setLevel(logging.INFO)
QueryIterator = TypeVar("QueryIterator")
SearchIterator = TypeVar("SearchIterator")

SearchIteratorV2 = TypeVar("SearchIteratorV2")
log = logging.getLogger(__name__)


Expand Down Expand Up @@ -87,6 +92,13 @@ def check_set_flag(obj: Any, flag_name: str, kwargs: Dict[str, Any], key: str):
setattr(obj, flag_name, kwargs.get(key, False))


def check_batch_size(batch_size: int):
if batch_size < 0:
raise ParamError(message="batch size cannot be less than zero")
if batch_size > MAX_BATCH_SIZE:
raise ParamError(message=f"batch size cannot be larger than {MAX_BATCH_SIZE}")


class QueryIterator:
def __init__(
self,
Expand Down Expand Up @@ -192,10 +204,7 @@ def __check_set_reduce_stop_for_best(self):
self._kwargs[REDUCE_STOP_FOR_BEST] = "False"

def __check_set_batch_size(self, batch_size: int):
if batch_size < 0:
raise ParamError(message="batch size cannot be less than zero")
if batch_size > MAX_BATCH_SIZE:
raise ParamError(message=f"batch size cannot be larger than {MAX_BATCH_SIZE}")
check_batch_size(batch_size)
self._kwargs[BATCH_SIZE] = batch_size
self._kwargs[MILVUS_LIMIT] = batch_size

Expand Down Expand Up @@ -432,13 +441,31 @@ def distances(self):
return distances


def check_num_queries(data: Union[List, utils.SparseMatrixInputType]):
rows = entity_helper.get_input_num_rows(data)
if rows > 1:
raise ParamError(message="Not support search iteration over multiple vectors at present")
if rows == 0:
raise ParamError(message="vector_data for search cannot be empty")


def check_metrics(param: Dict):
if param[METRIC_TYPE] is None or param[METRIC_TYPE] == "":
raise ParamError(message="must specify metrics type for search iterator")


def check_offset(kwargs: Dict):
if kwargs.get(OFFSET, 0) != 0:
raise ParamError(message="Not support offset when searching iteration")


class SearchIterator:
def __init__(
self,
connection: Connections,
collection_name: str,
data: Union[List, utils.SparseMatrixInputType],
ann_field: str,
anns_field: str,
param: Dict,
batch_size: Optional[int] = 1000,
limit: Optional[int] = UNLIMITED,
Expand All @@ -450,18 +477,14 @@ def __init__(
schema: Optional[CollectionSchema] = None,
**kwargs,
) -> SearchIterator:
rows = entity_helper.get_input_num_rows(data)
if rows > 1:
raise ParamError(
message="Not support search iteration over multiple vectors at present"
)
if rows == 0:
raise ParamError(message="vector_data for search cannot be empty")
check_num_queries(data)
check_metrics(param)
check_offset(kwargs)
self._conn = connection
self._iterator_params = {
"collection_name": collection_name,
"data": data,
"ann_field": ann_field,
"anns_field": anns_field,
BATCH_SIZE: batch_size,
"output_fields": output_fields,
"partition_names": partition_names,
Expand All @@ -478,8 +501,6 @@ def __init__(
self._schema = schema
self._limit = limit
self._returned_count = 0
self.__check_metrics()
self.__check_offset()
self.__check_rm_range_search_parameters()
self.__setup__pk_prop()
check_set_flag(self, "_print_iterator_cursor", self._kwargs, PRINT_ITERATOR_CURSOR)
Expand Down Expand Up @@ -561,10 +582,6 @@ def __setup__pk_prop(self):
if self._pk_field_name is None or self._pk_field_name == "":
raise ParamError(message="schema must contain pk field, broke")

def __check_metrics(self):
if self._param[METRIC_TYPE] is None or self._param[METRIC_TYPE] == "":
raise ParamError(message="must specify metrics type for search iterator")

"""we use search && range search to implement search iterator,
so range search parameters are disabled to clients"""

Expand All @@ -587,10 +604,6 @@ def __check_rm_range_search_parameters(self):
f"smalled than range_filter, please adjust your parameter"
)

def __check_offset(self):
if self._kwargs.get(OFFSET, 0) != 0:
raise ParamError(message="Not support offset when searching iteration")

def __update_filtered_ids(self, res: SearchPage):
if len(res) == 0:
return
Expand Down Expand Up @@ -698,7 +711,7 @@ def __execute_next_search(
res = self._conn.search(
self._iterator_params["collection_name"],
self._iterator_params["data"],
self._iterator_params["ann_field"],
self._iterator_params["anns_field"],
next_params,
extend_batch_size(self._iterator_params[BATCH_SIZE], next_params, to_extend_batch),
next_expr,
Expand Down Expand Up @@ -784,3 +797,75 @@ def release_cache(self, cache_id: int):
NO_CACHE_ID = -1
# Singleton Mode in Python
iterator_cache = IteratorCache()


class SearchIteratorV2:
def __init__(
self,
connection: Connections,
collection_name: str,
data: Union[List, utils.SparseMatrixInputType],
anns_field: str,
param: Dict,
batch_size: int = 1000,
expr: Optional[str] = None,
partition_names: Optional[List[str]] = None,
output_fields: Optional[List[str]] = None,
timeout: Optional[float] = None,
ttl: Optional[int] = None,
round_decimal: int = -1,
**kwargs,
) -> SearchIteratorV2:
check_num_queries(data)
check_metrics(param)
check_offset(kwargs)
check_batch_size(batch_size)

# delete limit from incoming for compatibility
if MILVUS_LIMIT in kwargs:
del kwargs[MILVUS_LIMIT]

self._conn = connection
self._params = {
"collection_name": collection_name,
"data": data,
"anns_field": anns_field,
"param": deepcopy(param),
"limit": batch_size,
"expression": expr,
"partition_names": partition_names,
"output_fields": output_fields,
"round_decimal": round_decimal,
"timeout": timeout,
ITERATOR_FIELD: True,
ITER_SEARCH_V2_KEY: True,
ITER_SEARCH_BATCH_SIZE_KEY: batch_size,
ITER_SEARCH_TTL_KEY: ttl,
GUARANTEE_TIMESTAMP: 0,
**kwargs,
}

def next(self):
res = self._conn.search(**self._params)
iter_info = res.get_search_iterator_v2_results_info()
self._params[ITER_SEARCH_LAST_BOUND_KEY] = iter_info.last_bound

# patch token and guarantee timestamp for the first next() call
if ITER_SEARCH_ID_KEY not in self._params:
self._params[ITER_SEARCH_ID_KEY] = iter_info.token
if self._params[GUARANTEE_TIMESTAMP] <= 0:
if res.get_session_ts() > 0:
self._params[GUARANTEE_TIMESTAMP] = res.get_session_ts()
else:
log.warning(
"failed to set up mvccTs from milvus server, use client-side ts instead"
)
self._params[GUARANTEE_TIMESTAMP] = fall_back_to_latest_session_ts()

# return SearchPage for compability
if len(res) > 0:
return SearchPage(res[0])
return SearchPage(None)

def close(self):
pass

0 comments on commit 6398b3a

Please sign in to comment.