Skip to content

Commit

Permalink
Use namedtuple for TaskInstanceKeyType (#9712)
Browse files Browse the repository at this point in the history
* Use namedtuple for TaskInstanceKeyType

GitOrigin-RevId: ecf2f8499b78ea12751c5e364e57d34912c25f7c
  • Loading branch information
turbaszek authored and Cloud Composer Team committed Sep 12, 2024
1 parent bff0c8f commit 6a34c25
Show file tree
Hide file tree
Showing 13 changed files with 125 additions and 94 deletions.
29 changes: 14 additions & 15 deletions airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import Any, Dict, List, Optional, Set, Tuple, Union

from airflow.configuration import conf
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance, TaskInstanceKeyType
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance, TaskInstanceKey
from airflow.stats import Stats
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import State
Expand Down Expand Up @@ -58,10 +58,10 @@ class BaseExecutor(LoggingMixin):
def __init__(self, parallelism: int = PARALLELISM):
super().__init__()
self.parallelism: int = parallelism
self.queued_tasks: OrderedDict[TaskInstanceKeyType, QueuedTaskInstanceType] \
self.queued_tasks: OrderedDict[TaskInstanceKey, QueuedTaskInstanceType] \
= OrderedDict()
self.running: Set[TaskInstanceKeyType] = set()
self.event_buffer: Dict[TaskInstanceKeyType, EventBufferValueType] = {}
self.running: Set[TaskInstanceKey] = set()
self.event_buffer: Dict[TaskInstanceKey, EventBufferValueType] = {}

def start(self): # pragma: no cover
"""
Expand Down Expand Up @@ -155,7 +155,7 @@ def heartbeat(self) -> None:
self.log.debug("Calling the %s sync method", self.__class__)
self.sync()

def order_queued_tasks_by_priority(self) -> List[Tuple[TaskInstanceKeyType, QueuedTaskInstanceType]]:
def order_queued_tasks_by_priority(self) -> List[Tuple[TaskInstanceKey, QueuedTaskInstanceType]]:
"""
Orders the queued tasks by priority.
Expand Down Expand Up @@ -183,7 +183,7 @@ def trigger_tasks(self, open_slots: int) -> None:
queue=None,
executor_config=simple_ti.executor_config)

