Skip to content

Commit

Permalink
[k8s] Fix mounting when launching from a service account (#3532)
Browse files Browse the repository at this point in the history
* wip

* wip

* smoke tests

* move service account constant

* comment

* add comparisons for k8s object equality checks

* lint
  • Loading branch information
romilbhardwaj authored May 10, 2024
1 parent 05aafb8 commit 8a0a34d
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 43 deletions.
10 changes: 10 additions & 0 deletions sky/adaptors/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
_custom_objects_api = None
_node_api = None
_apps_api = None
_api_client = None

# Timeout to use for API calls
API_TIMEOUT = 5
Expand Down Expand Up @@ -118,6 +119,15 @@ def apps_api():
return _apps_api


def api_client():
global _api_client
if _api_client is None:
_load_config()
_api_client = kubernetes.client.ApiClient()

return _api_client


def api_exception():
return kubernetes.client.rest.ApiException

Expand Down
8 changes: 4 additions & 4 deletions sky/clouds/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
# Namespace for SkyPilot resources shared across multiple tenants on the
# same cluster (even if they might be running in different namespaces).
# E.g., FUSE device manager daemonset is run in this namespace.
_SKY_SYSTEM_NAMESPACE = 'skypilot-system'
_SKYPILOT_SYSTEM_NAMESPACE = 'skypilot-system'


@clouds.CLOUD_REGISTRY.register
Expand All @@ -38,7 +38,6 @@ class Kubernetes(clouds.Cloud):

SKY_SSH_KEY_SECRET_NAME = 'sky-ssh-keys'
SKY_SSH_JUMP_NAME = 'sky-ssh-jump-pod'
SKY_DEFAULT_SERVICE_ACCOUNT_NAME = 'skypilot-service-account'
PORT_FORWARD_PROXY_CMD_TEMPLATE = \
'kubernetes-port-forward-proxy-command.sh.j2'
PORT_FORWARD_PROXY_CMD_PATH = '~/.sky/port-forward-proxy-cmd.sh'
Expand Down Expand Up @@ -284,7 +283,8 @@ def make_deploy_resources_variables(
elif (remote_identity ==
schemas.RemoteIdentityOptions.SERVICE_ACCOUNT.value):
# Use the default service account
k8s_service_account_name = self.SKY_DEFAULT_SERVICE_ACCOUNT_NAME
k8s_service_account_name = (
kubernetes_utils.DEFAULT_SERVICE_ACCOUNT_NAME)
k8s_automount_sa_token = 'true'
else:
# User specified a custom service account
Expand Down Expand Up @@ -313,7 +313,7 @@ def make_deploy_resources_variables(
'k8s_automount_sa_token': k8s_automount_sa_token,
'k8s_fuse_device_required': fuse_device_required,
# Namespace to run the FUSE device manager in
'k8s_fuse_device_manager_namespace': _SKY_SYSTEM_NAMESPACE,
'k8s_skypilot_system_namespace': _SKYPILOT_SYSTEM_NAMESPACE,
'image_id': image_id,
}

Expand Down
173 changes: 139 additions & 34 deletions sky/provision/kubernetes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import math
import os
from typing import Any, Dict, Union
from typing import Any, Dict, Optional, Union

import yaml

Expand All @@ -27,11 +27,9 @@ def bootstrap_instances(

config = _configure_ssh_jump(namespace, config)

if config.provider_config.get('fuse_device_required', False):
_configure_fuse_mounting(config.provider_config)

requested_service_account = config.node_config['spec']['serviceAccountName']
if requested_service_account == 'skypilot-service-account':
if (requested_service_account ==
kubernetes_utils.DEFAULT_SERVICE_ACCOUNT_NAME):
# If the user has requested a different service account (via pod_config
# in ~/.sky/config.yaml), we assume they have already set up the
# necessary roles and role bindings.
Expand Down Expand Up @@ -71,6 +69,26 @@ def bootstrap_instances(
elif requested_service_account != 'default':
logger.info(f'Using service account {requested_service_account!r}, '
'skipping role and role binding setup.')

# SkyPilot system namespace is required for FUSE mounting. Here we just
# create the namespace and set up the necessary permissions.
#
# We need to setup the namespace outside the if block below because if
# we put in the if block, the following happens:
# 1. User launches job controller on Kubernetes with SERVICE_ACCOUNT. No
# namespace is created at this point since the controller does not
# require FUSE.
# 2. User submits a job requiring FUSE.
# 3. The namespace is created here, but since the job controller is using
# SERVICE_ACCOUNT, it does not have the necessary permissions to create
# a role for itself to create the FUSE device manager.
# 4. The job fails to launch.
_configure_skypilot_system_namespace(config.provider_config,
requested_service_account)

if config.provider_config.get('fuse_device_required', False):
_configure_fuse_mounting(config.provider_config)

return config


Expand Down Expand Up @@ -238,6 +256,9 @@ def _configure_autoscaler_service_account(
namespace, field_selector=field_selector).items)
if len(accounts) > 0:
assert len(accounts) == 1
# Nothing to check for equality and patch here,
# since the service_account.metadata.name is the only important
# attribute, which is already filtered for above.
logger.info('_configure_autoscaler_service_account: '
f'{using_existing_msg(account_field, name)}')
return
Expand Down Expand Up @@ -272,12 +293,20 @@ def _configure_autoscaler_role(namespace: str, provider_config: Dict[str, Any],

name = role['metadata']['name']
field_selector = f'metadata.name={name}'
accounts = (kubernetes.auth_api().list_namespaced_role(
roles = (kubernetes.auth_api().list_namespaced_role(
namespace, field_selector=field_selector).items)
if len(accounts) > 0:
assert len(accounts) == 1
if len(roles) > 0:
assert len(roles) == 1
existing_role = roles[0]
# Convert to k8s object to compare
new_role = kubernetes_utils.dict_to_k8s_object(role, 'V1Role')
if new_role.rules == existing_role.rules:
logger.info('_configure_autoscaler_role: '
f'{using_existing_msg(role_field, name)}')
return
logger.info('_configure_autoscaler_role: '
f'{using_existing_msg(role_field, name)}')
f'{updating_existing_msg(role_field, name)}')
kubernetes.auth_api().patch_namespaced_role(name, namespace, role)
return

logger.info('_configure_autoscaler_role: '
Expand All @@ -286,9 +315,12 @@ def _configure_autoscaler_role(namespace: str, provider_config: Dict[str, Any],
logger.info(f'_configure_autoscaler_role: {created_msg(role_field, name)}')


def _configure_autoscaler_role_binding(namespace: str,
provider_config: Dict[str, Any],
binding_field: str) -> None:
def _configure_autoscaler_role_binding(
namespace: str,
provider_config: Dict[str, Any],
binding_field: str,
override_name: Optional[str] = None,
override_subject_namespace: Optional[str] = None) -> None:
""" Reads the role binding from the config, creates if it does not exist.
Args:
Expand All @@ -309,22 +341,37 @@ def _configure_autoscaler_role_binding(namespace: str,
else:
rb_namespace = binding['metadata']['namespace']

# If override_subject_namespace is provided, we will use that
# namespace for the subject. Otherwise, we will raise an error.
subject_namespace = override_subject_namespace or namespace
for subject in binding['subjects']:
if 'namespace' not in subject:
subject['namespace'] = namespace
elif subject['namespace'] != namespace:
subject['namespace'] = subject_namespace
elif subject['namespace'] != subject_namespace:
subject_name = subject['name']
raise InvalidNamespaceError(
binding_field + f' subject {subject_name}', namespace)

# Override name if provided
binding['metadata']['name'] = override_name or binding['metadata']['name']
name = binding['metadata']['name']

field_selector = f'metadata.name={name}'
accounts = (kubernetes.auth_api().list_namespaced_role_binding(
role_bindings = (kubernetes.auth_api().list_namespaced_role_binding(
rb_namespace, field_selector=field_selector).items)
if len(accounts) > 0:
assert len(accounts) == 1
if len(role_bindings) > 0:
assert len(role_bindings) == 1
existing_binding = role_bindings[0]
new_rb = kubernetes_utils.dict_to_k8s_object(binding, 'V1RoleBinding')
if (new_rb.role_ref == existing_binding.role_ref and
new_rb.subjects == existing_binding.subjects):
logger.info('_configure_autoscaler_role_binding: '
f'{using_existing_msg(binding_field, name)}')
return
logger.info('_configure_autoscaler_role_binding: '
f'{using_existing_msg(binding_field, name)}')
f'{updating_existing_msg(binding_field, name)}')
kubernetes.auth_api().patch_namespaced_role_binding(
name, rb_namespace, binding)
return

logger.info('_configure_autoscaler_role_binding: '
Expand All @@ -350,12 +397,19 @@ def _configure_autoscaler_cluster_role(namespace,

name = role['metadata']['name']
field_selector = f'metadata.name={name}'
accounts = (kubernetes.auth_api().list_cluster_role(
cluster_roles = (kubernetes.auth_api().list_cluster_role(
field_selector=field_selector).items)
if len(accounts) > 0:
assert len(accounts) == 1
if len(cluster_roles) > 0:
assert len(cluster_roles) == 1
existing_cr = cluster_roles[0]
new_cr = kubernetes_utils.dict_to_k8s_object(role, 'V1ClusterRole')
if new_cr.rules == existing_cr.rules:
logger.info('_configure_autoscaler_cluster_role: '
f'{using_existing_msg(role_field, name)}')
return
logger.info('_configure_autoscaler_cluster_role: '
f'{using_existing_msg(role_field, name)}')
f'{updating_existing_msg(role_field, name)}')
kubernetes.auth_api().patch_cluster_role(name, role)
return

logger.info('_configure_autoscaler_cluster_role: '
Expand Down Expand Up @@ -388,12 +442,21 @@ def _configure_autoscaler_cluster_role_binding(

name = binding['metadata']['name']
field_selector = f'metadata.name={name}'
accounts = (kubernetes.auth_api().list_cluster_role_binding(
cr_bindings = (kubernetes.auth_api().list_cluster_role_binding(
field_selector=field_selector).items)
if len(accounts) > 0:
assert len(accounts) == 1
if len(cr_bindings) > 0:
assert len(cr_bindings) == 1
existing_binding = cr_bindings[0]
new_binding = kubernetes_utils.dict_to_k8s_object(
binding, 'V1ClusterRoleBinding')
if (new_binding.role_ref == existing_binding.role_ref and
new_binding.subjects == existing_binding.subjects):
logger.info('_configure_autoscaler_cluster_role_binding: '
f'{using_existing_msg(binding_field, name)}')
return
logger.info('_configure_autoscaler_cluster_role_binding: '
f'{using_existing_msg(binding_field, name)}')
f'{updating_existing_msg(binding_field, name)}')
kubernetes.auth_api().patch_cluster_role_binding(name, binding)
return

logger.info('_configure_autoscaler_cluster_role_binding: '
Expand Down Expand Up @@ -438,6 +501,48 @@ def _configure_ssh_jump(namespace, config: common.ProvisionConfig):
return config


def _configure_skypilot_system_namespace(
provider_config: Dict[str,
Any], service_account: Optional[str]) -> None:
"""Creates the namespace for skypilot-system mounting if it does not exist.
Also patches the SkyPilot service account to have the necessary permissions
to manage resources in the namespace.
"""
svc_account_namespace = provider_config['namespace']
skypilot_system_namespace = provider_config['skypilot_system_namespace']
kubernetes_utils.create_namespace(skypilot_system_namespace)

# Setup permissions if using the default service account.
# If the user has requested a different service account (via
# remote_identity in ~/.sky/config.yaml), we assume they have already set
# up the necessary roles and role bindings.
if service_account == kubernetes_utils.DEFAULT_SERVICE_ACCOUNT_NAME:
# Note - this must be run only after the service account has been
# created in the cluster (in bootstrap_instances).
# Create the role in the skypilot-system namespace if it does not exist.
_configure_autoscaler_role(skypilot_system_namespace,
provider_config,
role_field='autoscaler_skypilot_system_role')
# We must create a unique role binding per-namespace that SkyPilot is
# running in, so we override the name with a unique name identifying
# the namespace. This is required for multi-tenant setups where
# different SkyPilot instances may be running in different namespaces.
override_name = provider_config[
'autoscaler_skypilot_system_role_binding']['metadata'][
'name'] + '-' + svc_account_namespace

# Create the role binding in the skypilot-system namespace, and have
# the subject namespace be the namespace that the SkyPilot service
# account is created in.
_configure_autoscaler_role_binding(
skypilot_system_namespace,
provider_config,
binding_field='autoscaler_skypilot_system_role_binding',
override_name=override_name,
override_subject_namespace=svc_account_namespace)


def _configure_fuse_mounting(provider_config: Dict[str, Any]) -> None:
"""Creates sidecars required for FUSE mounting.
Expand All @@ -446,17 +551,15 @@ def _configure_fuse_mounting(provider_config: Dict[str, Any]) -> None:
which exposes the host /dev/fuse device as a Kubernetes resource. The
SkyPilot pod requests this resource to mount the FUSE filesystem.
We create this daemonset in a common namespace, which is configurable in the
provider config. This allows the FUSE mounting sidecar to be shared across
multiple tenants. The default namespace is 'sky-system' (populated in
clouds.Kubernetes)
We create this daemonset in the skypilot_system_namespace, which is
configurable in the provider config. This allows the FUSE mounting sidecar
to be shared across multiple tenants. The default namespace is
'skypilot-system' (populated in clouds.Kubernetes).
"""

logger.info('_configure_fuse_mounting: Setting up FUSE device manager.')

fuse_device_manager_namespace = provider_config.get(
'fuse_device_manager_namespace', 'default')
kubernetes_utils.create_namespace(fuse_device_manager_namespace)
fuse_device_manager_namespace = provider_config['skypilot_system_namespace']

# Read the device manager YAMLs from the manifests directory
root_dir = os.path.dirname(os.path.dirname(__file__))
Expand Down Expand Up @@ -526,7 +629,9 @@ def _configure_services(namespace: str, provider_config: Dict[str,
if len(services) > 0:
assert len(services) == 1
existing_service = services[0]
if service == existing_service:
# Convert to k8s object to compare
new_svc = kubernetes_utils.dict_to_k8s_object(service, 'V1Service')
if new_svc.spec.ports == existing_service.spec.ports:
logger.info('_configure_services: '
f'{using_existing_msg("service", name)}')
return
Expand Down
24 changes: 24 additions & 0 deletions sky/provision/kubernetes/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Kubernetes utilities for SkyPilot."""
import json
import math
import os
import re
Expand All @@ -21,8 +22,11 @@
from sky.utils import schemas
from sky.utils import ux_utils

# TODO(romilb): Move constants to constants.py
DEFAULT_NAMESPACE = 'default'

DEFAULT_SERVICE_ACCOUNT_NAME = 'skypilot-service-account'

MEMORY_SIZE_UNITS = {
'B': 1,
'K': 2**10,
Expand Down Expand Up @@ -1443,3 +1447,23 @@ def get_autoscaler_type(
autoscaler_type = kubernetes_enums.KubernetesAutoscalerType(
autoscaler_type)
return autoscaler_type


def dict_to_k8s_object(object_dict: Dict[str, Any], object_type: 'str') -> Any:
"""Converts a dictionary to a Kubernetes object.
Useful for comparing two Kubernetes objects. Adapted from
https://github.com/kubernetes-client/python/issues/977#issuecomment-592030030 # pylint: disable=line-too-long
Args:
object_dict: Dictionary representing the Kubernetes object
object_type: Type of the Kubernetes object. E.g., 'V1Pod', 'V1Service'.
"""

class FakeKubeResponse:

def __init__(self, obj):
self.data = json.dumps(obj)

fake_kube_response = FakeKubeResponse(object_dict)
return kubernetes.api_client().deserialize(fake_kube_response, object_type)
Loading

0 comments on commit 8a0a34d

Please sign in to comment.