From 7bb58fe8687179c2719c64c5e2573b028562f32b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Ram=C3=ADrez-Mondrag=C3=B3n?= Date: Tue, 22 Oct 2024 12:11:42 -0600 Subject: [PATCH 1/3] feat: Adapt to changes in stream key properties --- poetry.lock | 11 +++++------ pyproject.toml | 4 ++-- tests/core.py | 2 ++ 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/poetry.lock b/poetry.lock index 1d68146..674d44c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1465,13 +1465,13 @@ files = [ [[package]] name = "singer-sdk" -version = "0.41.0" +version = "0.42.0a2" description = "A framework for building Singer taps" optional = false python-versions = ">=3.8" files = [ - {file = "singer_sdk-0.41.0-py3-none-any.whl", hash = "sha256:9570377d043239c04d38d4193e0c6e164949d07382234c5895e5ea1ba273e260"}, - {file = "singer_sdk-0.41.0.tar.gz", hash = "sha256:be3a4b0ae034eda445e7dd9378999f9d19d8135fd14c9135e3bc5deaf5dbd3ad"}, + {file = "singer_sdk-0.42.0a2-py3-none-any.whl", hash = "sha256:28eb7a06d8c68e7a54c3e986f2f2bc62a241c9f3fda48edabb6b80cf7c3c0542"}, + {file = "singer_sdk-0.42.0a2.tar.gz", hash = "sha256:25d4a3107a26bf8813fed3d9101b9b49e58e4ad37a3ca6a9a590888413f71f12"}, ] [package.dependencies] @@ -1493,11 +1493,10 @@ PyYAML = ">=6.0" referencing = ">=0.30.0" requests = ">=2.25.1" setuptools = "<=70.3.0" -simpleeval = ">=0.9.13" +simpleeval = {version = ">=0.9.13", markers = "python_version >= \"3.9\""} simplejson = ">=3.17.6" sqlalchemy = ">=1.4,<3.0" typing-extensions = ">=4.5.0" -urllib3 = ">=1.26,<2" [package.extras] docs = ["furo (>=2024.5.6)", "myst-parser (>=3)", "pytest (>=7.2.1)", "sphinx (>=7)", "sphinx-copybutton (>=0.5.2)", "sphinx-inline-tabs (>=2023.4.21)", "sphinx-notfound-page (>=1.0.0)", "sphinx-reredirects (>=0.1.5)"] @@ -1775,4 +1774,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.9" -content-hash = "fd170f2abc878aacb4f7a22bf6dc9c4052c8c574ca6583801ac9f8c5d95a73e1" +content-hash = "3ef78b39fb58252d15d87d4ce296b8af83c6f4044dfcff48841fb7d6a78d8c72" diff --git a/pyproject.toml b/pyproject.toml index 5b4e5c7..f37d6b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ snowflake-connector-python = { version = "<4.0.0", extras = ["secure-local-stora sqlalchemy = "~=2.0.31" [tool.poetry.dependencies.singer-sdk] -version = "~=0.41.0" +version = "~=0.42.0a2" [tool.poetry.group.dev.dependencies] coverage = ">=7.2.7" @@ -30,7 +30,7 @@ pytest = ">=7.4.3" pytest-xdist = ">=3.3.1" [tool.poetry.group.dev.dependencies.singer-sdk] -version="~=0.41.0" +version = "~=0.42.0a2" extras = ["testing"] [tool.ruff] diff --git a/tests/core.py b/tests/core.py index 79dc433..5077d2d 100644 --- a/tests/core.py +++ b/tests/core.py @@ -16,6 +16,7 @@ TargetInvalidSchemaTest, TargetNoPrimaryKeys, TargetOptionalAttributes, + TargetPrimaryKeyUpdates, TargetRecordBeforeSchemaTest, TargetRecordMissingKeyProperty, TargetRecordMissingRequiredProperty, @@ -581,6 +582,7 @@ def singer_filepath(self) -> Path: # TODO: Not available in the SDK yet # TargetMultipleStateMessages, TargetNoPrimaryKeys, # Implicitly asserts no pk is handled + TargetPrimaryKeyUpdates, # Implicitly asserts pk updates are handled TargetOptionalAttributes, # Implicitly asserts nullable fields handled SnowflakeTargetRecordBeforeSchemaTest, SnowflakeTargetRecordMissingKeyProperty, From 7a91444531cce210acde218d80c5897784181a56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Ram=C3=ADrez-Mondrag=C3=B3n?= Date: Tue, 22 Oct 2024 12:19:13 -0600 Subject: [PATCH 2/3] Implement it --- target_snowflake/connector.py | 66 +++++++++++++++++++++++++---------- 1 file changed, 48 insertions(+), 18 deletions(-) diff --git a/target_snowflake/connector.py b/target_snowflake/connector.py index 9662e2c..14f5927 100644 --- a/target_snowflake/connector.py +++ b/target_snowflake/connector.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, cast import snowflake.sqlalchemy.custom_types as sct -import sqlalchemy +import sqlalchemy as sa from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from singer_sdk import typing as th @@ -97,7 +97,7 @@ def get_table_columns( self, full_table_name: str, column_names: list[str] | None = None, - ) -> dict[str, sqlalchemy.Column]: + ) -> dict[str, sa.Column]: """Return a list of table columns. Args: @@ -110,11 +110,11 @@ def get_table_columns( if full_table_name in self.table_cache: return self.table_cache[full_table_name] _, schema_name, table_name = self.parse_full_table_name(full_table_name) - inspector = sqlalchemy.inspect(self._engine) + inspector = sa.inspect(self._engine) columns = inspector.get_columns(table_name, schema_name) parsed_columns = { - col_meta["name"]: sqlalchemy.Column( + col_meta["name"]: sa.Column( col_meta["name"], self._convert_type(col_meta["type"]), nullable=col_meta.get("nullable", False), @@ -224,7 +224,7 @@ def create_engine(self) -> Engine: } if self.auth_method == SnowflakeAuthMethod.KEY_PAIR: connect_args["private_key"] = self.get_private_key() - engine = sqlalchemy.create_engine( + engine = sa.create_engine( self.sqlalchemy_url, connect_args=connect_args, echo=False, @@ -240,7 +240,7 @@ def prepare_column( self, full_table_name: str, column_name: str, - sql_type: sqlalchemy.types.TypeEngine, + sql_type: sa.types.TypeEngine, ) -> None: formatter = SnowflakeIdentifierPreparer(SnowflakeDialect()) # Make quoted column names upper case because we create them that way @@ -263,12 +263,42 @@ def prepare_column( ) raise + def prepare_primary_key(self, *, full_table_name: str | FullyQualifiedName, primary_keys: Sequence[str]) -> None: + _, schema_name, table_name = self.parse_full_table_name(full_table_name) + meta = sa.MetaData(schema=schema_name) + meta.reflect(bind=self._engine, only=[table_name]) + table = meta.tables[full_table_name] # type: ignore[index] + current_pk_cols = [col.name for col in table.primary_key.columns] + + # Nothing to do + if current_pk_cols == primary_keys: + return + + new_pk = sa.PrimaryKeyConstraint(*primary_keys) + + # If table has no primary key, add the provided one + if not current_pk_cols: + with self._connect() as conn, conn.begin(): + table.append_constraint(new_pk) + conn.execute(sa.schema.AddConstraint(new_pk).against(table)) + return + + # Drop the existing primary key + with self._connect() as conn, conn.begin(): + conn.execute(sa.schema.DropConstraint(table.primary_key).against(table)) + + # Add the new primary key + if primary_keys: + with self._connect() as conn, conn.begin(): + table.append_constraint(new_pk) + conn.execute(sa.schema.AddConstraint(new_pk).against(table)) + @staticmethod def get_column_rename_ddl( table_name: str, column_name: str, new_column_name: str, - ) -> sqlalchemy.DDL: + ) -> sa.DDL: formatter = SnowflakeIdentifierPreparer(SnowflakeDialect()) # Since we build the ddl manually we can't rely on SQLAlchemy to # quote column names automatically. @@ -282,8 +312,8 @@ def get_column_rename_ddl( def get_column_alter_ddl( table_name: str, column_name: str, - column_type: sqlalchemy.types.TypeEngine, - ) -> sqlalchemy.DDL: + column_type: sa.types.TypeEngine, + ) -> sa.DDL: """Get the alter column DDL statement. Override this if your database uses a different syntax for altering columns. @@ -299,7 +329,7 @@ def get_column_alter_ddl( formatter = SnowflakeIdentifierPreparer(SnowflakeDialect()) # Since we build the ddl manually we can't rely on SQLAlchemy to # quote column names automatically. - return sqlalchemy.DDL( + return sa.DDL( "ALTER TABLE %(table_name)s ALTER COLUMN %(column_name)s SET DATA TYPE %(column_type)s", { "table_name": table_name, @@ -317,7 +347,7 @@ def _conform_max_length(jsonschema_type): # noqa: ANN205, ANN001 return jsonschema_type @staticmethod - def to_sql_type(jsonschema_type: dict) -> sqlalchemy.types.TypeEngine: + def to_sql_type(jsonschema_type: dict) -> sa.types.TypeEngine: """Return a JSON Schema representation of the provided type. Uses custom Snowflake types from [snowflake-sqlalchemy](https://github.com/snowflakedb/snowflake-sqlalchemy/blob/main/src/snowflake/sqlalchemy/custom_types.py) @@ -337,9 +367,9 @@ def to_sql_type(jsonschema_type: dict) -> sqlalchemy.types.TypeEngine: # define type maps string_submaps = [ TypeMap(eq, TIMESTAMP_NTZ(), "date-time"), - TypeMap(contains, sqlalchemy.types.TIME(), "time"), - TypeMap(eq, sqlalchemy.types.DATE(), "date"), - TypeMap(eq, sqlalchemy.types.VARCHAR(maxlength), None), + TypeMap(contains, sa.types.TIME(), "time"), + TypeMap(eq, sa.types.DATE(), "date"), + TypeMap(eq, sa.types.VARCHAR(maxlength), None), ] type_maps = [ TypeMap(th._jsonschema_type_check, NUMBER(), ("integer",)), # noqa: SLF001 @@ -354,12 +384,12 @@ def to_sql_type(jsonschema_type: dict) -> sqlalchemy.types.TypeEngine: else: target_type = evaluate_typemaps(type_maps, jsonschema_type, target_type) - return cast(sqlalchemy.types.TypeEngine, target_type) + return cast(sa.types.TypeEngine, target_type) def schema_exists(self, schema_name: str) -> bool: if schema_name in self.schema_cache: return True - schema_names = sqlalchemy.inspect(self._engine).get_schema_names() + schema_names = sa.inspect(self._engine).get_schema_names() self.schema_cache = schema_names formatter = SnowflakeIdentifierPreparer(SnowflakeDialect()) # Make quoted schema names upper case because we create them that way @@ -664,7 +694,7 @@ def _adapt_column_type( self, full_table_name: str, column_name: str, - sql_type: sqlalchemy.types.TypeEngine, + sql_type: sa.types.TypeEngine, ) -> None: """Adapt table column type to support the new JSON schema type. @@ -679,7 +709,7 @@ def _adapt_column_type( try: super()._adapt_column_type(full_table_name, column_name, sql_type) except Exception: - current_type: sqlalchemy.types.TypeEngine = self._get_column_type( + current_type: sa.types.TypeEngine = self._get_column_type( full_table_name, column_name, ) From 925defbcea37b8f05f850da0f3df2794c1b3ab9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Ram=C3=ADrez-Mondrag=C3=B3n?= Date: Tue, 22 Oct 2024 12:29:19 -0600 Subject: [PATCH 3/3] Use `inspector.reflect_table` --- target_snowflake/connector.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/target_snowflake/connector.py b/target_snowflake/connector.py index 14f5927..d660b8e 100644 --- a/target_snowflake/connector.py +++ b/target_snowflake/connector.py @@ -266,8 +266,10 @@ def prepare_column( def prepare_primary_key(self, *, full_table_name: str | FullyQualifiedName, primary_keys: Sequence[str]) -> None: _, schema_name, table_name = self.parse_full_table_name(full_table_name) meta = sa.MetaData(schema=schema_name) - meta.reflect(bind=self._engine, only=[table_name]) - table = meta.tables[full_table_name] # type: ignore[index] + table = sa.Table(table_name, meta, schema=schema_name) + inspector = sa.inspect(self._engine) + inspector.reflect_table(table, None) + current_pk_cols = [col.name for col in table.primary_key.columns] # Nothing to do