-
Notifications
You must be signed in to change notification settings - Fork 14.6k
/
Copy pathdagrun.py
1527 lines (1343 loc) · 60.7 KB
/
dagrun.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import itertools
import os
import warnings
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, NamedTuple, Sequence, TypeVar, overload
import re2
from sqlalchemy import (
Boolean,
Column,
ForeignKey,
ForeignKeyConstraint,
Index,
Integer,
PickleType,
PrimaryKeyConstraint,
String,
Text,
UniqueConstraint,
and_,
func,
or_,
text,
update,
)
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import declared_attr, joinedload, relationship, synonym, validates
from sqlalchemy.sql.expression import false, select, true
from airflow import settings
from airflow.api_internal.internal_api_call import internal_api_call
from airflow.callbacks.callback_requests import DagCallbackRequest
from airflow.configuration import conf as airflow_conf
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, TaskNotFound
from airflow.listeners.listener import get_listener_manager
from airflow.models.abstractoperator import NotMapped
from airflow.models.base import Base, StringID
from airflow.models.expandinput import NotFullyPopulated
from airflow.models.taskinstance import TaskInstance as TI
from airflow.models.tasklog import LogTemplate
from airflow.stats import Stats
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_states import SCHEDULEABLE_STATES
from airflow.utils import timezone
from airflow.utils.helpers import chunks, is_container, prune_dict
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, skip_locked, tuple_in_condition, with_row_locks
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import NOTSET, DagRunType
if TYPE_CHECKING:
from datetime import datetime
from sqlalchemy.orm import Query, Session
from airflow.models.dag import DAG
from airflow.models.operator import Operator
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
from airflow.serialization.pydantic.tasklog import LogTemplatePydantic
from airflow.typing_compat import Literal
from airflow.utils.types import ArgNotSet
CreatedTasks = TypeVar("CreatedTasks", Iterator["dict[str, Any]"], Iterator[TI])
TaskCreator = Callable[[Operator, Iterable[int]], CreatedTasks]
RUN_ID_REGEX = r"^(?:manual|scheduled|dataset_triggered)__(?:\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\+00:00)$"
class TISchedulingDecision(NamedTuple):
"""Type of return for DagRun.task_instance_scheduling_decisions."""
tis: list[TI]
schedulable_tis: list[TI]
changed_tis: bool
unfinished_tis: list[TI]
finished_tis: list[TI]
def _creator_note(val):
"""Creator the ``note`` association proxy."""
if isinstance(val, str):
return DagRunNote(content=val)
elif isinstance(val, dict):
return DagRunNote(**val)
else:
return DagRunNote(*val)
class DagRun(Base, LoggingMixin):
"""Invocation instance of a DAG.
A DAG run can be created by the scheduler (i.e. scheduled runs), or by an
external trigger (i.e. manual runs).
"""
__tablename__ = "dag_run"
id = Column(Integer, primary_key=True)
dag_id = Column(StringID(), nullable=False)
queued_at = Column(UtcDateTime)
execution_date = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
start_date = Column(UtcDateTime)
end_date = Column(UtcDateTime)
_state = Column("state", String(50), default=DagRunState.QUEUED)
run_id = Column(StringID(), nullable=False)
creating_job_id = Column(Integer)
external_trigger = Column(Boolean, default=True)
run_type = Column(String(50), nullable=False)
conf = Column(PickleType)
# These two must be either both NULL or both datetime.
data_interval_start = Column(UtcDateTime)
data_interval_end = Column(UtcDateTime)
# When a scheduler last attempted to schedule TIs for this DagRun
last_scheduling_decision = Column(UtcDateTime)
dag_hash = Column(String(32))
# Foreign key to LogTemplate. DagRun rows created prior to this column's
# existence have this set to NULL. Later rows automatically populate this on
# insert to point to the latest LogTemplate entry.
log_template_id = Column(
Integer,
ForeignKey("log_template.id", name="task_instance_log_template_id_fkey", ondelete="NO ACTION"),
default=select(func.max(LogTemplate.__table__.c.id)),
)
updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow)
# Keeps track of the number of times the dagrun had been cleared.
# This number is incremented only when the DagRun is re-Queued,
# when the DagRun is cleared.
clear_number = Column(Integer, default=0, nullable=False)
# Remove this `if` after upgrading Sphinx-AutoAPI
if not TYPE_CHECKING and "BUILDING_AIRFLOW_DOCS" in os.environ:
dag: DAG | None
else:
dag: DAG | None = None
__table_args__ = (
Index("dag_id_state", dag_id, _state),
UniqueConstraint("dag_id", "execution_date", name="dag_run_dag_id_execution_date_key"),
UniqueConstraint("dag_id", "run_id", name="dag_run_dag_id_run_id_key"),
Index("idx_last_scheduling_decision", last_scheduling_decision),
Index("idx_dag_run_dag_id", dag_id),
Index(
"idx_dag_run_running_dags",
"state",
"dag_id",
postgresql_where=text("state='running'"),
sqlite_where=text("state='running'"),
),
# since mysql lacks filtered/partial indices, this creates a
# duplicate index on mysql. Not the end of the world
Index(
"idx_dag_run_queued_dags",
"state",
"dag_id",
postgresql_where=text("state='queued'"),
sqlite_where=text("state='queued'"),
),
)
task_instances = relationship(
TI, back_populates="dag_run", cascade="save-update, merge, delete, delete-orphan"
)
dag_model = relationship(
"DagModel",
primaryjoin="foreign(DagRun.dag_id) == DagModel.dag_id",
uselist=False,
viewonly=True,
)
dag_run_note = relationship(
"DagRunNote",
back_populates="dag_run",
uselist=False,
cascade="all, delete, delete-orphan",
)
note = association_proxy("dag_run_note", "content", creator=_creator_note)
DEFAULT_DAGRUNS_TO_EXAMINE = airflow_conf.getint(
"scheduler",
"max_dagruns_per_loop_to_schedule",
fallback=20,
)
def __init__(
self,
dag_id: str | None = None,
run_id: str | None = None,
queued_at: datetime | None | ArgNotSet = NOTSET,
execution_date: datetime | None = None,
start_date: datetime | None = None,
external_trigger: bool | None = None,
conf: Any | None = None,
state: DagRunState | None = None,
run_type: str | None = None,
dag_hash: str | None = None,
creating_job_id: int | None = None,
data_interval: tuple[datetime, datetime] | None = None,
):
if data_interval is None:
# Legacy: Only happen for runs created prior to Airflow 2.2.
self.data_interval_start = self.data_interval_end = None
else:
self.data_interval_start, self.data_interval_end = data_interval
self.dag_id = dag_id
self.run_id = run_id
self.execution_date = execution_date
self.start_date = start_date
self.external_trigger = external_trigger
self.conf = conf or {}
if state is not None:
self.state = state
if queued_at is NOTSET:
self.queued_at = timezone.utcnow() if state == DagRunState.QUEUED else None
else:
self.queued_at = queued_at
self.run_type = run_type
self.dag_hash = dag_hash
self.creating_job_id = creating_job_id
self.clear_number = 0
super().__init__()
def __repr__(self):
return (
f"<DagRun {self.dag_id} @ {self.execution_date}: {self.run_id}, state:{self.state}, "
f"queued_at: {self.queued_at}. externally triggered: {self.external_trigger}>"
)
@validates("run_id")
def validate_run_id(self, key: str, run_id: str) -> str | None:
if not run_id:
return None
regex = airflow_conf.get("scheduler", "allowed_run_id_pattern")
if not re2.match(regex, run_id) and not re2.match(RUN_ID_REGEX, run_id):
raise ValueError(
f"The run_id provided '{run_id}' does not match the pattern '{regex}' or '{RUN_ID_REGEX}'"
)
return run_id
@property
def stats_tags(self) -> dict[str, str]:
return prune_dict({"dag_id": self.dag_id, "run_type": self.run_type})
@property
def logical_date(self) -> datetime:
return self.execution_date
def get_state(self):
return self._state
def set_state(self, state: DagRunState) -> None:
if state not in State.dag_states:
raise ValueError(f"invalid DagRun state: {state}")
if self._state != state:
self._state = state
self.end_date = timezone.utcnow() if self._state in State.finished_dr_states else None
if state == DagRunState.QUEUED:
self.queued_at = timezone.utcnow()
@declared_attr
def state(self):
return synonym("_state", descriptor=property(self.get_state, self.set_state))
@provide_session
def refresh_from_db(self, session: Session = NEW_SESSION) -> None:
"""
Reload the current dagrun from the database.
:param session: database session
"""
dr = session.scalars(
select(DagRun).where(DagRun.dag_id == self.dag_id, DagRun.run_id == self.run_id)
).one()
self.id = dr.id
self.state = dr.state
@classmethod
@provide_session
def active_runs_of_dags(
cls,
dag_ids: Iterable[str] | None = None,
only_running: bool = False,
session: Session = NEW_SESSION,
) -> dict[str, int]:
"""Get the number of active dag runs for each dag."""
query = select(cls.dag_id, func.count("*"))
if dag_ids is not None:
# 'set' called to avoid duplicate dag_ids, but converted back to 'list'
# because SQLAlchemy doesn't accept a set here.
query = query.where(cls.dag_id.in_(set(dag_ids)))
if only_running:
query = query.where(cls.state == DagRunState.RUNNING)
else:
query = query.where(cls.state.in_((DagRunState.RUNNING, DagRunState.QUEUED)))
query = query.group_by(cls.dag_id)
return dict(iter(session.execute(query)))
@classmethod
def next_dagruns_to_examine(
cls,
state: DagRunState,
session: Session,
max_number: int | None = None,
) -> Query:
"""
Return the next DagRuns that the scheduler should attempt to schedule.
This will return zero or more DagRun rows that are row-level-locked with a "SELECT ... FOR UPDATE"
query, you should ensure that any scheduling decisions are made in a single transaction -- as soon as
the transaction is committed it will be unlocked.
"""
from airflow.models.dag import DagModel
if max_number is None:
max_number = cls.DEFAULT_DAGRUNS_TO_EXAMINE
# TODO: Bake this query, it is run _A lot_
query = (
select(cls)
.with_hint(cls, "USE INDEX (idx_dag_run_running_dags)", dialect_name="mysql")
.where(cls.state == state, cls.run_type != DagRunType.BACKFILL_JOB)
.join(DagModel, DagModel.dag_id == cls.dag_id)
.where(DagModel.is_paused == false(), DagModel.is_active == true())
)
if state == DagRunState.QUEUED:
# For dag runs in the queued state, we check if they have reached the max_active_runs limit
# and if so we drop them
running_drs = (
select(DagRun.dag_id, func.count(DagRun.state).label("num_running"))
.where(DagRun.state == DagRunState.RUNNING)
.group_by(DagRun.dag_id)
.subquery()
)
query = query.outerjoin(running_drs, running_drs.c.dag_id == DagRun.dag_id).where(
func.coalesce(running_drs.c.num_running, 0) < DagModel.max_active_runs
)
query = query.order_by(
nulls_first(cls.last_scheduling_decision, session=session),
cls.execution_date,
)
if not settings.ALLOW_FUTURE_EXEC_DATES:
query = query.where(DagRun.execution_date <= func.now())
return session.scalars(
with_row_locks(query.limit(max_number), of=cls, session=session, **skip_locked(session=session))
)
@classmethod
@provide_session
def find(
cls,
dag_id: str | list[str] | None = None,
run_id: Iterable[str] | None = None,
execution_date: datetime | Iterable[datetime] | None = None,
state: DagRunState | None = None,
external_trigger: bool | None = None,
no_backfills: bool = False,
run_type: DagRunType | None = None,
session: Session = NEW_SESSION,
execution_start_date: datetime | None = None,
execution_end_date: datetime | None = None,
) -> list[DagRun]:
"""
Return a set of dag runs for the given search criteria.
:param dag_id: the dag_id or list of dag_id to find dag runs for
:param run_id: defines the run id for this dag run
:param run_type: type of DagRun
:param execution_date: the execution date
:param state: the state of the dag run
:param external_trigger: whether this dag run is externally triggered
:param no_backfills: return no backfills (True), return all (False).
Defaults to False
:param session: database session
:param execution_start_date: dag run that was executed from this date
:param execution_end_date: dag run that was executed until this date
"""
qry = select(cls)
dag_ids = [dag_id] if isinstance(dag_id, str) else dag_id
if dag_ids:
qry = qry.where(cls.dag_id.in_(dag_ids))
if is_container(run_id):
qry = qry.where(cls.run_id.in_(run_id))
elif run_id is not None:
qry = qry.where(cls.run_id == run_id)
if is_container(execution_date):
qry = qry.where(cls.execution_date.in_(execution_date))
elif execution_date is not None:
qry = qry.where(cls.execution_date == execution_date)
if execution_start_date and execution_end_date:
qry = qry.where(cls.execution_date.between(execution_start_date, execution_end_date))
elif execution_start_date:
qry = qry.where(cls.execution_date >= execution_start_date)
elif execution_end_date:
qry = qry.where(cls.execution_date <= execution_end_date)
if state:
qry = qry.where(cls.state == state)
if external_trigger is not None:
qry = qry.where(cls.external_trigger == external_trigger)
if run_type:
qry = qry.where(cls.run_type == run_type)
if no_backfills:
qry = qry.where(cls.run_type != DagRunType.BACKFILL_JOB)
return session.scalars(qry.order_by(cls.execution_date)).all()
@classmethod
@provide_session
def find_duplicate(
cls,
dag_id: str,
run_id: str,
execution_date: datetime,
session: Session = NEW_SESSION,
) -> DagRun | None:
"""
Return an existing run for the DAG with a specific run_id or execution_date.
*None* is returned if no such DAG run is found.
:param dag_id: the dag_id to find duplicates for
:param run_id: defines the run id for this dag run
:param execution_date: the execution date
:param session: database session
"""
return session.scalars(
select(cls).where(
cls.dag_id == dag_id,
or_(cls.run_id == run_id, cls.execution_date == execution_date),
)
).one_or_none()
@staticmethod
def generate_run_id(run_type: DagRunType, execution_date: datetime) -> str:
"""Generate Run ID based on Run Type and Execution Date."""
# _Ensure_ run_type is a DagRunType, not just a string from user code
return DagRunType(run_type).generate_run_id(execution_date)
@staticmethod
@internal_api_call
@provide_session
def fetch_task_instances(
dag_id: str | None = None,
run_id: str | None = None,
task_ids: list[str] | None = None,
state: Iterable[TaskInstanceState | None] | None = None,
session: Session = NEW_SESSION,
) -> list[TI]:
"""Return the task instances for this dag run."""
tis = (
select(TI)
.options(joinedload(TI.dag_run))
.where(
TI.dag_id == dag_id,
TI.run_id == run_id,
)
)
if state:
if isinstance(state, str):
tis = tis.where(TI.state == state)
else:
# this is required to deal with NULL values
if None in state:
if all(x is None for x in state):
tis = tis.where(TI.state.is_(None))
else:
not_none_state = (s for s in state if s)
tis = tis.where(or_(TI.state.in_(not_none_state), TI.state.is_(None)))
else:
tis = tis.where(TI.state.in_(state))
if task_ids is not None:
tis = tis.where(TI.task_id.in_(task_ids))
return session.scalars(tis).all()
@provide_session
def get_task_instances(
self,
state: Iterable[TaskInstanceState | None] | None = None,
session: Session = NEW_SESSION,
) -> list[TI]:
"""
Returns the task instances for this dag run.
Redirect to DagRun.fetch_task_instances method.
Keep this method because it is widely used across the code.
"""
task_ids = self.dag.task_ids if self.dag and self.dag.partial else None
return DagRun.fetch_task_instances(
dag_id=self.dag_id, run_id=self.run_id, task_ids=task_ids, state=state, session=session
)
@provide_session
def get_task_instance(
self,
task_id: str,
session: Session = NEW_SESSION,
*,
map_index: int = -1,
) -> TI | TaskInstancePydantic | None:
"""
Return the task instance specified by task_id for this dag run.
:param task_id: the task id
:param session: Sqlalchemy ORM Session
"""
return DagRun.fetch_task_instance(
dag_id=self.dag_id,
dag_run_id=self.run_id,
task_id=task_id,
session=session,
map_index=map_index,
)
@staticmethod
@internal_api_call
@provide_session
def fetch_task_instance(
dag_id: str,
dag_run_id: str,
task_id: str,
session: Session = NEW_SESSION,
map_index: int = -1,
) -> TI | TaskInstancePydantic | None:
"""
Returns the task instance specified by task_id for this dag run.
:param dag_id: the DAG id
:param dag_run_id: the DAG run id
:param task_id: the task id
:param session: Sqlalchemy ORM Session
"""
return session.scalars(
select(TI).filter_by(dag_id=dag_id, run_id=dag_run_id, task_id=task_id, map_index=map_index)
).one_or_none()
def get_dag(self) -> DAG:
"""
Return the Dag associated with this DagRun.
:return: DAG
"""
if not self.dag:
raise AirflowException(f"The DAG (.dag) for {self} needs to be set")
return self.dag
@staticmethod
@internal_api_call
@provide_session
def get_previous_dagrun(
dag_run: DagRun | DagRunPydantic, state: DagRunState | None = None, session: Session = NEW_SESSION
) -> DagRun | None:
"""
Return the previous DagRun, if there is one.
:param dag_run: the dag run
:param session: SQLAlchemy ORM Session
:param state: the dag run state
"""
filters = [
DagRun.dag_id == dag_run.dag_id,
DagRun.execution_date < dag_run.execution_date,
]
if state is not None:
filters.append(DagRun.state == state)
return session.scalar(select(DagRun).where(*filters).order_by(DagRun.execution_date.desc()).limit(1))
@staticmethod
@internal_api_call
@provide_session
def get_previous_scheduled_dagrun(
dag_run_id: int,
session: Session = NEW_SESSION,
) -> DagRun | None:
"""
Return the previous SCHEDULED DagRun, if there is one.
:param dag_run_id: the DAG run ID
:param session: SQLAlchemy ORM Session
"""
dag_run = session.get(DagRun, dag_run_id)
return session.scalar(
select(DagRun)
.where(
DagRun.dag_id == dag_run.dag_id,
DagRun.execution_date < dag_run.execution_date,
DagRun.run_type != DagRunType.MANUAL,
)
.order_by(DagRun.execution_date.desc())
.limit(1)
)
def _tis_for_dagrun_state(self, *, dag, tis):
"""
Return the collection of tasks that should be considered for evaluation of terminal dag run state.
Teardown tasks by default are not considered for the purpose of dag run state. But
users may enable such consideration with on_failure_fail_dagrun.
"""
def is_effective_leaf(task):
for down_task_id in task.downstream_task_ids:
down_task = dag.get_task(down_task_id)
if not down_task.is_teardown or down_task.on_failure_fail_dagrun:
# we found a down task that is not ignorable; not a leaf
return False
# we found no ignorable downstreams
# evaluate whether task is itself ignorable
return not task.is_teardown or task.on_failure_fail_dagrun
leaf_task_ids = {x.task_id for x in dag.tasks if is_effective_leaf(x)}
if not leaf_task_ids:
# can happen if dag is exclusively teardown tasks
leaf_task_ids = {x.task_id for x in dag.tasks if not x.downstream_list}
leaf_tis = {ti for ti in tis if ti.task_id in leaf_task_ids if ti.state != TaskInstanceState.REMOVED}
return leaf_tis
@provide_session
def update_state(
self, session: Session = NEW_SESSION, execute_callbacks: bool = True
) -> tuple[list[TI], DagCallbackRequest | None]:
"""
Determine the overall state of the DagRun based on the state of its TaskInstances.
:param session: Sqlalchemy ORM Session
:param execute_callbacks: Should dag callbacks (success/failure, SLA etc.) be invoked
directly (default: true) or recorded as a pending request in the ``returned_callback`` property
:return: Tuple containing tis that can be scheduled in the current loop & `returned_callback` that
needs to be executed
"""
# Callback to execute in case of Task Failures
callback: DagCallbackRequest | None = None
class _UnfinishedStates(NamedTuple):
tis: Sequence[TI]
@classmethod
def calculate(cls, unfinished_tis: Sequence[TI]) -> _UnfinishedStates:
return cls(tis=unfinished_tis)
@property
def should_schedule(self) -> bool:
return (
bool(self.tis)
and all(not t.task.depends_on_past for t in self.tis)
and all(t.task.max_active_tis_per_dag is None for t in self.tis)
and all(t.task.max_active_tis_per_dagrun is None for t in self.tis)
and all(t.state != TaskInstanceState.DEFERRED for t in self.tis)
)
def recalculate(self) -> _UnfinishedStates:
return self._replace(tis=[t for t in self.tis if t.state in State.unfinished])
start_dttm = timezone.utcnow()
self.last_scheduling_decision = start_dttm
with Stats.timer(f"dagrun.dependency-check.{self.dag_id}"), Stats.timer(
"dagrun.dependency-check", tags=self.stats_tags
):
dag = self.get_dag()
info = self.task_instance_scheduling_decisions(session)
tis = info.tis
schedulable_tis = info.schedulable_tis
changed_tis = info.changed_tis
finished_tis = info.finished_tis
unfinished = _UnfinishedStates.calculate(info.unfinished_tis)
if unfinished.should_schedule:
are_runnable_tasks = schedulable_tis or changed_tis
# small speed up
if not are_runnable_tasks:
are_runnable_tasks, changed_by_upstream = self._are_premature_tis(
unfinished.tis, finished_tis, session
)
if changed_by_upstream: # Something changed, we need to recalculate!
unfinished = unfinished.recalculate()
tis_for_dagrun_state = self._tis_for_dagrun_state(dag=dag, tis=tis)
# if all tasks finished and at least one failed, the run failed
if not unfinished.tis and any(x.state in State.failed_states for x in tis_for_dagrun_state):
self.log.error("Marking run %s failed", self)
self.set_state(DagRunState.FAILED)
self.notify_dagrun_state_changed(msg="task_failure")
if execute_callbacks:
dag.handle_callback(self, success=False, reason="task_failure", session=session)
elif dag.has_on_failure_callback:
from airflow.models.dag import DagModel
dag_model = DagModel.get_dagmodel(dag.dag_id, session)
callback = DagCallbackRequest(
full_filepath=dag.fileloc,
dag_id=self.dag_id,
run_id=self.run_id,
is_failure_callback=True,
processor_subdir=None if dag_model is None else dag_model.processor_subdir,
msg="task_failure",
)
# if all leaves succeeded and no unfinished tasks, the run succeeded
elif not unfinished.tis and all(x.state in State.success_states for x in tis_for_dagrun_state):
self.log.info("Marking run %s successful", self)
self.set_state(DagRunState.SUCCESS)
self.notify_dagrun_state_changed(msg="success")
if execute_callbacks:
dag.handle_callback(self, success=True, reason="success", session=session)
elif dag.has_on_success_callback:
from airflow.models.dag import DagModel
dag_model = DagModel.get_dagmodel(dag.dag_id, session)
callback = DagCallbackRequest(
full_filepath=dag.fileloc,
dag_id=self.dag_id,
run_id=self.run_id,
is_failure_callback=False,
processor_subdir=None if dag_model is None else dag_model.processor_subdir,
msg="success",
)
# if *all tasks* are deadlocked, the run failed
elif unfinished.should_schedule and not are_runnable_tasks:
self.log.error("Task deadlock (no runnable tasks); marking run %s failed", self)
self.set_state(DagRunState.FAILED)
self.notify_dagrun_state_changed(msg="all_tasks_deadlocked")
if execute_callbacks:
dag.handle_callback(self, success=False, reason="all_tasks_deadlocked", session=session)
elif dag.has_on_failure_callback:
from airflow.models.dag import DagModel
dag_model = DagModel.get_dagmodel(dag.dag_id, session)
callback = DagCallbackRequest(
full_filepath=dag.fileloc,
dag_id=self.dag_id,
run_id=self.run_id,
is_failure_callback=True,
processor_subdir=None if dag_model is None else dag_model.processor_subdir,
msg="all_tasks_deadlocked",
)
# finally, if the leaves aren't done, the dag is still running
else:
self.set_state(DagRunState.RUNNING)
if self._state == DagRunState.FAILED or self._state == DagRunState.SUCCESS:
msg = (
"DagRun Finished: dag_id=%s, execution_date=%s, run_id=%s, "
"run_start_date=%s, run_end_date=%s, run_duration=%s, "
"state=%s, external_trigger=%s, run_type=%s, "
"data_interval_start=%s, data_interval_end=%s, dag_hash=%s"
)
self.log.info(
msg,
self.dag_id,
self.execution_date,
self.run_id,
self.start_date,
self.end_date,
(self.end_date - self.start_date).total_seconds()
if self.start_date and self.end_date
else None,
self._state,
self.external_trigger,
self.run_type,
self.data_interval_start,
self.data_interval_end,
self.dag_hash,
)
session.flush()
self._emit_true_scheduling_delay_stats_for_finished_state(finished_tis)
self._emit_duration_stats_for_finished_state()
session.merge(self)
# We do not flush here for performance reasons(It increases queries count by +20)
return schedulable_tis, callback
@provide_session
def task_instance_scheduling_decisions(self, session: Session = NEW_SESSION) -> TISchedulingDecision:
tis = self.get_task_instances(session=session, state=State.task_states)
self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis))
def _filter_tis_and_exclude_removed(dag: DAG, tis: list[TI]) -> Iterable[TI]:
"""Populate ``ti.task`` while excluding those missing one, marking them as REMOVED."""
for ti in tis:
try:
ti.task = dag.get_task(ti.task_id)
except TaskNotFound:
if ti.state != TaskInstanceState.REMOVED:
self.log.error("Failed to get task for ti %s. Marking it as removed.", ti)
ti.state = TaskInstanceState.REMOVED
session.flush()
else:
yield ti
tis = list(_filter_tis_and_exclude_removed(self.get_dag(), tis))
unfinished_tis = [t for t in tis if t.state in State.unfinished]
finished_tis = [t for t in tis if t.state in State.finished]
if unfinished_tis:
schedulable_tis = [ut for ut in unfinished_tis if ut.state in SCHEDULEABLE_STATES]
self.log.debug("number of scheduleable tasks for %s: %s task(s)", self, len(schedulable_tis))
schedulable_tis, changed_tis, expansion_happened = self._get_ready_tis(
schedulable_tis,
finished_tis,
session=session,
)
# During expansion, we may change some tis into non-schedulable
# states, so we need to re-compute.
if expansion_happened:
changed_tis = True
new_unfinished_tis = [t for t in unfinished_tis if t.state in State.unfinished]
finished_tis.extend(t for t in unfinished_tis if t.state in State.finished)
unfinished_tis = new_unfinished_tis
else:
schedulable_tis = []
changed_tis = False
return TISchedulingDecision(
tis=tis,
schedulable_tis=schedulable_tis,
changed_tis=changed_tis,
unfinished_tis=unfinished_tis,
finished_tis=finished_tis,
)
def notify_dagrun_state_changed(self, msg: str = ""):
if self.state == DagRunState.RUNNING:
get_listener_manager().hook.on_dag_run_running(dag_run=self, msg=msg)
elif self.state == DagRunState.SUCCESS:
get_listener_manager().hook.on_dag_run_success(dag_run=self, msg=msg)
elif self.state == DagRunState.FAILED:
get_listener_manager().hook.on_dag_run_failed(dag_run=self, msg=msg)
# deliberately not notifying on QUEUED
# we can't get all the state changes on SchedulerJob, BackfillJob
# or LocalTaskJob, so we don't want to "falsely advertise" we notify about that
def _get_ready_tis(
self,
schedulable_tis: list[TI],
finished_tis: list[TI],
session: Session,
) -> tuple[list[TI], bool, bool]:
old_states = {}
ready_tis: list[TI] = []
changed_tis = False
if not schedulable_tis:
return ready_tis, changed_tis, False
# If we expand TIs, we need a new list so that we iterate over them too. (We can't alter
# `schedulable_tis` in place and have the `for` loop pick them up
additional_tis: list[TI] = []
dep_context = DepContext(
flag_upstream_failed=True,
ignore_unmapped_tasks=True, # Ignore this Dep, as we will expand it if we can.
finished_tis=finished_tis,
)
def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None:
"""Try to expand the ti, if needed.
If the ti needs expansion, newly created task instances are
returned as well as the original ti.
The original ti is also modified in-place and assigned the
``map_index`` of 0.
If the ti does not need expansion, either because the task is not
mapped, or has already been expanded, *None* is returned.
"""
if ti.map_index >= 0: # Already expanded, we're good.
return None
from airflow.models.mappedoperator import MappedOperator
if isinstance(ti.task, MappedOperator):
# If we get here, it could be that we are moving from non-mapped to mapped
# after task instance clearing or this ti is not yet expanded. Safe to clear
# the db references.
ti.clear_db_references(session=session)
try:
expanded_tis, _ = ti.task.expand_mapped_task(self.run_id, session=session)
except NotMapped: # Not a mapped task, nothing needed.
return None
if expanded_tis:
return expanded_tis
return ()
# Check dependencies.
expansion_happened = False
# Set of task ids for which was already done _revise_map_indexes_if_mapped
revised_map_index_task_ids = set()
for schedulable in itertools.chain(schedulable_tis, additional_tis):
old_state = schedulable.state
if not schedulable.are_dependencies_met(session=session, dep_context=dep_context):
old_states[schedulable.key] = old_state
continue
# If schedulable is not yet expanded, try doing it now. This is
# called in two places: First and ideally in the mini scheduler at
# the end of LocalTaskJob, and then as an "expansion of last resort"
# in the scheduler to ensure that the mapped task is correctly
# expanded before executed. Also see _revise_map_indexes_if_mapped
# docstring for additional information.
new_tis = None
if schedulable.map_index < 0:
new_tis = _expand_mapped_task_if_needed(schedulable)
if new_tis is not None:
additional_tis.extend(new_tis)
expansion_happened = True
if new_tis is None and schedulable.state in SCHEDULEABLE_STATES:
# It's enough to revise map index once per task id,
# checking the map index for each mapped task significantly slows down scheduling
if schedulable.task.task_id not in revised_map_index_task_ids:
ready_tis.extend(self._revise_map_indexes_if_mapped(schedulable.task, session=session))
revised_map_index_task_ids.add(schedulable.task.task_id)
ready_tis.append(schedulable)
# Check if any ti changed state
tis_filter = TI.filter_for_tis(old_states)
if tis_filter is not None:
fresh_tis = session.scalars(select(TI).where(tis_filter)).all()
changed_tis = any(ti.state != old_states[ti.key] for ti in fresh_tis)
return ready_tis, changed_tis, expansion_happened
def _are_premature_tis(
self,
unfinished_tis: Sequence[TI],
finished_tis: list[TI],
session: Session,
) -> tuple[bool, bool]:
dep_context = DepContext(
flag_upstream_failed=True,
ignore_in_retry_period=True,
ignore_in_reschedule_period=True,
finished_tis=finished_tis,
)
# there might be runnable tasks that are up for retry and for some reason(retry delay, etc.) are
# not ready yet, so we set the flags to count them in
return (
any(ut.are_dependencies_met(dep_context=dep_context, session=session) for ut in unfinished_tis),
dep_context.have_changed_ti_states,
)
def _emit_true_scheduling_delay_stats_for_finished_state(self, finished_tis: list[TI]) -> None:
"""Emit the true scheduling delay stats.
The true scheduling delay stats is defined as the time when the first
task in DAG starts minus the expected DAG run datetime.
This helper method is used in ``update_state`` when the state of the
DAG run is updated to a completed status (either success or failure).
It finds the first started task within the DAG, calculates the run's
expected start time based on the logical date and timetable, and gets
the delay from the difference of these two values.
The emitted data may contain outliers (e.g. when the first task was
cleared, so the second task's start date will be used), but we can get
rid of the outliers on the stats side through dashboards tooling.
Note that the stat will only be emitted for scheduler-triggered DAG runs
(i.e. when ``external_trigger`` is *False* and ``clear_number`` is
greater than 0).
"""
if self.state == TaskInstanceState.RUNNING:
return
if self.external_trigger:
return
if self.clear_number > 0:
return