Skip to content

Commit

Permalink
Add SDK Batch message type
Browse files Browse the repository at this point in the history
Exclude 'if TYPE_CHECKING:' blocks from coverage

Remove BaseBatchFileEncoding.parse method

Add abstract IO method for processing batch messages

Move compression attribute to parent class

Drop unused import
  • Loading branch information
edgarrmondragon committed Aug 30, 2022
1 parent cf03626 commit 080eefe
Show file tree
Hide file tree
Showing 7 changed files with 370 additions and 71 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ exclude_lines = [
"raise NotImplementedError",
"if __name__ == .__main__.:",
'''class .*\bProtocol\):''',
"if TYPE_CHECKING:",
]
fail_under = 82

Expand Down
192 changes: 153 additions & 39 deletions singer_sdk/helpers/_singer.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,41 @@
from __future__ import annotations

import enum
import logging
from dataclasses import dataclass, fields
import sys
from dataclasses import asdict, dataclass, field, fields
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, Tuple, Union, cast

from singer.catalog import Catalog as BaseCatalog
from singer.catalog import CatalogEntry as BaseCatalogEntry
from singer.messages import Message

from singer_sdk.helpers._schema import SchemaPlus

if TYPE_CHECKING:
from typing_extensions import TypeAlias

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal

Breadcrumb = Tuple[str, ...]

logger = logging.getLogger(__name__)


class SingerMessageType(str, enum.Enum):
"""Singer specification message types."""

RECORD = "RECORD"
SCHEMA = "SCHEMA"
STATE = "STATE"
ACTIVATE_VERSION = "ACTIVATE_VERSION"
BATCH = "BATCH"


class SelectionMask(Dict[Breadcrumb, bool]):
"""Boolean mask for property selection in schemas and records."""

Expand Down Expand Up @@ -40,28 +63,28 @@ class InclusionType(str, Enum):
AUTOMATIC = "automatic"
UNSUPPORTED = "unsupported"

inclusion: Optional[InclusionType] = None
selected: Optional[bool] = None
selected_by_default: Optional[bool] = None
inclusion: InclusionType | None = None
selected: bool | None = None
selected_by_default: bool | None = None

@classmethod
def from_dict(cls, value: Dict[str, Any]):
def from_dict(cls, value: dict[str, Any]):
"""Parse metadata dictionary."""
return cls(
**{
field.name: value.get(field.name.replace("_", "-"))
for field in fields(cls)
object_field.name: value.get(object_field.name.replace("_", "-"))
for object_field in fields(cls)
}
)

