Skip to content

Commit

Permalink
AIP-72: Improve Supervisor and Task Instance State Validation (#44405)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaxil authored Nov 27, 2024
1 parent 21933a7 commit b1a44b4
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 6 deletions.
38 changes: 37 additions & 1 deletion task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@
from typing import TYPE_CHECKING
from unittest.mock import MagicMock

import httpx
import pytest
import structlog
from uuid6 import uuid7

from airflow.sdk.api import client as sdk_client
from airflow.sdk.api.client import ServerResponseError
from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.api.datamodels.activities import ExecuteTaskActivity
from airflow.sdk.execution_time.comms import (
Expand All @@ -46,6 +48,8 @@
from airflow.sdk.execution_time.supervisor import WatchedSubprocess, supervise
from airflow.utils import timezone as tz

from task_sdk.tests.api.test_client import make_client

if TYPE_CHECKING:
import kgb

Expand Down Expand Up @@ -73,7 +77,7 @@ def subprocess_main():
print("I'm a short message")
sys.stdout.write("Message ")
print("stderr message", file=sys.stderr)
# We need a short sleep for the main process to process things. I worry this timining will be
# We need a short sleep for the main process to process things. I worry this timing will be
# fragile, but I can't think of a better way. This lets the stdout be read (partial line) and the
# stderr full line be read
sleep(0.1)
Expand Down Expand Up @@ -265,6 +269,38 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine):
"timestamp": "2024-11-07T12:34:56.078901Z",
} in captured_logs

def test_supervisor_handles_already_running_task(self):
"""Test that Supervisor prevents starting a Task Instance that is already running."""
ti = TaskInstance(id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1)

# Mock API Server response indicating the TI is already running
# The API Server would return a 409 Conflict status code if the TI is not
# in a "queued" state.
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == f"/task-instances/{ti.id}/state":
return httpx.Response(
409,
json={
"reason": "invalid_state",
"message": "TI was not in a state where it could be marked as running",
"previous_state": "running",
},
)

return httpx.Response(status_code=204)

client = make_client(transport=httpx.MockTransport(handle_request))

with pytest.raises(ServerResponseError, match="Server returned error") as err:
WatchedSubprocess.start(path=os.devnull, ti=ti, client=client)

assert err.value.response.status_code == 409
assert err.value.detail == {
"reason": "invalid_state",
"message": "TI was not in a state where it could be marked as running",
"previous_state": "running",
}


class TestHandleRequest:
@pytest.fixture
Expand Down
13 changes: 8 additions & 5 deletions tests/api_fastapi/execution_api/routes/test_task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
from airflow.utils.state import State
from airflow.utils.state import State, TaskInstanceState

from tests_common.test_utils.db import clear_db_runs

Expand Down Expand Up @@ -79,14 +79,17 @@ def test_ti_update_state_to_running(self, client, session, create_task_instance)
assert ti.pid == 100
assert ti.start_date.isoformat() == "2024-10-31T12:00:00+00:00"

def test_ti_update_state_conflict_if_not_queued(self, client, session, create_task_instance):
@pytest.mark.parametrize("initial_ti_state", [s for s in TaskInstanceState if s != State.QUEUED])
def test_ti_update_state_conflict_if_not_queued(
self, client, session, create_task_instance, initial_ti_state
):
"""
Test that a 409 error is returned when the Task Instance is not in a state where it can be marked as
running. In this case, the Task Instance is first in NONE state so it cannot be marked as running.
"""
ti = create_task_instance(
task_id="test_ti_update_state_conflict_if_not_queued",
state=State.NONE,
state=initial_ti_state,
)
session.commit()

Expand All @@ -105,12 +108,12 @@ def test_ti_update_state_conflict_if_not_queued(self, client, session, create_ta
assert response.json() == {
"detail": {
"message": "TI was not in a state where it could be marked as running",
"previous_state": State.NONE,
"previous_state": initial_ti_state,
"reason": "invalid_state",
}
}

assert session.scalar(select(TaskInstance.state).where(TaskInstance.id == ti.id)) == State.NONE
assert session.scalar(select(TaskInstance.state).where(TaskInstance.id == ti.id)) == initial_ti_state

@pytest.mark.parametrize(
("state", "end_date", "expected_state"),
Expand Down

0 comments on commit b1a44b4

Please sign in to comment.