Skip to content

Commit

Permalink
Support building DAGs out of topologically unsorted YAML files (#307)
Browse files Browse the repository at this point in the history
Any YAML files that declare upstream tasks after downstream tasks,
regardless of using dynamic task mapping, would fail.

Example of DAG that would fail:
```
test_expand:
  default_args:
    owner: "custom_owner"
    start_date: 2 days
  description: "test expand"
  schedule_interval: "0 3 * * *"
  default_view: "graph"
  tasks:
    process:
      operator: airflow.operators.python_operator.PythonOperator
      python_callable_name: expand_task
      python_callable_file: $CONFIG_ROOT_DIR/expand_tasks.py
      partial:
        op_kwargs:
          test_id: "test"
      expand:
        op_args:
          request.output
      dependencies: [request]
    request:
      operator: airflow.operators.python.PythonOperator
      python_callable_name: example_task_mapping
      python_callable_file: $CONFIG_ROOT_DIR/expand_tasks.py
```

In this example, the upstream (parent) task "request" is defined after
the downstream (child) task "process". Before this change, this DAG
would fail.

I implemented a solution to solve the problem that uses Kahn's algorithm
to sort the tasks topologically:
https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm

It has asymptotic complexity O(N + D), where N is the total number of
tasks, and D is the total number of dependencies. This complexity seems
acceptable.

An alternative to the current approach would be to create all the tasks
without dependencies as a starting point and add the dependencies once
all tasks were made - similar to what we did in
https://github.com/astronomer/astronomer-cosmos. However, this approach
would require a bigger refactor of the DAG factory and may have issues
with dynamic task mapping.

Closes: #225
  • Loading branch information
tatiana authored Dec 6, 2024
1 parent 2e9fd4e commit a6bf015
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 15 deletions.
50 changes: 49 additions & 1 deletion dagfactory/dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,8 @@ def build(self) -> Dict[str, Union[str, DAG]]:

# create dictionary to track tasks and set dependencies
tasks_dict: Dict[str, BaseOperator] = {}
for task_name, task_conf in tasks.items():
tasks_tuples = self.topological_sort_tasks(tasks)
for task_name, task_conf in tasks_tuples:
task_conf["task_id"]: str = task_name
operator: str = task_conf["operator"]
task_conf["dag"]: DAG = dag
Expand All @@ -848,6 +849,53 @@ def build(self) -> Dict[str, Union[str, DAG]]:

return {"dag_id": dag_params["dag_id"], "dag": dag}

@staticmethod
def topological_sort_tasks(tasks_configs: dict[str, Any]) -> list[tuple(str, Any)]:
"""
Use the Kahn's algorithm to sort topologically the tasks:
(https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm)
The complexity is O(N + D) where N: total tasks and D: number of dependencies.
:returns: topologically sorted list containing tuples (task name, task config)
"""
# Step 1: Build the downstream (adjacency) tasks list and the upstream dependencies (in-degree) count
downstream_tasks = {}
upstream_dependencies_count = {}

for task_name, _ in tasks_configs.items():
downstream_tasks[task_name] = []
upstream_dependencies_count[task_name] = 0

for task_name, task_conf in tasks_configs.items():
for upstream_task in task_conf.get("dependencies", []):
# there are cases when dependencies contains references to TaskGroups and not Tasks - we skip those
if upstream_task in tasks_configs:
downstream_tasks[upstream_task].append(task_name)
upstream_dependencies_count[task_name] += 1

# Step 2: Find all tasks with no dependencies
tasks_without_dependencies = [
task for task in upstream_dependencies_count if not upstream_dependencies_count[task]
]
sorted_tasks = []

# Step 3: Perform topological sort
while tasks_without_dependencies:
current = tasks_without_dependencies.pop(0)
sorted_tasks.append((current, tasks_configs[current]))

for child in downstream_tasks[current]:
upstream_dependencies_count[child] -= 1
if upstream_dependencies_count[child] == 0:
tasks_without_dependencies.append(child)

# If not all tasks are processed, there is a cycle (not applicable for DAGs)
if len(sorted_tasks) != len(tasks_configs):
raise ValueError("Cycle detected in task dependencies!")

return sorted_tasks

@staticmethod
def set_callback(parameters: Union[dict, str], callback_type: str, has_name_and_file=False) -> Callable:
"""
Expand Down
14 changes: 7 additions & 7 deletions dev/dags/example_dynamic_task_mapping.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@ test_expand:
schedule_interval: "0 3 * * *"
default_view: "graph"
tasks:
request:
operator: airflow.operators.python.PythonOperator
python_callable_name: example_task_mapping
python_callable_file: $CONFIG_ROOT_DIR/expand_tasks.py
process:
operator: airflow.operators.python_operator.PythonOperator
python_callable_name: expand_task
python_callable_name: consume_value
python_callable_file: $CONFIG_ROOT_DIR/expand_tasks.py
partial:
op_kwargs:
test_id: "test"
fixed_param: "test"
expand:
op_args:
request.output
request.output
dependencies: [request]
request:
operator: airflow.operators.python.PythonOperator
python_callable_name: make_list
python_callable_file: $CONFIG_ROOT_DIR/expand_tasks.py
12 changes: 6 additions & 6 deletions dev/dags/expand_tasks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
def example_task_mapping():
return [[1], [2], [3]]
def make_list():
return [[1], [2], [3], [4]]


def expand_task(x, test_id):
print(test_id)
print(x)
return [x]
def consume_value(expanded_param, fixed_param):
print(fixed_param)
print(expanded_param)
return [expanded_param]
2 changes: 1 addition & 1 deletion tests/fixtures/dag_factory_kubernetes_pod_operator.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ example_dag:
task_id: 'passing-task'
get_logs: True
in_cluster: False
dependencies: ['task_1']
dependencies: []
task_2:
operator: airflow.contrib.operators.kubernetes_pod_operator.KubernetesPodOperator
namespace: 'default'
Expand Down
62 changes: 62 additions & 0 deletions tests/test_dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,3 +1058,65 @@ def test_make_nested_task_groups():
del sub_task_group["parent_group"]
assert task_groups["task_group"].__dict__ == expected["task_group"].__dict__
assert sub_task_group == expected["sub_task_group"].__dict__


class TestTopologicalSortTasks:

def test_basic_topological_sort(self):
tasks_configs = {
"task1": {"dependencies": []},
"task2": {"dependencies": ["task1"]},
"task3": {"dependencies": ["task2"]},
}
result = dagbuilder.DagBuilder.topological_sort_tasks(tasks_configs)
expected = [
("task1", {"dependencies": []}),
("task2", {"dependencies": ["task1"]}),
("task3", {"dependencies": ["task2"]}),
]
assert result == expected

def test_no_dependencies(self):
tasks_configs = {
"task1": {"dependencies": []},
"task2": {"dependencies": []},
"task3": {"dependencies": []},
}
result = dagbuilder.DagBuilder.topological_sort_tasks(tasks_configs)
# Order doesn't matter as there are no dependencies
expected = [
("task1", {"dependencies": []}),
("task2", {"dependencies": []}),
("task3", {"dependencies": []}),
]
assert result == expected

def test_empty_input(self):
tasks_configs = {}
result = dagbuilder.DagBuilder.topological_sort_tasks(tasks_configs)
assert result == []

def test_cyclic_dependencies(self):
tasks_configs = {
"task1": {"dependencies": ["task3"]},
"task2": {"dependencies": ["task1"]},
"task3": {"dependencies": ["task2"]},
}
with pytest.raises(ValueError) as exc_info:
dagbuilder.DagBuilder.topological_sort_tasks(tasks_configs)
assert "Cycle detected" in str(exc_info.value)

def test_multiple_dependencies(self):
tasks_configs = {
"task1": {"dependencies": []},
"task2": {"dependencies": ["task1"]},
"task3": {"dependencies": ["task1"]},
"task4": {"dependencies": ["task2", "task3"]},
}
result = dagbuilder.DagBuilder.topological_sort_tasks(tasks_configs)
# Verify ordering with dependencies
task_names = [task[0] for task in result]
assert task_names.index("task1") < task_names.index("task2")
assert task_names.index("task1") < task_names.index("task3")
assert task_names.index("task2") < task_names.index("task4")
assert task_names.index("task3") < task_names.index("task4")

0 comments on commit a6bf015

Please sign in to comment.