Skip to content

Commit

Permalink
Add @task.short_circuit TaskFlow decorator (#25752)
Browse files Browse the repository at this point in the history
  • Loading branch information
josh-fell authored Sep 7, 2022
1 parent 87108d7 commit ebef9ed
Show file tree
Hide file tree
Showing 7 changed files with 256 additions and 18 deletions.
3 changes: 3 additions & 0 deletions airflow/decorators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from airflow.decorators.external_python import external_python_task
from airflow.decorators.python import python_task
from airflow.decorators.python_virtualenv import virtualenv_task
from airflow.decorators.short_circuit import short_circuit_task
from airflow.decorators.task_group import task_group
from airflow.models.dag import dag
from airflow.providers_manager import ProvidersManager
Expand All @@ -37,6 +38,7 @@
"virtualenv_task",
"external_python_task",
"branch_task",
"short_circuit_task",
]


Expand All @@ -47,6 +49,7 @@ class TaskDecoratorCollection:
virtualenv = staticmethod(virtualenv_task)
external_python = staticmethod(external_python_task)
branch = staticmethod(branch_task)
short_circuit = staticmethod(short_circuit_task)

__call__: Any = python # Alias '@task' to '@task.python'.

Expand Down
20 changes: 20 additions & 0 deletions airflow/decorators/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ __all__ = [
"virtualenv_task",
"external_python_task",
"branch_task",
"short_circuit_task",
]

class TaskDecoratorCollection:
Expand Down Expand Up @@ -171,6 +172,25 @@ class TaskDecoratorCollection:
"""
@overload
def branch(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ...
@overload
def short_circuit(
self,
*,
multiple_outputs: Optional[bool] = None,
ignore_downstream_trigger_rules: bool = True,
**kwargs,
) -> TaskDecorator:
"""Create a decorator to wrap the decorated callable into a ShortCircuitOperator.
:param multiple_outputs: If set, function return value will be unrolled to multiple XCom values.
Dict will unroll to XCom values with keys as XCom keys. Defaults to False.
:param ignore_downstream_trigger_rules: If set to True, all downstream tasks from this operator task
will be skipped. This is the default behavior. If set to False, the direct, downstream task(s)
will be skipped but the ``trigger_rule`` defined for a other downstream tasks will be respected.
Defaults to True.
"""
@overload
def short_circuit(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ...
# [START decorator_signature]
def docker(
self,
Expand Down
83 changes: 83 additions & 0 deletions airflow/decorators/short_circuit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Callable, Optional, Sequence

from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory
from airflow.operators.python import ShortCircuitOperator


class _ShortCircuitDecoratedOperator(DecoratedOperator, ShortCircuitOperator):
"""
Wraps a Python callable and captures args/kwargs when called for execution.
:param python_callable: A reference to an object that is callable
:param op_kwargs: a dictionary of keyword arguments that will get unpacked
in your function (templated)
:param op_args: a list of positional arguments that will get unpacked when
calling your callable (templated)
:param multiple_outputs: If set to True, the decorated function's return value will be unrolled to
multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False.
"""

template_fields: Sequence[str] = ('op_args', 'op_kwargs')
template_fields_renderers = {"op_args": "py", "op_kwargs": "py"}

# since we won't mutate the arguments, we should just do the shallow copy
# there are some cases we can't deepcopy the objects (e.g protobuf).
shallow_copy_attrs: Sequence[str] = ('python_callable',)

custom_operator_name: str = '@task.short_circuit'

def __init__(self, *, python_callable, op_args, op_kwargs, **kwargs) -> None:
kwargs_to_upstream = {
"python_callable": python_callable,
"op_args": op_args,
"op_kwargs": op_kwargs,
}
super().__init__(
kwargs_to_upstream=kwargs_to_upstream,
python_callable=python_callable,
op_args=op_args,
op_kwargs=op_kwargs,
**kwargs,
)


def short_circuit_task(
python_callable: Optional[Callable] = None,
multiple_outputs: Optional[bool] = None,
**kwargs,
) -> TaskDecorator:
"""Wraps a function into an ShortCircuitOperator.
Accepts kwargs for operator kwarg. Can be reused in a single DAG.
This function is only used only used during type checking or auto-completion.
:param python_callable: Function to decorate
:param multiple_outputs: If set to True, the decorated function's return value will be unrolled to
multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False.
:meta private:
"""
return task_decorator_factory(
python_callable=python_callable,
multiple_outputs=multiple_outputs,
decorated_operator_class=_ShortCircuitDecoratedOperator,
**kwargs,
)
59 changes: 59 additions & 0 deletions airflow/example_dags/example_short_circuit_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Example DAG demonstrating the usage of the `@task.short_circuit()` TaskFlow decorator."""
import pendulum

