Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix schedule_downstream_tasks bug (#42582) #43299

Merged
merged 1 commit into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3865,21 +3865,15 @@ def _schedule_downstream_tasks(
assert task
assert task.dag

# Get a partial DAG with just the specific tasks we want to examine.
# In order for dep checks to work correctly, we include ourself (so
# TriggerRuleDep can check the state of the task we just executed).
partial_dag = task.dag.partial_subset(
task.downstream_task_ids,
include_downstream=True,
include_upstream=False,
include_direct_upstream=True,
)

dag_run.dag = partial_dag
# Previously, this section used task.dag.partial_subset to retrieve a partial DAG.
# However, this approach is unsafe as it can result in incomplete or incorrect task execution,
# leading to potential bad cases. As a result, the operation has been removed.
# For more details, refer to the discussion in PR #[https://github.com/apache/airflow/pull/42582].
dag_run.dag = task.dag
info = dag_run.task_instance_scheduling_decisions(session)

skippable_task_ids = {
task_id for task_id in partial_dag.task_ids if task_id not in task.downstream_task_ids
task_id for task_id in task.dag.task_ids if task_id not in task.downstream_task_ids
}

schedulable_tis = [
Expand Down
76 changes: 75 additions & 1 deletion tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
from airflow.notifications.basenotifier import BaseNotifier
from airflow.operators.bash import BashOperator
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import PythonOperator
from airflow.operators.python import BranchPythonOperator, PythonOperator
from airflow.sensors.base import BaseSensorOperator
from airflow.sensors.python import PythonSensor
from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG
Expand Down Expand Up @@ -5224,6 +5224,80 @@ def last_task():
assert "3 downstream tasks scheduled from follow-on schedule" in caplog.text


@pytest.mark.skip_if_database_isolation_mode
def test_one_success_task_in_mini_scheduler_if_upstreams_are_done(dag_maker, caplog, session):
"""Test that mini scheduler with one_success task"""
with dag_maker() as dag:
branch = BranchPythonOperator(task_id="branch", python_callable=lambda: "task_run")
task_run = BashOperator(task_id="task_run", bash_command="echo 0")
task_skip = BashOperator(task_id="task_skip", bash_command="echo 0")
task_1 = BashOperator(task_id="task_1", bash_command="echo 0")
task_one_success = BashOperator(
task_id="task_one_success", bash_command="echo 0", trigger_rule="one_success"
)
task_2 = BashOperator(task_id="task_2", bash_command="echo 0")

task_1 >> task_2
branch >> task_skip
branch >> task_run
task_run >> task_one_success
task_skip >> task_one_success
task_one_success >> task_2
task_skip >> task_2

dr = dag_maker.create_dagrun()

branch = dr.get_task_instance(task_id="branch")
task_1 = dr.get_task_instance(task_id="task_1")
task_skip = dr.get_task_instance(task_id="task_skip")
branch.state = State.SUCCESS
task_1.state = State.SUCCESS
task_skip.state = State.SKIPPED
session.merge(branch)
session.merge(task_1)
session.merge(task_skip)
session.commit()
task_1.refresh_from_task(dag.get_task("task_1"))
task_1.schedule_downstream_tasks(session=session)

branch = dr.get_task_instance(task_id="branch")
task_run = dr.get_task_instance(task_id="task_run")
task_skip = dr.get_task_instance(task_id="task_skip")
task_1 = dr.get_task_instance(task_id="task_1")
task_one_success = dr.get_task_instance(task_id="task_one_success")
task_2 = dr.get_task_instance(task_id="task_2")
assert branch.state == State.SUCCESS
assert task_run.state == State.NONE
assert task_skip.state == State.SKIPPED
assert task_1.state == State.SUCCESS
# task_one_success should not be scheduled
assert task_one_success.state == State.NONE
assert task_2.state == State.SKIPPED
assert "0 downstream tasks scheduled from follow-on schedule" in caplog.text

task_run = dr.get_task_instance(task_id="task_run")
task_run.state = State.SUCCESS
session.merge(task_run)
session.commit()
task_run.refresh_from_task(dag.get_task("task_run"))
task_run.schedule_downstream_tasks(session=session)

branch = dr.get_task_instance(task_id="branch")
task_run = dr.get_task_instance(task_id="task_run")
task_skip = dr.get_task_instance(task_id="task_skip")
task_1 = dr.get_task_instance(task_id="task_1")
task_one_success = dr.get_task_instance(task_id="task_one_success")
task_2 = dr.get_task_instance(task_id="task_2")
assert branch.state == State.SUCCESS
assert task_run.state == State.SUCCESS
assert task_skip.state == State.SKIPPED
assert task_1.state == State.SUCCESS
# task_one_success should not be scheduled
assert task_one_success.state == State.SCHEDULED
assert task_2.state == State.SKIPPED
assert "1 downstream tasks scheduled from follow-on schedule" in caplog.text


@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode
def test_mini_scheduler_not_skip_mapped_downstream_until_all_upstreams_finish(dag_maker, session):
with dag_maker(session=session):
Expand Down
Loading