Skip to content

Commit

Permalink
AIP-72: Port task success overtime to the Supervisor
Browse files Browse the repository at this point in the history
This PR ports the overtime feature on `LocalTaskJob` (added in apache#39890) to the Supervisor.
It allows to terminate Task process to terminate when it exceeding the configured success overtime threshold which is useful when we add Listenener to the Task process.

closes apache#44356

Also added `TaskState` to update state and send end_date from task process to the supervisor.
  • Loading branch information
kaxil committed Dec 3, 2024
1 parent a242ff6 commit 48ca10c
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 16 deletions.
2 changes: 2 additions & 0 deletions task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

from __future__ import annotations

from datetime import datetime
from typing import Annotated, Literal, Union

from pydantic import BaseModel, ConfigDict, Field
Expand Down Expand Up @@ -101,6 +102,7 @@ class TaskState(BaseModel):
"""

state: TerminalTIState
end_date: datetime | None = None
type: Literal["TaskState"] = "TaskState"


Expand Down
37 changes: 30 additions & 7 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
GetVariable,
GetXCom,
StartupDetails,
TaskState,
ToSupervisor,
)

Expand Down Expand Up @@ -265,9 +266,9 @@ class WatchedSubprocess:
client: Client

_process: psutil.Process
_exit_code: int | None = None
_terminal_state: str | None = None
_final_state: str | None = None
_exit_code: int | None = attrs.field(default=None, init=False)
_terminal_state: str | None = attrs.field(default=None, init=False)
_final_state: str | None = attrs.field(default=None, init=False)

_last_successful_heartbeat: float = attrs.field(default=0, init=False)
_last_heartbeat_attempt: float = attrs.field(default=0, init=False)
Expand All @@ -277,6 +278,13 @@ class WatchedSubprocess:
# does not hang around forever.
failed_heartbeats: int = attrs.field(default=0, init=False)

# Maximum possible time (in seconds) that task will have for execution of auxiliary processes
# like listeners after task is marked as success.
# TODO: This should be come from airflow.cfg: [core] task_success_overtime
task_success_overtime_threshold: float = attrs.field(default=20.0, init=False)
_overtime: float = attrs.field(default=0.0, init=False)
_task_end_datetime: datetime | None = attrs.field(default=None, init=False)

selector: selectors.BaseSelector = attrs.field(factory=selectors.DefaultSelector)

procs: ClassVar[weakref.WeakValueDictionary[int, WatchedSubprocess]] = weakref.WeakValueDictionary()
Expand Down Expand Up @@ -500,6 +508,20 @@ def _monitor_subprocess(self):

self._send_heartbeat_if_needed()

self._handle_task_overtime_if_needed()

def _handle_task_overtime_if_needed(self):
"""Handle termination of auxiliary processes if the task exceeds the configured success overtime."""
if self._terminal_state != TerminalTIState.SUCCESS:
return

now = datetime.now(tz=timezone.utc)
self._overtime = (now - (self._task_end_datetime or now)).total_seconds()

if self._overtime > self.task_success_overtime_threshold:
log.warning("Task success overtime reached; terminating process", ti_id=self.ti_id)
self.kill(signal.SIGTERM, force=True)

def _service_subprocess(self, max_wait_time: float, raise_on_timeout: bool = False):
"""
Service subprocess events by processing socket activity and checking for process exit.
Expand Down Expand Up @@ -631,9 +653,11 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N
log.exception("Unable to decode message", line=line)
continue

# if isinstance(msg, TaskState):
# self._terminal_state = msg.state
if isinstance(msg, GetConnection):
resp = None
if isinstance(msg, TaskState):
self._terminal_state = msg.state
self._task_end_datetime = msg.end_date
elif isinstance(msg, GetConnection):
conn = self.client.connections.get(msg.conn_id)
resp = conn.model_dump_json(exclude_unset=True).encode()
elif isinstance(msg, GetVariable):
Expand All @@ -645,7 +669,6 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N
elif isinstance(msg, DeferTask):
self._terminal_state = IntermediateTIState.DEFERRED
self.client.task_instances.defer(self.ti_id, msg)
resp = None
else:
log.error("Unhandled request", msg=msg)
continue
Expand Down
14 changes: 10 additions & 4 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@

import os
import sys
from datetime import datetime, timezone
from io import FileIO
from typing import TYPE_CHECKING, TextIO

import attrs
import structlog
from pydantic import ConfigDict, TypeAdapter

from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.execution_time.comms import DeferTask, StartupDetails, ToSupervisor, ToTask
from airflow.sdk.execution_time.comms import DeferTask, StartupDetails, TaskState, ToSupervisor, ToTask

if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger as Logger
Expand Down Expand Up @@ -158,11 +159,14 @@ def run(ti: RuntimeTaskInstance, log: Logger):
if TYPE_CHECKING:
assert ti.task is not None
assert isinstance(ti.task, BaseOperator)

msg = None
try:
# TODO: pre execute etc.
# TODO next_method to support resuming from deferred
# TODO: Get a real context object
ti.task.execute({"task_instance": ti}) # type: ignore[attr-defined]
msg = TaskState(state=TerminalTIState.SUCCESS, end_date=datetime.now(tz=timezone.utc))
except TaskDeferred as defer:
classpath, trigger_kwargs = defer.trigger.serialize()
next_method = defer.method_name
Expand All @@ -173,9 +177,8 @@ def run(ti: RuntimeTaskInstance, log: Logger):
next_method=next_method,
trigger_timeout=timeout,
)
SUPERVISOR_COMMS.send_request(msg=msg, log=log)
except AirflowSkipException:
...
msg = TaskState(state=TerminalTIState.SKIPPED)
except AirflowRescheduleException:
...
except (AirflowFailException, AirflowSensorTimeout):
Expand All @@ -189,6 +192,9 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# TODO: Handle TI handle failure
raise

