From f76bc56a032cb068fc2d21ca789fc8c7f8138b94 Mon Sep 17 00:00:00 2001 From: Vincent Beck Date: Thu, 31 Aug 2023 13:45:38 -0400 Subject: [PATCH] Replace `save_to_db()` to `finish_task()` --- .../endpoints/rpc_api_endpoint.py | 2 +- airflow/models/taskinstance.py | 86 +++++++++++++++---- 2 files changed, 70 insertions(+), 18 deletions(-) diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py index 0e24f90ee5945..c8b09d840a535 100644 --- a/airflow/api_internal/endpoints/rpc_api_endpoint.py +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -70,7 +70,7 @@ def _initialize_map() -> dict[str, Callable]: SerializedDagModel.get_serialized_dag, TaskInstance.get_task_instance, TaskInstance.fetch_handle_failure_context, - TaskInstance.save_to_db, + TaskInstance.finish_task, TaskInstance.set_end_date, Trigger.from_object, Trigger.bulk_fetch, diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index db3d63f5dcd73..d8eb38da5b74d 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -805,7 +805,17 @@ def _handle_failure( ) if not test_mode: - TaskInstance.save_to_db(failure_context["ti"], session) + TaskInstance.finish_task( + task_id=task_instance.task_id, + dag_id=task_instance.dag_id, + run_id=task_instance.run_id, + map_index=task_instance.map_index, + end_date=failure_context["end_date"], + duration=failure_context["duration"], + state=failure_context["state"], + try_number=failure_context["try_number"], + session=session, + ) def _get_try_number(*, task_instance: TaskInstance | TaskInstancePydantic): @@ -1151,6 +1161,24 @@ def _get_previous_ti( return dagrun.get_task_instance(task_instance.task_id, session=session) +def _set_end_date( + task_instance: TaskInstance | TaskInstancePydantic, + end_date: datetime, +): + """ + Set the end date and compute the duration of the task instance. + + :param task_instance: the task instance + :param end_date: the end date for the task + :meta private: + """ + task_instance.end_date = end_date + if task_instance.start_date: + task_instance.duration = (end_date - task_instance.start_date).total_seconds() + else: + task_instance.duration = None + + class TaskInstance(Base, LoggingMixin): """ Task instances store the state of a task instance. @@ -1734,12 +1762,8 @@ def set_end_date( task_instance = session.get( TaskInstance, {"task_id": task_id, "dag_id": dag_id, "run_id": run_id, "map_index": map_index} ) - task_instance.end_date = end_date + _set_end_date(task_instance=task_instance, end_date=end_date) - if task_instance.start_date: - task_instance.duration = (end_date - task_instance.start_date).total_seconds() - else: - task_instance.duration = None cls.logger().debug("Task Duration set to %s", task_instance.duration) session.commit() @@ -2625,14 +2649,7 @@ def fetch_handle_failure_context( if not test_mode: ti.refresh_from_db(session) - TaskInstance.set_end_date( - dag_id=ti.dag_id, - run_id=ti.run_id, - task_id=ti.task_id, - map_index=ti.map_index, - end_date=timezone.utcnow(), - session=session, - ) + _set_end_date(task_instance=ti, end_date=timezone.utcnow()) Stats.incr(f"operator_failures_{ti.operator}", tags=ti.stats_tags) # Same metric with tagging @@ -2689,7 +2706,10 @@ def fetch_handle_failure_context( callbacks = task.on_retry_callback if task else None return { - "ti": ti, + "end_date": ti.end_date, + "duration": ti.duration, + "state": ti.state, + "try_number": ti._try_number, "email_for_state": email_for_state, "task": task, "callbacks": callbacks, @@ -2699,8 +2719,40 @@ def fetch_handle_failure_context( @staticmethod @internal_api_call @provide_session - def save_to_db(ti: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION): - session.merge(ti) + def finish_task( + task_id: str, + dag_id: str, + run_id: str, + map_index: int, + end_date: datetime, + duration: int, + state: str, + try_number: int, + session: Session = NEW_SESSION, + ): + """ + Finish a task. + + Set the end date, duration, state and the try number of the task instance. + + :param task_id: the task ID + :param dag_id: the DAG ID + :param run_id: the run ID + :param map_index: the map index + :param end_date: the end date of the task to set + :param duration: the duration of the task to set + :param state: the state of the task to set + :param try_number: the try number of the task to set + :param session: SQLAlchemy ORM Session + """ + task_instance = session.get( + TaskInstance, {"task_id": task_id, "dag_id": dag_id, "run_id": run_id, "map_index": map_index} + ) + task_instance.end_date = end_date + task_instance.duration = duration + task_instance.state = state + task_instance._try_number = try_number + session.merge(task_instance) session.flush() @provide_session