Skip to content

Commit

Permalink
AIP-72: Handle External update TI state in Supervisor
Browse files Browse the repository at this point in the history
- Updated logic to handle externally updated TI state in Supervisor. This states could have been externally changed via UI, CLI, API etc
- Replaced `FASTEST_HEARTBEAT_INTERVAL` and `SLOWEST_HEARTBEAT_INTERVAL` with `MIN_HEARTBEAT_INTERVAL` and `HEARTBEAT_THRESHOLD` for better clarity

This is part of my efforts to port LocalTaskJob tests to Supervisor: apache#44356.

This ports over `TestLocalTaskJob.test_mark_{success,failure}_no_kill`.

This PR also allows retrying heartbeats:

- Added `_last_successful_heartbeat` and `_last_heartbeat_attempt` for better separation of tracking successful heartbeats and retries.
- `MIN_HEARTBEAT_INTERVAL` is now respected between heartbeat attempts, even after failures.
-  The num of retries is configurable via `MAX_FAILED_HEARTBEATS`
  • Loading branch information
kaxil committed Nov 28, 2024
1 parent ab2bd2d commit 0893557
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 39 deletions.
88 changes: 61 additions & 27 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
import structlog
from pydantic import TypeAdapter

from airflow.sdk.api.client import Client
from airflow.sdk.api.client import Client, ServerResponseError
from airflow.sdk.api.datamodels._generated import IntermediateTIState, TaskInstance, TerminalTIState
from airflow.sdk.execution_time.comms import (
DeferTask,
Expand All @@ -54,6 +54,8 @@
)

if TYPE_CHECKING:
from selectors import SelectorKey

from structlog.typing import FilteringBoundLogger, WrappedLogger


Expand All @@ -62,9 +64,11 @@
log: FilteringBoundLogger = structlog.get_logger(logger_name="supervisor")

# TODO: Pull this from config
SLOWEST_HEARTBEAT_INTERVAL: int = 30
# (previously `[scheduler] local_task_job_heartbeat_sec` with the following as fallback if it is 0:
# `[scheduler] scheduler_zombie_task_threshold`)
HEARTBEAT_THRESHOLD: int = 30
# Don't heartbeat more often than this
FASTEST_HEARTBEAT_INTERVAL: int = 5
MIN_HEARTBEAT_INTERVAL: int = 5


