diff --git a/airflow/cli/commands/kubernetes_command.py b/airflow/cli/commands/kubernetes_command.py index 25cfecc2b47b9..c367d4be87b6d 100644 --- a/airflow/cli/commands/kubernetes_command.py +++ b/airflow/cli/commands/kubernetes_command.py @@ -25,9 +25,10 @@ from kubernetes.client.api_client import ApiClient from kubernetes.client.rest import ApiException -from airflow.executors.kubernetes_executor import KubeConfig, create_pod_id +from airflow.executors.kubernetes_executor import KubeConfig from airflow.kubernetes import pod_generator from airflow.kubernetes.kube_client import get_kube_client +from airflow.kubernetes.kubernetes_helper_functions import create_pod_id from airflow.kubernetes.pod_generator import PodGenerator from airflow.models import DagRun, TaskInstance from airflow.utils import cli as cli_utils, yaml diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py index ea84ad6dab646..c1f8861f895fd 100644 --- a/airflow/executors/kubernetes_executor.py +++ b/airflow/executors/kubernetes_executor.py @@ -29,477 +29,36 @@ import time from collections import defaultdict from contextlib import suppress +from datetime import datetime from queue import Empty, Queue -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Sequence -from kubernetes import client, watch -from kubernetes.client import Configuration, models as k8s -from kubernetes.client.rest import ApiException from sqlalchemy.orm import Session -from urllib3.exceptions import ReadTimeoutError from airflow.configuration import conf -from airflow.exceptions import AirflowException, PodMutationHookException, PodReconciliationError +from airflow.exceptions import PodMutationHookException, PodReconciliationError from airflow.executors.base_executor import BaseExecutor -from airflow.kubernetes import pod_generator -from airflow.kubernetes.kube_client import get_kube_client +from airflow.executors.kubernetes_executor_types import POD_EXECUTOR_DONE_KEY from airflow.kubernetes.kube_config import KubeConfig -from airflow.kubernetes.kubernetes_helper_functions import ( - annotations_for_logging_task_metadata, - annotations_to_key, - create_pod_id, -) -from airflow.kubernetes.pod_generator import PodGenerator +from airflow.kubernetes.kubernetes_helper_functions import annotations_to_key from airflow.utils.event_scheduler import EventScheduler -from airflow.utils.log.logging_mixin import LoggingMixin, remove_escape_codes +from airflow.utils.log.logging_mixin import remove_escape_codes from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.singleton import Singleton from airflow.utils.state import State, TaskInstanceState if TYPE_CHECKING: + from kubernetes import client + from kubernetes.client import models as k8s + from airflow.executors.base_executor import CommandType + from airflow.executors.kubernetes_executor_types import ( + KubernetesJobType, + KubernetesResultsType, + ) + from airflow.executors.kubernetes_executor_utils import AirflowKubernetesScheduler from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey - # TaskInstance key, command, configuration, pod_template_file - KubernetesJobType = Tuple[TaskInstanceKey, CommandType, Any, Optional[str]] - - # key, pod state, pod_name, namespace, resource_version - KubernetesResultsType = Tuple[TaskInstanceKey, Optional[str], str, str, str] - - # pod_name, namespace, pod state, annotations, resource_version - KubernetesWatchType = Tuple[str, str, Optional[str], Dict[str, str], str] - -ALL_NAMESPACES = "ALL_NAMESPACES" -POD_EXECUTOR_DONE_KEY = "airflow_executor_done" - - -class ResourceVersion(metaclass=Singleton): - """Singleton for tracking resourceVersion from Kubernetes.""" - - resource_version: dict[str, str] = {} - - -class KubernetesJobWatcher(multiprocessing.Process, LoggingMixin): - """Watches for Kubernetes jobs.""" - - def __init__( - self, - namespace: str, - watcher_queue: Queue[KubernetesWatchType], - resource_version: str | None, - scheduler_job_id: str, - kube_config: Configuration, - ): - super().__init__() - self.namespace = namespace - self.scheduler_job_id = scheduler_job_id - self.watcher_queue = watcher_queue - self.resource_version = resource_version - self.kube_config = kube_config - - def run(self) -> None: - """Performs watching.""" - if TYPE_CHECKING: - assert self.scheduler_job_id - - kube_client: client.CoreV1Api = get_kube_client() - while True: - try: - self.resource_version = self._run( - kube_client, self.resource_version, self.scheduler_job_id, self.kube_config - ) - except ReadTimeoutError: - self.log.warning( - "There was a timeout error accessing the Kube API. Retrying request.", exc_info=True - ) - time.sleep(1) - except Exception: - self.log.exception("Unknown error in KubernetesJobWatcher. Failing") - self.resource_version = "0" - ResourceVersion().resource_version[self.namespace] = "0" - raise - else: - self.log.warning( - "Watch died gracefully, starting back up with: last resource_version: %s", - self.resource_version, - ) - - def _pod_events(self, kube_client: client.CoreV1Api, query_kwargs: dict): - watcher = watch.Watch() - try: - if self.namespace == ALL_NAMESPACES: - return watcher.stream(kube_client.list_pod_for_all_namespaces, **query_kwargs) - else: - return watcher.stream(kube_client.list_namespaced_pod, self.namespace, **query_kwargs) - except ApiException as e: - if e.status == 410: # Resource version is too old - if self.namespace == ALL_NAMESPACES: - pods = kube_client.list_pod_for_all_namespaces(watch=False) - else: - pods = kube_client.list_namespaced_pod(namespace=self.namespace, watch=False) - resource_version = pods.metadata.resource_version - query_kwargs["resource_version"] = resource_version - return self._pod_events(kube_client=kube_client, query_kwargs=query_kwargs) - else: - raise - - def _run( - self, - kube_client: client.CoreV1Api, - resource_version: str | None, - scheduler_job_id: str, - kube_config: Any, - ) -> str | None: - self.log.info("Event: and now my watch begins starting at resource_version: %s", resource_version) - - kwargs = {"label_selector": f"airflow-worker={scheduler_job_id}"} - if resource_version: - kwargs["resource_version"] = resource_version - if kube_config.kube_client_request_args: - for key, value in kube_config.kube_client_request_args.items(): - kwargs[key] = value - - last_resource_version: str | None = None - - for event in self._pod_events(kube_client=kube_client, query_kwargs=kwargs): - task = event["object"] - self.log.debug("Event: %s had an event of type %s", task.metadata.name, event["type"]) - if event["type"] == "ERROR": - return self.process_error(event) - annotations = task.metadata.annotations - task_instance_related_annotations = { - "dag_id": annotations["dag_id"], - "task_id": annotations["task_id"], - "execution_date": annotations.get("execution_date"), - "run_id": annotations.get("run_id"), - "try_number": annotations["try_number"], - } - map_index = annotations.get("map_index") - if map_index is not None: - task_instance_related_annotations["map_index"] = map_index - - self.process_status( - pod_name=task.metadata.name, - namespace=task.metadata.namespace, - status=task.status.phase, - annotations=task_instance_related_annotations, - resource_version=task.metadata.resource_version, - event=event, - ) - last_resource_version = task.metadata.resource_version - - return last_resource_version - - def process_error(self, event: Any) -> str: - """Process error response.""" - self.log.error("Encountered Error response from k8s list namespaced pod stream => %s", event) - raw_object = event["raw_object"] - if raw_object["code"] == 410: - self.log.info( - "Kubernetes resource version is too old, must reset to 0 => %s", (raw_object["message"],) - ) - # Return resource version 0 - return "0" - raise AirflowException( - f"Kubernetes failure for {raw_object['reason']} with code {raw_object['code']} and message: " - f"{raw_object['message']}" - ) - - def process_status( - self, - pod_name: str, - namespace: str, - status: str, - annotations: dict[str, str], - resource_version: str, - event: Any, - ) -> None: - pod = event["object"] - annotations_string = annotations_for_logging_task_metadata(annotations) - """Process status response.""" - if status == "Pending": - # deletion_timestamp is set by kube server when a graceful deletion is requested. - # since kube server have received request to delete pod set TI state failed - if event["type"] == "DELETED" and pod.metadata.deletion_timestamp: - self.log.info("Event: Failed to start pod %s, annotations: %s", pod_name, annotations_string) - self.watcher_queue.put((pod_name, namespace, State.FAILED, annotations, resource_version)) - else: - self.log.debug("Event: %s Pending, annotations: %s", pod_name, annotations_string) - elif status == "Failed": - self.log.error("Event: %s Failed, annotations: %s", pod_name, annotations_string) - self.watcher_queue.put((pod_name, namespace, State.FAILED, annotations, resource_version)) - elif status == "Succeeded": - # We get multiple events once the pod hits a terminal state, and we only want to - # send it along to the scheduler once. - # If our event type is DELETED, we have the POD_EXECUTOR_DONE_KEY, or the pod has - # a deletion timestamp, we've already seen the initial Succeeded event and sent it - # along to the scheduler. - if ( - event["type"] == "DELETED" - or POD_EXECUTOR_DONE_KEY in pod.metadata.labels - or pod.metadata.deletion_timestamp - ): - self.log.info( - "Skipping event for Succeeded pod %s - event for this pod already sent to executor", - pod_name, - ) - return - self.log.info("Event: %s Succeeded, annotations: %s", pod_name, annotations_string) - self.watcher_queue.put((pod_name, namespace, None, annotations, resource_version)) - elif status == "Running": - # deletion_timestamp is set by kube server when a graceful deletion is requested. - # since kube server have received request to delete pod set TI state failed - if event["type"] == "DELETED" and pod.metadata.deletion_timestamp: - self.log.info( - "Event: Pod %s deleted before it could complete, annotations: %s", - pod_name, - annotations_string, - ) - self.watcher_queue.put((pod_name, namespace, State.FAILED, annotations, resource_version)) - else: - self.log.info("Event: %s is Running, annotations: %s", pod_name, annotations_string) - else: - self.log.warning( - "Event: Invalid state: %s on pod: %s in namespace %s with annotations: %s with " - "resource_version: %s", - status, - pod_name, - namespace, - annotations, - resource_version, - ) - - -class AirflowKubernetesScheduler(LoggingMixin): - """Airflow Scheduler for Kubernetes.""" - - def __init__( - self, - kube_config: Any, - result_queue: Queue[KubernetesResultsType], - kube_client: client.CoreV1Api, - scheduler_job_id: str, - ): - super().__init__() - self.log.debug("Creating Kubernetes executor") - self.kube_config = kube_config - self.result_queue = result_queue - self.namespace = self.kube_config.kube_namespace - self.log.debug("Kubernetes using namespace %s", self.namespace) - self.kube_client = kube_client - self._manager = multiprocessing.Manager() - self.watcher_queue = self._manager.Queue() - self.scheduler_job_id = scheduler_job_id - self.kube_watchers = self._make_kube_watchers() - - def run_pod_async(self, pod: k8s.V1Pod, **kwargs): - """Runs POD asynchronously.""" - sanitized_pod = self.kube_client.api_client.sanitize_for_serialization(pod) - json_pod = json.dumps(sanitized_pod, indent=2) - - self.log.debug("Pod Creation Request: \n%s", json_pod) - try: - resp = self.kube_client.create_namespaced_pod( - body=sanitized_pod, namespace=pod.metadata.namespace, **kwargs - ) - self.log.debug("Pod Creation Response: %s", resp) - except Exception as e: - self.log.exception("Exception when attempting to create Namespaced Pod: %s", json_pod) - raise e - return resp - - def _make_kube_watcher(self, namespace) -> KubernetesJobWatcher: - resource_version = ResourceVersion().resource_version.get(namespace, "0") - watcher = KubernetesJobWatcher( - watcher_queue=self.watcher_queue, - namespace=namespace, - resource_version=resource_version, - scheduler_job_id=self.scheduler_job_id, - kube_config=self.kube_config, - ) - watcher.start() - return watcher - - def _make_kube_watchers(self) -> dict[str, KubernetesJobWatcher]: - watchers = {} - if self.kube_config.multi_namespace_mode: - namespaces_to_watch = ( - self.kube_config.multi_namespace_mode_namespace_list - if self.kube_config.multi_namespace_mode_namespace_list - else [ALL_NAMESPACES] - ) - else: - namespaces_to_watch = [self.kube_config.kube_namespace] - - for namespace in namespaces_to_watch: - watchers[namespace] = self._make_kube_watcher(namespace) - return watchers - - def _health_check_kube_watchers(self): - for namespace, kube_watcher in self.kube_watchers.items(): - if kube_watcher.is_alive(): - self.log.debug("KubeJobWatcher for namespace %s alive, continuing", namespace) - else: - self.log.error( - ( - "Error while health checking kube watcher process for namespace %s. " - "Process died for unknown reasons" - ), - namespace, - ) - ResourceVersion().resource_version[namespace] = "0" - self.kube_watchers[namespace] = self._make_kube_watcher(namespace) - - def run_next(self, next_job: KubernetesJobType) -> None: - """Receives the next job to run, builds the pod, and creates it.""" - key, command, kube_executor_config, pod_template_file = next_job - - dag_id, task_id, run_id, try_number, map_index = key - - if command[0:3] != ["airflow", "tasks", "run"]: - raise ValueError('The command must start with ["airflow", "tasks", "run"].') - - base_worker_pod = get_base_pod_from_template(pod_template_file, self.kube_config) - - if not base_worker_pod: - raise AirflowException( - f"could not find a valid worker template yaml at {self.kube_config.pod_template_file}" - ) - - pod = PodGenerator.construct_pod( - namespace=self.namespace, - scheduler_job_id=self.scheduler_job_id, - pod_id=create_pod_id(dag_id, task_id), - dag_id=dag_id, - task_id=task_id, - kube_image=self.kube_config.kube_image, - try_number=try_number, - map_index=map_index, - date=None, - run_id=run_id, - args=command, - pod_override_object=kube_executor_config, - base_worker_pod=base_worker_pod, - with_mutation_hook=True, - ) - # Reconcile the pod generated by the Operator and the Pod - # generated by the .cfg file - self.log.info( - "Creating kubernetes pod for job is %s, with pod name %s, annotations: %s", - key, - pod.metadata.name, - annotations_for_logging_task_metadata(pod.metadata.annotations), - ) - self.log.debug("Kubernetes running for command %s", command) - self.log.debug("Kubernetes launching image %s", pod.spec.containers[0].image) - - # the watcher will monitor pods, so we do not block. - self.run_pod_async(pod, **self.kube_config.kube_client_request_args) - self.log.debug("Kubernetes Job created!") - - def delete_pod(self, pod_name: str, namespace: str) -> None: - """Deletes Pod from a namespace. Does not raise if it does not exist.""" - try: - self.log.debug("Deleting pod %s in namespace %s", pod_name, namespace) - self.kube_client.delete_namespaced_pod( - pod_name, - namespace, - body=client.V1DeleteOptions(**self.kube_config.delete_option_kwargs), - **self.kube_config.kube_client_request_args, - ) - except ApiException as e: - # If the pod is already deleted - if e.status != 404: - raise - - def patch_pod_executor_done(self, *, pod_name: str, namespace: str): - """Add a "done" annotation to ensure we don't continually adopt pods.""" - self.log.debug("Patching pod %s in namespace %s to mark it as done", pod_name, namespace) - try: - self.kube_client.patch_namespaced_pod( - name=pod_name, - namespace=namespace, - body={"metadata": {"labels": {POD_EXECUTOR_DONE_KEY: "True"}}}, - ) - except ApiException as e: - self.log.info("Failed to patch pod %s with done annotation. Reason: %s", pod_name, e) - - def sync(self) -> None: - """ - Checks the status of all currently running kubernetes jobs. - - If a job is completed, its status is placed in the result queue to be sent back to the scheduler. - """ - self.log.debug("Syncing KubernetesExecutor") - self._health_check_kube_watchers() - while True: - try: - task = self.watcher_queue.get_nowait() - try: - self.log.debug("Processing task %s", task) - self.process_watcher_task(task) - finally: - self.watcher_queue.task_done() - except Empty: - break - - def process_watcher_task(self, task: KubernetesWatchType) -> None: - """Process the task by watcher.""" - pod_name, namespace, state, annotations, resource_version = task - self.log.debug( - "Attempting to finish pod; pod_name: %s; state: %s; annotations: %s", - pod_name, - state, - annotations_for_logging_task_metadata(annotations), - ) - key = annotations_to_key(annotations=annotations) - if key: - self.log.debug("finishing job %s - %s (%s)", key, state, pod_name) - self.result_queue.put((key, state, pod_name, namespace, resource_version)) - - def _flush_watcher_queue(self) -> None: - self.log.debug("Executor shutting down, watcher_queue approx. size=%d", self.watcher_queue.qsize()) - while True: - try: - task = self.watcher_queue.get_nowait() - # Ignoring it since it can only have either FAILED or SUCCEEDED pods - self.log.warning("Executor shutting down, IGNORING watcher task=%s", task) - self.watcher_queue.task_done() - except Empty: - break - - def terminate(self) -> None: - """Terminates the watcher.""" - self.log.debug("Terminating kube_watchers...") - for namespace, kube_watcher in self.kube_watchers.items(): - kube_watcher.terminate() - kube_watcher.join() - self.log.debug("kube_watcher=%s", kube_watcher) - self.log.debug("Flushing watcher_queue...") - self._flush_watcher_queue() - # Queue should be empty... - self.watcher_queue.join() - self.log.debug("Shutting down manager...") - self._manager.shutdown() - - -def get_base_pod_from_template(pod_template_file: str | None, kube_config: Any) -> k8s.V1Pod: - """ - Get base pod from template. - - Reads either the pod_template_file set in the executor_config or the base pod_template_file - set in the airflow.cfg to craft a "base pod" that will be used by the KubernetesExecutor - - :param pod_template_file: absolute path to a pod_template_file.yaml or None - :param kube_config: The KubeConfig class generated by airflow that contains all kube metadata - :return: a V1Pod that can be used as the base pod for k8s tasks - """ - if pod_template_file: - return PodGenerator.deserialize_model_file(pod_template_file) - else: - return PodGenerator.deserialize_model_file(kube_config.pod_template_file) - class KubernetesExecutor(BaseExecutor): """Executor for Kubernetes.""" @@ -536,6 +95,19 @@ def _list_pods(self, query_kwargs): return pods + def _make_safe_label_value(self, input_value: str | datetime) -> str: + """ + Normalize a provided label to be of valid length and characters. + See airflow.kubernetes.pod_generator.make_safe_label_value for more details. + """ + # airflow.kubernetes is an expensive import, locally import it here to + # speed up load times of the kubernetes_executor module. + from airflow.kubernetes import pod_generator + + if isinstance(input_value, datetime): + return pod_generator.datetime_to_label_safe_datestring(input_value) + return pod_generator.make_safe_label_value(input_value) + @provide_session def clear_not_launched_queued_tasks(self, session: Session = NEW_SESSION) -> None: """ @@ -580,9 +152,9 @@ def clear_not_launched_queued_tasks(self, session: Session = NEW_SESSION) -> Non # Build the pod selector base_label_selector = ( - f"dag_id={pod_generator.make_safe_label_value(ti.dag_id)}," - f"task_id={pod_generator.make_safe_label_value(ti.task_id)}," - f"airflow-worker={pod_generator.make_safe_label_value(str(ti.queued_by_job_id))}" + f"dag_id={self._make_safe_label_value(ti.dag_id)}," + f"task_id={self._make_safe_label_value(ti.task_id)}," + f"airflow-worker={self._make_safe_label_value(str(ti.queued_by_job_id))}" ) if ti.map_index >= 0: # Old tasks _couldn't_ be mapped, so we don't have to worry about compat @@ -592,15 +164,14 @@ def clear_not_launched_queued_tasks(self, session: Session = NEW_SESSION) -> Non kwargs.update(**self.kube_config.kube_client_request_args) # Try run_id first - kwargs["label_selector"] += ",run_id=" + pod_generator.make_safe_label_value(ti.run_id) + kwargs["label_selector"] += ",run_id=" + self._make_safe_label_value(ti.run_id) pod_list = self._list_pods(kwargs) if pod_list: continue # Fallback to old style of using execution_date - kwargs["label_selector"] = ( - f"{base_label_selector}," - f"execution_date={pod_generator.datetime_to_label_safe_datestring(ti.execution_date)}" - ) + kwargs[ + "label_selector" + ] = f"{base_label_selector},execution_date={self._make_safe_label_value(ti.execution_date)}" pod_list = self._list_pods(kwargs) if pod_list: continue @@ -617,6 +188,9 @@ def start(self) -> None: self.log.info("Start Kubernetes executor") self.scheduler_job_id = str(self.job_id) self.log.debug("Start with scheduler_job_id: %s", self.scheduler_job_id) + from airflow.executors.kubernetes_executor_utils import AirflowKubernetesScheduler + from airflow.kubernetes.kube_client import get_kube_client + self.kube_client = get_kube_client() self.kube_scheduler = AirflowKubernetesScheduler( kube_config=self.kube_config, @@ -650,6 +224,8 @@ def execute_async( else: self.log.info("Add task %s with command %s", key, command) + from airflow.kubernetes.pod_generator import PodGenerator + try: kube_executor_config = PodGenerator.from_obj(executor_config) except Exception: @@ -706,15 +282,20 @@ def sync(self) -> None: except Empty: break + from airflow.executors.kubernetes_executor_utils import ResourceVersion + resource_instance = ResourceVersion() for ns in resource_instance.resource_version.keys(): resource_instance.resource_version[ns] = ( last_resource_version[ns] or resource_instance.resource_version[ns] ) + from kubernetes.client.rest import ApiException + for _ in range(self.kube_config.worker_pods_creation_batch_size): try: task = self.task_queue.get_nowait() + try: self.kube_scheduler.run_next(task) except PodReconciliationError as e: @@ -766,7 +347,6 @@ def _change_state( ) -> None: if TYPE_CHECKING: assert self.kube_scheduler - from airflow.models.taskinstance import TaskInstance if state == State.RUNNING: self.event_buffer[key] = state, None @@ -787,6 +367,8 @@ def _change_state( # If we don't have a TI state, look it up from the db. event_buffer expects the TI state if state is None: + from airflow.models.taskinstance import TaskInstance + state = session.query(TaskInstance.state).filter(TaskInstance.filter_for_tis([key])).scalar() self.event_buffer[key] = state, None @@ -803,6 +385,7 @@ def get_task_log(self, ti: TaskInstance, try_number: int) -> tuple[list[str], li messages = [] log = [] try: + from airflow.kubernetes.kube_client import get_kube_client from airflow.kubernetes.pod_generator import PodGenerator @@ -849,7 +432,7 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task tis_to_flush_by_key = {ti.key: ti for ti in tis if ti.queued_by_job_id} kube_client: client.CoreV1Api = self.kube_client for scheduler_job_id in scheduler_job_ids: - scheduler_job_id = pod_generator.make_safe_label_value(str(scheduler_job_id)) + scheduler_job_id = self._make_safe_label_value(str(scheduler_job_id)) # We will look for any pods owned by the no-longer-running scheduler, # but will exclude only successful pods, as those TIs will have a terminal state # and not be up for adoption! @@ -880,6 +463,8 @@ def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> list[str]: :param tis: List of Task Instances to clean up :return: List of readable task instances for a warning message """ + from airflow.kubernetes.pod_generator import PodGenerator + if TYPE_CHECKING: assert self.kube_client assert self.kube_scheduler @@ -930,8 +515,11 @@ def adopt_launched_task( self.log.error("attempting to adopt taskinstance which was not specified by database: %s", ti_key) return - new_worker_id_label = pod_generator.make_safe_label_value(self.scheduler_job_id) + new_worker_id_label = self._make_safe_label_value(self.scheduler_job_id) + from kubernetes.client.rest import ApiException + try: + kube_client.patch_namespaced_pod( name=pod.metadata.name, namespace=pod.metadata.namespace, @@ -953,7 +541,7 @@ def _adopt_completed_pods(self, kube_client: client.CoreV1Api) -> None: if TYPE_CHECKING: assert self.scheduler_job_id - new_worker_id_label = pod_generator.make_safe_label_value(self.scheduler_job_id) + new_worker_id_label = self._make_safe_label_value(self.scheduler_job_id) query_kwargs = { "field_selector": "status.phase=Succeeded", "label_selector": ( @@ -964,7 +552,10 @@ def _adopt_completed_pods(self, kube_client: client.CoreV1Api) -> None: pod_list = self._list_pods(query_kwargs) for pod in pod_list: self.log.info("Attempting to adopt pod %s", pod.metadata.name) + from kubernetes.client.rest import ApiException + try: + kube_client.patch_namespaced_pod( name=pod.metadata.name, namespace=pod.metadata.namespace, diff --git a/airflow/executors/kubernetes_executor_types.py b/airflow/executors/kubernetes_executor_types.py new file mode 100644 index 0000000000000..a13cd35f8d76f --- /dev/null +++ b/airflow/executors/kubernetes_executor_types.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple + +if TYPE_CHECKING: + from airflow.executors.base_executor import CommandType + from airflow.models.taskinstance import TaskInstanceKey + + # TaskInstance key, command, configuration, pod_template_file + KubernetesJobType = Tuple[TaskInstanceKey, CommandType, Any, Optional[str]] + + # key, pod state, pod_name, namespace, resource_version + KubernetesResultsType = Tuple[TaskInstanceKey, Optional[str], str, str, str] + + # pod_name, namespace, pod state, annotations, resource_version + KubernetesWatchType = Tuple[str, str, Optional[str], Dict[str, str], str] + +ALL_NAMESPACES = "ALL_NAMESPACES" +POD_EXECUTOR_DONE_KEY = "airflow_executor_done" diff --git a/airflow/executors/kubernetes_executor_utils.py b/airflow/executors/kubernetes_executor_utils.py new file mode 100644 index 0000000000000..b1ee49cb695ea --- /dev/null +++ b/airflow/executors/kubernetes_executor_utils.py @@ -0,0 +1,477 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +import multiprocessing +import time +from queue import Empty, Queue +from typing import TYPE_CHECKING, Any + +from kubernetes import client, watch +from kubernetes.client import Configuration, models as k8s +from kubernetes.client.rest import ApiException +from urllib3.exceptions import ReadTimeoutError + +from airflow.exceptions import AirflowException +from airflow.kubernetes.kube_client import get_kube_client +from airflow.kubernetes.kubernetes_helper_functions import ( + annotations_for_logging_task_metadata, + annotations_to_key, + create_pod_id, +) +from airflow.kubernetes.pod_generator import PodGenerator +from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.singleton import Singleton +from airflow.utils.state import State + +if TYPE_CHECKING: + from airflow.executors.kubernetes_executor_types import ( + KubernetesJobType, + KubernetesResultsType, + KubernetesWatchType, + ) + + +from airflow.executors.kubernetes_executor_types import ALL_NAMESPACES, POD_EXECUTOR_DONE_KEY + + +class ResourceVersion(metaclass=Singleton): + """Singleton for tracking resourceVersion from Kubernetes.""" + + resource_version: dict[str, str] = {} + + +class KubernetesJobWatcher(multiprocessing.Process, LoggingMixin): + """Watches for Kubernetes jobs.""" + + def __init__( + self, + namespace: str, + watcher_queue: Queue[KubernetesWatchType], + resource_version: str | None, + scheduler_job_id: str, + kube_config: Configuration, + ): + super().__init__() + self.namespace = namespace + self.scheduler_job_id = scheduler_job_id + self.watcher_queue = watcher_queue + self.resource_version = resource_version + self.kube_config = kube_config + + def run(self) -> None: + """Performs watching.""" + if TYPE_CHECKING: + assert self.scheduler_job_id + + kube_client: client.CoreV1Api = get_kube_client() + while True: + try: + self.resource_version = self._run( + kube_client, self.resource_version, self.scheduler_job_id, self.kube_config + ) + except ReadTimeoutError: + self.log.warning( + "There was a timeout error accessing the Kube API. Retrying request.", exc_info=True + ) + time.sleep(1) + except Exception: + self.log.exception("Unknown error in KubernetesJobWatcher. Failing") + self.resource_version = "0" + ResourceVersion().resource_version[self.namespace] = "0" + raise + else: + self.log.warning( + "Watch died gracefully, starting back up with: last resource_version: %s", + self.resource_version, + ) + + def _pod_events(self, kube_client: client.CoreV1Api, query_kwargs: dict): + watcher = watch.Watch() + try: + if self.namespace == ALL_NAMESPACES: + return watcher.stream(kube_client.list_pod_for_all_namespaces, **query_kwargs) + else: + return watcher.stream(kube_client.list_namespaced_pod, self.namespace, **query_kwargs) + except ApiException as e: + if e.status == 410: # Resource version is too old + if self.namespace == ALL_NAMESPACES: + pods = kube_client.list_pod_for_all_namespaces(watch=False) + else: + pods = kube_client.list_namespaced_pod(namespace=self.namespace, watch=False) + resource_version = pods.metadata.resource_version + query_kwargs["resource_version"] = resource_version + return self._pod_events(kube_client=kube_client, query_kwargs=query_kwargs) + else: + raise + + def _run( + self, + kube_client: client.CoreV1Api, + resource_version: str | None, + scheduler_job_id: str, + kube_config: Any, + ) -> str | None: + self.log.info("Event: and now my watch begins starting at resource_version: %s", resource_version) + + kwargs = {"label_selector": f"airflow-worker={scheduler_job_id}"} + if resource_version: + kwargs["resource_version"] = resource_version + if kube_config.kube_client_request_args: + for key, value in kube_config.kube_client_request_args.items(): + kwargs[key] = value + + last_resource_version: str | None = None + + for event in self._pod_events(kube_client=kube_client, query_kwargs=kwargs): + task = event["object"] + self.log.debug("Event: %s had an event of type %s", task.metadata.name, event["type"]) + if event["type"] == "ERROR": + return self.process_error(event) + annotations = task.metadata.annotations + task_instance_related_annotations = { + "dag_id": annotations["dag_id"], + "task_id": annotations["task_id"], + "execution_date": annotations.get("execution_date"), + "run_id": annotations.get("run_id"), + "try_number": annotations["try_number"], + } + map_index = annotations.get("map_index") + if map_index is not None: + task_instance_related_annotations["map_index"] = map_index + + self.process_status( + pod_name=task.metadata.name, + namespace=task.metadata.namespace, + status=task.status.phase, + annotations=task_instance_related_annotations, + resource_version=task.metadata.resource_version, + event=event, + ) + last_resource_version = task.metadata.resource_version + + return last_resource_version + + def process_error(self, event: Any) -> str: + """Process error response.""" + self.log.error("Encountered Error response from k8s list namespaced pod stream => %s", event) + raw_object = event["raw_object"] + if raw_object["code"] == 410: + self.log.info( + "Kubernetes resource version is too old, must reset to 0 => %s", (raw_object["message"],) + ) + # Return resource version 0 + return "0" + raise AirflowException( + f"Kubernetes failure for {raw_object['reason']} with code {raw_object['code']} and message: " + f"{raw_object['message']}" + ) + + def process_status( + self, + pod_name: str, + namespace: str, + status: str, + annotations: dict[str, str], + resource_version: str, + event: Any, + ) -> None: + pod = event["object"] + annotations_string = annotations_for_logging_task_metadata(annotations) + """Process status response.""" + if status == "Pending": + # deletion_timestamp is set by kube server when a graceful deletion is requested. + # since kube server have received request to delete pod set TI state failed + if event["type"] == "DELETED" and pod.metadata.deletion_timestamp: + self.log.info("Event: Failed to start pod %s, annotations: %s", pod_name, annotations_string) + self.watcher_queue.put((pod_name, namespace, State.FAILED, annotations, resource_version)) + else: + self.log.debug("Event: %s Pending, annotations: %s", pod_name, annotations_string) + elif status == "Failed": + self.log.error("Event: %s Failed, annotations: %s", pod_name, annotations_string) + self.watcher_queue.put((pod_name, namespace, State.FAILED, annotations, resource_version)) + elif status == "Succeeded": + # We get multiple events once the pod hits a terminal state, and we only want to + # send it along to the scheduler once. + # If our event type is DELETED, we have the POD_EXECUTOR_DONE_KEY, or the pod has + # a deletion timestamp, we've already seen the initial Succeeded event and sent it + # along to the scheduler. + if ( + event["type"] == "DELETED" + or POD_EXECUTOR_DONE_KEY in pod.metadata.labels + or pod.metadata.deletion_timestamp + ): + self.log.info( + "Skipping event for Succeeded pod %s - event for this pod already sent to executor", + pod_name, + ) + return + self.log.info("Event: %s Succeeded, annotations: %s", pod_name, annotations_string) + self.watcher_queue.put((pod_name, namespace, None, annotations, resource_version)) + elif status == "Running": + # deletion_timestamp is set by kube server when a graceful deletion is requested. + # since kube server have received request to delete pod set TI state failed + if event["type"] == "DELETED" and pod.metadata.deletion_timestamp: + self.log.info( + "Event: Pod %s deleted before it could complete, annotations: %s", + pod_name, + annotations_string, + ) + self.watcher_queue.put((pod_name, namespace, State.FAILED, annotations, resource_version)) + else: + self.log.info("Event: %s is Running, annotations: %s", pod_name, annotations_string) + else: + self.log.warning( + "Event: Invalid state: %s on pod: %s in namespace %s with annotations: %s with " + "resource_version: %s", + status, + pod_name, + namespace, + annotations, + resource_version, + ) + + +class AirflowKubernetesScheduler(LoggingMixin): + """Airflow Scheduler for Kubernetes.""" + + def __init__( + self, + kube_config: Any, + result_queue: Queue[KubernetesResultsType], + kube_client: client.CoreV1Api, + scheduler_job_id: str, + ): + super().__init__() + self.log.debug("Creating Kubernetes executor") + self.kube_config = kube_config + self.result_queue = result_queue + self.namespace = self.kube_config.kube_namespace + self.log.debug("Kubernetes using namespace %s", self.namespace) + self.kube_client = kube_client + self._manager = multiprocessing.Manager() + self.watcher_queue = self._manager.Queue() + self.scheduler_job_id = scheduler_job_id + self.kube_watchers = self._make_kube_watchers() + + def run_pod_async(self, pod: k8s.V1Pod, **kwargs): + """Runs POD asynchronously.""" + sanitized_pod = self.kube_client.api_client.sanitize_for_serialization(pod) + json_pod = json.dumps(sanitized_pod, indent=2) + + self.log.debug("Pod Creation Request: \n%s", json_pod) + try: + resp = self.kube_client.create_namespaced_pod( + body=sanitized_pod, namespace=pod.metadata.namespace, **kwargs + ) + self.log.debug("Pod Creation Response: %s", resp) + except Exception as e: + self.log.exception("Exception when attempting to create Namespaced Pod: %s", json_pod) + raise e + return resp + + def _make_kube_watcher(self, namespace) -> KubernetesJobWatcher: + resource_version = ResourceVersion().resource_version.get(namespace, "0") + watcher = KubernetesJobWatcher( + watcher_queue=self.watcher_queue, + namespace=namespace, + resource_version=resource_version, + scheduler_job_id=self.scheduler_job_id, + kube_config=self.kube_config, + ) + watcher.start() + return watcher + + def _make_kube_watchers(self) -> dict[str, KubernetesJobWatcher]: + watchers = {} + if self.kube_config.multi_namespace_mode: + namespaces_to_watch = ( + self.kube_config.multi_namespace_mode_namespace_list + if self.kube_config.multi_namespace_mode_namespace_list + else [ALL_NAMESPACES] + ) + else: + namespaces_to_watch = [self.kube_config.kube_namespace] + + for namespace in namespaces_to_watch: + watchers[namespace] = self._make_kube_watcher(namespace) + return watchers + + def _health_check_kube_watchers(self): + for namespace, kube_watcher in self.kube_watchers.items(): + if kube_watcher.is_alive(): + self.log.debug("KubeJobWatcher for namespace %s alive, continuing", namespace) + else: + self.log.error( + ( + "Error while health checking kube watcher process for namespace %s. " + "Process died for unknown reasons" + ), + namespace, + ) + ResourceVersion().resource_version[namespace] = "0" + self.kube_watchers[namespace] = self._make_kube_watcher(namespace) + + def run_next(self, next_job: KubernetesJobType) -> None: + """Receives the next job to run, builds the pod, and creates it.""" + key, command, kube_executor_config, pod_template_file = next_job + + dag_id, task_id, run_id, try_number, map_index = key + + if command[0:3] != ["airflow", "tasks", "run"]: + raise ValueError('The command must start with ["airflow", "tasks", "run"].') + + base_worker_pod = get_base_pod_from_template(pod_template_file, self.kube_config) + + if not base_worker_pod: + raise AirflowException( + f"could not find a valid worker template yaml at {self.kube_config.pod_template_file}" + ) + + pod = PodGenerator.construct_pod( + namespace=self.namespace, + scheduler_job_id=self.scheduler_job_id, + pod_id=create_pod_id(dag_id, task_id), + dag_id=dag_id, + task_id=task_id, + kube_image=self.kube_config.kube_image, + try_number=try_number, + map_index=map_index, + date=None, + run_id=run_id, + args=command, + pod_override_object=kube_executor_config, + base_worker_pod=base_worker_pod, + with_mutation_hook=True, + ) + # Reconcile the pod generated by the Operator and the Pod + # generated by the .cfg file + self.log.info( + "Creating kubernetes pod for job is %s, with pod name %s, annotations: %s", + key, + pod.metadata.name, + annotations_for_logging_task_metadata(pod.metadata.annotations), + ) + self.log.debug("Kubernetes running for command %s", command) + self.log.debug("Kubernetes launching image %s", pod.spec.containers[0].image) + + # the watcher will monitor pods, so we do not block. + self.run_pod_async(pod, **self.kube_config.kube_client_request_args) + self.log.debug("Kubernetes Job created!") + + def delete_pod(self, pod_name: str, namespace: str) -> None: + """Deletes Pod from a namespace. Does not raise if it does not exist.""" + try: + self.log.debug("Deleting pod %s in namespace %s", pod_name, namespace) + self.kube_client.delete_namespaced_pod( + pod_name, + namespace, + body=client.V1DeleteOptions(**self.kube_config.delete_option_kwargs), + **self.kube_config.kube_client_request_args, + ) + except ApiException as e: + # If the pod is already deleted + if e.status != 404: + raise + + def patch_pod_executor_done(self, *, pod_name: str, namespace: str): + """Add a "done" annotation to ensure we don't continually adopt pods.""" + self.log.debug("Patching pod %s in namespace %s to mark it as done", pod_name, namespace) + try: + self.kube_client.patch_namespaced_pod( + name=pod_name, + namespace=namespace, + body={"metadata": {"labels": {POD_EXECUTOR_DONE_KEY: "True"}}}, + ) + except ApiException as e: + self.log.info("Failed to patch pod %s with done annotation. Reason: %s", pod_name, e) + + def sync(self) -> None: + """ + Checks the status of all currently running kubernetes jobs. + + If a job is completed, its status is placed in the result queue to be sent back to the scheduler. + """ + self.log.debug("Syncing KubernetesExecutor") + self._health_check_kube_watchers() + while True: + try: + task = self.watcher_queue.get_nowait() + try: + self.log.debug("Processing task %s", task) + self.process_watcher_task(task) + finally: + self.watcher_queue.task_done() + except Empty: + break + + def process_watcher_task(self, task: KubernetesWatchType) -> None: + """Process the task by watcher.""" + pod_name, namespace, state, annotations, resource_version = task + self.log.debug( + "Attempting to finish pod; pod_name: %s; state: %s; annotations: %s", + pod_name, + state, + annotations_for_logging_task_metadata(annotations), + ) + key = annotations_to_key(annotations=annotations) + if key: + self.log.debug("finishing job %s - %s (%s)", key, state, pod_name) + self.result_queue.put((key, state, pod_name, namespace, resource_version)) + + def _flush_watcher_queue(self) -> None: + self.log.debug("Executor shutting down, watcher_queue approx. size=%d", self.watcher_queue.qsize()) + while True: + try: + task = self.watcher_queue.get_nowait() + # Ignoring it since it can only have either FAILED or SUCCEEDED pods + self.log.warning("Executor shutting down, IGNORING watcher task=%s", task) + self.watcher_queue.task_done() + except Empty: + break + + def terminate(self) -> None: + """Terminates the watcher.""" + self.log.debug("Terminating kube_watchers...") + for namespace, kube_watcher in self.kube_watchers.items(): + kube_watcher.terminate() + kube_watcher.join() + self.log.debug("kube_watcher=%s", kube_watcher) + self.log.debug("Flushing watcher_queue...") + self._flush_watcher_queue() + # Queue should be empty... + self.watcher_queue.join() + self.log.debug("Shutting down manager...") + self._manager.shutdown() + + +def get_base_pod_from_template(pod_template_file: str | None, kube_config: Any) -> k8s.V1Pod: + """ + Get base pod from template. + + Reads either the pod_template_file set in the executor_config or the base pod_template_file + set in the airflow.cfg to craft a "base pod" that will be used by the KubernetesExecutor + + :param pod_template_file: absolute path to a pod_template_file.yaml or None + :param kube_config: The KubeConfig class generated by airflow that contains all kube metadata + :return: a V1Pod that can be used as the base pod for k8s tasks + """ + if pod_template_file: + return PodGenerator.deserialize_model_file(pod_template_file) + else: + return PodGenerator.deserialize_model_file(kube_config.pod_template_file) diff --git a/airflow/kubernetes/kubernetes_helper_functions.py b/airflow/kubernetes/kubernetes_helper_functions.py index fdb76b0aa85fc..4cd3422cb646b 100644 --- a/airflow/kubernetes/kubernetes_helper_functions.py +++ b/airflow/kubernetes/kubernetes_helper_functions.py @@ -19,13 +19,16 @@ import logging import secrets import string +from typing import TYPE_CHECKING import pendulum from slugify import slugify from airflow.compat.functools import cache from airflow.configuration import conf -from airflow.models.taskinstancekey import TaskInstanceKey + +if TYPE_CHECKING: + from airflow.models.taskinstancekey import TaskInstanceKey log = logging.getLogger(__name__) @@ -91,12 +94,12 @@ def annotations_to_key(annotations: dict[str, str]) -> TaskInstanceKey: annotation_run_id = annotations.get("run_id") map_index = int(annotations.get("map_index", -1)) - if not annotation_run_id and "execution_date" in annotations: - # Compat: Look up the run_id from the TI table! - from airflow.models.dagrun import DagRun - from airflow.models.taskinstance import TaskInstance - from airflow.settings import Session + # Compat: Look up the run_id from the TI table! + from airflow.models.dagrun import DagRun + from airflow.models.taskinstance import TaskInstance, TaskInstanceKey + from airflow.settings import Session + if not annotation_run_id and "execution_date" in annotations: execution_date = pendulum.parse(annotations["execution_date"]) # Do _not_ use create-session, we don't want to expunge session = Session() diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py index ba9c33a36e421..4b4685e079e62 100644 --- a/tests/executors/test_kubernetes_executor.py +++ b/tests/executors/test_kubernetes_executor.py @@ -40,10 +40,10 @@ from tests.test_utils.config import conf_vars try: - from airflow.executors.kubernetes_executor import ( - POD_EXECUTOR_DONE_KEY, + from airflow.executors.kubernetes_executor import KubernetesExecutor + from airflow.executors.kubernetes_executor_types import POD_EXECUTOR_DONE_KEY + from airflow.executors.kubernetes_executor_utils import ( AirflowKubernetesScheduler, - KubernetesExecutor, KubernetesJobWatcher, ResourceVersion, create_pod_id, @@ -160,9 +160,9 @@ def test_execution_date_serialize_deserialize(self): @pytest.mark.skipif( AirflowKubernetesScheduler is None, reason="kubernetes python package is not installed" ) - @mock.patch("airflow.executors.kubernetes_executor.get_kube_client") - @mock.patch("airflow.executors.kubernetes_executor.client") - @mock.patch("airflow.executors.kubernetes_executor.KubernetesJobWatcher") + @mock.patch("airflow.kubernetes.kube_client.get_kube_client") + @mock.patch("airflow.executors.kubernetes_executor_utils.client") + @mock.patch("airflow.executors.kubernetes_executor_utils.KubernetesJobWatcher") def test_delete_pod_successfully(self, mock_watcher, mock_client, mock_kube_client): pod_name = "my-pod-1" namespace = "my-namespace-1" @@ -182,9 +182,9 @@ def test_delete_pod_successfully(self, mock_watcher, mock_client, mock_kube_clie @pytest.mark.skipif( AirflowKubernetesScheduler is None, reason="kubernetes python package is not installed" ) - @mock.patch("airflow.executors.kubernetes_executor.get_kube_client") - @mock.patch("airflow.executors.kubernetes_executor.client") - @mock.patch("airflow.executors.kubernetes_executor.KubernetesJobWatcher") + @mock.patch("airflow.kubernetes.kube_client.get_kube_client") + @mock.patch("airflow.executors.kubernetes_executor_utils.client") + @mock.patch("airflow.executors.kubernetes_executor_utils.KubernetesJobWatcher") def test_delete_pod_raises_404(self, mock_watcher, mock_client, mock_kube_client): pod_name = "my-pod-1" namespace = "my-namespace-2" @@ -205,9 +205,9 @@ def test_delete_pod_raises_404(self, mock_watcher, mock_client, mock_kube_client @pytest.mark.skipif( AirflowKubernetesScheduler is None, reason="kubernetes python package is not installed" ) - @mock.patch("airflow.executors.kubernetes_executor.get_kube_client") - @mock.patch("airflow.executors.kubernetes_executor.client") - @mock.patch("airflow.executors.kubernetes_executor.KubernetesJobWatcher") + @mock.patch("airflow.kubernetes.kube_client.get_kube_client") + @mock.patch("airflow.executors.kubernetes_executor_utils.client") + @mock.patch("airflow.executors.kubernetes_executor_utils.KubernetesJobWatcher") def test_delete_pod_404_not_raised(self, mock_watcher, mock_client, mock_kube_client): pod_name = "my-pod-1" namespace = "my-namespace-3" @@ -249,8 +249,8 @@ def setup_method(self) -> None: pytest.param(400, False, id="400 BadRequest"), ], ) - @mock.patch("airflow.executors.kubernetes_executor.KubernetesJobWatcher") - @mock.patch("airflow.executors.kubernetes_executor.get_kube_client") + @mock.patch("airflow.executors.kubernetes_executor_utils.KubernetesJobWatcher") + @mock.patch("airflow.kubernetes.kube_client.get_kube_client") def test_run_next_exception_requeue( self, mock_get_kube_client, mock_kubernetes_job_watcher, status, should_requeue ): @@ -319,7 +319,7 @@ def test_run_next_exception_requeue( AirflowKubernetesScheduler is None, reason="kubernetes python package is not installed" ) @mock.patch("airflow.settings.pod_mutation_hook") - @mock.patch("airflow.executors.kubernetes_executor.get_kube_client") + @mock.patch("airflow.kubernetes.kube_client.get_kube_client") def test_run_next_pmh_error(self, mock_get_kube_client, mock_pmh): """ Exception during Pod Mutation Hook execution should be handled gracefully. @@ -357,8 +357,8 @@ def test_run_next_pmh_error(self, mock_get_kube_client, mock_pmh): @pytest.mark.skipif( AirflowKubernetesScheduler is None, reason="kubernetes python package is not installed" ) - @mock.patch("airflow.executors.kubernetes_executor.KubernetesJobWatcher") - @mock.patch("airflow.executors.kubernetes_executor.get_kube_client") + @mock.patch("airflow.executors.kubernetes_executor_utils.KubernetesJobWatcher") + @mock.patch("airflow.kubernetes.kube_client.get_kube_client") def test_run_next_pod_reconciliation_error(self, mock_get_kube_client, mock_kubernetes_job_watcher): """ When construct_pod raises PodReconciliationError, we should fail the task. @@ -417,8 +417,8 @@ def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock ] mock_stats_gauge.assert_has_calls(calls) - @mock.patch("airflow.executors.kubernetes_executor.KubernetesJobWatcher") - @mock.patch("airflow.executors.kubernetes_executor.get_kube_client") + @mock.patch("airflow.executors.kubernetes_executor_utils.KubernetesJobWatcher") + @mock.patch("airflow.kubernetes.kube_client.get_kube_client") def test_invalid_executor_config(self, mock_get_kube_client, mock_kubernetes_job_watcher): executor = self.kubernetes_executor executor.start() @@ -443,8 +443,8 @@ def test_invalid_executor_config(self, mock_get_kube_client, mock_kubernetes_job @pytest.mark.skipif( AirflowKubernetesScheduler is None, reason="kubernetes python package is not installed" ) - @mock.patch("airflow.executors.kubernetes_executor.AirflowKubernetesScheduler.run_pod_async") - @mock.patch("airflow.executors.kubernetes_executor.get_kube_client") + @mock.patch("airflow.executors.kubernetes_executor_utils.AirflowKubernetesScheduler.run_pod_async") + @mock.patch("airflow.kubernetes.kube_client.get_kube_client") def test_pod_template_file_override_in_executor_config(self, mock_get_kube_client, mock_run_pod_async): current_folder = pathlib.Path(__file__).parent.resolve() template_file = str( @@ -528,8 +528,8 @@ def test_pod_template_file_override_in_executor_config(self, mock_get_kube_clien finally: executor.end() - @mock.patch("airflow.executors.kubernetes_executor.KubernetesJobWatcher") - @mock.patch("airflow.executors.kubernetes_executor.get_kube_client") + @mock.patch("airflow.executors.kubernetes_executor_utils.KubernetesJobWatcher") + @mock.patch("airflow.kubernetes.kube_client.get_kube_client") def test_change_state_running(self, mock_get_kube_client, mock_kubernetes_job_watcher): executor = self.kubernetes_executor executor.start() @@ -542,9 +542,9 @@ def test_change_state_running(self, mock_get_kube_client, mock_kubernetes_job_wa finally: executor.end() - @mock.patch("airflow.executors.kubernetes_executor.KubernetesJobWatcher") - @mock.patch("airflow.executors.kubernetes_executor.get_kube_client") - @mock.patch("airflow.executors.kubernetes_executor.AirflowKubernetesScheduler.delete_pod") + @mock.patch("airflow.executors.kubernetes_executor_utils.KubernetesJobWatcher") + @mock.patch("airflow.kubernetes.kube_client.get_kube_client") + @mock.patch("airflow.executors.kubernetes_executor_utils.AirflowKubernetesScheduler.delete_pod") def test_change_state_success(self, mock_delete_pod, mock_get_kube_client, mock_kubernetes_job_watcher): executor = self.kubernetes_executor executor.start() @@ -558,9 +558,9 @@ def test_change_state_success(self, mock_delete_pod, mock_get_kube_client, mock_ finally: executor.end() - @mock.patch("airflow.executors.kubernetes_executor.KubernetesJobWatcher") - @mock.patch("airflow.executors.kubernetes_executor.get_kube_client") - @mock.patch("airflow.executors.kubernetes_executor.AirflowKubernetesScheduler") + @mock.patch("airflow.executors.kubernetes_executor_utils.KubernetesJobWatcher") + @mock.patch("airflow.kubernetes.kube_client.get_kube_client") + @mock.patch("airflow.executors.kubernetes_executor_utils.AirflowKubernetesScheduler") def test_change_state_failed_no_deletion( self, mock_kubescheduler, mock_get_kube_client, mock_kubernetes_job_watcher ): @@ -584,9 +584,9 @@ def test_change_state_failed_no_deletion( @pytest.mark.parametrize( "ti_state", [TaskInstanceState.SUCCESS, TaskInstanceState.FAILED, TaskInstanceState.DEFERRED] ) - @mock.patch("airflow.executors.kubernetes_executor.KubernetesJobWatcher") - @mock.patch("airflow.executors.kubernetes_executor.get_kube_client") - @mock.patch("airflow.executors.kubernetes_executor.AirflowKubernetesScheduler.delete_pod") + @mock.patch("airflow.executors.kubernetes_executor_utils.KubernetesJobWatcher") + @mock.patch("airflow.kubernetes.kube_client.get_kube_client") + @mock.patch("airflow.executors.kubernetes_executor_utils.AirflowKubernetesScheduler.delete_pod") def test_change_state_none( self, mock_delete_pod, @@ -616,7 +616,7 @@ def test_change_state_none( pytest.param(None, ["ALL_NAMESPACES"]), ], ) - @mock.patch("airflow.executors.kubernetes_executor.get_kube_client") + @mock.patch("airflow.kubernetes.kube_client.get_kube_client") def test_watchers_under_multi_namespace_mode( self, mock_get_kube_client, multi_namespace_mode_namespace_list, watchers_keys ): @@ -632,9 +632,9 @@ def test_watchers_under_multi_namespace_mode( finally: executor.end() - @mock.patch("airflow.executors.kubernetes_executor.KubernetesJobWatcher") - @mock.patch("airflow.executors.kubernetes_executor.get_kube_client") - @mock.patch("airflow.executors.kubernetes_executor.AirflowKubernetesScheduler") + @mock.patch("airflow.executors.kubernetes_executor_utils.KubernetesJobWatcher") + @mock.patch("airflow.kubernetes.kube_client.get_kube_client") + @mock.patch("airflow.executors.kubernetes_executor_utils.AirflowKubernetesScheduler") def test_change_state_skip_pod_deletion( self, mock_kubescheduler, mock_get_kube_client, mock_kubernetes_job_watcher ): @@ -656,9 +656,9 @@ def test_change_state_skip_pod_deletion( finally: executor.end() - @mock.patch("airflow.executors.kubernetes_executor.KubernetesJobWatcher") - @mock.patch("airflow.executors.kubernetes_executor.get_kube_client") - @mock.patch("airflow.executors.kubernetes_executor.AirflowKubernetesScheduler") + @mock.patch("airflow.executors.kubernetes_executor_utils.KubernetesJobWatcher") + @mock.patch("airflow.kubernetes.kube_client.get_kube_client") + @mock.patch("airflow.executors.kubernetes_executor_utils.AirflowKubernetesScheduler") def test_change_state_failed_pod_deletion( self, mock_kubescheduler, mock_get_kube_client, mock_kubernetes_job_watcher ): @@ -778,7 +778,7 @@ def test_try_adopt_task_instances_no_matching_pods( mock_adopt_launched_task.assert_not_called() mock_adopt_completed_pods.assert_called_once() - @mock.patch("airflow.executors.kubernetes_executor.get_kube_client") + @mock.patch("airflow.kubernetes.kube_client.get_kube_client") def test_adopt_launched_task(self, mock_kube_client): executor = self.kubernetes_executor executor.scheduler_job_id = "modified" @@ -803,7 +803,7 @@ def test_adopt_launched_task(self, mock_kube_client): assert tis_to_flush_by_key == {} assert executor.running == {ti_key} - @mock.patch("airflow.executors.kubernetes_executor.get_kube_client") + @mock.patch("airflow.kubernetes.kube_client.get_kube_client") def test_adopt_launched_task_api_exception(self, mock_kube_client): """We shouldn't think we are running the task if aren't able to patch the pod""" executor = self.kubernetes_executor @@ -828,7 +828,7 @@ def test_adopt_launched_task_api_exception(self, mock_kube_client): assert tis_to_flush_by_key == {ti_key: {}} assert executor.running == set() - @mock.patch("airflow.executors.kubernetes_executor.get_kube_client") + @mock.patch("airflow.kubernetes.kube_client.get_kube_client") def test_adopt_completed_pods(self, mock_kube_client): """We should adopt all completed pods from other schedulers""" executor = self.kubernetes_executor @@ -878,7 +878,7 @@ def get_annotations(pod_name): ) assert executor.running == expected_running_ti_keys - @mock.patch("airflow.executors.kubernetes_executor.get_kube_client") + @mock.patch("airflow.kubernetes.kube_client.get_kube_client") def test_not_adopt_unassigned_task(self, mock_kube_client): """ We should not adopt any tasks that were not assigned by the scheduler. @@ -904,8 +904,8 @@ def test_not_adopt_unassigned_task(self, mock_kube_client): assert not mock_kube_client.patch_namespaced_pod.called assert tis_to_flush_by_key == {"foobar": {}} - @mock.patch("airflow.executors.kubernetes_executor.get_kube_client") - @mock.patch("airflow.executors.kubernetes_executor.AirflowKubernetesScheduler.delete_pod") + @mock.patch("airflow.kubernetes.kube_client.get_kube_client") + @mock.patch("airflow.executors.kubernetes_executor_utils.AirflowKubernetesScheduler.delete_pod") def test_cleanup_stuck_queued_tasks(self, mock_delete_pod, mock_kube_client, dag_maker, session): """Delete any pods associated with a task stuck in queued.""" executor = KubernetesExecutor() @@ -1232,7 +1232,7 @@ def setup_method(self): self.events = [] def _run(self): - with mock.patch("airflow.executors.kubernetes_executor.watch") as mock_watch: + with mock.patch("airflow.executors.kubernetes_executor_utils.watch") as mock_watch: mock_watch.Watch.return_value.stream.return_value = self.events latest_resource_version = self.watcher._run( self.kube_client, @@ -1362,7 +1362,7 @@ def effect(): self.watcher._run = mock_underscore_run - with mock.patch("airflow.executors.kubernetes_executor.get_kube_client"): + with mock.patch("airflow.executors.kubernetes_executor_utils.get_kube_client"): try: # self.watcher._run() is mocked and return "500" as last resource_version self.watcher.run()