diff --git a/airflow/callbacks/callback_requests.py b/airflow/callbacks/callback_requests.py index 8ec0187978db6..7158c45d44d91 100644 --- a/airflow/callbacks/callback_requests.py +++ b/airflow/callbacks/callback_requests.py @@ -19,6 +19,8 @@ import json from typing import TYPE_CHECKING +from airflow.utils.state import TaskInstanceState + if TYPE_CHECKING: from airflow.models.taskinstance import SimpleTaskInstance @@ -68,22 +70,33 @@ class TaskCallbackRequest(CallbackRequest): :param full_filepath: File Path to use to run the callback :param simple_task_instance: Simplified Task Instance representation - :param is_failure_callback: Flag to determine whether it is a Failure Callback or Success Callback :param msg: Additional Message that can be used for logging to determine failure/zombie :param processor_subdir: Directory used by Dag Processor when parsed the dag. + :param task_callback_type: e.g. whether on success, on failure, on retry. """ def __init__( self, full_filepath: str, simple_task_instance: SimpleTaskInstance, - is_failure_callback: bool | None = True, processor_subdir: str | None = None, msg: str | None = None, + task_callback_type: TaskInstanceState | None = None, ): super().__init__(full_filepath=full_filepath, processor_subdir=processor_subdir, msg=msg) self.simple_task_instance = simple_task_instance - self.is_failure_callback = is_failure_callback + self.task_callback_type = task_callback_type + + @property + def is_failure_callback(self) -> bool: + """Returns True if the callback is a failure callback.""" + if self.task_callback_type is None: + return True + return self.task_callback_type in { + TaskInstanceState.FAILED, + TaskInstanceState.UP_FOR_RETRY, + TaskInstanceState.UPSTREAM_FAILED, + } def to_json(self) -> str: from airflow.serialization.serialized_objects import BaseSerialization diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index 84049de4e2675..3cc2fe142b867 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -47,7 +47,7 @@ from airflow.models.dagwarning import DagWarning, DagWarningType from airflow.models.errors import ParseImportError from airflow.models.serialized_dag import SerializedDagModel -from airflow.models.taskinstance import TaskInstance, TaskInstance as TI +from airflow.models.taskinstance import TaskInstance, TaskInstance as TI, _run_finished_callback from airflow.stats import Stats from airflow.utils import timezone from airflow.utils.email import get_email_address_list, send_email @@ -808,8 +808,26 @@ def _execute_dag_callbacks(cls, dagbag: DagBag, request: DagCallbackRequest, ses @provide_session def _execute_task_callbacks( cls, dagbag: DagBag | None, request: TaskCallbackRequest, unit_test_mode: bool, session: Session - ): - if not request.is_failure_callback: + ) -> None: + """ + Execute the task callbacks. + + :param dagbag: the DagBag to use to get the task instance + :param request: the task callback request + :param session: the session to use + """ + try: + callback_type = TaskInstanceState(request.task_callback_type) + except ValueError: + callback_type = None + is_remote = callback_type in (TaskInstanceState.SUCCESS, TaskInstanceState.FAILED) + + # previously we ignored any request besides failures. now if given callback type directly, + # then we respect it and execute it. additionally because in this scenario the callback + # is submitted remotely, we assume there is no need to mess with state; we simply run + # the callback + + if not is_remote and not request.is_failure_callback: return simple_ti = request.simple_task_instance @@ -820,6 +838,7 @@ def _execute_task_callbacks( map_index=simple_ti.map_index, session=session, ) + if not ti: return @@ -841,8 +860,16 @@ def _execute_task_callbacks( if task: ti.refresh_from_task(task) - ti.handle_failure(error=request.msg, test_mode=unit_test_mode, session=session) - cls.logger().info("Executed failure callback for %s in state %s", ti, ti.state) + if callback_type is TaskInstanceState.SUCCESS: + context = ti.get_template_context(session=session) + if TYPE_CHECKING: + assert ti.task + callbacks = ti.task.on_success_callback + _run_finished_callback(callbacks=callbacks, context=context) + cls.logger().info("Executed callback for %s in state %s", ti, ti.state) + elif not is_remote or callback_type is TaskInstanceState.FAILED: + ti.handle_failure(error=request.msg, test_mode=unit_test_mode, session=session) + cls.logger().info("Executed callback for %s in state %s", ti, ti.state) session.flush() @classmethod diff --git a/airflow/exceptions.py b/airflow/exceptions.py index cdf04be3b3948..dc59f91841133 100644 --- a/airflow/exceptions.py +++ b/airflow/exceptions.py @@ -372,7 +372,10 @@ class TaskDeferred(BaseException): Signal an operator moving to deferred state. Special exception raised to signal that the operator it was raised from - wishes to defer until a trigger fires. + wishes to defer until a trigger fires. Triggers can send execution back to task or end the task instance + directly. If the trigger should end the task instance itself, ``method_name`` does not matter, + and can be None; otherwise, provide the name of the method that should be used when + resuming execution in the task. """ def __init__( diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 514386d7d358e..77423bfc3b99e 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1764,7 +1764,10 @@ def defer( Mark this Operator "deferred", suspending its execution until the provided trigger fires an event. This is achieved by raising a special exception (TaskDeferred) - which is caught in the main _execute_task wrapper. + which is caught in the main _execute_task wrapper. Triggers can send execution back to task or end + the task instance directly. If the trigger will end the task instance itself, ``method_name`` should + be None; otherwise, provide the name of the method that should be used when resuming execution in + the task. """ raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index 96b6ecf035d43..f384bfcd84ea8 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -46,8 +46,10 @@ AirflowClusterPolicyViolation, AirflowDagCycleException, AirflowDagDuplicatedIdException, + AirflowException, RemovedInAirflow3Warning, ) +from airflow.listeners.listener import get_listener_manager from airflow.models.base import Base from airflow.stats import Stats from airflow.utils import timezone @@ -512,6 +514,16 @@ def _bag_dag(self, *, dag, root_dag, recursive): settings.dag_policy(dag) for task in dag.tasks: + # The listeners are not supported when ending a task via a trigger on asynchronous operators. + if getattr(task, "end_from_trigger", False) and get_listener_manager().has_listeners: + raise AirflowException( + "Listeners are not supported with end_from_trigger=True for deferrable operators. " + "Task %s in DAG %s has end_from_trigger=True with listeners from plugins. " + "Set end_from_trigger=False to use listeners.", + task.task_id, + dag.dag_id, + ) + settings.task_policy(task) except (AirflowClusterPolicyViolation, AirflowClusterPolicySkipDag): raise diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index ca5868850e7b3..47f1624b78c6d 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -94,7 +94,6 @@ from airflow.models.dagbag import DagBag from airflow.models.dataset import DatasetAliasModel, DatasetModel from airflow.models.log import Log -from airflow.models.mappedoperator import MappedOperator from airflow.models.param import process_params from airflow.models.renderedtifields import get_serialized_template_fields from airflow.models.taskfail import TaskFail @@ -699,6 +698,8 @@ def _execute_task(task_instance: TaskInstance | TaskInstancePydantic, context: C :meta private: """ + from airflow.models.mappedoperator import MappedOperator + task_to_execute = task_instance.task if TYPE_CHECKING: @@ -1288,6 +1289,8 @@ def _record_task_map_for_downstreams( :meta private: """ + from airflow.models.mappedoperator import MappedOperator + if task.dag.__class__ is AttributeRemoved: task.dag = dag # required after deserialization @@ -3454,6 +3457,8 @@ def render_templates( the unmapped, fully rendered BaseOperator. The original ``self.task`` before replacement is returned. """ + from airflow.models.mappedoperator import MappedOperator + if not context: context = self.get_template_context() original_task = self.task @@ -3989,6 +3994,8 @@ def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> Mapp def _is_further_mapped_inside(operator: Operator, container: TaskGroup) -> bool: """Whether given operator is *further* mapped inside a task group.""" + from airflow.models.mappedoperator import MappedOperator + if isinstance(operator, MappedOperator): return True task_group = operator.task_group diff --git a/airflow/models/trigger.py b/airflow/models/trigger.py index 77506fe9f8555..971610d5f6b88 100644 --- a/airflow/models/trigger.py +++ b/airflow/models/trigger.py @@ -203,14 +203,7 @@ def submit_event(cls, trigger_id, event, session: Session = NEW_SESSION) -> None TaskInstance.trigger_id == trigger_id, TaskInstance.state == TaskInstanceState.DEFERRED ) ): - # Add the event's payload into the kwargs for the task - next_kwargs = task_instance.next_kwargs or {} - next_kwargs["event"] = event.payload - task_instance.next_kwargs = next_kwargs - # Remove ourselves as its trigger - task_instance.trigger_id = None - # Finally, mark it as scheduled so it gets re-queued - task_instance.state = TaskInstanceState.SCHEDULED + event.handle_submit(task_instance=task_instance) @classmethod @internal_api_call diff --git a/airflow/sensors/date_time.py b/airflow/sensors/date_time.py index b0763ebd40a87..490c23486008a 100644 --- a/airflow/sensors/date_time.py +++ b/airflow/sensors/date_time.py @@ -18,7 +18,7 @@ from __future__ import annotations import datetime -from typing import TYPE_CHECKING, NoReturn, Sequence +from typing import TYPE_CHECKING, Any, NoReturn, Sequence from airflow.sensors.base import BaseSensorOperator from airflow.triggers.temporal import DateTimeTrigger @@ -85,18 +85,21 @@ class DateTimeSensorAsync(DateTimeSensor): It is a drop-in replacement for DateTimeSensor. :param target_time: datetime after which the job succeeds. (templated) + :param end_from_trigger: End the task directly from the triggerer without going into the worker. """ - def __init__(self, **kwargs) -> None: + def __init__(self, *, end_from_trigger: bool = False, **kwargs) -> None: super().__init__(**kwargs) + self.end_from_trigger = end_from_trigger def execute(self, context: Context) -> NoReturn: - trigger = DateTimeTrigger(moment=timezone.parse(self.target_time)) self.defer( - trigger=trigger, method_name="execute_complete", + trigger=DateTimeTrigger( + moment=timezone.parse(self.target_time), end_from_trigger=self.end_from_trigger + ), ) - def execute_complete(self, context, event=None) -> None: - """Execute when the trigger fires - returns immediately.""" + def execute_complete(self, context: Context, event: Any = None) -> None: + """Handle the event when the trigger fires and return immediately.""" return None diff --git a/airflow/sensors/time_delta.py b/airflow/sensors/time_delta.py index 226d520aa0ee3..d068fad9bf5a5 100644 --- a/airflow/sensors/time_delta.py +++ b/airflow/sensors/time_delta.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, NoReturn +from typing import TYPE_CHECKING, Any, NoReturn from airflow.exceptions import AirflowSkipException from airflow.sensors.base import BaseSensorOperator @@ -59,6 +59,7 @@ class TimeDeltaSensorAsync(TimeDeltaSensor): Will defers itself to avoid taking up a worker slot while it is waiting. :param delta: time length to wait after the data interval before succeeding. + :param end_from_trigger: End the task directly from the triggerer without going into the worker. .. seealso:: For more information on how to use this sensor, take a look at the guide: @@ -66,6 +67,10 @@ class TimeDeltaSensorAsync(TimeDeltaSensor): """ + def __init__(self, *, end_from_trigger: bool = False, delta, **kwargs) -> None: + super().__init__(delta=delta, **kwargs) + self.end_from_trigger = end_from_trigger + def execute(self, context: Context) -> bool | NoReturn: target_dttm = context["data_interval_end"] target_dttm += self.delta @@ -73,7 +78,7 @@ def execute(self, context: Context) -> bool | NoReturn: # If the target datetime is in the past, return immediately return True try: - trigger = DateTimeTrigger(moment=target_dttm) + trigger = DateTimeTrigger(moment=target_dttm, end_from_trigger=self.end_from_trigger) except (TypeError, ValueError) as e: if self.soft_fail: raise AirflowSkipException("Skipping due to soft_fail is set to True.") from e @@ -81,6 +86,6 @@ def execute(self, context: Context) -> bool | NoReturn: self.defer(trigger=trigger, method_name="execute_complete") - def execute_complete(self, context, event=None) -> None: - """Execute for when the trigger fires - return immediately.""" + def execute_complete(self, context: Context, event: Any = None) -> None: + """Handle the event when the trigger fires and return immediately.""" return None diff --git a/airflow/sensors/time_sensor.py b/airflow/sensors/time_sensor.py index 91c1354782593..6c5c91d15d244 100644 --- a/airflow/sensors/time_sensor.py +++ b/airflow/sensors/time_sensor.py @@ -18,7 +18,7 @@ from __future__ import annotations import datetime -from typing import TYPE_CHECKING, NoReturn +from typing import TYPE_CHECKING, Any, NoReturn from airflow.sensors.base import BaseSensorOperator from airflow.triggers.temporal import DateTimeTrigger @@ -56,14 +56,16 @@ class TimeSensorAsync(BaseSensorOperator): This frees up a worker slot while it is waiting. :param target_time: time after which the job succeeds + :param end_from_trigger: End the task directly from the triggerer without going into the worker. .. seealso:: For more information on how to use this sensor, take a look at the guide: :ref:`howto/operator:TimeSensorAsync` """ - def __init__(self, *, target_time: datetime.time, **kwargs) -> None: + def __init__(self, *, end_from_trigger: bool = False, target_time: datetime.time, **kwargs) -> None: super().__init__(**kwargs) + self.end_from_trigger = end_from_trigger self.target_time = target_time aware_time = timezone.coerce_datetime( @@ -73,12 +75,11 @@ def __init__(self, *, target_time: datetime.time, **kwargs) -> None: self.target_datetime = timezone.convert_to_utc(aware_time) def execute(self, context: Context) -> NoReturn: - trigger = DateTimeTrigger(moment=self.target_datetime) self.defer( - trigger=trigger, + trigger=DateTimeTrigger(moment=self.target_datetime, end_from_trigger=self.end_from_trigger), method_name="execute_complete", ) - def execute_complete(self, context, event=None) -> None: - """Execute when the trigger fires - returns immediately.""" + def execute_complete(self, context: Context, event: Any = None) -> None: + """Handle the event when the trigger fires and return immediately.""" return None diff --git a/airflow/triggers/base.py b/airflow/triggers/base.py index 5dacee3364c54..190e2983cee07 100644 --- a/airflow/triggers/base.py +++ b/airflow/triggers/base.py @@ -17,11 +17,24 @@ from __future__ import annotations import abc +import logging from dataclasses import dataclass from datetime import timedelta -from typing import Any, AsyncIterator +from typing import TYPE_CHECKING, Any, AsyncIterator +from airflow.callbacks.callback_requests import TaskCallbackRequest +from airflow.callbacks.database_callback_sink import DatabaseCallbackSink +from airflow.models.taskinstance import SimpleTaskInstance from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.state import TaskInstanceState + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + + from airflow.models import TaskInstance + +log = logging.getLogger(__name__) @dataclass @@ -137,3 +150,106 @@ def __eq__(self, other): if isinstance(other, TriggerEvent): return other.payload == self.payload return False + + @provide_session + def handle_submit(self, *, task_instance: TaskInstance, session: Session = NEW_SESSION) -> None: + """ + Handle the submit event for a given task instance. + + This function sets the next method and next kwargs of the task instance, + as well as its state to scheduled. It also adds the event's payload + into the kwargs for the task. + + :param task_instance: The task instance to handle the submit event for. + :param session: The session to be used for the database callback sink. + """ + # Get the next kwargs of the task instance, or an empty dictionary if it doesn't exist + next_kwargs = task_instance.next_kwargs or {} + + # Add the event's payload into the kwargs for the task + next_kwargs["event"] = self.payload + + # Update the next kwargs of the task instance + task_instance.next_kwargs = next_kwargs + + # Remove ourselves as its trigger + task_instance.trigger_id = None + + # Set the state of the task instance to scheduled + task_instance.state = TaskInstanceState.SCHEDULED + + +class BaseTaskEndEvent(TriggerEvent): + """ + Base event class to end the task without resuming on worker. + + :meta private: + """ + + task_instance_state: TaskInstanceState + + def __init__(self, *, xcoms: dict[str, Any] | None = None, **kwargs) -> None: + """ + Initialize the class with the specified parameters. + + :param xcoms: A dictionary of XComs or None. + :param kwargs: Additional keyword arguments. + """ + if "payload" in kwargs: + raise ValueError("Param 'payload' not supported for this class.") + super().__init__(payload=self.task_instance_state) + self.xcoms = xcoms + + @provide_session + def handle_submit(self, *, task_instance: TaskInstance, session: Session = NEW_SESSION) -> None: + """ + Submit event for the given task instance. + + Marks the task with the state `task_instance_state` and optionally pushes xcom if applicable. + + :param task_instance: The task instance to be submitted. + :param session: The session to be used for the database callback sink. + """ + # Mark the task with terminal state and prevent it from resuming on worker + task_instance.trigger_id = None + task_instance.state = self.task_instance_state + self._submit_callback_if_necessary(task_instance=task_instance, session=session) + self._push_xcoms_if_necessary(task_instance=task_instance) + + def _submit_callback_if_necessary(self, *, task_instance: TaskInstance, session) -> None: + """Submit a callback request if the task state is SUCCESS or FAILED.""" + if self.task_instance_state in (TaskInstanceState.SUCCESS, TaskInstanceState.FAILED): + request = TaskCallbackRequest( + full_filepath=task_instance.dag_model.fileloc, + simple_task_instance=SimpleTaskInstance.from_ti(task_instance), + task_callback_type=self.task_instance_state, + ) + log.info("Sending callback: %s", request) + try: + DatabaseCallbackSink().send(callback=request, session=session) + except Exception: + log.exception("Failed to send callback.") + + def _push_xcoms_if_necessary(self, *, task_instance: TaskInstance) -> None: + """Pushes XComs to the database if they are provided.""" + if self.xcoms: + for key, value in self.xcoms.items(): + task_instance.xcom_push(key=key, value=value) + + +class TaskSuccessEvent(BaseTaskEndEvent): + """Yield this event in order to end the task successfully.""" + + task_instance_state = TaskInstanceState.SUCCESS + + +class TaskFailedEvent(BaseTaskEndEvent): + """Yield this event in order to end the task with failure.""" + + task_instance_state = TaskInstanceState.FAILED + + +class TaskSkippedEvent(BaseTaskEndEvent): + """Yield this event in order to end the task with status 'skipped'.""" + + task_instance_state = TaskInstanceState.SKIPPED diff --git a/airflow/triggers/temporal.py b/airflow/triggers/temporal.py index 79e8f39dd76e7..64c3afe8162c3 100644 --- a/airflow/triggers/temporal.py +++ b/airflow/triggers/temporal.py @@ -22,7 +22,7 @@ import pendulum -from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.triggers.base import BaseTrigger, TaskSuccessEvent, TriggerEvent from airflow.utils import timezone @@ -34,9 +34,13 @@ class DateTimeTrigger(BaseTrigger): a few seconds. The provided datetime MUST be in UTC. + + :param moment: when to yield event + :param end_from_trigger: whether the trigger should mark the task successful after time condition + reached or resume the task after time condition reached. """ - def __init__(self, moment: datetime.datetime): + def __init__(self, moment: datetime.datetime, *, end_from_trigger: bool = False) -> None: super().__init__() if not isinstance(moment, datetime.datetime): raise TypeError(f"Expected datetime.datetime type for moment. Got {type(moment)}") @@ -45,9 +49,13 @@ def __init__(self, moment: datetime.datetime): raise ValueError("You cannot pass naive datetimes") else: self.moment: pendulum.DateTime = timezone.convert_to_utc(moment) + self.end_from_trigger = end_from_trigger def serialize(self) -> tuple[str, dict[str, Any]]: - return ("airflow.triggers.temporal.DateTimeTrigger", {"moment": self.moment}) + return ( + "airflow.triggers.temporal.DateTimeTrigger", + {"moment": self.moment, "end_from_trigger": self.end_from_trigger}, + ) async def run(self) -> AsyncIterator[TriggerEvent]: """ @@ -70,9 +78,12 @@ async def run(self) -> AsyncIterator[TriggerEvent]: while self.moment > pendulum.instance(timezone.utcnow()): self.log.info("sleeping 1 second...") await asyncio.sleep(1) - # Send our single event and then we're done - self.log.info("yielding event with payload %r", self.moment) - yield TriggerEvent(self.moment) + if self.end_from_trigger: + self.log.info("Sensor time condition reached; marking task successful and exiting") + yield TaskSuccessEvent() + else: + self.log.info("yielding event with payload %r", self.moment) + yield TriggerEvent(self.moment) class TimeDeltaTrigger(DateTimeTrigger): @@ -84,7 +95,11 @@ class TimeDeltaTrigger(DateTimeTrigger): While this is its own distinct class here, it will serialise to a DateTimeTrigger class, since they're operationally the same. + + :param delta: how long to wait + :param end_from_trigger: whether the trigger should mark the task successful after time condition + reached or resume the task after time condition reached. """ - def __init__(self, delta: datetime.timedelta): - super().__init__(moment=timezone.utcnow() + delta) + def __init__(self, delta: datetime.timedelta, *, end_from_trigger: bool = False) -> None: + super().__init__(moment=timezone.utcnow() + delta, end_from_trigger=end_from_trigger) diff --git a/docs/apache-airflow/authoring-and-scheduling/deferring.rst b/docs/apache-airflow/authoring-and-scheduling/deferring.rst index aba81d5008d64..51265fae1807c 100644 --- a/docs/apache-airflow/authoring-and-scheduling/deferring.rst +++ b/docs/apache-airflow/authoring-and-scheduling/deferring.rst @@ -97,6 +97,59 @@ When writing a deferrable operators these are the main points to consider: return +Exiting deferred task from Triggers +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + .. versionadded:: 2.10.0 + +If you want to exit your task directly from the triggerer without going into the worker, you can specify the instance level attribute ``end_from_trigger`` with the attributes of your deferrable operator, as discussed above. This can save some resources needed to start a new worker. + +Triggers can have two options: they can either send execution back to the worker or end the task instance directly. If the trigger ends the task instance itself, the ``method_name`` does not matter and can be ``None``. Otherwise, provide ``method_name`` that should be used when resuming execution in the task. + +.. code-block:: python + + class WaitFiveHourSensorAsync(BaseSensorOperator): + # this sensor always exits from trigger. + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.end_from_trigger = True + + def execute(self, context: Context) -> NoReturn: + self.defer( + method_name=None, + trigger=WaitFiveHourTrigger(duration=timedelta(hours=5), end_from_trigger=self.end_from_trigger), + ) + + +``TaskSuccessEvent`` and ``TaskFailureEvent`` are the two events that can be used to end the task instance directly. This marks the task with the state ``task_instance_state`` and optionally pushes xcom if applicable. Here's an example of how to use these events: + +.. code-block:: python + + + class WaitFiveHourTrigger(BaseTrigger): + def __init__(self, duration: timedelta, *, end_from_trigger: bool = False): + super().__init__() + self.duration = duration + self.end_from_trigger = end_from_trigger + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "your_module.WaitFiveHourTrigger", + {"duration": self.duration, "end_from_trigger": self.end_from_trigger}, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + await asyncio.sleep(self.duration.total_seconds()) + if self.end_from_trigger: + yield TaskSuccessEvent() + else: + yield TriggerEvent({"duration": self.duration}) + +In the above example, the trigger will end the task instance directly if ``end_from_trigger`` is set to ``True`` by yielding ``TaskSuccessEvent``. Otherwise, it will resume the task instance with the method specified in the operator. + +.. note:: + Exiting from the trigger works only when listeners are not integrated for the deferrable operator. Currently, when deferrable operator has the ``end_from_trigger`` attribute set to ``True`` and listeners are integrated it raises an exception during parsing to indicate this limitation. While writing the custom trigger, ensure that the trigger is not set to end the task instance directly if the listeners are added from plugins. If the ``end_from_trigger`` attribute is changed to different attribute by author of trigger, the DAG parsing would not raise any exception and the listeners dependent on this task would not work. This limitation will be addressed in future releases. + Writing Triggers ~~~~~~~~~~~~~~~~ diff --git a/tests/callbacks/test_callback_requests.py b/tests/callbacks/test_callback_requests.py index b9a00376ff0bb..90305c67c1182 100644 --- a/tests/callbacks/test_callback_requests.py +++ b/tests/callbacks/test_callback_requests.py @@ -78,7 +78,6 @@ def test_from_json(self, input, request_class): full_filepath="filepath", simple_task_instance=SimpleTaskInstance.from_ti(ti=ti), processor_subdir="/test_dir", - is_failure_callback=True, ) json_str = input.to_json() result = request_class.from_json(json_str=json_str) @@ -94,7 +93,6 @@ def test_taskcallback_to_json_with_start_date_and_end_date(self, session, create full_filepath="filepath", simple_task_instance=SimpleTaskInstance.from_ti(ti), processor_subdir="/test_dir", - is_failure_callback=True, ) json_str = input.to_json() result = TaskCallbackRequest.from_json(json_str) diff --git a/tests/models/test_trigger.py b/tests/models/test_trigger.py index 6be2086f34112..5a8ef28df0a64 100644 --- a/tests/models/test_trigger.py +++ b/tests/models/test_trigger.py @@ -26,13 +26,20 @@ from airflow.jobs.job import Job from airflow.jobs.triggerer_job_runner import TriggererJobRunner -from airflow.models import TaskInstance, Trigger +from airflow.models import TaskInstance, Trigger, XCom from airflow.operators.empty import EmptyOperator from airflow.serialization.serialized_objects import BaseSerialization -from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.triggers.base import ( + BaseTrigger, + TaskFailedEvent, + TaskSkippedEvent, + TaskSuccessEvent, + TriggerEvent, +) from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.state import State +from airflow.utils.xcom import XCOM_RETURN_KEY from tests.test_utils.config import conf_vars pytestmark = pytest.mark.db_test @@ -113,7 +120,6 @@ def test_submit_event(session, create_task_instance): Trigger.submit_event(trigger.id, TriggerEvent(42), session=session) # commit changes made by submit event and expire all cache to read from db. session.flush() - session.expunge_all() # Check that the task instance is now scheduled updated_task_instance = session.query(TaskInstance).one() assert updated_task_instance.state == State.SCHEDULED @@ -123,7 +129,7 @@ def test_submit_event(session, create_task_instance): def test_submit_failure(session, create_task_instance): """ Tests that failures submitted to a trigger fail their dependent - task instances. + task instances if not using a TaskEndEvent. """ # Make a trigger trigger = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={}) @@ -144,6 +150,56 @@ def test_submit_failure(session, create_task_instance): assert updated_task_instance.next_method == "__fail__" +@pytest.mark.parametrize( + "event_cls, expected", + [ + (TaskSuccessEvent, "success"), + (TaskFailedEvent, "failed"), + (TaskSkippedEvent, "skipped"), + ], +) +def test_submit_event_task_end(session, create_task_instance, event_cls, expected): + """ + Tests that events inheriting BaseTaskEndEvent *don't* re-wake their dependent + but mark them in the appropriate terminal state and send xcom + """ + # Make a trigger + trigger = Trigger(classpath="does.not.matter", kwargs={}) + trigger.id = 1 + session.add(trigger) + session.commit() + # Make a TaskInstance that's deferred and waiting on it + task_instance = create_task_instance( + session=session, execution_date=timezone.utcnow(), state=State.DEFERRED + ) + task_instance.trigger_id = trigger.id + session.commit() + + def get_xcoms(ti): + return XCom.get_many(dag_ids=[ti.dag_id], task_ids=[ti.task_id], run_id=ti.run_id).all() + + # now for the real test + # first check initial state + ti: TaskInstance = session.query(TaskInstance).one() + assert ti.state == "deferred" + assert get_xcoms(ti) == [] + + session.flush() + # now, for each type, submit event + # verify that (1) task ends in right state and (2) xcom is pushed + Trigger.submit_event( + trigger.id, event_cls(xcoms={XCOM_RETURN_KEY: "xcomret", "a": "b", "c": "d"}), session=session + ) + # commit changes made by submit event and expire all cache to read from db. + session.flush() + # Check that the task instance is now correct + ti = session.query(TaskInstance).one() + assert ti.state == expected + assert ti.next_kwargs is None + actual_xcoms = {x.key: x.value for x in get_xcoms(ti)} + assert actual_xcoms == {"return_value": "xcomret", "a": "b", "c": "d"} + + def test_assign_unassigned(session, create_task_instance): """ Tests that unassigned triggers of all appropriate states are assigned. diff --git a/tests/sensors/test_time_sensor.py b/tests/sensors/test_time_sensor.py index 54a0212a247a9..d26fc7bf39005 100644 --- a/tests/sensors/test_time_sensor.py +++ b/tests/sensors/test_time_sensor.py @@ -63,8 +63,8 @@ def test_task_is_deferred(self): assert isinstance(exc_info.value.trigger, DateTimeTrigger) assert exc_info.value.trigger.moment == timezone.datetime(2020, 7, 7, 10) - assert exc_info.value.method_name == "execute_complete" assert exc_info.value.kwargs is None + assert exc_info.value.method_name == "execute_complete" def test_target_time_aware(self): with DAG("test_target_time_aware", start_date=timezone.datetime(2020, 1, 1, 23, 0)): diff --git a/tests/triggers/test_temporal.py b/tests/triggers/test_temporal.py index 6e8d32c467e63..90f00a694e5b3 100644 --- a/tests/triggers/test_temporal.py +++ b/tests/triggers/test_temporal.py @@ -26,6 +26,7 @@ from airflow.triggers.base import TriggerEvent from airflow.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger from airflow.utils import timezone +from airflow.utils.state import TaskInstanceState from airflow.utils.timezone import utcnow @@ -56,7 +57,7 @@ def test_datetime_trigger_serialization(): trigger = DateTimeTrigger(moment) classpath, kwargs = trigger.serialize() assert classpath == "airflow.triggers.temporal.DateTimeTrigger" - assert kwargs == {"moment": moment} + assert kwargs == {"moment": moment, "end_from_trigger": False} def test_timedelta_trigger_serialization(): @@ -74,15 +75,16 @@ def test_timedelta_trigger_serialization(): @pytest.mark.parametrize( - "tz", + "tz, end_from_trigger", [ - timezone.parse_timezone("UTC"), - timezone.parse_timezone("Europe/Paris"), - timezone.parse_timezone("America/Toronto"), + (pendulum.timezone("UTC"), True), + (pendulum.timezone("UTC"), False), # only really need to test one + (pendulum.timezone("Europe/Paris"), True), + (pendulum.timezone("America/Toronto"), True), ], ) @pytest.mark.asyncio -async def test_datetime_trigger_timing(tz): +async def test_datetime_trigger_timing(tz, end_from_trigger): """ Tests that the DateTimeTrigger only goes off on or after the appropriate time. @@ -91,7 +93,7 @@ async def test_datetime_trigger_timing(tz): future_moment = pendulum.instance((timezone.utcnow() + datetime.timedelta(seconds=60)).astimezone(tz)) # Create a task that runs the trigger for a short time then cancels it - trigger = DateTimeTrigger(future_moment) + trigger = DateTimeTrigger(future_moment, end_from_trigger=end_from_trigger) trigger_task = asyncio.create_task(trigger.run().__anext__()) await asyncio.sleep(0.5) @@ -100,14 +102,15 @@ async def test_datetime_trigger_timing(tz): trigger_task.cancel() # Now, make one waiting for en event in the past and do it again - trigger = DateTimeTrigger(past_moment) + trigger = DateTimeTrigger(past_moment, end_from_trigger=end_from_trigger) trigger_task = asyncio.create_task(trigger.run().__anext__()) await asyncio.sleep(0.5) assert trigger_task.done() is True result = trigger_task.result() assert isinstance(result, TriggerEvent) - assert result.payload == past_moment + expected_payload = TaskInstanceState.SUCCESS if end_from_trigger else past_moment + assert result.payload == expected_payload @mock.patch("airflow.triggers.temporal.timezone.utcnow")