Skip to content

Commit

Permalink
Enable ending the task directly from the triggerer without going into…
Browse files Browse the repository at this point in the history
… the worker. (#40084)
  • Loading branch information
sunank200 authored Jul 25, 2024
1 parent 6f97525 commit 0fba616
Show file tree
Hide file tree
Showing 17 changed files with 368 additions and 60 deletions.
19 changes: 16 additions & 3 deletions airflow/callbacks/callback_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import json
from typing import TYPE_CHECKING

from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from airflow.models.taskinstance import SimpleTaskInstance

Expand Down Expand Up @@ -68,22 +70,33 @@ class TaskCallbackRequest(CallbackRequest):
:param full_filepath: File Path to use to run the callback
:param simple_task_instance: Simplified Task Instance representation
:param is_failure_callback: Flag to determine whether it is a Failure Callback or Success Callback
:param msg: Additional Message that can be used for logging to determine failure/zombie
:param processor_subdir: Directory used by Dag Processor when parsed the dag.
:param task_callback_type: e.g. whether on success, on failure, on retry.
"""

def __init__(
self,
full_filepath: str,
simple_task_instance: SimpleTaskInstance,
is_failure_callback: bool | None = True,
processor_subdir: str | None = None,
msg: str | None = None,
task_callback_type: TaskInstanceState | None = None,
):
super().__init__(full_filepath=full_filepath, processor_subdir=processor_subdir, msg=msg)
self.simple_task_instance = simple_task_instance
self.is_failure_callback = is_failure_callback
self.task_callback_type = task_callback_type

@property
def is_failure_callback(self) -> bool:
"""Returns True if the callback is a failure callback."""
if self.task_callback_type is None:
return True
return self.task_callback_type in {
TaskInstanceState.FAILED,
TaskInstanceState.UP_FOR_RETRY,
TaskInstanceState.UPSTREAM_FAILED,
}

def to_json(self) -> str:
from airflow.serialization.serialized_objects import BaseSerialization
Expand Down
37 changes: 32 additions & 5 deletions airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from airflow.models.dagwarning import DagWarning, DagWarningType
from airflow.models.errors import ParseImportError
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import TaskInstance, TaskInstance as TI
from airflow.models.taskinstance import TaskInstance, TaskInstance as TI, _run_finished_callback
from airflow.stats import Stats
from airflow.utils import timezone
from airflow.utils.email import get_email_address_list, send_email
Expand Down Expand Up @@ -808,8 +808,26 @@ def _execute_dag_callbacks(cls, dagbag: DagBag, request: DagCallbackRequest, ses
@provide_session
def _execute_task_callbacks(
cls, dagbag: DagBag | None, request: TaskCallbackRequest, unit_test_mode: bool, session: Session
):
if not request.is_failure_callback:
) -> None:
"""
Execute the task callbacks.
:param dagbag: the DagBag to use to get the task instance
:param request: the task callback request
:param session: the session to use
"""
try:
callback_type = TaskInstanceState(request.task_callback_type)
except ValueError:
callback_type = None
is_remote = callback_type in (TaskInstanceState.SUCCESS, TaskInstanceState.FAILED)

# previously we ignored any request besides failures. now if given callback type directly,
# then we respect it and execute it. additionally because in this scenario the callback
# is submitted remotely, we assume there is no need to mess with state; we simply run
# the callback

if not is_remote and not request.is_failure_callback:
return

simple_ti = request.simple_task_instance
Expand All @@ -820,6 +838,7 @@ def _execute_task_callbacks(
map_index=simple_ti.map_index,
session=session,
)

if not ti:
return

Expand All @@ -841,8 +860,16 @@ def _execute_task_callbacks(
if task:
ti.refresh_from_task(task)

ti.handle_failure(error=request.msg, test_mode=unit_test_mode, session=session)
cls.logger().info("Executed failure callback for %s in state %s", ti, ti.state)
if callback_type is TaskInstanceState.SUCCESS:
context = ti.get_template_context(session=session)
if TYPE_CHECKING:
assert ti.task
callbacks = ti.task.on_success_callback
_run_finished_callback(callbacks=callbacks, context=context)
cls.logger().info("Executed callback for %s in state %s", ti, ti.state)
elif not is_remote or callback_type is TaskInstanceState.FAILED:
ti.handle_failure(error=request.msg, test_mode=unit_test_mode, session=session)
cls.logger().info("Executed callback for %s in state %s", ti, ti.state)
session.flush()

@classmethod
Expand Down
5 changes: 4 additions & 1 deletion airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,10 @@ class TaskDeferred(BaseException):
Signal an operator moving to deferred state.
Special exception raised to signal that the operator it was raised from
wishes to defer until a trigger fires.
wishes to defer until a trigger fires. Triggers can send execution back to task or end the task instance
directly. If the trigger should end the task instance itself, ``method_name`` does not matter,
and can be None; otherwise, provide the name of the method that should be used when
resuming execution in the task.
"""

def __init__(
Expand Down
5 changes: 4 additions & 1 deletion airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1764,7 +1764,10 @@ def defer(
Mark this Operator "deferred", suspending its execution until the provided trigger fires an event.
This is achieved by raising a special exception (TaskDeferred)
which is caught in the main _execute_task wrapper.
which is caught in the main _execute_task wrapper. Triggers can send execution back to task or end
the task instance directly. If the trigger will end the task instance itself, ``method_name`` should
be None; otherwise, provide the name of the method that should be used when resuming execution in
the task.
"""
raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout)

Expand Down
12 changes: 12 additions & 0 deletions airflow/models/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@
AirflowClusterPolicyViolation,
AirflowDagCycleException,
AirflowDagDuplicatedIdException,
AirflowException,
RemovedInAirflow3Warning,
)
from airflow.listeners.listener import get_listener_manager
from airflow.models.base import Base
from airflow.stats import Stats
from airflow.utils import timezone
Expand Down Expand Up @@ -512,6 +514,16 @@ def _bag_dag(self, *, dag, root_dag, recursive):
settings.dag_policy(dag)

for task in dag.tasks:
# The listeners are not supported when ending a task via a trigger on asynchronous operators.
if getattr(task, "end_from_trigger", False) and get_listener_manager().has_listeners:
raise AirflowException(
"Listeners are not supported with end_from_trigger=True for deferrable operators. "
"Task %s in DAG %s has end_from_trigger=True with listeners from plugins. "
"Set end_from_trigger=False to use listeners.",
task.task_id,
dag.dag_id,
)

settings.task_policy(task)
except (AirflowClusterPolicyViolation, AirflowClusterPolicySkipDag):
raise
Expand Down
9 changes: 8 additions & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@
from airflow.models.dagbag import DagBag
from airflow.models.dataset import DatasetAliasModel, DatasetModel
from airflow.models.log import Log
from airflow.models.mappedoperator import MappedOperator
from airflow.models.param import process_params
from airflow.models.renderedtifields import get_serialized_template_fields
from airflow.models.taskfail import TaskFail
Expand Down Expand Up @@ -699,6 +698,8 @@ def _execute_task(task_instance: TaskInstance | TaskInstancePydantic, context: C
:meta private:
"""
from airflow.models.mappedoperator import MappedOperator

task_to_execute = task_instance.task

if TYPE_CHECKING:
Expand Down Expand Up @@ -1288,6 +1289,8 @@ def _record_task_map_for_downstreams(
:meta private:
"""
from airflow.models.mappedoperator import MappedOperator

if task.dag.__class__ is AttributeRemoved:
task.dag = dag # required after deserialization

Expand Down Expand Up @@ -3454,6 +3457,8 @@ def render_templates(
the unmapped, fully rendered BaseOperator. The original ``self.task``
before replacement is returned.
"""
from airflow.models.mappedoperator import MappedOperator

if not context:
context = self.get_template_context()
original_task = self.task
Expand Down Expand Up @@ -3989,6 +3994,8 @@ def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> Mapp

def _is_further_mapped_inside(operator: Operator, container: TaskGroup) -> bool:
"""Whether given operator is *further* mapped inside a task group."""
from airflow.models.mappedoperator import MappedOperator

if isinstance(operator, MappedOperator):
return True
task_group = operator.task_group
Expand Down
9 changes: 1 addition & 8 deletions airflow/models/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,7 @@ def submit_event(cls, trigger_id, event, session: Session = NEW_SESSION) -> None
TaskInstance.trigger_id == trigger_id, TaskInstance.state == TaskInstanceState.DEFERRED
)
):
# Add the event's payload into the kwargs for the task
next_kwargs = task_instance.next_kwargs or {}
next_kwargs["event"] = event.payload
task_instance.next_kwargs = next_kwargs
# Remove ourselves as its trigger
task_instance.trigger_id = None
# Finally, mark it as scheduled so it gets re-queued
task_instance.state = TaskInstanceState.SCHEDULED
event.handle_submit(task_instance=task_instance)

@classmethod
@internal_api_call
Expand Down
15 changes: 9 additions & 6 deletions airflow/sensors/date_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

import datetime
from typing import TYPE_CHECKING, NoReturn, Sequence
from typing import TYPE_CHECKING, Any, NoReturn, Sequence

from airflow.sensors.base import BaseSensorOperator
from airflow.triggers.temporal import DateTimeTrigger
Expand Down Expand Up @@ -85,18 +85,21 @@ class DateTimeSensorAsync(DateTimeSensor):
It is a drop-in replacement for DateTimeSensor.
:param target_time: datetime after which the job succeeds. (templated)
:param end_from_trigger: End the task directly from the triggerer without going into the worker.
"""

def __init__(self, **kwargs) -> None:
def __init__(self, *, end_from_trigger: bool = False, **kwargs) -> None:
super().__init__(**kwargs)
self.end_from_trigger = end_from_trigger

def execute(self, context: Context) -> NoReturn:
trigger = DateTimeTrigger(moment=timezone.parse(self.target_time))
self.defer(
trigger=trigger,
method_name="execute_complete",
trigger=DateTimeTrigger(
moment=timezone.parse(self.target_time), end_from_trigger=self.end_from_trigger
),
)

def execute_complete(self, context, event=None) -> None:
"""Execute when the trigger fires - returns immediately."""
def execute_complete(self, context: Context, event: Any = None) -> None:
"""Handle the event when the trigger fires and return immediately."""
return None
13 changes: 9 additions & 4 deletions airflow/sensors/time_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, NoReturn
from typing import TYPE_CHECKING, Any, NoReturn

from airflow.exceptions import AirflowSkipException
from airflow.sensors.base import BaseSensorOperator
Expand Down Expand Up @@ -59,28 +59,33 @@ class TimeDeltaSensorAsync(TimeDeltaSensor):
Will defers itself to avoid taking up a worker slot while it is waiting.
:param delta: time length to wait after the data interval before succeeding.
:param end_from_trigger: End the task directly from the triggerer without going into the worker.
.. seealso::
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/operator:TimeDeltaSensorAsync`
"""

def __init__(self, *, end_from_trigger: bool = False, delta, **kwargs) -> None:
super().__init__(delta=delta, **kwargs)
self.end_from_trigger = end_from_trigger

def execute(self, context: Context) -> bool | NoReturn:
target_dttm = context["data_interval_end"]
target_dttm += self.delta
if timezone.utcnow() > target_dttm:
# If the target datetime is in the past, return immediately
return True
try:
trigger = DateTimeTrigger(moment=target_dttm)
trigger = DateTimeTrigger(moment=target_dttm, end_from_trigger=self.end_from_trigger)
except (TypeError, ValueError) as e:
if self.soft_fail:
raise AirflowSkipException("Skipping due to soft_fail is set to True.") from e
raise

self.defer(trigger=trigger, method_name="execute_complete")

def execute_complete(self, context, event=None) -> None:
"""Execute for when the trigger fires - return immediately."""
def execute_complete(self, context: Context, event: Any = None) -> None:
"""Handle the event when the trigger fires and return immediately."""
return None
13 changes: 7 additions & 6 deletions airflow/sensors/time_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

import datetime
from typing import TYPE_CHECKING, NoReturn
from typing import TYPE_CHECKING, Any, NoReturn

from airflow.sensors.base import BaseSensorOperator
from airflow.triggers.temporal import DateTimeTrigger
Expand Down Expand Up @@ -56,14 +56,16 @@ class TimeSensorAsync(BaseSensorOperator):
This frees up a worker slot while it is waiting.
:param target_time: time after which the job succeeds
:param end_from_trigger: End the task directly from the triggerer without going into the worker.
.. seealso::
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/operator:TimeSensorAsync`
"""

def __init__(self, *, target_time: datetime.time, **kwargs) -> None:
def __init__(self, *, end_from_trigger: bool = False, target_time: datetime.time, **kwargs) -> None:
super().__init__(**kwargs)
self.end_from_trigger = end_from_trigger
self.target_time = target_time

aware_time = timezone.coerce_datetime(
Expand All @@ -73,12 +75,11 @@ def __init__(self, *, target_time: datetime.time, **kwargs) -> None:
self.target_datetime = timezone.convert_to_utc(aware_time)

def execute(self, context: Context) -> NoReturn:
trigger = DateTimeTrigger(moment=self.target_datetime)
self.defer(
trigger=trigger,
trigger=DateTimeTrigger(moment=self.target_datetime, end_from_trigger=self.end_from_trigger),
method_name="execute_complete",
)

def execute_complete(self, context, event=None) -> None:
"""Execute when the trigger fires - returns immediately."""
def execute_complete(self, context: Context, event: Any = None) -> None:
"""Handle the event when the trigger fires and return immediately."""
return None
Loading

0 comments on commit 0fba616

Please sign in to comment.