Skip to content

Commit

Permalink
Rename searchV2 to HybridSearch (#1835)
Browse files Browse the repository at this point in the history
Signed-off-by: xige-16 <[email protected]>
  • Loading branch information
xige-16 authored Dec 28, 2023
1 parent 62c526a commit bd08a46
Show file tree
Hide file tree
Showing 12 changed files with 410 additions and 352 deletions.
16 changes: 8 additions & 8 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,16 +708,16 @@ def _execute_search(
return SearchFuture(None, None, e)
raise e from e

def _execute_searchV2(
self, request: milvus_types.SearchRequestV2, timeout: Optional[float] = None, **kwargs
def _execute_hybrid_search(
self, request: milvus_types.HybridSearchRequest, timeout: Optional[float] = None, **kwargs
):
try:
if kwargs.get("_async", False):
future = self._stub.SearchV2.future(request, timeout=timeout)
future = self._stub.HybridSearch.future(request, timeout=timeout)
func = kwargs.get("_callback", None)
return SearchFuture(future, func)

response = self._stub.SearchV2(request, timeout=timeout)
response = self._stub.HybridSearch(request, timeout=timeout)
check_status(response.status)
round_decimal = kwargs.get("round_decimal", -1)
return SearchResult(response.results, round_decimal)
Expand Down Expand Up @@ -767,7 +767,7 @@ def search(
return self._execute_search(request, timeout, round_decimal=round_decimal, **kwargs)

@retry_on_rpc_failure()
def searchV2(
def hybrid_search(
self,
collection_name: str,
reqs: List[AnnSearchRequest],
Expand Down Expand Up @@ -802,7 +802,7 @@ def searchV2(
)
requests.append(search_request)

search_request_v2 = Prepare.search_requestV2_with_ranker(
hybrid_search_request = Prepare.hybrid_search_request_with_ranker(
collection_name,
requests,
rerank.dict(),
Expand All @@ -812,8 +812,8 @@ def searchV2(
round_decimal,
**kwargs,
)
return self._execute_searchV2(
search_request_v2, timeout, round_decimal=round_decimal, **kwargs
return self._execute_hybrid_search(
hybrid_search_request, timeout, round_decimal=round_decimal, **kwargs
)

@retry_on_rpc_failure()
Expand Down
6 changes: 3 additions & 3 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ def dump(v: Dict):
return request

@classmethod
def search_requestV2_with_ranker(
def hybrid_search_request_with_ranker(
cls,
collection_name: str,
reqs: List,
Expand All @@ -663,7 +663,7 @@ def search_requestV2_with_ranker(
output_fields: Optional[List[str]] = None,
round_decimal: int = -1,
**kwargs,
) -> milvus_types.SearchRequestV2:
) -> milvus_types.HybridSearchRequest:
use_default_consistency = ts_utils.construct_guarantee_ts(collection_name, kwargs)
rerank_param["limit"] = limit
rerank_param["round_decimal"] = round_decimal
Expand All @@ -673,7 +673,7 @@ def dump(v: Dict):
return ujson.dumps(v)
return str(v)

request = milvus_types.SearchRequestV2(
request = milvus_types.HybridSearchRequest(
collection_name=collection_name,
partition_names=partition_names,
requests=reqs,
Expand Down
36 changes: 18 additions & 18 deletions pymilvus/grpc_gen/common_pb2.py

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions pymilvus/grpc_gen/common_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class MsgType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
DescribeIndex: _ClassVar[MsgType]
DropIndex: _ClassVar[MsgType]
GetIndexStatistics: _ClassVar[MsgType]
AlterIndex: _ClassVar[MsgType]
Insert: _ClassVar[MsgType]
Delete: _ClassVar[MsgType]
Flush: _ClassVar[MsgType]
Expand Down Expand Up @@ -278,6 +279,11 @@ class ObjectPrivilege(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
PrivilegeDropDatabase: _ClassVar[ObjectPrivilege]
PrivilegeListDatabases: _ClassVar[ObjectPrivilege]
PrivilegeFlushAll: _ClassVar[ObjectPrivilege]
PrivilegeCreatePartition: _ClassVar[ObjectPrivilege]
PrivilegeDropPartition: _ClassVar[ObjectPrivilege]
PrivilegeShowPartitions: _ClassVar[ObjectPrivilege]
PrivilegeHasPartition: _ClassVar[ObjectPrivilege]
PrivilegeGetFlushState: _ClassVar[ObjectPrivilege]

class StateCode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = []
Expand Down Expand Up @@ -409,6 +415,7 @@ CreateIndex: MsgType
DescribeIndex: MsgType
DropIndex: MsgType
GetIndexStatistics: MsgType
AlterIndex: MsgType
Insert: MsgType
Delete: MsgType
Flush: MsgType
Expand Down Expand Up @@ -532,6 +539,11 @@ PrivilegeCreateDatabase: ObjectPrivilege
PrivilegeDropDatabase: ObjectPrivilege
PrivilegeListDatabases: ObjectPrivilege
PrivilegeFlushAll: ObjectPrivilege
PrivilegeCreatePartition: ObjectPrivilege
PrivilegeDropPartition: ObjectPrivilege
PrivilegeShowPartitions: ObjectPrivilege
PrivilegeHasPartition: ObjectPrivilege
PrivilegeGetFlushState: ObjectPrivilege
Initializing: StateCode
Healthy: StateCode
Abnormal: StateCode
Expand Down
606 changes: 310 additions & 296 deletions pymilvus/grpc_gen/milvus_pb2.py

Large diffs are not rendered by default.

14 changes: 9 additions & 5 deletions pymilvus/grpc_gen/milvus_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -700,20 +700,22 @@ class MutationResult(_message.Message):
def __init__(self, status: _Optional[_Union[_common_pb2.Status, _Mapping]] = ..., IDs: _Optional[_Union[_schema_pb2.IDs, _Mapping]] = ..., succ_index: _Optional[_Iterable[int]] = ..., err_index: _Optional[_Iterable[int]] = ..., acknowledged: bool = ..., insert_cnt: _Optional[int] = ..., delete_cnt: _Optional[int] = ..., upsert_cnt: _Optional[int] = ..., timestamp: _Optional[int] = ...) -> None: ...

class DeleteRequest(_message.Message):
__slots__ = ["base", "db_name", "collection_name", "partition_name", "expr", "hash_keys"]
__slots__ = ["base", "db_name", "collection_name", "partition_name", "expr", "hash_keys", "consistency_level"]
BASE_FIELD_NUMBER: _ClassVar[int]
DB_NAME_FIELD_NUMBER: _ClassVar[int]
COLLECTION_NAME_FIELD_NUMBER: _ClassVar[int]
PARTITION_NAME_FIELD_NUMBER: _ClassVar[int]
EXPR_FIELD_NUMBER: _ClassVar[int]
HASH_KEYS_FIELD_NUMBER: _ClassVar[int]
CONSISTENCY_LEVEL_FIELD_NUMBER: _ClassVar[int]
base: _common_pb2.MsgBase
db_name: str
collection_name: str
partition_name: str
expr: str
hash_keys: _containers.RepeatedScalarFieldContainer[int]
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ..., collection_name: _Optional[str] = ..., partition_name: _Optional[str] = ..., expr: _Optional[str] = ..., hash_keys: _Optional[_Iterable[int]] = ...) -> None: ...
consistency_level: _common_pb2.ConsistencyLevel
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ..., collection_name: _Optional[str] = ..., partition_name: _Optional[str] = ..., expr: _Optional[str] = ..., hash_keys: _Optional[_Iterable[int]] = ..., consistency_level: _Optional[_Union[_common_pb2.ConsistencyLevel, str]] = ...) -> None: ...

class SearchRequest(_message.Message):
__slots__ = ["base", "db_name", "collection_name", "partition_names", "dsl", "placeholder_group", "dsl_type", "output_fields", "search_params", "travel_timestamp", "guarantee_timestamp", "nq", "not_return_all_meta", "consistency_level", "use_default_consistency", "search_by_primary_keys"]
Expand Down Expand Up @@ -771,7 +773,7 @@ class SearchResults(_message.Message):
collection_name: str
def __init__(self, status: _Optional[_Union[_common_pb2.Status, _Mapping]] = ..., results: _Optional[_Union[_schema_pb2.SearchResultData, _Mapping]] = ..., collection_name: _Optional[str] = ...) -> None: ...

class SearchRequestV2(_message.Message):
class HybridSearchRequest(_message.Message):
__slots__ = ["base", "db_name", "collection_name", "partition_names", "requests", "rank_params", "travel_timestamp", "guarantee_timestamp", "not_return_all_meta", "output_fields", "consistency_level", "use_default_consistency"]
BASE_FIELD_NUMBER: _ClassVar[int]
DB_NAME_FIELD_NUMBER: _ClassVar[int]
Expand Down Expand Up @@ -1212,22 +1214,24 @@ class GetFlushAllStateResponse(_message.Message):
def __init__(self, status: _Optional[_Union[_common_pb2.Status, _Mapping]] = ..., flushed: bool = ...) -> None: ...

class ImportRequest(_message.Message):
__slots__ = ["collection_name", "partition_name", "channel_names", "row_based", "files", "options", "db_name"]
__slots__ = ["collection_name", "partition_name", "channel_names", "row_based", "files", "options", "db_name", "clustering_info"]
COLLECTION_NAME_FIELD_NUMBER: _ClassVar[int]
PARTITION_NAME_FIELD_NUMBER: _ClassVar[int]
CHANNEL_NAMES_FIELD_NUMBER: _ClassVar[int]
ROW_BASED_FIELD_NUMBER: _ClassVar[int]
FILES_FIELD_NUMBER: _ClassVar[int]
OPTIONS_FIELD_NUMBER: _ClassVar[int]
DB_NAME_FIELD_NUMBER: _ClassVar[int]
CLUSTERING_INFO_FIELD_NUMBER: _ClassVar[int]
collection_name: str
partition_name: str
channel_names: _containers.RepeatedScalarFieldContainer[str]
row_based: bool
files: _containers.RepeatedScalarFieldContainer[str]
options: _containers.RepeatedCompositeFieldContainer[_common_pb2.KeyValuePair]
db_name: str
def __init__(self, collection_name: _Optional[str] = ..., partition_name: _Optional[str] = ..., channel_names: _Optional[_Iterable[str]] = ..., row_based: bool = ..., files: _Optional[_Iterable[str]] = ..., options: _Optional[_Iterable[_Union[_common_pb2.KeyValuePair, _Mapping]]] = ..., db_name: _Optional[str] = ...) -> None: ...
clustering_info: bytes
def __init__(self, collection_name: _Optional[str] = ..., partition_name: _Optional[str] = ..., channel_names: _Optional[_Iterable[str]] = ..., row_based: bool = ..., files: _Optional[_Iterable[str]] = ..., options: _Optional[_Iterable[_Union[_common_pb2.KeyValuePair, _Mapping]]] = ..., db_name: _Optional[str] = ..., clustering_info: _Optional[bytes] = ...) -> None: ...

class ImportResponse(_message.Message):
__slots__ = ["status", "tasks"]
Expand Down
20 changes: 10 additions & 10 deletions pymilvus/grpc_gen/milvus_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,9 @@ def __init__(self, channel):
request_serializer=milvus__pb2.SearchRequest.SerializeToString,
response_deserializer=milvus__pb2.SearchResults.FromString,
)
self.SearchV2 = channel.unary_unary(
'/milvus.proto.milvus.MilvusService/SearchV2',
request_serializer=milvus__pb2.SearchRequestV2.SerializeToString,
self.HybridSearch = channel.unary_unary(
'/milvus.proto.milvus.MilvusService/HybridSearch',
request_serializer=milvus__pb2.HybridSearchRequest.SerializeToString,
response_deserializer=milvus__pb2.SearchResults.FromString,
)
self.Flush = channel.unary_unary(
Expand Down Expand Up @@ -642,7 +642,7 @@ def Search(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def SearchV2(self, request, context):
def HybridSearch(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
Expand Down Expand Up @@ -1114,9 +1114,9 @@ def add_MilvusServiceServicer_to_server(servicer, server):
request_deserializer=milvus__pb2.SearchRequest.FromString,
response_serializer=milvus__pb2.SearchResults.SerializeToString,
),
'SearchV2': grpc.unary_unary_rpc_method_handler(
servicer.SearchV2,
request_deserializer=milvus__pb2.SearchRequestV2.FromString,
'HybridSearch': grpc.unary_unary_rpc_method_handler(
servicer.HybridSearch,
request_deserializer=milvus__pb2.HybridSearchRequest.FromString,
response_serializer=milvus__pb2.SearchResults.SerializeToString,
),
'Flush': grpc.unary_unary_rpc_method_handler(
Expand Down Expand Up @@ -1948,7 +1948,7 @@ def Search(request,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def SearchV2(request,
def HybridSearch(request,
target,
options=(),
channel_credentials=None,
Expand All @@ -1958,8 +1958,8 @@ def SearchV2(request,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/milvus.proto.milvus.MilvusService/SearchV2',
milvus__pb2.SearchRequestV2.SerializeToString,
return grpc.experimental.unary_unary(request, target, '/milvus.proto.milvus.MilvusService/HybridSearch',
milvus__pb2.HybridSearchRequest.SerializeToString,
milvus__pb2.SearchResults.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
Expand Down
Loading

0 comments on commit bd08a46

Please sign in to comment.