Skip to content

Commit

Permalink
TriggerDagRunOperator depreacte exection_date in favor of `logica…
Browse files Browse the repository at this point in the history
…l_date` (#39285)

* added logical_date parameter

* fix comment
  • Loading branch information
flolas authored Apr 27, 2024
1 parent 8dfdc3a commit c946fc3
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 90 deletions.
62 changes: 37 additions & 25 deletions airflow/operators/trigger_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
import datetime
import json
import time
import warnings
from typing import TYPE_CHECKING, Any, Sequence, cast

from sqlalchemy import select
from sqlalchemy.orm.exc import NoResultFound

from airflow.api.common.trigger_dag import trigger_dag
from airflow.configuration import conf
from airflow.exceptions import AirflowException, DagNotFound, DagRunAlreadyExists
from airflow.exceptions import AirflowException, DagNotFound, DagRunAlreadyExists, RemovedInAirflow3Warning
from airflow.models.baseoperator import BaseOperator
from airflow.models.baseoperatorlink import BaseOperatorLink
from airflow.models.dag import DagModel
Expand All @@ -41,7 +42,7 @@
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunType

XCOM_EXECUTION_DATE_ISO = "trigger_execution_date_iso"
XCOM_LOGICAL_DATE_ISO = "trigger_logical_date_iso"
XCOM_RUN_ID = "trigger_run_id"


Expand All @@ -64,7 +65,7 @@ class TriggerDagRunLink(BaseOperatorLink):
def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str:
# Fetch the correct execution date for the triggerED dag which is
# stored in xcom during execution of the triggerING task.
when = XCom.get_value(ti_key=ti_key, key=XCOM_EXECUTION_DATE_ISO)
when = XCom.get_value(ti_key=ti_key, key=XCOM_LOGICAL_DATE_ISO)
query = {"dag_id": cast(TriggerDagRunOperator, operator).trigger_dag_id, "base_date": when}
return build_airflow_url_with_query(query)

Expand All @@ -77,7 +78,7 @@ class TriggerDagRunOperator(BaseOperator):
:param trigger_run_id: The run ID to use for the triggered DAG run (templated).
If not provided, a run ID will be automatically generated.
:param conf: Configuration for the DAG run (templated).
:param execution_date: Execution date for the dag (templated).
:param logical_date: Logical date for the dag (templated).
:param reset_dag_run: Whether clear existing dag run if already exists.
This is useful when backfill or rerun an existing dag run.
This only resets (not recreates) the dag run.
Expand All @@ -91,12 +92,13 @@ class TriggerDagRunOperator(BaseOperator):
:param failed_states: List of failed or dis-allowed states, default is ``None``.
:param deferrable: If waiting for completion, whether or not to defer the task until done,
default is ``False``.
:param execution_date: Deprecated parameter; same as ``logical_date``.
"""

template_fields: Sequence[str] = (
"trigger_dag_id",
"trigger_run_id",
"execution_date",
"logical_date",
"conf",
"wait_for_completion",
)
Expand All @@ -110,13 +112,14 @@ def __init__(
trigger_dag_id: str,
trigger_run_id: str | None = None,
conf: dict | None = None,
execution_date: str | datetime.datetime | None = None,
logical_date: str | datetime.datetime | None = None,
reset_dag_run: bool = False,
wait_for_completion: bool = False,
poke_interval: int = 60,
allowed_states: list[str] | None = None,
failed_states: list[str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
execution_date: str | datetime.datetime | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -136,20 +139,29 @@ def __init__(
self.failed_states = [DagRunState.FAILED]
self._defer = deferrable

if execution_date is not None and not isinstance(execution_date, (str, datetime.datetime)):
if execution_date is not None:
warnings.warn(
"Parameter 'execution_date' is deprecated. Use 'logical_date' instead.",
RemovedInAirflow3Warning,
stacklevel=2,
)
logical_date = execution_date

if logical_date is not None and not isinstance(logical_date, (str, datetime.datetime)):
type_name = type(logical_date).__name__
raise TypeError(
f"Expected str or datetime.datetime type for execution_date.Got {type(execution_date)}"
f"Expected str or datetime.datetime type for parameter 'logical_date'. Got {type_name}"
)

self.execution_date = execution_date
self.logical_date = logical_date

def execute(self, context: Context):
if isinstance(self.execution_date, datetime.datetime):
parsed_execution_date = self.execution_date
elif isinstance(self.execution_date, str):
parsed_execution_date = timezone.parse(self.execution_date)
if isinstance(self.logical_date, datetime.datetime):
parsed_logical_date = self.logical_date
elif isinstance(self.logical_date, str):
parsed_logical_date = timezone.parse(self.logical_date)
else:
parsed_execution_date = timezone.utcnow()
parsed_logical_date = timezone.utcnow()

try:
json.dumps(self.conf)
Expand All @@ -159,20 +171,20 @@ def execute(self, context: Context):
if self.trigger_run_id:
run_id = str(self.trigger_run_id)
else:
run_id = DagRun.generate_run_id(DagRunType.MANUAL, parsed_execution_date)
run_id = DagRun.generate_run_id(DagRunType.MANUAL, parsed_logical_date)

try:
dag_run = trigger_dag(
dag_id=self.trigger_dag_id,
run_id=run_id,
conf=self.conf,
execution_date=parsed_execution_date,
execution_date=parsed_logical_date,
replace_microseconds=False,
)

except DagRunAlreadyExists as e:
if self.reset_dag_run:
self.log.info("Clearing %s on %s", self.trigger_dag_id, parsed_execution_date)
self.log.info("Clearing %s on %s", self.trigger_dag_id, parsed_logical_date)

# Get target dag object and call clear()
dag_model = DagModel.get_current(self.trigger_dag_id)
Expand All @@ -182,15 +194,15 @@ def execute(self, context: Context):
dag_bag = DagBag(dag_folder=dag_model.fileloc, read_dags_from_db=True)
dag = dag_bag.get_dag(self.trigger_dag_id)
dag_run = e.dag_run
dag.clear(start_date=dag_run.execution_date, end_date=dag_run.execution_date)
dag.clear(start_date=dag_run.logical_date, end_date=dag_run.logical_date)
else:
raise e
if dag_run is None:
raise RuntimeError("The dag_run should be set here!")
# Store the execution date from the dag run (either created or found above) to
# be used when creating the extra link on the webserver.
ti = context["task_instance"]
ti.xcom_push(key=XCOM_EXECUTION_DATE_ISO, value=dag_run.execution_date.isoformat())
ti.xcom_push(key=XCOM_LOGICAL_DATE_ISO, value=dag_run.logical_date.isoformat())
ti.xcom_push(key=XCOM_RUN_ID, value=dag_run.run_id)

if self.wait_for_completion:
Expand All @@ -200,7 +212,7 @@ def execute(self, context: Context):
trigger=DagStateTrigger(
dag_id=self.trigger_dag_id,
states=self.allowed_states + self.failed_states,
execution_dates=[parsed_execution_date],
execution_dates=[parsed_logical_date],
poll_interval=self.poke_interval,
),
method_name="execute_complete",
Expand All @@ -210,7 +222,7 @@ def execute(self, context: Context):
self.log.info(
"Waiting for %s on %s to become allowed state %s ...",
self.trigger_dag_id,
dag_run.execution_date,
dag_run.logical_date,
self.allowed_states,
)
time.sleep(self.poke_interval)
Expand All @@ -225,17 +237,17 @@ def execute(self, context: Context):

@provide_session
def execute_complete(self, context: Context, session: Session, event: tuple[str, dict[str, Any]]):
# This execution date is parsed from the return trigger event
provided_execution_date = event[1]["execution_dates"][0]
# This logical_date is parsed from the return trigger event
provided_logical_date = event[1]["execution_dates"][0]
try:
dag_run = session.execute(
select(DagRun).where(
DagRun.dag_id == self.trigger_dag_id, DagRun.execution_date == provided_execution_date
DagRun.dag_id == self.trigger_dag_id, DagRun.execution_date == provided_logical_date
)
).scalar_one()
except NoResultFound:
raise AirflowException(
f"No DAG run found for DAG {self.trigger_dag_id} and execution date {self.execution_date}"
f"No DAG run found for DAG {self.trigger_dag_id} and logical date {self.logical_date}"
)

state = dag_run.state
Expand Down
Loading

0 comments on commit c946fc3

Please sign in to comment.