diff --git a/qdrant_client/conversions/conversion.py b/qdrant_client/conversions/conversion.py index 96bc0816..83c4592e 100644 --- a/qdrant_client/conversions/conversion.py +++ b/qdrant_client/conversions/conversion.py @@ -92,8 +92,8 @@ def payload_to_grpc(payload: Dict[str, Any]) -> Dict[str, Value]: return dict((key, json_to_value(val)) for key, val in payload.items()) -def grpc_to_payload(grpc: Dict[str, Value]) -> Dict[str, Any]: - return dict((key, value_to_json(val)) for key, val in grpc.items()) +def grpc_to_payload(grpc_: Dict[str, Value]) -> Dict[str, Any]: + return dict((key, value_to_json(val)) for key, val in grpc_.items()) def grpc_payload_schema_to_field_type(model: grpc.PayloadSchemaType) -> grpc.FieldType: @@ -157,6 +157,17 @@ def convert_filter(cls, model: grpc.Filter) -> rest.Filter: must=[cls.convert_condition(condition) for condition in model.must], should=[cls.convert_condition(condition) for condition in model.should], must_not=[cls.convert_condition(condition) for condition in model.must_not], + min_should=( + rest.MinShould( + conditions=[ + cls.convert_condition(condition) + for condition in model.min_should.conditions + ], + min_count=model.min_should.min_count, + ) + if model.HasField("min_should") + else None + ), ) @classmethod @@ -218,9 +229,11 @@ def convert_collection_config(cls, model: grpc.CollectionConfig) -> rest.Collect optimizer_config=cls.convert_optimizer_config(model.optimizer_config), params=cls.convert_collection_params(model.params), wal_config=cls.convert_wal_config(model.wal_config), - quantization_config=cls.convert_quantization_config(model.quantization_config) - if model.HasField("quantization_config") - else None, + quantization_config=( + cls.convert_quantization_config(model.quantization_config) + if model.HasField("quantization_config") + else None + ), ) @classmethod @@ -228,12 +241,12 @@ def convert_hnsw_config_diff(cls, model: grpc.HnswConfigDiff) -> rest.HnswConfig return rest.HnswConfigDiff( ef_construct=model.ef_construct if model.HasField("ef_construct") else None, m=model.m if model.HasField("m") else None, - full_scan_threshold=model.full_scan_threshold - if model.HasField("full_scan_threshold") - else None, - max_indexing_threads=model.max_indexing_threads - if model.HasField("max_indexing_threads") - else None, + full_scan_threshold=( + model.full_scan_threshold if model.HasField("full_scan_threshold") else None + ), + max_indexing_threads=( + model.max_indexing_threads if model.HasField("max_indexing_threads") else None + ), on_disk=model.on_disk if model.HasField("on_disk") else None, payload_m=model.payload_m if model.HasField("payload_m") else None, ) @@ -243,12 +256,12 @@ def convert_hnsw_config(cls, model: grpc.HnswConfigDiff) -> rest.HnswConfig: return rest.HnswConfig( ef_construct=model.ef_construct if model.HasField("ef_construct") else None, m=model.m if model.HasField("m") else None, - full_scan_threshold=model.full_scan_threshold - if model.HasField("full_scan_threshold") - else None, - max_indexing_threads=model.max_indexing_threads - if model.HasField("max_indexing_threads") - else None, + full_scan_threshold=( + model.full_scan_threshold if model.HasField("full_scan_threshold") else None + ), + max_indexing_threads=( + model.max_indexing_threads if model.HasField("max_indexing_threads") else None + ), on_disk=model.on_disk if model.HasField("on_disk") else None, payload_m=model.payload_m if model.HasField("payload_m") else None, ) @@ -256,30 +269,34 @@ def convert_hnsw_config(cls, model: grpc.HnswConfigDiff) -> rest.HnswConfig: @classmethod def convert_optimizer_config(cls, model: grpc.OptimizersConfigDiff) -> rest.OptimizersConfig: return rest.OptimizersConfig( - default_segment_number=model.default_segment_number - if model.HasField("default_segment_number") - else None, - deleted_threshold=model.deleted_threshold - if model.HasField("deleted_threshold") - else None, - flush_interval_sec=model.flush_interval_sec - if model.HasField("flush_interval_sec") - else None, - indexing_threshold=model.indexing_threshold - if model.HasField("indexing_threshold") - else None, - max_optimization_threads=model.max_optimization_threads - if model.HasField("max_optimization_threads") - else None, - max_segment_size=model.max_segment_size - if model.HasField("max_segment_size") - else None, - memmap_threshold=model.memmap_threshold - if model.HasField("memmap_threshold") - else None, - vacuum_min_vector_number=model.vacuum_min_vector_number - if model.HasField("vacuum_min_vector_number") - else None, + default_segment_number=( + model.default_segment_number if model.HasField("default_segment_number") else None + ), + deleted_threshold=( + model.deleted_threshold if model.HasField("deleted_threshold") else None + ), + flush_interval_sec=( + model.flush_interval_sec if model.HasField("flush_interval_sec") else None + ), + indexing_threshold=( + model.indexing_threshold if model.HasField("indexing_threshold") else None + ), + max_optimization_threads=( + model.max_optimization_threads + if model.HasField("max_optimization_threads") + else None + ), + max_segment_size=( + model.max_segment_size if model.HasField("max_segment_size") else None + ), + memmap_threshold=( + model.memmap_threshold if model.HasField("memmap_threshold") else None + ), + vacuum_min_vector_number=( + model.vacuum_min_vector_number + if model.HasField("vacuum_min_vector_number") + else None + ), ) @classmethod @@ -299,9 +316,9 @@ def convert_distance(cls, model: grpc.Distance) -> rest.Distance: def convert_wal_config(cls, model: grpc.WalConfigDiff) -> rest.WalConfig: return rest.WalConfig( wal_capacity_mb=model.wal_capacity_mb if model.HasField("wal_capacity_mb") else None, - wal_segments_ahead=model.wal_segments_ahead - if model.HasField("wal_segments_ahead") - else None, + wal_segments_ahead=( + model.wal_segments_ahead if model.HasField("wal_segments_ahead") else None + ), ) @classmethod @@ -314,9 +331,11 @@ def convert_payload_schema( def convert_payload_schema_info(cls, model: grpc.PayloadSchemaInfo) -> rest.PayloadIndexInfo: return rest.PayloadIndexInfo( data_type=cls.convert_payload_schema_type(model.data_type), - params=cls.convert_payload_schema_params(model.params) - if model.HasField("params") - else None, + params=( + cls.convert_payload_schema_params(model.params) + if model.HasField("params") + else None + ), points=model.points, ) @@ -420,9 +439,11 @@ def convert_search_params(cls, model: grpc.SearchParams) -> rest.SearchParams: return rest.SearchParams( hnsw_ef=model.hnsw_ef if model.HasField("hnsw_ef") else None, exact=model.exact if model.HasField("exact") else None, - quantization=cls.convert_quantization_search_params(model.quantization) - if model.HasField("quantization") - else None, + quantization=( + cls.convert_quantization_search_params(model.quantization) + if model.HasField("quantization") + else None + ), indexed_only=model.indexed_only if model.HasField("indexed_only") else None, ) @@ -439,9 +460,9 @@ def convert_scored_point(cls, model: grpc.ScoredPoint) -> rest.ScoredPoint: score=model.score, vector=cls.convert_vectors(model.vectors) if model.HasField("vectors") else None, version=model.version, - shard_key=cls.convert_shard_key(model.shard_key) - if model.HasField("shard_key") - else None, + shard_key=( + cls.convert_shard_key(model.shard_key) if model.HasField("shard_key") else None + ), ) @classmethod @@ -534,28 +555,32 @@ def convert_match(cls, model: grpc.Match) -> rest.Match: def convert_wal_config_diff(cls, model: grpc.WalConfigDiff) -> rest.WalConfigDiff: return rest.WalConfigDiff( wal_capacity_mb=model.wal_capacity_mb if model.HasField("wal_capacity_mb") else None, - wal_segments_ahead=model.wal_segments_ahead - if model.HasField("wal_segments_ahead") - else None, + wal_segments_ahead=( + model.wal_segments_ahead if model.HasField("wal_segments_ahead") else None + ), ) @classmethod def convert_collection_params(cls, model: grpc.CollectionParams) -> rest.CollectionParams: return rest.CollectionParams( - vectors=cls.convert_vectors_config(model.vectors_config) - if model.HasField("vectors_config") - else None, + vectors=( + cls.convert_vectors_config(model.vectors_config) + if model.HasField("vectors_config") + else None + ), shard_number=model.shard_number, on_disk_payload=model.on_disk_payload, - replication_factor=model.replication_factor - if model.HasField("replication_factor") - else None, - read_fan_out_factor=model.read_fan_out_factor - if model.HasField("read_fan_out_factor") - else None, - write_consistency_factor=model.write_consistency_factor - if model.HasField("write_consistency_factor") - else None, + replication_factor=( + model.replication_factor if model.HasField("replication_factor") else None + ), + read_fan_out_factor=( + model.read_fan_out_factor if model.HasField("read_fan_out_factor") else None + ), + write_consistency_factor=( + model.write_consistency_factor + if model.HasField("write_consistency_factor") + else None + ), ) @classmethod @@ -563,50 +588,64 @@ def convert_optimizers_config_diff( cls, model: grpc.OptimizersConfigDiff ) -> rest.OptimizersConfigDiff: return rest.OptimizersConfigDiff( - default_segment_number=model.default_segment_number - if model.HasField("default_segment_number") - else None, - deleted_threshold=model.deleted_threshold - if model.HasField("deleted_threshold") - else None, - flush_interval_sec=model.flush_interval_sec - if model.HasField("flush_interval_sec") - else None, - indexing_threshold=model.indexing_threshold - if model.HasField("indexing_threshold") - else None, - max_optimization_threads=model.max_optimization_threads - if model.HasField("max_optimization_threads") - else None, - max_segment_size=model.max_segment_size - if model.HasField("max_segment_size") - else None, - memmap_threshold=model.memmap_threshold - if model.HasField("memmap_threshold") - else None, - vacuum_min_vector_number=model.vacuum_min_vector_number - if model.HasField("vacuum_min_vector_number") - else None, + default_segment_number=( + model.default_segment_number if model.HasField("default_segment_number") else None + ), + deleted_threshold=( + model.deleted_threshold if model.HasField("deleted_threshold") else None + ), + flush_interval_sec=( + model.flush_interval_sec if model.HasField("flush_interval_sec") else None + ), + indexing_threshold=( + model.indexing_threshold if model.HasField("indexing_threshold") else None + ), + max_optimization_threads=( + model.max_optimization_threads + if model.HasField("max_optimization_threads") + else None + ), + max_segment_size=( + model.max_segment_size if model.HasField("max_segment_size") else None + ), + memmap_threshold=( + model.memmap_threshold if model.HasField("memmap_threshold") else None + ), + vacuum_min_vector_number=( + model.vacuum_min_vector_number + if model.HasField("vacuum_min_vector_number") + else None + ), ) @classmethod def convert_update_collection(cls, model: grpc.UpdateCollection) -> rest.UpdateCollection: return rest.UpdateCollection( - vectors=cls.convert_vectors_config_diff(model.vectors_config) - if model.HasField("vectors_config") - else None, - optimizers_config=cls.convert_optimizers_config_diff(model.optimizers_config) - if model.HasField("optimizers_config") - else None, - params=cls.convert_collection_params_diff(model.params) - if model.HasField("params") - else None, - hnsw_config=cls.convert_hnsw_config_diff(model.hnsw_config) - if model.HasField("hnsw_config") - else None, - quantization_config=cls.convert_quantization_config_diff(model.quantization_config) - if model.HasField("quantization_config") - else None, + vectors=( + cls.convert_vectors_config_diff(model.vectors_config) + if model.HasField("vectors_config") + else None + ), + optimizers_config=( + cls.convert_optimizers_config_diff(model.optimizers_config) + if model.HasField("optimizers_config") + else None + ), + params=( + cls.convert_collection_params_diff(model.params) + if model.HasField("params") + else None + ), + hnsw_config=( + cls.convert_hnsw_config_diff(model.hnsw_config) + if model.HasField("hnsw_config") + else None + ), + quantization_config=( + cls.convert_quantization_config_diff(model.quantization_config) + if model.HasField("quantization_config") + else None + ), ) @classmethod @@ -684,9 +723,9 @@ def convert_retrieved_point(cls, model: grpc.RetrievedPoint) -> rest.Record: id=cls.convert_point_id(model.id), payload=cls.convert_payload(model.payload), vector=cls.convert_vectors(model.vectors) if model.HasField("vectors") else None, - shard_key=cls.convert_shard_key(model.shard_key) - if model.HasField("shard_key") - else None, + shard_key=( + cls.convert_shard_key(model.shard_key) if model.HasField("shard_key") else None + ), ) @classmethod @@ -703,9 +742,11 @@ def convert_snapshot_description( ) -> rest.SnapshotDescription: return rest.SnapshotDescription( name=model.name, - creation_time=model.creation_time.ToDatetime().isoformat() - if model.HasField("creation_time") - else None, + creation_time=( + model.creation_time.ToDatetime().isoformat() + if model.HasField("creation_time") + else None + ), size=model.size, ) @@ -714,12 +755,16 @@ def convert_vector_params(cls, model: grpc.VectorParams) -> rest.VectorParams: return rest.VectorParams( size=model.size, distance=cls.convert_distance(model.distance), - hnsw_config=cls.convert_hnsw_config_diff(model.hnsw_config) - if model.HasField("hnsw_config") - else None, - quantization_config=cls.convert_quantization_config(model.quantization_config) - if model.HasField("quantization_config") - else None, + hnsw_config=( + cls.convert_hnsw_config_diff(model.hnsw_config) + if model.HasField("hnsw_config") + else None + ), + quantization_config=( + cls.convert_quantization_config(model.quantization_config) + if model.HasField("quantization_config") + else None + ), on_disk=model.on_disk if model.HasField("on_disk") else None, ) @@ -780,18 +825,24 @@ def convert_search_points(cls, model: grpc.SearchPoints) -> rest.SearchRequest: vector=rest.NamedVector(name=model.vector_name, vector=model.vector[:]), filter=cls.convert_filter(model.filter) if model.HasField("filter") else None, limit=model.limit, - with_payload=cls.convert_with_payload_interface(model.with_payload) - if model.HasField("with_payload") - else None, + with_payload=( + cls.convert_with_payload_interface(model.with_payload) + if model.HasField("with_payload") + else None + ), params=cls.convert_search_params(model.params) if model.HasField("params") else None, score_threshold=model.score_threshold if model.HasField("score_threshold") else None, offset=model.offset if model.HasField("offset") else None, - with_vector=cls.convert_with_vectors_selector(model.with_vectors) - if model.HasField("with_vectors") - else None, - shard_key=cls.convert_shard_key_selector(model.shard_key_selector) - if model.HasField("shard_key_selector") - else None, + with_vector=( + cls.convert_with_vectors_selector(model.with_vectors) + if model.HasField("with_vectors") + else None + ), + shard_key=( + cls.convert_shard_key_selector(model.shard_key_selector) + if model.HasField("shard_key_selector") + else None + ), ) @classmethod @@ -807,25 +858,35 @@ def convert_recommend_points(cls, model: grpc.RecommendPoints) -> rest.Recommend negative=negative_ids + negative_vectors, filter=cls.convert_filter(model.filter) if model.HasField("filter") else None, limit=model.limit, - with_payload=cls.convert_with_payload_interface(model.with_payload) - if model.HasField("with_payload") - else None, + with_payload=( + cls.convert_with_payload_interface(model.with_payload) + if model.HasField("with_payload") + else None + ), params=cls.convert_search_params(model.params) if model.HasField("params") else None, score_threshold=model.score_threshold if model.HasField("score_threshold") else None, offset=model.offset if model.HasField("offset") else None, - with_vector=cls.convert_with_vectors_selector(model.with_vectors) - if model.HasField("with_vectors") - else None, + with_vector=( + cls.convert_with_vectors_selector(model.with_vectors) + if model.HasField("with_vectors") + else None + ), using=model.using, - lookup_from=cls.convert_lookup_location(model.lookup_from) - if model.HasField("lookup_from") - else None, - strategy=cls.convert_recommend_strategy(model.strategy) - if model.HasField("strategy") - else None, - shard_key=cls.convert_shard_key_selector(model.shard_key_selector) - if model.HasField("shard_key_selector") - else None, + lookup_from=( + cls.convert_lookup_location(model.lookup_from) + if model.HasField("lookup_from") + else None + ), + strategy=( + cls.convert_recommend_strategy(model.strategy) + if model.HasField("strategy") + else None + ), + shard_key=( + cls.convert_shard_key_selector(model.shard_key_selector) + if model.HasField("shard_key_selector") + else None + ), ) @classmethod @@ -837,21 +898,29 @@ def convert_discover_points(cls, model: grpc.DiscoverPoints) -> rest.DiscoverReq context=context, filter=cls.convert_filter(model.filter) if model.HasField("filter") else None, limit=model.limit, - with_payload=cls.convert_with_payload_interface(model.with_payload) - if model.HasField("with_payload") - else None, + with_payload=( + cls.convert_with_payload_interface(model.with_payload) + if model.HasField("with_payload") + else None + ), params=cls.convert_search_params(model.params) if model.HasField("params") else None, offset=model.offset if model.HasField("offset") else None, - with_vector=cls.convert_with_vectors_selector(model.with_vectors) - if model.HasField("with_vectors") - else None, + with_vector=( + cls.convert_with_vectors_selector(model.with_vectors) + if model.HasField("with_vectors") + else None + ), using=model.using, - lookup_from=cls.convert_lookup_location(model.lookup_from) - if model.HasField("lookup_from") - else None, - shard_key=cls.convert_shard_key_selector(model.shard_key_selector) - if model.HasField("shard_key_selector") - else None, + lookup_from=( + cls.convert_lookup_location(model.lookup_from) + if model.HasField("lookup_from") + else None + ), + shard_key=( + cls.convert_shard_key_selector(model.shard_key_selector) + if model.HasField("shard_key_selector") + else None + ), ) @classmethod @@ -906,15 +975,17 @@ def convert_collection_params_diff( cls, model: grpc.CollectionParamsDiff ) -> rest.CollectionParamsDiff: return rest.CollectionParamsDiff( - replication_factor=model.replication_factor - if model.HasField("replication_factor") - else None, - write_consistency_factor=model.write_consistency_factor - if model.HasField("write_consistency_factor") - else None, - read_fan_out_factor=model.read_fan_out_factor - if model.HasField("read_fan_out_factor") - else None, + replication_factor=( + model.replication_factor if model.HasField("replication_factor") else None + ), + write_consistency_factor=( + model.write_consistency_factor + if model.HasField("write_consistency_factor") + else None + ), + read_fan_out_factor=( + model.read_fan_out_factor if model.HasField("read_fan_out_factor") else None + ), on_disk_payload=model.on_disk_payload if model.HasField("on_disk_payload") else None, ) @@ -1053,12 +1124,16 @@ def convert_group_id(cls, model: grpc.GroupId) -> rest.GroupId: def convert_with_lookup(cls, model: grpc.WithLookup) -> rest.WithLookup: return rest.WithLookup( collection=model.collection, - with_payload=cls.convert_with_payload_selector(model.with_payload) - if model.HasField("with_payload") - else None, - with_vectors=cls.convert_with_vectors_selector(model.with_vectors) - if model.HasField("with_vectors") - else None, + with_payload=( + cls.convert_with_payload_selector(model.with_payload) + if model.HasField("with_payload") + else None + ), + with_vectors=( + cls.convert_with_vectors_selector(model.with_vectors) + if model.HasField("with_vectors") + else None + ), ) @classmethod @@ -1080,12 +1155,16 @@ def convert_quantization_config_diff( @classmethod def convert_vector_params_diff(cls, model: grpc.VectorParamsDiff) -> rest.VectorParamsDiff: return rest.VectorParamsDiff( - hnsw_config=cls.convert_hnsw_config_diff(model.hnsw_config) - if model.HasField("hnsw_config") - else None, - quantization_config=cls.convert_quantization_config_diff(model.quantization_config) - if model.HasField("quantization_config") - else None, + hnsw_config=( + cls.convert_hnsw_config_diff(model.hnsw_config) + if model.HasField("hnsw_config") + else None + ), + quantization_config=( + cls.convert_quantization_config_diff(model.quantization_config) + if model.HasField("quantization_config") + else None + ), on_disk=model.on_disk if model.HasField("on_disk") else None, ) @@ -1270,9 +1349,9 @@ def convert_recommend_strategy(cls, model: grpc.RecommendStrategy) -> rest.Recom @classmethod def convert_sparse_index_config(cls, model: grpc.SparseIndexConfig) -> rest.SparseIndexParams: return rest.SparseIndexParams( - full_scan_threshold=model.full_scan_threshold - if model.HasField("full_scan_threshold") - else None, + full_scan_threshold=( + model.full_scan_threshold if model.HasField("full_scan_threshold") else None + ), on_disk=model.on_disk if model.HasField("on_disk") else None, ) @@ -1281,9 +1360,9 @@ def convert_sparse_vector_params( cls, model: grpc.SparseVectorParams ) -> rest.SparseVectorParams: return rest.SparseVectorParams( - index=cls.convert_sparse_index_config(model.index) - if model.index is not None - else None, + index=( + cls.convert_sparse_index_config(model.index) if model.index is not None else None + ), ) @classmethod @@ -1336,12 +1415,12 @@ def convert_start_from(cls, model: grpc.StartFrom) -> rest.StartFrom: def convert_order_by(cls, model: grpc.OrderBy) -> rest.OrderBy: return rest.OrderBy( key=model.key, - direction=cls.convert_direction(model.direction) - if model.HasField("direction") - else None, - start_from=cls.convert_start_from(model.start_from) - if model.HasField("start_from") - else None, + direction=( + cls.convert_direction(model.direction) if model.HasField("direction") else None + ), + start_from=( + cls.convert_start_from(model.start_from) if model.HasField("start_from") else None + ), ) @@ -1356,15 +1435,32 @@ class RestToGrpc: @classmethod def convert_filter(cls, model: rest.Filter) -> grpc.Filter: return grpc.Filter( - must=[cls.convert_condition(condition) for condition in model.must] - if model.must is not None - else None, - must_not=[cls.convert_condition(condition) for condition in model.must_not] - if model.must_not is not None - else None, - should=[cls.convert_condition(condition) for condition in model.should] - if model.should is not None - else None, + must=( + [cls.convert_condition(condition) for condition in model.must] + if model.must is not None + else None + ), + must_not=( + [cls.convert_condition(condition) for condition in model.must_not] + if model.must_not is not None + else None + ), + should=( + [cls.convert_condition(condition) for condition in model.should] + if model.should is not None + else None + ), + min_should=( + grpc.MinShould( + conditions=[ + cls.convert_condition(condition) + for condition in model.min_should.conditions + ], + min_count=model.min_should.min_count, + ) + if model.min_should is not None + else None + ), ) @classmethod @@ -1406,9 +1502,11 @@ def convert_collection_info(cls, model: rest.CollectionInfo) -> grpc.CollectionI return grpc.CollectionInfo( config=cls.convert_collection_config(model.config) if model.config else None, optimizer_status=cls.convert_optimizer_status(model.optimizer_status), - payload_schema=cls.convert_payload_schema(model.payload_schema) - if model.payload_schema is not None - else None, + payload_schema=( + cls.convert_payload_schema(model.payload_schema) + if model.payload_schema is not None + else None + ), segments_count=model.segments_count, status=cls.convert_collection_status(model.status), vectors_count=model.vectors_count, @@ -1529,9 +1627,11 @@ def convert_search_params(cls, model: rest.SearchParams) -> grpc.SearchParams: return grpc.SearchParams( hnsw_ef=model.hnsw_ef, exact=model.exact, - quantization=cls.convert_quantization_search_params(model.quantization) - if model.quantization is not None - else None, + quantization=( + cls.convert_quantization_search_params(model.quantization) + if model.quantization is not None + else None + ), indexed_only=model.indexed_only, ) @@ -1629,9 +1729,11 @@ def convert_collection_config(cls, model: rest.CollectionConfig) -> grpc.Collect hnsw_config=cls.convert_hnsw_config(model.hnsw_config), optimizer_config=cls.convert_optimizers_config(model.optimizer_config), wal_config=cls.convert_wal_config(model.wal_config), - quantization_config=cls.convert_quantization_config(model.quantization_config) - if model.quantization_config is not None - else None, + quantization_config=( + cls.convert_quantization_config(model.quantization_config) + if model.quantization_config is not None + else None + ), ) @classmethod @@ -1668,9 +1770,9 @@ def convert_distance(cls, model: rest.Distance) -> grpc.Distance: @classmethod def convert_collection_params(cls, model: rest.CollectionParams) -> grpc.CollectionParams: return grpc.CollectionParams( - vectors_config=cls.convert_vectors_config(model.vectors) - if model.vectors is not None - else None, + vectors_config=( + cls.convert_vectors_config(model.vectors) if model.vectors is not None else None + ), shard_number=model.shard_number, on_disk_payload=model.on_disk_payload or False, write_consistency_factor=model.write_consistency_factor, @@ -1712,21 +1814,31 @@ def convert_update_collection( ) -> grpc.UpdateCollection: return grpc.UpdateCollection( collection_name=collection_name, - optimizers_config=cls.convert_optimizers_config_diff(model.optimizers_config) - if model.optimizers_config is not None - else None, - vectors_config=cls.convert_vectors_config_diff(model.vectors) - if model.vectors is not None - else None, - params=cls.convert_collection_params_diff(model.params) - if model.params is not None - else None, - hnsw_config=cls.convert_hnsw_config_diff(model.hnsw_config) - if model.hnsw_config is not None - else None, - quantization_config=cls.convert_quantization_config_diff(model.quantization_config) - if model.quantization_config is not None - else None, + optimizers_config=( + cls.convert_optimizers_config_diff(model.optimizers_config) + if model.optimizers_config is not None + else None + ), + vectors_config=( + cls.convert_vectors_config_diff(model.vectors) + if model.vectors is not None + else None + ), + params=( + cls.convert_collection_params_diff(model.params) + if model.params is not None + else None + ), + hnsw_config=( + cls.convert_hnsw_config_diff(model.hnsw_config) + if model.hnsw_config is not None + else None + ), + quantization_config=( + cls.convert_quantization_config_diff(model.quantization_config) + if model.quantization_config is not None + else None + ), ) @classmethod @@ -1948,12 +2060,12 @@ def convert_direction(cls, model: rest.Direction) -> grpc.Direction: def convert_order_by(cls, model: rest.OrderBy) -> grpc.OrderBy: return grpc.OrderBy( key=model.key, - direction=cls.convert_direction(model.direction) - if model.direction is not None - else None, - start_from=cls.convert_start_from(model.start_from) - if model.start_from is not None - else None, + direction=( + cls.convert_direction(model.direction) if model.direction is not None else None + ), + start_from=( + cls.convert_start_from(model.start_from) if model.start_from is not None else None + ), ) @classmethod @@ -2001,12 +2113,16 @@ def convert_vector_params(cls, model: rest.VectorParams) -> grpc.VectorParams: return grpc.VectorParams( size=model.size, distance=cls.convert_distance(model.distance), - hnsw_config=cls.convert_hnsw_config_diff(model.hnsw_config) - if model.hnsw_config is not None - else None, - quantization_config=cls.convert_quantization_config(model.quantization_config) - if model.quantization_config is not None - else None, + hnsw_config=( + cls.convert_hnsw_config_diff(model.hnsw_config) + if model.hnsw_config is not None + else None + ), + quantization_config=( + cls.convert_quantization_config(model.quantization_config) + if model.quantization_config is not None + else None + ), on_disk=model.on_disk, ) @@ -2093,19 +2209,23 @@ def convert_search_request( sparse_indices=sparse_indices, filter=cls.convert_filter(model.filter) if model.filter is not None else None, limit=model.limit, - with_payload=cls.convert_with_payload_interface(model.with_payload) - if model.with_payload is not None - else None, + with_payload=( + cls.convert_with_payload_interface(model.with_payload) + if model.with_payload is not None + else None + ), params=cls.convert_search_params(model.params) if model.params is not None else None, score_threshold=model.score_threshold, offset=model.offset, vector_name=name, - with_vectors=cls.convert_with_vectors(model.with_vector) - if model.with_vector is not None - else None, - shard_key_selector=cls.convert_shard_key_selector(model.shard_key) - if model.shard_key - else None, + with_vectors=( + cls.convert_with_vectors(model.with_vector) + if model.with_vector is not None + else None + ), + shard_key_selector=( + cls.convert_shard_key_selector(model.shard_key) if model.shard_key else None + ), ) @classmethod @@ -2130,27 +2250,35 @@ def convert_recommend_request( negative=negative_ids, filter=cls.convert_filter(model.filter) if model.filter is not None else None, limit=model.limit, - with_payload=cls.convert_with_payload_interface(model.with_payload) - if model.with_payload is not None - else None, + with_payload=( + cls.convert_with_payload_interface(model.with_payload) + if model.with_payload is not None + else None + ), params=cls.convert_search_params(model.params) if model.params is not None else None, score_threshold=model.score_threshold, offset=model.offset, - with_vectors=cls.convert_with_vectors(model.with_vector) - if model.with_vector is not None - else None, + with_vectors=( + cls.convert_with_vectors(model.with_vector) + if model.with_vector is not None + else None + ), using=model.using, - lookup_from=cls.convert_lookup_location(model.lookup_from) - if model.lookup_from is not None - else None, - strategy=cls.convert_recommend_strategy(model.strategy) - if model.strategy is not None - else None, + lookup_from=( + cls.convert_lookup_location(model.lookup_from) + if model.lookup_from is not None + else None + ), + strategy=( + cls.convert_recommend_strategy(model.strategy) + if model.strategy is not None + else None + ), positive_vectors=positive_vectors, negative_vectors=negative_vectors, - shard_key_selector=cls.convert_shard_key_selector(model.shard_key) - if model.shard_key - else None, + shard_key_selector=( + cls.convert_shard_key_selector(model.shard_key) if model.shard_key else None + ), ) @classmethod @@ -2230,9 +2358,11 @@ def convert_tokenizer_type(cls, model: rest.TokenizerType) -> grpc.TokenizerType @classmethod def convert_text_index_params(cls, model: rest.TextIndexParams) -> grpc.TextIndexParams: return grpc.TextIndexParams( - tokenizer=cls.convert_tokenizer_type(model.tokenizer) - if model.tokenizer is not None - else None, + tokenizer=( + cls.convert_tokenizer_type(model.tokenizer) + if model.tokenizer is not None + else None + ), lowercase=model.lowercase, min_token_len=model.min_token_len, max_token_len=model.max_token_len, @@ -2407,12 +2537,16 @@ def convert_group_id(cls, model: rest.GroupId) -> grpc.GroupId: def convert_with_lookup(cls, model: rest.WithLookup) -> grpc.WithLookup: return grpc.WithLookup( collection=model.collection, - with_vectors=cls.convert_with_vectors(model.with_vectors) - if model.with_vectors is not None - else None, - with_payload=cls.convert_with_payload_interface(model.with_payload) - if model.with_payload is not None - else None, + with_vectors=( + cls.convert_with_vectors(model.with_vectors) + if model.with_vectors is not None + else None + ), + with_payload=( + cls.convert_with_payload_interface(model.with_payload) + if model.with_payload is not None + else None + ), ) @classmethod @@ -2441,12 +2575,16 @@ def convert_quantization_config_diff( @classmethod def convert_vector_params_diff(cls, model: rest.VectorParamsDiff) -> grpc.VectorParamsDiff: return grpc.VectorParamsDiff( - hnsw_config=cls.convert_hnsw_config_diff(model.hnsw_config) - if model.hnsw_config is not None - else None, - quantization_config=cls.convert_quantization_config_diff(model.quantization_config) - if model.quantization_config is not None - else None, + hnsw_config=( + cls.convert_hnsw_config_diff(model.hnsw_config) + if model.hnsw_config is not None + else None + ), + quantization_config=( + cls.convert_quantization_config_diff(model.quantization_config) + if model.quantization_config is not None + else None + ), on_disk=model.on_disk, ) @@ -2477,9 +2615,11 @@ def convert_point_insert_operation( grpc.PointStruct( id=RestToGrpc.convert_extended_point_id(model.batch.ids[idx]), vectors=vectors_batch[idx], - payload=RestToGrpc.convert_payload(model.batch.payloads[idx]) - if model.batch.payloads is not None - else None, + payload=( + RestToGrpc.convert_payload(model.batch.payloads[idx]) + if model.batch.payloads is not None + else None + ), ) for idx in range(len(model.batch.ids)) ] @@ -2665,9 +2805,9 @@ def convert_recommend_strategy(cls, model: rest.RecommendStrategy) -> grpc.Recom @classmethod def convert_sparse_index_config(cls, model: rest.SparseIndexConfig) -> grpc.SparseIndexConfig: return grpc.SparseIndexConfig( - full_scan_threshold=model.full_scan_threshold - if model.full_scan_threshold is not None - else None, + full_scan_threshold=( + model.full_scan_threshold if model.full_scan_threshold is not None else None + ), on_disk=model.on_disk if model.on_disk is not None else None, ) @@ -2676,9 +2816,9 @@ def convert_sparse_vector_params( cls, model: rest.SparseVectorParams ) -> grpc.SparseVectorParams: return grpc.SparseVectorParams( - index=cls.convert_sparse_index_config(model.index) - if model.index is not None - else None, + index=( + cls.convert_sparse_index_config(model.index) if model.index is not None else None + ), ) @classmethod diff --git a/qdrant_client/local/payload_filters.py b/qdrant_client/local/payload_filters.py index 62e16804..3d54e494 100644 --- a/qdrant_client/local/payload_filters.py +++ b/qdrant_client/local/payload_filters.py @@ -214,6 +214,17 @@ def check_should( return any(check_condition(condition, payload, point_id) for condition in conditions) +def check_min_should( + conditions: List[models.Condition], + payload: dict, + point_id: models.ExtendedPointId, + min_count: int, +) -> bool: + return ( + sum(check_condition(condition, payload, point_id) for condition in conditions) >= min_count + ) + + def check_filter( payload_filter: models.Filter, payload: dict, point_id: models.ExtendedPointId ) -> bool: @@ -226,6 +237,14 @@ def check_filter( if payload_filter.should is not None: if not check_should(payload_filter.should, payload, point_id): return False + if payload_filter.min_should is not None: + if not check_min_should( + payload_filter.min_should.conditions, + payload, + point_id, + payload_filter.min_should.min_count, + ): + return False return True diff --git a/tests/conversions/fixtures.py b/tests/conversions/fixtures.py index bbcaa27a..fc429143 100644 --- a/tests/conversions/fixtures.py +++ b/tests/conversions/fixtures.py @@ -130,6 +130,15 @@ must_not=[ grpc.Condition(filter=grpc.Filter(must=[grpc.Condition(field=field_condition_range)])) ], + min_should=grpc.MinShould( + conditions=[ + condition_has_id, + condition_is_empty, + condition_except_keywords, + condition_except_integers, + ], + min_count=3, + ), ) vector_param = grpc.VectorParams( diff --git a/tests/fixtures/filters.py b/tests/fixtures/filters.py index 0ade3cf0..0716e64c 100644 --- a/tests/fixtures/filters.py +++ b/tests/fixtures/filters.py @@ -292,6 +292,7 @@ def one_random_filter_please() -> models.Filter: two_should_filter, two_must_not_filter, should_must_filter, + min_should_filter, ] )() @@ -304,6 +305,18 @@ def should_filter() -> models.Filter: return models.Filter(should=[one_random_condition_please()]) +def min_should_filter() -> models.Filter: + min_count = random.randint(1, 3) + upper_bound = max(min_count + 1, min_count * 2) + num_conditions = random.randint(min_count, upper_bound) + return models.Filter( + min_should=models.MinShould( + conditions=[one_random_condition_please() for _ in range(num_conditions)], + min_count=min_count, + ) + ) + + def must_not_filter() -> models.Filter: return models.Filter(must_not=[one_random_condition_please()]) diff --git a/tests/test_qdrant_client.py b/tests/test_qdrant_client.py index a0d24914..2a723198 100644 --- a/tests/test_qdrant_client.py +++ b/tests/test_qdrant_client.py @@ -559,9 +559,43 @@ def test_qdrant_client_integration(prefer_grpc, numpy_upload, local_mode): with_payload=True, limit=5, ) - assert hits_should == hits_match_any + if version is None or (version >= "v1.8.0" or version == "dev"): + hits_min_should = client.search( + collection_name=COLLECTION_NAME, + query_vector=query_vector, + query_filter=Filter( + min_should=models.MinShould( + conditions=[ + FieldCondition(key="id_str", match=MatchValue(value="11")), + FieldCondition(key="rand_digit", match=MatchAny(any=list(range(10)))), + FieldCondition(key="id", match=MatchAny(any=list(range(100, 150)))), + ], + min_count=2, + ) + ), + with_payload=True, + limit=5, + ) + assert len(hits_min_should) > 0 + + hits_min_should_empty = client.search( + collection_name=COLLECTION_NAME, + query_vector=query_vector, + query_filter=Filter( + min_should=models.MinShould( + conditions=[ + FieldCondition(key="id_str", match=MatchValue(value="11")), + ], + min_count=2, + ) + ), + with_payload=True, + limit=5, + ) + assert len(hits_min_should_empty) == 0 + # Let's now query same vector with filter condition hits = client.search( collection_name=COLLECTION_NAME,