Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for MongoDB Atlas $vectorSearch vector search #11139

Merged
merged 14 commits into from
Sep 28, 2023
95 changes: 53 additions & 42 deletions libs/langchain/langchain/vectorstores/mongodb_atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,18 @@ def from_connection_string(
embedding: Embeddings,
**kwargs: Any,
) -> MongoDBAtlasVectorSearch:
"""Construct a `MongoDB Atlas Vector Search` vector store
from a MongoDB connection URI.

Args:
connection_string: A valid MongoDB connection URI.
namespace: A valid MongoDB namespace (database and collection).
embedding: The text embedding model to use for the vector store.

Returns:
A new MongoDBAtlasVectorSearch instance.

"""
try:
from pymongo import MongoClient
except ImportError:
Expand Down Expand Up @@ -149,24 +161,23 @@ def _similarity_search_with_score(
self,
embedding: List[float],
k: int = 4,
pre_filter: Optional[dict] = None,
pre_filter: Optional[Dict] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
) -> List[Tuple[Document, float]]:
knn_beta = {
"vector": embedding,
params = {
"queryVector": embedding,
"path": self._embedding_key,
"k": k,
"numCandidates": k * 10,
"limit": k,
"index": self._index_name,
}
if pre_filter:
knn_beta["filter"] = pre_filter
params["filter"] = pre_filter
query = {"$vectorSearch": params}

pipeline = [
{
"$search": {
"index": self._index_name,
"knnBeta": knn_beta,
}
},
{"$set": {"score": {"$meta": "searchScore"}}},
query,
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
]
if post_filter_pipeline is not None:
pipeline.extend(post_filter_pipeline)
Expand All @@ -183,12 +194,12 @@ def similarity_search_with_score(
query: str,
*,
k: int = 4,
pre_filter: Optional[dict] = None,
pre_filter: Optional[Dict] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
) -> List[Tuple[Document, float]]:
"""Return MongoDB documents most similar to query, along with scores.
"""Return MongoDB documents most similar to the given query and their scores.

Use the knnBeta Operator available in MongoDB Atlas Search
Uses the knnBeta Operator available in MongoDB Atlas Search.
This feature is in early access and available only for evaluation purposes, to
validate functionality, and to gather feedback from a small closed group of
early access users. It is not recommended for production deployments as we
Expand All @@ -197,14 +208,14 @@ def similarity_search_with_score(

Args:
query: Text to look up documents similar to.
k: Optional Number of Documents to return. Defaults to 4.
pre_filter: Optional Dictionary of argument(s) to prefilter on document
fields.
post_filter_pipeline: Optional Pipeline of MongoDB aggregation stages
following the knnBeta search.
k: (Optional) number of documents to return. Defaults to 4.
pre_filter: (Optional) dictionary of argument(s) to prefilter document
fields on.
post_filter_pipeline: (Optional) Pipeline of MongoDB aggregation stages
following the knnBeta vector search.

Returns:
List of Documents most similar to the query and score for each
List of documents most similar to the query and their scores.
"""
embedding = self._embedding.embed_query(query)
docs = self._similarity_search_with_score(
Expand All @@ -219,29 +230,29 @@ def similarity_search(
self,
query: str,
k: int = 4,
pre_filter: Optional[dict] = None,
pre_filter: Optional[Dict] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return MongoDB documents most similar to query.
"""Return MongoDB documents most similar to the given query.

Use the knnBeta Operator available in MongoDB Atlas Search
Uses the knnBeta Operator available in MongoDB Atlas Search.
This feature is in early access and available only for evaluation purposes, to
validate functionality, and to gather feedback from a small closed group of
early access users. It is not recommended for production deployments as we may
introduce breaking changes.
early access users. It is not recommended for production deployments as we
may introduce breaking changes.
For more: https://www.mongodb.com/docs/atlas/atlas-search/knn-beta

Args:
query: Text to look up documents similar to.
k: Optional Number of Documents to return. Defaults to 4.
pre_filter: Optional Dictionary of argument(s) to prefilter on document
fields.
post_filter_pipeline: Optional Pipeline of MongoDB aggregation stages
following the knnBeta search.
k: (Optional) number of documents to return. Defaults to 4.
pre_filter: (Optional) dictionary of argument(s) to prefilter document
fields on.
post_filter_pipeline: (Optional) Pipeline of MongoDB aggregation stages
following the knnBeta vector search.

Returns:
List of Documents most similar to the query and score for each
List of documents most similar to the query and their scores.
"""
docs_and_scores = self.similarity_search_with_score(
query,
Expand All @@ -257,30 +268,30 @@ def max_marginal_relevance_search(
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
pre_filter: Optional[dict] = None,
pre_filter: Optional[Dict] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
"""Return documents selected using the maximal marginal relevance.

Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.

Args:
query: Text to look up documents similar to.
k: Optional Number of Documents to return. Defaults to 4.
fetch_k: Optional Number of Documents to fetch before passing to MMR
k: (Optional) number of documents to return. Defaults to 4.
fetch_k: (Optional) number of documents to fetch before passing to MMR
algorithm. Defaults to 20.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
pre_filter: Optional Dictionary of argument(s) to prefilter on document
pre_filter: (Optional) dictionary of argument(s) to prefilter on document
fields.
post_filter_pipeline: Optional Pipeline of MongoDB aggregation stages
following the knnBeta search.
post_filter_pipeline: (Optional) pipeline of MongoDB aggregation stages
following the knnBeta vector search.
Returns:
List of Documents selected by maximal marginal relevance.
List of documents selected by maximal marginal relevance.
"""
query_embedding = self._embedding.embed_query(query)
docs = self._similarity_search_with_score(
Expand All @@ -303,11 +314,11 @@ def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
metadatas: Optional[List[Dict]] = None,
collection: Optional[Collection[MongoDBDocumentType]] = None,
**kwargs: Any,
) -> MongoDBAtlasVectorSearch:
"""Construct MongoDBAtlasVectorSearch wrapper from raw documents.
"""Construct a `MongoDB Atlas Vector Search` vector store from raw documents.

This is a user-friendly interface that:
1. Embeds documents.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,49 +3,54 @@

import os
from time import sleep
from typing import TYPE_CHECKING, Any
from typing import Any

import pytest

from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores.mongodb_atlas import MongoDBAtlasVectorSearch

if TYPE_CHECKING:
from pymongo import MongoClient

INDEX_NAME = "langchain-test-index"
NAMESPACE = "langchain_test_db.langchain_test_collection"
CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI")
DB_NAME, COLLECTION_NAME = NAMESPACE.split(".")

# Instantiate as constant instead of pytest fixture to prevent needing to make multiple
# connections.

def get_collection() -> Any:
from pymongo import MongoClient

@pytest.fixture
def collection() -> Any:
test_client = MongoClient(CONNECTION_STRING)
test_client: MongoClient = MongoClient(CONNECTION_STRING)
return test_client[DB_NAME][COLLECTION_NAME]


@pytest.fixture()
def collection() -> Any:
return get_collection()


class TestMongoDBAtlasVectorSearch:
@classmethod
def setup_class(cls, collection: Any) -> None:
def setup_class(cls) -> None:
# insure the test collection is empty
collection = get_collection()
assert collection.count_documents({}) == 0 # type: ignore[index] # noqa: E501

@classmethod
def teardown_class(cls, collection: Any) -> None:
def teardown_class(cls) -> None:
collection = get_collection()
# delete all the documents in the collection
collection.delete_many({}) # type: ignore[index]

@pytest.fixture(autouse=True)
def setup(self, collection: Any) -> None:
def setup(self) -> None:
collection = get_collection()
# delete all the documents in the collection
collection.delete_many({}) # type: ignore[index]

def test_from_documents(self, embedding_openai: Embeddings) -> None:
def test_from_documents(
self, embedding_openai: Embeddings, collection: Any
) -> None:
"""Test end to end construction and search."""
documents = [
Document(page_content="Dogs are tough.", metadata={"a": 1}),
Expand All @@ -64,7 +69,7 @@ def test_from_documents(self, embedding_openai: Embeddings) -> None:
assert output[0].page_content == "What is a sandwich?"
assert output[0].metadata["c"] == 1

def test_from_texts(self, embedding_openai: Embeddings) -> None:
def test_from_texts(self, embedding_openai: Embeddings, collection: Any) -> None:
texts = [
"Dogs are tough.",
"Cats have fluff.",
Expand All @@ -81,7 +86,9 @@ def test_from_texts(self, embedding_openai: Embeddings) -> None:
output = vectorstore.similarity_search("Sandwich", k=1)
assert output[0].page_content == "What is a sandwich?"

def test_from_texts_with_metadatas(self, embedding_openai: Embeddings) -> None:
def test_from_texts_with_metadatas(
self, embedding_openai: Embeddings, collection: Any
) -> None:
texts = [
"Dogs are tough.",
"Cats have fluff.",
Expand All @@ -102,7 +109,7 @@ def test_from_texts_with_metadatas(self, embedding_openai: Embeddings) -> None:
assert output[0].metadata["c"] == 1

def test_from_texts_with_metadatas_and_pre_filter(
self, embedding_openai: Embeddings
self, embedding_openai: Embeddings, collection: Any
) -> None:
texts = [
"Dogs are tough.",
Expand All @@ -124,7 +131,7 @@ def test_from_texts_with_metadatas_and_pre_filter(
)
assert output == []

def test_mmr(self, embedding_openai: Embeddings) -> None:
def test_mmr(self, embedding_openai: Embeddings, collection: Any) -> None:
texts = ["foo", "foo", "fou", "foy"]
vectorstore = MongoDBAtlasVectorSearch.from_texts(
texts,
Expand Down