Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't separate setup/teardowns from normal task #30008

Merged
merged 2 commits into from
Mar 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions airflow/serialization/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,6 @@
"is_mapped": { "type": "boolean" },
"prefix_group_id": { "type": "boolean" },
"children": { "$ref": "#/definitions/dict" },
"setup_children": { "$ref": "#/definitions/dict" },
"teardown_children": { "$ref": "#/definitions/dict" },
"tooltip": { "type": "string" },
"ui_color": { "type": "string" },
"ui_fgcolor": { "type": "string" },
Expand Down
19 changes: 0 additions & 19 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,13 +1323,6 @@ def serialize_task_group(cls, task_group: TaskGroup) -> dict[str, Any] | None:
"tooltip": task_group.tooltip,
"ui_color": task_group.ui_color,
"ui_fgcolor": task_group.ui_fgcolor,
"setup_children": {
label: child.serialize_for_task_group() for label, child in task_group.setup_children.items()
},
"teardown_children": {
label: child.serialize_for_task_group()
for label, child in task_group.teardown_children.items()
},
"children": {
label: child.serialize_for_task_group() for label, child in task_group.children.items()
},
Expand Down Expand Up @@ -1380,18 +1373,6 @@ def set_ref(task: Operator) -> Operator:
task.task_group = weakref.proxy(group)
return task

