diff --git a/libs/community/langchain_community/retrievers/knn.py b/libs/community/langchain_community/retrievers/knn.py index 045d11cc1d357..434837f1e3080 100644 --- a/libs/community/langchain_community/retrievers/knn.py +++ b/libs/community/langchain_community/retrievers/knn.py @@ -5,7 +5,7 @@ from __future__ import annotations import concurrent.futures -from typing import Any, List, Optional +from typing import Any, Iterable, List, Optional import numpy as np from langchain_core.callbacks import CallbackManagerForRetrieverRun @@ -38,6 +38,8 @@ class KNNRetriever(BaseRetriever): """Index of embeddings.""" texts: List[str] """List of texts to index.""" + metadatas: Optional[List[dict]] = None + """List of metadatas corresponding with each text.""" k: int = 4 """Number of results to return.""" relevancy_threshold: Optional[float] = None @@ -51,10 +53,32 @@ class Config: @classmethod def from_texts( - cls, texts: List[str], embeddings: Embeddings, **kwargs: Any + cls, + texts: List[str], + embeddings: Embeddings, + metadatas: Optional[List[dict]] = None, + **kwargs: Any, ) -> KNNRetriever: index = create_index(texts, embeddings) - return cls(embeddings=embeddings, index=index, texts=texts, **kwargs) + return cls( + embeddings=embeddings, + index=index, + texts=texts, + metadatas=metadatas, + **kwargs, + ) + + @classmethod + def from_documents( + cls, + documents: Iterable[Document], + embeddings: Embeddings, + **kwargs: Any, + ) -> KNNRetriever: + texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents)) + return cls.from_texts( + texts=texts, embeddings=embeddings, metadatas=metadatas, **kwargs + ) def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun @@ -71,7 +95,10 @@ def _get_relevant_documents( normalized_similarities = (similarities - np.min(similarities)) / denominator top_k_results = [ - Document(page_content=self.texts[row]) + Document( + page_content=self.texts[row], + metadata=self.metadatas[row] if self.metadatas else {}, + ) for row in sorted_ix[0 : self.k] if ( self.relevancy_threshold is None diff --git a/libs/community/tests/unit_tests/retrievers/test_knn.py b/libs/community/tests/unit_tests/retrievers/test_knn.py index 6132021c8882a..2956b2b58e62a 100644 --- a/libs/community/tests/unit_tests/retrievers/test_knn.py +++ b/libs/community/tests/unit_tests/retrievers/test_knn.py @@ -1,3 +1,5 @@ +from langchain_core.documents import Document + from langchain_community.embeddings import FakeEmbeddings from langchain_community.retrievers.knn import KNNRetriever @@ -9,3 +11,19 @@ def test_from_texts(self) -> None: texts=input_texts, embeddings=FakeEmbeddings(size=100) ) assert len(knn_retriever.texts) == 3 + + def test_from_documents(self) -> None: + input_docs = [ + Document(page_content="I have a pen.", metadata={"page": 1}), + Document(page_content="Do you have a pen?", metadata={"page": 2}), + Document(page_content="I have a bag.", metadata={"page": 3}), + ] + knn_retriever = KNNRetriever.from_documents( + documents=input_docs, embeddings=FakeEmbeddings(size=100) + ) + assert knn_retriever.texts == [ + "I have a pen.", + "Do you have a pen?", + "I have a bag.", + ] + assert knn_retriever.metadatas == [{"page": 1}, {"page": 2}, {"page": 3}]