From 27dc7e80df3ecf5aa61718334f32a1d128b0125c Mon Sep 17 00:00:00 2001 From: vatsrahul1001 <43964496+vatsrahul1001@users.noreply.github.com> Date: Thu, 18 Jan 2024 19:33:04 +0530 Subject: [PATCH] Optimize `SnowflakeSqlApiOperator` execution in deferrable mode (#36850) --- .../snowflake/operators/snowflake.py | 15 +++++ .../snowflake/operators/test_snowflake.py | 67 ++++++++++++++++++- 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/airflow/providers/snowflake/operators/snowflake.py b/airflow/providers/snowflake/operators/snowflake.py index 9e0bf3d1cfbb5..f7890b87e1b98 100644 --- a/airflow/providers/snowflake/operators/snowflake.py +++ b/airflow/providers/snowflake/operators/snowflake.py @@ -514,6 +514,21 @@ def execute(self, context: Context) -> None: if self.do_xcom_push: context["ti"].xcom_push(key="query_ids", value=self.query_ids) + succeeded_query_ids = [] + for query_id in self.query_ids: + self.log.info("Retrieving status for query id %s", query_id) + statement_status = self._hook.get_sql_api_query_status(query_id) + if statement_status.get("status") == "running": + break + elif statement_status.get("status") == "success": + succeeded_query_ids.append(query_id) + else: + raise AirflowException(f"{statement_status.get('status')}: {statement_status.get('message')}") + + if len(self.query_ids) == len(succeeded_query_ids): + self.log.info("%s completed successfully.", self.task_id) + return + if self.deferrable: self.defer( timeout=self.execution_timeout, diff --git a/tests/providers/snowflake/operators/test_snowflake.py b/tests/providers/snowflake/operators/test_snowflake.py index 07df5fb147a41..7f429277b9268 100644 --- a/tests/providers/snowflake/operators/test_snowflake.py +++ b/tests/providers/snowflake/operators/test_snowflake.py @@ -253,7 +253,9 @@ def test_snowflake_sql_api_to_fails_when_one_query_fails( @pytest.mark.parametrize("mock_sql, statement_count", [(SQL_MULTIPLE_STMTS, 4), (SINGLE_STMT, 1)]) @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.execute_query") - def test_snowflake_sql_api_execute_operator_async(self, mock_db_hook, mock_sql, statement_count): + def test_snowflake_sql_api_execute_operator_async( + self, mock_execute_query, mock_sql, statement_count, mock_get_sql_api_query_status + ): """ Asserts that a task is deferred and an SnowflakeSqlApiTrigger will be fired when the SnowflakeSqlApiOperator is executed. @@ -266,6 +268,9 @@ def test_snowflake_sql_api_execute_operator_async(self, mock_db_hook, mock_sql, deferrable=True, ) + mock_execute_query.return_value = ["uuid1"] + mock_get_sql_api_query_status.side_effect = [{"status": "running"}] + with pytest.raises(TaskDeferred) as exc: operator.execute(create_context(operator)) @@ -311,3 +316,63 @@ def test_snowflake_sql_api_execute_complete(self, mock_conn, mock_event): with mock.patch.object(operator.log, "info") as mock_log_info: operator.execute_complete(context=None, event=mock_event) mock_log_info.assert_called_with("%s completed successfully.", TASK_ID) + + @mock.patch("airflow.providers.snowflake.operators.snowflake.SnowflakeSqlApiOperator.defer") + def test_snowflake_sql_api_execute_operator_failed_before_defer( + self, mock_defer, mock_execute_query, mock_get_sql_api_query_status + ): + """Asserts that a task is not deferred when its failed""" + + operator = SnowflakeSqlApiOperator( + task_id=TASK_ID, + snowflake_conn_id="snowflake_default", + sql=SQL_MULTIPLE_STMTS, + statement_count=4, + do_xcom_push=False, + deferrable=True, + ) + mock_execute_query.return_value = ["uuid1"] + mock_get_sql_api_query_status.side_effect = [{"status": "error"}] + with pytest.raises(AirflowException): + operator.execute(create_context(operator)) + assert not mock_defer.called + + @mock.patch("airflow.providers.snowflake.operators.snowflake.SnowflakeSqlApiOperator.defer") + def test_snowflake_sql_api_execute_operator_succeeded_before_defer( + self, mock_defer, mock_execute_query, mock_get_sql_api_query_status + ): + """Asserts that a task is not deferred when its succeeded""" + + operator = SnowflakeSqlApiOperator( + task_id=TASK_ID, + snowflake_conn_id="snowflake_default", + sql=SQL_MULTIPLE_STMTS, + statement_count=4, + do_xcom_push=False, + deferrable=True, + ) + mock_execute_query.return_value = ["uuid1"] + mock_get_sql_api_query_status.side_effect = [{"status": "success"}] + operator.execute(create_context(operator)) + + assert not mock_defer.called + + @mock.patch("airflow.providers.snowflake.operators.snowflake.SnowflakeSqlApiOperator.defer") + def test_snowflake_sql_api_execute_operator_running_before_defer( + self, mock_defer, mock_execute_query, mock_get_sql_api_query_status + ): + """Asserts that a task is deferred when its running""" + + operator = SnowflakeSqlApiOperator( + task_id=TASK_ID, + snowflake_conn_id="snowflake_default", + sql=SQL_MULTIPLE_STMTS, + statement_count=4, + do_xcom_push=False, + deferrable=True, + ) + mock_execute_query.return_value = ["uuid1"] + mock_get_sql_api_query_status.side_effect = [{"status": "running"}] + operator.execute(create_context(operator)) + + assert mock_defer.called