Skip to content

Commit

Permalink
Simplify DagRun.verify_integrity (#26894)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr authored Oct 21, 2022
1 parent ad6b4dc commit eeb39f1
Showing 1 changed file with 24 additions and 40 deletions.
64 changes: 24 additions & 40 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@
from airflow.models.dag import DAG
from airflow.models.operator import Operator


CreatedTasksType = TypeVar("CreatedTasksType")
CreatedTasks = TypeVar("CreatedTasks", Iterator["dict[str, Any]"], Iterator[TI])
TaskCreator = Callable[[Operator, Iterable[int]], CreatedTasks]


class TISchedulingDecision(NamedTuple):
Expand Down Expand Up @@ -854,11 +854,7 @@ def _emit_duration_stats_for_finished_state(self):
Stats.timing(f'dagrun.duration.failed.{self.dag_id}', duration)

@provide_session
def verify_integrity(
self,
*,
session: Session = NEW_SESSION,
):
def verify_integrity(self, *, session: Session = NEW_SESSION) -> None:
"""
Verifies the DagRun by checking for removed tasks or tasks that are not in the
database yet. It will set state to removed or add the task if required.
Expand All @@ -885,14 +881,12 @@ def task_filter(task: Operator) -> bool:
)

created_counts: dict[str, int] = defaultdict(int)

# Get task creator function
task_creator = self._get_task_creator(created_counts, task_instance_mutation_hook, hook_is_noop)

# Create the missing tasks, including mapped tasks
tasks = self._create_tasks(dag, task_creator, task_filter, session=session)

self._create_task_instances(dag.dag_id, tasks, created_counts, hook_is_noop, session=session)
tasks_to_create = (task for task in dag.task_dict.values() if task_filter(task))
tis_to_create = self._create_tasks(tasks_to_create, task_creator, session=session)
self._create_task_instances(self.dag_id, tis_to_create, created_counts, hook_is_noop, session=session)

def _check_for_removed_or_restored_tasks(
self, dag: DAG, ti_mutation_hook, *, session: Session
Expand Down Expand Up @@ -978,7 +972,7 @@ def _get_task_creator(
created_counts: dict[str, int],
ti_mutation_hook: Callable,
hook_is_noop: Literal[True],
) -> Callable[[Operator, tuple[int, ...]], Iterator[dict[str, Any]]]:
) -> Callable[[Operator, Iterable[int]], Iterator[dict[str, Any]]]:
...

@overload
Expand All @@ -987,15 +981,15 @@ def _get_task_creator(
created_counts: dict[str, int],
ti_mutation_hook: Callable,
hook_is_noop: Literal[False],
) -> Callable[[Operator, tuple[int, ...]], Iterator[TI]]:
) -> Callable[[Operator, Iterable[int]], Iterator[TI]]:
...

def _get_task_creator(
self,
created_counts: dict[str, int],
ti_mutation_hook: Callable,
hook_is_noop: Literal[True, False],
) -> Callable[[Operator, tuple[int, ...]], Iterator[dict[str, Any]] | Iterator[TI]]:
) -> Callable[[Operator, Iterable[int]], Iterator[dict[str, Any]] | Iterator[TI]]:
"""
Get the task creator function.
Expand All @@ -1008,7 +1002,7 @@ def _get_task_creator(
"""
if hook_is_noop:

def create_ti_mapping(task: Operator, indexes: tuple[int, ...]) -> Iterator[dict[str, Any]]:
def create_ti_mapping(task: Operator, indexes: Iterable[int]) -> Iterator[dict[str, Any]]:
created_counts[task.task_type] += 1
for map_index in indexes:
yield TI.insert_mapping(self.run_id, task, map_index=map_index)
Expand All @@ -1017,7 +1011,7 @@ def create_ti_mapping(task: Operator, indexes: tuple[int, ...]) -> Iterator[dict

else:

def create_ti(task: Operator, indexes: tuple[int, ...]) -> Iterator[TI]:
def create_ti(task: Operator, indexes: Iterable[int]) -> Iterator[TI]:
for map_index in indexes:
ti = TI(task, run_id=self.run_id, map_index=map_index)
ti_mutation_hook(ti)
Expand All @@ -1029,36 +1023,26 @@ def create_ti(task: Operator, indexes: tuple[int, ...]) -> Iterator[TI]:

def _create_tasks(
self,
dag: DAG,
task_creator: Callable[[Operator, tuple[int, ...]], CreatedTasksType],
task_filter: Callable[[Operator], bool],
tasks: Iterable[Operator],
task_creator: TaskCreator,
*,
session: Session,
) -> CreatedTasksType:
) -> CreatedTasks:
"""
Create missing tasks -- and expand any MappedOperator that _only_ have literals as input
:param dag: DAG object corresponding to the dagrun
:param task_creator: a function that creates tasks
:param task_filter: a function that filters tasks to create
:param session: the session to use
:param tasks: Tasks to create jobs for in the DAG run
:param task_creator: Function to create task instances
"""

def expand_mapped_literals(task: Operator) -> tuple[Operator, Sequence[int]]:
for task in tasks:
if not task.is_mapped:
return (task, (-1,))
task = cast("MappedOperator", task)
count = task.get_mapped_ti_count(self.run_id, session=session)
if not count:
return (task, (-1,))
return (task, range(count))

tasks_and_map_idxs = map(expand_mapped_literals, filter(task_filter, dag.task_dict.values()))

tasks: CreatedTasksType = itertools.chain.from_iterable( # type: ignore
itertools.starmap(task_creator, tasks_and_map_idxs) # type: ignore
)
return tasks
yield from task_creator(task, (-1,))
continue
count = cast(MappedOperator, task).get_mapped_ti_count(self.run_id, session=session)
if count:
yield from task_creator(task, range(count))
continue
yield from task_creator(task, (-1,))

def _create_task_instances(
self,
Expand Down

0 comments on commit eeb39f1

Please sign in to comment.