Skip to content

Commit

Permalink
Support building DAGs out of topologically unsorted YAML files
Browse files Browse the repository at this point in the history
Closes: #225
  • Loading branch information
tatiana committed Dec 4, 2024
1 parent 72bc85b commit d52213a
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 21 deletions.
63 changes: 55 additions & 8 deletions dagfactory/dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,16 @@
)
from airflow.kubernetes.secret import Secret
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator
except ImportError: # pragma: no cover
from airflow.contrib.kubernetes.pod import Port
from airflow.contrib.kubernetes.pod_runtime_info_env import PodRuntimeInfoEnv
from airflow.contrib.kubernetes.secret import Secret
from airflow.contrib.kubernetes.volume import Volume
from airflow.contrib.kubernetes.volume_mount import VolumeMount
from airflow.contrib.operators.kubernetes_pod_operator import KubernetesPodOperator
except (ImportError, ModuleNotFoundError): # pragma: no cover
try:
from airflow.contrib.kubernetes.pod import Port
from airflow.contrib.kubernetes.pod_runtime_info_env import PodRuntimeInfoEnv
from airflow.contrib.kubernetes.secret import Secret
from airflow.contrib.kubernetes.volume import Volume
from airflow.contrib.kubernetes.volume_mount import VolumeMount
from airflow.contrib.operators.kubernetes_pod_operator import KubernetesPodOperator
except ModuleNotFoundError:
pass

from airflow.utils.task_group import TaskGroup
from kubernetes.client.models import V1Container, V1Pod
Expand Down Expand Up @@ -820,7 +823,8 @@ def build(self) -> Dict[str, Union[str, DAG]]:

# create dictionary to track tasks and set dependencies
tasks_dict: Dict[str, BaseOperator] = {}
for task_name, task_conf in tasks.items():
tasks_tuples = self.topological_sort_tasks(tasks)
for task_name, task_conf in tasks_tuples:
task_conf["task_id"]: str = task_name
operator: str = task_conf["operator"]
task_conf["dag"]: DAG = dag
Expand All @@ -844,6 +848,49 @@ def build(self) -> Dict[str, Union[str, DAG]]:

return {"dag_id": dag_params["dag_id"], "dag": dag}

@staticmethod
def topological_sort_tasks(tasks_configs: dict[str, Any]) -> list[tuple(str, Any)]:
"""
Use the Kahn's algorithm to sort topologically the tasks:
(https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm)
The complexity is O(N + D) where N: total tasks and D: number of dependencies.
"""
# Step 1: Build the downstream (adjacency) tasks list and the upstream dependencies (in-degree) count
downstream_tasks = {}
upstream_dependencies_count = {}

for task_id, _ in tasks_configs.items():
downstream_tasks[task_id] = []
upstream_dependencies_count[task_id] = 0

for task_id, task_conf in tasks_configs.items():
for upstream_task in task_conf.get("dependencies", []):
downstream_tasks[upstream_task].append(task_id)
upstream_dependencies_count[task_id] += 1

# Step 2: Find all tasks with no dependencies
tasks_without_dependencies = [
task for task in upstream_dependencies_count if not upstream_dependencies_count[task]
]
sorted_tasks = []

# Step 3: Perform topological sort
while tasks_without_dependencies:
current = tasks_without_dependencies.pop(0)
sorted_tasks.append((current, tasks_configs[current]))

for child in downstream_tasks[current]:
upstream_dependencies_count[child] -= 1
if upstream_dependencies_count[child] == 0:
tasks_without_dependencies.append(child)

# If not all tasks are processed, there is a cycle (not applicable for DAGs)
if len(sorted_tasks) != len(tasks_configs):
raise ValueError("Cycle detected in task dependencies!")

return sorted_tasks

@staticmethod
def set_callback(parameters: Union[dict, str], callback_type: str, has_name_and_file=False) -> Callable:
"""
Expand Down
14 changes: 7 additions & 7 deletions dev/dags/example_dynamic_task_mapping.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@ test_expand:
schedule_interval: "0 3 * * *"
default_view: "graph"
tasks:
request:
operator: airflow.operators.python.PythonOperator
python_callable_name: example_task_mapping
python_callable_file: $CONFIG_ROOT_DIR/expand_tasks.py
process:
operator: airflow.operators.python_operator.PythonOperator
python_callable_name: expand_task
python_callable_name: consume_value
python_callable_file: $CONFIG_ROOT_DIR/expand_tasks.py
partial:
op_kwargs:
test_id: "test"
fixed_param: "test"
expand:
op_args:
request.output
request.output
dependencies: [request]
request:
operator: airflow.operators.python.PythonOperator
python_callable_name: make_list
python_callable_file: $CONFIG_ROOT_DIR/expand_tasks.py
12 changes: 6 additions & 6 deletions dev/dags/expand_tasks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
def example_task_mapping():
return [[1], [2], [3]]
def make_list():
return [[1], [2], [3], [4]]


def expand_task(x, test_id):
print(test_id)
print(x)
return [x]
def consume_value(expanded_param, fixed_param):
print(fixed_param)
print(expanded_param)
return [expanded_param]

0 comments on commit d52213a

Please sign in to comment.