Skip to content

Commit

Permalink
Add SDK Batch message type
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarrmondragon committed Aug 11, 2022
1 parent cc7e06d commit 4d28791
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 49 deletions.
182 changes: 143 additions & 39 deletions singer_sdk/helpers/_singer.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,41 @@
from __future__ import annotations

import enum
import json
import logging
from dataclasses import dataclass, fields
import sys
from dataclasses import asdict, dataclass, field, fields
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, Tuple, Union, cast

from singer.catalog import Catalog as BaseCatalog
from singer.catalog import CatalogEntry as BaseCatalogEntry
from singer.messages import Message
from singer.schema import Schema

if TYPE_CHECKING:
from typing_extensions import TypeAlias

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal

Breadcrumb = Tuple[str, ...]

logger = logging.getLogger(__name__)


class SingerMessageType(str, enum.Enum):
"""Singer specification message types."""

RECORD = "RECORD"
SCHEMA = "SCHEMA"
STATE = "STATE"
ACTIVATE_VERSION = "ACTIVATE_VERSION"
BATCH = "BATCH"


class SelectionMask(Dict[Breadcrumb, bool]):
"""Boolean mask for property selection in schemas and records."""

Expand Down Expand Up @@ -39,28 +63,28 @@ class InclusionType(str, Enum):
AUTOMATIC = "automatic"
UNSUPPORTED = "unsupported"

inclusion: Optional[InclusionType] = None
selected: Optional[bool] = None
selected_by_default: Optional[bool] = None
inclusion: InclusionType | None = None
selected: bool | None = None
selected_by_default: bool | None = None

@classmethod
def from_dict(cls, value: Dict[str, Any]):
def from_dict(cls, value: dict[str, Any]):
"""Parse metadata dictionary."""
return cls(
**{
field.name: value.get(field.name.replace("_", "-"))
for field in fields(cls)
object_field.name: value.get(object_field.name.replace("_", "-"))
for object_field in fields(cls)
}
)

