From 2721dc5d00ceef91617ca91d372ad6db6f71f830 Mon Sep 17 00:00:00 2001 From: Ken Payne Date: Wed, 19 Oct 2022 23:28:55 +0100 Subject: [PATCH] fix: create schema and table on `add_sink` (#1036) * start on schema and table creation on * linting * add default schema name * add schema to table metadata * Add missing import for `singer_sdk.helpers._catalog` * undo connection module * fix copy-paste formatting * fix test * more connector changes * fix docstring * add schema creation test * remove create_table_with_records method * Update singer_sdk/sinks/sql.py Co-authored-by: Aaron ("AJ") Steers Co-authored-by: Edgar R. M Co-authored-by: Aaron ("AJ") Steers --- singer_sdk/sinks/core.py | 8 +++ singer_sdk/sinks/sql.py | 118 +++++++++++++++++--------------------- singer_sdk/streams/sql.py | 73 ++++++++++++++++------- singer_sdk/target_base.py | 7 ++- tests/core/test_sqlite.py | 36 ++++++++++++ 5 files changed, 152 insertions(+), 90 deletions(-) diff --git a/singer_sdk/sinks/core.py b/singer_sdk/sinks/core.py index d3f8badad..18487546c 100644 --- a/singer_sdk/sinks/core.py +++ b/singer_sdk/sinks/core.py @@ -423,6 +423,14 @@ def activate_version(self, new_version: int) -> None: "Ignoring." ) + def setup(self) -> None: + """Perform any setup actions at the beginning of a Stream. + + Setup is executed once per Sink instance, after instantiation. If a Schema + change is detected, a new Sink is instantiated and this method is called again. + """ + pass + def clean_up(self) -> None: """Perform any clean up actions required at end of a stream. diff --git a/singer_sdk/sinks/sql.py b/singer_sdk/sinks/sql.py index 5f37a0236..c3455d5df 100644 --- a/singer_sdk/sinks/sql.py +++ b/singer_sdk/sinks/sql.py @@ -10,7 +10,7 @@ from singer_sdk.plugin_base import PluginBase from singer_sdk.sinks.batch import BatchSink -from singer_sdk.streams.sql import SQLConnector +from singer_sdk.streams import SQLConnector class SQLSink(BatchSink): @@ -38,11 +38,7 @@ def __init__( connector: Optional connector to reuse. """ self._connector: SQLConnector - if connector: - self._connector = connector - else: - self._connector = self.connector_class(dict(target.config)) - + self._connector = connector or self.connector_class(dict(target.config)) super().__init__(target, stream_name, schema, key_properties) @property @@ -65,103 +61,93 @@ def connection(self) -> sqlalchemy.engine.Connection: @property def table_name(self) -> str: - """Returns the table name, with no schema or database part. + """Return the table name, with no schema or database part. Returns: The target table name. """ parts = self.stream_name.split("-") - - if len(parts) == 1: - return self.stream_name - else: - return parts[-1] + return self.stream_name if len(parts) == 1 else parts[-1] @property def schema_name(self) -> Optional[str]: - """Returns the schema name or `None` if using names with no schema part. + """Return the schema name or `None` if using names with no schema part. Returns: The target schema name. """ - return None # Assumes single-schema target context. + parts = self.stream_name.split("-") + if len(parts) in {2, 3}: + # Stream name is a two-part or three-part identifier. + # Use the second-to-last part as the schema name. + return parts[-2] + + # Schema name not detected. + return None @property def database_name(self) -> Optional[str]: - """Returns the DB name or `None` if using names with no database part. + """Return the DB name or `None` if using names with no database part. Returns: The target database name. """ return None # Assumes single-DB target context. - def process_batch(self, context: dict) -> None: - """Process a batch with the given batch context. - - Writes a batch to the SQL target. Developers may override this method - in order to provide a more efficient upload/upsert process. + @property + def full_table_name(self) -> str: + """Return the fully qualified table name. - Args: - context: Stream partition or context dictionary. + Returns: + The fully qualified table name. """ - # If duplicates are merged, these can be tracked via - # :meth:`~singer_sdk.Sink.tally_duplicate_merged()`. - self.connector.prepare_table( - full_table_name=self.full_table_name, - schema=self.schema, - primary_keys=self.key_properties, - as_temp_table=False, - ) - self.bulk_insert_records( - full_table_name=self.full_table_name, - schema=self.schema, - records=context["records"], + return self.connector.get_fully_qualified_name( + table_name=self.table_name, + schema_name=self.schema_name, + db_name=self.database_name, ) @property - def full_table_name(self) -> str: - """Gives the fully qualified table name. + def full_schema_name(self) -> str: + """Return the fully qualified schema name. Returns: - The fully qualified table name. + The fully qualified schema name. """ return self.connector.get_fully_qualified_name( - self.table_name, - self.schema_name, - self.database_name, + schema_name=self.schema_name, db_name=self.database_name ) - def create_table_with_records( - self, - full_table_name: Optional[str], - schema: dict, - records: Iterable[Dict[str, Any]], - primary_keys: Optional[List[str]] = None, - partition_keys: Optional[List[str]] = None, - as_temp_table: bool = False, - ) -> None: - """Create an empty table. + def setup(self) -> None: + """Set up Sink. - Args: - full_table_name: the target table name. - schema: the JSON schema for the new table. - records: records to load. - primary_keys: list of key properties. - partition_keys: list of partition keys. - as_temp_table: True to create a temp table. + This method is called on Sink creation, and creates the required Schema and + Table entities in the target database. """ - full_table_name = full_table_name or self.full_table_name - if primary_keys is None: - primary_keys = self.key_properties - partition_keys = partition_keys or None + if self.schema_name: + self.connector.prepare_schema(self.schema_name) self.connector.prepare_table( - full_table_name=full_table_name, - primary_keys=primary_keys, - schema=schema, - as_temp_table=as_temp_table, + full_table_name=self.full_table_name, + schema=self.schema, + primary_keys=self.key_properties, + as_temp_table=False, ) + + def process_batch(self, context: dict) -> None: + """Process a batch with the given batch context. + + Writes a batch to the SQL target. Developers may override this method + in order to provide a more efficient upload/upsert process. + + Args: + context: Stream partition or context dictionary. + """ + # If duplicates are merged, these can be tracked via + # :meth:`~singer_sdk.Sink.tally_duplicate_merged()`. self.bulk_insert_records( - full_table_name=full_table_name, schema=schema, records=records + full_table_name=self.full_table_name, + schema=self.schema, + records=context["records"], ) def generate_insert_statement( diff --git a/singer_sdk/streams/sql.py b/singer_sdk/streams/sql.py index 6a0392485..eead34c97 100644 --- a/singer_sdk/streams/sql.py +++ b/singer_sdk/streams/sql.py @@ -26,7 +26,6 @@ class SQLConnector: The connector class serves as a wrapper around the SQL connection. The functions of the connector are: - - connecting to the source - generating SQLAlchemy connection and engine objects - discovering schema catalog entries @@ -76,6 +75,7 @@ def create_sqlalchemy_connection(self) -> sqlalchemy.engine.Connection: By default this will create using the sqlalchemy `stream_results=True` option described here: + https://docs.sqlalchemy.org/en/14/core/connections.html#using-server-side-cursors-a-k-a-stream-results Developers may override this method if their provider does not support @@ -191,7 +191,6 @@ def to_sql_type(jsonschema_type: dict) -> sqlalchemy.types.TypeEngine: Developers may override this method to accept additional input argument types, to support non-standard types, or to provide custom typing logic. - If overriding this method, developers should call the default implementation from the base class for all unhandled cases. @@ -205,7 +204,7 @@ def to_sql_type(jsonschema_type: dict) -> sqlalchemy.types.TypeEngine: @staticmethod def get_fully_qualified_name( - table_name: str, + table_name: str | None = None, schema_name: str | None = None, db_name: str | None = None, delimiter: str = ".", @@ -219,23 +218,23 @@ def get_fully_qualified_name( delimiter: Generally: '.' for SQL names and '-' for Singer names. Raises: - ValueError: If table_name is not provided or if neither schema_name or - db_name are provided. + ValueError: If all 3 name parts not supplied. Returns: The fully qualified name as a string. """ - if db_name and schema_name: - result = delimiter.join([db_name, schema_name, table_name]) - elif db_name: - result = delimiter.join([db_name, table_name]) - elif schema_name: - result = delimiter.join([schema_name, table_name]) - elif table_name: - result = table_name - else: + parts = [] + + if db_name: + parts.append(db_name) + if schema_name: + parts.append(schema_name) + if table_name: + parts.append(table_name) + + if not parts: raise ValueError( - "Could not generate fully qualified name for stream: " + "Could not generate fully qualified name: " + ":".join( [ db_name or "(unknown-db)", @@ -245,7 +244,7 @@ def get_fully_qualified_name( ) ) - return result + return delimiter.join(parts) @property def _dialect(self) -> sqlalchemy.engine.Dialect: @@ -487,6 +486,18 @@ def table_exists(self, full_table_name: str) -> bool: sqlalchemy.inspect(self._engine).has_table(full_table_name), ) + def schema_exists(self, schema_name: str) -> bool: + """Determine if the target database schema already exists. + + Args: + schema_name: The target database schema name. + + Returns: + True if the database schema exists, False if not. + """ + schema_names = sqlalchemy.inspect(self._engine).get_schema_names() + return schema_name in schema_names + def get_table_columns( self, full_table_name: str, column_names: list[str] | None = None ) -> dict[str, sqlalchemy.Column]: @@ -547,6 +558,14 @@ def column_exists(self, full_table_name: str, column_name: str) -> bool: """ return column_name in self.get_table_columns(full_table_name) + def create_schema(self, schema_name: str) -> None: + """Create target schema. + + Args: + schema_name: The target schema to create. + """ + self._engine.execute(sqlalchemy.schema.CreateSchema(schema_name)) + def create_empty_table( self, full_table_name: str, @@ -573,7 +592,8 @@ def create_empty_table( _ = partition_keys # Not supported in generic implementation. - meta = sqlalchemy.MetaData() + _, schema_name, table_name = self.parse_full_table_name(full_table_name) + meta = sqlalchemy.MetaData(schema=schema_name) columns: list[sqlalchemy.Column] = [] primary_keys = primary_keys or [] try: @@ -592,7 +612,7 @@ def create_empty_table( ) ) - _ = sqlalchemy.Table(full_table_name, meta, *columns) + _ = sqlalchemy.Table(table_name, meta, *columns) meta.create_all(self._engine) def _create_empty_column( @@ -630,6 +650,16 @@ def _create_empty_column( ) ) + def prepare_schema(self, schema_name: str) -> None: + """Create the target database schema. + + Args: + schema_name: The target schema name. + """ + schema_exists = self.schema_exists(schema_name) + if not schema_exists: + self.create_schema(schema_name) + def prepare_table( self, full_table_name: str, @@ -788,6 +818,7 @@ def _sort_types( For example, [Smallint, Integer, Datetime, String, Double] would become [Unicode, String, Double, Integer, Smallint, Datetime]. + String types will be listed first, then decimal types, then integer types, then bool types, and finally datetime and date. Higher precision, scale, and length will be sorted earlier. @@ -823,7 +854,7 @@ def _get_type_sort_key( def _get_column_type( self, full_table_name: str, column_name: str ) -> sqlalchemy.types.TypeEngine: - """Gets the SQL type of the declared column. + """Get the SQL type of the declared column. Args: full_table_name: The name of the table. @@ -937,7 +968,7 @@ def _singer_catalog_entry(self) -> CatalogEntry: @property def connector(self) -> SQLConnector: - """The connector object. + """Return a connector object. Returns: The connector object. @@ -946,7 +977,7 @@ def connector(self) -> SQLConnector: @property def metadata(self) -> MetadataMapping: - """The Singer metadata. + """Return the Singer metadata. Metadata from an input catalog will override standard metadata. diff --git a/singer_sdk/target_base.py b/singer_sdk/target_base.py index 0b21b04c5..a1ddd5e78 100644 --- a/singer_sdk/target_base.py +++ b/singer_sdk/target_base.py @@ -224,14 +224,15 @@ def add_sink( """ self.logger.info(f"Initializing '{self.name}' target sink...") sink_class = self.get_sink_class(stream_name=stream_name) - result = sink_class( + sink = sink_class( target=self, stream_name=stream_name, schema=schema, key_properties=key_properties, ) - self._sinks_active[stream_name] = result - return result + sink.setup() + self._sinks_active[stream_name] = sink + return sink def _assert_sink_exists(self, stream_name: str) -> None: """Raise a RecordsWithoutSchemaException exception if stream doesn't exist. diff --git a/tests/core/test_sqlite.py b/tests/core/test_sqlite.py index 5e76d8168..b9c7c82ca 100644 --- a/tests/core/test_sqlite.py +++ b/tests/core/test_sqlite.py @@ -10,6 +10,7 @@ from uuid import uuid4 import pytest +import sqlalchemy from samples.sample_tap_sqlite import SQLiteConnector, SQLiteTap from samples.sample_target_csv.csv_target import SampleTargetCSV @@ -265,6 +266,41 @@ def test_sync_sqlite_to_sqlite( assert line_num > 0, "No lines read." +def test_sqlite_schema_addition( + sqlite_target_test_config: dict, sqlite_sample_target: SQLTarget +): + """Test that SQL-based targets attempt to create new schema if included in stream name.""" + schema_name = f"test_schema_{str(uuid4()).split('-')[-1]}" + table_name = f"zzz_tmp_{str(uuid4()).split('-')[-1]}" + test_stream_name = f"{schema_name}-{table_name}" + schema_message = { + "type": "SCHEMA", + "stream": test_stream_name, + "schema": { + "type": "object", + "properties": {"col_a": th.StringType().to_dict()}, + }, + } + tap_output = "\n".join( + json.dumps(msg) + for msg in [ + schema_message, + { + "type": "RECORD", + "stream": test_stream_name, + "record": {"col_a": "samplerow1"}, + }, + ] + ) + # sqlite doesn't support schema creation + with pytest.raises(sqlalchemy.exc.OperationalError) as excinfo: + target_sync_test( + sqlite_sample_target, input=StringIO(tap_output), finalize=True + ) + # check the target at least tried to create the schema + assert excinfo.value.statement == f"CREATE SCHEMA {schema_name}" + + def test_sqlite_column_addition(sqlite_sample_target: SQLTarget): """End-to-end-to-end test for SQLite tap and target.