Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fail stop feature for DAGs #29406

Merged
merged 5 commits into from
Apr 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from typing import TYPE_CHECKING, Any, NamedTuple, Sized

if TYPE_CHECKING:
from airflow.models import DagRun
from airflow.models import DAG, DagRun


class AirflowException(Exception):
Expand Down Expand Up @@ -207,6 +207,22 @@ def __init__(self, *args, **kwargs):
warnings.warn("DagFileExists is deprecated and will be removed.", DeprecationWarning, stacklevel=2)


class DagInvalidTriggerRule(AirflowException):
"""Raise when a dag has 'fail_stop' enabled yet has a non-default trigger rule"""

@classmethod
def check(cls, dag: DAG | None, trigger_rule: str):
from airflow.models.abstractoperator import DEFAULT_TRIGGER_RULE

if dag is not None and dag.fail_stop and trigger_rule != DEFAULT_TRIGGER_RULE:
raise cls()

def __str__(self) -> str:
from airflow.models.abstractoperator import DEFAULT_TRIGGER_RULE

return f"A 'fail-stop' dag can only have {DEFAULT_TRIGGER_RULE} trigger rule"


class DuplicateTaskIdFound(AirflowException):
"""Raise when a Task with duplicate task_id is defined in the same DAG."""

Expand Down
4 changes: 3 additions & 1 deletion airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from sqlalchemy.orm.exc import NoResultFound

