Skip to content

Commit

Permalink
Replace save_to_db() to finish_task()
Browse files Browse the repository at this point in the history
  • Loading branch information
vincbeck committed Aug 31, 2023
1 parent 7921dc5 commit f76bc56
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 18 deletions.
2 changes: 1 addition & 1 deletion airflow/api_internal/endpoints/rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _initialize_map() -> dict[str, Callable]:
SerializedDagModel.get_serialized_dag,
TaskInstance.get_task_instance,
TaskInstance.fetch_handle_failure_context,
TaskInstance.save_to_db,
TaskInstance.finish_task,
TaskInstance.set_end_date,
Trigger.from_object,
Trigger.bulk_fetch,
Expand Down
86 changes: 69 additions & 17 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,17 @@ def _handle_failure(
)

if not test_mode:
TaskInstance.save_to_db(failure_context["ti"], session)
TaskInstance.finish_task(
task_id=task_instance.task_id,
dag_id=task_instance.dag_id,
run_id=task_instance.run_id,
map_index=task_instance.map_index,
end_date=failure_context["end_date"],
duration=failure_context["duration"],
state=failure_context["state"],
try_number=failure_context["try_number"],
session=session,
)


def _get_try_number(*, task_instance: TaskInstance | TaskInstancePydantic):
Expand Down Expand Up @@ -1151,6 +1161,24 @@ def _get_previous_ti(
return dagrun.get_task_instance(task_instance.task_id, session=session)


def _set_end_date(
task_instance: TaskInstance | TaskInstancePydantic,
end_date: datetime,
):
"""
Set the end date and compute the duration of the task instance.
:param task_instance: the task instance
:param end_date: the end date for the task
:meta private:
"""
task_instance.end_date = end_date
if task_instance.start_date:
task_instance.duration = (end_date - task_instance.start_date).total_seconds()
else:
task_instance.duration = None


class TaskInstance(Base, LoggingMixin):
"""
Task instances store the state of a task instance.
Expand Down Expand Up @@ -1734,12 +1762,8 @@ def set_end_date(
task_instance = session.get(
TaskInstance, {"task_id": task_id, "dag_id": dag_id, "run_id": run_id, "map_index": map_index}
)
task_instance.end_date = end_date
_set_end_date(task_instance=task_instance, end_date=end_date)

if task_instance.start_date:
task_instance.duration = (end_date - task_instance.start_date).total_seconds()
else:
task_instance.duration = None
cls.logger().debug("Task Duration set to %s", task_instance.duration)
session.commit()

Expand Down Expand Up @@ -2625,14 +2649,7 @@ def fetch_handle_failure_context(
if not test_mode:
ti.refresh_from_db(session)

TaskInstance.set_end_date(
dag_id=ti.dag_id,
run_id=ti.run_id,
task_id=ti.task_id,
map_index=ti.map_index,
end_date=timezone.utcnow(),
session=session,
)
_set_end_date(task_instance=ti, end_date=timezone.utcnow())

Stats.incr(f"operator_failures_{ti.operator}", tags=ti.stats_tags)
# Same metric with tagging
Expand Down Expand Up @@ -2689,7 +2706,10 @@ def fetch_handle_failure_context(
callbacks = task.on_retry_callback if task else None

return {
"ti": ti,
"end_date": ti.end_date,
"duration": ti.duration,
"state": ti.state,
"try_number": ti._try_number,
"email_for_state": email_for_state,
"task": task,
"callbacks": callbacks,
Expand All @@ -2699,8 +2719,40 @@ def fetch_handle_failure_context(
@staticmethod
@internal_api_call
@provide_session
def save_to_db(ti: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION):
session.merge(ti)
def finish_task(
task_id: str,
dag_id: str,
run_id: str,
map_index: int,
end_date: datetime,
duration: int,
state: str,
try_number: int,
session: Session = NEW_SESSION,
):
"""
Finish a task.
Set the end date, duration, state and the try number of the task instance.
:param task_id: the task ID
:param dag_id: the DAG ID
:param run_id: the run ID
:param map_index: the map index
:param end_date: the end date of the task to set
:param duration: the duration of the task to set
:param state: the state of the task to set
:param try_number: the try number of the task to set
:param session: SQLAlchemy ORM Session
"""
task_instance = session.get(
TaskInstance, {"task_id": task_id, "dag_id": dag_id, "run_id": run_id, "map_index": map_index}
)
task_instance.end_date = end_date
task_instance.duration = duration
task_instance.state = state
task_instance._try_number = try_number
session.merge(task_instance)
session.flush()

@provide_session
Expand Down

0 comments on commit f76bc56

Please sign in to comment.