From 76209d0e6633e7e332808bd2e48a60c61bc93453 Mon Sep 17 00:00:00 2001 From: Aayush Kataria Date: Sat, 18 May 2024 05:32:30 -0700 Subject: [PATCH] Python: Adds a memory connector for Azure Cosmos DB for NoSQL (#6195) ### Motivation and Context Azure Cosmos DB is adding Vector Similarity APIs to the NoSQL project, and would like Semantic Kernel users to be able to leverage them. This adds a Memory Connector implementation for Azure Cosmos DB's, including support for the new vector search functionality coming soon in Cosmos DB. ### Description ### Contribution Checklist - [ ] The code builds clean without any errors or warnings - [ ] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [ ] All unit tests pass, and I have added new tests where possible - [ ] I didn't break anyone :smile: --------- Co-authored-by: Eduard van Valkenburg Co-authored-by: Evan Mattson <35585003+moonbox3@users.noreply.github.com> --- python/poetry.lock | 34 ++- python/pyproject.toml | 7 +- .../memory/azure_cosmosdb_no_sql/__init__.py | 7 + .../azure_cosmosdb_no_sql_memory_store.py | 177 +++++++++++++++ ...test_azure_cosmosdb_no_sql_memory_store.py | 210 ++++++++++++++++++ 5 files changed, 424 insertions(+), 11 deletions(-) create mode 100644 python/semantic_kernel/connectors/memory/azure_cosmosdb_no_sql/__init__.py create mode 100644 python/semantic_kernel/connectors/memory/azure_cosmosdb_no_sql/azure_cosmosdb_no_sql_memory_store.py create mode 100644 python/tests/integration/connectors/memory/test_azure_cosmosdb_no_sql_memory_store.py diff --git a/python/poetry.lock b/python/poetry.lock index 5d3a489d6c77..44feb480dfb5 100644 --- a/python/poetry.lock +++ b/python/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -320,6 +320,21 @@ typing-extensions = ">=4.6.0" [package.extras] aio = ["aiohttp (>=3.0)"] +[[package]] +name = "azure-cosmos" +version = "4.7.0" +description = "Microsoft Azure Cosmos Client Library for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "azure-cosmos-4.7.0.tar.gz", hash = "sha256:72d714033134656302a2e8957c4b93590673bd288b0ca60cb123e348ae99a241"}, + {file = "azure_cosmos-4.7.0-py3-none-any.whl", hash = "sha256:03d8c7740ddc2906fb16e07b136acc0fe6a6a02656db46c5dd6f1b127b58cc96"}, +] + +[package.dependencies] +azure-core = ">=1.25.1" +typing-extensions = ">=4.6.0" + [[package]] name = "azure-identity" version = "1.16.0" @@ -1333,12 +1348,12 @@ files = [ google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" grpcio = [ + {version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""}, {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, - {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, ] grpcio-status = [ + {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "extra == \"grpc\""}, {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, - {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, ] proto-plus = ">=1.22.3,<2.0.0dev" protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" @@ -3498,9 +3513,9 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -3794,8 +3809,8 @@ certifi = ">=2019.11.17" tqdm = ">=4.64.1" typing-extensions = ">=3.7.4" urllib3 = [ - {version = ">=1.26.5", markers = "python_version >= \"3.12\" and python_version < \"4.0\""}, {version = ">=1.26.0", markers = "python_version >= \"3.8\" and python_version < \"3.12\""}, + {version = ">=1.26.5", markers = "python_version >= \"3.12\" and python_version < \"4.0\""}, ] [package.extras] @@ -4778,6 +4793,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -4928,8 +4944,8 @@ grpcio = ">=1.41.0" grpcio-tools = ">=1.41.0" httpx = {version = ">=0.20.0", extras = ["http2"]} numpy = [ - {version = ">=1.26", markers = "python_version >= \"3.12\""}, {version = ">=1.21", markers = "python_version >= \"3.8\" and python_version < \"3.12\""}, + {version = ">=1.26", markers = "python_version >= \"3.12\""}, ] portalocker = ">=2.7.0,<3.0.0" pydantic = ">=1.10.8" @@ -6832,8 +6848,8 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] -all = ["azure-core", "azure-identity", "azure-search-documents", "chromadb", "google-generativeai", "grpcio-status", "ipykernel", "milvus", "milvus", "pinecone-client", "psycopg", "pyarrow", "pymilvus", "pymilvus", "qdrant-client", "qdrant-client", "redis", "sentence-transformers", "torch", "transformers", "usearch", "weaviate-client"] -azure = ["azure-core", "azure-identity", "azure-search-documents"] +all = ["azure-core", "azure-cosmos", "azure-identity", "azure-search-documents", "chromadb", "google-generativeai", "grpcio-status", "ipykernel", "milvus", "milvus", "pinecone-client", "psycopg", "pyarrow", "pymilvus", "pymilvus", "qdrant-client", "qdrant-client", "redis", "sentence-transformers", "torch", "transformers", "usearch", "weaviate-client"] +azure = ["azure-core", "azure-cosmos", "azure-identity", "azure-search-documents"] chromadb = ["chromadb"] google = ["google-generativeai", "grpcio-status"] hugging-face = ["sentence-transformers", "torch", "transformers"] @@ -6849,4 +6865,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = "^3.10,<3.13" -content-hash = "8f37912da67cd7728e5b3555e5286fa4fe7a2faf63b240d26b6ae6360c3d2d7f" +content-hash = "855581d6ded65eebdd6fca14d076294e8f3508ef4270becfa30c8571d81b957e" diff --git a/python/pyproject.toml b/python/pyproject.toml index afe98f521880..100ec8980a64 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -63,6 +63,7 @@ redis = { version = "^4.6.0", optional = true} azure-search-documents = {version = "11.6.0b1", allow-prereleases = true, optional = true} azure-core = { version = "^1.28.0", optional = true} azure-identity = { version = "^1.13.0", optional = true} +azure-cosmos = { version = "^4.7.0", optional = true} usearch = { version = "^2.9", optional = true} pyarrow = { version = ">=12.0.1,<16.0.0", optional = true} @@ -86,6 +87,7 @@ optional = true google-generativeai = { version = ">=0.1,<0.4", markers = "python_version >= '3.9'"} azure-search-documents = {version = "11.6.0b1", allow-prereleases = true} azure-core = "^1.28.0" +azure-cosmos = "^4.7.0" transformers = "^4.28.1" sentence-transformers = "^2.2.2" torch = "^2.2.0" @@ -116,6 +118,7 @@ redis = "^4.6.0" azure-search-documents = {version = "11.6.0b1", allow-prereleases = true} azure-core = "^1.28.0" azure-identity = "^1.13.0" +azure-cosmos = "^4.7.0" usearch = "^2.9" pyarrow = ">=12.0.1,<16.0.0" msgraph-sdk = "^1.2.0" @@ -131,10 +134,10 @@ weaviate = ["weaviate-client"] pinecone = ["pinecone-client"] postgres = ["psycopg"] redis = ["redis"] -azure = ["azure-search-documents", "azure-core", "azure-identity", "msgraph-sdk"] +azure = ["azure-search-documents", "azure-core", "azure-identity", "azure-cosmos", "msgraph-sdk"] usearch = ["usearch", "pyarrow"] notebooks = ["ipykernel"] -all = ["google-generativeai", "grpcio-status", "transformers", "sentence-transformers", "torch", "qdrant-client", "chromadb", "pymilvus", "milvus", "weaviate-client", "pinecone-client", "psycopg", "redis", "azure-search-documents", "azure-core", "azure-identity", "usearch", "pyarrow", "ipykernel"] +all = ["google-generativeai", "grpcio-status", "transformers", "sentence-transformers", "torch", "qdrant-client", "chromadb", "pymilvus", "milvus", "weaviate-client", "pinecone-client", "psycopg", "redis", "azure-search-documents", "azure-core", "azure-identity", "azure-cosmos", "usearch", "pyarrow", "ipykernel"] [tool.ruff] lint.select = ["E", "F", "I"] diff --git a/python/semantic_kernel/connectors/memory/azure_cosmosdb_no_sql/__init__.py b/python/semantic_kernel/connectors/memory/azure_cosmosdb_no_sql/__init__.py new file mode 100644 index 000000000000..743cc61920df --- /dev/null +++ b/python/semantic_kernel/connectors/memory/azure_cosmosdb_no_sql/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft. All rights reserved. + +from semantic_kernel.connectors.memory.azure_cosmosdb_no_sql.azure_cosmosdb_no_sql_memory_store import ( + AzureCosmosDBNoSQLMemoryStore, +) + +__all__ = ["AzureCosmosDBNoSQLMemoryStore"] diff --git a/python/semantic_kernel/connectors/memory/azure_cosmosdb_no_sql/azure_cosmosdb_no_sql_memory_store.py b/python/semantic_kernel/connectors/memory/azure_cosmosdb_no_sql/azure_cosmosdb_no_sql_memory_store.py new file mode 100644 index 000000000000..632869960971 --- /dev/null +++ b/python/semantic_kernel/connectors/memory/azure_cosmosdb_no_sql/azure_cosmosdb_no_sql_memory_store.py @@ -0,0 +1,177 @@ +# Copyright (c) Microsoft. All rights reserved. + +import json +from typing import Any, Dict, List, Tuple + +import numpy as np +from azure.cosmos.aio import ContainerProxy, CosmosClient, DatabaseProxy +from numpy import ndarray + +from semantic_kernel.memory.memory_record import MemoryRecord +from semantic_kernel.memory.memory_store_base import MemoryStoreBase + + +# You can read more about vector search using AzureCosmosDBNoSQL here. +# https://aka.ms/CosmosVectorSearch +class AzureCosmosDBNoSQLMemoryStore(MemoryStoreBase): + cosmos_client: CosmosClient = None + database: DatabaseProxy + container: ContainerProxy + database_name: str = None + partition_key: str = None + vector_embedding_policy: [Dict[str, Any]] = None + indexing_policy: [Dict[str, Any]] = None + cosmos_container_properties: [Dict[str, Any]] = None + + def __init__( + self, + cosmos_client: CosmosClient, + database_name: str, + partition_key: str, + vector_embedding_policy: [Dict[str, Any]], + indexing_policy: [Dict[str, Any]], + cosmos_container_properties: [Dict[str, Any]], + ): + if indexing_policy["vectorIndexes"] is None or len(indexing_policy["vectorIndexes"]) == 0: + raise ValueError("vectorIndexes cannot be null or empty in the indexing_policy.") + if vector_embedding_policy is None or len(vector_embedding_policy["vectorEmbeddings"]) == 0: + raise ValueError("vectorEmbeddings cannot be null or empty in the vector_embedding_policy.") + + self.cosmos_client = cosmos_client + self.database_name = database_name + self.partition_key = partition_key + self.vector_embedding_policy = vector_embedding_policy + self.indexing_policy = indexing_policy + self.cosmos_container_properties = cosmos_container_properties + + async def create_collection(self, collection_name: str) -> None: + # Create the database if it already doesn't exist + self.database = await self.cosmos_client.create_database_if_not_exists(id=self.database_name) + + # Create the collection if it already doesn't exist + self.container = await self.database.create_container_if_not_exists( + id=collection_name, + partition_key=self.cosmos_container_properties["partition_key"], + indexing_policy=self.indexing_policy, + vector_embedding_policy=self.vector_embedding_policy, + ) + + async def get_collections(self) -> List[str]: + return [container["id"] async for container in self.database.list_containers()] + + async def delete_collection(self, collection_name: str) -> None: + return await self.database.delete_container(collection_name) + + async def does_collection_exist(self, collection_name: str) -> bool: + return collection_name in [container["id"] async for container in self.database.list_containers()] + + async def upsert(self, collection_name: str, record: MemoryRecord) -> str: + result = await self.upsert_batch(collection_name, [record]) + return result[0] + + async def upsert_batch(self, collection_name: str, records: List[MemoryRecord]) -> List[str]: + doc_ids: List[str] = [] + for record in records: + cosmosRecord: dict = { + "id": record.id, + "embedding": record.embedding.tolist(), + "text": record.text, + "description": record.description, + "metadata": self.__serialize_metadata(record), + } + if record.timestamp is not None: + cosmosRecord["timeStamp"] = record.timestamp + + await self.container.create_item(cosmosRecord) + doc_ids.append(cosmosRecord["id"]) + return doc_ids + + async def get(self, collection_name: str, key: str, with_embedding: bool) -> MemoryRecord: + item = await self.container.read_item(key, partition_key=key) + return MemoryRecord.local_record( + id=item["id"], + embedding=np.array(item["embedding"]) if with_embedding else np.array([]), + text=item["text"], + description=item["description"], + additional_metadata=item["metadata"], + timestamp=item.get("timestamp", None), + ) + + async def get_batch(self, collection_name: str, keys: List[str], with_embeddings: bool) -> List[MemoryRecord]: + query = "SELECT * FROM c WHERE ARRAY_CONTAINS(@ids, c.id)" + parameters = [{"name": "@ids", "value": keys}] + + all_results = [] + items = [item async for item in self.container.query_items(query, parameters=parameters)] + for item in items: + MemoryRecord.local_record( + id=item["id"], + embedding=np.array(item["embedding"]) if with_embeddings else np.array([]), + text=item["text"], + description=item["description"], + additional_metadata=item["metadata"], + timestamp=item.get("timestamp", None), + ) + all_results.append(item) + return all_results + + async def remove(self, collection_name: str, key: str) -> None: + await self.container.delete_item(key, partition_key=key) + + async def remove_batch(self, collection_name: str, keys: List[str]) -> None: + for key in keys: + await self.container.delete_item(key, partition_key=key) + + async def get_nearest_matches( + self, collection_name: str, embedding: ndarray, limit: int, min_relevance_score: float, with_embeddings: bool + ) -> List[Tuple[MemoryRecord, float]]: + embedding_key = self.vector_embedding_policy["vectorEmbeddings"][0]["path"][1:] + query = ( + "SELECT TOP {} c.id, c.{}, c.text, c.description, c.metadata, " + "c.timestamp, VectorDistance(c.{}, {}) AS SimilarityScore FROM c ORDER BY " + "VectorDistance(c.{}, {})".format( + limit, embedding_key, embedding_key, embedding.tolist(), embedding_key, embedding.tolist() + ) + ) + + items = [item async for item in self.container.query_items(query=query)] + nearest_results = [] + for item in items: + score = item["SimilarityScore"] + if score < min_relevance_score: + continue + result = MemoryRecord.local_record( + id=item["id"], + embedding=np.array(item["embedding"]) if with_embeddings else np.array([]), + text=item["text"], + description=item["description"], + additional_metadata=item["metadata"], + timestamp=item.get("timestamp", None), + ) + nearest_results.append((result, score)) + return nearest_results + + async def get_nearest_match( + self, collection_name: str, embedding: ndarray, min_relevance_score: float, with_embedding: bool + ) -> Tuple[MemoryRecord, float]: + nearest_results = await self.get_nearest_matches( + collection_name=collection_name, + embedding=embedding, + limit=1, + min_relevance_score=min_relevance_score, + with_embeddings=with_embedding, + ) + if len(nearest_results) > 0: + return nearest_results[0] + else: + return None + + @staticmethod + def __serialize_metadata(record: MemoryRecord) -> str: + return json.dumps( + { + "text": record.text, + "description": record.description, + "additional_metadata": record.additional_metadata, + } + ) diff --git a/python/tests/integration/connectors/memory/test_azure_cosmosdb_no_sql_memory_store.py b/python/tests/integration/connectors/memory/test_azure_cosmosdb_no_sql_memory_store.py new file mode 100644 index 000000000000..68352a4398d0 --- /dev/null +++ b/python/tests/integration/connectors/memory/test_azure_cosmosdb_no_sql_memory_store.py @@ -0,0 +1,210 @@ +# Copyright (c) Microsoft. All rights reserved. +from typing import List + +import numpy as np +import pytest +from azure.cosmos import PartitionKey +from azure.cosmos.aio import CosmosClient + +from semantic_kernel.memory.memory_record import MemoryRecord +from semantic_kernel.memory.memory_store_base import MemoryStoreBase + +try: + from semantic_kernel.connectors.memory.azure_cosmosdb_no_sql.azure_cosmosdb_no_sql_memory_store import ( + AzureCosmosDBNoSQLMemoryStore, + ) + + azure_cosmosdb_no_sql_memory_store_installed = True +except AssertionError: + azure_cosmosdb_no_sql_memory_store_installed = False + +pytest_mark = pytest.mark.skipif( + not azure_cosmosdb_no_sql_memory_store_installed, + reason="Azure CosmosDB No SQL Memory Store is not installed", +) + +# Host and Key for CosmosDB No SQl +HOST = "" +KEY = "" + +if not HOST or KEY: + skip_test = True +else: + skip_test = False + +cosmos_client = CosmosClient(HOST, KEY) +database_name = "sk_python_db" +container_name = "sk_python_container" +partition_key = PartitionKey(path="/id") +cosmos_container_properties = {"partition_key": partition_key} + + +async def azure_cosmosdb_no_sql_memory_store() -> MemoryStoreBase: + store = AzureCosmosDBNoSQLMemoryStore( + cosmos_client=cosmos_client, + database_name=database_name, + partition_key=partition_key.path, + vector_embedding_policy=get_vector_embedding_policy("cosine", "float32", 5), + indexing_policy=get_vector_indexing_policy("flat"), + cosmos_container_properties=cosmos_container_properties, + ) + return store + + +@pytest.mark.asyncio +@pytest.mark.skipif(skip_test, reason="Skipping test because HOST or KEY is not set") +async def test_create_get_drop_exists_collection(): + store = await azure_cosmosdb_no_sql_memory_store() + + await store.create_collection(collection_name=container_name) + + collection_list = await store.get_collections() + assert container_name in collection_list + + await store.delete_collection(collection_name=container_name) + + result = await store.does_collection_exist(collection_name=container_name) + assert result is False + + +@pytest.mark.asyncio +@pytest.mark.skipif(skip_test, reason="Skipping test because HOST or KEY is not set") +async def test_upsert_and_get_and_remove(): + store = await azure_cosmosdb_no_sql_memory_store() + await store.create_collection(collection_name=container_name) + record = get_vector_items()[0] + + doc_id = await store.upsert(container_name, record) + assert doc_id == record.id + + result = await store.get(container_name, record.id, with_embedding=True) + + assert result is not None + assert result.id == record.id + assert all(result._embedding[i] == record._embedding[i] for i in range(len(result._embedding))) + await store.remove(container_name, record.id) + + +@pytest.mark.asyncio +@pytest.mark.skipif(skip_test, reason="Skipping test because HOST or KEY is not set") +async def test_upsert_batch_and_get_batch_remove_batch(): + store = await azure_cosmosdb_no_sql_memory_store() + await store.create_collection(collection_name=container_name) + records = get_vector_items() + + doc_ids = await store.upsert_batch(container_name, records) + assert len(doc_ids) == 3 + assert all(doc_id in [record.id for record in records] for doc_id in doc_ids) + + results = await store.get_batch(container_name, [record.id for record in records], with_embeddings=True) + + assert len(results) == 3 + assert all(result["id"] in [record.id for record in records] for result in results) + + await store.remove_batch(container_name, [record.id for record in records]) + + +@pytest.mark.asyncio +@pytest.mark.skipif(skip_test, reason="Skipping test because HOST or KEY is not set") +async def test_get_nearest_match(): + store = await azure_cosmosdb_no_sql_memory_store() + await store.create_collection(collection_name=container_name) + records = get_vector_items() + await store.upsert_batch(container_name, records) + + test_embedding = get_vector_items()[0].embedding.copy() + test_embedding[0] = test_embedding[0] + 0.1 + + result = await store.get_nearest_match(container_name, test_embedding, min_relevance_score=0.0, with_embedding=True) + + assert result is not None + assert result[1] > 0.0 + + await store.remove_batch(container_name, [record.id for record in records]) + + +@pytest.mark.asyncio +@pytest.mark.skipif(skip_test, reason="Skipping test because HOST or KEY is not set") +async def test_get_nearest_matches(): + store = await azure_cosmosdb_no_sql_memory_store() + await store.create_collection(collection_name=container_name) + records = get_vector_items() + await store.upsert_batch(container_name, records) + + test_embedding = get_vector_items()[0].embedding.copy() + test_embedding[0] = test_embedding[0] + 0.1 + + result = await store.get_nearest_matches( + container_name, test_embedding, limit=3, min_relevance_score=0.0, with_embeddings=True + ) + + assert result is not None + assert len(result) == 3 + assert all(result[i][0].id in [record.id for record in records] for i in range(3)) + + await store.remove_batch(container_name, [record.id for record in records]) + + +def get_vector_indexing_policy(embedding_type): + return { + "indexingMode": "consistent", + "includedPaths": [{"path": "/*"}], + "vectorIndexes": [{"path": "/embedding", "type": f"{embedding_type}"}], + } + + +def get_vector_embedding_policy(distance_function, data_type, dimensions): + return { + "vectorEmbeddings": [ + { + "path": "/embedding", + "dataType": f"{data_type}", + "dimensions": dimensions, + "distanceFunction": f"{distance_function}", + } + ] + } + + +def create_embedding(non_zero_pos: int) -> np.ndarray: + # Create a NumPy array with a single non-zero value of dimension 1546 + embedding = np.zeros(5) + embedding[non_zero_pos - 1] = 1.0 + return embedding + + +def get_vector_items() -> List[MemoryRecord]: + records = [] + record = MemoryRecord( + id="test_id1", + text="sample text1", + is_reference=False, + embedding=create_embedding(non_zero_pos=2), + description="description", + additional_metadata="additional metadata", + external_source_name="external source", + ) + records.append(record) + + record = MemoryRecord( + id="test_id2", + text="sample text2", + is_reference=False, + embedding=create_embedding(non_zero_pos=3), + description="description", + additional_metadata="additional metadata", + external_source_name="external source", + ) + records.append(record) + + record = MemoryRecord( + id="test_id3", + text="sample text3", + is_reference=False, + embedding=create_embedding(non_zero_pos=4), + description="description", + additional_metadata="additional metadata", + external_source_name="external source", + ) + records.append(record) + return records