Skip to content

Commit

Permalink
Support Task Flow and enhance dynamic task mapping (#314)
Browse files Browse the repository at this point in the history
Implement support for [Airflow
TaskFlow](https://airflow.apache.org/docs/apache-airflow/stable/core-concepts/taskflow.html),
available since 2.0.

# How to test

The following example defines a task that generates a list of numbers
and another that consumes this list and creates dynamically (using
Airflow dynamic task mapping) an independent task that doubles each
individual number.
```
example_taskflow:
  default_args:
    owner: "custom_owner"
    start_date: 2 days
  description: "Example of TaskFlow powered DAG that includes dynamic task mapping."
  schedule_interval: "0 3 * * *"
  default_view: "graph"
  tasks:

    numbers_list:
      decorator: airflow.decorators.task
      python_callable: sample.build_numbers_list

    double_number_with_dynamic_task_mapping_taskflow:
      decorator: airflow.decorators.task
      python_callable: sample.double
      expand:
          number: +numbers_list  # the prefix + tells DagFactory to resolve this value as the task `numbers_list`, previously defined
```

For the `sample.py` file below:
```
def build_numbers_list():
    return [2, 4, 6]


def double(number: int):
    result = 2 * number
    print(result)
    return result
```

In the UI, it is shown as:
![Screenshot 2024-12-06 at 11 53
04](https://github.com/user-attachments/assets/0643002a-2530-4bc1-af39-16fb3f48d4d4)

And:

![Screenshot 2024-12-06 at 11 52
28](https://github.com/user-attachments/assets/2c2ed46a-4ee8-438a-836d-3112b4737c6a)

# Scope

This PR includes several use cases of [dynamic task
mapping](https://airflow.apache.org/docs/apache-airflow/2.10.3/authoring-and-scheduling/dynamic-task-mapping.html):
1. Simple mapping
2. Task-generated mapping
3. Repeated mapping
4. Adding parameters that do not expand (`partial`)
5. Mapping over multiple parameters
6. Named mapping (`map_index_template`)

The following dynamic task mapping cases were not tested but are
expected to work:
* Mapping with non-TaskFlow operators
* Mapping over the result of classic operators
* Filtering items from a mapped task

The following dynamic task mapping cases were not tested and should not
work (they were considered outside of the scope of the current ticket):
* Assigning multiple parameters to a non-TaskFlow operator
* Mapping over a task group
* Transforming expanding data
* Combining upstream data (aka “zipping”)

# Tests

The feature is being tested by running the example DAGs introduced in
this PR, which validate various scenarios of task flow and dynamic task
mapping and serve as documentation.

As with other parts of DAG Factory, we can and should improve the
overall unit test coverage.

Two example DAG files were added, containing multiple examples of
TaskFlow and Dynamic Task mapping. This is how they are displayed in the
AIrflow UI:
<img width="1501" alt="Screenshot 2024-12-06 at 16 11 10"
src="https://github.com/user-attachments/assets/c4d12520-31f5-4b9d-b191-dd37523299e1">
<img width="1500" alt="Screenshot 2024-12-06 at 16 11 42"
src="https://github.com/user-attachments/assets/ab08749f-aedb-4c8f-9df1-8f0d0451477d">
<img width="1510" alt="Screenshot 2024-12-06 at 16 11 32"
src="https://github.com/user-attachments/assets/591e949a-49da-49f6-8d4d-1458fbb88d7f">



# Docs

This PR does not contain user-facing docs other than the README.
However, we'll address this as part of #278.

# Related issues

This PR closes two open tickets:

Closes: #302 (support named mapping, via the `map_index_template`
argument)

Example of usage of `map_index_template`:
```
    dynamic_task_with_named_mapping:
      decorator: airflow.decorators.task
      python_callable: sample.extract_last_name
      map_index_template: "{{ custom_mapping_key }}"
      expand:
        full_name:
          - Lucy Black
          - Vera Santos
          - Marks Spencer
```

Closes: #301 (Mapping over multiple parameters)

Example of multiple parameters:
```
    multiply_with_multiple_parameters:
      decorator: airflow.decorators.task
      python_callable: sample.multiply
      expand:
          a: +numbers_list  # the prefix + tells DagFactory to resolve this value as the task `numbers_list`, previously defined
          b: +another_numbers_list # the prefix + tells DagFactory to resolve this value as the task `another_numbers_list`, previously defined
```
  • Loading branch information
tatiana authored Dec 6, 2024
1 parent 80b885e commit 1f6525c
Show file tree
Hide file tree
Showing 7 changed files with 316 additions and 79 deletions.
243 changes: 165 additions & 78 deletions dagfactory/dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

# pylint: disable=ungrouped-imports
import inspect
import os
import re
from copy import deepcopy
Expand Down Expand Up @@ -452,55 +453,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
if task_params.get("init_containers") is not None
else None
)
if utils.check_dict_key(task_params, "execution_timeout_secs"):
task_params["execution_timeout"]: timedelta = timedelta(seconds=task_params["execution_timeout_secs"])
del task_params["execution_timeout_secs"]

if utils.check_dict_key(task_params, "sla_secs"):
task_params["sla"]: timedelta = timedelta(seconds=task_params["sla_secs"])
del task_params["sla_secs"]

if utils.check_dict_key(task_params, "execution_delta_secs"):
task_params["execution_delta"]: timedelta = timedelta(seconds=task_params["execution_delta_secs"])
del task_params["execution_delta_secs"]

if utils.check_dict_key(task_params, "execution_date_fn_name") and utils.check_dict_key(
task_params, "execution_date_fn_file"
):
task_params["execution_date_fn"]: Callable = utils.get_python_callable(
task_params["execution_date_fn_name"],
task_params["execution_date_fn_file"],
)
del task_params["execution_date_fn_name"]
del task_params["execution_date_fn_file"]

# on_execute_callback is an Airflow 2.0 feature
if utils.check_dict_key(task_params, "on_execute_callback") and version.parse(
AIRFLOW_VERSION
) >= version.parse("2.0.0"):
task_params["on_execute_callback"]: Callable = import_string(task_params["on_execute_callback"])

if utils.check_dict_key(task_params, "on_failure_callback"):
task_params["on_failure_callback"]: Callable = import_string(task_params["on_failure_callback"])

if utils.check_dict_key(task_params, "on_success_callback"):
task_params["on_success_callback"]: Callable = import_string(task_params["on_success_callback"])

if utils.check_dict_key(task_params, "on_retry_callback"):
task_params["on_retry_callback"]: Callable = import_string(task_params["on_retry_callback"])

# use variables as arguments on operator
if utils.check_dict_key(task_params, "variables_as_arguments"):
variables: List[Dict[str, str]] = task_params.get("variables_as_arguments")
for variable in variables:
if Variable.get(variable["variable"], default_var=None) is not None:
task_params[variable["attribute"]] = Variable.get(variable["variable"], default_var=None)
del task_params["variables_as_arguments"]

if (
utils.check_dict_key(task_params, "expand") or utils.check_dict_key(task_params, "partial")
) and version.parse(AIRFLOW_VERSION) < version.parse("2.3.0"):
raise DagFactoryConfigException("Dynamic task mapping available only in Airflow >= 2.3.0")
DagBuilder.adjust_general_task_params(task_params)

expand_kwargs: Dict[str, Union[Dict[str, Any], Any]] = {}
# expand available only in airflow >= 2.3.0
Expand All @@ -518,23 +471,6 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
if partial_kwargs and not utils.is_partial_duplicated(partial_kwargs, task_params):
task_params.update(partial_kwargs)

if utils.check_dict_key(task_params, "outlets") and version.parse(AIRFLOW_VERSION) >= version.parse(
"2.4.0"
):
if utils.check_dict_key(task_params["outlets"], "file") and utils.check_dict_key(
task_params["outlets"], "datasets"
):
file = task_params["outlets"]["file"]
datasets_filter = task_params["outlets"]["datasets"]
datasets_uri = utils.get_datasets_uri_yaml_file(file, datasets_filter)

del task_params["outlets"]["file"]
del task_params["outlets"]["datasets"]
else:
datasets_uri = task_params["outlets"]

task_params["outlets"] = [Dataset(uri) for uri in datasets_uri]

task: Union[BaseOperator, MappedOperator] = (
operator_obj(**task_params)
if not expand_kwargs
Expand Down Expand Up @@ -826,23 +762,34 @@ def build(self) -> Dict[str, Union[str, DAG]]:
tasks_tuples = self.topological_sort_tasks(tasks)
for task_name, task_conf in tasks_tuples:
task_conf["task_id"]: str = task_name
operator: str = task_conf["operator"]
task_conf["dag"]: DAG = dag
# add task to task_group

if task_groups_dict and task_conf.get("task_group_name"):
task_conf["task_group"] = task_groups_dict[task_conf.get("task_group_name")]
# Dynamic task mapping available only in Airflow >= 2.3.0
if (task_conf.get("expand") or task_conf.get("partial")) and version.parse(AIRFLOW_VERSION) < version.parse(
"2.3.0"
):
raise DagFactoryConfigException("Dynamic task mapping available only in Airflow >= 2.3.0")

# replace 'task_id.output' or 'XComArg(task_id)' with XComArg(task_instance) object
if task_conf.get("expand") and version.parse(AIRFLOW_VERSION) >= version.parse("2.3.0"):
task_conf = self.replace_expand_values(task_conf, tasks_dict)
params: Dict[str, Any] = {k: v for k, v in task_conf.items() if k not in SYSTEM_PARAMS}
task: Union[BaseOperator, MappedOperator] = DagBuilder.make_task(operator=operator, task_params=params)
tasks_dict[task.task_id]: BaseOperator = task

if "operator" in task_conf:
operator: str = task_conf["operator"]

# Dynamic task mapping available only in Airflow >= 2.3.0
if task_conf.get("expand"):
if version.parse(AIRFLOW_VERSION) < version.parse("2.3.0"):
raise DagFactoryConfigException("Dynamic task mapping available only in Airflow >= 2.3.0")
else:
task_conf = self.replace_expand_values(task_conf, tasks_dict)

task: Union[BaseOperator, MappedOperator] = DagBuilder.make_task(operator=operator, task_params=params)
tasks_dict[task.task_id]: BaseOperator = task

elif "decorator" in task_conf:
task = DagBuilder.make_decorator(
decorator_import_path=task_conf["decorator"], task_params=params, tasks_dict=tasks_dict
)
tasks_dict[task_name]: BaseOperator = task
else:
raise DagFactoryConfigException("Tasks must define either 'operator' or 'decorator")

# set task dependencies after creating tasks
self.set_dependencies(tasks, tasks_dict, dag_params.get("task_groups", {}), task_groups_dict)

Expand Down Expand Up @@ -895,6 +842,146 @@ def topological_sort_tasks(tasks_configs: dict[str, Any]) -> list[tuple(str, Any

return sorted_tasks

def adjust_general_task_params(task_params: dict(str, Any)):
"""Adjusts in place the task params argument"""
if utils.check_dict_key(task_params, "execution_timeout_secs"):
task_params["execution_timeout"]: timedelta = timedelta(seconds=task_params["execution_timeout_secs"])
del task_params["execution_timeout_secs"]

if utils.check_dict_key(task_params, "sla_secs"):
task_params["sla"]: timedelta = timedelta(seconds=task_params["sla_secs"])
del task_params["sla_secs"]

if utils.check_dict_key(task_params, "execution_delta_secs"):
task_params["execution_delta"]: timedelta = timedelta(seconds=task_params["execution_delta_secs"])
del task_params["execution_delta_secs"]

if utils.check_dict_key(task_params, "execution_date_fn_name") and utils.check_dict_key(
task_params, "execution_date_fn_file"
):
task_params["execution_date_fn"]: Callable = utils.get_python_callable(
task_params["execution_date_fn_name"],
task_params["execution_date_fn_file"],
)
del task_params["execution_date_fn_name"]
del task_params["execution_date_fn_file"]

# on_execute_callback is an Airflow 2.0 feature
if utils.check_dict_key(task_params, "on_execute_callback"):
task_params["on_execute_callback"]: Callable = import_string(task_params["on_execute_callback"])

if utils.check_dict_key(task_params, "on_failure_callback"):
task_params["on_failure_callback"]: Callable = import_string(task_params["on_failure_callback"])

if utils.check_dict_key(task_params, "on_success_callback"):
task_params["on_success_callback"]: Callable = import_string(task_params["on_success_callback"])

if utils.check_dict_key(task_params, "on_retry_callback"):
task_params["on_retry_callback"]: Callable = import_string(task_params["on_retry_callback"])

# use variables as arguments on operator
if utils.check_dict_key(task_params, "variables_as_arguments"):
variables: List[Dict[str, str]] = task_params.get("variables_as_arguments")
for variable in variables:
if Variable.get(variable["variable"], default_var=None) is not None:
task_params[variable["attribute"]] = Variable.get(variable["variable"], default_var=None)
del task_params["variables_as_arguments"]

if utils.check_dict_key(task_params, "outlets") and version.parse(AIRFLOW_VERSION) >= version.parse("2.4.0"):
if utils.check_dict_key(task_params["outlets"], "file") and utils.check_dict_key(
task_params["outlets"], "datasets"
):
file = task_params["outlets"]["file"]
datasets_filter = task_params["outlets"]["datasets"]
datasets_uri = utils.get_datasets_uri_yaml_file(file, datasets_filter)

del task_params["outlets"]["file"]
del task_params["outlets"]["datasets"]
else:
datasets_uri = task_params["outlets"]

task_params["outlets"] = [Dataset(uri) for uri in datasets_uri]

@staticmethod
def make_decorator(
decorator_import_path: str, task_params: Dict[str, Any], tasks_dict: dict(str, Any)
) -> BaseOperator:
"""
Takes a decorator and params and creates an instance of that decorator.
:returns: instance of operator object
"""
# Check mandatory fields
mandatory_keys_set1 = set(["python_callable_name", "python_callable_file"])

# Fetch the Python callable
if set(mandatory_keys_set1).issubset(task_params):
python_callable: Callable = utils.get_python_callable(
task_params["python_callable_name"],
task_params["python_callable_file"],
)
# Remove dag-factory specific parameters since Airflow 2.0 doesn't allow these to be passed to operator
del task_params["python_callable_name"]
del task_params["python_callable_file"]
elif "python_callable" in task_params:
python_callable: Callable = import_string(task_params["python_callable"])
else:
raise DagFactoryException(
"Failed to create task. Decorator-based tasks require \
`python_callable_name` and `python_callable_file` "
"parameters.\nOptionally you can load python_callable "
"from a file. with the special pyyaml notation:\n"
" python_callable: !!python/name:my_module.my_func"
)

task_params["python_callable"] = python_callable

decorator: Callable[..., BaseOperator] = import_string(decorator_import_path)
task_params.pop("decorator")

DagBuilder.adjust_general_task_params(task_params)

callable_args_keys = inspect.getfullargspec(python_callable).args
callable_kwargs = {}
decorator_kwargs = dict(**task_params)
for arg_key, arg_value in task_params.items():
if arg_key in callable_args_keys:
decorator_kwargs.pop(arg_key)
if isinstance(arg_value, str) and arg_value.startswith("+"):
upstream_task_name = arg_value.split("+")[-1]
callable_kwargs[arg_key] = tasks_dict[upstream_task_name]
else:
callable_kwargs[arg_key] = arg_value

expand_kwargs = decorator_kwargs.pop("expand", {})
partial_kwargs = decorator_kwargs.pop("partial", {})

if ("map_index_template" in decorator_kwargs) and (version.parse(AIRFLOW_VERSION) < version.parse("2.7.0")):
raise DagFactoryConfigException(
"The dynamic task mapping argument `map_index_template` is only supported since Airflow 2.7"
)

if expand_kwargs and partial_kwargs:
if callable_kwargs:
raise DagFactoryConfigException(
"When using dynamic task mapping, all the task arguments should be defined in expand and partial."
)
DagBuilder.replace_kwargs_values_as_tasks(expand_kwargs, tasks_dict)
DagBuilder.replace_kwargs_values_as_tasks(partial_kwargs, tasks_dict)
return decorator(**decorator_kwargs).partial(**partial_kwargs).expand(**expand_kwargs)
elif expand_kwargs:
DagBuilder.replace_kwargs_values_as_tasks(expand_kwargs, tasks_dict)
return decorator(**decorator_kwargs).expand(**expand_kwargs)
else:
return decorator(**decorator_kwargs)(**callable_kwargs)

@staticmethod
def replace_kwargs_values_as_tasks(kwargs: dict(str, Any), tasks_dict: dict(str, Any)):
for key, value in kwargs.items():
if isinstance(value, str) and value.startswith("+"):
upstream_task_name = value.split("+")[-1]
kwargs[key] = tasks_dict[upstream_task_name]

@staticmethod
def set_callback(parameters: Union[dict, str], callback_type: str, has_name_and_file=False) -> Callable:
"""
Expand Down
16 changes: 16 additions & 0 deletions dev/dags/example_map_index_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os
from pathlib import Path

# The following import is here so Airflow parses this file
# from airflow import DAG
import dagfactory

DEFAULT_CONFIG_ROOT_DIR = "/usr/local/airflow/dags/"
CONFIG_ROOT_DIR = Path(os.getenv("CONFIG_ROOT_DIR", DEFAULT_CONFIG_ROOT_DIR))

config_file = str(CONFIG_ROOT_DIR / "example_map_index_template.yml")
example_dag_factory = dagfactory.DagFactory(config_file)

# Creating task dependencies
example_dag_factory.clean_dags(globals())
example_dag_factory.generate_dags(globals())
18 changes: 18 additions & 0 deletions dev/dags/example_map_index_template.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Requires Airflow 2.7 or higher
example_map_index_template:
default_args:
owner: "custom_owner"
start_date: 2 days
description: "Example of TaskFlow powered DAG that includes dynamic task mapping"
schedule_interval: "0 3 * * *"
default_view: "graph"
tasks:
dynamic_task_with_named_mapping:
decorator: airflow.decorators.task
python_callable: sample.extract_last_name
map_index_template: "{{ custom_mapping_key }}"
expand:
full_name:
- Lucy Black
- Vera Santos
- Marks Spencer
16 changes: 16 additions & 0 deletions dev/dags/example_taskflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os
from pathlib import Path

# The following import is here so Airflow parses this file
# from airflow import DAG
import dagfactory

DEFAULT_CONFIG_ROOT_DIR = "/usr/local/airflow/dags/"
CONFIG_ROOT_DIR = Path(os.getenv("CONFIG_ROOT_DIR", DEFAULT_CONFIG_ROOT_DIR))

config_file = str(CONFIG_ROOT_DIR / "example_taskflow.yml")
example_dag_factory = dagfactory.DagFactory(config_file)

# Creating task dependencies
example_dag_factory.clean_dags(globals())
example_dag_factory.generate_dags(globals())
52 changes: 52 additions & 0 deletions dev/dags/example_taskflow.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
example_taskflow:
default_args:
owner: "custom_owner"
start_date: 2 days
description: "Example of TaskFlow powered DAG that includes dynamic task mapping"
schedule_interval: "0 3 * * *"
default_view: "graph"
tasks:
some_number:
decorator: airflow.decorators.task
python_callable: sample.some_number
numbers_list:
decorator: airflow.decorators.task
python_callable_name: build_numbers_list
python_callable_file: $CONFIG_ROOT_DIR/sample.py
another_numbers_list:
decorator: airflow.decorators.task
python_callable: sample.build_numbers_list
double_number_from_arg:
decorator: airflow.decorators.task
python_callable: sample.double
number: 2
double_number_from_task:
decorator: airflow.decorators.task
python_callable: sample.double
number: +some_number # the prefix + leads to resolving this value as the task `some_number`, previously defined
double_number_with_dynamic_task_mapping_static:
decorator: airflow.decorators.task
python_callable: sample.double
expand:
number:
- 1
- 3
- 5
double_number_with_dynamic_task_mapping_taskflow:
decorator: airflow.decorators.task
python_callable: sample.double
expand:
number: +numbers_list # the prefix + tells DagFactory to resolve this value as the task `numbers_list`, previously defined
multiply_with_multiple_parameters:
decorator: airflow.decorators.task
python_callable: sample.multiply
expand:
a: +numbers_list # the prefix + tells DagFactory to resolve this value as the task `numbers_list`, previously defined
b: +another_numbers_list # the prefix + tells DagFactory to resolve this value as the task `another_numbers_list`, previously defined
double_number_with_dynamic_task_and_partial:
decorator: airflow.decorators.task
python_callable: sample.double_with_label
expand:
number: +numbers_list # the prefix + tells DagFactory to resolve this value as the task `numbers_list`, previously defined
partial:
label: True
Loading

0 comments on commit 1f6525c

Please sign in to comment.