From 7afba2d4fc1540d85877b10cbd9140d104adbe2b Mon Sep 17 00:00:00 2001 From: Daniel Bichuetti Date: Fri, 9 Sep 2022 06:31:37 -0300 Subject: [PATCH] refactor: improve support for dataclasses (#3142) * refactor: improve support for dataclasses * refactor: refactor class init * refactor: remove unused import * refactor: testing 3.7 diffs * refactor: checking meta where is Optional * refactor: reverting some changes on 3.7 * refactor: remove unused imports * build: manual pre-commit run * doc: run doc pre-commit manually * refactor: post initialization hack for 3.7-3.10 compat. TODO: investigate another method to improve 3.7 compatibility. * doc: force pre-commit * refactor: refactored for both Python 3.7 and 3.9 * docs: manually run pre-commit hooks * docs: run api docs manually * docs: fix wrong comment * refactor: change no type-checked test code * docs: update primitives * docs: api documentation * docs: api documentation * refactor: minor test refactoring * refactor: remova unused enumeration on test * refactor: remove unneeded dir in gitignore * refactor: exclude all private fields and change meta def * refactor: add pydantic comment * refactor : fix for mypy on Python 3.7 * refactor: revert custom init * docs: update docs to new pydoc-markdown style * Update test/nodes/test_generator.py Co-authored-by: Sara Zan --- docs/_src/api/api/file_converter.md | 4 +-- docs/_src/api/api/primitives.md | 5 +-- docs/_src/api/api/summarizer.md | 2 +- docs/_src/api/openapi/openapi-1.8.1rc0.json | 18 +++++----- docs/_src/api/openapi/openapi.json | 18 +++++----- haystack/nodes/file_converter/pdf.py | 6 ++-- haystack/nodes/summarizer/transformers.py | 2 +- haystack/schema.py | 38 +++++++++++---------- pyproject.toml | 2 +- test/nodes/test_generator.py | 13 ++++--- 10 files changed, 53 insertions(+), 55 deletions(-) diff --git a/docs/_src/api/api/file_converter.md b/docs/_src/api/api/file_converter.md index 0d2eb4374c..43cd2437ce 100644 --- a/docs/_src/api/api/file_converter.md +++ b/docs/_src/api/api/file_converter.md @@ -366,7 +366,7 @@ Defaults to "UTF-8" in order to support special characters (e.g. German Umlauts, ```python def convert(file_path: Path, - meta: Optional[Dict[str, str]] = None, + meta: Optional[Dict[str, Any]] = None, remove_numeric_tables: Optional[bool] = None, valid_languages: Optional[List[str]] = None, encoding: Optional[str] = None, @@ -440,7 +440,7 @@ In this case the id will be generated by using the content and the defined metad ```python def convert(file_path: Path, - meta: Optional[Dict[str, str]] = None, + meta: Optional[Dict[str, Any]] = None, remove_numeric_tables: Optional[bool] = None, valid_languages: Optional[List[str]] = None, encoding: Optional[str] = None, diff --git a/docs/_src/api/api/primitives.md b/docs/_src/api/api/primitives.md index 28270f6dcb..f50bf68c1e 100644 --- a/docs/_src/api/api/primitives.md +++ b/docs/_src/api/api/primitives.md @@ -20,7 +20,7 @@ def __init__(content: Union[str, pd.DataFrame], content_type: Literal["text", "table", "image", "audio"] = "text", id: Optional[str] = None, score: Optional[float] = None, - meta: Dict[str, Any] = None, + meta: Optional[Dict[str, Any]] = None, embedding: Optional[np.ndarray] = None, id_hash_keys: Optional[List[str]] = None) ``` @@ -29,13 +29,10 @@ One of the core data classes in Haystack. It's used to represent documents / pas Documents are stored in DocumentStores, are returned by Retrievers, are the input for Readers and are used in many other places that manipulate or interact with document-level data. - Note: There can be multiple Documents originating from one file (e.g. PDF), if you split the text into smaller passages. We'll have one Document per passage in this case. - Each document has a unique ID. This can be supplied by the user or generated automatically. It's particularly helpful for handling of duplicates and referencing documents in other objects (e.g. Labels) - There's an easy option to convert from/to dicts via `from_dict()` and `to_dict`. **Arguments**: diff --git a/docs/_src/api/api/summarizer.md b/docs/_src/api/api/summarizer.md index ee6c388040..d91840de00 100644 --- a/docs/_src/api/api/summarizer.md +++ b/docs/_src/api/api/summarizer.md @@ -59,7 +59,7 @@ See the up-to-date list of available models on **Example** ```python -| docs = [Document(text="PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions." +| docs = [Document(content="PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions." | "The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by" | "the shutoffs which were expected to last through at least midday tomorrow.")] | diff --git a/docs/_src/api/openapi/openapi-1.8.1rc0.json b/docs/_src/api/openapi/openapi-1.8.1rc0.json index 7e588f01fa..2ad4d3dc22 100644 --- a/docs/_src/api/openapi/openapi-1.8.1rc0.json +++ b/docs/_src/api/openapi/openapi-1.8.1rc0.json @@ -618,13 +618,15 @@ "Document": { "title": "Document", "required": [ - "content", - "content_type", "id", - "meta" + "content" ], "type": "object", "properties": { + "id": { + "title": "Id", + "type": "string" + }, "content": { "title": "Content", "anyOf": [ @@ -644,15 +646,13 @@ "image", "audio" ], - "type": "string" - }, - "id": { - "title": "Id", - "type": "string" + "type": "string", + "default": "text" }, "meta": { "title": "Meta", - "type": "object" + "type": "object", + "default": {} }, "score": { "title": "Score", diff --git a/docs/_src/api/openapi/openapi.json b/docs/_src/api/openapi/openapi.json index 7e588f01fa..2ad4d3dc22 100644 --- a/docs/_src/api/openapi/openapi.json +++ b/docs/_src/api/openapi/openapi.json @@ -618,13 +618,15 @@ "Document": { "title": "Document", "required": [ - "content", - "content_type", "id", - "meta" + "content" ], "type": "object", "properties": { + "id": { + "title": "Id", + "type": "string" + }, "content": { "title": "Content", "anyOf": [ @@ -644,15 +646,13 @@ "image", "audio" ], - "type": "string" - }, - "id": { - "title": "Id", - "type": "string" + "type": "string", + "default": "text" }, "meta": { "title": "Meta", - "type": "object" + "type": "object", + "default": {} }, "score": { "title": "Score", diff --git a/haystack/nodes/file_converter/pdf.py b/haystack/nodes/file_converter/pdf.py index b65d2b9b76..25899fb5e6 100644 --- a/haystack/nodes/file_converter/pdf.py +++ b/haystack/nodes/file_converter/pdf.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Dict +from typing import List, Optional, Dict, Any import os import logging @@ -74,7 +74,7 @@ def __init__( def convert( self, file_path: Path, - meta: Optional[Dict[str, str]] = None, + meta: Optional[Dict[str, Any]] = None, remove_numeric_tables: Optional[bool] = None, valid_languages: Optional[List[str]] = None, encoding: Optional[str] = None, @@ -212,7 +212,7 @@ def __init__( def convert( self, file_path: Path, - meta: Optional[Dict[str, str]] = None, + meta: Optional[Dict[str, Any]] = None, remove_numeric_tables: Optional[bool] = None, valid_languages: Optional[List[str]] = None, encoding: Optional[str] = None, diff --git a/haystack/nodes/summarizer/transformers.py b/haystack/nodes/summarizer/transformers.py index 9fc3d8068d..531c246b26 100644 --- a/haystack/nodes/summarizer/transformers.py +++ b/haystack/nodes/summarizer/transformers.py @@ -28,7 +28,7 @@ class TransformersSummarizer(BaseSummarizer): **Example** ```python - | docs = [Document(text="PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions." + | docs = [Document(content="PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions." | "The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by" | "the shutoffs which were expected to last through at least midday tomorrow.")] | diff --git a/haystack/schema.py b/haystack/schema.py index 55fce78a0f..fb4fbaadcb 100644 --- a/haystack/schema.py +++ b/haystack/schema.py @@ -2,11 +2,10 @@ import csv import hashlib -import typing from typing import Any, Optional, Dict, List, Union try: - from typing import Literal + from typing import Literal # type: ignore except ImportError: from typing_extensions import Literal # type: ignore @@ -16,21 +15,18 @@ import time import json import ast -from dataclasses import asdict +from dataclasses import asdict, InitVar import mmh3 import numpy as np import pandas as pd -from pydantic import BaseConfig +from pydantic import BaseConfig, Field from pydantic.json import pydantic_encoder -if not typing.TYPE_CHECKING: - # We are using Pydantic dataclasses instead of vanilla Python's - # See #1598 for the reasons behind this choice & performance considerations - from pydantic.dataclasses import dataclass -else: - from dataclasses import dataclass # type: ignore # pylint: disable=ungrouped-imports +# We are using Pydantic dataclasses instead of vanilla Python's +# See #1598 for the reasons behind this choice & performance considerations +from pydantic.dataclasses import dataclass logger = logging.getLogger(__name__) @@ -41,12 +37,13 @@ @dataclass class Document: - content: Union[str, pd.DataFrame] - content_type: Literal["text", "table", "image", "audio"] id: str - meta: Dict[str, Any] + content: Union[str, pd.DataFrame] + content_type: Literal["text", "table", "image", "audio"] = Field(default="text") + meta: Dict[str, Any] = Field(default={}) score: Optional[float] = None embedding: Optional[np.ndarray] = None + id_hash_keys: InitVar[Optional[List[str]]] = None # We use a custom init here as we want some custom logic. The annotations above are however still needed in order # to use some dataclass magic like "asdict()". See https://www.python.org/dev/peps/pep-0557/#custom-init-method @@ -58,7 +55,7 @@ def __init__( content_type: Literal["text", "table", "image", "audio"] = "text", id: Optional[str] = None, score: Optional[float] = None, - meta: Dict[str, Any] = None, + meta: Optional[Dict[str, Any]] = None, embedding: Optional[np.ndarray] = None, id_hash_keys: Optional[List[str]] = None, ): @@ -66,15 +63,11 @@ def __init__( One of the core data classes in Haystack. It's used to represent documents / passages in a standardized way within Haystack. Documents are stored in DocumentStores, are returned by Retrievers, are the input for Readers and are used in many other places that manipulate or interact with document-level data. - Note: There can be multiple Documents originating from one file (e.g. PDF), if you split the text into smaller passages. We'll have one Document per passage in this case. - Each document has a unique ID. This can be supplied by the user or generated automatically. It's particularly helpful for handling of duplicates and referencing documents in other objects (e.g. Labels) - There's an easy option to convert from/to dicts via `from_dict()` and `to_dict`. - :param content: Content of the document. For most cases, this will be text, but it can be a table or image. :param content_type: One of "text", "table" or "image". Haystack components can use this to adjust their handling of Documents and check compatibility. @@ -154,6 +147,9 @@ def to_dict(self, field_map={}) -> Dict: inv_field_map = {v: k for k, v in field_map.items()} _doc: Dict[str, str] = {} for k, v in self.__dict__.items(): + # Exclude internal fields (Pydantic, ...) fields from the conversion process + if k.startswith("__"): + continue if k == "content": # Convert pd.DataFrame to list of rows for serialization if self.content_type == "table" and isinstance(self.content, pd.DataFrame): @@ -184,6 +180,9 @@ def from_dict( _doc["meta"] = {} # copy additional fields into "meta" for k, v in _doc.items(): + # Exclude internal fields (Pydantic, ...) fields from the conversion process + if k.startswith("__"): + continue if k not in init_args and k not in field_map: _doc["meta"][k] = v # remove additional fields from top level @@ -615,6 +614,8 @@ class MultiLabel: contexts: List[str] offsets_in_contexts: List[Dict] offsets_in_documents: List[Dict] + drop_negative_labels: InitVar[bool] = False + drop_no_answer: InitVar[bool] = False def __init__(self, labels: List[Label], drop_negative_labels=False, drop_no_answers=False, **kwargs): """ @@ -676,6 +677,7 @@ def __init__(self, labels: List[Label], drop_negative_labels=False, drop_no_answ # as separate no_answer labels, and thus with document.id but without answer.document_id. # If we do not exclude them from document_ids this would be problematic for retriever evaluation as they do not contain the answer. # Hence, we exclude them here as well. + self.document_ids = [l.document.id for l in self.labels if not l.no_answer] self.contexts = [l.document.content for l in self.labels if not l.no_answer] diff --git a/pyproject.toml b/pyproject.toml index 0420497fbb..f46af0d4da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ dependencies = [ "importlib-metadata; python_version < '3.8'", "torch>1.9,<1.13", "requests", - "pydantic==1.9.2", + "pydantic", "transformers==4.21.2", "nltk", "pandas", diff --git a/test/nodes/test_generator.py b/test/nodes/test_generator.py index 23c1a9abab..c7c042857e 100644 --- a/test/nodes/test_generator.py +++ b/test/nodes/test_generator.py @@ -2,7 +2,6 @@ import sys from typing import List -import numpy as np import pytest from haystack.schema import Document @@ -64,8 +63,8 @@ def test_generator_pipeline(document_store, retriever, rag_generator, docs_with_ def test_lfqa_pipeline(document_store, retriever, lfqa_generator, docs_with_true_emb): # reuse existing DOCS but regenerate embeddings with retribert docs: List[Document] = [] - for idx, d in enumerate(docs_with_true_emb): - docs.append(Document(d.content, str(idx))) + for d in docs_with_true_emb: + docs.append(Document(content=d.content)) document_store.write_documents(docs) document_store.update_embeddings(retriever) query = "Tell me about Berlin?" @@ -84,8 +83,8 @@ def test_lfqa_pipeline(document_store, retriever, lfqa_generator, docs_with_true def test_lfqa_pipeline_unknown_converter(document_store, retriever, docs_with_true_emb): # reuse existing DOCS but regenerate embeddings with retribert docs: List[Document] = [] - for idx, d in enumerate(docs_with_true_emb): - docs.append(Document(d.content, str(idx))) + for d in docs_with_true_emb: + docs.append(Document(content=d.content)) document_store.write_documents(docs) document_store.update_embeddings(retriever) seq2seq = Seq2SeqGenerator(model_name_or_path="patrickvonplaten/t5-tiny-random") @@ -106,8 +105,8 @@ def test_lfqa_pipeline_unknown_converter(document_store, retriever, docs_with_tr def test_lfqa_pipeline_invalid_converter(document_store, retriever, docs_with_true_emb): # reuse existing DOCS but regenerate embeddings with retribert docs: List[Document] = [] - for idx, d in enumerate(docs_with_true_emb): - docs.append(Document(d.content, str(idx))) + for d in docs_with_true_emb: + docs.append(Document(content=d.content)) document_store.write_documents(docs) document_store.update_embeddings(retriever)