Skip to content

Commit

Permalink
Fix bug in LivyOperator when its trigger times out (#38916)
Browse files Browse the repository at this point in the history
When a LivyOperator was instantiated with deferrable=True and its batch job ran for more time than the set execution_timeout, airflow would detect this timeout and would cancel the trigger and then try to kill the task with the 'on_kill' method.

But that would fail raising an AttributeError because the batch_id attribute wouldn't be defined in the constructor method.

From now on, the LivyTrigger will timeout itself before airflow does it, and it will send an event to the LivyOperator signaling that a timeout happened. This way, the operator can stop the running Livy batch job, and can fail the task instance gracefully.
  • Loading branch information
mateuslatrova authored Apr 14, 2024
1 parent c2f96ff commit bf5ab8f
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 3 deletions.
9 changes: 7 additions & 2 deletions airflow/providers/apache/livy/operators/livy.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(
self._extra_options = extra_options or {}
self._extra_headers = extra_headers or {}

self._batch_id: int | str
self._batch_id: int | str | None = None
self.retry_args = retry_args
self.deferrable = deferrable

Expand Down Expand Up @@ -170,6 +170,7 @@ def execute(self, context: Context) -> Any:
polling_interval=self._polling_interval,
extra_options=self._extra_options,
extra_headers=self._extra_headers,
execution_timeout=self.execution_timeout,
),
method_name="execute_complete",
)
Expand Down Expand Up @@ -217,8 +218,12 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
for log_line in event["log_lines"]:
self.log.info(log_line)

if event["status"] == "error":
if event["status"] == "timeout":
self.hook.delete_batch(event["batch_id"])

if event["status"] in ["error", "timeout"]:
raise AirflowException(event["response"])

