Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[k8s] Add sky status flag to query global Kubernetes status #4040

Merged
merged 24 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 92 additions & 3 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1458,6 +1458,79 @@ def _get_services(service_names: Optional[List[str]],
return num_services, msg


def _status_kubernetes(show_all: bool):
romilbhardwaj marked this conversation as resolved.
Show resolved Hide resolved
"""Show all SkyPilot resources in the current Kubernetes context.

Args:
show_all (bool): Show all job information (e.g., start time, failures).
"""
context = kubernetes_utils.get_current_kube_config_context_name()
try:
pods = kubernetes_utils.get_skypilot_pods(context)
except exceptions.ResourcesUnavailableError as e:
with ux_utils.print_exception_no_traceback():
raise ValueError('Failed to get SkyPilot pods from '
f'Kubernetes: {str(e)}') from e
all_clusters, jobs_controllers, serve_controllers = (
status_utils.process_skypilot_pods(pods, context))
all_jobs = []
with rich_utils.safe_status(
'[bold cyan]Checking in-progress managed jobs[/]') as spinner:
for i, (_, job_controller_info) in enumerate(jobs_controllers.items()):
user = job_controller_info['user']
pod = job_controller_info['pods'][0]
status_message = ('[bold cyan]Checking managed jobs controller')
if len(jobs_controllers) > 1:
status_message += f's ({i+1}/{len(jobs_controllers)})'
spinner.update(f'{status_message}[/]')
try:
job_list = managed_jobs.queue_from_kubernetes_pod(
pod.metadata.name)
except RuntimeError as e:
logger.warning('Failed to get managed jobs from controller '
f'{pod.metadata.name}: {str(e)}')
job_list = []
# Add user field to jobs
for job in job_list:
job['user'] = user
all_jobs.extend(job_list)
# Reconcile cluster state between managed jobs and clusters:
# To maintain a clear separation between regular SkyPilot clusters
# and those from managed jobs, we need to exclude the latter from
# the main cluster list.
# We do this by reconstructing managed job cluster names from each
# job's name and ID. We then use this set to filter out managed
# clusters from the main cluster list. This is necessary because there
# are no identifiers distinguishing clusters from managed jobs from
# regular clusters.
managed_job_cluster_names = set()
for job in all_jobs:
# Managed job cluster name is <job_name>-<job_id>
managed_cluster_name = f'{job["job_name"]}-{job["job_id"]}'
managed_job_cluster_names.add(managed_cluster_name)
unmanaged_clusters = [
c for c in all_clusters
if c['cluster_name'] not in managed_job_cluster_names
]
click.echo(f'{colorama.Fore.CYAN}{colorama.Style.BRIGHT}'
f'Kubernetes cluster state (context: {context})'
f'{colorama.Style.RESET_ALL}')
status_utils.show_kubernetes_cluster_status_table(unmanaged_clusters,
show_all)
if all_jobs:
click.echo(f'\n{colorama.Fore.CYAN}{colorama.Style.BRIGHT}'
f'Managed jobs'
f'{colorama.Style.RESET_ALL}')
msg = managed_jobs.format_job_table(all_jobs, show_all=show_all)
click.echo(msg)
if serve_controllers:
# TODO: Parse serve controllers and show services separately.
# Currently we show a hint that services are shown as clusters.
click.echo(f'\n{colorama.Style.DIM}Hint: SkyServe replica pods are '
'shown in the "SkyPilot clusters" section.'
f'{colorama.Style.RESET_ALL}')


@cli.command()
@click.option('--all',
'-a',
Expand Down Expand Up @@ -1503,6 +1576,14 @@ def _get_services(service_names: Optional[List[str]],
is_flag=True,
required=False,
help='Also show sky serve services, if any.')
@click.option(
'--kubernetes',
'--k8s',
default=False,
is_flag=True,
required=False,
help='[Experimental] Show all SkyPilot resources (including from other '
'users) in the current Kubernetes context.')
@click.argument('clusters',
required=False,
type=str,
Expand All @@ -1512,7 +1593,7 @@ def _get_services(service_names: Optional[List[str]],
# pylint: disable=redefined-builtin
def status(all: bool, refresh: bool, ip: bool, endpoints: bool,
endpoint: Optional[int], show_managed_jobs: bool,
show_services: bool, clusters: List[str]):
show_services: bool, kubernetes: bool, clusters: List[str]):
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
"""Show clusters.

Expand Down Expand Up @@ -1571,6 +1652,9 @@ def status(all: bool, refresh: bool, ip: bool, endpoints: bool,
or for autostop-enabled clusters, use ``--refresh`` to query the latest
cluster statuses from the cloud providers.
"""
if kubernetes:
_status_kubernetes(all)
return
# Using a pool with 2 worker to run the managed job query and sky serve
# service query in parallel to speed up. The pool provides a AsyncResult
# object that can be used as a future.
Expand Down Expand Up @@ -3113,7 +3197,12 @@ def _output():
print_section_titles = False
# If cloud is kubernetes, we want to show real-time capacity
if kubernetes_is_enabled and (cloud is None or cloud_is_kubernetes):
context = region
if region:
context = region
else:
# If region is not specified, we use the current context
context = (
kubernetes_utils.get_current_kube_config_context_name())
try:
# If --cloud kubernetes is not specified, we want to catch
# the case where no GPUs are available on the cluster and
Expand All @@ -3128,7 +3217,7 @@ def _output():
else:
print_section_titles = True
yield (f'{colorama.Fore.CYAN}{colorama.Style.BRIGHT}'
f'Kubernetes GPUs (Context: {context})'
f'Kubernetes GPUs (context: {context})'
f'{colorama.Style.RESET_ALL}\n')
yield from k8s_realtime_table.get_string()
k8s_node_table = _get_kubernetes_node_info_table(context)
Expand Down
7 changes: 4 additions & 3 deletions sky/data/storage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from sky.skylet import constants
from sky.utils import common_utils
from sky.utils import log_utils
from sky.utils.cli_utils import status_utils

logger = sky_logging.init_logger(__name__)

Expand All @@ -22,6 +21,8 @@
'to the cloud storage for {path!r}'
'due to the following error: {error_msg!r}')

_LAST_USE_TRUNC_LENGTH = 25


def format_storage_table(storages: List[Dict[str, Any]],
show_all: bool = False) -> str:
Expand All @@ -46,8 +47,8 @@ def format_storage_table(storages: List[Dict[str, Any]],
if show_all:
command = row['last_use']
else:
command = status_utils.truncate_long_string(
row['last_use'], status_utils.COMMAND_TRUNC_LENGTH)
command = common_utils.truncate_long_string(row['last_use'],
_LAST_USE_TRUNC_LENGTH)
storage_table.add_row([
# NAME
row['name'],
Expand Down
2 changes: 2 additions & 0 deletions sky/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sky.jobs.core import cancel
from sky.jobs.core import launch
from sky.jobs.core import queue
from sky.jobs.core import queue_from_kubernetes_pod
from sky.jobs.core import tail_logs
from sky.jobs.recovery_strategy import DEFAULT_RECOVERY_STRATEGY
from sky.jobs.recovery_strategy import RECOVERY_STRATEGIES
Expand All @@ -34,6 +35,7 @@
'cancel',
'launch',
'queue',
'queue_from_kubernetes_pod',
'tail_logs',
# utils
'ManagedJobCodeGen',
Expand Down
78 changes: 78 additions & 0 deletions sky/jobs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
import sky
from sky import backends
from sky import exceptions
from sky import provision as provision_lib
from sky import sky_logging
from sky import status_lib
from sky import task as task_lib
from sky.backends import backend_utils
from sky.clouds.service_catalog import common as service_catalog_common
from sky.jobs import constants as managed_job_constants
from sky.jobs import utils as managed_job_utils
from sky.provision import common
from sky.skylet import constants as skylet_constants
from sky.usage import usage_lib
from sky.utils import admin_policy_utils
Expand Down Expand Up @@ -138,6 +140,82 @@ def launch(
_disable_controller_check=True)


def queue_from_kubernetes_pod(
pod_name: str,
context: Optional[str] = None,
skip_finished: bool = False) -> List[Dict[str, Any]]:
"""Gets the jobs queue from a specific controller pod.

Args:
pod_name (str): The name of the controller pod to query for jobs.
context (Optional[str]): The Kubernetes context to use. If None, the
current context is used.
skip_finished (bool): If True, does not return finished jobs.

Returns:
[
{
'job_id': int,
'job_name': str,
'resources': str,
'submitted_at': (float) timestamp of submission,
'end_at': (float) timestamp of end,
'duration': (float) duration in seconds,
'recovery_count': (int) Number of retries,
'status': (sky.jobs.ManagedJobStatus) of the job,
'cluster_resources': (str) resources of the cluster,
'region': (str) region of the cluster,
}
]

Raises:
RuntimeError: If there's an error fetching the managed jobs.
"""
# Create dummy cluster info to get the command runner.
provider_config = {'context': context}
instances = {
pod_name: [
common.InstanceInfo(instance_id=pod_name,
internal_ip='',
external_ip='',
tags={})
]
} # Internal IP is not required for Kubernetes
cluster_info = common.ClusterInfo(provider_name='kubernetes',
head_instance_id=pod_name,
provider_config=provider_config,
instances=instances)
managed_jobs_runner = provision_lib.get_command_runners(
'kubernetes', cluster_info)[0]

code = managed_job_utils.ManagedJobCodeGen.get_job_table()
returncode, job_table_payload, stderr = managed_jobs_runner.run(
code,
require_outputs=True,
separate_stderr=True,
stream_logs=False,
)
try:
subprocess_utils.handle_returncode(returncode,
code,
'Failed to fetch managed jobs',
job_table_payload + stderr,
stream_logs=False)
except exceptions.CommandError as e:
raise RuntimeError(str(e)) from e

jobs = managed_job_utils.load_managed_job_queue(job_table_payload)
if skip_finished:
# Filter out the finished jobs. If a multi-task job is partially
# finished, we will include all its tasks.
non_finished_tasks = list(
filter(lambda job: not job['status'].is_terminal(), jobs))
non_finished_job_ids = {job['job_id'] for job in non_finished_tasks}
jobs = list(
filter(lambda job: job['job_id'] in non_finished_job_ids, jobs))
return jobs


@usage_lib.entrypoint
def queue(refresh: bool, skip_finished: bool = False) -> List[Dict[str, Any]]:
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
Expand Down
29 changes: 20 additions & 9 deletions sky/jobs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,29 +599,35 @@ def format_job_table(
a list of "rows" (each of which is a list of str).
"""
jobs = collections.defaultdict(list)
# Check if the tasks have user information.
tasks_have_user = any([task.get('user') for task in tasks])
if max_jobs and tasks_have_user:
raise ValueError('max_jobs is not supported when tasks have user info.')

def get_hash(task):
if tasks_have_user:
return (task['user'], task['job_id'])
return task['job_id']

for task in tasks:
# The tasks within the same job_id are already sorted
# by the task_id.
jobs[task['job_id']].append(task)
jobs = dict(jobs)
jobs[get_hash(task)].append(task)

status_counts: Dict[str, int] = collections.defaultdict(int)
for job_tasks in jobs.values():
managed_job_status = _get_job_status_from_tasks(job_tasks)[0]
if not managed_job_status.is_terminal():
status_counts[managed_job_status.value] += 1

if max_jobs is not None:
job_ids = sorted(jobs.keys(), reverse=True)
job_ids = job_ids[:max_jobs]
jobs = {job_id: jobs[job_id] for job_id in job_ids}

columns = [
'ID', 'TASK', 'NAME', 'RESOURCES', 'SUBMITTED', 'TOT. DURATION',
'JOB DURATION', '#RECOVERIES', 'STATUS'
]
if show_all:
columns += ['STARTED', 'CLUSTER', 'REGION', 'FAILURE']
if tasks_have_user:
columns.insert(0, 'USER')
job_table = log_utils.create_table(columns)

status_counts: Dict[str, int] = collections.defaultdict(int)
Expand All @@ -636,9 +642,9 @@ def format_job_table(
for task in all_tasks:
# The tasks within the same job_id are already sorted
# by the task_id.
jobs[task['job_id']].append(task)
jobs[get_hash(task)].append(task)

for job_id, job_tasks in jobs.items():
for job_hash, job_tasks in jobs.items():
if len(job_tasks) > 1:
# Aggregate the tasks into a new row in the table.
job_name = job_tasks[0]['job_name']
Expand Down Expand Up @@ -674,6 +680,7 @@ def format_job_table(
if not managed_job_status.is_terminal():
status_str += f' (task: {current_task_id})'

job_id = job_hash[1] if tasks_have_user else job_hash
job_values = [
job_id,
'',
Expand All @@ -692,6 +699,8 @@ def format_job_table(
'-',
failure_reason if failure_reason is not None else '-',
])
if tasks_have_user:
job_values.insert(0, job_tasks[0].get('user', '-'))
job_table.add_row(job_values)

for task in job_tasks:
Expand Down Expand Up @@ -724,6 +733,8 @@ def format_job_table(
task['failure_reason']
if task['failure_reason'] is not None else '-',
])
if tasks_have_user:
values.insert(0, task.get('user', '-'))
job_table.add_row(values)

if len(job_tasks) > 1:
Expand Down
25 changes: 25 additions & 0 deletions sky/provision/kubernetes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1998,3 +1998,28 @@ def get_context_from_config(provider_config: Dict[str, Any]) -> Optional[str]:
# we need to use in-cluster auth.
context = None
return context


def get_skypilot_pods(context: Optional[str] = None) -> List[Any]:
"""Gets all SkyPilot pods in the Kubernetes cluster.

Args:
context: Kubernetes context to use. If None, uses the current context.

Returns:
A list of Kubernetes pod objects.
"""
if context is None:
context = get_current_kube_config_context_name()

try:
pods = kubernetes.core_api(context).list_pod_for_all_namespaces(
label_selector='skypilot-cluster',
_request_timeout=kubernetes.API_TIMEOUT).items
except kubernetes.max_retry_error():
raise exceptions.ResourcesUnavailableError(
'Timed out trying to get SkyPilot pods from Kubernetes cluster. '
'Please check if the cluster is healthy and retry. To debug, run: '
'kubectl get pods --selector=skypilot-cluster --all-namespaces'
) from None
return pods
Loading
Loading