Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Adapt to changes in stream key properties #280

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ snowflake-connector-python = { version = "<4.0.0", extras = ["secure-local-stora
sqlalchemy = "~=2.0.31"

[tool.poetry.dependencies.singer-sdk]
version = "~=0.41.0"
version = "~=0.42.0a2"

[tool.poetry.group.dev.dependencies]
coverage = ">=7.2.7"
pytest = ">=7.4.3"
pytest-xdist = ">=3.3.1"

[tool.poetry.group.dev.dependencies.singer-sdk]
version="~=0.41.0"
version = "~=0.42.0a2"
extras = ["testing"]

[tool.ruff]
Expand Down
68 changes: 50 additions & 18 deletions target_snowflake/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import TYPE_CHECKING, Any, cast

import snowflake.sqlalchemy.custom_types as sct
import sqlalchemy
import sqlalchemy as sa
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from singer_sdk import typing as th
Expand Down Expand Up @@ -97,7 +97,7 @@ def get_table_columns(
self,
full_table_name: str,
column_names: list[str] | None = None,
) -> dict[str, sqlalchemy.Column]:
) -> dict[str, sa.Column]:
"""Return a list of table columns.

Args:
Expand All @@ -110,11 +110,11 @@ def get_table_columns(
if full_table_name in self.table_cache:
return self.table_cache[full_table_name]
_, schema_name, table_name = self.parse_full_table_name(full_table_name)
inspector = sqlalchemy.inspect(self._engine)
inspector = sa.inspect(self._engine)
columns = inspector.get_columns(table_name, schema_name)

parsed_columns = {
col_meta["name"]: sqlalchemy.Column(
col_meta["name"]: sa.Column(
col_meta["name"],
self._convert_type(col_meta["type"]),
nullable=col_meta.get("nullable", False),
Expand Down Expand Up @@ -224,7 +224,7 @@ def create_engine(self) -> Engine:
}
if self.auth_method == SnowflakeAuthMethod.KEY_PAIR:
connect_args["private_key"] = self.get_private_key()
engine = sqlalchemy.create_engine(
engine = sa.create_engine(
self.sqlalchemy_url,
connect_args=connect_args,
echo=False,
Expand All @@ -240,7 +240,7 @@ def prepare_column(
self,
full_table_name: str,
column_name: str,
sql_type: sqlalchemy.types.TypeEngine,
sql_type: sa.types.TypeEngine,
) -> None:
formatter = SnowflakeIdentifierPreparer(SnowflakeDialect())
# Make quoted column names upper case because we create them that way
Expand All @@ -263,12 +263,44 @@ def prepare_column(
)
raise

def prepare_primary_key(self, *, full_table_name: str | FullyQualifiedName, primary_keys: Sequence[str]) -> None:
_, schema_name, table_name = self.parse_full_table_name(full_table_name)
meta = sa.MetaData(schema=schema_name)
table = sa.Table(table_name, meta, schema=schema_name)
inspector = sa.inspect(self._engine)
inspector.reflect_table(table, None)

current_pk_cols = [col.name for col in table.primary_key.columns]

# Nothing to do
if current_pk_cols == primary_keys:
return

new_pk = sa.PrimaryKeyConstraint(*primary_keys)

# If table has no primary key, add the provided one
if not current_pk_cols:
with self._connect() as conn, conn.begin():
table.append_constraint(new_pk)
conn.execute(sa.schema.AddConstraint(new_pk).against(table))
return

# Drop the existing primary key
with self._connect() as conn, conn.begin():
conn.execute(sa.schema.DropConstraint(table.primary_key).against(table))

# Add the new primary key
if primary_keys:
with self._connect() as conn, conn.begin():
table.append_constraint(new_pk)
conn.execute(sa.schema.AddConstraint(new_pk).against(table))

@staticmethod
def get_column_rename_ddl(
table_name: str,
column_name: str,
new_column_name: str,
) -> sqlalchemy.DDL:
) -> sa.DDL:
formatter = SnowflakeIdentifierPreparer(SnowflakeDialect())
# Since we build the ddl manually we can't rely on SQLAlchemy to
# quote column names automatically.
Expand All @@ -282,8 +314,8 @@ def get_column_rename_ddl(
def get_column_alter_ddl(
table_name: str,
column_name: str,
column_type: sqlalchemy.types.TypeEngine,
) -> sqlalchemy.DDL:
column_type: sa.types.TypeEngine,
) -> sa.DDL:
"""Get the alter column DDL statement.

Override this if your database uses a different syntax for altering columns.
Expand All @@ -299,7 +331,7 @@ def get_column_alter_ddl(
formatter = SnowflakeIdentifierPreparer(SnowflakeDialect())
# Since we build the ddl manually we can't rely on SQLAlchemy to
# quote column names automatically.
return sqlalchemy.DDL(
return sa.DDL(
"ALTER TABLE %(table_name)s ALTER COLUMN %(column_name)s SET DATA TYPE %(column_type)s",
{
"table_name": table_name,
Expand All @@ -317,7 +349,7 @@ def _conform_max_length(jsonschema_type): # noqa: ANN205, ANN001
return jsonschema_type

@staticmethod
def to_sql_type(jsonschema_type: dict) -> sqlalchemy.types.TypeEngine:
def to_sql_type(jsonschema_type: dict) -> sa.types.TypeEngine:
"""Return a JSON Schema representation of the provided type.

Uses custom Snowflake types from [snowflake-sqlalchemy](https://github.com/snowflakedb/snowflake-sqlalchemy/blob/main/src/snowflake/sqlalchemy/custom_types.py)
Expand All @@ -337,9 +369,9 @@ def to_sql_type(jsonschema_type: dict) -> sqlalchemy.types.TypeEngine:
# define type maps
string_submaps = [
TypeMap(eq, TIMESTAMP_NTZ(), "date-time"),
TypeMap(contains, sqlalchemy.types.TIME(), "time"),
TypeMap(eq, sqlalchemy.types.DATE(), "date"),
TypeMap(eq, sqlalchemy.types.VARCHAR(maxlength), None),
TypeMap(contains, sa.types.TIME(), "time"),
TypeMap(eq, sa.types.DATE(), "date"),
TypeMap(eq, sa.types.VARCHAR(maxlength), None),
]
type_maps = [
TypeMap(th._jsonschema_type_check, NUMBER(), ("integer",)), # noqa: SLF001
Expand All @@ -354,12 +386,12 @@ def to_sql_type(jsonschema_type: dict) -> sqlalchemy.types.TypeEngine:
else:
target_type = evaluate_typemaps(type_maps, jsonschema_type, target_type)

return cast(sqlalchemy.types.TypeEngine, target_type)
return cast(sa.types.TypeEngine, target_type)

def schema_exists(self, schema_name: str) -> bool:
if schema_name in self.schema_cache:
return True
schema_names = sqlalchemy.inspect(self._engine).get_schema_names()
schema_names = sa.inspect(self._engine).get_schema_names()
self.schema_cache = schema_names
formatter = SnowflakeIdentifierPreparer(SnowflakeDialect())
# Make quoted schema names upper case because we create them that way
Expand Down Expand Up @@ -664,7 +696,7 @@ def _adapt_column_type(
self,
full_table_name: str,
column_name: str,
sql_type: sqlalchemy.types.TypeEngine,
sql_type: sa.types.TypeEngine,
) -> None:
"""Adapt table column type to support the new JSON schema type.

Expand All @@ -679,7 +711,7 @@ def _adapt_column_type(
try:
super()._adapt_column_type(full_table_name, column_name, sql_type)
except Exception:
current_type: sqlalchemy.types.TypeEngine = self._get_column_type(
current_type: sa.types.TypeEngine = self._get_column_type(
full_table_name,
column_name,
)
Expand Down
2 changes: 2 additions & 0 deletions tests/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
TargetInvalidSchemaTest,
TargetNoPrimaryKeys,
TargetOptionalAttributes,
TargetPrimaryKeyUpdates,
TargetRecordBeforeSchemaTest,
TargetRecordMissingKeyProperty,
TargetRecordMissingRequiredProperty,
Expand Down Expand Up @@ -581,6 +582,7 @@ def singer_filepath(self) -> Path:
# TODO: Not available in the SDK yet
# TargetMultipleStateMessages,
TargetNoPrimaryKeys, # Implicitly asserts no pk is handled
TargetPrimaryKeyUpdates, # Implicitly asserts pk updates are handled
TargetOptionalAttributes, # Implicitly asserts nullable fields handled
SnowflakeTargetRecordBeforeSchemaTest,
SnowflakeTargetRecordMissingKeyProperty,
Expand Down
Loading