def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Convert metadata to a JSON-encodeable dictionary."""
result = {}

for field in fields(self):
value = getattr(self, field.name)
for object_field in fields(self):
value = getattr(self, object_field.name)
if value is not None:
result[field.name.replace("_", "-")] = value
result[object_field.name.replace("_", "-")] = value

return result

Expand All @@ -69,19 +93,22 @@ def to_dict(self) -> Dict[str, Any]:
class StreamMetadata(Metadata):
"""Stream metadata."""

table_key_properties: Optional[List[str]] = None
forced_replication_method: Optional[str] = None
valid_replication_keys: Optional[List[str]] = None
schema_name: Optional[str] = None
table_key_properties: list[str] | None = None
forced_replication_method: str | None = None
valid_replication_keys: list[str] | None = None
schema_name: str | None = None


AnyMetadata: TypeAlias = Union[Metadata, StreamMetadata]


class MetadataMapping(Dict[Breadcrumb, Union[Metadata, StreamMetadata]]):
class MetadataMapping(Dict[Breadcrumb, AnyMetadata]):
"""Stream metadata mapping."""

@classmethod
def from_iterable(cls, iterable: Iterable[Dict[str, Any]]):
def from_iterable(cls, iterable: Iterable[dict[str, Any]]):
"""Create a metadata mapping from an iterable of metadata dictionaries."""
mapping = cls()
mapping: dict[Breadcrumb, AnyMetadata] = cls()
for d in iterable:
breadcrumb = tuple(d["breadcrumb"])
metadata = d["metadata"]
Expand All @@ -92,7 +119,7 @@ def from_iterable(cls, iterable: Iterable[Dict[str, Any]]):

return mapping

def to_list(self) -> List[Dict[str, Any]]:
def to_list(self) -> list[dict[str, Any]]:
"""Convert mapping to a JSON-encodable list."""
return [
{"breadcrumb": list(k), "metadata": v.to_dict()} for k, v in self.items()
Expand All @@ -112,11 +139,11 @@ def root(self):
@classmethod
def get_standard_metadata(
cls,
schema: Optional[Dict[str, Any]] = None,
schema_name: Optional[str] = None,
key_properties: Optional[List[str]] = None,
valid_replication_keys: Optional[List[str]] = None,
replication_method: Optional[str] = None,
schema: dict[str, Any] | None = None,
schema_name: str | None = None,
key_properties: list[str] | None = None,
valid_replication_keys: list[str] | None = None,
replication_method: str | None = None,
):
"""Get default metadata for a stream."""
mapping = cls()
Expand Down Expand Up @@ -211,18 +238,18 @@ class CatalogEntry(BaseCatalogEntry):
tap_stream_id: str
metadata: MetadataMapping
schema: Schema
stream: Optional[str] = None
key_properties: Optional[List[str]] = None
replication_key: Optional[str] = None
is_view: Optional[bool] = None
database: Optional[str] = None
table: Optional[str] = None
row_count: Optional[int] = None
stream_alias: Optional[str] = None
replication_method: Optional[str] = None
stream: str | None = None
key_properties: list[str] | None = None
replication_key: str | None = None
is_view: bool | None = None
database: str | None = None
table: str | None = None
row_count: int | None = None
stream_alias: str | None = None
replication_method: str | None = None

@classmethod
def from_dict(cls, stream: Dict[str, Any]):
def from_dict(cls, stream: dict[str, Any]):
"""Create a catalog entry from a dictionary."""
return cls(
tap_stream_id=stream["tap_stream_id"],
Expand All @@ -249,15 +276,15 @@ class Catalog(Dict[str, CatalogEntry], BaseCatalog):
"""Singer catalog mapping of stream entries."""

@classmethod
def from_dict(cls, data: Dict[str, List[Dict[str, Any]]]) -> "Catalog":
def from_dict(cls, data: dict[str, list[dict[str, Any]]]) -> Catalog:
"""Create a catalog from a dictionary."""
instance = cls()
for stream in data.get("streams", []):
entry = CatalogEntry.from_dict(stream)
instance[entry.tap_stream_id] = entry
return instance

def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Return a dictionary representation of the catalog.
Returns:
Expand All @@ -266,14 +293,91 @@ def to_dict(self) -> Dict[str, Any]:
return cast(Dict[str, Any], super().to_dict())

@property
def streams(self) -> List[CatalogEntry]:
def streams(self) -> list[CatalogEntry]:
"""Get catalog entries."""
return list(self.values())

def add_stream(self, entry: CatalogEntry) -> None:
"""Add a stream entry to the catalog."""
self[entry.tap_stream_id] = entry

def get_stream(self, stream_id: str) -> Optional[CatalogEntry]:
def get_stream(self, stream_id: str) -> CatalogEntry | None:
"""Retrieve a stream entry from the catalog."""
return self.get(stream_id)


@dataclass
class BaseBatchFileEncoding:
"""Base class for batch file encodings."""

registered_encodings: ClassVar[dict[str, type[BaseBatchFileEncoding]]] = {}
__encoding_format__: ClassVar[str] = "OVERRIDE_ME"

# Base encoding fields
format: str = field(init=False)
"""The format of the batch file."""

def __init_subclass__(cls, **kwargs: Any) -> None:
"""Register subclasses."""
super().__init_subclass__(**kwargs)
cls.registered_encodings[cls.__encoding_format__] = cls

def __post_init__(self) -> None:
self.format = self.__encoding_format__

@classmethod
def from_dict(cls, data: dict[str, Any]) -> BaseBatchFileEncoding:
"""Create an encoding from a dictionary."""
encoding_format = data.pop("format")
encoding_cls = cls.registered_encodings[encoding_format]
return encoding_cls(**data)

@classmethod
def parse(cls, value: str, **kwargs: Any):
"""Parse a JSON Lines encoding from a dictionary.
Args:
value: A dictionary containing the encoding.
kwargs: Additional keyword arguments for `json.loads`.
"""
return cls.from_dict(json.loads(value, **kwargs))


@dataclass
class JSONLinesEncoding(BaseBatchFileEncoding):
"""JSON Lines encoding for batch files."""

__encoding_format__ = "jsonl"

compression: str | None = None


@dataclass
class SDKBatchMessage(Message):
"""Singer batch message in the Meltano SDK flavor."""

type: Literal[SingerMessageType.BATCH] = field(init=False)
"""The message type."""

stream: str
"""The stream name."""

encoding: BaseBatchFileEncoding
"""The file encoding of the batch."""

manifest: list[str] = field(default_factory=list)
"""The manifest of files in the batch."""

def __post_init__(self):
if isinstance(self.encoding, dict):
self.encoding = BaseBatchFileEncoding.from_dict(self.encoding)

self.type = SingerMessageType.BATCH

def asdict(self):
"""Return a dictionary representation of the message.
Returns:
A dictionary with the defined message fields.
"""
return asdict(self)
11 changes: 1 addition & 10 deletions singer_sdk/io_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import abc
import enum
import json
import logging
import sys
Expand All @@ -12,19 +11,11 @@
from typing import Counter as CounterType

from singer_sdk.helpers._compat import final
from singer_sdk.helpers._singer import SingerMessageType

logger = logging.getLogger(__name__)


class SingerMessageType(str, enum.Enum):
"""Singer specification message types."""

RECORD = "RECORD"
SCHEMA = "SCHEMA"
STATE = "STATE"
ACTIVATE_VERSION = "ACTIVATE_VERSION"


class SingerReader(metaclass=abc.ABCMeta):
"""Interface for all plugins reading Singer messages from stdin."""

Expand Down
58 changes: 58 additions & 0 deletions tests/core/test_singer_messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from dataclasses import asdict

import pytest

from singer_sdk.helpers._singer import (
BaseBatchFileEncoding,
JSONLinesEncoding,
SDKBatchMessage,
SingerMessageType,
)


@pytest.mark.parametrize(
"encoding,expected",
[
(JSONLinesEncoding("gzip"), {"compression": "gzip", "format": "jsonl"}),
(JSONLinesEncoding(), {"compression": None, "format": "jsonl"}),
],
ids=["jsonl-compression-gzip", "jsonl-compression-none"],
)
def test_encoding_as_dict(encoding: BaseBatchFileEncoding, expected: dict) -> None:
"""Test encoding as dict."""
assert asdict(encoding) == expected


@pytest.mark.parametrize(
"message,expected",
[
(
SDKBatchMessage(
stream="test_stream",
encoding=JSONLinesEncoding("gzip"),
manifest=[
"path/to/file1.jsonl.gz",
"path/to/file2.jsonl.gz",
],
),
{
"type": SingerMessageType.BATCH,
"stream": "test_stream",
"encoding": {"compression": "gzip", "format": "jsonl"},
"manifest": [
"path/to/file1.jsonl.gz",
"path/to/file2.jsonl.gz",
],
},
)
],
ids=["batch-message-jsonl"],
)
def test_batch_message_as_dict(message, expected):
"""Test batch message as dict."""

dumped = message.asdict()
assert dumped == expected

dumped.pop("type")
assert message.__class__(**dumped) == message

0 comments on commit 4d28791

Please sign in to comment.