diff --git a/flytekit/core/base_sql_task.py b/flytekit/core/base_sql_task.py index 78f4341839..d2e4838ed8 100644 --- a/flytekit/core/base_sql_task.py +++ b/flytekit/core/base_sql_task.py @@ -41,7 +41,7 @@ def __init__( task_config=task_config, **kwargs, ) - self._query_template = query_template + self._query_template = query_template.replace("\n", "\\n").replace("\t", "\\t") @property def query_template(self) -> str: diff --git a/plugins/flytekit-snowflake/tests/test_snowflake.py b/plugins/flytekit-snowflake/tests/test_snowflake.py index ab558ca534..a012e38d99 100644 --- a/plugins/flytekit-snowflake/tests/test_snowflake.py +++ b/plugins/flytekit-snowflake/tests/test_snowflake.py @@ -64,14 +64,26 @@ def test_local_exec(): snowflake_task = SnowflakeTask( name="flytekit.demo.snowflake_task.query2", inputs=kwtypes(ds=str), - query_template=query_template, + query_template="select 1\n", # the schema literal's backend uri will be equal to the value of .raw_output_data output_schema_type=FlyteSchema, ) assert len(snowflake_task.interface.inputs) == 1 + assert snowflake_task.query_template == "select 1\\n" assert len(snowflake_task.interface.outputs) == 1 # will not run locally with pytest.raises(Exception): snowflake_task() + + +def test_sql_template(): + snowflake_task = SnowflakeTask( + name="flytekit.demo.snowflake_task.query2", + inputs=kwtypes(ds=str), + query_template="""select 1 from\t + custom where column = 1""", + output_schema_type=FlyteSchema, + ) + assert snowflake_task.query_template == "select 1 from\\t\\n custom where column = 1"