Skip to content

Commit

Permalink
fix: fix prefetch conversion, fix local mode query batch points offse…
Browse files Browse the repository at this point in the history
…t, add tests (#812)
  • Loading branch information
joein authored Oct 15, 2024
1 parent f2494ca commit 5897726
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 3 deletions.
5 changes: 4 additions & 1 deletion qdrant_client/conversions/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2901,9 +2901,12 @@ def convert_search_points(
def convert_query_request(
cls, model: rest.QueryRequest, collection_name: str
) -> grpc.QueryPoints:
prefetch = (
[model.prefetch] if isinstance(model.prefetch, rest.Prefetch) else model.prefetch
)
return grpc.QueryPoints(
collection_name=collection_name,
prefetch=[cls.convert_prefetch_query(prefetch) for prefetch in model.prefetch]
prefetch=[cls.convert_prefetch_query(p) for p in prefetch]
if model.prefetch is not None
else None,
query=cls.convert_query_interface(model.query) if model.query is not None else None,
Expand Down
2 changes: 1 addition & 1 deletion qdrant_client/local/async_qdrant_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ async def query_batch_points(
prefetch=request.prefetch,
query_filter=request.filter,
limit=request.limit,
offset=request.offset,
offset=request.offset or 0,
with_payload=request.with_payload,
with_vectors=request.with_vector,
score_threshold=request.score_threshold,
Expand Down
2 changes: 1 addition & 1 deletion qdrant_client/local/qdrant_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def query_batch_points(
prefetch=request.prefetch,
query_filter=request.filter,
limit=request.limit,
offset=request.offset,
offset=request.offset or 0,
with_payload=request.with_payload,
with_vectors=request.with_vector,
score_threshold=request.score_threshold,
Expand Down
3 changes: 3 additions & 0 deletions tests/congruence_tests/test_query_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def __init__(self):
self.dense_vector_query_batch_text.append(
models.QueryRequest(
query=np.random.random(text_vector_size).tolist(),
prefetch=models.Prefetch(
query=np.random.random(text_vector_size).tolist(), limit=5, using="text"
),
limit=5,
using="text",
with_payload=True,
Expand Down
22 changes: 22 additions & 0 deletions tests/conversions/test_validate_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,25 @@ def test_convert_flat_filter():
assert recovered.must[0] == rest_filter.must
assert recovered.should[0] == rest_filter.should
assert recovered.must_not[0] == rest_filter.must_not


def test_query_points():
from qdrant_client import models
from qdrant_client.conversions.conversion import GrpcToRest, RestToGrpc

prefetch = models.Prefetch(query=models.NearestQuery(nearest=[1.0, 2.0]))
query_request = models.QueryRequest(
query=1,
limit=5,
using="test",
with_payload=True,
prefetch=prefetch,
)
grpc_query_request = RestToGrpc.convert_query_request(query_request, "check")
recovered = GrpcToRest.convert_query_points(grpc_query_request)

assert recovered.query == models.NearestQuery(nearest=query_request.query)
assert recovered.limit == query_request.limit
assert recovered.using == query_request.using
assert recovered.with_payload == query_request.with_payload
assert recovered.prefetch[0] == query_request.prefetch

0 comments on commit 5897726

Please sign in to comment.