Skip to content

Commit

Permalink
fix: do not modify input structs in-place (#715)
Browse files Browse the repository at this point in the history
* fix: do not modify input structs in-place

* fix: regen async
  • Loading branch information
joein committed Aug 9, 2024
1 parent 35834ff commit 81aeaec
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 9 deletions.
3 changes: 3 additions & 0 deletions qdrant_client/local/async_qdrant_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import logging
import os
import shutil
from copy import deepcopy
from io import TextIOWrapper
from typing import (
Any,
Expand Down Expand Up @@ -290,6 +291,7 @@ def input_into_vector(vector_input: types.VectorInput) -> types.VectorInput:
else:
return vector_input

query = deepcopy(query)
if isinstance(query, rest_models.NearestQuery):
query.nearest = input_into_vector(query.nearest)
elif isinstance(query, rest_models.RecommendQuery):
Expand Down Expand Up @@ -353,6 +355,7 @@ def _resolve_prefetch_input(
) -> types.Prefetch:
if prefetch.query is None:
return prefetch
prefetch = deepcopy(prefetch)
(query, mentioned_ids) = self._resolve_query_input(
collection_name, prefetch.query, prefetch.using, prefetch.lookup_from
)
Expand Down
11 changes: 3 additions & 8 deletions qdrant_client/local/local_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,8 +661,6 @@ def query_points(
Assumes all vectors have been homogenized so that there are no ids in the inputs
"""
scored_points = []

prefetches = []
if prefetch is not None:
prefetches = prefetch if isinstance(prefetch, list) else [prefetch]
Expand Down Expand Up @@ -709,11 +707,6 @@ def _prefetch(self, prefetch: types.Prefetch, offset: int) -> List[types.ScoredP
)

if len(inner_prefetches) > 0:
# Recursive case: inner prefetches
prefetches = (
prefetch.prefetch if isinstance(prefetch.prefetch, list) else [prefetch.prefetch]
)

sources = [
self._prefetch(inner_prefetch, offset) for inner_prefetch in inner_prefetches
]
Expand Down Expand Up @@ -782,7 +775,7 @@ def _merge_sources(
sources_ids.add(point.id)

if len(sources_ids) == 0:
# no need to perform a query if there no matches for the sources
# no need to perform a query if there are no matches for the sources
return []
else:
filter_with_sources = _include_ids_in_filter(query_filter, list(sources_ids))
Expand Down Expand Up @@ -2173,6 +2166,7 @@ def ignore_mentioned_ids_filter(
if query_filter is None:
query_filter = models.Filter(must_not=[ignore_mentioned_ids])
else:
query_filter = deepcopy(query_filter)
if query_filter.must_not is None:
query_filter.must_not = [ignore_mentioned_ids]
else:
Expand All @@ -2192,6 +2186,7 @@ def _include_ids_in_filter(
if query_filter is None:
query_filter = models.Filter(must=[include_ids])
else:
query_filter = deepcopy(query_filter)
if query_filter.must is None:
query_filter.must = [include_ids]
else:
Expand Down
3 changes: 3 additions & 0 deletions qdrant_client/local/qdrant_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import os
import shutil
from copy import deepcopy
from io import TextIOWrapper
from typing import (
Any,
Expand Down Expand Up @@ -305,6 +306,7 @@ def input_into_vector(
else:
return vector_input

query = deepcopy(query)
if isinstance(query, rest_models.NearestQuery):
query.nearest = input_into_vector(query.nearest)

Expand Down Expand Up @@ -375,6 +377,7 @@ def _resolve_prefetch_input(
if prefetch.query is None:
return prefetch

prefetch = deepcopy(prefetch)
query, mentioned_ids = self._resolve_query_input(
collection_name,
prefetch.query,
Expand Down
72 changes: 71 additions & 1 deletion tests/congruence_tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
multi_vector_config,
)
from tests.fixtures.filters import one_random_filter_please
from tests.fixtures.points import generate_random_sparse_vector, generate_random_multivector
from tests.fixtures.points import (
generate_random_sparse_vector,
generate_random_multivector,
generate_points,
)
from tests.utils import read_version

SECONDARY_COLLECTION_NAME = "congruence_secondary_collection"
Expand Down Expand Up @@ -1165,3 +1169,69 @@ def test_flat_query_multivector_interface(prefer_grpc):
init_client(remote_client, fixture_points, vectors_config=multi_vector_config)

compare_client_results(local_client, remote_client, searcher.multivec_query_text)


@pytest.mark.parametrize("prefer_grpc", (False, True))
def test_original_input_persistence(prefer_grpc):
num_points = 50
vectors_config = {"text": models.VectorParams(size=50, distance=models.Distance.COSINE)}
sparse_vectors_config = {"sparse-text": models.SparseVectorParams()}
fixture_points = generate_fixtures(vectors_sizes={"text": 50}, num=num_points)
sparse_fixture_points = generate_sparse_fixtures(num=num_points)
points = [
models.PointStruct(
id=point.id,
payload=point.payload,
vector={
"text": point.vector["text"],
"sparse-text": sparse_point.vector["sparse-text"],
},
)
for point, sparse_point in zip(fixture_points, sparse_fixture_points)
]
dense_vector_name = "text"
sparse_vector_name = "sparse-text"

local_client = init_local()
init_client(
local_client,
points,
vectors_config=vectors_config,
sparse_vectors_config=sparse_vectors_config,
)
remote_client = init_remote(prefer_grpc=prefer_grpc)
init_client(
remote_client,
points,
vectors_config=vectors_config,
sparse_vectors_config=sparse_vectors_config,
)

point_id = 1
shared_instance = models.RecommendInput(positive=[point_id], negative=[])
prefetch = [
models.Prefetch(
query=models.RecommendQuery(recommend=shared_instance),
using=sparse_vector_name,
),
]
local_client.query_points(
collection_name=COLLECTION_NAME,
prefetch=prefetch,
query=models.RecommendQuery(recommend=shared_instance),
using=dense_vector_name,
)

shared_instance = models.RecommendInput(positive=[point_id], negative=[])
prefetch = [
models.Prefetch(
query=models.RecommendQuery(recommend=shared_instance),
using=sparse_vector_name,
),
]
remote_client.query_points(
collection_name=COLLECTION_NAME,
prefetch=prefetch,
query=models.RecommendQuery(recommend=shared_instance),
using=dense_vector_name,
)

0 comments on commit 81aeaec

Please sign in to comment.