Skip to content

Commit

Permalink
refactor: improve support for dataclasses (#3142)
Browse files Browse the repository at this point in the history
* refactor: improve support for dataclasses

* refactor: refactor class init

* refactor: remove unused import

* refactor: testing 3.7 diffs

* refactor: checking meta where is Optional

* refactor: reverting some changes on 3.7

* refactor: remove unused imports

* build: manual pre-commit run

* doc: run doc pre-commit manually

* refactor: post initialization hack for 3.7-3.10 compat.

TODO: investigate another method to improve 3.7 compatibility.

* doc: force pre-commit

* refactor: refactored for both Python 3.7 and 3.9

* docs: manually run pre-commit hooks

* docs: run api docs manually

* docs: fix wrong comment

* refactor: change no type-checked test code

* docs: update primitives

* docs: api documentation

* docs: api documentation

* refactor: minor test refactoring

* refactor: remova unused enumeration on test

* refactor: remove unneeded dir in gitignore

* refactor: exclude all private fields and change meta def

* refactor: add pydantic comment

* refactor : fix for mypy on Python 3.7

* refactor: revert custom init

* docs: update docs to new pydoc-markdown style

* Update test/nodes/test_generator.py

Co-authored-by: Sara Zan <[email protected]>
  • Loading branch information
danielbichuetti and ZanSara authored Sep 9, 2022
1 parent 1a6cbca commit 621e1af
Show file tree
Hide file tree
Showing 10 changed files with 53 additions and 55 deletions.
4 changes: 2 additions & 2 deletions docs/_src/api/api/file_converter.md
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ Defaults to "UTF-8" in order to support special characters (e.g. German Umlauts,

```python
def convert(file_path: Path,
meta: Optional[Dict[str, str]] = None,
meta: Optional[Dict[str, Any]] = None,
remove_numeric_tables: Optional[bool] = None,
valid_languages: Optional[List[str]] = None,
encoding: Optional[str] = None,
Expand Down Expand Up @@ -440,7 +440,7 @@ In this case the id will be generated by using the content and the defined metad

```python
def convert(file_path: Path,
meta: Optional[Dict[str, str]] = None,
meta: Optional[Dict[str, Any]] = None,
remove_numeric_tables: Optional[bool] = None,
valid_languages: Optional[List[str]] = None,
encoding: Optional[str] = None,
Expand Down
5 changes: 1 addition & 4 deletions docs/_src/api/api/primitives.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(content: Union[str, pd.DataFrame],
content_type: Literal["text", "table", "image", "audio"] = "text",
id: Optional[str] = None,
score: Optional[float] = None,
meta: Dict[str, Any] = None,
meta: Optional[Dict[str, Any]] = None,
embedding: Optional[np.ndarray] = None,
id_hash_keys: Optional[List[str]] = None)
```
Expand All @@ -29,13 +29,10 @@ One of the core data classes in Haystack. It's used to represent documents / pas

Documents are stored in DocumentStores, are returned by Retrievers, are the input for Readers and are used in
many other places that manipulate or interact with document-level data.

Note: There can be multiple Documents originating from one file (e.g. PDF), if you split the text
into smaller passages. We'll have one Document per passage in this case.

Each document has a unique ID. This can be supplied by the user or generated automatically.
It's particularly helpful for handling of duplicates and referencing documents in other objects (e.g. Labels)

There's an easy option to convert from/to dicts via `from_dict()` and `to_dict`.

**Arguments**:
Expand Down
2 changes: 1 addition & 1 deletion docs/_src/api/api/summarizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ See the up-to-date list of available models on
**Example**

```python
| docs = [Document(text="PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions."
| docs = [Document(content="PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions."
| "The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by"
| "the shutoffs which were expected to last through at least midday tomorrow.")]
|
Expand Down
18 changes: 9 additions & 9 deletions docs/_src/api/openapi/openapi-1.8.1rc0.json
Original file line number Diff line number Diff line change
Expand Up @@ -618,13 +618,15 @@
"Document": {
"title": "Document",
"required": [
"content",
"content_type",
"id",
"meta"
"content"
],
"type": "object",
"properties": {
"id": {
"title": "Id",
"type": "string"
},
"content": {
"title": "Content",
"anyOf": [
Expand All @@ -644,15 +646,13 @@
"image",
"audio"
],
"type": "string"
},
"id": {
"title": "Id",
"type": "string"
"type": "string",
"default": "text"
},
"meta": {
"title": "Meta",
"type": "object"
"type": "object",
"default": {}
},
"score": {
"title": "Score",
Expand Down
18 changes: 9 additions & 9 deletions docs/_src/api/openapi/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -618,13 +618,15 @@
"Document": {
"title": "Document",
"required": [
"content",
"content_type",
"id",
"meta"
"content"
],
"type": "object",
"properties": {
"id": {
"title": "Id",
"type": "string"
},
"content": {
"title": "Content",
"anyOf": [
Expand All @@ -644,15 +646,13 @@
"image",
"audio"
],
"type": "string"
},
"id": {
"title": "Id",
"type": "string"
"type": "string",
"default": "text"
},
"meta": {
"title": "Meta",
"type": "object"
"type": "object",
"default": {}
},
"score": {
"title": "Score",
Expand Down
6 changes: 3 additions & 3 deletions haystack/nodes/file_converter/pdf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Dict
from typing import List, Optional, Dict, Any

import os
import logging
Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(
def convert(
self,
file_path: Path,
meta: Optional[Dict[str, str]] = None,
meta: Optional[Dict[str, Any]] = None,
remove_numeric_tables: Optional[bool] = None,
valid_languages: Optional[List[str]] = None,
encoding: Optional[str] = None,
Expand Down Expand Up @@ -212,7 +212,7 @@ def __init__(
def convert(
self,
file_path: Path,
meta: Optional[Dict[str, str]] = None,
meta: Optional[Dict[str, Any]] = None,
remove_numeric_tables: Optional[bool] = None,
valid_languages: Optional[List[str]] = None,
encoding: Optional[str] = None,
Expand Down
2 changes: 1 addition & 1 deletion haystack/nodes/summarizer/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class TransformersSummarizer(BaseSummarizer):
**Example**
```python
| docs = [Document(text="PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions."
| docs = [Document(content="PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions."
| "The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by"
| "the shutoffs which were expected to last through at least midday tomorrow.")]
|
Expand Down
38 changes: 20 additions & 18 deletions haystack/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
import csv
import hashlib

import typing
from typing import Any, Optional, Dict, List, Union

try:
from typing import Literal
from typing import Literal # type: ignore
except ImportError:
from typing_extensions import Literal # type: ignore

Expand All @@ -16,21 +15,18 @@
import time
import json
import ast
from dataclasses import asdict
from dataclasses import asdict, InitVar

import mmh3
import numpy as np
import pandas as pd

from pydantic import BaseConfig
from pydantic import BaseConfig, Field
from pydantic.json import pydantic_encoder

if not typing.TYPE_CHECKING:
# We are using Pydantic dataclasses instead of vanilla Python's
# See #1598 for the reasons behind this choice & performance considerations
from pydantic.dataclasses import dataclass
else:
from dataclasses import dataclass # type: ignore # pylint: disable=ungrouped-imports
# We are using Pydantic dataclasses instead of vanilla Python's
# See #1598 for the reasons behind this choice & performance considerations
from pydantic.dataclasses import dataclass


logger = logging.getLogger(__name__)
Expand All @@ -41,12 +37,13 @@

@dataclass
class Document:
content: Union[str, pd.DataFrame]
content_type: Literal["text", "table", "image", "audio"]
id: str
meta: Dict[str, Any]
content: Union[str, pd.DataFrame]
content_type: Literal["text", "table", "image", "audio"] = Field(default="text")
meta: Dict[str, Any] = Field(default={})
score: Optional[float] = None
embedding: Optional[np.ndarray] = None
id_hash_keys: InitVar[Optional[List[str]]] = None

# We use a custom init here as we want some custom logic. The annotations above are however still needed in order
# to use some dataclass magic like "asdict()". See https://www.python.org/dev/peps/pep-0557/#custom-init-method
Expand All @@ -58,23 +55,19 @@ def __init__(
content_type: Literal["text", "table", "image", "audio"] = "text",
id: Optional[str] = None,
score: Optional[float] = None,
meta: Dict[str, Any] = None,
meta: Optional[Dict[str, Any]] = None,
embedding: Optional[np.ndarray] = None,
id_hash_keys: Optional[List[str]] = None,
):
"""
One of the core data classes in Haystack. It's used to represent documents / passages in a standardized way within Haystack.
Documents are stored in DocumentStores, are returned by Retrievers, are the input for Readers and are used in
many other places that manipulate or interact with document-level data.
Note: There can be multiple Documents originating from one file (e.g. PDF), if you split the text
into smaller passages. We'll have one Document per passage in this case.
Each document has a unique ID. This can be supplied by the user or generated automatically.
It's particularly helpful for handling of duplicates and referencing documents in other objects (e.g. Labels)
There's an easy option to convert from/to dicts via `from_dict()` and `to_dict`.
:param content: Content of the document. For most cases, this will be text, but it can be a table or image.
:param content_type: One of "text", "table" or "image". Haystack components can use this to adjust their
handling of Documents and check compatibility.
Expand Down Expand Up @@ -154,6 +147,9 @@ def to_dict(self, field_map={}) -> Dict:
inv_field_map = {v: k for k, v in field_map.items()}
_doc: Dict[str, str] = {}
for k, v in self.__dict__.items():
# Exclude internal fields (Pydantic, ...) fields from the conversion process
if k.startswith("__"):
continue
if k == "content":
# Convert pd.DataFrame to list of rows for serialization
if self.content_type == "table" and isinstance(self.content, pd.DataFrame):
Expand Down Expand Up @@ -184,6 +180,9 @@ def from_dict(
_doc["meta"] = {}
# copy additional fields into "meta"
for k, v in _doc.items():
# Exclude internal fields (Pydantic, ...) fields from the conversion process
if k.startswith("__"):
continue
if k not in init_args and k not in field_map:
_doc["meta"][k] = v
# remove additional fields from top level
Expand Down Expand Up @@ -615,6 +614,8 @@ class MultiLabel:
contexts: List[str]
offsets_in_contexts: List[Dict]
offsets_in_documents: List[Dict]
drop_negative_labels: InitVar[bool] = False
drop_no_answer: InitVar[bool] = False

def __init__(self, labels: List[Label], drop_negative_labels=False, drop_no_answers=False, **kwargs):
"""
Expand Down Expand Up @@ -676,6 +677,7 @@ def __init__(self, labels: List[Label], drop_negative_labels=False, drop_no_answ
# as separate no_answer labels, and thus with document.id but without answer.document_id.
# If we do not exclude them from document_ids this would be problematic for retriever evaluation as they do not contain the answer.
# Hence, we exclude them here as well.

self.document_ids = [l.document.id for l in self.labels if not l.no_answer]
self.contexts = [l.document.content for l in self.labels if not l.no_answer]

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ dependencies = [
"importlib-metadata; python_version < '3.8'",
"torch>1.9,<1.13",
"requests",
"pydantic==1.9.2",
"pydantic",
"transformers==4.21.2",
"nltk",
"pandas",
Expand Down
13 changes: 6 additions & 7 deletions test/nodes/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import sys
from typing import List

import numpy as np
import pytest

from haystack.schema import Document
Expand Down Expand Up @@ -64,8 +63,8 @@ def test_generator_pipeline(document_store, retriever, rag_generator, docs_with_
def test_lfqa_pipeline(document_store, retriever, lfqa_generator, docs_with_true_emb):
# reuse existing DOCS but regenerate embeddings with retribert
docs: List[Document] = []
for idx, d in enumerate(docs_with_true_emb):
docs.append(Document(d.content, str(idx)))
for d in docs_with_true_emb:
docs.append(Document(content=d.content))
document_store.write_documents(docs)
document_store.update_embeddings(retriever)
query = "Tell me about Berlin?"
Expand All @@ -84,8 +83,8 @@ def test_lfqa_pipeline(document_store, retriever, lfqa_generator, docs_with_true
def test_lfqa_pipeline_unknown_converter(document_store, retriever, docs_with_true_emb):
# reuse existing DOCS but regenerate embeddings with retribert
docs: List[Document] = []
for idx, d in enumerate(docs_with_true_emb):
docs.append(Document(d.content, str(idx)))
for d in docs_with_true_emb:
docs.append(Document(content=d.content))
document_store.write_documents(docs)
document_store.update_embeddings(retriever)
seq2seq = Seq2SeqGenerator(model_name_or_path="patrickvonplaten/t5-tiny-random")
Expand All @@ -106,8 +105,8 @@ def test_lfqa_pipeline_unknown_converter(document_store, retriever, docs_with_tr
def test_lfqa_pipeline_invalid_converter(document_store, retriever, docs_with_true_emb):
# reuse existing DOCS but regenerate embeddings with retribert
docs: List[Document] = []
for idx, d in enumerate(docs_with_true_emb):
docs.append(Document(d.content, str(idx)))
for d in docs_with_true_emb:
docs.append(Document(content=d.content))
document_store.write_documents(docs)
document_store.update_embeddings(retriever)

Expand Down

0 comments on commit 621e1af

Please sign in to comment.