from airflow.configuration import conf
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, TaskDeferred
from airflow.exceptions import AirflowException, DagInvalidTriggerRule, RemovedInAirflow3Warning, TaskDeferred
from airflow.lineage import apply_lineage, prepare_lineage
from airflow.models.abstractoperator import (
DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
Expand Down Expand Up @@ -801,6 +801,8 @@ def __init__(
dag = dag or DagContext.get_current_dag()
task_group = task_group or TaskGroupContext.get_current_task_group(dag)

DagInvalidTriggerRule.check(dag, trigger_rule)

self.task_id = task_group.child_id(task_id) if task_group else task_id
if not self.__from_mapped and task_group:
task_group.add(self)
Expand Down
12 changes: 12 additions & 0 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
AirflowDagInconsistent,
AirflowException,
AirflowSkipException,
DagInvalidTriggerRule,
DuplicateTaskIdFound,
RemovedInAirflow3Warning,
TaskNotFound,
Expand Down Expand Up @@ -357,6 +358,9 @@ class DAG(LoggingMixin):
Can be used as an HTTP link (for example the link to your Slack channel), or a mailto link.
e.g: {"dag_owner": "https://airflow.apache.org/"}
:param auto_register: Automatically register this DAG when it is used in a ``with`` block
:param fail_stop: Fails currently running tasks when task in DAG fails.
**Warning**: A fail stop dag can only have tasks with the default trigger rule ("all_success").
An exception will be thrown if any task in a fail stop dag has a non default trigger rule.
"""

_comps = {
Expand Down Expand Up @@ -419,6 +423,7 @@ def __init__(
tags: list[str] | None = None,
owner_links: dict[str, str] | None = None,
auto_register: bool = True,
fail_stop: bool = False,
):
from airflow.utils.task_group import TaskGroup

Expand Down Expand Up @@ -602,6 +607,8 @@ def __init__(
self.is_paused_upon_creation = is_paused_upon_creation
self.auto_register = auto_register

self.fail_stop = fail_stop

self.jinja_environment_kwargs = jinja_environment_kwargs
self.render_template_as_native_obj = render_template_as_native_obj

Expand Down Expand Up @@ -2353,6 +2360,8 @@ def add_task(self, task: Operator) -> None:

:param task: the task you want to add
"""
DagInvalidTriggerRule.check(self, task.trigger_rule)

from airflow.utils.task_group import TaskGroupContext

if not self.start_date and not task.start_date:
Expand Down Expand Up @@ -3055,6 +3064,7 @@ def get_serialized_fields(cls):
"has_on_success_callback",
"has_on_failure_callback",
"auto_register",
"fail_stop",
}
cls.__serialized_fields = frozenset(vars(DAG(dag_id="test")).keys()) - exclusion_list
return cls.__serialized_fields
Expand Down Expand Up @@ -3530,6 +3540,7 @@ def dag(
tags: list[str] | None = None,
owner_links: dict[str, str] | None = None,
auto_register: bool = True,
fail_stop: bool = False,
) -> Callable[[Callable], Callable[..., DAG]]:
"""
Python dag decorator. Wraps a function into an Airflow DAG.
Expand Down Expand Up @@ -3583,6 +3594,7 @@ def factory(*args, **kwargs):
schedule=schedule,
owner_links=owner_links,
auto_register=auto_register,
fail_stop=fail_stop,
) as dag_obj:
# Set DAG documentation from function documentation if it exists and doc_md is not set.
if f.__doc__ and not dag_obj.doc_md:
Expand Down
19 changes: 19 additions & 0 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,21 @@ def set_current_context(context: Context) -> Generator[Context, None, None]:
)


def stop_all_tasks_in_dag(tis: list[TaskInstance], session: Session, task_id_to_ignore: int):
for ti in tis:
if ti.task_id == task_id_to_ignore or ti.state in (
TaskInstanceState.SUCCESS,
TaskInstanceState.FAILED,
):
continue
if ti.state == TaskInstanceState.RUNNING:
log.info("Forcing task %s to fail", ti.task_id)
ti.error(session)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shoudl add some logging telling that we are doing it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise it will be quite magical

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! I have added a logging statement for when a running task is being force failed, and when a task is being set to the skipped state. Let me know if these log statements are good or if there are anything to change.

else:
log.info("Setting task %s to SKIPPED", ti.task_id)
ti.set_state(state=TaskInstanceState.SKIPPED, session=session)


def clear_task_instances(
tis: list[TaskInstance],
session: Session,
Expand Down Expand Up @@ -1896,6 +1911,10 @@ def handle_failure(
email_for_state = operator.attrgetter("email_on_failure")
callbacks = task.on_failure_callback if task else None
callback_type = "on_failure"

if task and task.dag and task.dag.fail_stop:
tis = self.get_dagrun(session).get_task_instances()
stop_all_tasks_in_dag(tis, session, self.task_id)
else:
if self.state == State.QUEUED:
# We increase the try_number so as to fail the task if it fails to start after sometime
Expand Down
38 changes: 37 additions & 1 deletion tests/models/test_baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import pytest

from airflow.decorators import task as task_decorator
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
from airflow.exceptions import AirflowException, DagInvalidTriggerRule, RemovedInAirflow3Warning
from airflow.lineage.entities import File
from airflow.models import DAG
from airflow.models.baseoperator import BaseOperator, BaseOperatorMeta, chain, cross_downstream
Expand Down Expand Up @@ -163,6 +163,42 @@ def test_illegal_args_forbidden(self):
illegal_argument_1234="hello?",
)

def test_trigger_rule_validation(self):
from airflow.models.abstractoperator import DEFAULT_TRIGGER_RULE

fail_stop_dag = DAG(
dag_id="test_dag_trigger_rule_validation", start_date=DEFAULT_DATE, fail_stop=True
)
non_fail_stop_dag = DAG(
dag_id="test_dag_trigger_rule_validation", start_date=DEFAULT_DATE, fail_stop=False
)

# An operator with default trigger rule and a fail-stop dag should be allowed
try:
BaseOperator(
task_id="test_valid_trigger_rule", dag=fail_stop_dag, trigger_rule=DEFAULT_TRIGGER_RULE
)
except DagInvalidTriggerRule as exception:
assert (
False
), f"BaseOperator raises exception with fail-stop dag & default trigger rule: {exception}"

# An operator with non default trigger rule and a non fail-stop dag should be allowed
try:
BaseOperator(
task_id="test_valid_trigger_rule", dag=non_fail_stop_dag, trigger_rule=TriggerRule.DUMMY
)
except DagInvalidTriggerRule as exception:
assert (
False
), f"BaseOperator raises exception with non fail-stop dag & non-default trigger rule: {exception}"

# An operator with non default trigger rule and a fail stop dag should not be allowed
with pytest.raises(DagInvalidTriggerRule):
BaseOperator(
task_id="test_invalid_trigger_rule", dag=fail_stop_dag, trigger_rule=TriggerRule.DUMMY
)

@pytest.mark.parametrize(
("content", "context", "expected_output"),
[
Expand Down
36 changes: 36 additions & 0 deletions tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1718,6 +1718,42 @@ def test_create_dagrun_job_id_is_set(self):
)
assert dr.creating_job_id == job_id

def test_dag_add_task_checks_trigger_rule(self):
# A non fail stop dag should allow any trigger rule
from airflow.exceptions import DagInvalidTriggerRule
from airflow.utils.trigger_rule import TriggerRule

task_with_non_default_trigger_rule = EmptyOperator(
task_id="task_with_non_default_trigger_rule", trigger_rule=TriggerRule.DUMMY
)
non_fail_stop_dag = DAG(
dag_id="test_dag_add_task_checks_trigger_rule", start_date=DEFAULT_DATE, fail_stop=False
)
try:
non_fail_stop_dag.add_task(task_with_non_default_trigger_rule)
except DagInvalidTriggerRule as exception:
assert False, f"dag add_task() raises DagInvalidTriggerRule for non fail stop dag: {exception}"

# a fail stop dag should allow default trigger rule
from airflow.models.abstractoperator import DEFAULT_TRIGGER_RULE

fail_stop_dag = DAG(
dag_id="test_dag_add_task_checks_trigger_rule", start_date=DEFAULT_DATE, fail_stop=True
)
task_with_default_trigger_rule = EmptyOperator(
task_id="task_with_default_trigger_rule", trigger_rule=DEFAULT_TRIGGER_RULE
)
try:
fail_stop_dag.add_task(task_with_default_trigger_rule)
except DagInvalidTriggerRule as exception:
assert (
False
), f"dag.add_task() raises exception for fail-stop dag & default trigger rule: {exception}"

# a fail stop dag should not allow a non-default trigger rule
with pytest.raises(DagInvalidTriggerRule):
fail_stop_dag.add_task(task_with_non_default_trigger_rule)

def test_dag_add_task_sets_default_task_group(self):
dag = DAG(dag_id="test_dag_add_task_sets_default_task_group", start_date=DEFAULT_DATE)
task_without_task_group = EmptyOperator(task_id="task_without_group_id")
Expand Down
55 changes: 55 additions & 0 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2556,6 +2556,61 @@ def test_handle_failure_task_undefined(self, create_task_instance):
del ti.task
ti.handle_failure("test ti.task undefined")

@provide_session
def test_handle_failure_fail_stop(self, create_dummy_dag, session=None):
start_date = timezone.datetime(2016, 6, 1)
clear_db_runs()

dag, task1 = create_dummy_dag(
dag_id="test_handle_failure_fail_stop",
schedule=None,
start_date=start_date,
task_id="task1",
trigger_rule="all_success",
with_dagrun_type=DagRunType.MANUAL,
session=session,
fail_stop=True,
)
dr = dag.create_dagrun(
run_id="test_ff",
run_type=DagRunType.MANUAL,
execution_date=timezone.utcnow(),
state=None,
session=session,
)

ti1 = dr.get_task_instance(task1.task_id, session=session)
ti1.task = task1
ti1.state = State.SUCCESS

states = [State.RUNNING, State.FAILED, State.QUEUED, State.SCHEDULED, State.DEFERRED]
tasks = []
for i in range(len(states)):
op = EmptyOperator(
task_id=f"reg_Task{i}",
dag=dag,
)
ti = TI(task=op, run_id=dr.run_id)
ti.state = states[i]
session.add(ti)
tasks.append(ti)

fail_task = EmptyOperator(
task_id="fail_Task",
dag=dag,
)
ti_ff = TI(task=fail_task, run_id=dr.run_id)
ti_ff.state = State.FAILED
session.add(ti_ff)
session.flush()
ti_ff.handle_failure("test retry handling")

assert ti1.state == State.SUCCESS
assert ti_ff.state == State.FAILED
exp_states = [State.FAILED, State.FAILED, State.SKIPPED, State.SKIPPED, State.SKIPPED]
for i in range(len(states)):
assert tasks[i].state == exp_states[i]

def test_does_not_retry_on_airflow_fail_exception(self, dag_maker):
def fail():
raise AirflowFailException("hopeless")
Expand Down