self.log.info(
"%s completed with response %s",
self.task_id,
Expand Down
27 changes: 26 additions & 1 deletion airflow/providers/apache/livy/triggers/livy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from __future__ import annotations

import asyncio
from datetime import datetime, timedelta, timezone
from typing import Any, AsyncIterator

from airflow.providers.apache.livy.hooks.livy import BatchState, LivyAsyncHook
Expand Down Expand Up @@ -54,6 +55,7 @@ def __init__(
extra_options: dict[str, Any] | None = None,
extra_headers: dict[str, Any] | None = None,
livy_hook_async: LivyAsyncHook | None = None,
execution_timeout: timedelta | None = None,
):
super().__init__()
self._batch_id = batch_id
Expand All @@ -63,6 +65,7 @@ def __init__(
self._extra_options = extra_options
self._extra_headers = extra_headers
self._livy_hook_async = livy_hook_async
self._execution_timeout = execution_timeout

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize LivyTrigger arguments and classpath."""
Expand All @@ -76,6 +79,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"extra_options": self._extra_options,
"extra_headers": self._extra_headers,
"livy_hook_async": self._livy_hook_async,
"execution_timeout": self._execution_timeout,
},
)

Expand Down Expand Up @@ -113,16 +117,37 @@ async def poll_for_termination(self, batch_id: int | str) -> dict[str, Any]:
:param batch_id: id of the batch session to monitor.
"""
if self._execution_timeout is not None:
timeout_datetime = datetime.now(timezone.utc) + self._execution_timeout
else:
timeout_datetime = None
batch_execution_timed_out = False
hook = self._get_async_hook()
state = await hook.get_batch_state(batch_id)
self.log.info("Batch with id %s is in state: %s", batch_id, state["batch_state"].value)
while state["batch_state"] not in hook.TERMINAL_STATES:
self.log.info("Batch with id %s is in state: %s", batch_id, state["batch_state"].value)
batch_execution_timed_out = (
timeout_datetime is not None and datetime.now(timezone.utc) > timeout_datetime
)
if batch_execution_timed_out:
break
self.log.info("Sleeping for %s seconds", self._polling_interval)
await asyncio.sleep(self._polling_interval)
state = await hook.get_batch_state(batch_id)
self.log.info("Batch with id %s terminated with state: %s", batch_id, state["batch_state"].value)
log_lines = await hook.dump_batch_logs(batch_id)
if batch_execution_timed_out:
self.log.info(
"Batch with id %s did not terminate, but it reached execution timeout.",
batch_id,
)
return {
"status": "timeout",
"batch_id": batch_id,
"response": f"Batch {batch_id} timed out",
"log_lines": log_lines,
}
self.log.info("Batch with id %s terminated with state: %s", batch_id, state["batch_state"].value)
if state["batch_state"] != BatchState.SUCCESS:
return {
"status": "error",
Expand Down
37 changes: 37 additions & 0 deletions tests/providers/apache/livy/operators/test_livy.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,19 @@ def test_execution_with_extra_options_deferrable(
task.execute(context=self.mock_context)
assert task.hook.extra_options == extra_options

@patch("airflow.providers.apache.livy.operators.livy.LivyHook.delete_batch")
def test_when_kill_is_called_right_after_construction_it_should_not_raise_attribute_error(
self, mock_delete_batch
):
task = LivyOperator(
livy_conn_id="livyunittest",
file="sparkapp",
dag=self.dag,
task_id="livy_example",
)
task.kill()
mock_delete_batch.assert_not_called()

@patch("airflow.providers.apache.livy.operators.livy.LivyHook.delete_batch")
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID)
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch", return_value=GET_BATCH)
Expand Down Expand Up @@ -380,6 +393,30 @@ def test_execute_complete_error(self, mock_post):
)
self.mock_context["ti"].xcom_push.assert_not_called()

@patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID)
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.delete_batch")
def test_execute_complete_timeout(self, mock_delete, mock_post):
task = LivyOperator(
livy_conn_id="livyunittest",
file="sparkapp",
dag=self.dag,
task_id="livy_example",
polling_interval=1,
deferrable=True,
)
with pytest.raises(AirflowException):
task.execute_complete(
context=self.mock_context,
event={
"status": "timeout",
"log_lines": ["mock log"],
"batch_id": BATCH_ID,
"response": "mock timeout",
},
)
mock_delete.assert_called_once_with(BATCH_ID)
self.mock_context["ti"].xcom_push.assert_not_called()


@pytest.mark.db_test
def test_spark_params_templating(create_task_instance_of_operator):
Expand Down
30 changes: 30 additions & 0 deletions tests/providers/apache/livy/triggers/test_livy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import asyncio
from datetime import timedelta
from unittest import mock

import pytest
Expand Down Expand Up @@ -46,6 +47,7 @@ def test_livy_trigger_serialization(self):
"extra_options": None,
"extra_headers": None,
"livy_hook_async": None,
"execution_timeout": None,
}

@pytest.mark.asyncio
Expand Down Expand Up @@ -195,3 +197,31 @@ async def test_livy_trigger_poll_for_termination_state(self, mock_dump_batch_log
# TriggerEvent was not returned
assert task.done() is False
asyncio.get_event_loop().stop()

@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_batch_state")
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.dump_batch_logs")
async def test_livy_trigger_poll_for_termination_timeout(
self, mock_dump_batch_logs, mock_get_batch_state
):
"""
Test if poll_for_termination() returns timeout response when execution times out.
"""
mock_get_batch_state.return_value = {"batch_state": BatchState.RUNNING}
mock_dump_batch_logs.return_value = ["mock_log"]
trigger = LivyTrigger(
batch_id=1,
spark_params={},
livy_conn_id=LivyHook.default_conn_name,
polling_interval=1,
execution_timeout=timedelta(seconds=0),
)

task = await trigger.poll_for_termination(1)

assert task == {
"status": "timeout",
"batch_id": 1,
"response": "Batch 1 timed out",
"log_lines": ["mock_log"],
}

0 comments on commit bf5ab8f

Please sign in to comment.