group.setup_children = {
label: set_ref(task_dict[val])
if _type == DAT.OP
else cls.deserialize_task_group(val, group, task_dict, dag=dag)
for label, (_type, val) in encoded_group["setup_children"].items()
}
group.teardown_children = {
label: set_ref(task_dict[val])
if _type == DAT.OP
else cls.deserialize_task_group(val, group, task_dict, dag=dag)
for label, (_type, val) in encoded_group["teardown_children"].items()
}
group.children = {
label: set_ref(task_dict[val])
if _type == DAT.OP
Expand Down
36 changes: 11 additions & 25 deletions airflow/utils/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,6 @@ def __init__(

self.children: dict[str, DAGNode] = {}

self.setup_children: dict[str, DAGNode] = {}
self.teardown_children: dict[str, DAGNode] = {}

if parent_group:
parent_group.add(self)

Expand Down Expand Up @@ -199,16 +196,12 @@ def parent_group(self) -> TaskGroup | None:
return self.task_group

def __iter__(self):
for child in self.all_children.values():
for child in self.children.values():
if isinstance(child, TaskGroup):
yield from child
else:
yield child

@property
def all_children(self) -> dict[str, DAGNode]:
return {**self.setup_children, **self.children, **self.teardown_children}

def add(self, task: DAGNode) -> None:
"""Add a task to this TaskGroup.

Expand All @@ -224,7 +217,7 @@ def add(self, task: DAGNode) -> None:
task.task_group = weakref.proxy(self)
key = task.node_id

if key in self.all_children:
if key in self.children:
node_type = "Task" if hasattr(task, "task_id") else "Task Group"
raise DuplicateTaskIdFound(f"{node_type} id '{key}' has already been added to the DAG")

Expand All @@ -241,27 +234,20 @@ def add(self, task: DAGNode) -> None:
if SetupTeardownContext.is_setup:
if isinstance(task, AbstractOperator):
setattr(task, "_is_setup", True)
self.setup_children[key] = task
elif SetupTeardownContext.is_teardown:
if isinstance(task, AbstractOperator):
setattr(task, "_is_teardown", True)
self.teardown_children[key] = task
else:
self.children[key] = task

self.children[key] = task

def _remove(self, task: DAGNode) -> None:
key = task.node_id

if key in self.children:
del self.children[key]
elif key in self.setup_children:
del self.setup_children[key]
elif key in self.teardown_children:
del self.teardown_children[key]
else:
if key not in self.children:
raise KeyError(f"Node id {key!r} not part of this task group")

self.used_group_ids.remove(key)
del self.children[key]

@property
def group_id(self) -> str | None:
Expand Down Expand Up @@ -352,7 +338,7 @@ def __exit__(self, _type, _value, _tb):

def has_task(self, task: BaseOperator) -> bool:
"""Returns True if this TaskGroup or its children TaskGroups contains the given task."""
if task.task_id in self.all_children:
if task.task_id in self.children:
return True

return any(child.has_task(task) for child in self.children.values() if isinstance(child, TaskGroup))
Expand Down Expand Up @@ -431,7 +417,7 @@ def build_map(task_group):

def get_child_by_label(self, label: str) -> DAGNode:
"""Get a child task/TaskGroup by its label (i.e. task_id/group_id without the group_id prefix)"""
return self.all_children[self.child_id(label)]
return self.children[self.child_id(label)]

def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
"""Required by DAGNode."""
Expand All @@ -450,12 +436,12 @@ def topological_sort(self, _include_subdag_tasks: bool = False):
# not have to pre-compute the "in-degree" of the nodes.
from airflow.operators.subdag import SubDagOperator # Avoid circular import

graph_unsorted = copy.copy(self.all_children)
graph_unsorted = copy.copy(self.children)

graph_sorted: list[DAGNode] = []

# special case
if len(self.all_children) == 0:
if len(self.children) == 0:
return graph_sorted

# Run until the unsorted graph is empty.
Expand Down Expand Up @@ -521,7 +507,7 @@ def iter_tasks(self) -> Iterator[AbstractOperator]:
while groups_to_visit:
visiting = groups_to_visit.pop(0)

for child in visiting.all_children.values():
for child in visiting.children.values():
if isinstance(child, AbstractOperator):
yield child
elif isinstance(child, TaskGroup):
Expand Down
60 changes: 20 additions & 40 deletions tests/decorators/test_setup_teardown.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,9 @@ def mytask():
with dag_maker() as dag:
mytask()

setup_task = dag.task_group.setup_children["mytask"]
assert len(dag.task_group.children) == 1
setup_task = dag.task_group.children["mytask"]
assert setup_task._is_setup
assert len(dag.task_group.setup_children) == 1
assert len(dag.task_group.children) == 0
assert len(dag.task_group.teardown_children) == 0

def test_marking_functions_as_teardown_task(self, dag_maker):
@teardown
Expand All @@ -43,11 +41,9 @@ def mytask():
with dag_maker() as dag:
mytask()

teardown_task = dag.task_group.teardown_children["mytask"]
assert len(dag.task_group.children) == 1
teardown_task = dag.task_group.children["mytask"]
assert teardown_task._is_teardown
assert len(dag.task_group.setup_children) == 0
assert len(dag.task_group.children) == 0
assert len(dag.task_group.teardown_children) == 1

def test_marking_decorated_functions_as_setup_task(self, dag_maker):
@setup
Expand All @@ -58,23 +54,19 @@ def mytask():
with dag_maker() as dag:
mytask()

setup_task = dag.task_group.setup_children["mytask"]
assert len(dag.task_group.children) == 1
setup_task = dag.task_group.children["mytask"]
assert setup_task._is_setup
assert len(dag.task_group.setup_children) == 1
assert len(dag.task_group.children) == 0
assert len(dag.task_group.teardown_children) == 0

def test_marking_operator_as_setup_task(self, dag_maker):
from airflow.operators.bash import BashOperator

with dag_maker() as dag:
BashOperator.as_setup(task_id="mytask", bash_command='echo "I am a setup task"')

setup_task = dag.task_group.setup_children["mytask"]
assert len(dag.task_group.children) == 1
setup_task = dag.task_group.children["mytask"]
assert setup_task._is_setup
assert len(dag.task_group.setup_children) == 1
assert len(dag.task_group.children) == 0
assert len(dag.task_group.teardown_children) == 0

def test_marking_decorated_functions_as_teardown_task(self, dag_maker):
@teardown
Expand All @@ -85,23 +77,19 @@ def mytask():
with dag_maker() as dag:
mytask()

teardown_task = dag.task_group.teardown_children["mytask"]
assert len(dag.task_group.children) == 1
teardown_task = dag.task_group.children["mytask"]
assert teardown_task._is_teardown
assert len(dag.task_group.setup_children) == 0
assert len(dag.task_group.children) == 0
assert len(dag.task_group.teardown_children) == 1

def test_marking_operator_as_teardown_task(self, dag_maker):
from airflow.operators.bash import BashOperator

with dag_maker() as dag:
BashOperator.as_teardown(task_id="mytask", bash_command='echo "I am a setup task"')

teardown_task = dag.task_group.teardown_children["mytask"]
assert len(dag.task_group.children) == 1
teardown_task = dag.task_group.children["mytask"]
assert teardown_task._is_teardown
assert len(dag.task_group.setup_children) == 0
assert len(dag.task_group.children) == 0
assert len(dag.task_group.teardown_children) == 1

def test_setup_taskgroup(self, dag_maker):
@setup
Expand All @@ -116,14 +104,10 @@ def mytask():
with dag_maker() as dag:
mygroup()

assert len(dag.task_group.setup_children) == 1
assert len(dag.task_group.children) == 0
assert len(dag.task_group.teardown_children) == 0
setup_task_group = dag.task_group.setup_children["mygroup"]
assert len(setup_task_group.setup_children) == 1
assert len(setup_task_group.children) == 0
assert len(setup_task_group.teardown_children) == 0
setup_task = setup_task_group.setup_children["mygroup.mytask"]
assert len(dag.task_group.children) == 1
setup_task_group = dag.task_group.children["mygroup"]
assert len(setup_task_group.children) == 1
setup_task = setup_task_group.children["mygroup.mytask"]
assert setup_task._is_setup

def test_teardown_taskgroup(self, dag_maker):
Expand All @@ -139,12 +123,8 @@ def mytask():
with dag_maker() as dag:
mygroup()

assert len(dag.task_group.setup_children) == 0
assert len(dag.task_group.children) == 0
assert len(dag.task_group.teardown_children) == 1
teardown_task_group = dag.task_group.teardown_children["mygroup"]
assert len(teardown_task_group.setup_children) == 0
assert len(teardown_task_group.children) == 0
assert len(teardown_task_group.teardown_children) == 1
teardown_task = teardown_task_group.teardown_children["mygroup.mytask"]
assert len(dag.task_group.children) == 1
teardown_task_group = dag.task_group.children["mygroup"]
assert len(teardown_task_group.children) == 1
teardown_task = teardown_task_group.children["mygroup.mytask"]
assert teardown_task._is_teardown
Loading