Skip to content

Commit

Permalink
Enable ending the task directly from the triggerer without going into…
Browse files Browse the repository at this point in the history
… the worker.

Make is_failure_callback as property of TaskCallbackRequest

Add execute_complete as method name in test

Apply suggestions from code review

Add other example in deferrable.rst and fix test

Apply suggestions from code review

Fix circular import from other PRs

Fix the documentation

Raise Exception for Listeners with end_from_trigger=True

Add the notification to listeners

Fix the PR comments

Fix the PR comments

Add the documentation

Refactor rename end_task to end_from_trigger

Add method name as __trigger_exit__

fix PR comments

Resolve PR comments

Fix PR comments

Fix the if else condition in implementation

Fix mypy errors

remove session from handle submit

Make task to go to terminal state from triggerer without needing a worker

Make task to go to terminal state from triggerer without needing a worker

Resolve PR comments

Fix PR comments

Fix the if else condition in implementation

Fix the tests

Fix mypy errors

add session in handle submit

remove session from handle submit

Make task to go to terminal state from triggerer without needing a worker

Make task to go to terminal state from triggerer without needing a worker
  • Loading branch information
sunank200 committed Jul 24, 2024
1 parent 9ec9eb7 commit d42afd2
Show file tree
Hide file tree
Showing 16 changed files with 370 additions and 58 deletions.
19 changes: 16 additions & 3 deletions airflow/callbacks/callback_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import json
from typing import TYPE_CHECKING

from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from airflow.models.taskinstance import SimpleTaskInstance

Expand Down Expand Up @@ -68,22 +70,33 @@ class TaskCallbackRequest(CallbackRequest):
:param full_filepath: File Path to use to run the callback
:param simple_task_instance: Simplified Task Instance representation
:param is_failure_callback: Flag to determine whether it is a Failure Callback or Success Callback
:param msg: Additional Message that can be used for logging to determine failure/zombie
:param processor_subdir: Directory used by Dag Processor when parsed the dag.
:param task_callback_type: e.g. whether on success, on failure, on retry.
"""

def __init__(
self,
full_filepath: str,
simple_task_instance: SimpleTaskInstance,
is_failure_callback: bool | None = True,
processor_subdir: str | None = None,
msg: str | None = None,
task_callback_type: TaskInstanceState | None = None,
):
super().__init__(full_filepath=full_filepath, processor_subdir=processor_subdir, msg=msg)
self.simple_task_instance = simple_task_instance
self.is_failure_callback = is_failure_callback
self.task_callback_type = task_callback_type

@property
def is_failure_callback(self) -> bool:
"""Returns True if the callback is a failure callback."""
if self.task_callback_type is None:
return True
return self.task_callback_type in {
TaskInstanceState.FAILED,
TaskInstanceState.UP_FOR_RETRY,
TaskInstanceState.UPSTREAM_FAILED,
}

def to_json(self) -> str:
from airflow.serialization.serialized_objects import BaseSerialization
Expand Down
39 changes: 34 additions & 5 deletions airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from airflow.models.dagwarning import DagWarning, DagWarningType
from airflow.models.errors import ParseImportError
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import TaskInstance, TaskInstance as TI
from airflow.models.taskinstance import TaskInstance, TaskInstance as TI, _run_finished_callback
from airflow.stats import Stats
from airflow.utils import timezone
from airflow.utils.email import get_email_address_list, send_email
Expand Down Expand Up @@ -763,8 +763,28 @@ def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, se
if callbacks and context:
DAG.execute_callback(callbacks, context, dag.dag_id)

def _execute_task_callbacks(self, dagbag: DagBag | None, request: TaskCallbackRequest, session: Session):
if not request.is_failure_callback:
def _execute_task_callbacks(
self, dagbag: DagBag | None, request: TaskCallbackRequest, session: Session
) -> None:
"""
Execute the task callbacks.
:param dagbag: the DagBag to use to get the task instance
:param request: the task callback request
:param session: the session to use
"""
try:
callback_type = TaskInstanceState(request.task_callback_type)
except ValueError:
callback_type = None
is_remote = callback_type in (TaskInstanceState.SUCCESS, TaskInstanceState.FAILED)

# previously we ignored any request besides failures. now if given callback type directly,
# then we respect it and execute it. additionally because in this scenario the callback
# is submitted remotely, we assume there is no need to mess with state; we simply run
# the callback

if not is_remote and not request.is_failure_callback:
return

simple_ti = request.simple_task_instance
Expand All @@ -775,6 +795,7 @@ def _execute_task_callbacks(self, dagbag: DagBag | None, request: TaskCallbackRe
map_index=simple_ti.map_index,
session=session,
)

if not ti:
return

Expand All @@ -796,8 +817,16 @@ def _execute_task_callbacks(self, dagbag: DagBag | None, request: TaskCallbackRe
if task:
ti.refresh_from_task(task)

ti.handle_failure(error=request.msg, test_mode=self.UNIT_TEST_MODE, session=session)
self.log.info("Executed failure callback for %s in state %s", ti, ti.state)
if callback_type is TaskInstanceState.SUCCESS:
context = ti.get_template_context(session=session)
if TYPE_CHECKING:
assert ti.task
callbacks = ti.task.on_success_callback
_run_finished_callback(callbacks=callbacks, context=context)
self.log.info("Executed callback for %s in state %s", ti, ti.state)
elif not is_remote or callback_type is TaskInstanceState.FAILED:
ti.handle_failure(error=request.msg, test_mode=self.UNIT_TEST_MODE, session=session)
self.log.info("Executed callback for %s in state %s", ti, ti.state)
session.flush()

@classmethod
Expand Down
5 changes: 4 additions & 1 deletion airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,10 @@ class TaskDeferred(BaseException):
Signal an operator moving to deferred state.
Special exception raised to signal that the operator it was raised from
wishes to defer until a trigger fires.
wishes to defer until a trigger fires. Triggers can send execution back to task or end the task instance
directly. If the trigger should end the task instance itself, ``method_name`` does not matter,
and can be None; otherwise, provide the name of the method that should be used when
resuming execution in the task.
"""

