diff --git a/qdrant_client/local/qdrant_local.py b/qdrant_client/local/qdrant_local.py index e092945b..752d8200 100644 --- a/qdrant_client/local/qdrant_local.py +++ b/qdrant_client/local/qdrant_local.py @@ -3,7 +3,6 @@ import logging import os import shutil -from copy import deepcopy from io import TextIOWrapper from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union @@ -224,6 +223,14 @@ def recommend_batch( with_payload=request.with_payload, with_vectors=request.with_vector, score_threshold=request.score_threshold, + using=request.using, + lookup_from_collection=self._get_collection(request.lookup_from.collection) + if request.lookup_from + else None, + lookup_from_vector_name=request.lookup_from.vector + if request.lookup_from + else None, + strategy=request.strategy, ) for request in requests ] diff --git a/tests/congruence_tests/test_common.py b/tests/congruence_tests/test_common.py index 501f4587..98f00292 100644 --- a/tests/congruence_tests/test_common.py +++ b/tests/congruence_tests/test_common.py @@ -119,7 +119,7 @@ def compare_scored_record( point1.id == point2.id ), f"point1[{idx}].id = {point1.id}, point2[{idx}].id = {point2.id}" # adjust precision depending on the magnitude of score - max_difference = 1e-4 * 10**(math.floor(math.log(abs(point2.score), 10))) + max_difference = 1e-4 * 10 ** (math.floor(math.log(abs(point2.score), 10))) assert ( abs(point1.score - point2.score) < max_difference ), f"point1[{idx}].score = {point1.score}, point2[{idx}].score = {point2.score}, max_difference = {max_difference}" @@ -135,7 +135,12 @@ def compare_records(res1: list, res2: list) -> None: res1_item = res1[i] res2_item = res2[i] - if isinstance(res1_item, models.ScoredPoint) and isinstance(res2_item, models.ScoredPoint): + if isinstance(res1_item, list) and isinstance(res2_item, list): + compare_records(res1_item, res2_item) + + elif isinstance(res1_item, models.ScoredPoint) and isinstance( + res2_item, models.ScoredPoint + ): compare_scored_record(res1_item, res2_item, i) elif isinstance(res1_item, models.Record) and isinstance(res2_item, models.Record): diff --git a/tests/congruence_tests/test_group_recommend.py b/tests/congruence_tests/test_group_recommend.py index c688ee52..1f7920e9 100644 --- a/tests/congruence_tests/test_group_recommend.py +++ b/tests/congruence_tests/test_group_recommend.py @@ -11,7 +11,6 @@ init_remote, ) from tests.fixtures.filters import one_random_filter_please -from tests.fixtures.payload import one_random_payload_please secondary_collection_name = "secondary_collection" @@ -35,7 +34,7 @@ def simple_recommend_groups_image(self, client: QdrantBase) -> models.GroupsResu group_size=self.group_size, search_params=models.SearchParams(exact=True), ) - + def simple_recommend_groups_best_scores(self, client: QdrantBase) -> models.GroupsResult: return client.recommend_groups( collection_name=COLLECTION_NAME, diff --git a/tests/congruence_tests/test_recommendation.py b/tests/congruence_tests/test_recommendation.py index 568d6b26..3d7047a5 100644 --- a/tests/congruence_tests/test_recommendation.py +++ b/tests/congruence_tests/test_recommendation.py @@ -8,21 +8,22 @@ COLLECTION_NAME, compare_client_results, generate_fixtures, + image_vector_size, init_client, init_local, init_remote, - image_vector_size, ) from tests.fixtures.filters import one_random_filter_please secondary_collection_name = "secondary_collection" + class TestSimpleRecommendation: - + __test__ = False + def __init__(self): self.query_image = np.random.random(image_vector_size).tolist() - @classmethod def simple_recommend_image(cls, client: QdrantBase) -> List[models.ScoredPoint]: return client.recommend( @@ -82,19 +83,22 @@ def filter_recommend_text( limit=10, using="text", ) - + @classmethod def best_score_recommend(cls, client: QdrantBase) -> List[models.ScoredPoint]: return client.recommend( collection_name=COLLECTION_NAME, - positive=[10, 20,], + positive=[ + 10, + 20, + ], negative=[], with_payload=True, limit=10, using="image", strategy=models.RecommendStrategy.BEST_SCORE, ) - + @classmethod def only_negatives_best_score_recommend(cls, client: QdrantBase) -> List[models.ScoredPoint]: return client.recommend( @@ -106,7 +110,7 @@ def only_negatives_best_score_recommend(cls, client: QdrantBase) -> List[models. using="image", strategy=models.RecommendStrategy.BEST_SCORE, ) - + @classmethod def avg_vector_recommend(cls, client: QdrantBase) -> List[models.ScoredPoint]: return client.recommend( @@ -118,7 +122,7 @@ def avg_vector_recommend(cls, client: QdrantBase) -> List[models.ScoredPoint]: using="image", strategy=models.RecommendStrategy.AVERAGE_VECTOR, ) - + def recommend_from_raw_vectors(self, client: QdrantBase) -> List[models.ScoredPoint]: return client.recommend( collection_name=COLLECTION_NAME, @@ -128,7 +132,7 @@ def recommend_from_raw_vectors(self, client: QdrantBase) -> List[models.ScoredPo limit=10, using="image", ) - + def recommend_from_raw_vectors_and_ids(self, client: QdrantBase) -> List[models.ScoredPoint]: return client.recommend( collection_name=COLLECTION_NAME, @@ -139,6 +143,32 @@ def recommend_from_raw_vectors_and_ids(self, client: QdrantBase) -> List[models. using="image", ) + @staticmethod + def recommend_batch(client: QdrantBase) -> List[List[models.ScoredPoint]]: + return client.recommend_batch( + collection_name=COLLECTION_NAME, + requests=[ + models.RecommendRequest( + positive=[3], + negative=[], + limit=1, + using="image", + strategy=models.RecommendStrategy.AVERAGE_VECTOR, + ), + models.RecommendRequest( + positive=[10], + negative=[], + limit=2, + using="image", + strategy=models.RecommendStrategy.BEST_SCORE, + lookup_from=models.LookupLocation( + collection=secondary_collection_name, + vector="image", + ), + ), + ], + ) + def test_simple_recommend() -> None: fixture_records = generate_fixtures() @@ -160,10 +190,15 @@ def test_simple_recommend() -> None: compare_client_results(local_client, remote_client, searcher.simple_recommend_negative) compare_client_results(local_client, remote_client, searcher.recommend_from_another_collection) compare_client_results(local_client, remote_client, searcher.best_score_recommend) - compare_client_results(local_client, remote_client, searcher.only_negatives_best_score_recommend) + compare_client_results( + local_client, remote_client, searcher.only_negatives_best_score_recommend + ) compare_client_results(local_client, remote_client, searcher.avg_vector_recommend) compare_client_results(local_client, remote_client, searcher.recommend_from_raw_vectors) - compare_client_results(local_client, remote_client, searcher.recommend_from_raw_vectors_and_ids) + compare_client_results( + local_client, remote_client, searcher.recommend_from_raw_vectors_and_ids + ) + compare_client_results(local_client, remote_client, searcher.recommend_batch) for _ in range(10): query_filter = one_random_filter_please()