From 4d28791c28900b237f727e318d4ac517b17c97c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Rami=CC=81rez=20Mondrago=CC=81n?= Date: Wed, 10 Aug 2022 18:08:32 -0500 Subject: [PATCH] Add SDK Batch message type --- singer_sdk/helpers/_singer.py | 182 ++++++++++++++++++++++------- singer_sdk/io_base.py | 11 +- tests/core/test_singer_messages.py | 58 +++++++++ 3 files changed, 202 insertions(+), 49 deletions(-) create mode 100644 tests/core/test_singer_messages.py diff --git a/singer_sdk/helpers/_singer.py b/singer_sdk/helpers/_singer.py index 465e58fd96..2f3168dc42 100644 --- a/singer_sdk/helpers/_singer.py +++ b/singer_sdk/helpers/_singer.py @@ -1,17 +1,41 @@ +from __future__ import annotations + +import enum +import json import logging -from dataclasses import dataclass, fields +import sys +from dataclasses import asdict, dataclass, field, fields from enum import Enum -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, Tuple, Union, cast from singer.catalog import Catalog as BaseCatalog from singer.catalog import CatalogEntry as BaseCatalogEntry +from singer.messages import Message from singer.schema import Schema +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + if sys.version_info >= (3, 8): + from typing import Literal + else: + from typing_extensions import Literal + Breadcrumb = Tuple[str, ...] logger = logging.getLogger(__name__) +class SingerMessageType(str, enum.Enum): + """Singer specification message types.""" + + RECORD = "RECORD" + SCHEMA = "SCHEMA" + STATE = "STATE" + ACTIVATE_VERSION = "ACTIVATE_VERSION" + BATCH = "BATCH" + + class SelectionMask(Dict[Breadcrumb, bool]): """Boolean mask for property selection in schemas and records.""" @@ -39,28 +63,28 @@ class InclusionType(str, Enum): AUTOMATIC = "automatic" UNSUPPORTED = "unsupported" - inclusion: Optional[InclusionType] = None - selected: Optional[bool] = None - selected_by_default: Optional[bool] = None + inclusion: InclusionType | None = None + selected: bool | None = None + selected_by_default: bool | None = None @classmethod - def from_dict(cls, value: Dict[str, Any]): + def from_dict(cls, value: dict[str, Any]): """Parse metadata dictionary.""" return cls( **{ - field.name: value.get(field.name.replace("_", "-")) - for field in fields(cls) + object_field.name: value.get(object_field.name.replace("_", "-")) + for object_field in fields(cls) } ) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Convert metadata to a JSON-encodeable dictionary.""" result = {} - for field in fields(self): - value = getattr(self, field.name) + for object_field in fields(self): + value = getattr(self, object_field.name) if value is not None: - result[field.name.replace("_", "-")] = value + result[object_field.name.replace("_", "-")] = value return result @@ -69,19 +93,22 @@ def to_dict(self) -> Dict[str, Any]: class StreamMetadata(Metadata): """Stream metadata.""" - table_key_properties: Optional[List[str]] = None - forced_replication_method: Optional[str] = None - valid_replication_keys: Optional[List[str]] = None - schema_name: Optional[str] = None + table_key_properties: list[str] | None = None + forced_replication_method: str | None = None + valid_replication_keys: list[str] | None = None + schema_name: str | None = None + + +AnyMetadata: TypeAlias = Union[Metadata, StreamMetadata] -class MetadataMapping(Dict[Breadcrumb, Union[Metadata, StreamMetadata]]): +class MetadataMapping(Dict[Breadcrumb, AnyMetadata]): """Stream metadata mapping.""" @classmethod - def from_iterable(cls, iterable: Iterable[Dict[str, Any]]): + def from_iterable(cls, iterable: Iterable[dict[str, Any]]): """Create a metadata mapping from an iterable of metadata dictionaries.""" - mapping = cls() + mapping: dict[Breadcrumb, AnyMetadata] = cls() for d in iterable: breadcrumb = tuple(d["breadcrumb"]) metadata = d["metadata"] @@ -92,7 +119,7 @@ def from_iterable(cls, iterable: Iterable[Dict[str, Any]]): return mapping - def to_list(self) -> List[Dict[str, Any]]: + def to_list(self) -> list[dict[str, Any]]: """Convert mapping to a JSON-encodable list.""" return [ {"breadcrumb": list(k), "metadata": v.to_dict()} for k, v in self.items() @@ -112,11 +139,11 @@ def root(self): @classmethod def get_standard_metadata( cls, - schema: Optional[Dict[str, Any]] = None, - schema_name: Optional[str] = None, - key_properties: Optional[List[str]] = None, - valid_replication_keys: Optional[List[str]] = None, - replication_method: Optional[str] = None, + schema: dict[str, Any] | None = None, + schema_name: str | None = None, + key_properties: list[str] | None = None, + valid_replication_keys: list[str] | None = None, + replication_method: str | None = None, ): """Get default metadata for a stream.""" mapping = cls() @@ -211,18 +238,18 @@ class CatalogEntry(BaseCatalogEntry): tap_stream_id: str metadata: MetadataMapping schema: Schema - stream: Optional[str] = None - key_properties: Optional[List[str]] = None - replication_key: Optional[str] = None - is_view: Optional[bool] = None - database: Optional[str] = None - table: Optional[str] = None - row_count: Optional[int] = None - stream_alias: Optional[str] = None - replication_method: Optional[str] = None + stream: str | None = None + key_properties: list[str] | None = None + replication_key: str | None = None + is_view: bool | None = None + database: str | None = None + table: str | None = None + row_count: int | None = None + stream_alias: str | None = None + replication_method: str | None = None @classmethod - def from_dict(cls, stream: Dict[str, Any]): + def from_dict(cls, stream: dict[str, Any]): """Create a catalog entry from a dictionary.""" return cls( tap_stream_id=stream["tap_stream_id"], @@ -249,7 +276,7 @@ class Catalog(Dict[str, CatalogEntry], BaseCatalog): """Singer catalog mapping of stream entries.""" @classmethod - def from_dict(cls, data: Dict[str, List[Dict[str, Any]]]) -> "Catalog": + def from_dict(cls, data: dict[str, list[dict[str, Any]]]) -> Catalog: """Create a catalog from a dictionary.""" instance = cls() for stream in data.get("streams", []): @@ -257,7 +284,7 @@ def from_dict(cls, data: Dict[str, List[Dict[str, Any]]]) -> "Catalog": instance[entry.tap_stream_id] = entry return instance - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Return a dictionary representation of the catalog. Returns: @@ -266,7 +293,7 @@ def to_dict(self) -> Dict[str, Any]: return cast(Dict[str, Any], super().to_dict()) @property - def streams(self) -> List[CatalogEntry]: + def streams(self) -> list[CatalogEntry]: """Get catalog entries.""" return list(self.values()) @@ -274,6 +301,83 @@ def add_stream(self, entry: CatalogEntry) -> None: """Add a stream entry to the catalog.""" self[entry.tap_stream_id] = entry - def get_stream(self, stream_id: str) -> Optional[CatalogEntry]: + def get_stream(self, stream_id: str) -> CatalogEntry | None: """Retrieve a stream entry from the catalog.""" return self.get(stream_id) + + +@dataclass +class BaseBatchFileEncoding: + """Base class for batch file encodings.""" + + registered_encodings: ClassVar[dict[str, type[BaseBatchFileEncoding]]] = {} + __encoding_format__: ClassVar[str] = "OVERRIDE_ME" + + # Base encoding fields + format: str = field(init=False) + """The format of the batch file.""" + + def __init_subclass__(cls, **kwargs: Any) -> None: + """Register subclasses.""" + super().__init_subclass__(**kwargs) + cls.registered_encodings[cls.__encoding_format__] = cls + + def __post_init__(self) -> None: + self.format = self.__encoding_format__ + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> BaseBatchFileEncoding: + """Create an encoding from a dictionary.""" + encoding_format = data.pop("format") + encoding_cls = cls.registered_encodings[encoding_format] + return encoding_cls(**data) + + @classmethod + def parse(cls, value: str, **kwargs: Any): + """Parse a JSON Lines encoding from a dictionary. + + Args: + value: A dictionary containing the encoding. + kwargs: Additional keyword arguments for `json.loads`. + """ + return cls.from_dict(json.loads(value, **kwargs)) + + +@dataclass +class JSONLinesEncoding(BaseBatchFileEncoding): + """JSON Lines encoding for batch files.""" + + __encoding_format__ = "jsonl" + + compression: str | None = None + + +@dataclass +class SDKBatchMessage(Message): + """Singer batch message in the Meltano SDK flavor.""" + + type: Literal[SingerMessageType.BATCH] = field(init=False) + """The message type.""" + + stream: str + """The stream name.""" + + encoding: BaseBatchFileEncoding + """The file encoding of the batch.""" + + manifest: list[str] = field(default_factory=list) + """The manifest of files in the batch.""" + + def __post_init__(self): + if isinstance(self.encoding, dict): + self.encoding = BaseBatchFileEncoding.from_dict(self.encoding) + + self.type = SingerMessageType.BATCH + + def asdict(self): + """Return a dictionary representation of the message. + + Returns: + A dictionary with the defined message fields. + """ + return asdict(self) diff --git a/singer_sdk/io_base.py b/singer_sdk/io_base.py index f5da20d19a..32bb00e665 100644 --- a/singer_sdk/io_base.py +++ b/singer_sdk/io_base.py @@ -3,7 +3,6 @@ from __future__ import annotations import abc -import enum import json import logging import sys @@ -12,19 +11,11 @@ from typing import Counter as CounterType from singer_sdk.helpers._compat import final +from singer_sdk.helpers._singer import SingerMessageType logger = logging.getLogger(__name__) -class SingerMessageType(str, enum.Enum): - """Singer specification message types.""" - - RECORD = "RECORD" - SCHEMA = "SCHEMA" - STATE = "STATE" - ACTIVATE_VERSION = "ACTIVATE_VERSION" - - class SingerReader(metaclass=abc.ABCMeta): """Interface for all plugins reading Singer messages from stdin.""" diff --git a/tests/core/test_singer_messages.py b/tests/core/test_singer_messages.py new file mode 100644 index 0000000000..9601d453da --- /dev/null +++ b/tests/core/test_singer_messages.py @@ -0,0 +1,58 @@ +from dataclasses import asdict + +import pytest + +from singer_sdk.helpers._singer import ( + BaseBatchFileEncoding, + JSONLinesEncoding, + SDKBatchMessage, + SingerMessageType, +) + + +@pytest.mark.parametrize( + "encoding,expected", + [ + (JSONLinesEncoding("gzip"), {"compression": "gzip", "format": "jsonl"}), + (JSONLinesEncoding(), {"compression": None, "format": "jsonl"}), + ], + ids=["jsonl-compression-gzip", "jsonl-compression-none"], +) +def test_encoding_as_dict(encoding: BaseBatchFileEncoding, expected: dict) -> None: + """Test encoding as dict.""" + assert asdict(encoding) == expected + + +@pytest.mark.parametrize( + "message,expected", + [ + ( + SDKBatchMessage( + stream="test_stream", + encoding=JSONLinesEncoding("gzip"), + manifest=[ + "path/to/file1.jsonl.gz", + "path/to/file2.jsonl.gz", + ], + ), + { + "type": SingerMessageType.BATCH, + "stream": "test_stream", + "encoding": {"compression": "gzip", "format": "jsonl"}, + "manifest": [ + "path/to/file1.jsonl.gz", + "path/to/file2.jsonl.gz", + ], + }, + ) + ], + ids=["batch-message-jsonl"], +) +def test_batch_message_as_dict(message, expected): + """Test batch message as dict.""" + + dumped = message.asdict() + assert dumped == expected + + dumped.pop("type") + assert message.__class__(**dumped) == message