Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for MongoDB Atlas $vectorSearch vector search #2825

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ Next, you'll need create a MongoDB Atlas cluster. Navigate to the [MongoDB Atlas

Create and name a cluster when prompted, then find it under `Database`. Select `Collections` and create either a blank collection or one from the provided sample data.

** Note ** The cluster created must be MongoDB 7.0 or higher. If you are using a pre-7.0 version of MongoDB, you must use a version of langchainjs<=0.0.163.

### Creating an Index

After configuring your cluster, you'll need to create an index on the collection field you want to search over.
Expand Down
70 changes: 29 additions & 41 deletions langchain/src/vectorstores/mongodb_atlas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,63 +102,51 @@ export class MongoDBAtlasVectorSearch extends VectorStore {
k: number,
filter?: MongoDBAtlasFilter
): Promise<[Document, number][]> {
const knnBeta: MongoDBDocument = {
vector: query,
path: this.embeddingKey,
k,
};

let preFilter: MongoDBDocument | undefined;
let postFilterPipeline: MongoDBDocument[] | undefined;
let includeEmbeddings: boolean | undefined;
if (
const postFilterPipeline = filter?.postFilterPipeline ?? [];
const preFilter: MongoDBDocument | undefined =
filter?.preFilter ||
filter?.postFilterPipeline ||
filter?.includeEmbeddings
) {
preFilter = filter.preFilter;
postFilterPipeline = filter.postFilterPipeline;
includeEmbeddings = filter.includeEmbeddings || false;
} else preFilter = filter;
? filter.preFilter
: filter;
const removeEmbeddingsPipeline = !filter?.includeEmbeddings
? [
{
$project: {
[this.embeddingKey]: 0,
},
},
]
: [];

if (preFilter) {
knnBeta.filter = preFilter;
}
const pipeline: MongoDBDocument[] = [
{
$search: {
$vectorSearch: {
queryVector: query,
index: this.indexName,
knnBeta,
path: this.embeddingKey,
limit: k,
numCandidates: 10 * k,
...(preFilter && { filter: preFilter }),
},
},
{
$set: {
score: { $meta: "searchScore" },
score: { $meta: "vectorSearchScore" },
},
},
...removeEmbeddingsPipeline,
...postFilterPipeline,
];

if (!includeEmbeddings) {
const removeEmbeddingsStage = {
$project: {
[this.embeddingKey]: 0,
},
};
pipeline.push(removeEmbeddingsStage);
}

if (postFilterPipeline) {
pipeline.push(...postFilterPipeline);
}
const results = this.collection.aggregate(pipeline);

const ret: [Document, number][] = [];
for await (const result of results) {
const { score, [this.textKey]: text, ...metadata } = result;
ret.push([new Document({ pageContent: text, metadata }), score]);
}
const results = this.collection
.aggregate(pipeline)
.map<[Document, number]>((result) => {
const { score, [this.textKey]: text, ...metadata } = result;
return [new Document({ pageContent: text, metadata }), score];
});

return ret;
return results.toArray();
}

/**
Expand Down
25 changes: 13 additions & 12 deletions langchain/src/vectorstores/tests/mongodb_atlas.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

import { test, expect } from "@jest/globals";
import { MongoClient } from "mongodb";
import { CohereEmbeddings } from "../../embeddings/cohere.js";
import { setTimeout } from "timers/promises";
import { MongoDBAtlasVectorSearch } from "../mongodb_atlas.js";

import { Document } from "../../document.js";
import { OpenAIEmbeddings } from "../../embeddings/openai.js";

/**
* The following json can be used to create an index in atlas for Cohere embeddings.
Expand All @@ -15,8 +16,9 @@ import { Document } from "../../document.js";
{
"mappings": {
"fields": {
"e": { "type": "number" },
"embedding": {
"dimensions": 1024,
"dimensions": 1536,
"similarity": "euclidean",
"type": "knnVector"
}
Expand All @@ -25,10 +27,6 @@ import { Document } from "../../document.js";
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment flags a change in the code that explicitly accesses an environment variable via process.env, and it should be reviewed by maintainers.

*/

function sleep(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms));
}

test.skip("MongoDBAtlasVectorSearch with external ids", async () => {
expect(process.env.MONGODB_ATLAS_URI).toBeDefined();

Expand All @@ -40,7 +38,7 @@ test.skip("MongoDBAtlasVectorSearch with external ids", async () => {
const [dbName, collectionName] = namespace.split(".");
const collection = client.db(dbName).collection(collectionName);

const vectorStore = new MongoDBAtlasVectorSearch(new CohereEmbeddings(), {
const vectorStore = new MongoDBAtlasVectorSearch(new OpenAIEmbeddings(), {
collection,
});

Expand All @@ -57,7 +55,7 @@ test.skip("MongoDBAtlasVectorSearch with external ids", async () => {
]);

// we sleep 2 seconds to make sure the index in atlas has replicated the new documents
await sleep(2000);
await setTimeout(2000);
const results: Document[] = await vectorStore.similaritySearch(
"Sandwich",
1
Expand All @@ -70,7 +68,7 @@ test.skip("MongoDBAtlasVectorSearch with external ids", async () => {

// we can pre filter the search
const preFilter = {
range: { lte: 1, path: "e" },
e: { $lte: 1 },
};

const filteredResults = await vectorStore.similaritySearch(
Expand Down Expand Up @@ -113,12 +111,12 @@ test.skip("MongoDBAtlasVectorSearch with Maximal Marginal Relevance", async () =
const vectorStore = await MongoDBAtlasVectorSearch.fromTexts(
texts,
{},
new CohereEmbeddings(),
{ collection }
new OpenAIEmbeddings(),
{ collection, indexName: "default" }
);

// we sleep 2 seconds to make sure the index in atlas has replicated the new documents
await sleep(2000);
await setTimeout(5000);

const output = await vectorStore.maxMarginalRelevanceSearch("foo", {
k: 10,
Expand Down Expand Up @@ -158,6 +156,9 @@ test.skip("MongoDBAtlasVectorSearch with Maximal Marginal Relevance", async () =
const retrieverActual = retrieverOutput.map((doc) => doc.pageContent);
const retrieverExpected = ["foo", "foy", "foo"];
expect(retrieverActual).toEqual(retrieverExpected);

const similarity = await vectorStore.similaritySearchWithScore("foo", 1);
expect(similarity.length).toBe(1);
} finally {
await client.close();
}
Expand Down