from airflow.decorators import dag, task
from airflow.models.baseoperator import chain
from airflow.operators.empty import EmptyOperator
from airflow.utils.trigger_rule import TriggerRule


@dag(start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, tags=['example'])
def example_short_circuit_decorator():
# [START howto_operator_short_circuit]
@task.short_circuit()
def check_condition(condition):
return condition

ds_true = [EmptyOperator(task_id='true_' + str(i)) for i in [1, 2]]
ds_false = [EmptyOperator(task_id='false_' + str(i)) for i in [1, 2]]

condition_is_true = check_condition.override(task_id="condition_is_true")(condition=True)
condition_is_false = check_condition.override(task_id="condition_is_false")(condition=False)

chain(condition_is_true, *ds_true)
chain(condition_is_false, *ds_false)
# [END howto_operator_short_circuit]

# [START howto_operator_short_circuit_trigger_rules]
[task_1, task_2, task_3, task_4, task_5, task_6] = [
EmptyOperator(task_id=f"task_{i}") for i in range(1, 7)
]

task_7 = EmptyOperator(task_id="task_7", trigger_rule=TriggerRule.ALL_DONE)

short_circuit = check_condition.override(task_id="short_circuit", ignore_downstream_trigger_rules=False)(
condition=False
)

chain(task_1, [task_2, short_circuit], [task_3, task_4], [task_5, task_6], task_7)
# [END howto_operator_short_circuit_trigger_rules]


example_dag = example_short_circuit_decorator()
4 changes: 0 additions & 4 deletions airflow/example_dags/example_short_circuit_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
catchup=False,
tags=['example'],
) as dag:
# [START howto_operator_short_circuit]
cond_true = ShortCircuitOperator(
task_id='condition_is_True',
python_callable=lambda: True,
Expand All @@ -47,9 +46,7 @@

chain(cond_true, *ds_true)
chain(cond_false, *ds_false)
# [END howto_operator_short_circuit]

# [START howto_operator_short_circuit_trigger_rules]
[task_1, task_2, task_3, task_4, task_5, task_6] = [
EmptyOperator(task_id=f"task_{i}") for i in range(1, 7)
]
Expand All @@ -61,4 +58,3 @@
)

chain(task_1, [task_2, short_circuit], [task_3, task_4], [task_5, task_6], task_7)
# [END howto_operator_short_circuit_trigger_rules]
33 changes: 19 additions & 14 deletions docs/apache-airflow/howto/operator/python.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,19 +129,24 @@ If you want the context related to datetime objects like ``data_interval_start``
.. _howto/operator:ShortCircuitOperator:

ShortCircuitOperator
========================
====================

Use the ``@task.short_circuit`` decorator to control whether a pipeline continues
if a condition is satisfied or a truthy value is obtained.

.. warning::
The ``@task.short_circuit`` decorator is recommended over the classic :class:`~airflow.operators.python.ShortCircuitOperator`
to short-circuit pipelines via Python callables.

Use the :class:`~airflow.operators.python.ShortCircuitOperator` to control whether a pipeline continues
if a condition is satisfied or a truthy value is obtained. The evaluation of this condition and truthy value
is done via the output of a ``python_callable``. If the ``python_callable`` returns True or a truthy value,
The evaluation of this condition and truthy value
is done via the output of the decorated function. If the decorated function returns True or a truthy value,
the pipeline is allowed to continue and an :ref:`XCom <concepts:xcom>` of the output will be pushed. If the
output is False or a falsy value, the pipeline will be short-circuited based on the configured
short-circuiting (more on this later). In the example below, the tasks that follow the "condition_is_True"
ShortCircuitOperator will execute while the tasks downstream of the "condition_is_False" ShortCircuitOperator
will be skipped.
short-circuiting (more on this later). In the example below, the tasks that follow the "condition_is_true"
task will execute while the tasks downstream of the "condition_is_false" task will be skipped.


