diff --git a/airflow/models/dag.py b/airflow/models/dag.py index dac7be010ad54..92e1a8945fd5a 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -3049,8 +3049,8 @@ def bulk_write_to_db( ) query = with_row_locks(query, of=DagModel, session=session) orm_dags: list[DagModel] = session.scalars(query).unique().all() - existing_dags = {orm_dag.dag_id: orm_dag for orm_dag in orm_dags} - missing_dag_ids = dag_ids.difference(existing_dags) + existing_dags: dict[str, DagModel] = {x.dag_id: x for x in orm_dags} + missing_dag_ids = dag_ids.difference(existing_dags.keys()) for missing_dag_id in missing_dag_ids: orm_dag = DagModel(dag_id=missing_dag_id) @@ -3067,7 +3067,7 @@ def bulk_write_to_db( # Skip these queries entirely if no DAGs can be scheduled to save time. if any(dag.timetable.can_be_scheduled for dag in dags): # Get the latest automated dag run for each existing dag as a single query (avoid n+1 query) - query = cls._get_latest_runs_query(existing_dags, session) + query = cls._get_latest_runs_query(dags=list(existing_dags.keys())) latest_runs = {run.dag_id: run for run in session.scalars(query)} # Get number of active dagruns for all dags we are processing as a single query. @@ -3240,16 +3240,15 @@ def bulk_write_to_db( cls.bulk_write_to_db(dag.subdags, processor_subdir=processor_subdir, session=session) @classmethod - def _get_latest_runs_query(cls, dags, session) -> Query: + def _get_latest_runs_query(cls, dags: list[str]) -> Query: """ Query the database to retrieve the last automated run for each dag. :param dags: dags to query - :param session: sqlalchemy session object """ if len(dags) == 1: # Index optimized fast path to avoid more complicated & slower groupby queryplan - existing_dag_id = list(dags)[0].dag_id + existing_dag_id = dags[0] last_automated_runs_subq = ( select(func.max(DagRun.execution_date).label("max_execution_date")) .where( diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 7c337ed965c28..1f70ba051af09 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -4140,7 +4140,7 @@ def test_validate_setup_teardown_trigger_rule(self): def test_get_latest_runs_query_one_dag(dag_maker, session): with dag_maker(dag_id="dag1") as dag1: ... - query = DAG._get_latest_runs_query(dags=[dag1], session=session) + query = DAG._get_latest_runs_query(dags=[dag1.dag_id]) actual = [x.strip() for x in str(query.compile()).splitlines()] expected = [ "SELECT dag_run.id, dag_run.dag_id, dag_run.execution_date, dag_run.data_interval_start, dag_run.data_interval_end", @@ -4157,7 +4157,7 @@ def test_get_latest_runs_query_two_dags(dag_maker, session): ... with dag_maker(dag_id="dag2") as dag2: ... - query = DAG._get_latest_runs_query(dags=[dag1, dag2], session=session) + query = DAG._get_latest_runs_query(dags=[dag1.dag_id, dag2.dag_id]) actual = [x.strip() for x in str(query.compile()).splitlines()] print("\n".join(actual)) expected = [