Skip to content

Commit

Permalink
Fix skipped tasks handling in mapped operators
Browse files Browse the repository at this point in the history
  • Loading branch information
shahar1 committed Dec 13, 2024
1 parent caa90a1 commit 3918d28
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 43 deletions.
2 changes: 2 additions & 0 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,8 @@ def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) ->
op.is_setup = is_setup
op.is_teardown = is_teardown
op.on_failure_fail_dagrun = on_failure_fail_dagrun
op.downstream_task_ids = self.downstream_task_ids
op.upstream_task_ids = self.upstream_task_ids
return op

# After a mapped operator is serialized, there's no real way to actually
Expand Down
4 changes: 1 addition & 3 deletions airflow/models/skipmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,6 @@ def _skip(
raise ValueError("dag_run is required")

task_ids_list = [d.task_id for d in task_list]
SkipMixin._set_state_to_skipped(dag_run, task_ids_list, session)
session.commit()

if task_id is not None:
from airflow.models.xcom import XCom
Expand All @@ -177,8 +175,8 @@ def _skip(
session=session,
)

@staticmethod
def skip_all_except(
self,
ti: TaskInstance | TaskInstancePydantic,
branch_task_ids: None | str | Iterable[str],
):
Expand Down
2 changes: 1 addition & 1 deletion airflow/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def get_tasks_to_skip():

to_skip = get_tasks_to_skip()

# this let's us avoid an intermediate list unless debug logging
# this lets us avoid an intermediate list unless debug logging
if self.log.getEffectiveLevel() <= logging.DEBUG:
self.log.debug("Downstream task IDs %s", to_skip := list(get_tasks_to_skip()))

Expand Down
78 changes: 39 additions & 39 deletions airflow/ti_deps/deps/not_previously_skipped_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from airflow.models.taskinstance import PAST_DEPENDS_MET
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.utils.db import LazySelectSequence


class NotPreviouslySkippedDep(BaseTIDep):
Expand All @@ -38,7 +39,6 @@ def _get_dep_statuses(self, ti, session, dep_context):
XCOM_SKIPMIXIN_FOLLOWED,
XCOM_SKIPMIXIN_KEY,
XCOM_SKIPMIXIN_SKIPPED,
SkipMixin,
)
from airflow.utils.state import TaskInstanceState

Expand All @@ -49,46 +49,46 @@ def _get_dep_statuses(self, ti, session, dep_context):
finished_task_ids = {t.task_id for t in finished_tis}

for parent in upstream:
if isinstance(parent, SkipMixin):
if parent.task_id not in finished_task_ids:
# This can happen if the parent task has not yet run.
continue
if parent.task_id not in finished_task_ids:
# This can happen if the parent task has not yet run.
continue

prev_result = ti.xcom_pull(task_ids=parent.task_id, key=XCOM_SKIPMIXIN_KEY, session=session)
prev_result = ti.xcom_pull(
task_ids=parent.task_id, key=XCOM_SKIPMIXIN_KEY, session=session, map_indexes=ti.map_index
)
if isinstance(prev_result, LazySelectSequence):
prev_result = next(iter(prev_result))

if prev_result is None:
# This can happen if the parent task has not yet run.
continue
if prev_result is None:
# This can happen if the parent task has not yet run.
continue

should_skip = False
if (
XCOM_SKIPMIXIN_FOLLOWED in prev_result
and ti.task_id not in prev_result[XCOM_SKIPMIXIN_FOLLOWED]
):
# Skip any tasks that are not in "followed"
should_skip = True
elif (
XCOM_SKIPMIXIN_SKIPPED in prev_result
and ti.task_id in prev_result[XCOM_SKIPMIXIN_SKIPPED]
):
# Skip any tasks that are in "skipped"
should_skip = True
should_skip = False
if (
XCOM_SKIPMIXIN_FOLLOWED in prev_result
and ti.task_id not in prev_result[XCOM_SKIPMIXIN_FOLLOWED]
):
# Skip any tasks that are not in "followed"
should_skip = True
elif XCOM_SKIPMIXIN_SKIPPED in prev_result and ti.task_id in prev_result[XCOM_SKIPMIXIN_SKIPPED]:
# Skip any tasks that are in "skipped"
should_skip = True

if should_skip:
# If the parent SkipMixin has run, and the XCom result stored indicates this
# ti should be skipped, set ti.state to SKIPPED and fail the rule so that the
# ti does not execute.
if dep_context.wait_for_past_depends_before_skipping:
past_depends_met = ti.xcom_pull(
task_ids=ti.task_id, key=PAST_DEPENDS_MET, session=session, default=False
)
if not past_depends_met:
yield self._failing_status(
reason=("Task should be skipped but the past depends are not met")
)
return
ti.set_state(TaskInstanceState.SKIPPED, session)
yield self._failing_status(
reason=f"Skipping because of previous XCom result from parent task {parent.task_id}"
if should_skip:
# If the parent SkipMixin has run, and the XCom result stored indicates this
# ti should be skipped, set ti.state to SKIPPED and fail the rule so that the
# ti does not execute.
if dep_context.wait_for_past_depends_before_skipping:
past_depends_met = ti.xcom_pull(
task_ids=ti.task_id, key=PAST_DEPENDS_MET, session=session, default=False
)
return
if not past_depends_met:
yield self._failing_status(
reason="Task should be skipped but the past depends are not met"
)
return
ti.set_state(TaskInstanceState.SKIPPED, session)
yield self._failing_status(
reason=f"Skipping because of previous XCom result from parent task {parent.task_id}"
)
return

0 comments on commit 3918d28

Please sign in to comment.