From d52213a5b936270ab02a93cb1ab301d7c20daea3 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Wed, 4 Dec 2024 13:54:03 +0000 Subject: [PATCH] Support building DAGs out of topologically unsorted YAML files Closes: #225 --- dagfactory/dagbuilder.py | 63 ++++++++++++++++++++--- dev/dags/example_dynamic_task_mapping.yml | 14 ++--- dev/dags/expand_tasks.py | 12 ++--- 3 files changed, 68 insertions(+), 21 deletions(-) diff --git a/dagfactory/dagbuilder.py b/dagfactory/dagbuilder.py index 3eabf900..cc86dc13 100644 --- a/dagfactory/dagbuilder.py +++ b/dagfactory/dagbuilder.py @@ -67,13 +67,16 @@ ) from airflow.kubernetes.secret import Secret from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator -except ImportError: # pragma: no cover - from airflow.contrib.kubernetes.pod import Port - from airflow.contrib.kubernetes.pod_runtime_info_env import PodRuntimeInfoEnv - from airflow.contrib.kubernetes.secret import Secret - from airflow.contrib.kubernetes.volume import Volume - from airflow.contrib.kubernetes.volume_mount import VolumeMount - from airflow.contrib.operators.kubernetes_pod_operator import KubernetesPodOperator +except (ImportError, ModuleNotFoundError): # pragma: no cover + try: + from airflow.contrib.kubernetes.pod import Port + from airflow.contrib.kubernetes.pod_runtime_info_env import PodRuntimeInfoEnv + from airflow.contrib.kubernetes.secret import Secret + from airflow.contrib.kubernetes.volume import Volume + from airflow.contrib.kubernetes.volume_mount import VolumeMount + from airflow.contrib.operators.kubernetes_pod_operator import KubernetesPodOperator + except ModuleNotFoundError: + pass from airflow.utils.task_group import TaskGroup from kubernetes.client.models import V1Container, V1Pod @@ -820,7 +823,8 @@ def build(self) -> Dict[str, Union[str, DAG]]: # create dictionary to track tasks and set dependencies tasks_dict: Dict[str, BaseOperator] = {} - for task_name, task_conf in tasks.items(): + tasks_tuples = self.topological_sort_tasks(tasks) + for task_name, task_conf in tasks_tuples: task_conf["task_id"]: str = task_name operator: str = task_conf["operator"] task_conf["dag"]: DAG = dag @@ -844,6 +848,49 @@ def build(self) -> Dict[str, Union[str, DAG]]: return {"dag_id": dag_params["dag_id"], "dag": dag} + @staticmethod + def topological_sort_tasks(tasks_configs: dict[str, Any]) -> list[tuple(str, Any)]: + """ + Use the Kahn's algorithm to sort topologically the tasks: + (https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm) + + The complexity is O(N + D) where N: total tasks and D: number of dependencies. + """ + # Step 1: Build the downstream (adjacency) tasks list and the upstream dependencies (in-degree) count + downstream_tasks = {} + upstream_dependencies_count = {} + + for task_id, _ in tasks_configs.items(): + downstream_tasks[task_id] = [] + upstream_dependencies_count[task_id] = 0 + + for task_id, task_conf in tasks_configs.items(): + for upstream_task in task_conf.get("dependencies", []): + downstream_tasks[upstream_task].append(task_id) + upstream_dependencies_count[task_id] += 1 + + # Step 2: Find all tasks with no dependencies + tasks_without_dependencies = [ + task for task in upstream_dependencies_count if not upstream_dependencies_count[task] + ] + sorted_tasks = [] + + # Step 3: Perform topological sort + while tasks_without_dependencies: + current = tasks_without_dependencies.pop(0) + sorted_tasks.append((current, tasks_configs[current])) + + for child in downstream_tasks[current]: + upstream_dependencies_count[child] -= 1 + if upstream_dependencies_count[child] == 0: + tasks_without_dependencies.append(child) + + # If not all tasks are processed, there is a cycle (not applicable for DAGs) + if len(sorted_tasks) != len(tasks_configs): + raise ValueError("Cycle detected in task dependencies!") + + return sorted_tasks + @staticmethod def set_callback(parameters: Union[dict, str], callback_type: str, has_name_and_file=False) -> Callable: """ diff --git a/dev/dags/example_dynamic_task_mapping.yml b/dev/dags/example_dynamic_task_mapping.yml index 078e4f2d..e0613323 100644 --- a/dev/dags/example_dynamic_task_mapping.yml +++ b/dev/dags/example_dynamic_task_mapping.yml @@ -6,18 +6,18 @@ test_expand: schedule_interval: "0 3 * * *" default_view: "graph" tasks: - request: - operator: airflow.operators.python.PythonOperator - python_callable_name: example_task_mapping - python_callable_file: $CONFIG_ROOT_DIR/expand_tasks.py process: operator: airflow.operators.python_operator.PythonOperator - python_callable_name: expand_task + python_callable_name: consume_value python_callable_file: $CONFIG_ROOT_DIR/expand_tasks.py partial: op_kwargs: - test_id: "test" + fixed_param: "test" expand: op_args: - request.output + request.output dependencies: [request] + request: + operator: airflow.operators.python.PythonOperator + python_callable_name: make_list + python_callable_file: $CONFIG_ROOT_DIR/expand_tasks.py diff --git a/dev/dags/expand_tasks.py b/dev/dags/expand_tasks.py index 8e4aa00a..c14d0537 100644 --- a/dev/dags/expand_tasks.py +++ b/dev/dags/expand_tasks.py @@ -1,8 +1,8 @@ -def example_task_mapping(): - return [[1], [2], [3]] +def make_list(): + return [[1], [2], [3], [4]] -def expand_task(x, test_id): - print(test_id) - print(x) - return [x] +def consume_value(expanded_param, fixed_param): + print(fixed_param) + print(expanded_param) + return [expanded_param]