Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

langchain: Add async methods to TimeWeightedVectorStoreRetriever #19606

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 44 additions & 10 deletions libs/langchain/langchain/retrievers/time_weighted_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple

from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Field
from langchain_core.retrievers import BaseRetriever
Expand Down Expand Up @@ -89,17 +92,26 @@ def get_salient_docs(self, query: str) -> Dict[int, Tuple[Document, float]]:
results[buffer_idx] = (doc, relevance)
return results

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
async def aget_salient_docs(self, query: str) -> Dict[int, Tuple[Document, float]]:
"""Return documents that are salient to the query."""
docs_and_scores: List[Tuple[Document, float]]
docs_and_scores = (
await self.vectorstore.asimilarity_search_with_relevance_scores(
query, **self.search_kwargs
)
)
results = {}
for fetched_doc, relevance in docs_and_scores:
if "buffer_idx" in fetched_doc.metadata:
buffer_idx = fetched_doc.metadata["buffer_idx"]
doc = self.memory_stream[buffer_idx]
results[buffer_idx] = (doc, relevance)
return results

def _get_rescored_docs(
self, docs_and_scores: Dict[Any, Tuple[Document, Optional[float]]]
) -> List[Document]:
"""Return documents that are relevant to the query."""
current_time = datetime.datetime.now()
docs_and_scores = {
doc.metadata["buffer_idx"]: (doc, self.default_salience)
for doc in self.memory_stream[-self.k :]
}
# If a doc is considered salient, update the salience score
docs_and_scores.update(self.get_salient_docs(query))
rescored_docs = [
(doc, self._get_combined_score(doc, relevance, current_time))
for doc, relevance in docs_and_scores.values()
Expand All @@ -114,6 +126,28 @@ def _get_relevant_documents(
result.append(buffered_doc)
return result

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
docs_and_scores = {
doc.metadata["buffer_idx"]: (doc, self.default_salience)
for doc in self.memory_stream[-self.k :]
}
# If a doc is considered salient, update the salience score
docs_and_scores.update(self.get_salient_docs(query))
return self._get_rescored_docs(docs_and_scores)

async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
docs_and_scores = {
doc.metadata["buffer_idx"]: (doc, self.default_salience)
for doc in self.memory_stream[-self.k :]
}
# If a doc is considered salient, update the salience score
docs_and_scores.update(await self.aget_salient_docs(query))
return self._get_rescored_docs(docs_and_scores)

def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
"""Add documents to vectorstore."""
current_time = kwargs.get("current_time")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,45 +36,13 @@ def add_texts(
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore.

Args:
texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
kwargs: vectorstore specific parameters

Returns:
List of ids from adding the texts into the vectorstore.
"""
return list(texts)

async def aadd_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore."""
raise NotImplementedError

def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to query."""
return []

@classmethod
def from_documents(
cls: Type["MockVectorStore"],
documents: List[Document],
embedding: Embeddings,
**kwargs: Any,
) -> "MockVectorStore":
"""Return VectorStore initialized from documents and embeddings."""
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
return cls.from_texts(texts, embedding, metadatas=metadatas, **kwargs)

@classmethod
def from_texts(
cls: Type["MockVectorStore"],
Expand All @@ -83,7 +51,6 @@ def from_texts(
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> "MockVectorStore":
"""Return VectorStore initialized from texts and embeddings."""
return cls()

def _similarity_search_with_relevance_scores(
Expand All @@ -92,12 +59,16 @@ def _similarity_search_with_relevance_scores(
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs and similarity scores, normalized on a scale from 0 to 1.

0 is dissimilar, 1 is most similar.
"""
return [(doc, 0.5) for doc in _get_example_memories()]

async def _asimilarity_search_with_relevance_scores(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
return self._similarity_search_with_relevance_scores(query, k, **kwargs)


@pytest.fixture
def time_weighted_retriever() -> TimeWeightedVectorStoreRetriever:
Expand Down Expand Up @@ -146,6 +117,18 @@ def test_get_salient_docs(
assert doc in want


async def test_aget_salient_docs(
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
) -> None:
query = "Test query"
docs_and_scores = await time_weighted_retriever.aget_salient_docs(query)
want = [(doc, 0.5) for doc in _get_example_memories()]
assert isinstance(docs_and_scores, dict)
assert len(docs_and_scores) == len(want)
for k, doc in docs_and_scores.items():
assert doc in want


def test_get_relevant_documents(
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
) -> None:
Expand All @@ -164,6 +147,24 @@ def test_get_relevant_documents(
assert now - timedelta(hours=1) < d.metadata["last_accessed_at"] <= now


async def test_aget_relevant_documents(
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
) -> None:
query = "Test query"
relevant_documents = await time_weighted_retriever.aget_relevant_documents(query)
want = [(doc, 0.5) for doc in _get_example_memories()]
assert isinstance(relevant_documents, list)
assert len(relevant_documents) == len(want)
now = datetime.now()
for doc in relevant_documents:
# assert that the last_accessed_at is close to now.
assert now - timedelta(hours=1) < doc.metadata["last_accessed_at"] <= now

# assert that the last_accessed_at in the memory stream is updated.
for d in time_weighted_retriever.memory_stream:
assert now - timedelta(hours=1) < d.metadata["last_accessed_at"] <= now


def test_add_documents(
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
) -> None:
Expand All @@ -175,3 +176,16 @@ def test_add_documents(
time_weighted_retriever.memory_stream[-1].page_content
== documents[0].page_content
)


async def test_aadd_documents(
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
) -> None:
documents = [Document(page_content="test_add_documents document")]
added_documents = await time_weighted_retriever.aadd_documents(documents)
assert isinstance(added_documents, list)
assert len(added_documents) == 1
assert (
time_weighted_retriever.memory_stream[-1].page_content
== documents[0].page_content
)
Loading