-
Notifications
You must be signed in to change notification settings - Fork 16.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow to specify ID when adding to the FAISS vectorstore. #5190
Changes from all commits
96da672
09ea5d5
c352c51
2c2b689
f371c17
cd75e3b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -96,6 +96,7 @@ def __add( | |
texts: Iterable[str], | ||
embeddings: Iterable[List[float]], | ||
metadatas: Optional[List[dict]] = None, | ||
ids: Optional[List[str]] = None, | ||
**kwargs: Any, | ||
) -> List[str]: | ||
if not isinstance(self.docstore, AddableMixin): | ||
|
@@ -107,6 +108,8 @@ def __add( | |
for i, text in enumerate(texts): | ||
metadata = metadatas[i] if metadatas else {} | ||
documents.append(Document(page_content=text, metadata=metadata)) | ||
if ids is None: | ||
ids = [str(uuid.uuid4()) for _ in texts] | ||
# Add to the index, the index_to_id mapping, and the docstore. | ||
starting_len = len(self.index_to_docstore_id) | ||
faiss = dependable_faiss_import() | ||
|
@@ -115,10 +118,7 @@ def __add( | |
faiss.normalize_L2(vector) | ||
self.index.add(vector) | ||
# Get list of index, id, and docs. | ||
full_info = [ | ||
(starting_len + i, str(uuid.uuid4()), doc) | ||
for i, doc in enumerate(documents) | ||
] | ||
full_info = [(starting_len + i, ids[i], doc) for i, doc in enumerate(documents)] | ||
# Add information to docstore and index. | ||
self.docstore.add({_id: doc for _, _id, doc in full_info}) | ||
index_to_id = {index: _id for index, _id, _ in full_info} | ||
|
@@ -129,13 +129,15 @@ def add_texts( | |
self, | ||
texts: Iterable[str], | ||
metadatas: Optional[List[dict]] = None, | ||
ids: Optional[List[str]] = None, | ||
**kwargs: Any, | ||
) -> List[str]: | ||
"""Run more texts through the embeddings and add to the vectorstore. | ||
|
||
Args: | ||
texts: Iterable of strings to add to the vectorstore. | ||
metadatas: Optional list of metadatas associated with the texts. | ||
ids: Optional list of unique IDs. | ||
|
||
Returns: | ||
List of ids from adding the texts into the vectorstore. | ||
|
@@ -147,12 +149,13 @@ def add_texts( | |
) | ||
# Embed and create the documents. | ||
embeddings = [self.embedding_function(text) for text in texts] | ||
return self.__add(texts, embeddings, metadatas, **kwargs) | ||
return self.__add(texts, embeddings, metadatas=metadatas, ids=ids, **kwargs) | ||
|
||
def add_embeddings( | ||
self, | ||
text_embeddings: Iterable[Tuple[str, List[float]]], | ||
metadatas: Optional[List[dict]] = None, | ||
ids: Optional[List[str]] = None, | ||
**kwargs: Any, | ||
) -> List[str]: | ||
"""Run more texts through the embeddings and add to the vectorstore. | ||
|
@@ -161,6 +164,7 @@ def add_embeddings( | |
text_embeddings: Iterable pairs of string and embedding to | ||
add to the vectorstore. | ||
metadatas: Optional list of metadatas associated with the texts. | ||
ids: Optional list of unique IDs. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be passed to __add right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
Returns: | ||
List of ids from adding the texts into the vectorstore. | ||
|
@@ -174,7 +178,7 @@ def add_embeddings( | |
|
||
texts = [te[0] for te in text_embeddings] | ||
embeddings = [te[1] for te in text_embeddings] | ||
return self.__add(texts, embeddings, metadatas, **kwargs) | ||
return self.__add(texts, embeddings, metadatas=metadatas, ids=ids, **kwargs) | ||
|
||
def similarity_search_with_score_by_vector( | ||
self, embedding: List[float], k: int = 4 | ||
|
@@ -346,13 +350,13 @@ def merge_from(self, target: FAISS) -> None: | |
# Merge two IndexFlatL2 | ||
self.index.merge_from(target.index) | ||
|
||
# Create new id for docs from target FAISS object | ||
# Get id and docs from target FAISS object | ||
full_info = [] | ||
for i in target.index_to_docstore_id: | ||
doc = target.docstore.search(target.index_to_docstore_id[i]) | ||
for i, target_id in target.index_to_docstore_id.items(): | ||
doc = target.docstore.search(target_id) | ||
if not isinstance(doc, Document): | ||
raise ValueError("Document should be returned") | ||
full_info.append((starting_len + i, str(uuid.uuid4()), doc)) | ||
full_info.append((starting_len + i, target_id, doc)) | ||
|
||
# Add information to docstore and index_to_docstore_id. | ||
self.docstore.add({_id: doc for _, _id, doc in full_info}) | ||
|
@@ -366,6 +370,7 @@ def __from( | |
embeddings: List[List[float]], | ||
embedding: Embeddings, | ||
metadatas: Optional[List[dict]] = None, | ||
ids: Optional[List[str]] = None, | ||
normalize_L2: bool = False, | ||
**kwargs: Any, | ||
) -> FAISS: | ||
|
@@ -376,13 +381,13 @@ def __from( | |
faiss.normalize_L2(vector) | ||
index.add(vector) | ||
documents = [] | ||
if ids is None: | ||
ids = [str(uuid.uuid4()) for _ in texts] | ||
for i, text in enumerate(texts): | ||
metadata = metadatas[i] if metadatas else {} | ||
documents.append(Document(page_content=text, metadata=metadata)) | ||
index_to_id = {i: str(uuid.uuid4()) for i in range(len(documents))} | ||
docstore = InMemoryDocstore( | ||
{index_to_id[i]: doc for i, doc in enumerate(documents)} | ||
) | ||
index_to_id = dict(enumerate(ids)) | ||
docstore = InMemoryDocstore(dict(zip(index_to_id.values(), documents))) | ||
return cls( | ||
embedding.embed_query, | ||
index, | ||
|
@@ -398,6 +403,7 @@ def from_texts( | |
texts: List[str], | ||
embedding: Embeddings, | ||
metadatas: Optional[List[dict]] = None, | ||
ids: Optional[List[str]] = None, | ||
**kwargs: Any, | ||
) -> FAISS: | ||
"""Construct FAISS wrapper from raw documents. | ||
|
@@ -422,7 +428,8 @@ def from_texts( | |
texts, | ||
embeddings, | ||
embedding, | ||
metadatas, | ||
metadatas=metadatas, | ||
ids=ids, | ||
**kwargs, | ||
) | ||
|
||
|
@@ -432,6 +439,7 @@ def from_embeddings( | |
text_embeddings: List[Tuple[str, List[float]]], | ||
embedding: Embeddings, | ||
metadatas: Optional[List[dict]] = None, | ||
ids: Optional[List[str]] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be passed to __from right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
**kwargs: Any, | ||
) -> FAISS: | ||
"""Construct FAISS wrapper from raw documents. | ||
|
@@ -459,7 +467,8 @@ def from_embeddings( | |
texts, | ||
embeddings, | ||
embedding, | ||
metadatas, | ||
metadatas=metadatas, | ||
ids=ids, | ||
**kwargs, | ||
) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be passed to __add right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done