def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Convert metadata to a JSON-encodeable dictionary."""
result = {}

for field in fields(self):
value = getattr(self, field.name)
for object_field in fields(self):
value = getattr(self, object_field.name)
if value is not None:
result[field.name.replace("_", "-")] = value
result[object_field.name.replace("_", "-")] = value

return result

Expand All @@ -70,19 +93,22 @@ def to_dict(self) -> Dict[str, Any]:
class StreamMetadata(Metadata):
"""Stream metadata."""

table_key_properties: Optional[List[str]] = None
forced_replication_method: Optional[str] = None
valid_replication_keys: Optional[List[str]] = None
schema_name: Optional[str] = None
table_key_properties: list[str] | None = None
forced_replication_method: str | None = None
valid_replication_keys: list[str] | None = None
schema_name: str | None = None


class MetadataMapping(Dict[Breadcrumb, Union[Metadata, StreamMetadata]]):
AnyMetadata: TypeAlias = Union[Metadata, StreamMetadata]


class MetadataMapping(Dict[Breadcrumb, AnyMetadata]):
"""Stream metadata mapping."""

@classmethod
def from_iterable(cls, iterable: Iterable[Dict[str, Any]]):
def from_iterable(cls, iterable: Iterable[dict[str, Any]]):
"""Create a metadata mapping from an iterable of metadata dictionaries."""
mapping = cls()
mapping: dict[Breadcrumb, AnyMetadata] = cls()
for d in iterable:
breadcrumb = tuple(d["breadcrumb"])
metadata = d["metadata"]
Expand All @@ -93,7 +119,7 @@ def from_iterable(cls, iterable: Iterable[Dict[str, Any]]):

return mapping

def to_list(self) -> List[Dict[str, Any]]:
def to_list(self) -> list[dict[str, Any]]:
"""Convert mapping to a JSON-encodable list."""
return [
{"breadcrumb": list(k), "metadata": v.to_dict()} for k, v in self.items()
Expand All @@ -113,11 +139,11 @@ def root(self):
@classmethod
def get_standard_metadata(
cls,
schema: Optional[Dict[str, Any]] = None,
schema_name: Optional[str] = None,
key_properties: Optional[List[str]] = None,
valid_replication_keys: Optional[List[str]] = None,
replication_method: Optional[str] = None,
schema: dict[str, Any] | None = None,
schema_name: str | None = None,
key_properties: list[str] | None = None,
valid_replication_keys: list[str] | None = None,
replication_method: str | None = None,
):
"""Get default metadata for a stream."""
mapping = cls()
Expand Down Expand Up @@ -212,18 +238,18 @@ class CatalogEntry(BaseCatalogEntry):
tap_stream_id: str
metadata: MetadataMapping
schema: SchemaPlus
stream: Optional[str] = None
key_properties: Optional[List[str]] = None
replication_key: Optional[str] = None
is_view: Optional[bool] = None
database: Optional[str] = None
table: Optional[str] = None
row_count: Optional[int] = None
stream_alias: Optional[str] = None
replication_method: Optional[str] = None
stream: str | None = None
key_properties: list[str] | None = None
replication_key: str | None = None
is_view: bool | None = None
database: str | None = None
table: str | None = None
row_count: int | None = None
stream_alias: str | None = None
replication_method: str | None = None

@classmethod
def from_dict(cls, stream: Dict[str, Any]):
def from_dict(cls, stream: dict[str, Any]):
"""Create a catalog entry from a dictionary."""
return cls(
tap_stream_id=stream["tap_stream_id"],
Expand All @@ -250,15 +276,15 @@ class Catalog(Dict[str, CatalogEntry], BaseCatalog):
"""Singer catalog mapping of stream entries."""

@classmethod
def from_dict(cls, data: Dict[str, List[Dict[str, Any]]]) -> "Catalog":
def from_dict(cls, data: dict[str, list[dict[str, Any]]]) -> Catalog:
"""Create a catalog from a dictionary."""
instance = cls()
for stream in data.get("streams", []):
entry = CatalogEntry.from_dict(stream)
instance[entry.tap_stream_id] = entry
return instance

def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Return a dictionary representation of the catalog.
Returns:
Expand All @@ -267,14 +293,102 @@ def to_dict(self) -> Dict[str, Any]:
return cast(Dict[str, Any], super().to_dict())

@property
def streams(self) -> List[CatalogEntry]:
def streams(self) -> list[CatalogEntry]:
"""Get catalog entries."""
return list(self.values())

def add_stream(self, entry: CatalogEntry) -> None:
"""Add a stream entry to the catalog."""
self[entry.tap_stream_id] = entry

def get_stream(self, stream_id: str) -> Optional[CatalogEntry]:
def get_stream(self, stream_id: str) -> CatalogEntry | None:
"""Retrieve a stream entry from the catalog."""
return self.get(stream_id)


class BatchFileFormat(str, enum.Enum):
"""Batch file format."""

JSONL = "jsonl"
"""JSON Lines format."""


@dataclass
class BaseBatchFileEncoding:
"""Base class for batch file encodings."""

registered_encodings: ClassVar[dict[str, type[BaseBatchFileEncoding]]] = {}
__encoding_format__: ClassVar[str] = "OVERRIDE_ME"

# Base encoding fields
format: str = field(init=False)
"""The format of the batch file."""

compression: str | None = None
"""The compression of the batch file."""

def __init_subclass__(cls, **kwargs: Any) -> None:
"""Register subclasses."""
super().__init_subclass__(**kwargs)
cls.registered_encodings[cls.__encoding_format__] = cls

def __post_init__(self) -> None:
self.format = self.__encoding_format__

@classmethod
def from_dict(cls, data: dict[str, Any]) -> BaseBatchFileEncoding:
"""Create an encoding from a dictionary."""
encoding_format = data.pop("format")
encoding_cls = cls.registered_encodings[encoding_format]
return encoding_cls(**data)


@dataclass
class JSONLinesEncoding(BaseBatchFileEncoding):
"""JSON Lines encoding for batch files."""

__encoding_format__ = "jsonl"


@dataclass
class SDKBatchMessage(Message):
"""Singer batch message in the Meltano SDK flavor."""

type: Literal[SingerMessageType.BATCH] = field(init=False)
"""The message type."""

stream: str
"""The stream name."""

encoding: BaseBatchFileEncoding
"""The file encoding of the batch."""

manifest: list[str] = field(default_factory=list)
"""The manifest of files in the batch."""

def __post_init__(self):
if isinstance(self.encoding, dict):
self.encoding = BaseBatchFileEncoding.from_dict(self.encoding)

self.type = SingerMessageType.BATCH

def asdict(self):
"""Return a dictionary representation of the message.
Returns:
A dictionary with the defined message fields.
"""
return asdict(self)

@classmethod
def from_dict(cls, data: dict[str, Any]) -> SDKBatchMessage:
"""Create an encoding from a dictionary.
Args:
data: The dictionary to create the message from.
Returns:
The created message.
"""
data.pop("type")
return cls(**data)
18 changes: 8 additions & 10 deletions singer_sdk/io_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import abc
import enum
import json
import logging
import sys
Expand All @@ -12,19 +11,11 @@
from typing import Counter as CounterType

from singer_sdk.helpers._compat import final
from singer_sdk.helpers._singer import SingerMessageType

logger = logging.getLogger(__name__)


class SingerMessageType(str, enum.Enum):
"""Singer specification message types."""

RECORD = "RECORD"
SCHEMA = "SCHEMA"
STATE = "STATE"
ACTIVATE_VERSION = "ACTIVATE_VERSION"


class SingerReader(metaclass=abc.ABCMeta):
"""Interface for all plugins reading Singer messages from stdin."""

Expand Down Expand Up @@ -95,6 +86,9 @@ def _process_lines(self, file_input: IO[str]) -> CounterType[str]:
elif record_type == SingerMessageType.STATE:
self._process_state_message(line_dict)

elif record_type == SingerMessageType.BATCH:
self._process_batch_message(line_dict)

else:
self._process_unknown_message(line_dict)

Expand All @@ -118,6 +112,10 @@ def _process_state_message(self, message_dict: dict) -> None:
def _process_activate_version_message(self, message_dict: dict) -> None:
...

@abc.abstractmethod
def _process_batch_message(self, message_dict: dict) -> None:
...

def _process_unknown_message(self, message_dict: dict) -> None:
"""Internal method to process unknown message types from a Singer tap.
Expand Down
14 changes: 14 additions & 0 deletions singer_sdk/mapper_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def _process_state_message(self, message_dict: dict) -> None:
def _process_activate_version_message(self, message_dict: dict) -> None:
self._write_messages(self.map_activate_version_message(message_dict))

def _process_batch_message(self, message_dict: dict) -> None:
self._write_messages(self.map_batch_message(message_dict))

@abc.abstractmethod
def map_schema_message(self, message_dict: dict) -> Iterable[singer.Message]:
"""Map a schema message to zero or more new messages.
Expand Down Expand Up @@ -89,6 +92,17 @@ def map_activate_version_message(
"""
...

def map_batch_message(
self,
message_dict: dict,
) -> Iterable[singer.Message]:
"""Map a version message to zero or more new messages.
Args:
message_dict: An ACTIVATE_VERSION message JSON dictionary.
"""
pass

@classproperty
def cli(cls) -> Callable:
"""Execute standard CLI handler for inline mappers.
Expand Down
Loading

0 comments on commit 080eefe

Please sign in to comment.