Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to specify ID when adding to the FAISS vectorstore. #5190

Merged
merged 6 commits into from
May 25, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions langchain/vectorstores/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __add(
texts: Iterable[str],
embeddings: Iterable[List[float]],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
if not isinstance(self.docstore, AddableMixin):
Expand All @@ -107,6 +108,8 @@ def __add(
for i, text in enumerate(texts):
metadata = metadatas[i] if metadatas else {}
documents.append(Document(page_content=text, metadata=metadata))
if ids is None:
ids = [str(uuid.uuid4()) for _ in texts]
# Add to the index, the index_to_id mapping, and the docstore.
starting_len = len(self.index_to_docstore_id)
faiss = dependable_faiss_import()
Expand All @@ -115,10 +118,7 @@ def __add(
faiss.normalize_L2(vector)
self.index.add(vector)
# Get list of index, id, and docs.
full_info = [
(starting_len + i, str(uuid.uuid4()), doc)
for i, doc in enumerate(documents)
]
full_info = [(starting_len + i, ids[i], doc) for i, doc in enumerate(documents)]
# Add information to docstore and index.
self.docstore.add({_id: doc for _, _id, doc in full_info})
index_to_id = {index: _id for index, _id, _ in full_info}
Expand All @@ -129,13 +129,15 @@ def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore.

Args:
texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
ids: Optional list of unique IDs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be passed to __add right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


Returns:
List of ids from adding the texts into the vectorstore.
Expand All @@ -147,12 +149,13 @@ def add_texts(
)
# Embed and create the documents.
embeddings = [self.embedding_function(text) for text in texts]
return self.__add(texts, embeddings, metadatas, **kwargs)
return self.__add(texts, embeddings, metadatas=metadatas, ids=ids, **kwargs)

def add_embeddings(
self,
text_embeddings: Iterable[Tuple[str, List[float]]],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore.
Expand All @@ -161,6 +164,7 @@ def add_embeddings(
text_embeddings: Iterable pairs of string and embedding to
add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
ids: Optional list of unique IDs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be passed to __add right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


Returns:
List of ids from adding the texts into the vectorstore.
Expand All @@ -174,7 +178,7 @@ def add_embeddings(

texts = [te[0] for te in text_embeddings]
embeddings = [te[1] for te in text_embeddings]
return self.__add(texts, embeddings, metadatas, **kwargs)
return self.__add(texts, embeddings, metadatas=metadatas, ids=ids, **kwargs)

def similarity_search_with_score_by_vector(
self, embedding: List[float], k: int = 4
Expand Down Expand Up @@ -346,13 +350,13 @@ def merge_from(self, target: FAISS) -> None:
# Merge two IndexFlatL2
self.index.merge_from(target.index)

# Create new id for docs from target FAISS object
# Get id and docs from target FAISS object
full_info = []
for i in target.index_to_docstore_id:
doc = target.docstore.search(target.index_to_docstore_id[i])
for i, target_id in target.index_to_docstore_id.items():
doc = target.docstore.search(target_id)
if not isinstance(doc, Document):
raise ValueError("Document should be returned")
full_info.append((starting_len + i, str(uuid.uuid4()), doc))
full_info.append((starting_len + i, target_id, doc))

# Add information to docstore and index_to_docstore_id.
self.docstore.add({_id: doc for _, _id, doc in full_info})
Expand All @@ -366,6 +370,7 @@ def __from(
embeddings: List[List[float]],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
normalize_L2: bool = False,
**kwargs: Any,
) -> FAISS:
Expand All @@ -376,13 +381,13 @@ def __from(
faiss.normalize_L2(vector)
index.add(vector)
documents = []
if ids is None:
ids = [str(uuid.uuid4()) for _ in texts]
for i, text in enumerate(texts):
metadata = metadatas[i] if metadatas else {}
documents.append(Document(page_content=text, metadata=metadata))
index_to_id = {i: str(uuid.uuid4()) for i in range(len(documents))}
docstore = InMemoryDocstore(
{index_to_id[i]: doc for i, doc in enumerate(documents)}
)
index_to_id = dict(enumerate(ids))
docstore = InMemoryDocstore(dict(zip(index_to_id.values(), documents)))
return cls(
embedding.embed_query,
index,
Expand All @@ -398,6 +403,7 @@ def from_texts(
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> FAISS:
"""Construct FAISS wrapper from raw documents.
Expand All @@ -422,7 +428,8 @@ def from_texts(
texts,
embeddings,
embedding,
metadatas,
metadatas=metadatas,
ids=ids,
**kwargs,
)

Expand All @@ -432,6 +439,7 @@ def from_embeddings(
text_embeddings: List[Tuple[str, List[float]]],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be passed to __from right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

**kwargs: Any,
) -> FAISS:
"""Construct FAISS wrapper from raw documents.
Expand Down Expand Up @@ -459,7 +467,8 @@ def from_embeddings(
texts,
embeddings,
embedding,
metadatas,
metadatas=metadatas,
ids=ids,
**kwargs,
)

Expand Down