Skip to content

Commit

Permalink
Handle special characters in snowflake query (#1221)
Browse files Browse the repository at this point in the history
* Handle special characters in snowflake query

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Oct 7, 2022
1 parent ea1437a commit 91db60e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
2 changes: 1 addition & 1 deletion flytekit/core/base_sql_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
task_config=task_config,
**kwargs,
)
self._query_template = query_template
self._query_template = query_template.replace("\n", "\\n").replace("\t", "\\t")

@property
def query_template(self) -> str:
Expand Down
14 changes: 13 additions & 1 deletion plugins/flytekit-snowflake/tests/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,26 @@ def test_local_exec():
snowflake_task = SnowflakeTask(
name="flytekit.demo.snowflake_task.query2",
inputs=kwtypes(ds=str),
query_template=query_template,
query_template="select 1\n",
# the schema literal's backend uri will be equal to the value of .raw_output_data
output_schema_type=FlyteSchema,
)

assert len(snowflake_task.interface.inputs) == 1
assert snowflake_task.query_template == "select 1\\n"
assert len(snowflake_task.interface.outputs) == 1

# will not run locally
with pytest.raises(Exception):
snowflake_task()


def test_sql_template():
snowflake_task = SnowflakeTask(
name="flytekit.demo.snowflake_task.query2",
inputs=kwtypes(ds=str),
query_template="""select 1 from\t
custom where column = 1""",
output_schema_type=FlyteSchema,
)
assert snowflake_task.query_template == "select 1 from\\t\\n custom where column = 1"

0 comments on commit 91db60e

Please sign in to comment.