Skip to content

Commit

Permalink
Suport search on multi-vec fields (#1813)
Browse files Browse the repository at this point in the history
Signed-off-by: xige-16 <[email protected]>
  • Loading branch information
xige-16 authored Dec 20, 2023
1 parent c07942a commit 7b7516f
Show file tree
Hide file tree
Showing 15 changed files with 1,049 additions and 409 deletions.
5 changes: 4 additions & 1 deletion pymilvus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
RemoteBulkWriter,
)
from .client import __version__
from .client.abstract import Hit, Hits, SearchResult
from .client.abstract import AnnSearchRequest, Hit, Hits, RRFRanker, SearchResult, WeightedRanker
from .client.asynch import SearchFuture
from .client.prepare import Prepare
from .client.stub import Milvus
Expand Down Expand Up @@ -147,4 +147,7 @@
"bulk_import",
"get_import_progress",
"list_import_jobs",
"AnnSearchRequest",
"RRFRanker",
"WeightedRanker",
]
98 changes: 96 additions & 2 deletions pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

import ujson

from pymilvus.exceptions import MilvusException
from pymilvus.exceptions import DataTypeNotMatchException, ExceptionsMessage, MilvusException
from pymilvus.grpc_gen import schema_pb2
from pymilvus.settings import Config

from .constants import DEFAULT_CONSISTENCY_LEVEL
from .constants import DEFAULT_CONSISTENCY_LEVEL, RANKER_TYPE_RRF, RANKER_TYPE_WEIGHTED
from .types import DataType


Expand Down Expand Up @@ -271,6 +271,100 @@ def __next__(self) -> Any:
raise StopIteration


class BaseRanker:
def __int__(self):
return

def dict(self):
return {}

def __str__(self):
return self.dict().__str__()


class RRFRanker(BaseRanker):
def __init__(
self,
k: int = 60,
):
self._strategy = RANKER_TYPE_RRF
self._k = k

def dict(self):
params = {
"k": self._k,
}
return {
"strategy": self._strategy,
"params": params,
}


class WeightedRanker(BaseRanker):
def __init__(self, *nums):
self._strategy = RANKER_TYPE_WEIGHTED
weights = []
for num in nums:
weights.append(num)
self._weights = weights

def dict(self):
params = {
"weights": self._weights,
}
return {
"strategy": self._strategy,
"params": params,
}


class AnnSearchRequest:
def __init__(
self,
data: List,
anns_field: str,
param: Dict,
limit: int,
expr: Optional[str] = None,
):
self._data = data
self._anns_field = anns_field
self._param = param
self._limit = limit

if expr is not None and not isinstance(expr, str):
raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr))
self._expr = expr

@property
def data(self):
return self._data

@property
def anns_field(self):
return self._anns_field

@property
def param(self):
return self._param

@property
def limit(self):
return self._limit

@property
def expr(self):
return self._expr

def __str__(self):
return {
"anns_field": self.anns_field,
"param": self.param,
"limit": self.limit,
"expr": self.expr,
}.__str__()


class SearchResult(list):
"""nq results: List[Hits]"""

Expand Down
3 changes: 3 additions & 0 deletions pymilvus/client/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@
DEFAULT_RESOURCE_GROUP = "__default_resource_group"
REDUCE_STOP_FOR_BEST = "reduce_stop_for_best"
GROUP_BY_FIELD = "group_by_field"

RANKER_TYPE_RRF = "rrf"
RANKER_TYPE_WEIGHTED = "weighted"
71 changes: 70 additions & 1 deletion pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pymilvus.settings import Config

from . import entity_helper, interceptor, ts_utils
from .abstract import CollectionSchema, MutationResult, SearchResult
from .abstract import AnnSearchRequest, BaseRanker, CollectionSchema, MutationResult, SearchResult
from .asynch import (
CreateIndexFuture,
FlushFuture,
Expand Down Expand Up @@ -708,6 +708,25 @@ def _execute_search(
return SearchFuture(None, None, e)
raise e from e

def _execute_searchV2(
self, request: milvus_types.SearchRequestV2, timeout: Optional[float] = None, **kwargs
):
try:
if kwargs.get("_async", False):
future = self._stub.SearchV2.future(request, timeout=timeout)
func = kwargs.get("_callback", None)
return SearchFuture(future, func)

response = self._stub.SearchV2(request, timeout=timeout)
check_status(response.status)
round_decimal = kwargs.get("round_decimal", -1)
return SearchResult(response.results, round_decimal)

except Exception as e:
if kwargs.get("_async", False):
return SearchFuture(None, None, e)
raise e from e

@retry_on_rpc_failure()
def search(
self,
Expand Down Expand Up @@ -747,6 +766,56 @@ def search(
)
return self._execute_search(request, timeout, round_decimal=round_decimal, **kwargs)

@retry_on_rpc_failure()
def searchV2(
self,
collection_name: str,
reqs: List[AnnSearchRequest],
rerank: BaseRanker,
limit: int,
partition_names: Optional[List[str]] = None,
output_fields: Optional[List[str]] = None,
round_decimal: int = -1,
timeout: Optional[float] = None,
**kwargs,
):
check_pass_param(
limit=limit,
round_decimal=round_decimal,
partition_name_array=partition_names,
output_fields=output_fields,
guarantee_timestamp=kwargs.get("guarantee_timestamp", None),
)

requests = []
for req in reqs:
search_request = Prepare.search_requests_with_expr(
collection_name,
req.data,
req.anns_field,
req.param,
req.limit,
req.expr,
partition_names=partition_names,
round_decimal=round_decimal,
**kwargs,
)
requests.append(search_request)

search_request_v2 = Prepare.search_requestV2_with_ranker(
collection_name,
requests,
rerank.dict(),
limit,
partition_names,
output_fields,
round_decimal,
**kwargs,
)
return self._execute_searchV2(
search_request_v2, timeout, round_decimal=round_decimal, **kwargs
)

@retry_on_rpc_failure()
def get_query_segment_info(self, collection_name: str, timeout: float = 30, **kwargs):
req = Prepare.get_query_segment_info_request(collection_name)
Expand Down
41 changes: 41 additions & 0 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,47 @@ def dump(v: Dict):

return request

@classmethod
def search_requestV2_with_ranker(
cls,
collection_name: str,
reqs: List,
rerank_param: Dict,
limit: int,
partition_names: Optional[List[str]] = None,
output_fields: Optional[List[str]] = None,
round_decimal: int = -1,
**kwargs,
) -> milvus_types.SearchRequestV2:

use_default_consistency = ts_utils.construct_guarantee_ts(collection_name, kwargs)
rerank_param["limit"] = limit
rerank_param["round_decimal"] = round_decimal

def dump(v: Dict):
if isinstance(v, dict):
return ujson.dumps(v)
return str(v)

request = milvus_types.SearchRequestV2(
collection_name=collection_name,
partition_names=partition_names,
requests=reqs,
output_fields=output_fields,
guarantee_timestamp=kwargs.get("guarantee_timestamp", 0),
use_default_consistency=use_default_consistency,
consistency_level=kwargs.get("consistency_level", 0),
)

request.rank_params.extend(
[
common_types.KeyValuePair(key=str(key), value=dump(value))
for key, value in rerank_param.items()
]
)

return request

@classmethod
def create_alias_request(cls, collection_name: str, alias: str):
return milvus_types.CreateAliasRequest(collection_name=collection_name, alias=alias)
Expand Down
Loading

0 comments on commit 7b7516f

Please sign in to comment.