-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add sentence window retrieval (#7997)
* initial import * adding tests * adding license and release notes * adding missing release notes * working with any type of doc store * nit * adding get_class_object to serialization package * nit * refactoring get_class_object() * refactoring get_class_object() * chaning type and var names * more refactoring * Update haystack/core/serialization.py Co-authored-by: Vladimir Blagojevic <[email protected]> * Update haystack/core/serialization.py Co-authored-by: Vladimir Blagojevic <[email protected]> * Update test/core/test_serialization.py Co-authored-by: Vladimir Blagojevic <[email protected]> * more refactoring * more refactoring * Pydoc syntax --------- Co-authored-by: Vladimir Blagojevic <[email protected]>
- Loading branch information
1 parent
a77ea6b
commit 07e3056
Showing
6 changed files
with
336 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
139 changes: 139 additions & 0 deletions
139
haystack/components/retrievers/sentence_window_retrieval.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import Any, Dict, List | ||
|
||
from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict | ||
from haystack.core.serialization import import_class_by_name | ||
from haystack.document_stores.types import DocumentStore | ||
|
||
|
||
@component | ||
class SentenceWindowRetrieval: | ||
""" | ||
A component that retrieves surrounding documents of a given document from the document store. | ||
This component is designed to work together with one of the existing retrievers, e.g. BM25Retriever, | ||
EmbeddingRetriever. One of these retrievers can be used to retrieve documents based on a query and then use this | ||
component to get the surrounding documents of the retrieved documents. | ||
""" | ||
|
||
def __init__(self, document_store: DocumentStore, window_size: int = 3): | ||
""" | ||
Creates a new SentenceWindowRetrieval component. | ||
:param document_store: The document store to use for retrieving the surrounding documents. | ||
:param window_size: The number of surrounding documents to retrieve. | ||
""" | ||
if window_size < 1: | ||
raise ValueError("The window_size parameter must be greater than 0.") | ||
|
||
self.window_size = window_size | ||
self.document_store = document_store | ||
|
||
@staticmethod | ||
def merge_documents_text(documents: List[Document]) -> str: | ||
""" | ||
Merge a list of document text into a single string. | ||
This functions concatenates the textual content of a list of documents into a single string, eliminating any | ||
overlapping content. | ||
:param documents: List of Documents to merge. | ||
""" | ||
sorted_docs = sorted(documents, key=lambda doc: doc.meta["split_idx_start"]) | ||
merged_text = "" | ||
last_idx_end = 0 | ||
for doc in sorted_docs: | ||
start = doc.meta["split_idx_start"] # start of the current content | ||
|
||
# if the start of the current content is before the end of the last appended content, adjust it | ||
start = max(start, last_idx_end) | ||
|
||
# append the non-overlapping part to the merged text | ||
merged_text = merged_text.strip() | ||
merged_text += doc.content[start - doc.meta["split_idx_start"] :] # type: ignore | ||
|
||
# update the last end index | ||
last_idx_end = doc.meta["split_idx_start"] + len(doc.content) # type: ignore | ||
|
||
return merged_text | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
""" | ||
Serializes the component to a dictionary. | ||
:returns: | ||
Dictionary with serialized data. | ||
""" | ||
docstore = self.document_store.to_dict() | ||
return default_to_dict(self, document_store=docstore, window_size=self.window_size) | ||
|
||
@classmethod | ||
def from_dict(cls, data: Dict[str, Any]) -> "SentenceWindowRetrieval": | ||
""" | ||
Deserializes the component from a dictionary. | ||
:returns: | ||
Deserialized component. | ||
""" | ||
init_params = data.get("init_parameters", {}) | ||
|
||
if "document_store" not in init_params: | ||
raise DeserializationError("Missing 'document_store' in serialization data") | ||
if "type" not in init_params["document_store"]: | ||
raise DeserializationError("Missing 'type' in document store's serialization data") | ||
|
||
# deserialize the document store | ||
doc_store_data = data["init_parameters"]["document_store"] | ||
try: | ||
doc_store_class = import_class_by_name(doc_store_data["type"]) | ||
except ImportError as e: | ||
raise DeserializationError(f"Class '{doc_store_data['type']}' not correctly imported") from e | ||
|
||
data["init_parameters"]["document_store"] = default_from_dict(doc_store_class, doc_store_data) | ||
|
||
# deserialize the component | ||
return default_from_dict(cls, data) | ||
|
||
@component.output_types(context_windows=List[str]) | ||
def run(self, retrieved_documents: List[Document]): | ||
""" | ||
Based on the `source_id` and on the `doc.meta['split_id']` get surrounding documents from the document store. | ||
Implements the logic behind the sentence-window technique, retrieving the surrounding documents of a given | ||
document from the document store. | ||
:param retrieved_documents: List of retrieved documents from the previous retriever. | ||
:type retrieved_documents: List[Document] | ||
:returns: | ||
A dictionary with the following keys: | ||
- `context_windows`: List of strings representing the context windows of the retrieved documents. | ||
""" | ||
|
||
if not all("split_id" in doc.meta for doc in retrieved_documents): | ||
raise ValueError("The retrieved documents must have 'split_id' in the metadata.") | ||
|
||
if not all("source_id" in doc.meta for doc in retrieved_documents): | ||
raise ValueError("The retrieved documents must have 'source_id' in the metadata.") | ||
|
||
context_windows = [] | ||
for doc in retrieved_documents: | ||
source_id = doc.meta["source_id"] | ||
split_id = doc.meta["split_id"] | ||
min_before = min(list(range(split_id - 1, split_id - self.window_size - 1, -1))) | ||
max_after = max(list(range(split_id + 1, split_id + self.window_size + 1, 1))) | ||
context_docs = self.document_store.filter_documents( | ||
{ | ||
"operator": "AND", | ||
"conditions": [ | ||
{"field": "source_id", "operator": "==", "value": source_id}, | ||
{"field": "split_id", "operator": ">=", "value": min_before}, | ||
{"field": "split_id", "operator": "<=", "value": max_after}, | ||
], | ||
} | ||
) | ||
context_windows.append(self.merge_documents_text(context_docs)) | ||
|
||
return {"context_windows": context_windows} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
7 changes: 7 additions & 0 deletions
7
releasenotes/notes/add-sentence-window-retrieval-5de4b0d6b2e8b0d6.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
--- | ||
|
||
features: | ||
- | | ||
Adding a new component allowing to perform sentence-window retrieval, i.e. retrieves surrounding documents of a | ||
given document from the document store. This is useful when a document is split into multiple chunks and you want to | ||
retrieve the surrounding context of a given chunk. |
143 changes: 143 additions & 0 deletions
143
test/components/retrievers/test_sentence_window_retriever.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
import pytest | ||
|
||
from haystack import Document, DeserializationError | ||
from haystack.components.retrievers.sentence_window_retrieval import SentenceWindowRetrieval | ||
from haystack.document_stores.in_memory import InMemoryDocumentStore | ||
from haystack.components.preprocessors import DocumentSplitter | ||
|
||
|
||
class TestSentenceWindowRetrieval: | ||
def test_init_default(self): | ||
retrieval = SentenceWindowRetrieval(InMemoryDocumentStore()) | ||
assert retrieval.window_size == 3 | ||
|
||
def test_init_with_parameters(self): | ||
retrieval = SentenceWindowRetrieval(InMemoryDocumentStore(), window_size=5) | ||
assert retrieval.window_size == 5 | ||
|
||
def test_init_with_invalid_window_size_parameter(self): | ||
with pytest.raises(ValueError): | ||
SentenceWindowRetrieval(InMemoryDocumentStore(), window_size=-2) | ||
|
||
def test_merge_documents(self): | ||
docs = [ | ||
{ | ||
"id": "doc_0", | ||
"content": "This is a text with some words. There is a ", | ||
"source_id": "c5d7c632affc486d0cfe7b3c0f4dc1d3896ea720da2b538d6d10b104a3df5f99", | ||
"page_number": 1, | ||
"split_id": 0, | ||
"split_idx_start": 0, | ||
"_split_overlap": [{"doc_id": "doc_1", "range": (0, 22)}], | ||
}, | ||
{ | ||
"id": "doc_1", | ||
"content": "some words. There is a second sentence. And there is ", | ||
"source_id": "c5d7c632affc486d0cfe7b3c0f4dc1d3896ea720da2b538d6d10b104a3df5f99", | ||
"page_number": 1, | ||
"split_id": 1, | ||
"split_idx_start": 21, | ||
"_split_overlap": [{"doc_id": "doc_0", "range": (20, 42)}, {"doc_id": "doc_2", "range": (0, 29)}], | ||
}, | ||
{ | ||
"id": "doc_2", | ||
"content": "second sentence. And there is also a third sentence", | ||
"source_id": "c5d7c632affc486d0cfe7b3c0f4dc1d3896ea720da2b538d6d10b104a3df5f99", | ||
"page_number": 1, | ||
"split_id": 2, | ||
"split_idx_start": 45, | ||
"_split_overlap": [{"doc_id": "doc_1", "range": (23, 52)}], | ||
}, | ||
] | ||
merged_text = SentenceWindowRetrieval.merge_documents_text([Document.from_dict(doc) for doc in docs]) | ||
expected = "This is a text with some words. There is a second sentence. And there is also a third sentence" | ||
assert merged_text == expected | ||
|
||
def test_to_dict(self): | ||
window_retrieval = SentenceWindowRetrieval(InMemoryDocumentStore()) | ||
data = window_retrieval.to_dict() | ||
|
||
assert data["type"] == "haystack.components.retrievers.sentence_window_retrieval.SentenceWindowRetrieval" | ||
assert data["init_parameters"]["window_size"] == 3 | ||
assert ( | ||
data["init_parameters"]["document_store"]["type"] | ||
== "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore" | ||
) | ||
|
||
def test_from_dict(self): | ||
data = { | ||
"type": "haystack.components.retrievers.sentence_window_retrieval.SentenceWindowRetrieval", | ||
"init_parameters": { | ||
"document_store": { | ||
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore", | ||
"init_parameters": {}, | ||
}, | ||
"window_size": 5, | ||
}, | ||
} | ||
component = SentenceWindowRetrieval.from_dict(data) | ||
assert isinstance(component.document_store, InMemoryDocumentStore) | ||
assert component.window_size == 5 | ||
|
||
def test_from_dict_without_docstore(self): | ||
data = {"type": "SentenceWindowRetrieval", "init_parameters": {}} | ||
with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"): | ||
SentenceWindowRetrieval.from_dict(data) | ||
|
||
def test_from_dict_without_docstore_type(self): | ||
data = {"type": "SentenceWindowRetrieval", "init_parameters": {"document_store": {"init_parameters": {}}}} | ||
with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"): | ||
SentenceWindowRetrieval.from_dict(data) | ||
|
||
def test_from_dict_non_existing_docstore(self): | ||
data = { | ||
"type": "SentenceWindowRetrieval", | ||
"init_parameters": {"document_store": {"type": "Nonexisting.Docstore", "init_parameters": {}}}, | ||
} | ||
with pytest.raises(DeserializationError): | ||
SentenceWindowRetrieval.from_dict(data) | ||
|
||
def test_document_without_split_id(self): | ||
docs = [ | ||
Document(content="This is a text with some words. There is a ", meta={"id": "doc_0"}), | ||
Document(content="some words. There is a second sentence. And there is ", meta={"id": "doc_1"}), | ||
] | ||
with pytest.raises(ValueError): | ||
retriever = SentenceWindowRetrieval(document_store=InMemoryDocumentStore(), window_size=3) | ||
retriever.run(retrieved_documents=docs) | ||
|
||
def test_document_without_source_id(self): | ||
docs = [ | ||
Document(content="This is a text with some words. There is a ", meta={"id": "doc_0", "split_id": 0}), | ||
Document( | ||
content="some words. There is a second sentence. And there is ", meta={"id": "doc_1", "split_id": 1} | ||
), | ||
] | ||
with pytest.raises(ValueError): | ||
retriever = SentenceWindowRetrieval(document_store=InMemoryDocumentStore(), window_size=3) | ||
retriever.run(retrieved_documents=docs) | ||
|
||
@pytest.mark.integration | ||
def test_run_with_pipeline(self): | ||
splitter = DocumentSplitter(split_length=10, split_overlap=5, split_by="word") | ||
text = ( | ||
"This is a text with some words. There is a second sentence. And there is also a third sentence. " | ||
"It also contains a fourth sentence. And a fifth sentence. And a sixth sentence. And a seventh sentence" | ||
) | ||
|
||
doc = Document(content=text) | ||
|
||
docs = splitter.run([doc]) | ||
ds = InMemoryDocumentStore() | ||
ds.write_documents(docs["documents"]) | ||
|
||
retriever = SentenceWindowRetrieval(document_store=ds, window_size=3) | ||
result = retriever.run(retrieved_documents=[list(ds.storage.values())[3]]) | ||
expected = { | ||
"context_windows": [ | ||
"This is a text with some words. There is a second sentence. And there is also a third sentence. It " | ||
"also contains a fourth sentence. And a fifth sentence. And a sixth sentence. And a seventh sentence" | ||
] | ||
} | ||
|
||
assert result == expected |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import datetime | ||
|
||
import sys | ||
from unittest.mock import Mock | ||
|
||
|
@@ -10,7 +12,12 @@ | |
from haystack.core.component import component | ||
from haystack.core.errors import DeserializationError | ||
from haystack.testing import factory | ||
from haystack.core.serialization import default_to_dict, default_from_dict, generate_qualified_class_name | ||
from haystack.core.serialization import ( | ||
default_to_dict, | ||
default_from_dict, | ||
generate_qualified_class_name, | ||
import_class_by_name, | ||
) | ||
|
||
|
||
def test_default_component_to_dict(): | ||
|
@@ -87,3 +94,16 @@ def test_get_qualified_class_name(): | |
comp = MyComponent() | ||
res = generate_qualified_class_name(type(comp)) | ||
assert res == "haystack.testing.factory.MyComponent" | ||
|
||
|
||
def test_import_class_by_name(): | ||
data = "haystack.core.pipeline.Pipeline" | ||
class_object = import_class_by_name(data) | ||
class_instance = class_object() | ||
assert isinstance(class_instance, Pipeline) | ||
|
||
|
||
def test_import_class_by_name_no_valid_class(): | ||
data = "some.invalid.class" | ||
with pytest.raises(ImportError): | ||
import_class_by_name(data) |