diff --git a/examples/hybrid_search.py b/examples/hybrid_search.py index 28ae0b309..e2a87a634 100644 --- a/examples/hybrid_search.py +++ b/examples/hybrid_search.py @@ -68,7 +68,27 @@ req = AnnSearchRequest(**search_param) req_list.append(req) -print("rank by RRFRanker") +print(fmt.format("rank by RRFRanker")) +hybrid_res = milvus_client.hybrid_search(collection_name, req_list, RRFRanker(), default_limit, output_fields=["random"]) +for hits in hybrid_res: + for hit in hits: + print(f" hybrid search hit: {hit}") + +req_list = [] +for i in range(len(field_names)): + # 4. generate search data + vectors_to_search = rng.random((nq, dim)) + search_param = { + "data": vectors_to_search, + "anns_field": field_names[i], + "param": {"metric_type": "L2"}, + "limit": default_limit, + "expr": "random > {radius}", + "expr_params": {"radius": 0.5}} + req = AnnSearchRequest(**search_param) + req_list.append(req) + +print(fmt.format("rank by RRFRanker with expression template")) hybrid_res = milvus_client.hybrid_search(collection_name, req_list, RRFRanker(), default_limit, output_fields=["random"]) for hits in hybrid_res: for hit in hits: diff --git a/examples/hybrid_search/hybrid_search.py b/examples/hybrid_search/hybrid_search.py index 6a13045f0..3b2d5b899 100644 --- a/examples/hybrid_search/hybrid_search.py +++ b/examples/hybrid_search/hybrid_search.py @@ -91,3 +91,25 @@ for hits in hybrid_res: for hit in hits: print(f" hybrid search hit: {hit}") + +print("rank by WightedRanker with expression template") +req_list = [] +for i in range(len(field_names)): + # 4. generate search data + vectors_to_search = rng.random((nq, dim)) + search_param = { + "data": vectors_to_search, + "anns_field": field_names[i], + "param": {"metric_type": "L2"}, + "limit": default_limit, + "expr": "random > {radius}", + "expr_params": {"radius": 0.5}} + req = AnnSearchRequest(**search_param) + req_list.append(req) + +hybrid_res = hello_milvus.hybrid_search(req_list, WeightedRanker(*weights), default_limit, output_fields=["random"]) + +print("rank by WightedRanker") +for hits in hybrid_res: + for hit in hits: + print(f" hybrid search hit: {hit}") diff --git a/pymilvus/client/abstract.py b/pymilvus/client/abstract.py index 52cd3ef4f..7cbff180a 100644 --- a/pymilvus/client/abstract.py +++ b/pymilvus/client/abstract.py @@ -417,6 +417,7 @@ def __init__( param: Dict, limit: int, expr: Optional[str] = None, + expr_params: Optional[dict] = None, ): self._data = data self._anns_field = anns_field @@ -426,6 +427,7 @@ def __init__( if expr is not None and not isinstance(expr, str): raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr)) self._expr = expr + self._expr_params = expr_params @property def data(self): @@ -447,6 +449,10 @@ def limit(self): def expr(self): return self._expr + @property + def expr_params(self): + return self._expr_params + def __str__(self): return { "anns_field": self.anns_field, diff --git a/pymilvus/client/grpc_handler.py b/pymilvus/client/grpc_handler.py index 210a5228f..258711125 100644 --- a/pymilvus/client/grpc_handler.py +++ b/pymilvus/client/grpc_handler.py @@ -868,6 +868,7 @@ def hybrid_search( req.expr, partition_names=partition_names, round_decimal=round_decimal, + expr_params=req.expr_params, **kwargs, ) requests.append(search_request) diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py index 4f3fd5e9c..a3fd3c942 100644 --- a/pymilvus/client/prepare.py +++ b/pymilvus/client/prepare.py @@ -38,10 +38,10 @@ class Prepare: @classmethod def create_collection_request( - cls, - collection_name: str, - fields: Union[Dict[str, Iterable], CollectionSchema], - **kwargs, + cls, + collection_name: str, + fields: Union[Dict[str, Iterable], CollectionSchema], + **kwargs, ) -> milvus_types.CreateCollectionRequest: """ Args: @@ -106,9 +106,9 @@ def create_collection_request( @classmethod def get_schema_from_collection_schema( - cls, - collection_name: str, - fields: CollectionSchema, + cls, + collection_name: str, + fields: CollectionSchema, ) -> schema_types.CollectionSchema: coll_description = fields.description if not isinstance(coll_description, (str, bytes)): @@ -164,9 +164,9 @@ def get_schema_from_collection_schema( @staticmethod def get_field_schema( - field: Dict, - primary_field: Any, - auto_id_field: Any, + field: Dict, + primary_field: Any, + auto_id_field: Any, ) -> (schema_types.FieldSchema, Any, Any): field_name = field.get("name") if field_name is None: @@ -229,10 +229,10 @@ def get_field_schema( @classmethod def get_schema( - cls, - collection_name: str, - fields: Dict[str, Iterable], - **kwargs, + cls, + collection_name: str, + fields: Dict[str, Iterable], + **kwargs, ) -> schema_types.CollectionSchema: if not isinstance(fields, dict): raise ParamError(message="Param fields must be a dict") @@ -269,17 +269,17 @@ def drop_collection_request(cls, collection_name: str) -> milvus_types.DropColle @classmethod def describe_collection_request( - cls, - collection_name: str, + cls, + collection_name: str, ) -> milvus_types.DescribeCollectionRequest: return milvus_types.DescribeCollectionRequest(collection_name=collection_name) @classmethod def alter_collection_request( - cls, - collection_name: str, - properties: Optional[Dict] = None, - delete_keys: Optional[List[str]] = None, + cls, + collection_name: str, + properties: Optional[Dict] = None, + delete_keys: Optional[List[str]] = None, ) -> milvus_types.AlterCollectionRequest: kvs = [] if properties: @@ -291,7 +291,7 @@ def alter_collection_request( @classmethod def alter_collection_field_request( - cls, collection_name: str, field_name: str, field_param: Dict + cls, collection_name: str, field_name: str, field_param: Dict ) -> milvus_types.AlterCollectionFieldRequest: kvs = [] if field_param: @@ -349,10 +349,10 @@ def partition_stats_request(cls, collection_name: str, partition_name: str): @classmethod def show_partitions_request( - cls, - collection_name: str, - partition_names: Optional[List[str]] = None, - type_in_memory: bool = False, + cls, + collection_name: str, + partition_names: Optional[List[str]] = None, + type_in_memory: bool = False, ): check_pass_param(collection_name=collection_name, partition_name_array=partition_names) req = milvus_types.ShowPartitionsRequest(collection_name=collection_name) @@ -371,7 +371,7 @@ def show_partitions_request( @classmethod def get_loading_progress( - cls, collection_name: str, partition_names: Optional[List[str]] = None + cls, collection_name: str, partition_names: Optional[List[str]] = None ): check_pass_param(collection_name=collection_name, partition_name_array=partition_names) req = milvus_types.GetLoadingProgressRequest(collection_name=collection_name) @@ -420,10 +420,10 @@ def _num_input_fields(fields_info: List[Dict], is_upsert: bool): @staticmethod def _parse_row_request( - request: Union[milvus_types.InsertRequest, milvus_types.UpsertRequest], - fields_info: List[Dict], - enable_dynamic: bool, - entities: List, + request: Union[milvus_types.InsertRequest, milvus_types.UpsertRequest], + fields_info: List[Dict], + enable_dynamic: bool, + entities: List, ): input_fields_info = [ field for field in fields_info if Prepare._is_input_field(field, is_upsert=False) @@ -462,7 +462,7 @@ def _parse_row_request( if k in fields_data: field_info, field_data = field_info_map[k], fields_data[k] if field_info.get("nullable", False) or field_info.get( - "default_value", None + "default_value", None ): field_data.valid_data.append(v is not None) entity_helper.pack_field_value_to_field_data(v, field_data, field_info) @@ -502,10 +502,10 @@ def _parse_row_request( @staticmethod def _parse_upsert_row_request( - request: Union[milvus_types.InsertRequest, milvus_types.UpsertRequest], - fields_info: List[Dict], - enable_dynamic: bool, - entities: List, + request: Union[milvus_types.InsertRequest, milvus_types.UpsertRequest], + fields_info: List[Dict], + enable_dynamic: bool, + entities: List, ): input_fields_info = [ field for field in fields_info if Prepare._is_input_field(field, is_upsert=True) @@ -544,7 +544,7 @@ def _parse_upsert_row_request( if k in fields_data: field_info, field_data = field_info_map[k], fields_data[k] if field_info.get("nullable", False) or field_info.get( - "default_value", None + "default_value", None ): field_data.valid_data.append(v is not None) entity_helper.pack_field_value_to_field_data(v, field_data, field_info) @@ -597,12 +597,12 @@ def _parse_upsert_row_request( @classmethod def row_insert_param( - cls, - collection_name: str, - entities: List, - partition_name: str, - fields_info: Dict, - enable_dynamic: bool = False, + cls, + collection_name: str, + entities: List, + partition_name: str, + fields_info: Dict, + enable_dynamic: bool = False, ): if not fields_info: raise ParamError(message="Missing collection meta to validate entities") @@ -619,12 +619,12 @@ def row_insert_param( @classmethod def row_upsert_param( - cls, - collection_name: str, - entities: List, - partition_name: str, - fields_info: Any, - enable_dynamic: bool = False, + cls, + collection_name: str, + entities: List, + partition_name: str, + fields_info: Any, + enable_dynamic: bool = False, ): if not fields_info: raise ParamError(message="Missing collection meta to validate entities") @@ -641,14 +641,14 @@ def row_upsert_param( @staticmethod def _pre_insert_batch_check( - entities: List, - fields_info: Any, + entities: List, + fields_info: Any, ): for entity in entities: if ( - entity.get("name") is None - or entity.get("values") is None - or entity.get("type") is None + entity.get("name") is None + or entity.get("values") is None + or entity.get("type") is None ): raise ParamError( message="Missing param in entities, a field must have type, name and values" @@ -672,14 +672,14 @@ def _pre_insert_batch_check( @staticmethod def _pre_upsert_batch_check( - entities: List, - fields_info: Any, + entities: List, + fields_info: Any, ): for entity in entities: if ( - entity.get("name") is None - or entity.get("values") is None - or entity.get("type") is None + entity.get("name") is None + or entity.get("values") is None + or entity.get("type") is None ): raise ParamError( message="Missing param in entities, a field must have type, name and values" @@ -702,10 +702,10 @@ def _pre_upsert_batch_check( @staticmethod def _parse_batch_request( - request: Union[milvus_types.InsertRequest, milvus_types.UpsertRequest], - entities: List, - fields_info: Any, - location: Dict, + request: Union[milvus_types.InsertRequest, milvus_types.UpsertRequest], + entities: List, + fields_info: Any, + location: Dict, ): pre_field_size = 0 try: @@ -739,11 +739,11 @@ def _parse_batch_request( @classmethod def batch_insert_param( - cls, - collection_name: str, - entities: List, - partition_name: str, - fields_info: Any, + cls, + collection_name: str, + entities: List, + partition_name: str, + fields_info: Any, ): location = cls._pre_insert_batch_check(entities, fields_info) tag = partition_name if isinstance(partition_name, str) else "" @@ -753,11 +753,11 @@ def batch_insert_param( @classmethod def batch_upsert_param( - cls, - collection_name: str, - entities: List, - partition_name: str, - fields_info: Any, + cls, + collection_name: str, + entities: List, + partition_name: str, + fields_info: Any, ): location = cls._pre_upsert_batch_check(entities, fields_info) tag = partition_name if isinstance(partition_name, str) else "" @@ -767,12 +767,12 @@ def batch_upsert_param( @classmethod def delete_request( - cls, - collection_name: str, - filter: str, - partition_name: Optional[str] = None, - consistency_level: Optional[Union[int, str]] = None, - **kwargs, + cls, + collection_name: str, + filter: str, + partition_name: Optional[str] = None, + consistency_level: Optional[Union[int, str]] = None, + **kwargs, ): check.validate_strs( collection_name=collection_name, @@ -849,10 +849,10 @@ def add_array_data(v: List) -> schema_types.TemplateArrayValue: data.bool_data.data.extend(v) return data if element_type in ( - schema_types.Int8, - schema_types.Int16, - schema_types.Int32, - schema_types.Int64, + schema_types.Int8, + schema_types.Int16, + schema_types.Int32, + schema_types.Int64, ): data.long_data.data.extend(v) return data @@ -879,10 +879,10 @@ def add_data(v: Any) -> schema_types.TemplateValue: data.bool_val = v return data if dtype in ( - schema_types.Int8, - schema_types.Int16, - schema_types.Int32, - schema_types.Int64, + schema_types.Int8, + schema_types.Int16, + schema_types.Int32, + schema_types.Int64, ): data.int64_val = v return data @@ -904,17 +904,17 @@ def add_data(v: Any) -> schema_types.TemplateValue: @classmethod def search_requests_with_expr( - cls, - collection_name: str, - data: Union[List, utils.SparseMatrixInputType], - anns_field: str, - param: Dict, - limit: int, - expr: Optional[str] = None, - partition_names: Optional[List[str]] = None, - output_fields: Optional[List[str]] = None, - round_decimal: int = -1, - **kwargs, + cls, + collection_name: str, + data: Union[List, utils.SparseMatrixInputType], + anns_field: str, + param: Dict, + limit: int, + expr: Optional[str] = None, + partition_names: Optional[List[str]] = None, + output_fields: Optional[List[str]] = None, + round_decimal: int = -1, + **kwargs, ) -> milvus_types.SearchRequest: use_default_consistency = ts_utils.construct_guarantee_ts(collection_name, kwargs) @@ -997,7 +997,8 @@ def search_requests_with_expr( placeholder_group=plg_str, dsl_type=common_types.DslType.BoolExprV1, search_params=req_params, - expr_template_values=cls.prepare_expression_template(kwargs.get("expr_params", {})), + expr_template_values=cls.prepare_expression_template( + {} if kwargs.get("expr_params") is None else kwargs.get("expr_params")), ) if expr is not None: request.dsl = expr @@ -1006,15 +1007,15 @@ def search_requests_with_expr( @classmethod def hybrid_search_request_with_ranker( - cls, - collection_name: str, - reqs: List, - rerank_param: Dict, - limit: int, - partition_names: Optional[List[str]] = None, - output_fields: Optional[List[str]] = None, - round_decimal: int = -1, - **kwargs, + cls, + collection_name: str, + reqs: List, + rerank_param: Dict, + limit: int, + partition_names: Optional[List[str]] = None, + output_fields: Optional[List[str]] = None, + round_decimal: int = -1, + **kwargs, ) -> milvus_types.HybridSearchRequest: use_default_consistency = ts_utils.construct_guarantee_ts(collection_name, kwargs) rerank_param["limit"] = limit @@ -1115,7 +1116,7 @@ def create_index_request(cls, collection_name: str, field_name: str, params: Dic @classmethod def alter_index_properties_request( - cls, collection_name: str, index_name: str, properties: dict + cls, collection_name: str, index_name: str, properties: dict ): params = [] for k, v in properties.items(): @@ -1126,7 +1127,7 @@ def alter_index_properties_request( @classmethod def drop_index_properties_request( - cls, collection_name: str, index_name: str, delete_keys: List[str] + cls, collection_name: str, index_name: str, delete_keys: List[str] ): return milvus_types.AlterIndexRequest( collection_name=collection_name, index_name=index_name, delete_keys=delete_keys @@ -1134,7 +1135,7 @@ def drop_index_properties_request( @classmethod def describe_index_request( - cls, collection_name: str, index_name: str, timestamp: Optional[int] = None + cls, collection_name: str, index_name: str, timestamp: Optional[int] = None ): return milvus_types.DescribeIndexRequest( collection_name=collection_name, index_name=index_name, timestamp=timestamp @@ -1154,14 +1155,14 @@ def get_index_state_request(cls, collection_name: str, index_name: str): @classmethod def load_collection( - cls, - db_name: str, - collection_name: str, - replica_number: int, - refresh: bool, - resource_groups: List[str], - load_fields: List[str], - skip_load_dynamic_field: bool, + cls, + db_name: str, + collection_name: str, + replica_number: int, + refresh: bool, + resource_groups: List[str], + load_fields: List[str], + skip_load_dynamic_field: bool, ): return milvus_types.LoadCollectionRequest( db_name=db_name, @@ -1181,15 +1182,15 @@ def release_collection(cls, db_name: str, collection_name: str): @classmethod def load_partitions( - cls, - db_name: str, - collection_name: str, - partition_names: List[str], - replica_number: int, - refresh: bool, - resource_groups: List[str], - load_fields: List[str], - skip_load_dynamic_field: bool, + cls, + db_name: str, + collection_name: str, + partition_names: List[str], + replica_number: int, + refresh: bool, + resource_groups: List[str], + load_fields: List[str], + skip_load_dynamic_field: bool, ): return milvus_types.LoadPartitionsRequest( db_name=db_name, @@ -1251,11 +1252,11 @@ def dummy_request(cls, request_type: Any): @classmethod def retrieve_request( - cls, - collection_name: str, - ids: List[str], - output_fields: List[str], - partition_names: List[str], + cls, + collection_name: str, + ids: List[str], + output_fields: List[str], + partition_names: List[str], ): ids = schema_types.IDs(int_id=schema_types.LongArray(data=ids)) return milvus_types.RetrieveRequest( @@ -1268,12 +1269,12 @@ def retrieve_request( @classmethod def query_request( - cls, - collection_name: str, - expr: str, - output_fields: List[str], - partition_names: List[str], - **kwargs, + cls, + collection_name: str, + expr: str, + output_fields: List[str], + partition_names: List[str], + **kwargs, ): use_default_consistency = ts_utils.construct_guarantee_ts(collection_name, kwargs) req = milvus_types.QueryRequest( @@ -1314,11 +1315,11 @@ def query_request( @classmethod def load_balance_request( - cls, - collection_name: str, - src_node_id: int, - dst_node_ids: List[int], - sealed_segment_ids: List[int], + cls, + collection_name: str, + src_node_id: int, + dst_node_ids: List[int], + sealed_segment_ids: List[int], ): return milvus_types.LoadBalanceRequest( collectionName=collection_name, @@ -1476,13 +1477,13 @@ def select_user_request(cls, username: str, include_role_info: bool): @classmethod def operate_privilege_request( - cls, - role_name: str, - object: Any, - object_name: str, - privilege: str, - db_name: str, - operate_privilege_type: Any, + cls, + role_name: str, + object: Any, + object_name: str, + privilege: str, + db_name: str, + operate_privilege_type: Any, ): check_pass_param(role_name=role_name) check_pass_param(object=object) @@ -1504,12 +1505,12 @@ def operate_privilege_request( @classmethod def operate_privilege_v2_request( - cls, - role_name: str, - privilege: str, - operate_privilege_type: Any, - db_name: str, - collection_name: str, + cls, + role_name: str, + privilege: str, + operate_privilege_type: Any, + db_name: str, + collection_name: str, ): check_pass_param( role_name=role_name, @@ -1678,7 +1679,7 @@ def list_privilege_groups_req(cls): @classmethod def operate_privilege_group_req( - cls, privilege_group: str, privileges: List[str], operate_privilege_group_type: Any + cls, privilege_group: str, privileges: List[str], operate_privilege_group_type: Any ): check_pass_param(privilege_group=privilege_group) check_pass_param(privileges=privileges)