Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance: add search iterator v2 #2395

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,11 +497,15 @@ def __init__(
)
nq_thres += topk
self._session_ts = session_ts
self._search_iterator_v2_results = res.search_iterator_v2_results
super().__init__(data)

def get_session_ts(self):
return self._session_ts

def get_search_iterator_v2_results_info(self):
return self._search_iterator_v2_results

def get_fields_by_range(
self, start: int, end: int, all_fields_data: List[schema_pb2.FieldData]
) -> Dict[str, Tuple[List[Any], schema_pb2.FieldData]]:
Expand Down
4 changes: 4 additions & 0 deletions pymilvus/client/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
STRICT_GROUP_SIZE = "strict_group_size"
ITERATOR_FIELD = "iterator"
ITERATOR_SESSION_TS_FIELD = "iterator_session_ts"
ITER_SEARCH_V2_KEY = "search_iter_v2"
ITER_SEARCH_BATCH_SIZE_KEY = "search_iter_batch_size"
ITER_SEARCH_LAST_BOUND_KEY = "search_iter_last_bound"
ITER_SEARCH_ID_KEY = "search_iter_id"
PAGE_RETAIN_ORDER_FIELD = "page_retain_order"

RANKER_TYPE_RRF = "rrf"
Expand Down
20 changes: 20 additions & 0 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
DYNAMIC_FIELD_NAME,
GROUP_BY_FIELD,
GROUP_SIZE,
ITER_SEARCH_BATCH_SIZE_KEY,
ITER_SEARCH_ID_KEY,
ITER_SEARCH_LAST_BOUND_KEY,
ITER_SEARCH_V2_KEY,
ITERATOR_FIELD,
PAGE_RETAIN_ORDER_FIELD,
RANK_GROUP_SCORER,
Expand Down Expand Up @@ -941,6 +945,22 @@ def search_requests_with_expr(
if is_iterator is not None:
search_params[ITERATOR_FIELD] = is_iterator

is_search_iter_v2 = kwargs.get(ITER_SEARCH_V2_KEY)
if is_search_iter_v2 is not None:
search_params[ITER_SEARCH_V2_KEY] = is_search_iter_v2

search_iter_batch_size = kwargs.get(ITER_SEARCH_BATCH_SIZE_KEY)
if search_iter_batch_size is not None:
search_params[ITER_SEARCH_BATCH_SIZE_KEY] = search_iter_batch_size

search_iter_last_bound = kwargs.get(ITER_SEARCH_LAST_BOUND_KEY)
if search_iter_last_bound is not None:
search_params[ITER_SEARCH_LAST_BOUND_KEY] = search_iter_last_bound

search_iter_id = kwargs.get(ITER_SEARCH_ID_KEY)
if search_iter_id is not None:
search_params[ITER_SEARCH_ID_KEY] = search_iter_id

group_by_field = kwargs.get(GROUP_BY_FIELD)
if group_by_field is not None:
search_params[GROUP_BY_FIELD] = group_by_field
Expand Down
33 changes: 31 additions & 2 deletions pymilvus/orm/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from .constants import UNLIMITED
from .future import MutationFuture, SearchFuture
from .index import Index
from .iterator import QueryIterator, SearchIterator
from .iterator import QueryIterator, SearchIterator, SearchIteratorV2
from .mutation import MutationResult
from .partition import Partition
from .prepare import Prepare
Expand Down Expand Up @@ -977,7 +977,7 @@ def search_iterator(
connection=self._get_connection(),
collection_name=self._name,
data=data,
ann_field=anns_field,
anns_field=anns_field,
param=param,
batch_size=batch_size,
limit=limit,
Expand All @@ -990,6 +990,35 @@ def search_iterator(
**kwargs,
)

def search_iterator_v2(
self,
data: Union[List, utils.SparseMatrixInputType],
anns_field: str,
param: Dict,
batch_size: Optional[int] = 1000,
expr: Optional[str] = None,
partition_names: Optional[List[str]] = None,
output_fields: Optional[List[str]] = None,
timeout: Optional[float] = None,
round_decimal: int = -1,
**kwargs,
):
return SearchIteratorV2(
connection=self._get_connection(),
collection_name=self._name,
data=data,
anns_field=anns_field,
param=param,
batch_size=batch_size,
expr=expr,
partition_names=partition_names,
output_fields=output_fields,
timeout=timeout,
round_decimal=round_decimal,
schema=self._schema_dict,
**kwargs,
)

def query(
self,
expr: str,
Expand Down
5 changes: 5 additions & 0 deletions pymilvus/orm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@
REDUCE_STOP_FOR_BEST = "reduce_stop_for_best"
ITERATOR_FIELD = "iterator"
ITERATOR_SESSION_TS_FIELD = "iterator_session_ts"
ITER_SEARCH_V2_KEY = "search_iter_v2"
ITER_SEARCH_BATCH_SIZE_KEY = "search_iter_batch_size"
ITER_SEARCH_LAST_BOUND_KEY = "search_iter_last_bound"
ITER_SEARCH_ID_KEY = "search_iter_id"
ITER_SEARCH_TTL_KEY = "search_iter_ttl"
PRINT_ITERATOR_CURSOR = "print_iterator_cursor"
DEFAULT_MAX_L2_DISTANCE = 99999999.0
DEFAULT_MIN_IP_DISTANCE = -99999999.0
Expand Down
140 changes: 115 additions & 25 deletions pymilvus/orm/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
GUARANTEE_TIMESTAMP,
INT64_MAX,
IS_PRIMARY,
ITER_SEARCH_BATCH_SIZE_KEY,
ITER_SEARCH_ID_KEY,
ITER_SEARCH_LAST_BOUND_KEY,
ITER_SEARCH_TTL_KEY,
ITER_SEARCH_V2_KEY,
ITERATOR_FIELD,
ITERATOR_SESSION_CP_FILE,
ITERATOR_SESSION_TS_FIELD,
Expand All @@ -51,7 +56,7 @@
LOGGER.setLevel(logging.INFO)
QueryIterator = TypeVar("QueryIterator")
SearchIterator = TypeVar("SearchIterator")

SearchIteratorV2 = TypeVar("SearchIteratorV2")
log = logging.getLogger(__name__)


Expand Down Expand Up @@ -87,6 +92,13 @@ def check_set_flag(obj: Any, flag_name: str, kwargs: Dict[str, Any], key: str):
setattr(obj, flag_name, kwargs.get(key, False))


def check_batch_size(batch_size: int):
if batch_size < 0:
raise ParamError(message="batch size cannot be less than zero")
if batch_size > MAX_BATCH_SIZE:
raise ParamError(message=f"batch size cannot be larger than {MAX_BATCH_SIZE}")


class QueryIterator:
def __init__(
self,
Expand Down Expand Up @@ -192,10 +204,7 @@ def __check_set_reduce_stop_for_best(self):
self._kwargs[REDUCE_STOP_FOR_BEST] = "False"

def __check_set_batch_size(self, batch_size: int):
if batch_size < 0:
raise ParamError(message="batch size cannot be less than zero")
if batch_size > MAX_BATCH_SIZE:
raise ParamError(message=f"batch size cannot be larger than {MAX_BATCH_SIZE}")
check_batch_size(batch_size)
self._kwargs[BATCH_SIZE] = batch_size
self._kwargs[MILVUS_LIMIT] = batch_size

Expand Down Expand Up @@ -432,13 +441,31 @@ def distances(self):
return distances


def check_num_queries(data: Union[List, utils.SparseMatrixInputType]):
rows = entity_helper.get_input_num_rows(data)
if rows > 1:
raise ParamError(message="Not support search iteration over multiple vectors at present")
if rows == 0:
raise ParamError(message="vector_data for search cannot be empty")


def check_metrics(param: Dict):
if param[METRIC_TYPE] is None or param[METRIC_TYPE] == "":
raise ParamError(message="must specify metrics type for search iterator")


def check_offset(kwargs: Dict):
if kwargs.get(OFFSET, 0) != 0:
raise ParamError(message="Not support offset when searching iteration")


class SearchIterator:
def __init__(
self,
connection: Connections,
collection_name: str,
data: Union[List, utils.SparseMatrixInputType],
ann_field: str,
anns_field: str,
param: Dict,
batch_size: Optional[int] = 1000,
limit: Optional[int] = UNLIMITED,
Expand All @@ -450,18 +477,14 @@ def __init__(
schema: Optional[CollectionSchema] = None,
**kwargs,
) -> SearchIterator:
rows = entity_helper.get_input_num_rows(data)
if rows > 1:
raise ParamError(
message="Not support search iteration over multiple vectors at present"
)
if rows == 0:
raise ParamError(message="vector_data for search cannot be empty")
check_num_queries(data)
check_metrics(param)
check_offset(kwargs)
self._conn = connection
self._iterator_params = {
"collection_name": collection_name,
"data": data,
"ann_field": ann_field,
"anns_field": anns_field,
BATCH_SIZE: batch_size,
"output_fields": output_fields,
"partition_names": partition_names,
Expand All @@ -478,8 +501,6 @@ def __init__(
self._schema = schema
self._limit = limit
self._returned_count = 0
self.__check_metrics()
self.__check_offset()
self.__check_rm_range_search_parameters()
self.__setup__pk_prop()
check_set_flag(self, "_print_iterator_cursor", self._kwargs, PRINT_ITERATOR_CURSOR)
Expand Down Expand Up @@ -561,10 +582,6 @@ def __setup__pk_prop(self):
if self._pk_field_name is None or self._pk_field_name == "":
raise ParamError(message="schema must contain pk field, broke")

def __check_metrics(self):
if self._param[METRIC_TYPE] is None or self._param[METRIC_TYPE] == "":
raise ParamError(message="must specify metrics type for search iterator")

"""we use search && range search to implement search iterator,
so range search parameters are disabled to clients"""

Expand All @@ -587,10 +604,6 @@ def __check_rm_range_search_parameters(self):
f"smalled than range_filter, please adjust your parameter"
)

def __check_offset(self):
if self._kwargs.get(OFFSET, 0) != 0:
raise ParamError(message="Not support offset when searching iteration")

def __update_filtered_ids(self, res: SearchPage):
if len(res) == 0:
return
Expand Down Expand Up @@ -698,7 +711,7 @@ def __execute_next_search(
res = self._conn.search(
self._iterator_params["collection_name"],
self._iterator_params["data"],
self._iterator_params["ann_field"],
self._iterator_params["anns_field"],
next_params,
extend_batch_size(self._iterator_params[BATCH_SIZE], next_params, to_extend_batch),
next_expr,
Expand Down Expand Up @@ -784,3 +797,80 @@ def release_cache(self, cache_id: int):
NO_CACHE_ID = -1
# Singleton Mode in Python
iterator_cache = IteratorCache()


class SearchIteratorV2:
def __init__(
self,
connection: Connections,
collection_name: str,
data: Union[List, utils.SparseMatrixInputType],
anns_field: str,
param: Dict,
batch_size: int = 1000,
expr: Optional[str] = None,
partition_names: Optional[List[str]] = None,
output_fields: Optional[List[str]] = None,
timeout: Optional[float] = None,
ttl: Optional[int] = None,
round_decimal: int = -1,
**kwargs,
) -> SearchIteratorV2:
check_num_queries(data)
check_metrics(param)
check_offset(kwargs)
check_batch_size(batch_size)

# delete limit from incoming for compatibility
if MILVUS_LIMIT in kwargs:
del kwargs[MILVUS_LIMIT]

self._conn = connection
self._params = {
"collection_name": collection_name,
"data": data,
"anns_field": anns_field,
"param": deepcopy(param),
"limit": batch_size,
"expression": expr,
"partition_names": partition_names,
"output_fields": output_fields,
"round_decimal": round_decimal,
"timeout": timeout,
ITERATOR_FIELD: True,
ITER_SEARCH_V2_KEY: True,
ITER_SEARCH_BATCH_SIZE_KEY: batch_size,
ITER_SEARCH_TTL_KEY: ttl,
GUARANTEE_TIMESTAMP: 0,
**kwargs,
}

def next(self):
res = self._conn.search(**self._params)
iter_info = res.get_search_iterator_v2_results_info()
self._params[ITER_SEARCH_LAST_BOUND_KEY] = iter_info.last_bound

# patch token and guarantee timestamp for the first next() call
if ITER_SEARCH_ID_KEY not in self._params:
if iter_info.token is not None and iter_info.token != "":
self._params[ITER_SEARCH_ID_KEY] = iter_info.token
else:
raise MilvusException(
message="The server does not support Search Iterator V2. Please upgrade your Milvus server, or create a search_iterator with use_v1=True instead"
)
if self._params[GUARANTEE_TIMESTAMP] <= 0:
if res.get_session_ts() > 0:
self._params[GUARANTEE_TIMESTAMP] = res.get_session_ts()
else:
log.warning(
"failed to set up mvccTs from milvus server, use client-side ts instead"
)
self._params[GUARANTEE_TIMESTAMP] = fall_back_to_latest_session_ts()

# return SearchPage for compability
if len(res) > 0:
return SearchPage(res[0])
return SearchPage(None)

def close(self):
pass
Loading