Skip to content

Commit

Permalink
Kubernetes executor can adopt tasks from other schedulers (#10996)
Browse files Browse the repository at this point in the history
* KubernetesExecutor can adopt tasks from other schedulers

* simplify

* recreate tables properly

* fix pylint

Co-authored-by: Daniel Imberman <[email protected]>
  • Loading branch information
dimberman and astro-sql-decorator authored Oct 1, 2020
1 parent 427a4a8 commit 3ca11eb
Show file tree
Hide file tree
Showing 13 changed files with 330 additions and 218 deletions.
6 changes: 3 additions & 3 deletions airflow/cli/commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def generate_pod_yaml(args):
"""Generates yaml files for each task in the DAG. Used for testing output of KubernetesExecutor"""
from kubernetes.client.api_client import ApiClient

from airflow.executors.kubernetes_executor import AirflowKubernetesScheduler, KubeConfig
from airflow.executors.kubernetes_executor import KubeConfig, create_pod_id
from airflow.kubernetes import pod_generator
from airflow.kubernetes.pod_generator import PodGenerator
from airflow.settings import pod_mutation_hook
Expand All @@ -399,14 +399,14 @@ def generate_pod_yaml(args):
pod = PodGenerator.construct_pod(
dag_id=args.dag_id,
task_id=ti.task_id,
pod_id=AirflowKubernetesScheduler._create_pod_id( # pylint: disable=W0212
pod_id=create_pod_id(
args.dag_id, ti.task_id),
try_number=ti.try_number,
kube_image=kube_config.kube_image,
date=ti.execution_date,
command=ti.command_as_list(),
pod_override_object=PodGenerator.from_obj(ti.executor_config),
worker_uuid="worker-config",
scheduler_job_id="worker-config",
namespace=kube_config.executor_namespace,
base_worker_pod=PodGenerator.deserialize_model_file(kube_config.pod_template_file)
)
Expand Down
2 changes: 2 additions & 0 deletions airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class BaseExecutor(LoggingMixin):
``0`` for infinity
"""

job_id: Optional[str] = None

def __init__(self, parallelism: int = PARALLELISM):
super().__init__()
self.parallelism: int = parallelism
Expand Down
192 changes: 140 additions & 52 deletions airflow/executors/kubernetes_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import multiprocessing
import time
from queue import Empty, Queue # pylint: disable=unused-import
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import kubernetes
from dateutil import parser
Expand All @@ -44,7 +44,7 @@
from airflow.kubernetes.kube_client import get_kube_client
from airflow.kubernetes.pod_generator import MAX_POD_ID_LEN, PodGenerator
from airflow.kubernetes.pod_launcher import PodLauncher
from airflow.models import KubeResourceVersion, KubeWorkerIdentifier, TaskInstance
from airflow.models import TaskInstance
from airflow.models.taskinstance import TaskInstanceKey
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import provide_session
Expand All @@ -60,6 +60,18 @@
KubernetesWatchType = Tuple[str, str, Optional[str], Dict[str, str], str]


class ResourceVersion:
"""Singleton for tracking resourceVersion from Kubernetes"""

_instance = None
resource_version = "0"

def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance


class KubeConfig: # pylint: disable=too-many-instance-attributes
"""Configuration for Kubernetes"""

Expand Down Expand Up @@ -134,25 +146,25 @@ def __init__(self,
multi_namespace_mode: bool,
watcher_queue: 'Queue[KubernetesWatchType]',
resource_version: Optional[str],
worker_uuid: Optional[str],
scheduler_job_id: Optional[str],
kube_config: Configuration):
super().__init__()
self.namespace = namespace
self.multi_namespace_mode = multi_namespace_mode
self.worker_uuid = worker_uuid
self.scheduler_job_id = scheduler_job_id
self.watcher_queue = watcher_queue
self.resource_version = resource_version
self.kube_config = kube_config

def run(self) -> None:
"""Performs watching"""
kube_client: client.CoreV1Api = get_kube_client()
if not self.worker_uuid:
if not self.scheduler_job_id:
raise AirflowException(NOT_STARTED_MESSAGE)
while True:
try:
self.resource_version = self._run(kube_client, self.resource_version,
self.worker_uuid, self.kube_config)
self.scheduler_job_id, self.kube_config)
except ReadTimeoutError:
self.log.warning("There was a timeout error accessing the Kube API. "
"Retrying request.", exc_info=True)
Expand All @@ -167,15 +179,15 @@ def run(self) -> None:
def _run(self,
kube_client: client.CoreV1Api,
resource_version: Optional[str],
worker_uuid: str,
scheduler_job_id: str,
kube_config: Any) -> Optional[str]:
self.log.info(
'Event: and now my watch begins starting at resource_version: %s',
resource_version
)
watcher = watch.Watch()

kwargs = {'label_selector': 'airflow-worker={}'.format(worker_uuid)}
kwargs = {'label_selector': 'airflow-worker={}'.format(scheduler_job_id)}
if resource_version:
kwargs['resource_version'] = resource_version
if kube_config.kube_client_request_args:
Expand Down Expand Up @@ -277,7 +289,7 @@ def __init__(self,
task_queue: 'Queue[KubernetesJobType]',
result_queue: 'Queue[KubernetesResultsType]',
kube_client: client.CoreV1Api,
worker_uuid: str):
scheduler_job_id: str):
super().__init__()
self.log.debug("Creating Kubernetes executor")
self.kube_config = kube_config
Expand All @@ -289,16 +301,16 @@ def __init__(self,
self.launcher = PodLauncher(kube_client=self.kube_client)
self._manager = multiprocessing.Manager()
self.watcher_queue = self._manager.Queue()
self.worker_uuid = worker_uuid
self.scheduler_job_id = scheduler_job_id
self.kube_watcher = self._make_kube_watcher()

def _make_kube_watcher(self) -> KubernetesJobWatcher:
resource_version = KubeResourceVersion.get_current_resource_version()
resource_version = ResourceVersion().resource_version
watcher = KubernetesJobWatcher(watcher_queue=self.watcher_queue,
namespace=self.kube_config.kube_namespace,
multi_namespace_mode=self.kube_config.multi_namespace_mode,
resource_version=resource_version,
worker_uuid=self.worker_uuid,
scheduler_job_id=self.scheduler_job_id,
kube_config=self.kube_config)
watcher.start()
return watcher
Expand Down Expand Up @@ -333,8 +345,8 @@ def run_next(self, next_job: KubernetesJobType) -> None:

pod = PodGenerator.construct_pod(
namespace=self.namespace,
worker_uuid=self.worker_uuid,
pod_id=self._create_pod_id(dag_id, task_id),
scheduler_job_id=self.scheduler_job_id,
pod_id=create_pod_id(dag_id, task_id),
dag_id=dag_id,
task_id=task_id,
kube_image=self.kube_config.kube_image,
Expand Down Expand Up @@ -404,21 +416,6 @@ def _annotations_to_key(self, annotations: Dict[str, str]) -> Optional[TaskInsta

return TaskInstanceKey(dag_id, task_id, execution_date, try_number)

@staticmethod
def _strip_unsafe_kubernetes_special_chars(string: str) -> str:
"""
Kubernetes only supports lowercase alphanumeric characters and "-" and "." in
the pod name
However, there are special rules about how "-" and "." can be used so let's
only keep
alphanumeric chars see here for detail:
https://kubernetes.io/docs/concepts/overview/working-with-objects/names/
:param string: The requested Pod name
:return: ``str`` Pod name stripped of any unsafe characters
"""
return ''.join(ch.lower() for ind, ch in enumerate(string) if ch.isalnum())

@staticmethod
def _make_safe_pod_id(safe_dag_id: str, safe_task_id: str, safe_uuid: str) -> str:
r"""
Expand All @@ -437,14 +434,6 @@ def _make_safe_pod_id(safe_dag_id: str, safe_task_id: str, safe_uuid: str) -> st

return safe_pod_id

@staticmethod
def _create_pod_id(dag_id: str, task_id: str) -> str:
safe_dag_id = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars(
dag_id)
safe_task_id = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars(
task_id)
return safe_dag_id + safe_task_id

def _flush_watcher_queue(self) -> None:
self.log.debug('Executor shutting down, watcher_queue approx. size=%d', self.watcher_queue.qsize())
while True:
Expand All @@ -470,6 +459,36 @@ def terminate(self) -> None:
self._manager.shutdown()


def _strip_unsafe_kubernetes_special_chars(string: str) -> str:
"""
Kubernetes only supports lowercase alphanumeric characters, "-" and "." in
the pod name.
However, there are special rules about how "-" and "." can be used so let's
only keep
alphanumeric chars see here for detail:
https://kubernetes.io/docs/concepts/overview/working-with-objects/names/
:param string: The requested Pod name
:return: ``str`` Pod name stripped of any unsafe characters
"""
return ''.join(ch.lower() for ind, ch in enumerate(string) if ch.isalnum())


def create_pod_id(dag_id: str, task_id: str) -> str:
"""
Generates the kubernetes safe pod_id. Note that this is
NOT the full ID that will be launched to k8s. We will add a uuid
to ensure uniqueness.
:param dag_id: DAG ID
:param task_id: Task ID
:@return: The non-unique pod_id for this task/DAG pairing1
"""
safe_dag_id = _strip_unsafe_kubernetes_special_chars(dag_id)
safe_task_id = _strip_unsafe_kubernetes_special_chars(task_id)
return safe_dag_id + safe_task_id


class KubernetesExecutor(BaseExecutor, LoggingMixin):
"""Executor for Kubernetes"""

Expand All @@ -480,7 +499,7 @@ def __init__(self):
self.result_queue: 'Queue[KubernetesResultsType]' = self._manager.Queue()
self.kube_scheduler: Optional[AirflowKubernetesScheduler] = None
self.kube_client: Optional[client.CoreV1Api] = None
self.worker_uuid: Optional[str] = None
self.scheduler_job_id: Optional[str] = None
super().__init__(parallelism=self.kube_config.parallelism)

@provide_session
Expand Down Expand Up @@ -519,7 +538,7 @@ def clear_not_launched_queued_tasks(self, session=None) -> None:
pod_generator.datetime_to_label_safe_datestring(
task.execution_date
),
self.worker_uuid
self.scheduler_job_id
)
)
# pylint: enable=protected-access
Expand Down Expand Up @@ -568,19 +587,14 @@ def _create_or_update_secret(secret_name, secret_path):
def start(self) -> None:
"""Starts the executor"""
self.log.info('Start Kubernetes executor')
self.worker_uuid = KubeWorkerIdentifier.get_or_create_current_kube_worker_uuid()
if not self.worker_uuid:
raise AirflowException("Could not get worker uuid")
self.log.debug('Start with worker_uuid: %s', self.worker_uuid)
# always need to reset resource version since we don't know
# when we last started, note for behavior below
# https://github.com/kubernetes-client/python/blob/master/kubernetes/docs
# /CoreV1Api.md#list_namespaced_pod
KubeResourceVersion.reset_resource_version()
if not self.job_id:
raise AirflowException("Could not get scheduler_job_id")
self.scheduler_job_id = self.job_id
self.log.debug('Start with scheduler_job_id: %s', self.scheduler_job_id)
self.kube_client = get_kube_client()
self.kube_scheduler = AirflowKubernetesScheduler(
self.kube_config, self.task_queue, self.result_queue,
self.kube_client, self.worker_uuid
self.kube_client, self.scheduler_job_id
)
self._inject_secrets()
self.clear_not_launched_queued_tasks()
Expand All @@ -595,10 +609,10 @@ def execute_async(self,
'Add task %s with command %s with executor_config %s',
key, command, executor_config
)

