diff --git a/airflow/providers/apache/livy/operators/livy.py b/airflow/providers/apache/livy/operators/livy.py index edaa4e15d8579..b761118a5ce25 100644 --- a/airflow/providers/apache/livy/operators/livy.py +++ b/airflow/providers/apache/livy/operators/livy.py @@ -124,7 +124,7 @@ def __init__( self._extra_options = extra_options or {} self._extra_headers = extra_headers or {} - self._batch_id: int | str + self._batch_id: int | str | None = None self.retry_args = retry_args self.deferrable = deferrable @@ -170,6 +170,7 @@ def execute(self, context: Context) -> Any: polling_interval=self._polling_interval, extra_options=self._extra_options, extra_headers=self._extra_headers, + execution_timeout=self.execution_timeout, ), method_name="execute_complete", ) @@ -217,8 +218,12 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> Any: for log_line in event["log_lines"]: self.log.info(log_line) - if event["status"] == "error": + if event["status"] == "timeout": + self.hook.delete_batch(event["batch_id"]) + + if event["status"] in ["error", "timeout"]: raise AirflowException(event["response"]) + self.log.info( "%s completed with response %s", self.task_id, diff --git a/airflow/providers/apache/livy/triggers/livy.py b/airflow/providers/apache/livy/triggers/livy.py index d6203b4324d49..298d1e5f876e4 100644 --- a/airflow/providers/apache/livy/triggers/livy.py +++ b/airflow/providers/apache/livy/triggers/livy.py @@ -20,6 +20,7 @@ from __future__ import annotations import asyncio +from datetime import datetime, timedelta, timezone from typing import Any, AsyncIterator from airflow.providers.apache.livy.hooks.livy import BatchState, LivyAsyncHook @@ -54,6 +55,7 @@ def __init__( extra_options: dict[str, Any] | None = None, extra_headers: dict[str, Any] | None = None, livy_hook_async: LivyAsyncHook | None = None, + execution_timeout: timedelta | None = None, ): super().__init__() self._batch_id = batch_id @@ -63,6 +65,7 @@ def __init__( self._extra_options = extra_options self._extra_headers = extra_headers self._livy_hook_async = livy_hook_async + self._execution_timeout = execution_timeout def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize LivyTrigger arguments and classpath.""" @@ -76,6 +79,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "extra_options": self._extra_options, "extra_headers": self._extra_headers, "livy_hook_async": self._livy_hook_async, + "execution_timeout": self._execution_timeout, }, ) @@ -113,16 +117,37 @@ async def poll_for_termination(self, batch_id: int | str) -> dict[str, Any]: :param batch_id: id of the batch session to monitor. """ + if self._execution_timeout is not None: + timeout_datetime = datetime.now(timezone.utc) + self._execution_timeout + else: + timeout_datetime = None + batch_execution_timed_out = False hook = self._get_async_hook() state = await hook.get_batch_state(batch_id) self.log.info("Batch with id %s is in state: %s", batch_id, state["batch_state"].value) while state["batch_state"] not in hook.TERMINAL_STATES: self.log.info("Batch with id %s is in state: %s", batch_id, state["batch_state"].value) + batch_execution_timed_out = ( + timeout_datetime is not None and datetime.now(timezone.utc) > timeout_datetime + ) + if batch_execution_timed_out: + break self.log.info("Sleeping for %s seconds", self._polling_interval) await asyncio.sleep(self._polling_interval) state = await hook.get_batch_state(batch_id) - self.log.info("Batch with id %s terminated with state: %s", batch_id, state["batch_state"].value) log_lines = await hook.dump_batch_logs(batch_id) + if batch_execution_timed_out: + self.log.info( + "Batch with id %s did not terminate, but it reached execution timeout.", + batch_id, + ) + return { + "status": "timeout", + "batch_id": batch_id, + "response": f"Batch {batch_id} timed out", + "log_lines": log_lines, + } + self.log.info("Batch with id %s terminated with state: %s", batch_id, state["batch_state"].value) if state["batch_state"] != BatchState.SUCCESS: return { "status": "error", diff --git a/tests/providers/apache/livy/operators/test_livy.py b/tests/providers/apache/livy/operators/test_livy.py index 02e8231eb2896..4e128cbec8f61 100644 --- a/tests/providers/apache/livy/operators/test_livy.py +++ b/tests/providers/apache/livy/operators/test_livy.py @@ -280,6 +280,19 @@ def test_execution_with_extra_options_deferrable( task.execute(context=self.mock_context) assert task.hook.extra_options == extra_options + @patch("airflow.providers.apache.livy.operators.livy.LivyHook.delete_batch") + def test_when_kill_is_called_right_after_construction_it_should_not_raise_attribute_error( + self, mock_delete_batch + ): + task = LivyOperator( + livy_conn_id="livyunittest", + file="sparkapp", + dag=self.dag, + task_id="livy_example", + ) + task.kill() + mock_delete_batch.assert_not_called() + @patch("airflow.providers.apache.livy.operators.livy.LivyHook.delete_batch") @patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID) @patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch", return_value=GET_BATCH) @@ -380,6 +393,30 @@ def test_execute_complete_error(self, mock_post): ) self.mock_context["ti"].xcom_push.assert_not_called() + @patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID) + @patch("airflow.providers.apache.livy.operators.livy.LivyHook.delete_batch") + def test_execute_complete_timeout(self, mock_delete, mock_post): + task = LivyOperator( + livy_conn_id="livyunittest", + file="sparkapp", + dag=self.dag, + task_id="livy_example", + polling_interval=1, + deferrable=True, + ) + with pytest.raises(AirflowException): + task.execute_complete( + context=self.mock_context, + event={ + "status": "timeout", + "log_lines": ["mock log"], + "batch_id": BATCH_ID, + "response": "mock timeout", + }, + ) + mock_delete.assert_called_once_with(BATCH_ID) + self.mock_context["ti"].xcom_push.assert_not_called() + @pytest.mark.db_test def test_spark_params_templating(create_task_instance_of_operator): diff --git a/tests/providers/apache/livy/triggers/test_livy.py b/tests/providers/apache/livy/triggers/test_livy.py index ac1464ffd45ac..df85a84bac338 100644 --- a/tests/providers/apache/livy/triggers/test_livy.py +++ b/tests/providers/apache/livy/triggers/test_livy.py @@ -17,6 +17,7 @@ from __future__ import annotations import asyncio +from datetime import timedelta from unittest import mock import pytest @@ -46,6 +47,7 @@ def test_livy_trigger_serialization(self): "extra_options": None, "extra_headers": None, "livy_hook_async": None, + "execution_timeout": None, } @pytest.mark.asyncio @@ -195,3 +197,31 @@ async def test_livy_trigger_poll_for_termination_state(self, mock_dump_batch_log # TriggerEvent was not returned assert task.done() is False asyncio.get_event_loop().stop() + + @pytest.mark.asyncio + @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_batch_state") + @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.dump_batch_logs") + async def test_livy_trigger_poll_for_termination_timeout( + self, mock_dump_batch_logs, mock_get_batch_state + ): + """ + Test if poll_for_termination() returns timeout response when execution times out. + """ + mock_get_batch_state.return_value = {"batch_state": BatchState.RUNNING} + mock_dump_batch_logs.return_value = ["mock_log"] + trigger = LivyTrigger( + batch_id=1, + spark_params={}, + livy_conn_id=LivyHook.default_conn_name, + polling_interval=1, + execution_timeout=timedelta(seconds=0), + ) + + task = await trigger.poll_for_termination(1) + + assert task == { + "status": "timeout", + "batch_id": 1, + "response": "Batch 1 timed out", + "log_lines": ["mock_log"], + }