.. exampleinclude:: /../../airflow/example_dags/example_short_circuit_operator.py
.. exampleinclude:: /../../airflow/example_dags/example_short_circuit_decorator.py
:language: python
:dedent: 4
:start-after: [START howto_operator_short_circuit]
Expand All @@ -155,14 +160,14 @@ set to False, the direct downstream tasks are skipped but the specified ``trigge
downstream tasks are respected. In this short-circuiting configuration, the operator assumes the direct
downstream task(s) were purposely meant to be skipped but perhaps not other subsequent tasks. This
configuration is especially useful if only *part* of a pipeline should be short-circuited rather than all
tasks which follow the ShortCircuitOperator task.
tasks which follow the short-circuiting task.

In the example below, notice that the ShortCircuitOperator task is configured to respect downstream trigger
rules. This means while the tasks that follow the "short_circuit" ShortCircuitOperator task will be skipped
since the ``python_callable`` returns False, "task_7" will still execute as its set to execute when upstream
In the example below, notice that the "short_circuit" task is configured to respect downstream trigger
rules. This means while the tasks that follow the "short_circuit" task will be skipped
since the decorated function returns False, "task_7" will still execute as its set to execute when upstream
tasks have completed running regardless of status (i.e. the ``TriggerRule.ALL_DONE`` trigger rule).

.. exampleinclude:: /../../airflow/example_dags/example_short_circuit_operator.py
.. exampleinclude:: /../../airflow/example_dags/example_short_circuit_decorator.py
:language: python
:dedent: 4
:start-after: [START howto_operator_short_circuit_trigger_rules]
Expand All @@ -173,7 +178,7 @@ tasks have completed running regardless of status (i.e. the ``TriggerRule.ALL_DO
Passing in arguments
^^^^^^^^^^^^^^^^^^^^

Both the ``op_args`` and ``op_kwargs`` arguments can be used in same way as described for the PythonOperator.
Pass extra arguments to the ``@task.short_circuit``-decorated function as you would with a normal Python function.


Templating
Expand Down
72 changes: 72 additions & 0 deletions tests/decorators/test_short_circuit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from pendulum import datetime

from airflow.decorators import task
from airflow.utils.state import State
from airflow.utils.trigger_rule import TriggerRule

DEFAULT_DATE = datetime(2022, 8, 17)


def test_short_circuit_decorator(dag_maker):
with dag_maker():

@task
def empty():
...

@task.short_circuit()
def short_circuit(condition):
return condition

short_circuit_false = short_circuit.override(task_id="short_circuit_false")(condition=False)
task_1 = empty.override(task_id="task_1")()
short_circuit_false >> task_1

short_circuit_true = short_circuit.override(task_id="short_circuit_true")(condition=True)
task_2 = empty.override(task_id="task_2")()
short_circuit_true >> task_2

short_circuit_respect_trigger_rules = short_circuit.override(
task_id="short_circuit_respect_trigger_rules", ignore_downstream_trigger_rules=False
)(condition=False)
task_3 = empty.override(task_id="task_3")()
task_4 = empty.override(task_id="task_4")()
task_5 = empty.override(task_id="task_5", trigger_rule=TriggerRule.ALL_DONE)()
short_circuit_respect_trigger_rules >> [task_3, task_4] >> task_5

dr = dag_maker.create_dagrun()

for t in dag_maker.dag.tasks:
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

task_state_mapping = {
"short_circuit_false": State.SUCCESS,
"task_1": State.SKIPPED,
"short_circuit_true": State.SUCCESS,
"task_2": State.SUCCESS,
"short_circuit_respect_trigger_rules": State.SUCCESS,
"task_3": State.SKIPPED,
"task_4": State.SKIPPED,
"task_5": State.SUCCESS,
}

tis = dr.get_task_instances()
for ti in tis:
assert ti.state == task_state_mapping[ti.task_id]

0 comments on commit ebef9ed

Please sign in to comment.