@overload
Expand Down Expand Up @@ -423,14 +427,10 @@ def _monitor_subprocess(self):
This function:
- Polls the subprocess for output
- Sends heartbeats to the client to keep the task alive
- Checks if the subprocess has exited
- Waits for activity on file objects (e.g., subprocess stdout, stderr) using the selector.
- Processes events triggered on the monitored file objects, such as data availability or EOF.
- Sends heartbeats to ensure the process is alive and checks if the subprocess has exited.
"""
# Until we have a selector for the process, don't poll for more than 10s, just in case it exists but
# doesn't produce any output
max_poll_interval = 10

while self._exit_code is None or len(self.selector.get_map()):
last_heartbeat_ago = time.monotonic() - self._last_heartbeat
# Monitor the task to see if it's done. Wait in a syscall (`select`) for as long as possible
Expand All @@ -439,22 +439,42 @@ def _monitor_subprocess(self):
0, # Make sure this value is never negative,
min(
# Ensure we heartbeat _at most_ 75% through time the zombie threshold time
SLOWEST_HEARTBEAT_INTERVAL - last_heartbeat_ago * 0.75,
max_poll_interval,
HEARTBEAT_THRESHOLD - last_heartbeat_ago * 0.75,
MIN_HEARTBEAT_INTERVAL,
),
)
# Block until events are ready or the timeout is reached
# This listens for activity (e.g., subprocess output) on registered file objects
events = self.selector.select(timeout=max_wait_time)
for key, _ in events:
socket_handler = key.data
need_more = socket_handler(key.fileobj)

if not need_more:
self.selector.unregister(key.fileobj)
key.fileobj.close() # type: ignore[union-attr]
self._process_file_object_events(events)

self._check_subprocess_exit()
self._send_heartbeat_if_needed()

def _process_file_object_events(self, events: list[tuple[SelectorKey, int]]):
"""
Process selector events by invoking handlers for each file object.
For each file object event, this method retrieves the associated handler and processes
the event. If the handler indicates that the file object no longer needs
monitoring (e.g., EOF or closed), the file object is unregistered and closed.
"""
for key, _ in events:
# Retrieve the handler responsible for processing this file object (e.g., stdout, stderr)
socket_handler = key.data

# Example of handler behavior:
# If the subprocess writes "Hello, World!" to stdout:
# - `socket_handler` reads and processes the message.
# - If EOF is reached, the handler returns False to signal no more reads are expected.
need_more = socket_handler(key.fileobj)

# If the handler signals that the file object is no longer needed (EOF, closed, etc.)
# unregister it from the selector to stop monitoring and close it cleanly.
if not need_more:
self.selector.unregister(key.fileobj)
key.fileobj.close() # type: ignore[union-attr]

def _check_subprocess_exit(self):
"""Check if the subprocess has exited."""
if self._exit_code is None:
Expand All @@ -466,14 +486,28 @@ def _check_subprocess_exit(self):

def _send_heartbeat_if_needed(self):
"""Send a heartbeat to the client if heartbeat interval has passed."""
if time.monotonic() - self._last_heartbeat >= FASTEST_HEARTBEAT_INTERVAL:
try:
self.client.task_instances.heartbeat(self.ti_id, pid=self._process.pid)
self._last_heartbeat = time.monotonic()
except Exception:
log.warning("Failed to send heartbeat", exc_info=True)
# TODO: If we couldn't heartbeat for X times the interval, kill ourselves
pass
# Respect the minimum interval between heartbeat attempts
if (time.monotonic() - self._last_heartbeat) < MIN_HEARTBEAT_INTERVAL:
return

try:
self.client.task_instances.heartbeat(self.ti_id, pid=self._process.pid)
self._last_heartbeat = time.monotonic()
except ServerResponseError as e:
if e.response.status_code == 409:
log.error("Server indicated we shouldn't be running anymore", detail=e.detail)
elif e.response.status_code == 404:
log.error("Task Instance not found")
else:
# TODO: Handle other errors
raise

# If heartbeating raises an error, kill the subprocess
self.kill(signal.SIGTERM)
except Exception:
log.warning("Failed to send heartbeat", exc_info=True)
# TODO: If we couldn't heartbeat for X times the interval, kill ourselves
pass

@property
def final_state(self):
Expand Down
9 changes: 8 additions & 1 deletion task_sdk/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import logging
import os
from pathlib import Path
from typing import TYPE_CHECKING, NoReturn
Expand Down Expand Up @@ -72,7 +73,7 @@ def test_dags_dir():


@pytest.fixture
def captured_logs():
def captured_logs(request):
import structlog

from airflow.sdk.log import configure_logging, reset_logging
Expand All @@ -81,6 +82,12 @@ def captured_logs():
reset_logging()
configure_logging(enable_pretty_log=False)

# Get log level from test parameter, defaulting to INFO if not provided
log_level = getattr(request, "param", logging.INFO)

# We want to capture all logs, but we don't want to see them in the test output
structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(log_level))

# But we need to replace remove the last processor (the one that turns JSON into text, as we want the
# event dict for tests)
cur_processors = structlog.get_config()["processors"]
Expand Down
75 changes: 64 additions & 11 deletions task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

import httpx
import pytest
import structlog
from uuid6 import uuid7

from airflow.sdk.api import client as sdk_client
Expand Down Expand Up @@ -64,9 +63,6 @@ def lineno():
@pytest.mark.usefixtures("disable_capturing")
class TestWatchedSubprocess:
def test_reading_from_pipes(self, captured_logs, time_machine):
# Ignore anything lower than INFO for this test. Captured_logs resets things for us afterwards
structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(logging.INFO))

def subprocess_main():
# This is run in the subprocess!

Expand Down Expand Up @@ -177,9 +173,6 @@ def subprocess_main():
assert rc == -9

def test_last_chance_exception_handling(self, capfd):
# Ignore anything lower than INFO for this test. Captured_logs resets things for us afterwards
structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(logging.INFO))

def subprocess_main():
# The real main() in task_runner catches exceptions! This is what would happen if we had a syntax
# or import error for instance - a very early exception
Expand Down Expand Up @@ -210,7 +203,7 @@ def test_regular_heartbeat(self, spy_agency: kgb.SpyAgency, monkeypatch):
"""Test that the WatchedSubprocess class regularly sends heartbeat requests, up to a certain frequency"""
import airflow.sdk.execution_time.supervisor

monkeypatch.setattr(airflow.sdk.execution_time.supervisor, "FASTEST_HEARTBEAT_INTERVAL", 0.1)
monkeypatch.setattr(airflow.sdk.execution_time.supervisor, "MIN_HEARTBEAT_INTERVAL", 0.1)

def subprocess_main():
sys.stdin.readline()
Expand Down Expand Up @@ -241,9 +234,6 @@ def subprocess_main():
def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine):
"""Test running a simple DAG in a subprocess and capturing the output."""

# Ignore anything lower than INFO for this test.
structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(logging.INFO))

instant = tz.datetime(2024, 11, 7, 12, 34, 56, 78901)
time_machine.move_to(instant, tick=False)

Expand Down Expand Up @@ -299,6 +289,69 @@ def handle_request(request: httpx.Request) -> httpx.Response:
"previous_state": "running",
}

@pytest.mark.parametrize("captured_logs", [logging.ERROR], indirect=True, ids=["log_level=error"])
def test_state_conflict_on_heartbeat(self, captured_logs, monkeypatch, mocker):
"""
Test that ensures that the Supervisor does not cause the task to fail if the Task Instance is no longer
in the running state.
"""
import airflow.sdk.execution_time.supervisor

monkeypatch.setattr(airflow.sdk.execution_time.supervisor, "MIN_HEARTBEAT_INTERVAL", 0.1)

def subprocess_main():
sys.stdin.readline()
sleep(5)

ti_id = uuid7()

# Track the number of requests to simulate mixed responses
request_count = {"count": 0}

def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == f"/task-instances/{ti_id}/heartbeat":
request_count["count"] += 1
if request_count["count"] == 1:
# First request succeeds
return httpx.Response(status_code=204)
else:
# Second request returns a conflict status code
return httpx.Response(
409,
json={
"reason": "not_running",
"message": "TI is no longer in the running state and task should terminate",
"current_state": "success",
},
)
return httpx.Response(status_code=204)

proc = WatchedSubprocess.start(
path=os.devnull,
ti=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1),
client=make_client(transport=httpx.MockTransport(handle_request)),
target=subprocess_main,
)

# Wait for the subprocess to finish
assert proc.wait() == -signal.SIGTERM

# Verify the number of requests made
assert request_count["count"] == 2
assert captured_logs == [
{
"detail": {
"current_state": "success",
"message": "TI is no longer in the running state and task should terminate",
"reason": "not_running",
},
"event": "Server indicated we shouldn't be running anymore",
"level": "error",
"logger": "supervisor",
"timestamp": mocker.ANY,
}
]


class TestHandleRequest:
@pytest.fixture
Expand Down

0 comments on commit 0893557

Please sign in to comment.