From b1a44b4e3db8a3d3c13dfc810ad02f5785f82eca Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Wed, 27 Nov 2024 12:13:44 +0000 Subject: [PATCH] AIP-72: Improve Supervisor and Task Instance State Validation (#44405) --- .../tests/execution_time/test_supervisor.py | 38 ++++++++++++++++++- .../routes/test_task_instances.py | 13 ++++--- 2 files changed, 45 insertions(+), 6 deletions(-) diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 1083364f289d5..9f582074586a5 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -28,11 +28,13 @@ from typing import TYPE_CHECKING from unittest.mock import MagicMock +import httpx import pytest import structlog from uuid6 import uuid7 from airflow.sdk.api import client as sdk_client +from airflow.sdk.api.client import ServerResponseError from airflow.sdk.api.datamodels._generated import TaskInstance from airflow.sdk.api.datamodels.activities import ExecuteTaskActivity from airflow.sdk.execution_time.comms import ( @@ -46,6 +48,8 @@ from airflow.sdk.execution_time.supervisor import WatchedSubprocess, supervise from airflow.utils import timezone as tz +from task_sdk.tests.api.test_client import make_client + if TYPE_CHECKING: import kgb @@ -73,7 +77,7 @@ def subprocess_main(): print("I'm a short message") sys.stdout.write("Message ") print("stderr message", file=sys.stderr) - # We need a short sleep for the main process to process things. I worry this timining will be + # We need a short sleep for the main process to process things. I worry this timing will be # fragile, but I can't think of a better way. This lets the stdout be read (partial line) and the # stderr full line be read sleep(0.1) @@ -265,6 +269,38 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine): "timestamp": "2024-11-07T12:34:56.078901Z", } in captured_logs + def test_supervisor_handles_already_running_task(self): + """Test that Supervisor prevents starting a Task Instance that is already running.""" + ti = TaskInstance(id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1) + + # Mock API Server response indicating the TI is already running + # The API Server would return a 409 Conflict status code if the TI is not + # in a "queued" state. + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == f"/task-instances/{ti.id}/state": + return httpx.Response( + 409, + json={ + "reason": "invalid_state", + "message": "TI was not in a state where it could be marked as running", + "previous_state": "running", + }, + ) + + return httpx.Response(status_code=204) + + client = make_client(transport=httpx.MockTransport(handle_request)) + + with pytest.raises(ServerResponseError, match="Server returned error") as err: + WatchedSubprocess.start(path=os.devnull, ti=ti, client=client) + + assert err.value.response.status_code == 409 + assert err.value.detail == { + "reason": "invalid_state", + "message": "TI was not in a state where it could be marked as running", + "previous_state": "running", + } + class TestHandleRequest: @pytest.fixture diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py b/tests/api_fastapi/execution_api/routes/test_task_instances.py index efb48ccb533aa..d2285e7e3a92f 100644 --- a/tests/api_fastapi/execution_api/routes/test_task_instances.py +++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py @@ -25,7 +25,7 @@ from airflow.models.taskinstance import TaskInstance from airflow.utils import timezone -from airflow.utils.state import State +from airflow.utils.state import State, TaskInstanceState from tests_common.test_utils.db import clear_db_runs @@ -79,14 +79,17 @@ def test_ti_update_state_to_running(self, client, session, create_task_instance) assert ti.pid == 100 assert ti.start_date.isoformat() == "2024-10-31T12:00:00+00:00" - def test_ti_update_state_conflict_if_not_queued(self, client, session, create_task_instance): + @pytest.mark.parametrize("initial_ti_state", [s for s in TaskInstanceState if s != State.QUEUED]) + def test_ti_update_state_conflict_if_not_queued( + self, client, session, create_task_instance, initial_ti_state + ): """ Test that a 409 error is returned when the Task Instance is not in a state where it can be marked as running. In this case, the Task Instance is first in NONE state so it cannot be marked as running. """ ti = create_task_instance( task_id="test_ti_update_state_conflict_if_not_queued", - state=State.NONE, + state=initial_ti_state, ) session.commit() @@ -105,12 +108,12 @@ def test_ti_update_state_conflict_if_not_queued(self, client, session, create_ta assert response.json() == { "detail": { "message": "TI was not in a state where it could be marked as running", - "previous_state": State.NONE, + "previous_state": initial_ti_state, "reason": "invalid_state", } } - assert session.scalar(select(TaskInstance.state).where(TaskInstance.id == ti.id)) == State.NONE + assert session.scalar(select(TaskInstance.state).where(TaskInstance.id == ti.id)) == initial_ti_state @pytest.mark.parametrize( ("state", "end_date", "expected_state"),