From 78fa95ca330f7e331118111ce4ff13ab1ce09c19 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Tue, 9 Aug 2022 15:17:41 +0100 Subject: [PATCH] Don't mistakenly take a lock on DagRun via ti.refresh_from_fb (#25312) In 2.2.0 we made TI.dag_run be automatically join-loaded, which is fine for most cases, but for `refresh_from_db` we don't need that (we don't access anything under ti.dag_run) and it's possible that when `lock_for_update=True` is passed we are locking more than we want to and _might_ cause deadlocks. Even if it doesn't, selecting more than we need is wasteful. (cherry picked from commit be2b53eaaf6fc136db8f3fa3edd797a6c529409a) --- airflow/models/taskinstance.py | 28 ++++++++++++++++++---------- tests/jobs/test_scheduler_job.py | 8 +++++--- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 33fe7a3f539e3..afcc469febcee 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -308,6 +308,7 @@ def clear_task_instances( if dag_run_state == DagRunState.QUEUED: dr.last_scheduling_decision = None dr.start_date = None + session.flush() class _LazyXComAccessIterator(collections.abc.Iterator): @@ -879,28 +880,35 @@ def refresh_from_db(self, session=NEW_SESSION, lock_for_update=False) -> None: """ self.log.debug("Refreshing TaskInstance %s from DB", self) - qry = session.query(TaskInstance).filter( - TaskInstance.dag_id == self.dag_id, - TaskInstance.task_id == self.task_id, - TaskInstance.run_id == self.run_id, - TaskInstance.map_index == self.map_index, + if self in session: + session.refresh(self, TaskInstance.__mapper__.column_attrs.keys()) + + qry = ( + # To avoid joining any relationships, by default select all + # columns, not the object. This also means we get (effectively) a + # namedtuple back, not a TI object + session.query(*TaskInstance.__table__.columns).filter( + TaskInstance.dag_id == self.dag_id, + TaskInstance.task_id == self.task_id, + TaskInstance.run_id == self.run_id, + TaskInstance.map_index == self.map_index, + ) ) if lock_for_update: for attempt in run_with_db_retries(logger=self.log): with attempt: - ti: Optional[TaskInstance] = qry.with_for_update().first() + ti: Optional[TaskInstance] = qry.with_for_update().one_or_none() else: - ti = qry.first() + ti = qry.one_or_none() if ti: # Fields ordered per model definition self.start_date = ti.start_date self.end_date = ti.end_date self.duration = ti.duration self.state = ti.state - # Get the raw value of try_number column, don't read through the - # accessor here otherwise it will be incremented by one already. - self.try_number = ti._try_number + # Since we selected columns, not the object, this is the raw value + self.try_number = ti.try_number self.max_tries = ti.max_tries self.hostname = ti.hostname self.unixname = ti.unixname diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index dfd66770c0797..c9c1ab166a188 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -457,7 +457,8 @@ def test_execute_task_instances_is_paused_wont_execute(self, session, dag_maker) (ti1,) = dr1.task_instances ti1.state = State.SCHEDULED - self.scheduler_job._critical_section_execute_task_instances(session) + self.scheduler_job._critical_section_enqueue_task_instances(session) + session.flush() ti1.refresh_from_db(session=session) assert State.SCHEDULED == ti1.state session.rollback() @@ -1315,8 +1316,9 @@ def test_enqueue_task_instances_sets_ti_state_to_None_if_dagrun_in_finish_state( session.commit() with patch.object(BaseExecutor, 'queue_command') as mock_queue_command: - self.scheduler_job._enqueue_task_instances_with_queued_state([ti]) - ti.refresh_from_db() + self.scheduler_job._enqueue_task_instances_with_queued_state([ti], session=session) + session.flush() + ti.refresh_from_db(session=session) assert ti.state == State.NONE mock_queue_command.assert_not_called()