Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handle reserved words in table names #232

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 34 additions & 13 deletions target_snowflake/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from snowflake.sqlalchemy import URL
from snowflake.sqlalchemy.base import SnowflakeIdentifierPreparer
from snowflake.sqlalchemy.snowdialect import SnowflakeDialect
from sqlalchemy.sql import quoted_name, text
from sqlalchemy.sql import text

from target_snowflake.snowflake_types import NUMBER, TIMESTAMP_NTZ, VARIANT

Expand Down Expand Up @@ -191,7 +191,7 @@ def prepare_column(
# Make quoted column names upper case because we create them that way
# and the metadata that SQLAlchemy returns is case insensitive only for non-quoted
# column names so these will look like they dont exist yet.
if '"' in formatter.format_collation(column_name):
if '"' in formatter.quote(column_name):
column_name = column_name.upper()

try:
Expand All @@ -208,6 +208,18 @@ def prepare_column(
)
raise

@staticmethod
def get_column_add_ddl(
table_name: str,
column_name: str,
column_type: sqlalchemy.types.TypeEngine,
) -> sqlalchemy.DDL:
return SQLConnector.get_column_add_ddl(
SnowflakeConnector._escape_full_table_name(table_name),
column_name,
column_type,
)

@staticmethod
def get_column_rename_ddl(
table_name: str,
Expand All @@ -218,9 +230,9 @@ def get_column_rename_ddl(
# Since we build the ddl manually we can't rely on SQLAlchemy to
# quote column names automatically.
return SQLConnector.get_column_rename_ddl(
table_name,
formatter.format_collation(column_name),
formatter.format_collation(new_column_name),
SnowflakeConnector._escape_full_table_name(table_name),
formatter.quote(column_name),
formatter.quote(new_column_name),
)

@staticmethod
Expand All @@ -244,11 +256,12 @@ def get_column_alter_ddl(
formatter = SnowflakeIdentifierPreparer(SnowflakeDialect())
# Since we build the ddl manually we can't rely on SQLAlchemy to
# quote column names automatically.
escaped_full_table_name = SnowflakeConnector._escape_full_table_name(table_name)
return sqlalchemy.DDL(
"ALTER TABLE %(table_name)s ALTER COLUMN %(column_name)s SET DATA TYPE %(column_type)s",
{
"table_name": table_name,
"column_name": formatter.format_collation(column_name),
"table_name": escaped_full_table_name,
"column_name": formatter.quote(column_name),
"column_type": column_type,
},
)
Expand Down Expand Up @@ -310,7 +323,7 @@ def schema_exists(self, schema_name: str) -> bool:
# Make quoted schema names upper case because we create them that way
# and the metadata that SQLAlchemy returns is case insensitive only for
# non-quoted schema names so these will look like they dont exist yet.
if '"' in formatter.format_collation(schema_name):
if '"' in formatter.quote(schema_name):
schema_name = schema_name.upper()
return schema_name in schema_names

Expand Down Expand Up @@ -342,7 +355,7 @@ def _get_column_selections(
) -> list:
column_selections = []
for property_name, property_def in schema["properties"].items():
clean_property_name = formatter.format_collation(property_name)
clean_property_name = formatter.quote(property_name)
clean_alias = clean_property_name
if '"' in clean_property_name:
clean_alias = clean_property_name.upper()
Expand Down Expand Up @@ -372,8 +385,8 @@ def _get_merge_from_stage_statement( # noqa: ANN202
)

# use UPPER from here onwards
formatted_properties = [formatter.format_collation(col) for col in schema["properties"]]
formatted_key_properties = [formatter.format_collation(col) for col in key_properties]
formatted_properties = [formatter.quote(col) for col in schema["properties"]]
formatted_key_properties = [formatter.quote(col) for col in key_properties]
join_expr = " and ".join(
[f"d.{key} = s.{key}" for key in formatted_key_properties],
)
Expand All @@ -386,9 +399,10 @@ def _get_merge_from_stage_statement( # noqa: ANN202
)
dedup_cols = ", ".join(list(formatted_key_properties))
dedup = f"QUALIFY ROW_NUMBER() OVER (PARTITION BY {dedup_cols} ORDER BY SEQ8() DESC) = 1"
escaped_full_table_name = self._escape_full_table_name(full_table_name)
return (
text(
f"merge into {quoted_name(full_table_name, quote=True)} d using " # noqa: ISC003
f"merge into {escaped_full_table_name} d using " # noqa: ISC003
+ f"(select {json_casting_selects} from '@~/target-snowflake/{sync_id}'" # noqa: S608
+ f"(file_format => {file_format}) {dedup}) s "
+ f"on {join_expr} "
Expand All @@ -411,9 +425,10 @@ def _get_copy_statement(self, full_table_name, schema, sync_id, file_format): #
column_selections,
"col_alias",
)
escaped_full_table_name = self._escape_full_table_name(full_table_name)
return (
text(
f"copy into {full_table_name} {col_alias_selects} from " # noqa: ISC003
f"copy into {escaped_full_table_name} {col_alias_selects} from " # noqa: ISC003
+ f"(select {json_casting_selects} from " # noqa: S608
+ f"'@~/target-snowflake/{sync_id}')"
+ f"file_format = (format_name='{file_format}')",
Expand Down Expand Up @@ -634,3 +649,9 @@ def _adapt_column_type(
sql_type,
)
raise

@staticmethod
def _escape_full_table_name(full_table_name: str) -> str:
formatter = SnowflakeIdentifierPreparer(SnowflakeDialect())
db_name, schema_name, table_name = SQLConnector().parse_full_table_name(full_table_name)
return f"{formatter.quote(db_name)}.{formatter.quote(schema_name)}.{formatter.quote(table_name)}"
4 changes: 2 additions & 2 deletions target_snowflake/sinks.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ def conform_name(
name: str,
object_type: str | None = None,
) -> str:
formatter = SnowflakeIdentifierPreparer(SnowflakeDialect())
if object_type and object_type != "column":
return super().conform_name(name=name, object_type=object_type)
formatter = SnowflakeIdentifierPreparer(SnowflakeDialect())
if '"' not in formatter.format_collation(name.lower()):
if '"' not in formatter.quote(name.lower()):
name = name.lower()
return name

Expand Down
57 changes: 57 additions & 0 deletions tests/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,61 @@ def setup(self) -> None:
)


class SnowflakeTargetExistingReservedNameTableAlter(TargetFileTestTemplate):
name = "existing_reserved_name_table_alter"
# This sends a schema that will request altering from TIMESTAMP_NTZ to VARCHAR

@property
def singer_filepath(self) -> Path:
current_dir = Path(__file__).resolve().parent
return current_dir / "target_test_streams" / "reserved_words_in_table.singer"

def setup(self) -> None:
connector = self.target.default_sink_class.connector_class(self.target.config)
table = f'{self.target.config['database']}.{self.target.config['default_target_schema']}."ORDER"'.upper()
connector.connection.execute(
f"""
CREATE OR REPLACE TABLE {table} (
ID VARCHAR(16777216),
COL_STR VARCHAR(16777216),
COL_TS TIMESTAMP_NTZ(9),
COL_INT STRING,
COL_BOOL BOOLEAN,
COL_VARIANT VARIANT,
_SDC_BATCHED_AT TIMESTAMP_NTZ(9),
_SDC_DELETED_AT VARCHAR(16777216),
_SDC_EXTRACTED_AT TIMESTAMP_NTZ(9),
_SDC_RECEIVED_AT TIMESTAMP_NTZ(9),
_SDC_SEQUENCE NUMBER(38,0),
_SDC_TABLE_VERSION NUMBER(38,0),
PRIMARY KEY (ID)
)
""",
)


class SnowflakeTargetReservedWordsInTable(TargetFileTestTemplate):
# Contains reserved words from
# https://docs.snowflake.com/en/sql-reference/reserved-keywords
# Syncs records then alters schema by adding a non-reserved word column.
name = "reserved_words_in_table"

@property
def singer_filepath(self) -> Path:
current_dir = Path(__file__).resolve().parent
return current_dir / "target_test_streams" / "reserved_words_in_table.singer"

def validate(self) -> None:
connector = self.target.default_sink_class.connector_class(self.target.config)
table = f'{self.target.config['database']}.{self.target.config['default_target_schema']}."ORDER"'.upper()
result = connector.connection.execute(
f"select * from {table}",
)
assert result.rowcount == 1
row = result.first()
assert len(row) == 13, f"Row has unexpected length {len(row)}"


class SnowflakeTargetTypeEdgeCasesTest(TargetFileTestTemplate):
name = "type_edge_cases"

Expand Down Expand Up @@ -540,6 +595,8 @@ def singer_filepath(self) -> Path:
SnowflakeTargetColonsInColName,
SnowflakeTargetExistingTable,
SnowflakeTargetExistingTableAlter,
SnowflakeTargetExistingReservedNameTableAlter,
SnowflakeTargetReservedWordsInTable,
SnowflakeTargetTypeEdgeCasesTest,
SnowflakeTargetColumnOrderMismatch,
],
Expand Down
2 changes: 2 additions & 0 deletions tests/target_test_streams/reserved_words_in_table.singer
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{ "type": "SCHEMA", "stream": "ORDER", "schema": { "properties": { "id": { "type": [ "string", "null" ] }, "col_str": { "type": [ "string", "null" ] }, "col_ts": { "format": "date-time", "type": [ "string", "null" ] }, "col_int": { "type": "integer" }, "col_bool": { "type": [ "boolean", "null" ] }, "col_variant": {"type": "object"} }, "type": "object" }, "key_properties": [ "id" ], "bookmark_properties": [ "col_ts" ] }
{ "type": "RECORD", "stream": "ORDER", "record": { "id": "123", "col_str": "foo", "col_ts": "2023-06-13 11:50:04.072", "col_int": 5, "col_bool": true, "col_variant": {"key": "val"} }, "time_extracted": "2023-06-14T18:08:23.074716+00:00" }
Loading