Skip to content

Commit

Permalink
fix: get task dependencies without serializing task tree to string (a…
Browse files Browse the repository at this point in the history
…pache#41494)

Signed-off-by: Maciej Obuchowski <[email protected]>
  • Loading branch information
mobuchowski authored and Artuz37 committed Aug 19, 2024
1 parent b218c8f commit 4ea24cd
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 74 deletions.
54 changes: 10 additions & 44 deletions airflow/providers/openlineage/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@
import datetime
import json
import logging
import re
from contextlib import redirect_stdout, suppress
from contextlib import suppress
from functools import wraps
from importlib import metadata
from io import StringIO
from typing import TYPE_CHECKING, Any, Callable, Iterable

import attrs
Expand All @@ -35,7 +33,7 @@
from airflow import __version__ as AIRFLOW_VERSION
from airflow.datasets import Dataset
from airflow.exceptions import AirflowProviderDeprecationWarning # TODO: move this maybe to Airflow's logic?
from airflow.models import DAG, BaseOperator, MappedOperator
from airflow.models import DAG, BaseOperator, MappedOperator, Operator
from airflow.providers.openlineage import conf
from airflow.providers.openlineage.plugins.facets import (
AirflowDagRunFacet,
Expand Down Expand Up @@ -441,16 +439,6 @@ def get_airflow_state_run_facet(dag_run: DagRun) -> dict[str, AirflowStateRunFac
}


def _safe_get_dag_tree_view(dag: DAG) -> list[str]:
# get_tree_view() has been added in Airflow 2.8.2
if hasattr(dag, "get_tree_view"):
return dag.get_tree_view().splitlines()

with redirect_stdout(StringIO()) as stdout:
dag.tree_view()
return stdout.getvalue().splitlines()


def _get_parsed_dag_tree(dag: DAG) -> dict:
"""
Get DAG's tasks hierarchy representation.
Expand All @@ -476,37 +464,15 @@ def _get_parsed_dag_tree(dag: DAG) -> dict:
"task_6": {}
}
"""
lines = _safe_get_dag_tree_view(dag)
task_dict: dict[str, dict] = {}
parent_map: dict[int, tuple[str, dict]] = {}

for line in lines:
stripped_line = line.strip()
if not stripped_line:
continue

# Determine the level by counting the leading spaces, assuming 4 spaces per level
# as defined in airflow.models.dag.DAG._generate_tree_view()
level = (len(line) - len(stripped_line)) // 4
# airflow.models.baseoperator.BaseOperator.__repr__ or
# airflow.models.mappedoperator.MappedOperator.__repr__ is used in DAG tree
# <Task({op_class}): {task_id}> or <Mapped({op_class}): {task_id}>
match = re.match(r"^<(?:Task|Mapped)\(.+\): (.+)>$", stripped_line)
if not match:
return {}
current_task_id = match[1]

if level == 0: # It's a root task
task_dict[current_task_id] = {}
parent_map[level] = (current_task_id, task_dict[current_task_id])
else:
# Find the immediate parent task
parent_task, parent_dict = parent_map[(level - 1)]
# Create new dict for the current task
parent_dict[current_task_id] = {}
# Update this task in the parent map
parent_map[level] = (current_task_id, parent_dict[current_task_id])

def get_downstream(task: Operator, current_dict: dict):
current_dict[task.task_id] = {}
for tmp_task in sorted(task.downstream_list, key=lambda x: x.task_id):
get_downstream(tmp_task, current_dict[task.task_id])

task_dict: dict = {}
for t in sorted(dag.roots, key=lambda x: x.task_id):
get_downstream(t, task_dict)
return task_dict


Expand Down
74 changes: 44 additions & 30 deletions tests/providers/openlineage/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from airflow import DAG
from airflow.decorators import task
from airflow.models.baseoperator import BaseOperator
from airflow.models.baseoperator import BaseOperator, chain
from airflow.models.dagrun import DagRun
from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskinstance import TaskInstance, TaskInstanceState
Expand All @@ -35,7 +35,6 @@
_get_parsed_dag_tree,
_get_task_groups_details,
_get_tasks_details,
_safe_get_dag_tree_view,
get_airflow_dag_run_facet,
get_airflow_job_facet,
get_fully_qualified_class_name,
Expand Down Expand Up @@ -342,34 +341,6 @@ def test_get_tasks_details_empty_dag():
assert _get_tasks_details(DAG("test_dag", schedule=None, start_date=datetime.datetime(2024, 6, 1))) == {}


def test_dag_tree_level_indent():
"""Tests the correct indentation of tasks in a DAG tree view.
Test verifies that the tree view of the DAG correctly represents the hierarchical structure
of the tasks with proper indentation. The expected indentation increases by 4 spaces for each
subsequent level in the DAG. The test asserts that the generated tree view matches the expected
lines with correct indentation.
"""
with DAG(dag_id="dag", schedule=None, start_date=datetime.datetime(2024, 6, 1)) as dag:
task_0 = EmptyOperator(task_id="task_0")
task_1 = EmptyOperator(task_id="task_1")
task_2 = EmptyOperator(task_id="task_2")
task_3 = EmptyOperator(task_id="task_3")

task_0 >> task_1 >> task_2
task_3 >> task_2

indent = 4 * " "
expected_lines = [
"<Task(EmptyOperator): task_0>",
indent + "<Task(EmptyOperator): task_1>",
2 * indent + "<Task(EmptyOperator): task_2>",
"<Task(EmptyOperator): task_3>",
indent + "<Task(EmptyOperator): task_2>",
]
assert _safe_get_dag_tree_view(dag) == expected_lines


def test_get_dag_tree():
class TestMappedOperator(BaseOperator):
def __init__(self, value, **kwargs):
Expand Down Expand Up @@ -462,6 +433,49 @@ def sum_values(values: list[int]) -> int:
assert result == expected


def test_get_dag_tree_large_dag():
class LongEmptyOperator(EmptyOperator):
# lets make repr really long :)
def __repr__(self) -> str:
return str(self.__dict__) * 200

with DAG("aaa_big_get_tree_view", schedule=None) as dag:
first_set = [LongEmptyOperator(task_id=f"hello_{i}_{'a' * 230}") for i in range(900)]
chain(*first_set)

last_task_in_first_set = first_set[-1]

chain(
last_task_in_first_set, [LongEmptyOperator(task_id=f"world_{i}_{'a' * 230}") for i in range(900)]
)

chain(
last_task_in_first_set, [LongEmptyOperator(task_id=f"this_{i}_{'a' * 230}") for i in range(900)]
)

chain(last_task_in_first_set, [LongEmptyOperator(task_id=f"is_{i}_{'a' * 230}") for i in range(900)])

chain(
last_task_in_first_set, [LongEmptyOperator(task_id=f"silly_{i}_{'a' * 230}") for i in range(900)]
)

chain(
last_task_in_first_set, [LongEmptyOperator(task_id=f"stuff_{i}_{'a' * 230}") for i in range(900)]
)

result = _get_parsed_dag_tree(dag)

def dfs_depth(d: dict, depth: int = 0) -> int:
max_depth = depth
for v in d.values():
if isinstance(v, dict):
max_depth = max(max_depth, dfs_depth(v, depth + 1))
return max_depth

assert len(result) == 1
assert dfs_depth(result, 901)


def test_get_dag_tree_empty_dag():
assert (
_get_parsed_dag_tree(
Expand Down

0 comments on commit 4ea24cd

Please sign in to comment.