Skip to content

Commit

Permalink
langchain[patch[: Add async methods to TimeWeightedVectorStoreRetriev…
Browse files Browse the repository at this point in the history
…er (#19606)
  • Loading branch information
cbornet authored and hinthornw committed Apr 26, 2024
1 parent 99e3286 commit 9f71ac4
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 47 deletions.
54 changes: 44 additions & 10 deletions libs/langchain/langchain/retrievers/time_weighted_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple

from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Field
from langchain_core.retrievers import BaseRetriever
Expand Down Expand Up @@ -89,17 +92,26 @@ def get_salient_docs(self, query: str) -> Dict[int, Tuple[Document, float]]:
results[buffer_idx] = (doc, relevance)
return results

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
async def aget_salient_docs(self, query: str) -> Dict[int, Tuple[Document, float]]:
"""Return documents that are salient to the query."""
docs_and_scores: List[Tuple[Document, float]]
docs_and_scores = (
await self.vectorstore.asimilarity_search_with_relevance_scores(
query, **self.search_kwargs
)
)
results = {}
for fetched_doc, relevance in docs_and_scores:
if "buffer_idx" in fetched_doc.metadata:
buffer_idx = fetched_doc.metadata["buffer_idx"]
doc = self.memory_stream[buffer_idx]
results[buffer_idx] = (doc, relevance)
return results

def _get_rescored_docs(
self, docs_and_scores: Dict[Any, Tuple[Document, Optional[float]]]
) -> List[Document]:
"""Return documents that are relevant to the query."""
current_time = datetime.datetime.now()
docs_and_scores = {
doc.metadata["buffer_idx"]: (doc, self.default_salience)
for doc in self.memory_stream[-self.k :]
}
# If a doc is considered salient, update the salience score
docs_and_scores.update(self.get_salient_docs(query))
rescored_docs = [
(doc, self._get_combined_score(doc, relevance, current_time))
for doc, relevance in docs_and_scores.values()
Expand All @@ -114,6 +126,28 @@ def _get_relevant_documents(
result.append(buffered_doc)
return result

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
docs_and_scores = {
doc.metadata["buffer_idx"]: (doc, self.default_salience)
for doc in self.memory_stream[-self.k :]
}
# If a doc is considered salient, update the salience score
docs_and_scores.update(self.get_salient_docs(query))
return self._get_rescored_docs(docs_and_scores)

async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
docs_and_scores = {
doc.metadata["buffer_idx"]: (doc, self.default_salience)
for doc in self.memory_stream[-self.k :]
}
# If a doc is considered salient, update the salience score
docs_and_scores.update(await self.aget_salient_docs(query))
return self._get_rescored_docs(docs_and_scores)

def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
"""Add documents to vectorstore."""
current_time = kwargs.get("current_time")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,45 +36,13 @@ def add_texts(
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore.
Args:
texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
kwargs: vectorstore specific parameters
Returns:
List of ids from adding the texts into the vectorstore.
"""
return list(texts)

async def aadd_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore."""
raise NotImplementedError

def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to query."""
return []

@classmethod
def from_documents(
cls: Type["MockVectorStore"],
documents: List[Document],
embedding: Embeddings,
**kwargs: Any,
) -> "MockVectorStore":
"""Return VectorStore initialized from documents and embeddings."""
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
return cls.from_texts(texts, embedding, metadatas=metadatas, **kwargs)

@classmethod
def from_texts(
cls: Type["MockVectorStore"],
Expand All @@ -83,7 +51,6 @@ def from_texts(
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> "MockVectorStore":
"""Return VectorStore initialized from texts and embeddings."""
return cls()

def _similarity_search_with_relevance_scores(
Expand All @@ -92,12 +59,16 @@ def _similarity_search_with_relevance_scores(
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs and similarity scores, normalized on a scale from 0 to 1.
0 is dissimilar, 1 is most similar.
"""
return [(doc, 0.5) for doc in _get_example_memories()]

async def _asimilarity_search_with_relevance_scores(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
return self._similarity_search_with_relevance_scores(query, k, **kwargs)


@pytest.fixture
def time_weighted_retriever() -> TimeWeightedVectorStoreRetriever:
Expand Down Expand Up @@ -146,6 +117,18 @@ def test_get_salient_docs(
assert doc in want


async def test_aget_salient_docs(
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
) -> None:
query = "Test query"
docs_and_scores = await time_weighted_retriever.aget_salient_docs(query)
want = [(doc, 0.5) for doc in _get_example_memories()]
assert isinstance(docs_and_scores, dict)
assert len(docs_and_scores) == len(want)
for k, doc in docs_and_scores.items():
assert doc in want


def test_get_relevant_documents(
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
) -> None:
Expand All @@ -164,6 +147,24 @@ def test_get_relevant_documents(
assert now - timedelta(hours=1) < d.metadata["last_accessed_at"] <= now


async def test_aget_relevant_documents(
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
) -> None:
query = "Test query"
relevant_documents = await time_weighted_retriever.aget_relevant_documents(query)
want = [(doc, 0.5) for doc in _get_example_memories()]
assert isinstance(relevant_documents, list)
assert len(relevant_documents) == len(want)
now = datetime.now()
for doc in relevant_documents:
# assert that the last_accessed_at is close to now.
assert now - timedelta(hours=1) < doc.metadata["last_accessed_at"] <= now

# assert that the last_accessed_at in the memory stream is updated.
for d in time_weighted_retriever.memory_stream:
assert now - timedelta(hours=1) < d.metadata["last_accessed_at"] <= now


def test_add_documents(
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
) -> None:
Expand All @@ -175,3 +176,16 @@ def test_add_documents(
time_weighted_retriever.memory_stream[-1].page_content
== documents[0].page_content
)


async def test_aadd_documents(
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
) -> None:
documents = [Document(page_content="test_add_documents document")]
added_documents = await time_weighted_retriever.aadd_documents(documents)
assert isinstance(added_documents, list)
assert len(added_documents) == 1
assert (
time_weighted_retriever.memory_stream[-1].page_content
== documents[0].page_content
)

0 comments on commit 9f71ac4

Please sign in to comment.