Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Task Flow and enhance dynamic task mapping #314

Merged
merged 10 commits into from
Dec 6, 2024
243 changes: 165 additions & 78 deletions dagfactory/dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

# pylint: disable=ungrouped-imports
import inspect
import os
import re
from copy import deepcopy
Expand Down Expand Up @@ -452,55 +453,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
if task_params.get("init_containers") is not None
else None
)
if utils.check_dict_key(task_params, "execution_timeout_secs"):
task_params["execution_timeout"]: timedelta = timedelta(seconds=task_params["execution_timeout_secs"])
del task_params["execution_timeout_secs"]

if utils.check_dict_key(task_params, "sla_secs"):
task_params["sla"]: timedelta = timedelta(seconds=task_params["sla_secs"])
del task_params["sla_secs"]

if utils.check_dict_key(task_params, "execution_delta_secs"):
task_params["execution_delta"]: timedelta = timedelta(seconds=task_params["execution_delta_secs"])
del task_params["execution_delta_secs"]

if utils.check_dict_key(task_params, "execution_date_fn_name") and utils.check_dict_key(
task_params, "execution_date_fn_file"
):
task_params["execution_date_fn"]: Callable = utils.get_python_callable(
task_params["execution_date_fn_name"],
task_params["execution_date_fn_file"],
)
del task_params["execution_date_fn_name"]
del task_params["execution_date_fn_file"]

# on_execute_callback is an Airflow 2.0 feature
if utils.check_dict_key(task_params, "on_execute_callback") and version.parse(
AIRFLOW_VERSION
) >= version.parse("2.0.0"):
task_params["on_execute_callback"]: Callable = import_string(task_params["on_execute_callback"])

if utils.check_dict_key(task_params, "on_failure_callback"):
task_params["on_failure_callback"]: Callable = import_string(task_params["on_failure_callback"])

if utils.check_dict_key(task_params, "on_success_callback"):
task_params["on_success_callback"]: Callable = import_string(task_params["on_success_callback"])

if utils.check_dict_key(task_params, "on_retry_callback"):
task_params["on_retry_callback"]: Callable = import_string(task_params["on_retry_callback"])

# use variables as arguments on operator
if utils.check_dict_key(task_params, "variables_as_arguments"):
variables: List[Dict[str, str]] = task_params.get("variables_as_arguments")
for variable in variables:
if Variable.get(variable["variable"], default_var=None) is not None:
task_params[variable["attribute"]] = Variable.get(variable["variable"], default_var=None)
del task_params["variables_as_arguments"]

if (
utils.check_dict_key(task_params, "expand") or utils.check_dict_key(task_params, "partial")
) and version.parse(AIRFLOW_VERSION) < version.parse("2.3.0"):
raise DagFactoryConfigException("Dynamic task mapping available only in Airflow >= 2.3.0")
tatiana marked this conversation as resolved.
Show resolved Hide resolved
DagBuilder.adjust_general_task_params(task_params)

expand_kwargs: Dict[str, Union[Dict[str, Any], Any]] = {}
# expand available only in airflow >= 2.3.0
Expand All @@ -518,23 +471,6 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
if partial_kwargs and not utils.is_partial_duplicated(partial_kwargs, task_params):
task_params.update(partial_kwargs)

if utils.check_dict_key(task_params, "outlets") and version.parse(AIRFLOW_VERSION) >= version.parse(
"2.4.0"
):
if utils.check_dict_key(task_params["outlets"], "file") and utils.check_dict_key(
task_params["outlets"], "datasets"
):
file = task_params["outlets"]["file"]
datasets_filter = task_params["outlets"]["datasets"]
datasets_uri = utils.get_datasets_uri_yaml_file(file, datasets_filter)

del task_params["outlets"]["file"]
del task_params["outlets"]["datasets"]
else:
datasets_uri = task_params["outlets"]

task_params["outlets"] = [Dataset(uri) for uri in datasets_uri]