def change_state(self, key: TaskInstanceKeyType, state: str, info=None) -> None:
def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None:
"""
Changes state of the task.
Expand All @@ -198,7 +198,7 @@ def change_state(self, key: TaskInstanceKeyType, state: str, info=None) -> None:
self.log.debug('Could not find key: %s', str(key))
self.event_buffer[key] = state, info

def fail(self, key: TaskInstanceKeyType, info=None) -> None:
def fail(self, key: TaskInstanceKey, info=None) -> None:
"""
Set fail state for the event.
Expand All @@ -207,7 +207,7 @@ def fail(self, key: TaskInstanceKeyType, info=None) -> None:
"""
self.change_state(key, State.FAILED, info)

def success(self, key: TaskInstanceKeyType, info=None) -> None:
def success(self, key: TaskInstanceKey, info=None) -> None:
"""
Set success state for the event.
Expand All @@ -216,7 +216,7 @@ def success(self, key: TaskInstanceKeyType, info=None) -> None:
"""
self.change_state(key, State.SUCCESS, info)

def get_event_buffer(self, dag_ids=None) -> Dict[TaskInstanceKeyType, EventBufferValueType]:
def get_event_buffer(self, dag_ids=None) -> Dict[TaskInstanceKey, EventBufferValueType]:
"""
Returns and flush the event buffer. In case dag_ids is specified
it will only return and flush events for the given dag_ids. Otherwise
Expand All @@ -225,20 +225,19 @@ def get_event_buffer(self, dag_ids=None) -> Dict[TaskInstanceKeyType, EventBuffe
:param dag_ids: to dag_ids to return events for, if None returns all
:return: a dict of events
"""
cleared_events: Dict[TaskInstanceKeyType, EventBufferValueType] = {}
cleared_events: Dict[TaskInstanceKey, EventBufferValueType] = {}
if dag_ids is None:
cleared_events = self.event_buffer
self.event_buffer = {}
else:
for key in list(self.event_buffer.keys()):
dag_id, _, _, _ = key
if dag_id in dag_ids:
cleared_events[key] = self.event_buffer.pop(key)
for ti_key in list(self.event_buffer.keys()):
if ti_key.dag_id in dag_ids:
cleared_events[ti_key] = self.event_buffer.pop(ti_key)

return cleared_events

def execute_async(self,
key: TaskInstanceKeyType,
key: TaskInstanceKey,
command: CommandType,
queue: Optional[str] = None,
executor_config: Optional[Any] = None) -> None: # pragma: no cover
Expand Down
12 changes: 6 additions & 6 deletions airflow/executors/celery_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.executors.base_executor import BaseExecutor, CommandType, EventBufferValueType
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKeyType
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKey
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.net import get_hostname
from airflow.utils.timeout import timeout
Expand Down Expand Up @@ -102,12 +102,12 @@ def __init__(self, exception: Exception, exception_traceback: str):


# Task instance that is sent over Celery queues
# TaskInstanceKeyType, SimpleTaskInstance, Command, queue_name, CallableTask
TaskInstanceInCelery = Tuple[TaskInstanceKeyType, SimpleTaskInstance, CommandType, Optional[str], Task]
# TaskInstanceKey, SimpleTaskInstance, Command, queue_name, CallableTask
TaskInstanceInCelery = Tuple[TaskInstanceKey, SimpleTaskInstance, CommandType, Optional[str], Task]


def send_task_to_executor(task_tuple: TaskInstanceInCelery) \
-> Tuple[TaskInstanceKeyType, CommandType, Union[AsyncResult, ExceptionWithTraceback]]:
-> Tuple[TaskInstanceKey, CommandType, Union[AsyncResult, ExceptionWithTraceback]]:
"""Sends task to executor."""
key, _, command, queue, task_to_run = task_tuple
try:
Expand Down Expand Up @@ -235,7 +235,7 @@ def update_all_task_states(self) -> None:
if state:
self.update_task_state(key, state, info)

def update_task_state(self, key: TaskInstanceKeyType, state: str, info: Any) -> None:
def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None:
"""Updates state of a single task."""
# noinspection PyBroadException
try:
Expand Down Expand Up @@ -265,7 +265,7 @@ def end(self, synchronous: bool = False) -> None:
self.sync()

def execute_async(self,
key: TaskInstanceKeyType,
key: TaskInstanceKey,
command: CommandType,
queue: Optional[str] = None,
executor_config: Optional[Any] = None):
Expand Down
6 changes: 3 additions & 3 deletions airflow/executors/dask_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.executors.base_executor import NOT_STARTED_MESSAGE, BaseExecutor, CommandType
from airflow.models.taskinstance import TaskInstanceKeyType
from airflow.models.taskinstance import TaskInstanceKey


class DaskExecutor(BaseExecutor):
Expand All @@ -50,7 +50,7 @@ def __init__(self, cluster_address=None):
self.tls_key = conf.get('dask', 'tls_key')
self.tls_cert = conf.get('dask', 'tls_cert')
self.client: Optional[Client] = None
self.futures: Optional[Dict[Future, TaskInstanceKeyType]] = None
self.futures: Optional[Dict[Future, TaskInstanceKey]] = None

def start(self) -> None:
if self.tls_ca or self.tls_key or self.tls_cert:
Expand All @@ -67,7 +67,7 @@ def start(self) -> None:
self.futures = {}

def execute_async(self,
key: TaskInstanceKeyType,
key: TaskInstanceKey,
command: CommandType,
queue: Optional[str] = None,
executor_config: Optional[Any] = None) -> None:
Expand Down
6 changes: 3 additions & 3 deletions airflow/executors/debug_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from airflow.configuration import conf
from airflow.executors.base_executor import BaseExecutor
from airflow.models.taskinstance import TaskInstance, TaskInstanceKeyType
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.utils.state import State


Expand All @@ -46,7 +46,7 @@ def __init__(self):
super().__init__()
self.tasks_to_run: List[TaskInstance] = []
# Place where we keep information for task instance raw run
self.tasks_params: Dict[TaskInstanceKeyType, Dict[str, Any]] = {}
self.tasks_params: Dict[TaskInstanceKey, Dict[str, Any]] = {}
self.fail_fast = conf.getboolean("debug", "fail_fast")

def execute_async(self, *args, **kwargs) -> None: # pylint: disable=signature-differs
Expand Down Expand Up @@ -147,7 +147,7 @@ def end(self) -> None:
def terminate(self) -> None:
self._terminated.set()

def change_state(self, key: TaskInstanceKeyType, state: str, info=None) -> None:
def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None:
self.log.debug("Popping %s from executor task queue.", key)
self.running.remove(key)
self.event_buffer[key] = state, info
16 changes: 8 additions & 8 deletions airflow/executors/kubernetes_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,18 @@
from airflow.kubernetes.pod_launcher import PodLauncher
from airflow.kubernetes.worker_configuration import WorkerConfiguration
from airflow.models import KubeResourceVersion, KubeWorkerIdentifier, TaskInstance
from airflow.models.taskinstance import TaskInstanceKeyType
from airflow.models.taskinstance import TaskInstanceKey
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import create_session, provide_session
from airflow.utils.state import State

MAX_LABEL_LEN = 63

# TaskInstance key, command, configuration
KubernetesJobType = Tuple[TaskInstanceKeyType, CommandType, Any]
KubernetesJobType = Tuple[TaskInstanceKey, CommandType, Any]

# key, state, pod_id, namespace, resource_version
KubernetesResultsType = Tuple[TaskInstanceKeyType, Optional[str], str, str, str]
KubernetesResultsType = Tuple[TaskInstanceKey, Optional[str], str, str, str]

# pod_id, namespace, state, labels, resource_version
KubernetesWatchType = Tuple[str, str, Optional[str], Dict[str, str], str]
Expand Down Expand Up @@ -588,7 +588,7 @@ def _datetime_to_label_safe_datestring(datetime_obj: datetime.datetime) -> str:
"""
return datetime_obj.isoformat().replace(":", "_").replace('+', '_plus_')

def _labels_to_key(self, labels: Dict[str, str]) -> Optional[TaskInstanceKeyType]:
def _labels_to_key(self, labels: Dict[str, str]) -> Optional[TaskInstanceKey]:
try_num = 1
try:
try_num = int(labels.get('try_number', '1'))
Expand Down Expand Up @@ -618,7 +618,7 @@ def _labels_to_key(self, labels: Dict[str, str]) -> Optional[TaskInstanceKeyType
'Found matching task %s-%s (%s) with current state of %s',
task.dag_id, task.task_id, task.execution_date, task.state
)
return (dag_id, task_id, ex_time, try_num)
return TaskInstanceKey(dag_id, task_id, ex_time, try_num)
else:
self.log.warning(
'task_id/dag_id are not safe to use as Kubernetes labels. This can cause '
Expand Down Expand Up @@ -649,7 +649,7 @@ def _labels_to_key(self, labels: Dict[str, str]) -> Optional[TaskInstanceKeyType
)
dag_id = task.dag_id
task_id = task.task_id
return dag_id, task_id, ex_time, try_num
return TaskInstanceKey(dag_id, task_id, ex_time, try_num)
self.log.warning(
'Failed to find and match task details to a pod; labels: %s',
labels
Expand Down Expand Up @@ -798,7 +798,7 @@ def start(self) -> None:
self.clear_not_launched_queued_tasks()

def execute_async(self,
key: TaskInstanceKeyType,
key: TaskInstanceKey,
command: CommandType,
queue: Optional[str] = None,
executor_config: Optional[Any] = None) -> None:
Expand Down Expand Up @@ -871,7 +871,7 @@ def sync(self) -> None:
# pylint: enable=too-many-nested-blocks

def _change_state(self,
key: TaskInstanceKeyType,
key: TaskInstanceKey,
state: Optional[str],
pod_id: str,
namespace: str) -> None:
Expand Down
16 changes: 8 additions & 8 deletions airflow/executors/local_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@
from airflow.exceptions import AirflowException
from airflow.executors.base_executor import NOT_STARTED_MESSAGE, PARALLELISM, BaseExecutor, CommandType
from airflow.models.taskinstance import ( # pylint: disable=unused-import # noqa: F401
TaskInstanceKeyType, TaskInstanceStateType,
TaskInstanceKey, TaskInstanceStateType,
)
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import State

# This is a work to be executed by a worker.
# It can Key and Command - but it can also be None, None which is actually a
# "Poison Pill" - worker seeing Poison Pill should take the pill and ... die instantly.
ExecutorWorkType = Tuple[Optional[TaskInstanceKeyType], Optional[CommandType]]
ExecutorWorkType = Tuple[Optional[TaskInstanceKey], Optional[CommandType]]


class LocalWorkerBase(Process, LoggingMixin):
Expand All @@ -54,7 +54,7 @@ def __init__(self, result_queue: 'Queue[TaskInstanceStateType]'):
self.daemon: bool = True
self.result_queue: 'Queue[TaskInstanceStateType]' = result_queue

def execute_work(self, key: TaskInstanceKeyType, command: CommandType) -> None:
def execute_work(self, key: TaskInstanceKey, command: CommandType) -> None:
"""
Executes command received and stores result state in queue.
Expand Down Expand Up @@ -83,10 +83,10 @@ class LocalWorker(LocalWorkerBase):
"""
def __init__(self,
result_queue: 'Queue[TaskInstanceStateType]',
key: TaskInstanceKeyType,
key: TaskInstanceKey,
command: CommandType):
super().__init__(result_queue)
self.key: TaskInstanceKeyType = key
self.key: TaskInstanceKey = key
self.command: CommandType = command

def run(self) -> None:
Expand Down Expand Up @@ -156,7 +156,7 @@ def start(self) -> None:
# pylint: disable=unused-argument # pragma: no cover
# noinspection PyUnusedLocal
def execute_async(self,
key: TaskInstanceKeyType,
key: TaskInstanceKey,
command: CommandType,
queue: Optional[str] = None,
executor_config: Optional[Any] = None) -> None:
Expand Down Expand Up @@ -227,7 +227,7 @@ def start(self) -> None:
# noinspection PyUnusedLocal
def execute_async(
self,
key: TaskInstanceKeyType,
key: TaskInstanceKey,
command: CommandType,
queue: Optional[str] = None, # pylint: disable=unused-argument
executor_config: Optional[Any] = None # pylint: disable=unused-argument
Expand Down Expand Up @@ -279,7 +279,7 @@ def start(self) -> None:

self.impl.start()

def execute_async(self, key: TaskInstanceKeyType,
def execute_async(self, key: TaskInstanceKey,
command: CommandType,
queue: Optional[str] = None,
executor_config: Optional[Any] = None) -> None:
Expand Down
4 changes: 2 additions & 2 deletions airflow/executors/sequential_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from typing import Any, Optional

from airflow.executors.base_executor import BaseExecutor, CommandType
from airflow.models.taskinstance import TaskInstanceKeyType
from airflow.models.taskinstance import TaskInstanceKey
from airflow.utils.state import State


Expand All @@ -45,7 +45,7 @@ def __init__(self):
self.commands_to_run = []

def execute_async(self,
key: TaskInstanceKeyType,
key: TaskInstanceKey,
command: CommandType,
queue: Optional[str] = None,
executor_config: Optional[Any] = None) -> None:
Expand Down
22 changes: 12 additions & 10 deletions airflow/jobs/backfill_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from airflow.jobs.base_job import BaseJob
from airflow.models import DAG, DagPickle
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance, TaskInstanceKeyType
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import BACKFILL_QUEUED_DEPS
from airflow.utils import timezone
Expand Down Expand Up @@ -67,17 +67,17 @@ class _DagRunTaskStatus:
it easier to pass it around.
:param to_run: Tasks to run in the backfill
:type to_run: dict[tuple[TaskInstanceKeyType], airflow.models.TaskInstance]
:type to_run: dict[tuple[TaskInstanceKey], airflow.models.TaskInstance]
:param running: Maps running task instance key to task instance object
:type running: dict[tuple[TaskInstanceKeyType], airflow.models.TaskInstance]
:type running: dict[tuple[TaskInstanceKey], airflow.models.TaskInstance]
:param skipped: Tasks that have been skipped
:type skipped: set[tuple[TaskInstanceKeyType]]
:type skipped: set[tuple[TaskInstanceKey]]
:param succeeded: Tasks that have succeeded so far
:type succeeded: set[tuple[TaskInstanceKeyType]]
:type succeeded: set[tuple[TaskInstanceKey]]
:param failed: Tasks that have failed
:type failed: set[tuple[TaskInstanceKeyType]]
:type failed: set[tuple[TaskInstanceKey]]
:param not_ready: Tasks not ready for execution
:type not_ready: set[tuple[TaskInstanceKeyType]]
:type not_ready: set[tuple[TaskInstanceKey]]
:param deadlocked: Deadlocked tasks
:type deadlocked: set[airflow.models.TaskInstance]
:param active_runs: Active dag runs at a certain point in time
Expand Down Expand Up @@ -196,7 +196,7 @@ def _update_counters(self, ti_status, session=None):

for ti in refreshed_tis:
# Here we remake the key by subtracting 1 to match in memory information
key = (ti.dag_id, ti.task_id, ti.execution_date, max(1, ti.try_number - 1))
key = ti.key.reduced
if ti.state == State.SUCCESS:
ti_status.succeeded.add(key)
self.log.debug("Task instance %s succeeded. Don't rerun.", ti)
Expand Down Expand Up @@ -637,10 +637,12 @@ def _per_task_process(task, key, ti, session=None): # pylint: disable=too-many-

@provide_session
def _collect_errors(self, ti_status, session=None):
def tabulate_ti_keys_set(set_ti_keys: Set[TaskInstanceKeyType]) -> str:
def tabulate_ti_keys_set(set_ti_keys: Set[TaskInstanceKey]) -> str:
# Sorting by execution date first
sorted_ti_keys = sorted(
set_ti_keys, key=lambda ti_key: (ti_key[2], ti_key[0], ti_key[1], ti_key[3]))
set_ti_keys, key=lambda ti_key:
(ti_key.execution_date, ti_key.dag_id, ti_key.task_id, ti_key.try_number)
)
return tabulate(sorted_ti_keys, headers=["DAG ID", "Task ID", "Execution date", "Try number"])

def tabulate_tis_set(set_tis: Set[TaskInstance]) -> str:
Expand Down
Loading

0 comments on commit 6a34c25

Please sign in to comment.