Skip to content

Commit

Permalink
Optimize SnowflakeSqlApiOperator execution in deferrable mode (#36850)
Browse files Browse the repository at this point in the history
  • Loading branch information
vatsrahul1001 authored Jan 18, 2024
1 parent 1ece3f6 commit 27dc7e8
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 1 deletion.
15 changes: 15 additions & 0 deletions airflow/providers/snowflake/operators/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,21 @@ def execute(self, context: Context) -> None:
if self.do_xcom_push:
context["ti"].xcom_push(key="query_ids", value=self.query_ids)

succeeded_query_ids = []
for query_id in self.query_ids:
self.log.info("Retrieving status for query id %s", query_id)
statement_status = self._hook.get_sql_api_query_status(query_id)
if statement_status.get("status") == "running":
break
elif statement_status.get("status") == "success":
succeeded_query_ids.append(query_id)
else:
raise AirflowException(f"{statement_status.get('status')}: {statement_status.get('message')}")

if len(self.query_ids) == len(succeeded_query_ids):
self.log.info("%s completed successfully.", self.task_id)
return

if self.deferrable:
self.defer(
timeout=self.execution_timeout,
Expand Down
67 changes: 66 additions & 1 deletion tests/providers/snowflake/operators/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,9 @@ def test_snowflake_sql_api_to_fails_when_one_query_fails(

@pytest.mark.parametrize("mock_sql, statement_count", [(SQL_MULTIPLE_STMTS, 4), (SINGLE_STMT, 1)])
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.execute_query")
def test_snowflake_sql_api_execute_operator_async(self, mock_db_hook, mock_sql, statement_count):
def test_snowflake_sql_api_execute_operator_async(
self, mock_execute_query, mock_sql, statement_count, mock_get_sql_api_query_status
):
"""
Asserts that a task is deferred and an SnowflakeSqlApiTrigger will be fired
when the SnowflakeSqlApiOperator is executed.
Expand All @@ -266,6 +268,9 @@ def test_snowflake_sql_api_execute_operator_async(self, mock_db_hook, mock_sql,
deferrable=True,
)

mock_execute_query.return_value = ["uuid1"]
mock_get_sql_api_query_status.side_effect = [{"status": "running"}]

with pytest.raises(TaskDeferred) as exc:
operator.execute(create_context(operator))

Expand Down Expand Up @@ -311,3 +316,63 @@ def test_snowflake_sql_api_execute_complete(self, mock_conn, mock_event):
with mock.patch.object(operator.log, "info") as mock_log_info:
operator.execute_complete(context=None, event=mock_event)
mock_log_info.assert_called_with("%s completed successfully.", TASK_ID)

@mock.patch("airflow.providers.snowflake.operators.snowflake.SnowflakeSqlApiOperator.defer")
def test_snowflake_sql_api_execute_operator_failed_before_defer(
self, mock_defer, mock_execute_query, mock_get_sql_api_query_status
):
"""Asserts that a task is not deferred when its failed"""

operator = SnowflakeSqlApiOperator(
task_id=TASK_ID,
snowflake_conn_id="snowflake_default",
sql=SQL_MULTIPLE_STMTS,
statement_count=4,
do_xcom_push=False,
deferrable=True,
)
mock_execute_query.return_value = ["uuid1"]
mock_get_sql_api_query_status.side_effect = [{"status": "error"}]
with pytest.raises(AirflowException):
operator.execute(create_context(operator))
assert not mock_defer.called

@mock.patch("airflow.providers.snowflake.operators.snowflake.SnowflakeSqlApiOperator.defer")
def test_snowflake_sql_api_execute_operator_succeeded_before_defer(
self, mock_defer, mock_execute_query, mock_get_sql_api_query_status
):
"""Asserts that a task is not deferred when its succeeded"""

operator = SnowflakeSqlApiOperator(
task_id=TASK_ID,
snowflake_conn_id="snowflake_default",
sql=SQL_MULTIPLE_STMTS,
statement_count=4,
do_xcom_push=False,
deferrable=True,
)
mock_execute_query.return_value = ["uuid1"]
mock_get_sql_api_query_status.side_effect = [{"status": "success"}]
operator.execute(create_context(operator))

assert not mock_defer.called

@mock.patch("airflow.providers.snowflake.operators.snowflake.SnowflakeSqlApiOperator.defer")
def test_snowflake_sql_api_execute_operator_running_before_defer(
self, mock_defer, mock_execute_query, mock_get_sql_api_query_status
):
"""Asserts that a task is deferred when its running"""

operator = SnowflakeSqlApiOperator(
task_id=TASK_ID,
snowflake_conn_id="snowflake_default",
sql=SQL_MULTIPLE_STMTS,
statement_count=4,
do_xcom_push=False,
deferrable=True,
)
mock_execute_query.return_value = ["uuid1"]
mock_get_sql_api_query_status.side_effect = [{"status": "running"}]
operator.execute(create_context(operator))

assert mock_defer.called

0 comments on commit 27dc7e8

Please sign in to comment.