Skip to content

Commit

Permalink
feat: add sentence window retrieval (#7997)
Browse files Browse the repository at this point in the history
* initial import

* adding tests

* adding license and release notes

* adding missing release notes

* working with any type of doc store

* nit

* adding get_class_object to serialization package

* nit

* refactoring get_class_object()

* refactoring get_class_object()

* chaning type and var names

* more refactoring

* Update haystack/core/serialization.py

Co-authored-by: Vladimir Blagojevic <[email protected]>

* Update haystack/core/serialization.py

Co-authored-by: Vladimir Blagojevic <[email protected]>

* Update test/core/test_serialization.py

Co-authored-by: Vladimir Blagojevic <[email protected]>

* more refactoring

* more refactoring

* Pydoc syntax

---------

Co-authored-by: Vladimir Blagojevic <[email protected]>
  • Loading branch information
davidsbatista and vblagoje committed Jul 10, 2024
1 parent a77ea6b commit 07e3056
Show file tree
Hide file tree
Showing 6 changed files with 336 additions and 3 deletions.
3 changes: 2 additions & 1 deletion haystack/components/retrievers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@
from haystack.components.retrievers.filter_retriever import FilterRetriever
from haystack.components.retrievers.in_memory.bm25_retriever import InMemoryBM25Retriever
from haystack.components.retrievers.in_memory.embedding_retriever import InMemoryEmbeddingRetriever
from haystack.components.retrievers.sentence_window_retrieval import SentenceWindowRetrieval

__all__ = ["FilterRetriever", "InMemoryEmbeddingRetriever", "InMemoryBM25Retriever"]
__all__ = ["FilterRetriever", "InMemoryEmbeddingRetriever", "InMemoryBM25Retriever", "SentenceWindowRetrieval"]
139 changes: 139 additions & 0 deletions haystack/components/retrievers/sentence_window_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, List

from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict
from haystack.core.serialization import import_class_by_name
from haystack.document_stores.types import DocumentStore


@component
class SentenceWindowRetrieval:
"""
A component that retrieves surrounding documents of a given document from the document store.
This component is designed to work together with one of the existing retrievers, e.g. BM25Retriever,
EmbeddingRetriever. One of these retrievers can be used to retrieve documents based on a query and then use this
component to get the surrounding documents of the retrieved documents.
"""

def __init__(self, document_store: DocumentStore, window_size: int = 3):
"""
Creates a new SentenceWindowRetrieval component.
:param document_store: The document store to use for retrieving the surrounding documents.
:param window_size: The number of surrounding documents to retrieve.
"""
if window_size < 1:
raise ValueError("The window_size parameter must be greater than 0.")

self.window_size = window_size
self.document_store = document_store

@staticmethod
def merge_documents_text(documents: List[Document]) -> str:
"""
Merge a list of document text into a single string.
This functions concatenates the textual content of a list of documents into a single string, eliminating any
overlapping content.
:param documents: List of Documents to merge.
"""
sorted_docs = sorted(documents, key=lambda doc: doc.meta["split_idx_start"])
merged_text = ""
last_idx_end = 0
for doc in sorted_docs:
start = doc.meta["split_idx_start"] # start of the current content

# if the start of the current content is before the end of the last appended content, adjust it
start = max(start, last_idx_end)

# append the non-overlapping part to the merged text
merged_text = merged_text.strip()
merged_text += doc.content[start - doc.meta["split_idx_start"] :] # type: ignore

# update the last end index
last_idx_end = doc.meta["split_idx_start"] + len(doc.content) # type: ignore

return merged_text

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
docstore = self.document_store.to_dict()
return default_to_dict(self, document_store=docstore, window_size=self.window_size)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SentenceWindowRetrieval":
"""
Deserializes the component from a dictionary.
:returns:
Deserialized component.
"""
init_params = data.get("init_parameters", {})

if "document_store" not in init_params:
raise DeserializationError("Missing 'document_store' in serialization data")
if "type" not in init_params["document_store"]:
raise DeserializationError("Missing 'type' in document store's serialization data")

# deserialize the document store
doc_store_data = data["init_parameters"]["document_store"]
try:
doc_store_class = import_class_by_name(doc_store_data["type"])
except ImportError as e:
raise DeserializationError(f"Class '{doc_store_data['type']}' not correctly imported") from e

data["init_parameters"]["document_store"] = default_from_dict(doc_store_class, doc_store_data)

# deserialize the component
return default_from_dict(cls, data)

@component.output_types(context_windows=List[str])
def run(self, retrieved_documents: List[Document]):
"""
Based on the `source_id` and on the `doc.meta['split_id']` get surrounding documents from the document store.
Implements the logic behind the sentence-window technique, retrieving the surrounding documents of a given
document from the document store.
:param retrieved_documents: List of retrieved documents from the previous retriever.
:type retrieved_documents: List[Document]
:returns:
A dictionary with the following keys:
- `context_windows`: List of strings representing the context windows of the retrieved documents.
"""

if not all("split_id" in doc.meta for doc in retrieved_documents):
raise ValueError("The retrieved documents must have 'split_id' in the metadata.")

if not all("source_id" in doc.meta for doc in retrieved_documents):
raise ValueError("The retrieved documents must have 'source_id' in the metadata.")

context_windows = []
for doc in retrieved_documents:
source_id = doc.meta["source_id"]
split_id = doc.meta["split_id"]
min_before = min(list(range(split_id - 1, split_id - self.window_size - 1, -1)))
max_after = max(list(range(split_id + 1, split_id + self.window_size + 1, 1)))
context_docs = self.document_store.filter_documents(
{
"operator": "AND",
"conditions": [
{"field": "source_id", "operator": "==", "value": source_id},
{"field": "split_id", "operator": ">=", "value": min_before},
{"field": "split_id", "operator": "<=", "value": max_after},
],
}
)
context_windows.append(self.merge_documents_text(context_docs))

return {"context_windows": context_windows}
25 changes: 24 additions & 1 deletion haystack/core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import inspect
from collections.abc import Callable
from dataclasses import dataclass
from importlib import import_module
from typing import Any, Dict, Optional, Type

from haystack.core.component.component import _hook_component_init
from haystack.core.component.component import _hook_component_init, logger
from haystack.core.errors import DeserializationError, SerializationError


Expand Down Expand Up @@ -189,3 +190,25 @@ def default_from_dict(cls: Type[object], data: Dict[str, Any]) -> Any:
if data["type"] != generate_qualified_class_name(cls):
raise DeserializationError(f"Class '{data['type']}' can't be deserialized as '{cls.__name__}'")
return cls(**init_params)


def import_class_by_name(fully_qualified_name: str) -> Type[object]:
"""
Utility function to import (load) a class object based on its fully qualified class name.
This function dynamically imports a class based on its string name.
It splits the name into module path and class name, imports the module,
and returns the class object.
:param fully_qualified_name: the fully qualified class name as a string
:returns: the class object.
:raises ImportError: If the class cannot be imported or found.
"""
try:
module_path, class_name = fully_qualified_name.rsplit(".", 1)
logger.debug(f"Attempting to import class '{class_name}' from module '{module_path}'")
module = import_module(module_path)
return getattr(module, class_name)
except (ImportError, AttributeError) as error:
logger.error(f"Failed to import class '{fully_qualified_name}'")
raise ImportError(f"Could not import class '{fully_qualified_name}'") from error
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---

features:
- |
Adding a new component allowing to perform sentence-window retrieval, i.e. retrieves surrounding documents of a
given document from the document store. This is useful when a document is split into multiple chunks and you want to
retrieve the surrounding context of a given chunk.
143 changes: 143 additions & 0 deletions test/components/retrievers/test_sentence_window_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import pytest

from haystack import Document, DeserializationError
from haystack.components.retrievers.sentence_window_retrieval import SentenceWindowRetrieval
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.preprocessors import DocumentSplitter


class TestSentenceWindowRetrieval:
def test_init_default(self):
retrieval = SentenceWindowRetrieval(InMemoryDocumentStore())
assert retrieval.window_size == 3

def test_init_with_parameters(self):
retrieval = SentenceWindowRetrieval(InMemoryDocumentStore(), window_size=5)
assert retrieval.window_size == 5

def test_init_with_invalid_window_size_parameter(self):
with pytest.raises(ValueError):
SentenceWindowRetrieval(InMemoryDocumentStore(), window_size=-2)

def test_merge_documents(self):
docs = [
{
"id": "doc_0",
"content": "This is a text with some words. There is a ",
"source_id": "c5d7c632affc486d0cfe7b3c0f4dc1d3896ea720da2b538d6d10b104a3df5f99",
"page_number": 1,
"split_id": 0,
"split_idx_start": 0,
"_split_overlap": [{"doc_id": "doc_1", "range": (0, 22)}],
},
{
"id": "doc_1",
"content": "some words. There is a second sentence. And there is ",
"source_id": "c5d7c632affc486d0cfe7b3c0f4dc1d3896ea720da2b538d6d10b104a3df5f99",
"page_number": 1,
"split_id": 1,
"split_idx_start": 21,
"_split_overlap": [{"doc_id": "doc_0", "range": (20, 42)}, {"doc_id": "doc_2", "range": (0, 29)}],
},
{
"id": "doc_2",
"content": "second sentence. And there is also a third sentence",
"source_id": "c5d7c632affc486d0cfe7b3c0f4dc1d3896ea720da2b538d6d10b104a3df5f99",
"page_number": 1,
"split_id": 2,
"split_idx_start": 45,
"_split_overlap": [{"doc_id": "doc_1", "range": (23, 52)}],
},
]
merged_text = SentenceWindowRetrieval.merge_documents_text([Document.from_dict(doc) for doc in docs])
expected = "This is a text with some words. There is a second sentence. And there is also a third sentence"
assert merged_text == expected

def test_to_dict(self):
window_retrieval = SentenceWindowRetrieval(InMemoryDocumentStore())
data = window_retrieval.to_dict()

assert data["type"] == "haystack.components.retrievers.sentence_window_retrieval.SentenceWindowRetrieval"
assert data["init_parameters"]["window_size"] == 3
assert (
data["init_parameters"]["document_store"]["type"]
== "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore"
)

def test_from_dict(self):
data = {
"type": "haystack.components.retrievers.sentence_window_retrieval.SentenceWindowRetrieval",
"init_parameters": {
"document_store": {
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
"init_parameters": {},
},
"window_size": 5,
},
}
component = SentenceWindowRetrieval.from_dict(data)
assert isinstance(component.document_store, InMemoryDocumentStore)
assert component.window_size == 5

def test_from_dict_without_docstore(self):
data = {"type": "SentenceWindowRetrieval", "init_parameters": {}}
with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"):
SentenceWindowRetrieval.from_dict(data)

def test_from_dict_without_docstore_type(self):
data = {"type": "SentenceWindowRetrieval", "init_parameters": {"document_store": {"init_parameters": {}}}}
with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"):
SentenceWindowRetrieval.from_dict(data)

def test_from_dict_non_existing_docstore(self):
data = {
"type": "SentenceWindowRetrieval",
"init_parameters": {"document_store": {"type": "Nonexisting.Docstore", "init_parameters": {}}},
}
with pytest.raises(DeserializationError):
SentenceWindowRetrieval.from_dict(data)

def test_document_without_split_id(self):
docs = [
Document(content="This is a text with some words. There is a ", meta={"id": "doc_0"}),
Document(content="some words. There is a second sentence. And there is ", meta={"id": "doc_1"}),
]
with pytest.raises(ValueError):
retriever = SentenceWindowRetrieval(document_store=InMemoryDocumentStore(), window_size=3)
retriever.run(retrieved_documents=docs)

def test_document_without_source_id(self):
docs = [
Document(content="This is a text with some words. There is a ", meta={"id": "doc_0", "split_id": 0}),
Document(
content="some words. There is a second sentence. And there is ", meta={"id": "doc_1", "split_id": 1}
),
]
with pytest.raises(ValueError):
retriever = SentenceWindowRetrieval(document_store=InMemoryDocumentStore(), window_size=3)
retriever.run(retrieved_documents=docs)

@pytest.mark.integration
def test_run_with_pipeline(self):
splitter = DocumentSplitter(split_length=10, split_overlap=5, split_by="word")
text = (
"This is a text with some words. There is a second sentence. And there is also a third sentence. "
"It also contains a fourth sentence. And a fifth sentence. And a sixth sentence. And a seventh sentence"
)

doc = Document(content=text)

docs = splitter.run([doc])
ds = InMemoryDocumentStore()
ds.write_documents(docs["documents"])

retriever = SentenceWindowRetrieval(document_store=ds, window_size=3)
result = retriever.run(retrieved_documents=[list(ds.storage.values())[3]])
expected = {
"context_windows": [
"This is a text with some words. There is a second sentence. And there is also a third sentence. It "
"also contains a fourth sentence. And a fifth sentence. And a sixth sentence. And a seventh sentence"
]
}

assert result == expected
22 changes: 21 additions & 1 deletion test/core/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import datetime

import sys
from unittest.mock import Mock

Expand All @@ -10,7 +12,12 @@
from haystack.core.component import component
from haystack.core.errors import DeserializationError
from haystack.testing import factory
from haystack.core.serialization import default_to_dict, default_from_dict, generate_qualified_class_name
from haystack.core.serialization import (
default_to_dict,
default_from_dict,
generate_qualified_class_name,
import_class_by_name,
)


def test_default_component_to_dict():
Expand Down Expand Up @@ -87,3 +94,16 @@ def test_get_qualified_class_name():
comp = MyComponent()
res = generate_qualified_class_name(type(comp))
assert res == "haystack.testing.factory.MyComponent"


def test_import_class_by_name():
data = "haystack.core.pipeline.Pipeline"
class_object = import_class_by_name(data)
class_instance = class_object()
assert isinstance(class_instance, Pipeline)


def test_import_class_by_name_no_valid_class():
data = "some.invalid.class"
with pytest.raises(ImportError):
import_class_by_name(data)

0 comments on commit 07e3056

Please sign in to comment.