diff --git a/pyproject.toml b/pyproject.toml index b521d30ec7..2ac48884f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -147,6 +147,7 @@ exclude_lines = [ "raise NotImplementedError", "if __name__ == .__main__.:", '''class .*\bProtocol\):''', + "if TYPE_CHECKING:", ] fail_under = 82 diff --git a/singer_sdk/helpers/_singer.py b/singer_sdk/helpers/_singer.py index 8b74643d38..4abf502549 100644 --- a/singer_sdk/helpers/_singer.py +++ b/singer_sdk/helpers/_singer.py @@ -1,18 +1,41 @@ +from __future__ import annotations + +import enum import logging -from dataclasses import dataclass, fields +import sys +from dataclasses import asdict, dataclass, field, fields from enum import Enum -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, Tuple, Union, cast from singer.catalog import Catalog as BaseCatalog from singer.catalog import CatalogEntry as BaseCatalogEntry +from singer.messages import Message from singer_sdk.helpers._schema import SchemaPlus +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + if sys.version_info >= (3, 8): + from typing import Literal + else: + from typing_extensions import Literal + Breadcrumb = Tuple[str, ...] logger = logging.getLogger(__name__) +class SingerMessageType(str, enum.Enum): + """Singer specification message types.""" + + RECORD = "RECORD" + SCHEMA = "SCHEMA" + STATE = "STATE" + ACTIVATE_VERSION = "ACTIVATE_VERSION" + BATCH = "BATCH" + + class SelectionMask(Dict[Breadcrumb, bool]): """Boolean mask for property selection in schemas and records.""" @@ -40,28 +63,28 @@ class InclusionType(str, Enum): AUTOMATIC = "automatic" UNSUPPORTED = "unsupported" - inclusion: Optional[InclusionType] = None - selected: Optional[bool] = None - selected_by_default: Optional[bool] = None + inclusion: InclusionType | None = None + selected: bool | None = None + selected_by_default: bool | None = None @classmethod - def from_dict(cls, value: Dict[str, Any]): + def from_dict(cls, value: dict[str, Any]): """Parse metadata dictionary.""" return cls( **{ - field.name: value.get(field.name.replace("_", "-")) - for field in fields(cls) + object_field.name: value.get(object_field.name.replace("_", "-")) + for object_field in fields(cls) } ) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Convert metadata to a JSON-encodeable dictionary.""" result = {} - for field in fields(self): - value = getattr(self, field.name) + for object_field in fields(self): + value = getattr(self, object_field.name) if value is not None: - result[field.name.replace("_", "-")] = value + result[object_field.name.replace("_", "-")] = value return result @@ -70,19 +93,22 @@ def to_dict(self) -> Dict[str, Any]: class StreamMetadata(Metadata): """Stream metadata.""" - table_key_properties: Optional[List[str]] = None - forced_replication_method: Optional[str] = None - valid_replication_keys: Optional[List[str]] = None - schema_name: Optional[str] = None + table_key_properties: list[str] | None = None + forced_replication_method: str | None = None + valid_replication_keys: list[str] | None = None + schema_name: str | None = None -class MetadataMapping(Dict[Breadcrumb, Union[Metadata, StreamMetadata]]): +AnyMetadata: TypeAlias = Union[Metadata, StreamMetadata] + + +class MetadataMapping(Dict[Breadcrumb, AnyMetadata]): """Stream metadata mapping.""" @classmethod - def from_iterable(cls, iterable: Iterable[Dict[str, Any]]): + def from_iterable(cls, iterable: Iterable[dict[str, Any]]): """Create a metadata mapping from an iterable of metadata dictionaries.""" - mapping = cls() + mapping: dict[Breadcrumb, AnyMetadata] = cls() for d in iterable: breadcrumb = tuple(d["breadcrumb"]) metadata = d["metadata"] @@ -93,7 +119,7 @@ def from_iterable(cls, iterable: Iterable[Dict[str, Any]]): return mapping - def to_list(self) -> List[Dict[str, Any]]: + def to_list(self) -> list[dict[str, Any]]: """Convert mapping to a JSON-encodable list.""" return [ {"breadcrumb": list(k), "metadata": v.to_dict()} for k, v in self.items() @@ -113,11 +139,11 @@ def root(self): @classmethod def get_standard_metadata( cls, - schema: Optional[Dict[str, Any]] = None, - schema_name: Optional[str] = None, - key_properties: Optional[List[str]] = None, - valid_replication_keys: Optional[List[str]] = None, - replication_method: Optional[str] = None, + schema: dict[str, Any] | None = None, + schema_name: str | None = None, + key_properties: list[str] | None = None, + valid_replication_keys: list[str] | None = None, + replication_method: str | None = None, ): """Get default metadata for a stream.""" mapping = cls() @@ -212,18 +238,18 @@ class CatalogEntry(BaseCatalogEntry): tap_stream_id: str metadata: MetadataMapping schema: SchemaPlus - stream: Optional[str] = None - key_properties: Optional[List[str]] = None - replication_key: Optional[str] = None - is_view: Optional[bool] = None - database: Optional[str] = None - table: Optional[str] = None - row_count: Optional[int] = None - stream_alias: Optional[str] = None - replication_method: Optional[str] = None + stream: str | None = None + key_properties: list[str] | None = None + replication_key: str | None = None + is_view: bool | None = None + database: str | None = None + table: str | None = None + row_count: int | None = None + stream_alias: str | None = None + replication_method: str | None = None @classmethod - def from_dict(cls, stream: Dict[str, Any]): + def from_dict(cls, stream: dict[str, Any]): """Create a catalog entry from a dictionary.""" return cls( tap_stream_id=stream["tap_stream_id"], @@ -250,7 +276,7 @@ class Catalog(Dict[str, CatalogEntry], BaseCatalog): """Singer catalog mapping of stream entries.""" @classmethod - def from_dict(cls, data: Dict[str, List[Dict[str, Any]]]) -> "Catalog": + def from_dict(cls, data: dict[str, list[dict[str, Any]]]) -> Catalog: """Create a catalog from a dictionary.""" instance = cls() for stream in data.get("streams", []): @@ -258,7 +284,7 @@ def from_dict(cls, data: Dict[str, List[Dict[str, Any]]]) -> "Catalog": instance[entry.tap_stream_id] = entry return instance - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Return a dictionary representation of the catalog. Returns: @@ -267,7 +293,7 @@ def to_dict(self) -> Dict[str, Any]: return cast(Dict[str, Any], super().to_dict()) @property - def streams(self) -> List[CatalogEntry]: + def streams(self) -> list[CatalogEntry]: """Get catalog entries.""" return list(self.values()) @@ -275,6 +301,94 @@ def add_stream(self, entry: CatalogEntry) -> None: """Add a stream entry to the catalog.""" self[entry.tap_stream_id] = entry - def get_stream(self, stream_id: str) -> Optional[CatalogEntry]: + def get_stream(self, stream_id: str) -> CatalogEntry | None: """Retrieve a stream entry from the catalog.""" return self.get(stream_id) + + +class BatchFileFormat(str, enum.Enum): + """Batch file format.""" + + JSONL = "jsonl" + """JSON Lines format.""" + + +@dataclass +class BaseBatchFileEncoding: + """Base class for batch file encodings.""" + + registered_encodings: ClassVar[dict[str, type[BaseBatchFileEncoding]]] = {} + __encoding_format__: ClassVar[str] = "OVERRIDE_ME" + + # Base encoding fields + format: str = field(init=False) + """The format of the batch file.""" + + compression: str | None = None + """The compression of the batch file.""" + + def __init_subclass__(cls, **kwargs: Any) -> None: + """Register subclasses.""" + super().__init_subclass__(**kwargs) + cls.registered_encodings[cls.__encoding_format__] = cls + + def __post_init__(self) -> None: + self.format = self.__encoding_format__ + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> BaseBatchFileEncoding: + """Create an encoding from a dictionary.""" + encoding_format = data.pop("format") + encoding_cls = cls.registered_encodings[encoding_format] + return encoding_cls(**data) + + +@dataclass +class JSONLinesEncoding(BaseBatchFileEncoding): + """JSON Lines encoding for batch files.""" + + __encoding_format__ = "jsonl" + + +@dataclass +class SDKBatchMessage(Message): + """Singer batch message in the Meltano SDK flavor.""" + + type: Literal[SingerMessageType.BATCH] = field(init=False) + """The message type.""" + + stream: str + """The stream name.""" + + encoding: BaseBatchFileEncoding + """The file encoding of the batch.""" + + manifest: list[str] = field(default_factory=list) + """The manifest of files in the batch.""" + + def __post_init__(self): + if isinstance(self.encoding, dict): + self.encoding = BaseBatchFileEncoding.from_dict(self.encoding) + + self.type = SingerMessageType.BATCH + + def asdict(self): + """Return a dictionary representation of the message. + + Returns: + A dictionary with the defined message fields. + """ + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> SDKBatchMessage: + """Create an encoding from a dictionary. + + Args: + data: The dictionary to create the message from. + + Returns: + The created message. + """ + data.pop("type") + return cls(**data) diff --git a/singer_sdk/io_base.py b/singer_sdk/io_base.py index f5da20d19a..32f48ae36a 100644 --- a/singer_sdk/io_base.py +++ b/singer_sdk/io_base.py @@ -3,7 +3,6 @@ from __future__ import annotations import abc -import enum import json import logging import sys @@ -12,19 +11,11 @@ from typing import Counter as CounterType from singer_sdk.helpers._compat import final +from singer_sdk.helpers._singer import SingerMessageType logger = logging.getLogger(__name__) -class SingerMessageType(str, enum.Enum): - """Singer specification message types.""" - - RECORD = "RECORD" - SCHEMA = "SCHEMA" - STATE = "STATE" - ACTIVATE_VERSION = "ACTIVATE_VERSION" - - class SingerReader(metaclass=abc.ABCMeta): """Interface for all plugins reading Singer messages from stdin.""" @@ -95,6 +86,9 @@ def _process_lines(self, file_input: IO[str]) -> CounterType[str]: elif record_type == SingerMessageType.STATE: self._process_state_message(line_dict) + elif record_type == SingerMessageType.BATCH: + self._process_batch_message(line_dict) + else: self._process_unknown_message(line_dict) @@ -118,6 +112,10 @@ def _process_state_message(self, message_dict: dict) -> None: def _process_activate_version_message(self, message_dict: dict) -> None: ... + @abc.abstractmethod + def _process_batch_message(self, message_dict: dict) -> None: + ... + def _process_unknown_message(self, message_dict: dict) -> None: """Internal method to process unknown message types from a Singer tap. diff --git a/singer_sdk/mapper_base.py b/singer_sdk/mapper_base.py index c09d39255f..abe218e5a6 100644 --- a/singer_sdk/mapper_base.py +++ b/singer_sdk/mapper_base.py @@ -50,6 +50,9 @@ def _process_state_message(self, message_dict: dict) -> None: def _process_activate_version_message(self, message_dict: dict) -> None: self._write_messages(self.map_activate_version_message(message_dict)) + def _process_batch_message(self, message_dict: dict) -> None: + self._write_messages(self.map_batch_message(message_dict)) + @abc.abstractmethod def map_schema_message(self, message_dict: dict) -> Iterable[singer.Message]: """Map a schema message to zero or more new messages. @@ -89,6 +92,17 @@ def map_activate_version_message( """ ... + def map_batch_message( + self, + message_dict: dict, + ) -> Iterable[singer.Message]: + """Map a version message to zero or more new messages. + + Args: + message_dict: An ACTIVATE_VERSION message JSON dictionary. + """ + pass + @classproperty def cli(cls) -> Callable: """Execute standard CLI handler for inline mappers. diff --git a/singer_sdk/streams/core.py b/singer_sdk/streams/core.py index 6c5208981d..8956bb27aa 100644 --- a/singer_sdk/streams/core.py +++ b/singer_sdk/streams/core.py @@ -3,8 +3,11 @@ import abc import copy import datetime +import gzip +import itertools import json import logging +import os from os import PathLike from pathlib import Path from types import MappingProxyType @@ -14,6 +17,7 @@ Dict, Generator, Iterable, + Iterator, List, Mapping, Optional, @@ -23,6 +27,7 @@ Union, cast, ) +from uuid import uuid4 import pendulum import requests @@ -35,9 +40,12 @@ from singer_sdk.helpers._flattening import get_flattening_options from singer_sdk.helpers._schema import SchemaPlus from singer_sdk.helpers._singer import ( + BaseBatchFileEncoding, Catalog, CatalogEntry, + JSONLinesEncoding, MetadataMapping, + SDKBatchMessage, SelectionMask, ) from singer_sdk.helpers._state import ( @@ -62,10 +70,32 @@ REPLICATION_LOG_BASED = "LOG_BASED" FactoryType = TypeVar("FactoryType", bound="Stream") +_T = TypeVar("_T") METRICS_LOG_LEVEL_SETTING = "metrics_log_level" +def lazy_chunked_generator( + iterable: Iterable[_T], + chunk_size: int, +) -> Generator[Iterator[_T], None, None]: + """Yield a generator for each chunk of the given iterable. + + Args: + iterable: The iterable to chunk. + chunk_size: The size of each chunk. + + Yields: + A generator for each chunk of the given iterable. + """ + iterator = iter(iterable) + while True: + chunk = list(itertools.islice(iterator, chunk_size)) + if not chunk: + break + yield iter(chunk) + + class Stream(metaclass=abc.ABCMeta): """Abstract base class for tap streams.""" @@ -79,6 +109,9 @@ class Stream(metaclass=abc.ABCMeta): # Internal API cost aggregator _sync_costs: Dict[str, int] = {} + # Batch attributes + batch_size: int = 100 + def __init__( self, tap: TapBaseClass, @@ -803,6 +836,25 @@ def _write_record_message(self, record: dict) -> None: for record_message in self._generate_record_messages(record): singer.write_message(record_message) + def _write_batch_message( + self, + encoding: BaseBatchFileEncoding, + manifest: List[str], + ) -> None: + """Write out a BATCH message. + + Args: + encoding: The encoding to use for the batch. + manifest: A list of filenames for the batch. + """ + singer.write_message( + SDKBatchMessage( + stream=self.name, + encoding=encoding, + manifest=manifest, + ) + ) + @property def _metric_logging_function(self) -> Optional[Callable]: """Return the metrics logging function. @@ -973,21 +1025,15 @@ def finalize_state_progress_markers(self, state: Optional[dict] = None) -> None: def _process_record( self, record: dict, - selected: bool, - record_context: Optional[dict] = None, child_context: Optional[dict] = None, partition_context: Optional[dict] = None, - count: int = 0, ) -> None: """Process a record. Args: record: The record to process. - selected: Whether the stream is selected. - record_context: The record context. child_context: The child context. partition_context: The partition context. - count: The current record count per stream. """ partition_context = partition_context or {} child_context = copy.copy( @@ -1002,20 +1048,22 @@ def _process_record( if self.stream_maps[0].get_filter_result(record): self._sync_children(child_context) - if selected: - if (count - 1) % self.STATE_MSG_FREQUENCY == 0: - self._write_state_message() - self._write_record_message(record) - self._increment_stream_state(record, context=record_context) - - def _sync_records(self, context: Optional[dict] = None) -> None: + def _sync_records( + self, + context: Optional[dict] = None, + write_messages: bool = True, + ) -> Generator[dict, Any, Any]: """Sync records, emitting RECORD and STATE messages. Args: context: Stream partition or context dictionary. + write_messages: Whether to write Singer messages to stdout. Raises: InvalidStreamSortException: TODO + + Yields: + Each record from the source. """ record_count = 0 current_context: Optional[dict] @@ -1042,11 +1090,8 @@ def _sync_records(self, context: Optional[dict] = None) -> None: try: self._process_record( record, - selected, - record_context=current_context, child_context=child_context, partition_context=state_partition_context, - count=record_count, ) except InvalidStreamSortException as ex: log_sort_error( @@ -1060,8 +1105,20 @@ def _sync_records(self, context: Optional[dict] = None) -> None: ) raise ex - record_count += 1 - partition_record_count += 1 + if selected: + if ( + record_count - 1 + ) % self.STATE_MSG_FREQUENCY == 0 and write_messages: + self._write_state_message() + if write_messages: + self._write_record_message(record) + self._increment_stream_state(record, context=current_context) + + yield record + + record_count += 1 + partition_record_count += 1 + if current_context == state_partition_context: # Finalize per-partition state only if 1:1 with context finalize_state_progress_markers(state) @@ -1070,8 +1127,20 @@ def _sync_records(self, context: Optional[dict] = None) -> None: # Otherwise will be finalized by tap at end of sync. finalize_state_progress_markers(self.stream_state) self._write_record_count_log(record_count=record_count, context=context) - # Reset interim bookmarks before emitting final STATE message: - self._write_state_message() + + if write_messages: + # Reset interim bookmarks before emitting final STATE message: + self._write_state_message() + + def _sync_batches(self, context: Optional[dict] = None) -> None: + """Sync batches, emitting BATCH messages. + + Args: + context: Stream partition or context dictionary. + """ + for encoding, manifest in self.get_batches(context): + self._write_batch_message(encoding=encoding, manifest=manifest) + self._write_state_message() # Public methods ("final", not recommended to be overridden) @@ -1097,8 +1166,16 @@ def sync(self, context: Optional[dict] = None) -> None: # Send a SCHEMA message to the downstream target: if self.selected: self._write_schema_message() - # Sync the records themselves: - self._sync_records(context) + + # TODO: This is a temporary hack to toggle BATCH mode during development. + batch_mode = os.getenv("SINGER_BATCH_MODE", "false") == "true" + + if batch_mode: + self._sync_batches(context=context) + else: + # Sync the records themselves: + for _ in self._sync_records(context=context): + pass def _sync_children(self, child_context: dict) -> None: for child_stream in self.child_streams: @@ -1209,6 +1286,33 @@ def get_records( """ pass + def get_batches( + self, + context: Optional[dict] = None, + ) -> Iterable[Tuple[BaseBatchFileEncoding, List[str]]]: + """Batch generator function. + + Developers are encouraged to override this method to customize batching + behavior for databases, bulk APIs, etc. + + Args: + context: Stream partition or context dictionary. + + Yields: + A tuple of (encoding, manifest) for each batch. + """ + for chunk in lazy_chunked_generator( + self._sync_records(context, write_messages=False), + self.batch_size, + ): + filename = f"output/{self.name}-{uuid4()}.json.gz" + + # TODO: Determine compression from config. + with gzip.open(filename, "wb") as f: + f.writelines((json.dumps(record) + "\n").encode() for record in chunk) + + yield JSONLinesEncoding(compression="gzip"), [filename] + def post_process(self, row: dict, context: Optional[dict] = None) -> Optional[dict]: """As needed, append or transform raw data to match expected structure. diff --git a/singer_sdk/target_base.py b/singer_sdk/target_base.py index dad17f09cd..b045062407 100644 --- a/singer_sdk/target_base.py +++ b/singer_sdk/target_base.py @@ -16,6 +16,7 @@ from singer_sdk.exceptions import RecordsWitoutSchemaException from singer_sdk.helpers._classproperty import classproperty from singer_sdk.helpers._compat import final +from singer_sdk.helpers._singer import BaseBatchFileEncoding from singer_sdk.helpers.capabilities import CapabilitiesEnum, PluginCapabilities from singer_sdk.io_base import SingerMessageType, SingerReader from singer_sdk.mapper import PluginMapper @@ -266,6 +267,7 @@ def _process_lines(self, file_input: IO[str]) -> Counter[str]: self.logger.info( f"Target '{self.name}' completed reading {line_count} lines of input " f"({counter[SingerMessageType.RECORD]} records, " + f"({counter[SingerMessageType.BATCH]} batch manifests, " f"{counter[SingerMessageType.STATE]} state messages)." ) @@ -400,6 +402,15 @@ def _process_activate_version_message(self, message_dict: dict) -> None: sink = self.get_sink(stream_name) sink.activate_version(message_dict["version"]) + def _process_batch_message(self, message_dict: dict) -> None: + """Handle the optional BATCH message extension. + + Args: + message_dict: TODO + """ + encoding = BaseBatchFileEncoding.from_dict(message_dict["encoding"]) + self.logger.info("Processing record batch encoded as %s", encoding) + # Sink drain methods @final diff --git a/tests/core/test_singer_messages.py b/tests/core/test_singer_messages.py new file mode 100644 index 0000000000..185606631e --- /dev/null +++ b/tests/core/test_singer_messages.py @@ -0,0 +1,57 @@ +from dataclasses import asdict + +import pytest + +from singer_sdk.helpers._singer import ( + BaseBatchFileEncoding, + JSONLinesEncoding, + SDKBatchMessage, + SingerMessageType, +) + + +@pytest.mark.parametrize( + "encoding,expected", + [ + (JSONLinesEncoding("gzip"), {"compression": "gzip", "format": "jsonl"}), + (JSONLinesEncoding(), {"compression": None, "format": "jsonl"}), + ], + ids=["jsonl-compression-gzip", "jsonl-compression-none"], +) +def test_encoding_as_dict(encoding: BaseBatchFileEncoding, expected: dict) -> None: + """Test encoding as dict.""" + assert asdict(encoding) == expected + + +@pytest.mark.parametrize( + "message,expected", + [ + ( + SDKBatchMessage( + stream="test_stream", + encoding=JSONLinesEncoding("gzip"), + manifest=[ + "path/to/file1.jsonl.gz", + "path/to/file2.jsonl.gz", + ], + ), + { + "type": SingerMessageType.BATCH, + "stream": "test_stream", + "encoding": {"compression": "gzip", "format": "jsonl"}, + "manifest": [ + "path/to/file1.jsonl.gz", + "path/to/file2.jsonl.gz", + ], + }, + ) + ], + ids=["batch-message-jsonl"], +) +def test_batch_message_as_dict(message, expected): + """Test batch message as dict.""" + + dumped = message.asdict() + assert dumped == expected + + assert message.from_dict(dumped) == message