From ebef9ed3fa4a9a1e69b4405945e7cd939f499ee5 Mon Sep 17 00:00:00 2001 From: Josh Fell <48934154+josh-fell@users.noreply.github.com> Date: Wed, 7 Sep 2022 19:17:34 -0400 Subject: [PATCH] Add ``@task.short_circuit`` TaskFlow decorator (#25752) --- airflow/decorators/__init__.py | 3 + airflow/decorators/__init__.pyi | 20 +++++ airflow/decorators/short_circuit.py | 83 +++++++++++++++++++ .../example_short_circuit_decorator.py | 59 +++++++++++++ .../example_short_circuit_operator.py | 4 - docs/apache-airflow/howto/operator/python.rst | 33 ++++---- tests/decorators/test_short_circuit.py | 72 ++++++++++++++++ 7 files changed, 256 insertions(+), 18 deletions(-) create mode 100644 airflow/decorators/short_circuit.py create mode 100644 airflow/example_dags/example_short_circuit_decorator.py create mode 100644 tests/decorators/test_short_circuit.py diff --git a/airflow/decorators/__init__.py b/airflow/decorators/__init__.py index 6004a397e4bfc..ad5d6431e5bcd 100644 --- a/airflow/decorators/__init__.py +++ b/airflow/decorators/__init__.py @@ -22,6 +22,7 @@ from airflow.decorators.external_python import external_python_task from airflow.decorators.python import python_task from airflow.decorators.python_virtualenv import virtualenv_task +from airflow.decorators.short_circuit import short_circuit_task from airflow.decorators.task_group import task_group from airflow.models.dag import dag from airflow.providers_manager import ProvidersManager @@ -37,6 +38,7 @@ "virtualenv_task", "external_python_task", "branch_task", + "short_circuit_task", ] @@ -47,6 +49,7 @@ class TaskDecoratorCollection: virtualenv = staticmethod(virtualenv_task) external_python = staticmethod(external_python_task) branch = staticmethod(branch_task) + short_circuit = staticmethod(short_circuit_task) __call__: Any = python # Alias '@task' to '@task.python'. diff --git a/airflow/decorators/__init__.pyi b/airflow/decorators/__init__.pyi index b5992cf51302e..e684860f4a366 100644 --- a/airflow/decorators/__init__.pyi +++ b/airflow/decorators/__init__.pyi @@ -44,6 +44,7 @@ __all__ = [ "virtualenv_task", "external_python_task", "branch_task", + "short_circuit_task", ] class TaskDecoratorCollection: @@ -171,6 +172,25 @@ class TaskDecoratorCollection: """ @overload def branch(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ... + @overload + def short_circuit( + self, + *, + multiple_outputs: Optional[bool] = None, + ignore_downstream_trigger_rules: bool = True, + **kwargs, + ) -> TaskDecorator: + """Create a decorator to wrap the decorated callable into a ShortCircuitOperator. + + :param multiple_outputs: If set, function return value will be unrolled to multiple XCom values. + Dict will unroll to XCom values with keys as XCom keys. Defaults to False. + :param ignore_downstream_trigger_rules: If set to True, all downstream tasks from this operator task + will be skipped. This is the default behavior. If set to False, the direct, downstream task(s) + will be skipped but the ``trigger_rule`` defined for a other downstream tasks will be respected. + Defaults to True. + """ + @overload + def short_circuit(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ... # [START decorator_signature] def docker( self, diff --git a/airflow/decorators/short_circuit.py b/airflow/decorators/short_circuit.py new file mode 100644 index 0000000000000..f3aec185b714f --- /dev/null +++ b/airflow/decorators/short_circuit.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Callable, Optional, Sequence + +from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory +from airflow.operators.python import ShortCircuitOperator + + +class _ShortCircuitDecoratedOperator(DecoratedOperator, ShortCircuitOperator): + """ + Wraps a Python callable and captures args/kwargs when called for execution. + + :param python_callable: A reference to an object that is callable + :param op_kwargs: a dictionary of keyword arguments that will get unpacked + in your function (templated) + :param op_args: a list of positional arguments that will get unpacked when + calling your callable (templated) + :param multiple_outputs: If set to True, the decorated function's return value will be unrolled to + multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False. + """ + + template_fields: Sequence[str] = ('op_args', 'op_kwargs') + template_fields_renderers = {"op_args": "py", "op_kwargs": "py"} + + # since we won't mutate the arguments, we should just do the shallow copy + # there are some cases we can't deepcopy the objects (e.g protobuf). + shallow_copy_attrs: Sequence[str] = ('python_callable',) + + custom_operator_name: str = '@task.short_circuit' + + def __init__(self, *, python_callable, op_args, op_kwargs, **kwargs) -> None: + kwargs_to_upstream = { + "python_callable": python_callable, + "op_args": op_args, + "op_kwargs": op_kwargs, + } + super().__init__( + kwargs_to_upstream=kwargs_to_upstream, + python_callable=python_callable, + op_args=op_args, + op_kwargs=op_kwargs, + **kwargs, + ) + + +def short_circuit_task( + python_callable: Optional[Callable] = None, + multiple_outputs: Optional[bool] = None, + **kwargs, +) -> TaskDecorator: + """Wraps a function into an ShortCircuitOperator. + + Accepts kwargs for operator kwarg. Can be reused in a single DAG. + + This function is only used only used during type checking or auto-completion. + + :param python_callable: Function to decorate + :param multiple_outputs: If set to True, the decorated function's return value will be unrolled to + multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False. + + :meta private: + """ + return task_decorator_factory( + python_callable=python_callable, + multiple_outputs=multiple_outputs, + decorated_operator_class=_ShortCircuitDecoratedOperator, + **kwargs, + ) diff --git a/airflow/example_dags/example_short_circuit_decorator.py b/airflow/example_dags/example_short_circuit_decorator.py new file mode 100644 index 0000000000000..4e7e098624fb4 --- /dev/null +++ b/airflow/example_dags/example_short_circuit_decorator.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Example DAG demonstrating the usage of the `@task.short_circuit()` TaskFlow decorator.""" +import pendulum + +from airflow.decorators import dag, task +from airflow.models.baseoperator import chain +from airflow.operators.empty import EmptyOperator +from airflow.utils.trigger_rule import TriggerRule + + +@dag(start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, tags=['example']) +def example_short_circuit_decorator(): + # [START howto_operator_short_circuit] + @task.short_circuit() + def check_condition(condition): + return condition + + ds_true = [EmptyOperator(task_id='true_' + str(i)) for i in [1, 2]] + ds_false = [EmptyOperator(task_id='false_' + str(i)) for i in [1, 2]] + + condition_is_true = check_condition.override(task_id="condition_is_true")(condition=True) + condition_is_false = check_condition.override(task_id="condition_is_false")(condition=False) + + chain(condition_is_true, *ds_true) + chain(condition_is_false, *ds_false) + # [END howto_operator_short_circuit] + + # [START howto_operator_short_circuit_trigger_rules] + [task_1, task_2, task_3, task_4, task_5, task_6] = [ + EmptyOperator(task_id=f"task_{i}") for i in range(1, 7) + ] + + task_7 = EmptyOperator(task_id="task_7", trigger_rule=TriggerRule.ALL_DONE) + + short_circuit = check_condition.override(task_id="short_circuit", ignore_downstream_trigger_rules=False)( + condition=False + ) + + chain(task_1, [task_2, short_circuit], [task_3, task_4], [task_5, task_6], task_7) + # [END howto_operator_short_circuit_trigger_rules] + + +example_dag = example_short_circuit_decorator() diff --git a/airflow/example_dags/example_short_circuit_operator.py b/airflow/example_dags/example_short_circuit_operator.py index 2278de30e6294..3fc9f1bd00df0 100644 --- a/airflow/example_dags/example_short_circuit_operator.py +++ b/airflow/example_dags/example_short_circuit_operator.py @@ -31,7 +31,6 @@ catchup=False, tags=['example'], ) as dag: - # [START howto_operator_short_circuit] cond_true = ShortCircuitOperator( task_id='condition_is_True', python_callable=lambda: True, @@ -47,9 +46,7 @@ chain(cond_true, *ds_true) chain(cond_false, *ds_false) - # [END howto_operator_short_circuit] - # [START howto_operator_short_circuit_trigger_rules] [task_1, task_2, task_3, task_4, task_5, task_6] = [ EmptyOperator(task_id=f"task_{i}") for i in range(1, 7) ] @@ -61,4 +58,3 @@ ) chain(task_1, [task_2, short_circuit], [task_3, task_4], [task_5, task_6], task_7) - # [END howto_operator_short_circuit_trigger_rules] diff --git a/docs/apache-airflow/howto/operator/python.rst b/docs/apache-airflow/howto/operator/python.rst index b61ea77df192b..7128a2a5e00be 100644 --- a/docs/apache-airflow/howto/operator/python.rst +++ b/docs/apache-airflow/howto/operator/python.rst @@ -129,19 +129,24 @@ If you want the context related to datetime objects like ``data_interval_start`` .. _howto/operator:ShortCircuitOperator: ShortCircuitOperator -======================== +==================== + +Use the ``@task.short_circuit`` decorator to control whether a pipeline continues +if a condition is satisfied or a truthy value is obtained. + +.. warning:: + The ``@task.short_circuit`` decorator is recommended over the classic :class:`~airflow.operators.python.ShortCircuitOperator` + to short-circuit pipelines via Python callables. -Use the :class:`~airflow.operators.python.ShortCircuitOperator` to control whether a pipeline continues -if a condition is satisfied or a truthy value is obtained. The evaluation of this condition and truthy value -is done via the output of a ``python_callable``. If the ``python_callable`` returns True or a truthy value, +The evaluation of this condition and truthy value +is done via the output of the decorated function. If the decorated function returns True or a truthy value, the pipeline is allowed to continue and an :ref:`XCom ` of the output will be pushed. If the output is False or a falsy value, the pipeline will be short-circuited based on the configured -short-circuiting (more on this later). In the example below, the tasks that follow the "condition_is_True" -ShortCircuitOperator will execute while the tasks downstream of the "condition_is_False" ShortCircuitOperator -will be skipped. +short-circuiting (more on this later). In the example below, the tasks that follow the "condition_is_true" +task will execute while the tasks downstream of the "condition_is_false" task will be skipped. -.. exampleinclude:: /../../airflow/example_dags/example_short_circuit_operator.py +.. exampleinclude:: /../../airflow/example_dags/example_short_circuit_decorator.py :language: python :dedent: 4 :start-after: [START howto_operator_short_circuit] @@ -155,14 +160,14 @@ set to False, the direct downstream tasks are skipped but the specified ``trigge downstream tasks are respected. In this short-circuiting configuration, the operator assumes the direct downstream task(s) were purposely meant to be skipped but perhaps not other subsequent tasks. This configuration is especially useful if only *part* of a pipeline should be short-circuited rather than all -tasks which follow the ShortCircuitOperator task. +tasks which follow the short-circuiting task. -In the example below, notice that the ShortCircuitOperator task is configured to respect downstream trigger -rules. This means while the tasks that follow the "short_circuit" ShortCircuitOperator task will be skipped -since the ``python_callable`` returns False, "task_7" will still execute as its set to execute when upstream +In the example below, notice that the "short_circuit" task is configured to respect downstream trigger +rules. This means while the tasks that follow the "short_circuit" task will be skipped +since the decorated function returns False, "task_7" will still execute as its set to execute when upstream tasks have completed running regardless of status (i.e. the ``TriggerRule.ALL_DONE`` trigger rule). -.. exampleinclude:: /../../airflow/example_dags/example_short_circuit_operator.py +.. exampleinclude:: /../../airflow/example_dags/example_short_circuit_decorator.py :language: python :dedent: 4 :start-after: [START howto_operator_short_circuit_trigger_rules] @@ -173,7 +178,7 @@ tasks have completed running regardless of status (i.e. the ``TriggerRule.ALL_DO Passing in arguments ^^^^^^^^^^^^^^^^^^^^ -Both the ``op_args`` and ``op_kwargs`` arguments can be used in same way as described for the PythonOperator. +Pass extra arguments to the ``@task.short_circuit``-decorated function as you would with a normal Python function. Templating diff --git a/tests/decorators/test_short_circuit.py b/tests/decorators/test_short_circuit.py new file mode 100644 index 0000000000000..c79da558ded28 --- /dev/null +++ b/tests/decorators/test_short_circuit.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from pendulum import datetime + +from airflow.decorators import task +from airflow.utils.state import State +from airflow.utils.trigger_rule import TriggerRule + +DEFAULT_DATE = datetime(2022, 8, 17) + + +def test_short_circuit_decorator(dag_maker): + with dag_maker(): + + @task + def empty(): + ... + + @task.short_circuit() + def short_circuit(condition): + return condition + + short_circuit_false = short_circuit.override(task_id="short_circuit_false")(condition=False) + task_1 = empty.override(task_id="task_1")() + short_circuit_false >> task_1 + + short_circuit_true = short_circuit.override(task_id="short_circuit_true")(condition=True) + task_2 = empty.override(task_id="task_2")() + short_circuit_true >> task_2 + + short_circuit_respect_trigger_rules = short_circuit.override( + task_id="short_circuit_respect_trigger_rules", ignore_downstream_trigger_rules=False + )(condition=False) + task_3 = empty.override(task_id="task_3")() + task_4 = empty.override(task_id="task_4")() + task_5 = empty.override(task_id="task_5", trigger_rule=TriggerRule.ALL_DONE)() + short_circuit_respect_trigger_rules >> [task_3, task_4] >> task_5 + + dr = dag_maker.create_dagrun() + + for t in dag_maker.dag.tasks: + t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + task_state_mapping = { + "short_circuit_false": State.SUCCESS, + "task_1": State.SKIPPED, + "short_circuit_true": State.SUCCESS, + "task_2": State.SUCCESS, + "short_circuit_respect_trigger_rules": State.SUCCESS, + "task_3": State.SKIPPED, + "task_4": State.SKIPPED, + "task_5": State.SUCCESS, + } + + tis = dr.get_task_instances() + for ti in tis: + assert ti.state == task_state_mapping[ti.task_id]