diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py index 1881511f9bbae..698c469dbb717 100644 --- a/airflow/jobs/local_task_job.py +++ b/airflow/jobs/local_task_job.py @@ -18,25 +18,20 @@ from __future__ import annotations import signal -from typing import TYPE_CHECKING import psutil -from sqlalchemy.exc import OperationalError from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.jobs.base_job import BaseJob from airflow.listeners.events import register_task_instance_state_events from airflow.listeners.listener import get_listener_manager -from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance -from airflow.sentry import Sentry from airflow.stats import Stats from airflow.task.task_runner import get_task_runner from airflow.utils import timezone from airflow.utils.net import get_hostname from airflow.utils.session import provide_session -from airflow.utils.sqlalchemy import with_row_locks from airflow.utils.state import State @@ -165,7 +160,7 @@ def handle_task_exit(self, return_code: int) -> None: if not self.task_instance.test_mode: if conf.getboolean('scheduler', 'schedule_after_task_execution', fallback=True): - self._run_mini_scheduler_on_child_tasks() + self.task_instance.schedule_downstream_tasks() def on_kill(self): self.task_runner.terminate() @@ -230,58 +225,6 @@ def heartbeat_callback(self, session=None): self.terminating = True self._state_change_checks += 1 - @provide_session - @Sentry.enrich_errors - def _run_mini_scheduler_on_child_tasks(self, session=None) -> None: - try: - # Re-select the row with a lock - dag_run = with_row_locks( - session.query(DagRun).filter_by( - dag_id=self.dag_id, - run_id=self.task_instance.run_id, - ), - session=session, - ).one() - - task = self.task_instance.task - if TYPE_CHECKING: - assert task.dag - - # Get a partial DAG with just the specific tasks we want to examine. - # In order for dep checks to work correctly, we include ourself (so - # TriggerRuleDep can check the state of the task we just executed). - partial_dag = task.dag.partial_subset( - task.downstream_task_ids, - include_downstream=True, - include_upstream=False, - include_direct_upstream=True, - ) - - dag_run.dag = partial_dag - info = dag_run.task_instance_scheduling_decisions(session) - - skippable_task_ids = { - task_id for task_id in partial_dag.task_ids if task_id not in task.downstream_task_ids - } - - schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids] - for schedulable_ti in schedulable_tis: - if not hasattr(schedulable_ti, "task"): - schedulable_ti.task = task.dag.get_task(schedulable_ti.task_id) - - num = dag_run.schedule_tis(schedulable_tis) - self.log.info("%d downstream tasks scheduled from follow-on schedule check", num) - - session.commit() - except OperationalError as e: - # Any kind of DB error here is _non fatal_ as this block is just an optimisation. - self.log.info( - "Skipping mini scheduling run due to exception: %s", - e.statement, - exc_info=True, - ) - session.rollback() - @staticmethod def _enable_task_listeners(): """ diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 62cc22f379cdd..9c591bf364b5d 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -619,13 +619,18 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence try: total_length = self._get_specified_expand_input().get_total_map_length(run_id, session=session) except NotFullyPopulated as e: - self.log.info( - "Cannot expand %r for run %s; missing upstream values: %s", - self, - run_id, - sorted(e.missing), - ) total_length = None + # partial dags comes from the mini scheduler. It's + # possible that the upstream tasks are not yet done, + # but we don't have upstream of upstreams in partial dags, + # so we ignore this exception. + if not self.dag or not self.dag.partial: + self.log.error( + "Cannot expand %r for run %s; missing upstream values: %s", + self, + run_id, + sorted(e.missing), + ) state: TaskInstanceState | None = None unmapped_ti: TaskInstance | None = ( @@ -646,10 +651,15 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence # The unmapped task instance still exists and is unfinished, i.e. we # haven't tried to run it before. if total_length is None: - # If the map length cannot be calculated (due to unavailable - # upstream sources), fail the unmapped task. - unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED - indexes_to_map: Iterable[int] = () + if self.dag and self.dag.partial: + # If the DAG is partial, it's likely that the upstream tasks + # are not done yet, so we do nothing + indexes_to_map: Iterable[int] = () + else: + # If the map length cannot be calculated (due to unavailable + # upstream sources), fail the unmapped task. + unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED + indexes_to_map = () elif total_length < 1: # If the upstream maps this to a zero-length value, simply mark # the unmapped task instance as SKIPPED (if needed). diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index d4453ca8426d9..4388d592e9d11 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2572,6 +2572,67 @@ def ti_selector_condition(cls, vals: Collection[str | tuple[str, int]]) -> Colum return filters[0] return or_(*filters) + @Sentry.enrich_errors + @provide_session + def schedule_downstream_tasks(self, session=None): + """ + The mini-scheduler for scheduling downstream tasks of this task instance + :meta: private + """ + from sqlalchemy.exc import OperationalError + + from airflow.models import DagRun + + try: + # Re-select the row with a lock + dag_run = with_row_locks( + session.query(DagRun).filter_by( + dag_id=self.dag_id, + run_id=self.run_id, + ), + session=session, + ).one() + + task = self.task + if TYPE_CHECKING: + assert task.dag + + # Get a partial DAG with just the specific tasks we want to examine. + # In order for dep checks to work correctly, we include ourself (so + # TriggerRuleDep can check the state of the task we just executed). + partial_dag = task.dag.partial_subset( + task.downstream_task_ids, + include_downstream=True, + include_upstream=False, + include_direct_upstream=True, + ) + + dag_run.dag = partial_dag + info = dag_run.task_instance_scheduling_decisions(session) + + skippable_task_ids = { + task_id for task_id in partial_dag.task_ids if task_id not in task.downstream_task_ids + } + + schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids] + for schedulable_ti in schedulable_tis: + if not hasattr(schedulable_ti, "task"): + schedulable_ti.task = task.dag.get_task(schedulable_ti.task_id) + + num = dag_run.schedule_tis(schedulable_tis, session=session) + self.log.info("%d downstream tasks scheduled from follow-on schedule check", num) + + session.flush() + + except OperationalError as e: + # Any kind of DB error here is _non fatal_ as this block is just an optimisation. + self.log.info( + "Skipping mini scheduling run due to exception: %s", + e.statement, + exc_info=True, + ) + session.rollback() + # State of the task instance. # Stores string version of the task state. diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py index 61bce800c8908..26f5e629a7373 100644 --- a/tests/jobs/test_local_task_job.py +++ b/tests/jobs/test_local_task_job.py @@ -739,7 +739,6 @@ def test_mini_scheduler_works_with_wait_for_upstream(self, caplog, get_test_dag) ti2_l.refresh_from_db() assert ti2_k.state == State.SUCCESS assert ti2_l.state == State.NONE - assert "0 downstream tasks scheduled from follow-on schedule" in caplog.text failed_deps = list(ti2_l.get_failed_dep_statuses()) assert len(failed_deps) == 1 diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index f4edebcb08e9d..7f7cca0da2a39 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -3614,3 +3614,86 @@ def get_extra_env(): echo_task = dag.get_task("echo") assert "get_extra_env" in echo_task.upstream_task_ids + + +def test_mapped_task_does_not_error_in_mini_scheduler_if_upstreams_are_not_done(dag_maker, caplog, session): + """ + This tests that when scheduling child tasks of a task and there's a mapped downstream task, + if the mapped downstream task has upstreams that are not yet done, the mapped downstream task is + not marked as `upstream_failed' + """ + with dag_maker() as dag: + + @dag.task + def second_task(): + return [0, 1, 2] + + @dag.task + def first_task(): + print(2) + + @dag.task + def middle_task(id): + return id + + middle = middle_task.expand(id=second_task()) + + @dag.task + def last_task(): + print(3) + + [first_task(), middle] >> last_task() + + dag_run = dag_maker.create_dagrun() + first_ti = dag_run.get_task_instance(task_id="first_task") + second_ti = dag_run.get_task_instance(task_id="second_task") + first_ti.state = State.SUCCESS + second_ti.state = State.RUNNING + session.merge(first_ti) + session.merge(second_ti) + session.commit() + first_ti.schedule_downstream_tasks(session=session) + middle_ti = dag_run.get_task_instance(task_id="middle_task") + assert middle_ti.state != State.UPSTREAM_FAILED + assert "0 downstream tasks scheduled from follow-on schedule" in caplog.text + + +def test_mapped_task_expands_in_mini_scheduler_if_upstreams_are_done(dag_maker, caplog, session): + """Test that mini scheduler expands mapped task""" + with dag_maker() as dag: + + @dag.task + def second_task(): + return [0, 1, 2] + + @dag.task + def first_task(): + print(2) + + @dag.task + def middle_task(id): + return id + + middle = middle_task.expand(id=second_task()) + + @dag.task + def last_task(): + print(3) + + [first_task(), middle] >> last_task() + + dr = dag_maker.create_dagrun() + + first_ti = dr.get_task_instance(task_id="first_task") + first_ti.state = State.SUCCESS + session.merge(first_ti) + session.commit() + second_task = dag.get_task("second_task") + second_ti = dr.get_task_instance(task_id="second_task") + second_ti.refresh_from_task(second_task) + second_ti.run() + second_ti.schedule_downstream_tasks(session=session) + for i in range(3): + middle_ti = dr.get_task_instance(task_id="middle_task", map_index=i) + assert middle_ti.state == State.SCHEDULED + assert "3 downstream tasks scheduled from follow-on schedule" in caplog.text