def __init__(
Expand Down
5 changes: 4 additions & 1 deletion airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1764,7 +1764,10 @@ def defer(
Mark this Operator "deferred", suspending its execution until the provided trigger fires an event.
This is achieved by raising a special exception (TaskDeferred)
which is caught in the main _execute_task wrapper.
which is caught in the main _execute_task wrapper. Triggers can send execution back to task or end
the task instance directly. If the trigger will end the task instance itself, ``method_name`` should
be None; otherwise, provide the name of the method that should be used when resuming execution in
the task.
"""
raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout)

Expand Down
12 changes: 12 additions & 0 deletions airflow/models/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@
AirflowClusterPolicyViolation,
AirflowDagCycleException,
AirflowDagDuplicatedIdException,
AirflowException,
RemovedInAirflow3Warning,
)
from airflow.listeners.listener import get_listener_manager
from airflow.models.base import Base
from airflow.stats import Stats
from airflow.utils import timezone
Expand Down Expand Up @@ -512,6 +514,16 @@ def _bag_dag(self, *, dag, root_dag, recursive):
settings.dag_policy(dag)

for task in dag.tasks:
# The listeners are not supported when ending a task via a trigger on asynchronous operators.
if getattr(task, "end_from_trigger", False) and get_listener_manager().has_listeners:
raise AirflowException(
"Listeners are not supported with end_from_trigger=True for deferrable operators. "
"Task %s in DAG %s has end_from_trigger=True with listeners from plugins. "
"Set end_from_trigger=False to use listeners.",
task.task_id,
dag.dag_id,
)

settings.task_policy(task)
except (AirflowClusterPolicyViolation, AirflowClusterPolicySkipDag):
raise
Expand Down
9 changes: 8 additions & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@
from airflow.models.dagbag import DagBag
from airflow.models.dataset import DatasetAliasModel, DatasetModel
from airflow.models.log import Log
from airflow.models.mappedoperator import MappedOperator
from airflow.models.param import process_params
from airflow.models.renderedtifields import get_serialized_template_fields
from airflow.models.taskfail import TaskFail
Expand Down Expand Up @@ -699,6 +698,8 @@ def _execute_task(task_instance: TaskInstance | TaskInstancePydantic, context: C
:meta private:
"""
from airflow.models.mappedoperator import MappedOperator

task_to_execute = task_instance.task

if TYPE_CHECKING:
Expand Down Expand Up @@ -1288,6 +1289,8 @@ def _record_task_map_for_downstreams(
:meta private:
"""
from airflow.models.mappedoperator import MappedOperator

if task.dag is ATTRIBUTE_REMOVED:
task.dag = dag # required after deserialization

Expand Down Expand Up @@ -3454,6 +3457,8 @@ def render_templates(
the unmapped, fully rendered BaseOperator. The original ``self.task``
before replacement is returned.
"""
from airflow.models.mappedoperator import MappedOperator

if not context:
context = self.get_template_context()
original_task = self.task
Expand Down Expand Up @@ -3989,6 +3994,8 @@ def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> Mapp

def _is_further_mapped_inside(operator: Operator, container: TaskGroup) -> bool:
"""Whether given operator is *further* mapped inside a task group."""
from airflow.models.mappedoperator import MappedOperator

if isinstance(operator, MappedOperator):
return True
task_group = operator.task_group
Expand Down
9 changes: 1 addition & 8 deletions airflow/models/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,7 @@ def submit_event(cls, trigger_id, event, session: Session = NEW_SESSION) -> None
TaskInstance.trigger_id == trigger_id, TaskInstance.state == TaskInstanceState.DEFERRED
)
):
# Add the event's payload into the kwargs for the task
next_kwargs = task_instance.next_kwargs or {}
next_kwargs["event"] = event.payload
task_instance.next_kwargs = next_kwargs
# Remove ourselves as its trigger
task_instance.trigger_id = None
# Finally, mark it as scheduled so it gets re-queued
task_instance.state = TaskInstanceState.SCHEDULED
event.handle_submit(task_instance=task_instance)

@classmethod
@internal_api_call
Expand Down
15 changes: 9 additions & 6 deletions airflow/sensors/date_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

import datetime
from typing import TYPE_CHECKING, NoReturn, Sequence
from typing import TYPE_CHECKING, Any, NoReturn, Sequence

from airflow.sensors.base import BaseSensorOperator
from airflow.triggers.temporal import DateTimeTrigger
Expand Down Expand Up @@ -85,18 +85,21 @@ class DateTimeSensorAsync(DateTimeSensor):
It is a drop-in replacement for DateTimeSensor.
:param target_time: datetime after which the job succeeds. (templated)
:param end_from_trigger: End the task directly from the triggerer without going into the worker.
"""

def __init__(self, **kwargs) -> None:
def __init__(self, *, end_from_trigger: bool = False, **kwargs) -> None:
super().__init__(**kwargs)
self.end_from_trigger = end_from_trigger

def execute(self, context: Context) -> NoReturn:
trigger = DateTimeTrigger(moment=timezone.parse(self.target_time))
self.defer(
trigger=trigger,
method_name="execute_complete",
trigger=DateTimeTrigger(
moment=timezone.parse(self.target_time), end_from_trigger=self.end_from_trigger
),
)

def execute_complete(self, context, event=None) -> None:
"""Execute when the trigger fires - returns immediately."""
def execute_complete(self, context: Context, event: Any = None) -> None:
"""Handle the event when the trigger fires and return immediately."""
return None
13 changes: 9 additions & 4 deletions airflow/sensors/time_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, NoReturn
from typing import TYPE_CHECKING, Any, NoReturn

from airflow.exceptions import AirflowSkipException
from airflow.sensors.base import BaseSensorOperator
Expand Down Expand Up @@ -59,28 +59,33 @@ class TimeDeltaSensorAsync(TimeDeltaSensor):
Will defers itself to avoid taking up a worker slot while it is waiting.
:param delta: time length to wait after the data interval before succeeding.
:param end_from_trigger: End the task directly from the triggerer without going into the worker.
.. seealso::
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/operator:TimeDeltaSensorAsync`
"""

def __init__(self, *, end_from_trigger: bool = False, delta, **kwargs) -> None:
super().__init__(delta=delta, **kwargs)
self.end_from_trigger = end_from_trigger

def execute(self, context: Context) -> bool | NoReturn:
target_dttm = context["data_interval_end"]
target_dttm += self.delta
if timezone.utcnow() > target_dttm:
# If the target datetime is in the past, return immediately
return True
try:
trigger = DateTimeTrigger(moment=target_dttm)
trigger = DateTimeTrigger(moment=target_dttm, end_from_trigger=self.end_from_trigger)
except (TypeError, ValueError) as e:
if self.soft_fail:
raise AirflowSkipException("Skipping due to soft_fail is set to True.") from e
raise

self.defer(trigger=trigger, method_name="execute_complete")

def execute_complete(self, context, event=None) -> None:
"""Execute for when the trigger fires - return immediately."""
def execute_complete(self, context: Context, event: Any = None) -> None:
"""Handle the event when the trigger fires and return immediately."""
return None
13 changes: 7 additions & 6 deletions airflow/sensors/time_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

import datetime
from typing import TYPE_CHECKING, NoReturn
from typing import TYPE_CHECKING, Any, NoReturn

from airflow.sensors.base import BaseSensorOperator
from airflow.triggers.temporal import DateTimeTrigger
Expand Down Expand Up @@ -56,14 +56,16 @@ class TimeSensorAsync(BaseSensorOperator):
This frees up a worker slot while it is waiting.
:param target_time: time after which the job succeeds
:param end_from_trigger: End the task directly from the triggerer without going into the worker.
.. seealso::
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/operator:TimeSensorAsync`
"""

def __init__(self, *, target_time: datetime.time, **kwargs) -> None:
def __init__(self, *, end_from_trigger: bool = False, target_time: datetime.time, **kwargs) -> None:
super().__init__(**kwargs)
self.end_from_trigger = end_from_trigger
self.target_time = target_time

aware_time = timezone.coerce_datetime(
Expand All @@ -73,12 +75,11 @@ def __init__(self, *, target_time: datetime.time, **kwargs) -> None:
self.target_datetime = timezone.convert_to_utc(aware_time)

def execute(self, context: Context) -> NoReturn:
trigger = DateTimeTrigger(moment=self.target_datetime)
self.defer(
trigger=trigger,
trigger=DateTimeTrigger(moment=self.target_datetime, end_from_trigger=self.end_from_trigger),
method_name="execute_complete",
)

def execute_complete(self, context, event=None) -> None:
"""Execute when the trigger fires - returns immediately."""
def execute_complete(self, context: Context, event: Any = None) -> None:
"""Handle the event when the trigger fires and return immediately."""
return None
Loading

0 comments on commit d42afd2

Please sign in to comment.