if msg:
SUPERVISOR_COMMS.send_request(msg=msg, log=log)


def finalize(log: Logger): ...

Expand Down
125 changes: 124 additions & 1 deletion task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@
from typing import TYPE_CHECKING
from unittest.mock import MagicMock

import attrs
import httpx
import psutil
import pytest
from uuid6 import uuid7

from airflow.sdk.api import client as sdk_client
from airflow.sdk.api.client import ServerResponseError
from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.execution_time.comms import (
ConnectionResult,
DeferTask,
Expand All @@ -62,6 +63,34 @@ def lineno():
return inspect.currentframe().f_back.f_lineno


@pytest.fixture
def mock_watched_subprocess(mocker) -> WatchedSubprocess:
"""
Mocks the WatchedSubprocess class with slots=False to allow for easier mocking of the class.
It is not possible to mock a class with slots=True, so we need to define a new class with slots=False
to allow for easier mocking of the class.
Reference: https://www.attrs.org/en/stable/glossary.html#term-slotted-classes
"""

@attrs.define(slots=False)
class MockWatchedSubprocess(WatchedSubprocess):
pass

mock_process = mocker.Mock()
mock_process.pid = 12345

mock_watched_subprocess = WatchedSubprocess(
ti_id=uuid7(),
pid=mock_process.pid,
stdin=mocker.Mock(),
process=mock_process,
client=mocker.Mock(),
)

return mock_watched_subprocess


@pytest.mark.usefixtures("disable_capturing")
class TestWatchedSubprocess:
def test_reading_from_pipes(self, captured_logs, time_machine):
Expand Down Expand Up @@ -478,6 +507,100 @@ def test_heartbeat_failures_handling(self, monkeypatch, mocker, captured_logs, t
"timestamp": mocker.ANY,
} in captured_logs

@pytest.mark.parametrize(
"terminal_state, task_end_datetime, overtime_threshold, expected_kill",
[
# The current date is fixed at tz.datetime(2024, 12, 1, 10, 10, 20)
# Current time minus 5 seconds | Threshold: 10s
pytest.param(
None,
tz.datetime(2024, 12, 1, 10, 10, 15),
10,
False,
id="no_terminal_state",
),
# Terminal state is not SUCCESS, while we are above the threshold, it should not kill the process
pytest.param(
TerminalTIState.SKIPPED,
tz.datetime(2024, 12, 1, 10, 10, 0),
1,
False,
id="non_success_state",
),
# Current time minus 5 seconds | Threshold: 10s
pytest.param(
TerminalTIState.SUCCESS,
tz.datetime(2024, 12, 1, 10, 10, 15),
10,
False,
id="below_threshold",
),
# Current time minus 10 seconds | Threshold: 9s
pytest.param(
TerminalTIState.SUCCESS,
tz.datetime(2024, 12, 1, 10, 10, 10),
9,
True,
id="above_threshold",
),
# End datetime is None | Threshold: 20s
pytest.param(
TerminalTIState.SUCCESS,
None,
20,
False,
id="task_end_datetime_none",
),
],
)
def test_overtime_handling(
self,
mocker,
terminal_state,
task_end_datetime,
overtime_threshold,
expected_kill,
time_machine,
):
"""Test handling of overtime under various conditions."""
# Mocking logger since we are only interested that it is called with the expected message
# and not the actual log output
mock_logger = mocker.patch("airflow.sdk.execution_time.supervisor.log")

# Mock the kill method at the class level so we can assert it was called with the correct signal
mock_kill = mocker.patch("airflow.sdk.execution_time.supervisor.WatchedSubprocess.kill")

mock_watched_subprocess = WatchedSubprocess(
ti_id=TI_ID,
pid=12345,
stdin=mocker.Mock(),
process=mocker.Mock(),
client=mocker.Mock(),
)

# Fix the current datetime
instant = tz.datetime(2024, 12, 1, 10, 10, 20)
time_machine.move_to(instant, tick=False)

# Set the terminal state and task end datetime
mock_watched_subprocess._terminal_state = terminal_state
mock_watched_subprocess._task_end_datetime = task_end_datetime
mock_watched_subprocess.task_success_overtime_threshold = overtime_threshold

# Call the method under test
mock_watched_subprocess._handle_task_overtime_if_needed()

# Validate process kill behavior and log messages
if expected_kill:
mock_kill.assert_called_once_with(signal.SIGTERM, force=True)
mock_logger.warning.assert_called_once_with(
"Task success overtime reached; terminating process",
ti_id=TI_ID,
)
else:
mock_kill.assert_not_called()
mock_logger.warning.assert_not_called()


class TestWatchedSubprocessKill:
@pytest.fixture
Expand Down
22 changes: 18 additions & 4 deletions task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@

from airflow.sdk import DAG, BaseOperator
from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.execution_time.comms import DeferTask, StartupDetails
from airflow.sdk.execution_time.comms import DeferTask, StartupDetails, TaskState
from airflow.sdk.execution_time.task_runner import CommsDecoder, parse, run
from airflow.utils import timezone
from airflow.utils.state import TerminalTIState


class TestCommsDecoder:
Expand Down Expand Up @@ -78,7 +79,7 @@ def test_parse(test_dags_dir: Path):
assert isinstance(ti.task.dag, DAG)


def test_run_basic(test_dags_dir: Path):
def test_run_basic(test_dags_dir: Path, time_machine, mocked_supervisor_comms):
"""Test running a basic task."""
what = StartupDetails(
ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1),
Expand All @@ -87,10 +88,23 @@ def test_run_basic(test_dags_dir: Path):
)

ti = parse(what)
run(ti, log=mock.MagicMock())

instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)

with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
# Mocking the communication interface
mock_supervisor_comms.send_request = mock.Mock()
run(ti, log=mock.MagicMock())

mock_supervisor_comms.send_request.assert_called_once_with(
msg=TaskState(state=TerminalTIState.SUCCESS, end_date=instant), log=mock.ANY
)


def test_run_deferred_basic(test_dags_dir: Path, time_machine):
def test_run_deferred_basic(test_dags_dir: Path, time_machine, mocked_supervisor_comms):
"""Test that a task can transition to a deferred state."""
what = StartupDetails(
ti=TaskInstance(
Expand Down

0 comments on commit 48ca10c

Please sign in to comment.