Skip to content

Commit

Permalink
qdrant: Add similarity_search_with_score_by_vector() function to th…
Browse files Browse the repository at this point in the history
…e `QdrantVectorStore` (#29641)

Added `similarity_search_with_score_by_vector()` function to the
`QdrantVectorStore` class.

It is required when we want to query multiple time with the same
embeddings. It was present in the now deprecated original `Qdrant`
vectorstore implementation, but was absent from the new one. It is also
implemented in a number of others `VectorStore` implementations

I have added tests for this new function

Note that I also argued in this discussion that it should be part of the
general `VectorStore`
#29638

Co-authored-by: Erick Friis <[email protected]>
  • Loading branch information
vemonet and efriis authored Feb 7, 2025
1 parent 488cb4a commit 3645181
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 8 deletions.
47 changes: 39 additions & 8 deletions libs/partners/qdrant/langchain_qdrant/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ def similarity_search_with_score(
for result in results
]

def similarity_search_by_vector(
def similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
Expand All @@ -578,11 +578,11 @@ def similarity_search_by_vector(
score_threshold: Optional[float] = None,
consistency: Optional[models.ReadConsistency] = None,
**kwargs: Any,
) -> List[Document]:
) -> List[tuple[Document, float]]:
"""Return docs most similar to embedding vector.
Returns:
List of Documents most similar to the query.
List of Documents most similar to the query and distance for each.
"""
qdrant_filter = filter

Expand All @@ -609,15 +609,46 @@ def similarity_search_by_vector(
).points

return [
self._document_from_point(
result,
self.collection_name,
self.content_payload_key,
self.metadata_payload_key,
(
self._document_from_point(
result,
self.collection_name,
self.content_payload_key,
self.metadata_payload_key,
),
result.score,
)
for result in results
]

def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[models.Filter] = None,
search_params: Optional[models.SearchParams] = None,
offset: int = 0,
score_threshold: Optional[float] = None,
consistency: Optional[models.ReadConsistency] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs most similar to embedding vector.
Returns:
List of Documents most similar to the query.
"""
results = self.similarity_search_with_score_by_vector(
embedding,
k,
filter=filter,
search_params=search_params,
offset=offset,
score_threshold=score_threshold,
consistency=consistency,
**kwargs,
)
return list(map(itemgetter(0), results))

def max_marginal_relevance_search(
self,
query: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,39 @@ def test_similarity_search_by_vector(
assert_documents_equals(output, [Document(page_content="foo")])


@pytest.mark.parametrize("location", qdrant_locations())
@pytest.mark.parametrize("content_payload_key", [QdrantVectorStore.CONTENT_KEY, "foo"])
@pytest.mark.parametrize(
"metadata_payload_key", [QdrantVectorStore.METADATA_KEY, "bar"]
)
@pytest.mark.parametrize("vector_name", ["", "my-vector"])
@pytest.mark.parametrize("batch_size", [1, 64])
def test_similarity_search_with_score_by_vector(
location: str,
content_payload_key: str,
metadata_payload_key: str,
vector_name: str,
batch_size: int,
) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
docsearch = QdrantVectorStore.from_texts(
texts,
ConsistentFakeEmbeddings(),
location=location,
content_payload_key=content_payload_key,
metadata_payload_key=metadata_payload_key,
batch_size=batch_size,
vector_name=vector_name,
)
embeddings = ConsistentFakeEmbeddings().embed_query("foo")
output = docsearch.similarity_search_with_score_by_vector(embeddings, k=1)
assert len(output) == 1
document, score = output[0]
assert_documents_equals([document], [Document(page_content="foo")])
assert score >= 0


@pytest.mark.parametrize("location", qdrant_locations())
@pytest.mark.parametrize(
"metadata_payload_key", [QdrantVectorStore.METADATA_KEY, "bar"]
Expand Down

0 comments on commit 3645181

Please sign in to comment.