tatiana marked this conversation as resolved.
Show resolved Hide resolved
task: Union[BaseOperator, MappedOperator] = (
operator_obj(**task_params)
if not expand_kwargs
Expand Down Expand Up @@ -826,23 +762,34 @@ def build(self) -> Dict[str, Union[str, DAG]]:
tasks_tuples = self.topological_sort_tasks(tasks)
for task_name, task_conf in tasks_tuples:
task_conf["task_id"]: str = task_name
operator: str = task_conf["operator"]
tatiana marked this conversation as resolved.
Show resolved Hide resolved
task_conf["dag"]: DAG = dag
# add task to task_group

if task_groups_dict and task_conf.get("task_group_name"):
task_conf["task_group"] = task_groups_dict[task_conf.get("task_group_name")]
# Dynamic task mapping available only in Airflow >= 2.3.0
if (task_conf.get("expand") or task_conf.get("partial")) and version.parse(AIRFLOW_VERSION) < version.parse(
"2.3.0"
):
raise DagFactoryConfigException("Dynamic task mapping available only in Airflow >= 2.3.0")

# replace 'task_id.output' or 'XComArg(task_id)' with XComArg(task_instance) object
if task_conf.get("expand") and version.parse(AIRFLOW_VERSION) >= version.parse("2.3.0"):
task_conf = self.replace_expand_values(task_conf, tasks_dict)
params: Dict[str, Any] = {k: v for k, v in task_conf.items() if k not in SYSTEM_PARAMS}
task: Union[BaseOperator, MappedOperator] = DagBuilder.make_task(operator=operator, task_params=params)
tasks_dict[task.task_id]: BaseOperator = task

if "operator" in task_conf:
operator: str = task_conf["operator"]

# Dynamic task mapping available only in Airflow >= 2.3.0
if task_conf.get("expand"):
if version.parse(AIRFLOW_VERSION) < version.parse("2.3.0"):
raise DagFactoryConfigException("Dynamic task mapping available only in Airflow >= 2.3.0")
else:
task_conf = self.replace_expand_values(task_conf, tasks_dict)

task: Union[BaseOperator, MappedOperator] = DagBuilder.make_task(operator=operator, task_params=params)
tasks_dict[task.task_id]: BaseOperator = task

elif "decorator" in task_conf:
task = DagBuilder.make_decorator(
decorator_import_path=task_conf["decorator"], task_params=params, tasks_dict=tasks_dict
)
tasks_dict[task_name]: BaseOperator = task
else:
raise DagFactoryConfigException("Tasks must define either 'operator' or 'decorator")

# set task dependencies after creating tasks
self.set_dependencies(tasks, tasks_dict, dag_params.get("task_groups", {}), task_groups_dict)

Expand Down Expand Up @@ -895,6 +842,146 @@ def topological_sort_tasks(tasks_configs: dict[str, Any]) -> list[tuple(str, Any

return sorted_tasks

def adjust_general_task_params(task_params: dict(str, Any)):
"""Adjusts in place the task params argument"""
if utils.check_dict_key(task_params, "execution_timeout_secs"):
task_params["execution_timeout"]: timedelta = timedelta(seconds=task_params["execution_timeout_secs"])
del task_params["execution_timeout_secs"]

if utils.check_dict_key(task_params, "sla_secs"):
task_params["sla"]: timedelta = timedelta(seconds=task_params["sla_secs"])
del task_params["sla_secs"]

if utils.check_dict_key(task_params, "execution_delta_secs"):
task_params["execution_delta"]: timedelta = timedelta(seconds=task_params["execution_delta_secs"])
del task_params["execution_delta_secs"]

if utils.check_dict_key(task_params, "execution_date_fn_name") and utils.check_dict_key(
task_params, "execution_date_fn_file"
):
task_params["execution_date_fn"]: Callable = utils.get_python_callable(
task_params["execution_date_fn_name"],
task_params["execution_date_fn_file"],
)
del task_params["execution_date_fn_name"]
del task_params["execution_date_fn_file"]

# on_execute_callback is an Airflow 2.0 feature
if utils.check_dict_key(task_params, "on_execute_callback"):
task_params["on_execute_callback"]: Callable = import_string(task_params["on_execute_callback"])

if utils.check_dict_key(task_params, "on_failure_callback"):
task_params["on_failure_callback"]: Callable = import_string(task_params["on_failure_callback"])

if utils.check_dict_key(task_params, "on_success_callback"):
task_params["on_success_callback"]: Callable = import_string(task_params["on_success_callback"])

if utils.check_dict_key(task_params, "on_retry_callback"):
task_params["on_retry_callback"]: Callable = import_string(task_params["on_retry_callback"])

# use variables as arguments on operator
if utils.check_dict_key(task_params, "variables_as_arguments"):
variables: List[Dict[str, str]] = task_params.get("variables_as_arguments")
for variable in variables:
if Variable.get(variable["variable"], default_var=None) is not None:
task_params[variable["attribute"]] = Variable.get(variable["variable"], default_var=None)
del task_params["variables_as_arguments"]

if utils.check_dict_key(task_params, "outlets") and version.parse(AIRFLOW_VERSION) >= version.parse("2.4.0"):
if utils.check_dict_key(task_params["outlets"], "file") and utils.check_dict_key(
task_params["outlets"], "datasets"
):
file = task_params["outlets"]["file"]
datasets_filter = task_params["outlets"]["datasets"]
datasets_uri = utils.get_datasets_uri_yaml_file(file, datasets_filter)

del task_params["outlets"]["file"]
del task_params["outlets"]["datasets"]
else:
datasets_uri = task_params["outlets"]

task_params["outlets"] = [Dataset(uri) for uri in datasets_uri]

@staticmethod
def make_decorator(
decorator_import_path: str, task_params: Dict[str, Any], tasks_dict: dict(str, Any)
) -> BaseOperator:
"""
Takes a decorator and params and creates an instance of that decorator.

:returns: instance of operator object
"""
# Check mandatory fields
mandatory_keys_set1 = set(["python_callable_name", "python_callable_file"])

# Fetch the Python callable
if set(mandatory_keys_set1).issubset(task_params):
python_callable: Callable = utils.get_python_callable(
task_params["python_callable_name"],
task_params["python_callable_file"],
)
# Remove dag-factory specific parameters since Airflow 2.0 doesn't allow these to be passed to operator
del task_params["python_callable_name"]
del task_params["python_callable_file"]
elif "python_callable" in task_params:
python_callable: Callable = import_string(task_params["python_callable"])
else:
raise DagFactoryException(
"Failed to create task. Decorator-based tasks require \
`python_callable_name` and `python_callable_file` "
"parameters.\nOptionally you can load python_callable "
"from a file. with the special pyyaml notation:\n"
" python_callable: !!python/name:my_module.my_func"
)

task_params["python_callable"] = python_callable

decorator: Callable[..., BaseOperator] = import_string(decorator_import_path)
task_params.pop("decorator")

DagBuilder.adjust_general_task_params(task_params)

callable_args_keys = inspect.getfullargspec(python_callable).args
callable_kwargs = {}
decorator_kwargs = dict(**task_params)
for arg_key, arg_value in task_params.items():
if arg_key in callable_args_keys:
decorator_kwargs.pop(arg_key)
if isinstance(arg_value, str) and arg_value.startswith("+"):
upstream_task_name = arg_value.split("+")[-1]
callable_kwargs[arg_key] = tasks_dict[upstream_task_name]
else:
callable_kwargs[arg_key] = arg_value

expand_kwargs = decorator_kwargs.pop("expand", {})
partial_kwargs = decorator_kwargs.pop("partial", {})

if ("map_index_template" in decorator_kwargs) and (version.parse(AIRFLOW_VERSION) < version.parse("2.7.0")):
raise DagFactoryConfigException(
"The dynamic task mapping argument `map_index_template` is only supported since Airflow 2.7"
)

if expand_kwargs and partial_kwargs:
if callable_kwargs:
raise DagFactoryConfigException(
"When using dynamic task mapping, all the task arguments should be defined in expand and partial."
)
DagBuilder.replace_kwargs_values_as_tasks(expand_kwargs, tasks_dict)
DagBuilder.replace_kwargs_values_as_tasks(partial_kwargs, tasks_dict)
return decorator(**decorator_kwargs).partial(**partial_kwargs).expand(**expand_kwargs)
elif expand_kwargs:
DagBuilder.replace_kwargs_values_as_tasks(expand_kwargs, tasks_dict)
return decorator(**decorator_kwargs).expand(**expand_kwargs)
else:
return decorator(**decorator_kwargs)(**callable_kwargs)

@staticmethod
def replace_kwargs_values_as_tasks(kwargs: dict(str, Any), tasks_dict: dict(str, Any)):
for key, value in kwargs.items():
if isinstance(value, str) and value.startswith("+"):
upstream_task_name = value.split("+")[-1]
kwargs[key] = tasks_dict[upstream_task_name]

@staticmethod
def set_callback(parameters: Union[dict, str], callback_type: str, has_name_and_file=False) -> Callable:
"""
Expand Down
16 changes: 16 additions & 0 deletions dev/dags/example_map_index_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os
from pathlib import Path

# The following import is here so Airflow parses this file
# from airflow import DAG
import dagfactory

DEFAULT_CONFIG_ROOT_DIR = "/usr/local/airflow/dags/"
CONFIG_ROOT_DIR = Path(os.getenv("CONFIG_ROOT_DIR", DEFAULT_CONFIG_ROOT_DIR))

config_file = str(CONFIG_ROOT_DIR / "example_map_index_template.yml")
example_dag_factory = dagfactory.DagFactory(config_file)

# Creating task dependencies
example_dag_factory.clean_dags(globals())
example_dag_factory.generate_dags(globals())
18 changes: 18 additions & 0 deletions dev/dags/example_map_index_template.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Requires Airflow 2.7 or higher
example_map_index_template:
default_args:
owner: "custom_owner"
start_date: 2 days
description: "Example of TaskFlow powered DAG that includes dynamic task mapping"
schedule_interval: "0 3 * * *"
default_view: "graph"
tasks:
dynamic_task_with_named_mapping:
decorator: airflow.decorators.task
python_callable: sample.extract_last_name
map_index_template: "{{ custom_mapping_key }}"
expand:
full_name:
- Lucy Black
- Vera Santos
- Marks Spencer
16 changes: 16 additions & 0 deletions dev/dags/example_taskflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os
from pathlib import Path

# The following import is here so Airflow parses this file
# from airflow import DAG
import dagfactory

DEFAULT_CONFIG_ROOT_DIR = "/usr/local/airflow/dags/"
CONFIG_ROOT_DIR = Path(os.getenv("CONFIG_ROOT_DIR", DEFAULT_CONFIG_ROOT_DIR))

config_file = str(CONFIG_ROOT_DIR / "example_taskflow.yml")
example_dag_factory = dagfactory.DagFactory(config_file)

# Creating task dependencies
example_dag_factory.clean_dags(globals())
example_dag_factory.generate_dags(globals())
52 changes: 52 additions & 0 deletions dev/dags/example_taskflow.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
example_taskflow:
default_args:
owner: "custom_owner"
start_date: 2 days
description: "Example of TaskFlow powered DAG that includes dynamic task mapping"
schedule_interval: "0 3 * * *"
default_view: "graph"
tasks:
some_number:
decorator: airflow.decorators.task
python_callable: sample.some_number
numbers_list:
decorator: airflow.decorators.task
python_callable_name: build_numbers_list
python_callable_file: $CONFIG_ROOT_DIR/sample.py
another_numbers_list:
decorator: airflow.decorators.task
python_callable: sample.build_numbers_list
double_number_from_arg:
decorator: airflow.decorators.task
python_callable: sample.double
number: 2
double_number_from_task:
decorator: airflow.decorators.task
python_callable: sample.double
number: +some_number # the prefix + leads to resolving this value as the task `some_number`, previously defined
double_number_with_dynamic_task_mapping_static:
decorator: airflow.decorators.task
python_callable: sample.double
expand:
number:
- 1
- 3
- 5
double_number_with_dynamic_task_mapping_taskflow:
decorator: airflow.decorators.task
python_callable: sample.double
expand:
number: +numbers_list # the prefix + tells DagFactory to resolve this value as the task `numbers_list`, previously defined
multiply_with_multiple_parameters:
decorator: airflow.decorators.task
python_callable: sample.multiply
expand:
a: +numbers_list # the prefix + tells DagFactory to resolve this value as the task `numbers_list`, previously defined
b: +another_numbers_list # the prefix + tells DagFactory to resolve this value as the task `another_numbers_list`, previously defined
double_number_with_dynamic_task_and_partial:
decorator: airflow.decorators.task
python_callable: sample.double_with_label
expand:
number: +numbers_list # the prefix + tells DagFactory to resolve this value as the task `numbers_list`, previously defined
partial:
label: True
Loading
Loading