kube_executor_config = PodGenerator.from_obj(executor_config)
if not self.task_queue:
raise AirflowException(NOT_STARTED_MESSAGE)
self.event_buffer[key] = (State.QUEUED, self.scheduler_job_id)
self.task_queue.put((key, command, kube_executor_config))

def sync(self) -> None:
Expand All @@ -607,7 +621,7 @@ def sync(self) -> None:
self.log.debug('self.running: %s', self.running)
if self.queued_tasks:
self.log.debug('self.queued: %s', self.queued_tasks)
if not self.worker_uuid:
if not self.scheduler_job_id:
raise AirflowException(NOT_STARTED_MESSAGE)
if not self.kube_scheduler:
raise AirflowException(NOT_STARTED_MESSAGE)
Expand Down Expand Up @@ -640,7 +654,8 @@ def sync(self) -> None:
except Empty:
break

KubeResourceVersion.checkpoint_resource_version(last_resource_version)
resource_instance = ResourceVersion()
resource_instance.resource_version = last_resource_version or resource_instance.resource_version

# pylint: disable=too-many-nested-blocks
for _ in range(self.kube_config.worker_pods_creation_batch_size):
Expand Down Expand Up @@ -681,6 +696,79 @@ def _change_state(self,
self.log.debug('Could not find key: %s', str(key))
self.event_buffer[key] = state, None

def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance]:
tis_to_flush = [ti for ti in tis if not ti.external_executor_id]
scheduler_job_ids = [ti.external_executor_id for ti in tis]
pod_ids = {
create_pod_id(dag_id=ti.dag_id, task_id=ti.task_id): ti
for ti in tis if ti.external_executor_id
}
kube_client: client.CoreV1Api = self.kube_client
for scheduler_job_id in scheduler_job_ids:
kwargs = {
'label_selector': f'airflow-worker={scheduler_job_id}'
}
pod_list = kube_client.list_namespaced_pod(
namespace=self.kube_config.kube_namespace,
**kwargs
)
for pod in pod_list.items:
self.adopt_launched_task(kube_client, pod, pod_ids)
self._adopt_completed_pods(kube_client)
tis_to_flush.extend(pod_ids.values())
return tis_to_flush

