Skip to content

Commit

Permalink
fix: add missing parameters in recommend batch (#338)
Browse files Browse the repository at this point in the history
  • Loading branch information
joein authored Oct 15, 2023
1 parent 278a68d commit bf1c593
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 16 deletions.
9 changes: 8 additions & 1 deletion qdrant_client/local/qdrant_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import logging
import os
import shutil
from copy import deepcopy
from io import TextIOWrapper
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union

Expand Down Expand Up @@ -224,6 +223,14 @@ def recommend_batch(
with_payload=request.with_payload,
with_vectors=request.with_vector,
score_threshold=request.score_threshold,
using=request.using,
lookup_from_collection=self._get_collection(request.lookup_from.collection)
if request.lookup_from
else None,
lookup_from_vector_name=request.lookup_from.vector
if request.lookup_from
else None,
strategy=request.strategy,
)
for request in requests
]
Expand Down
9 changes: 7 additions & 2 deletions tests/congruence_tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def compare_scored_record(
point1.id == point2.id
), f"point1[{idx}].id = {point1.id}, point2[{idx}].id = {point2.id}"
# adjust precision depending on the magnitude of score
max_difference = 1e-4 * 10**(math.floor(math.log(abs(point2.score), 10)))
max_difference = 1e-4 * 10 ** (math.floor(math.log(abs(point2.score), 10)))
assert (
abs(point1.score - point2.score) < max_difference
), f"point1[{idx}].score = {point1.score}, point2[{idx}].score = {point2.score}, max_difference = {max_difference}"
Expand All @@ -135,7 +135,12 @@ def compare_records(res1: list, res2: list) -> None:
res1_item = res1[i]
res2_item = res2[i]

if isinstance(res1_item, models.ScoredPoint) and isinstance(res2_item, models.ScoredPoint):
if isinstance(res1_item, list) and isinstance(res2_item, list):
compare_records(res1_item, res2_item)

elif isinstance(res1_item, models.ScoredPoint) and isinstance(
res2_item, models.ScoredPoint
):
compare_scored_record(res1_item, res2_item, i)

elif isinstance(res1_item, models.Record) and isinstance(res2_item, models.Record):
Expand Down
3 changes: 1 addition & 2 deletions tests/congruence_tests/test_group_recommend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
init_remote,
)
from tests.fixtures.filters import one_random_filter_please
from tests.fixtures.payload import one_random_payload_please

secondary_collection_name = "secondary_collection"

Expand All @@ -35,7 +34,7 @@ def simple_recommend_groups_image(self, client: QdrantBase) -> models.GroupsResu
group_size=self.group_size,
search_params=models.SearchParams(exact=True),
)

def simple_recommend_groups_best_scores(self, client: QdrantBase) -> models.GroupsResult:
return client.recommend_groups(
collection_name=COLLECTION_NAME,
Expand Down
57 changes: 46 additions & 11 deletions tests/congruence_tests/test_recommendation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,22 @@
COLLECTION_NAME,
compare_client_results,
generate_fixtures,
image_vector_size,
init_client,
init_local,
init_remote,
image_vector_size,
)
from tests.fixtures.filters import one_random_filter_please

secondary_collection_name = "secondary_collection"


class TestSimpleRecommendation:

__test__ = False

def __init__(self):
self.query_image = np.random.random(image_vector_size).tolist()


@classmethod
def simple_recommend_image(cls, client: QdrantBase) -> List[models.ScoredPoint]:
return client.recommend(
Expand Down Expand Up @@ -82,19 +83,22 @@ def filter_recommend_text(
limit=10,
using="text",
)

@classmethod
def best_score_recommend(cls, client: QdrantBase) -> List[models.ScoredPoint]:
return client.recommend(
collection_name=COLLECTION_NAME,
positive=[10, 20,],
positive=[
10,
20,
],
negative=[],
with_payload=True,
limit=10,
using="image",
strategy=models.RecommendStrategy.BEST_SCORE,
)

@classmethod
def only_negatives_best_score_recommend(cls, client: QdrantBase) -> List[models.ScoredPoint]:
return client.recommend(
Expand All @@ -106,7 +110,7 @@ def only_negatives_best_score_recommend(cls, client: QdrantBase) -> List[models.
using="image",
strategy=models.RecommendStrategy.BEST_SCORE,
)

@classmethod
def avg_vector_recommend(cls, client: QdrantBase) -> List[models.ScoredPoint]:
return client.recommend(
Expand All @@ -118,7 +122,7 @@ def avg_vector_recommend(cls, client: QdrantBase) -> List[models.ScoredPoint]:
using="image",
strategy=models.RecommendStrategy.AVERAGE_VECTOR,
)

def recommend_from_raw_vectors(self, client: QdrantBase) -> List[models.ScoredPoint]:
return client.recommend(
collection_name=COLLECTION_NAME,
Expand All @@ -128,7 +132,7 @@ def recommend_from_raw_vectors(self, client: QdrantBase) -> List[models.ScoredPo
limit=10,
using="image",
)

def recommend_from_raw_vectors_and_ids(self, client: QdrantBase) -> List[models.ScoredPoint]:
return client.recommend(
collection_name=COLLECTION_NAME,
Expand All @@ -139,6 +143,32 @@ def recommend_from_raw_vectors_and_ids(self, client: QdrantBase) -> List[models.
using="image",
)

@staticmethod
def recommend_batch(client: QdrantBase) -> List[List[models.ScoredPoint]]:
return client.recommend_batch(
collection_name=COLLECTION_NAME,
requests=[
models.RecommendRequest(
positive=[3],
negative=[],
limit=1,
using="image",
strategy=models.RecommendStrategy.AVERAGE_VECTOR,
),
models.RecommendRequest(
positive=[10],
negative=[],
limit=2,
using="image",
strategy=models.RecommendStrategy.BEST_SCORE,
lookup_from=models.LookupLocation(
collection=secondary_collection_name,
vector="image",
),
),
],
)


def test_simple_recommend() -> None:
fixture_records = generate_fixtures()
Expand All @@ -160,10 +190,15 @@ def test_simple_recommend() -> None:
compare_client_results(local_client, remote_client, searcher.simple_recommend_negative)
compare_client_results(local_client, remote_client, searcher.recommend_from_another_collection)
compare_client_results(local_client, remote_client, searcher.best_score_recommend)
compare_client_results(local_client, remote_client, searcher.only_negatives_best_score_recommend)
compare_client_results(
local_client, remote_client, searcher.only_negatives_best_score_recommend
)
compare_client_results(local_client, remote_client, searcher.avg_vector_recommend)
compare_client_results(local_client, remote_client, searcher.recommend_from_raw_vectors)
compare_client_results(local_client, remote_client, searcher.recommend_from_raw_vectors_and_ids)
compare_client_results(
local_client, remote_client, searcher.recommend_from_raw_vectors_and_ids
)
compare_client_results(local_client, remote_client, searcher.recommend_batch)

for _ in range(10):
query_filter = one_random_filter_please()
Expand Down

0 comments on commit bf1c593

Please sign in to comment.