From 62a50f32cf140b24876517219726a28465ef640e Mon Sep 17 00:00:00 2001 From: Eduard van Valkenburg Date: Mon, 16 Dec 2024 20:14:36 +0100 Subject: [PATCH] Python: Qdrant - fix in filter and 100% test coverage (#9982) ### Motivation and Context There was a small error in the filter creation logic, and improved test coverage for Qdrant. ### Description ### Contribution Checklist - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone :smile: --- .../memory/qdrant/qdrant_collection.py | 4 +- .../connectors/memory/qdrant/test_qdrant.py | 84 ++++++++++++++++--- 2 files changed, 73 insertions(+), 15 deletions(-) diff --git a/python/semantic_kernel/connectors/memory/qdrant/qdrant_collection.py b/python/semantic_kernel/connectors/memory/qdrant/qdrant_collection.py index 5fb8c177be89..cb30fa0cdc76 100644 --- a/python/semantic_kernel/connectors/memory/qdrant/qdrant_collection.py +++ b/python/semantic_kernel/connectors/memory/qdrant/qdrant_collection.py @@ -188,7 +188,7 @@ async def _inner_search( else: query_vector = vector if query_vector is None: - raise VectorSearchExecutionException("Search requires either a vector.") + raise VectorSearchExecutionException("Search requires a vector.") results = await self.qdrant_client.search( collection_name=self.collection_name, query_vector=query_vector, @@ -214,7 +214,7 @@ def _get_score_from_result(self, result: ScoredPoint) -> float: def _create_filter(self, options: VectorSearchOptions) -> Filter: return Filter( must=[ - FieldCondition(key=filter.field_name, match=MatchAny(any=filter.value)) + FieldCondition(key=filter.field_name, match=MatchAny(any=[filter.value])) for filter in options.filter.filters ] ) diff --git a/python/tests/unit/connectors/memory/qdrant/test_qdrant.py b/python/tests/unit/connectors/memory/qdrant/test_qdrant.py index ce00e7d88c95..c92571daf238 100644 --- a/python/tests/unit/connectors/memory/qdrant/test_qdrant.py +++ b/python/tests/unit/connectors/memory/qdrant/test_qdrant.py @@ -4,17 +4,19 @@ from pytest import fixture, mark, raises from qdrant_client.async_qdrant_client import AsyncQdrantClient -from qdrant_client.models import Datatype, Distance, VectorParams +from qdrant_client.models import Datatype, Distance, FieldCondition, Filter, MatchAny, VectorParams from semantic_kernel.connectors.memory.qdrant.qdrant_collection import QdrantCollection from semantic_kernel.connectors.memory.qdrant.qdrant_store import QdrantStore from semantic_kernel.data.record_definition.vector_store_record_fields import VectorStoreRecordVectorField +from semantic_kernel.data.vector_search.vector_search_filter import VectorSearchFilter from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions from semantic_kernel.exceptions.memory_connector_exceptions import ( MemoryConnectorException, MemoryConnectorInitializationError, VectorStoreModelValidationError, ) +from semantic_kernel.exceptions.search_exceptions import VectorSearchExecutionException BASE_PATH = "qdrant_client.async_qdrant_client.AsyncQdrantClient" @@ -119,9 +121,10 @@ def mock_search(): yield mock_search -def test_vector_store_defaults(vector_store): - assert vector_store.qdrant_client is not None - assert vector_store.qdrant_client._client.rest_uri == "http://localhost:6333" +async def test_vector_store_defaults(vector_store): + async with vector_store: + assert vector_store.qdrant_client is not None + assert vector_store.qdrant_client._client.rest_uri == "http://localhost:6333" def test_vector_store_with_client(): @@ -162,18 +165,18 @@ def test_get_collection(vector_store, data_model_definition, qdrant_unit_test_en assert vector_store.vector_record_collections["test"] == collection -def test_collection_init(data_model_definition, qdrant_unit_test_env): - collection = QdrantCollection( +async def test_collection_init(data_model_definition, qdrant_unit_test_env): + async with QdrantCollection( data_model_type=dict, collection_name="test", data_model_definition=data_model_definition, env_file_path="test.env", - ) - assert collection.collection_name == "test" - assert collection.qdrant_client is not None - assert collection.data_model_type is dict - assert collection.data_model_definition == data_model_definition - assert collection.named_vectors + ) as collection: + assert collection.collection_name == "test" + assert collection.qdrant_client is not None + assert collection.data_model_type is dict + assert collection.data_model_definition == data_model_definition + assert collection.named_vectors def test_collection_init_fail(data_model_definition): @@ -275,8 +278,63 @@ async def test_create_index_fail(collection_to_use, request): await collection.create_collection() -async def test_search(collection): +async def test_search(collection, mock_search): results = await collection._inner_search(vector=[1.0, 2.0, 3.0], options=VectorSearchOptions(include_vectors=False)) async for result in results.results: assert result.record["id"] == "id1" break + + assert mock_search.call_count == 1 + mock_search.assert_called_with( + collection_name="test", + query_vector=[1.0, 2.0, 3.0], + query_filter=Filter(must=[]), + with_vectors=False, + limit=3, + offset=0, + ) + + +async def test_search_named_vectors(collection, mock_search): + collection.named_vectors = True + results = await collection._inner_search( + vector=[1.0, 2.0, 3.0], options=VectorSearchOptions(vector_field_name="vector", include_vectors=False) + ) + async for result in results.results: + assert result.record["id"] == "id1" + break + + assert mock_search.call_count == 1 + mock_search.assert_called_with( + collection_name="test", + query_vector=("vector", [1.0, 2.0, 3.0]), + query_filter=Filter(must=[]), + with_vectors=False, + limit=3, + offset=0, + ) + + +async def test_search_filter(collection, mock_search): + results = await collection._inner_search( + vector=[1.0, 2.0, 3.0], + options=VectorSearchOptions(include_vectors=False, filter=VectorSearchFilter.equal_to("id", "id1")), + ) + async for result in results.results: + assert result.record["id"] == "id1" + break + + assert mock_search.call_count == 1 + mock_search.assert_called_with( + collection_name="test", + query_vector=[1.0, 2.0, 3.0], + query_filter=Filter(must=[FieldCondition(key="id", match=MatchAny(any=["id1"]))]), + with_vectors=False, + limit=3, + offset=0, + ) + + +async def test_search_fail(collection): + with raises(VectorSearchExecutionException, match="Search requires a vector."): + await collection._inner_search(options=VectorSearchOptions(include_vectors=False))