def adopt_launched_task(self, kube_client, pod, pod_ids: dict):
"""
Patch existing pod so that the current KubernetesJobWatcher can monitor it via label selectors
:param kube_client: kubernetes client for speaking to kube API
:param pod: V1Pod spec that we will patch with new label
:param pod_ids: pod_ids we expect to patch.
"""
self.log.info("attempting to adopt pod %s", pod.metadata.name)
pod.metadata.labels['airflow-worker'] = str(self.scheduler_job_id)
dag_id = pod.metadata.labels['dag_id']
task_id = pod.metadata.labels['task_id']
pod_id = create_pod_id(dag_id=dag_id, task_id=task_id)
if pod_id not in pod_ids:
self.log.error("attempting to adopt task %s in dag %s"
" which was not specified by database", task_id, dag_id)
else:
try:
kube_client.patch_namespaced_pod(
name=pod.metadata.name,
namespace=pod.metadata.namespace,
body=PodGenerator.serialize_pod(pod),
)
pod_ids.pop(pod_id)
except ApiException as e:
self.log.info("Failed to adopt pod %s. Reason: %s", pod.metadata.name, e)

def _adopt_completed_pods(self, kube_client: kubernetes.client.CoreV1Api):
"""
Patch completed pod so that the KubernetesJobWatcher can delete it.
:param kube_client: kubernetes client for speaking to kube API
"""
kwargs = {
'field_selector': "status.phase=Succeeded",
'label_selector': 'kubernetes_executor=True',
}
pod_list = kube_client.list_namespaced_pod(namespace=self.kube_config.kube_namespace, **kwargs)
for pod in pod_list.items:
self.log.info("Attempting to adopt pod %s", pod.metadata.name)
pod.metadata.labels['airflow-worker'] = str(self.scheduler_job_id)
try:
kube_client.patch_namespaced_pod(
name=pod.metadata.name,
namespace=pod.metadata.namespace,
body=PodGenerator.serialize_pod(pod),
)
except ApiException as e:
self.log.info("Failed to adopt pod %s. Reason: %s", pod.metadata.name, e)

def _flush_task_queue(self) -> None:
if not self.task_queue:
raise AirflowException(NOT_STARTED_MESSAGE)
Expand Down
Loading

0 comments on commit 3ca11eb

Please sign in to comment.