From 18ba98f8ebd2685231b1ae89699bece63bf67579 Mon Sep 17 00:00:00 2001 From: Tian Xia Date: Tue, 25 Jul 2023 18:12:24 -0700 Subject: [PATCH] [SkyServe] `sky serve` CLI prototype (#2276) * Add service schema * use new serve YAML * change to qpm * change to fix node * refactor init of SkyServiceSpec * change http example to new yaml format * update default value of from_yaml_config and handle service in task * Launching successfully * use argument in controller & redirector * resolve comments * use qps instead * raise when multiple task found * change to qps * introduce constants * introduce constants & fix bugs * add sky down * add Services No existing services. without STATUS (but with #healthy replica * format * add llama2 example * add fields to service db * status with replica information * fix policy parsing bug * add auth todo * add replica status todo * change cluster name prefix and order of the column * minor fixes * reorder status * change name: controller --> control plane * change name: middleware --> controller * clean code * rename default service name * env vars * add purge and skip identity check on serve controller * upload filemounts and workdir to storage & enhance --purge --- sky/__init__.py | 4 +- sky/backends/backend_utils.py | 58 +++++ sky/backends/cloud_vm_ray_backend.py | 13 +- sky/cli.py | 187 ++++++++++++++++ sky/core.py | 5 + sky/execution.py | 205 +++++++++++++++++- sky/global_user_state.py | 140 ++++++++++++ sky/serve/__init__.py | 3 + sky/serve/autoscalers.py | 48 ++-- sky/serve/common.py | 119 ++++++++-- sky/serve/constants.py | 10 + sky/serve/{controller.py => control_plane.py} | 83 ++++--- sky/serve/examples/http_server/README.md | 11 - sky/serve/examples/http_server/server.py | 3 + sky/serve/examples/http_server/task.yaml | 7 +- sky/serve/examples/llama2/chat.py | 42 ++++ sky/serve/examples/llama2/llama2.yaml | 50 +++++ sky/serve/infra_providers.py | 47 ++-- sky/serve/load_balancers.py | 10 +- sky/serve/redirector.py | 44 ++-- sky/setup_files/setup.py | 5 +- sky/status_lib.py | 34 +++ sky/task.py | 30 ++- sky/templates/skyserve-controller.yaml.j2 | 27 +++ sky/utils/cli_utils/status_utils.py | 86 ++++++++ sky/utils/schemas.py | 53 ++++- 26 files changed, 1196 insertions(+), 128 deletions(-) create mode 100644 sky/serve/__init__.py create mode 100644 sky/serve/constants.py rename sky/serve/{controller.py => control_plane.py} (59%) delete mode 100644 sky/serve/examples/http_server/README.md create mode 100644 sky/serve/examples/llama2/chat.py create mode 100644 sky/serve/examples/llama2/llama2.yaml create mode 100644 sky/templates/skyserve-controller.yaml.j2 diff --git a/sky/__init__.py b/sky/__init__.py index 715a126e1d1..a3f0631a4a7 100644 --- a/sky/__init__.py +++ b/sky/__init__.py @@ -12,7 +12,7 @@ from sky import clouds from sky.clouds.service_catalog import list_accelerators from sky.dag import Dag -from sky.execution import launch, exec, spot_launch # pylint: disable=redefined-builtin +from sky.execution import launch, exec, spot_launch, serve_up, serve_down # pylint: disable=redefined-builtin from sky.resources import Resources from sky.task import Task from sky.optimizer import Optimizer, OptimizeTarget @@ -64,6 +64,8 @@ 'launch', 'exec', 'spot_launch', + 'serve_up', + 'serve_down', # core APIs 'status', 'start', diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index 58cfcd195a8..6d631af1502 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -1,4 +1,5 @@ """Util constants/functions for the backends.""" +import base64 from datetime import datetime import difflib import enum @@ -6,6 +7,7 @@ import json import os import pathlib +import pickle import re import subprocess import tempfile @@ -37,6 +39,7 @@ from sky import skypilot_config from sky import sky_logging from sky import spot as spot_lib +from sky import serve as serve_lib from sky import status_lib from sky.backends import onprem_utils from sky.skylet import constants @@ -1326,6 +1329,10 @@ def generate_cluster_name(): return f'sky-{uuid.uuid4().hex[:4]}-{get_cleaned_username()}' +def generate_service_name(): + return f'service-{uuid.uuid4().hex[:4]}' + + def get_cleaned_username() -> str: """Cleans the current username to be used as part of a cluster name. @@ -2408,6 +2415,57 @@ def _refresh_cluster(cluster_name): return kept_records +def refresh_service_status(service: Optional[str]) -> List[Dict[str, Any]]: + if service is None: + service_records = global_user_state.get_services() + else: + service_record = global_user_state.get_service_from_name(service) + if service_record is None: + with ux_utils.print_exception_no_traceback(): + raise ValueError(f'Service {service} does not exist.') + service_records = [service_record] + # TODO(tian): Make it run in parallel. + for record in service_records: + controller_cluster_name = record['controller_cluster_name'] + endpoint = record['endpoint'] + if not endpoint: + continue + # TODO(tian): Refactor: store ip and app_port separately. + controller_ip = endpoint.split(':')[0] + with requests.Session() as session: + try: + resp = session.get( + f'http://{controller_ip}:{serve_lib.CONTROL_PLANE_PORT}/control_plane/get_replica_nums', + timeout=5) + except requests.RequestException: + pass + else: + record.update(resp.json()) + if record['num_healthy_replicas'] > 0: + record['status'] = status_lib.ServiceStatus.RUNNING + elif record['num_unhealthy_replicas'] > 0: + record['status'] = status_lib.ServiceStatus.REPLICA_INIT + global_user_state.add_or_update_service(**record) + if service is not None: + assert record['name'] == service + try: + resp = session.get( + f'http://{controller_ip}:{serve_lib.CONTROL_PLANE_PORT}/control_plane/get_replica_info', + timeout=5) + except requests.RequestException: + pass + else: + record['replica_info'] = resp.json()['replica_info'] + decoded_info = [] + for info in record['replica_info']: + decoded_info.append({ + k: pickle.loads(base64.b64decode(v)) + for k, v in info.items() + }) + record['replica_info'] = decoded_info + return service_records + + @typing.overload def get_backend_from_handle( handle: 'cloud_vm_ray_backend.CloudVmRayResourceHandle' diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index e522a8d8f39..ce79a90230c 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -33,6 +33,7 @@ from sky import optimizer from sky import skypilot_config from sky import spot as spot_lib +from sky import serve as serve_lib from sky import status_lib from sky import task as task_lib from sky.data import data_utils @@ -2892,8 +2893,9 @@ def _exec_code_on_head( f'Failed to submit job {job_id}.', stderr=stdout + stderr) - logger.info('Job submitted with Job ID: ' - f'{style.BRIGHT}{job_id}{style.RESET_ALL}') + if not handle.cluster_name.startswith(serve_lib.CONTROLLER_PREFIX): + logger.info('Job submitted with Job ID: ' + f'{style.BRIGHT}{job_id}{style.RESET_ALL}') try: if not detach_run: @@ -2924,7 +2926,9 @@ def _exec_code_on_head( '\nTo view the spot job dashboard:\t' f'{backend_utils.BOLD}sky spot dashboard' f'{backend_utils.RESET_BOLD}') - else: + elif not name.startswith(serve_lib.CONTROLLER_PREFIX): + # Skip logging for submit control plane & redirector jobs + # to controller logger.info(f'{fore.CYAN}Job ID: ' f'{style.BRIGHT}{job_id}{style.RESET_ALL}' '\nTo cancel the job:\t' @@ -3039,7 +3043,8 @@ def _post_execute(self, handle: CloudVmRayResourceHandle, fore = colorama.Fore style = colorama.Style name = handle.cluster_name - if name == spot_lib.SPOT_CONTROLLER_NAME or down: + if (name == spot_lib.SPOT_CONTROLLER_NAME or down or + name.startswith(serve_lib.CONTROLLER_PREFIX)): return stop_str = ('\nTo stop the cluster:' f'\t{backend_utils.BOLD}sky stop {name}' diff --git a/sky/cli.py b/sky/cli.py index 292b0dc4155..3cf9ddc87a0 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -438,6 +438,13 @@ def _complete_cluster_name(ctx: click.Context, param: click.Parameter, return global_user_state.get_cluster_names_start_with(incomplete) +def _complete_service_name(ctx: click.Context, param: click.Parameter, + incomplete: str) -> List[str]: + """Handle shell completion for service names.""" + del ctx, param # Unused. + return global_user_state.get_service_names_start_with(incomplete) + + def _complete_storage_name(ctx: click.Context, param: click.Parameter, incomplete: str) -> List[str]: """Handle shell completion for storage names.""" @@ -3798,6 +3805,186 @@ def spot_dashboard(port: Optional[int]): click.echo('Exiting.') +@cli.group(cls=_NaturalOrderGroup) +def serve(): + """SkyServe commands CLI.""" + pass + + +@serve.command('up', cls=_DocumentedCodeCommand) +@click.argument('entrypoint', + required=True, + type=str, + **_get_shell_complete_args(_complete_file_name)) +@click.option('--service', + '-s', + default=None, + type=str, + help='A service name. Unique for each service. If not provided, ' + 'provision a new service with an autogenerated name.') +@click.option('--yes', + '-y', + is_flag=True, + default=False, + required=False, + help='Skip confirmation prompt.') +def serve_up( + entrypoint: str, + service: Optional[str], + yes: bool, +): + """Launches a SkyServe instance. + + ENTRYPOINT must points to a valid YAML file. + + Example: + + .. code-block:: bash + + sky serve up service.yaml + """ + if service is None: + # TODO(tian): Check service name is unique. + service = backend_utils.generate_service_name() + + shell_splits = shlex.split(entrypoint) + yaml_file_provided = (len(shell_splits) == 1 and + (shell_splits[0].endswith('yaml') or + shell_splits[0].endswith('.yml'))) + if not yaml_file_provided: + click.secho('ENTRYPOINT must points to a valid YAML file.', fg='red') + return + + is_yaml = True + config: Optional[List[Dict[str, Any]]] = None + try: + with open(entrypoint, 'r') as f: + try: + config = list(yaml.safe_load_all(f)) + if config: + # FIXME(zongheng): in a chain DAG YAML it only returns the + # first section. OK for downstream but is weird. + result = config[0] + else: + result = {} + if isinstance(result, str): + invalid_reason = ( + 'cannot be parsed into a valid YAML file. ' + 'Please check syntax.') + is_yaml = False + except yaml.YAMLError as e: + if yaml_file_provided: + logger.debug(e) + invalid_reason = ('contains an invalid configuration. ' + ' Please check syntax.') + is_yaml = False + except OSError: + entry_point_path = os.path.expanduser(entrypoint) + if not os.path.exists(entry_point_path): + invalid_reason = ('does not exist. Please check if the path' + ' is correct.') + elif not os.path.isfile(entry_point_path): + invalid_reason = ('is not a file. Please check if the path' + ' is correct.') + else: + invalid_reason = ('yaml.safe_load() failed. Please check if the' + ' path is correct.') + is_yaml = False + if not is_yaml: + click.secho( + f'{entrypoint!r} looks like a yaml path but {invalid_reason}', + fg='red') + return + + click.secho('Service from YAML spec: ', fg='yellow', nl=False) + click.secho(entrypoint, bold=True) + usage_lib.messages.usage.update_user_task_yaml(entrypoint) + dag = dag_utils.load_chain_dag_from_yaml(entrypoint) + if len(dag.tasks) > 1: + click.secho('Multiple tasks found in the YAML file.', fg='red') + return + task = dag.tasks[0] + if task.service is None: + click.secho('Service section not found in the YAML file.', fg='red') + return + + if not yes: + prompt = f'Launching a new service {service}. Proceed?' + if prompt is not None: + click.confirm(prompt, default=True, abort=True, show_default=True) + + sky.serve_up(task, service, entrypoint) + + +@serve.command('status', cls=_DocumentedCodeCommand) +@click.option('--all', + '-a', + default=False, + is_flag=True, + required=False, + help='Show all information in full.') +@click.argument('service', + required=False, + type=str, + **_get_shell_complete_args(_complete_service_name)) +@usage_lib.entrypoint +# pylint: disable=redefined-builtin +def serve_status(all: bool, service: Optional[str]): + service_records = core.service_status(service) + click.echo(f'{colorama.Fore.CYAN}{colorama.Style.BRIGHT}Services' + f'{colorama.Style.RESET_ALL}') + status_utils.show_service_table(service_records, all) + if service is not None: + # If service not exist, we should already raise an error in + # core.service_status. + assert len(service_records) == 1, service_records + service_record = service_records[0] + if 'replica_info' not in service_record: + click.secho(f'Failed to refresh status of service: {service}.', + fg='red') + return + click.echo( + f'\n{colorama.Fore.CYAN}{colorama.Style.BRIGHT}Replicas of {service}' + f'{colorama.Style.RESET_ALL}') + status_utils.show_replica_table(service_record['replica_info'], all) + + +@serve.command('down', cls=_DocumentedCodeCommand) +@click.argument('service', + required=True, + **_get_shell_complete_args(_complete_service_name)) +@click.option('--yes', + '-y', + is_flag=True, + default=False, + required=False, + help='Skip confirmation prompt.') +@click.option('--purge', + '-p', + is_flag=True, + default=False, + required=False, + help='Ignore errors (if any). ') +def serve_down( + service: str, + yes: bool, + purge: bool, +): + """Stops a SkyServe instance. + + Example: + + .. code-block:: bash + + sky serve down my-service + """ + if not yes: + prompt = f'Tearing down service {service}. Proceed?' + click.confirm(prompt, default=True, abort=True, show_default=True) + + sky.serve_down(service, purge) + + # ============================== # Sky Benchmark CLIs # ============================== diff --git a/sky/core.py b/sky/core.py index 64a3161a943..71bca86b29c 100644 --- a/sky/core.py +++ b/sky/core.py @@ -109,6 +109,11 @@ def status(cluster_names: Optional[Union[str, List[str]]] = None, cluster_names=cluster_names) +@usage_lib.entrypoint +def service_status(service: Optional[str]) -> List[Dict[str, Any]]: + return backend_utils.refresh_service_status(service) + + @usage_lib.entrypoint def cost_report() -> List[Dict[str, Any]]: # NOTE(dev): Keep the docstring consistent between the Python API and CLI. diff --git a/sky/execution.py b/sky/execution.py index 4abaf886ac6..e3c25e6a7b5 100644 --- a/sky/execution.py +++ b/sky/execution.py @@ -25,12 +25,15 @@ import sky from sky import backends from sky import clouds +from sky import core from sky import exceptions from sky import global_user_state from sky import optimizer from sky import skypilot_config from sky import sky_logging from sky import spot +from sky import serve +from sky import status_lib from sky import task as task_lib from sky.backends import backend_utils from sky.clouds import gcp @@ -365,7 +368,9 @@ def _execute( backend.teardown_ephemeral_storage(task) backend.teardown(handle, terminate=True) finally: - if cluster_name != spot.SPOT_CONTROLLER_NAME: + if (cluster_name != spot.SPOT_CONTROLLER_NAME and + cluster_name is not None and + not cluster_name.startswith(serve.CONTROLLER_PREFIX)): # UX: print live clusters to make users aware (to save costs). # # Don't print if this job is launched by the spot controller, @@ -942,3 +947,201 @@ def _maybe_translate_local_file_mounts_and_sync_up(task: task_lib.Task): raise exceptions.NotSupportedError( f'Unsupported store type: {store_type}') storage_obj.force_delete = True + + +@usage_lib.entrypoint +def serve_up( + task: 'sky.Task', + name: str, + original_yaml_path: str, +): + """Serve up a service. + + Please refer to the sky.cli.serve_up for the document. + + Args: + task: sky.Task to serve up. + name: Name of the RESTful API. + + Raises: + """ + controller_cluster_name = serve.CONTROLLER_PREFIX + name + assert task.service is not None, task + policy = task.service.policy_str() + assert len(task.resources) == 1 + requested_resources = list(task.resources)[0] + global_user_state.add_or_update_service( + name, controller_cluster_name, '', + status_lib.ServiceStatus.CONTROLLER_INIT, 0, 0, 0, policy, + requested_resources) + app_port = int(task.service.app_port) + assert len(task.resources) == 1, task + task.set_resources(list(task.resources)[0].copy(ports=[app_port])) + + # TODO(tian): Use skyserve constants. + _maybe_translate_local_file_mounts_and_sync_up(task) + + with tempfile.NamedTemporaryFile(prefix=f'serve-task-{name}-', + mode='w') as f: + task_config = task.to_yaml_config() + if 'resources' in task_config and 'spot_recovery' in task_config[ + 'resources']: + del task_config['resources']['spot_recovery'] + common_utils.dump_yaml(f.name, task_config) + remote_task_yaml_path = f'{serve.SERVICE_YAML_PREFIX}/service_{name}.yaml' + vars_to_fill = { + 'ports': [app_port, serve.CONTROL_PLANE_PORT], + 'remote_task_yaml_path': remote_task_yaml_path, + 'local_task_yaml_path': f.name, + 'is_dev': env_options.Options.IS_DEVELOPER.get(), + 'is_debug': env_options.Options.SHOW_DEBUG_INFO.get(), + 'disable_logging': env_options.Options.DISABLE_LOGGING.get(), + } + controller_yaml_path = os.path.join(serve.CONTROLLER_YAML_PREFIX, + f'{name}.yaml') + backend_utils.fill_template(serve.CONTROLLER_TEMPLATE, + vars_to_fill, + output_path=controller_yaml_path) + controller_task = task_lib.Task.from_yaml(controller_yaml_path) + assert len(controller_task.resources) == 1, controller_task + print(f'{colorama.Fore.YELLOW}' + f'Launching controller for {name}...' + f'{colorama.Style.RESET_ALL}') + + _execute( + entrypoint=controller_task, + stream_logs=True, + cluster_name=controller_cluster_name, + retry_until_up=True, + ) + + handle = global_user_state.get_handle_from_cluster_name( + controller_cluster_name) + assert isinstance(handle, backends.CloudVmRayResourceHandle) + endpoint = f'{handle.head_ip}:{task.service.app_port}' + global_user_state.add_or_update_service( + name, controller_cluster_name, endpoint, + status_lib.ServiceStatus.REPLICA_INIT, 0, 0, 0, policy, + requested_resources) + + print( + f'{colorama.Fore.YELLOW}' + 'Launching control plane process on controller...' + f'{colorama.Style.RESET_ALL}', + end='') + _execute( + entrypoint=sky.Task( + name='run-control-plane', + run='python -m sky.serve.control_plane --service-name ' + f'{name} --task-yaml {remote_task_yaml_path} ' + f'--port {serve.CONTROL_PLANE_PORT}'), + stream_logs=False, + handle=handle, + stages=[Stage.EXEC], + cluster_name=controller_cluster_name, + detach_run=True, + ) + + print( + f'{colorama.Fore.YELLOW}' + 'Launching redirector process on controller...' + f'{colorama.Style.RESET_ALL}', + end='') + _execute( + entrypoint=sky.Task( + name='run-redirector', + run='python -m sky.serve.redirector --task-yaml ' + f'{remote_task_yaml_path} --port {app_port} ' + f'--control-plane-addr http://0.0.0.0:{serve.CONTROL_PLANE_PORT}' + ), + stream_logs=False, + handle=handle, + stages=[Stage.EXEC], + cluster_name=controller_cluster_name, + detach_run=True, + ) + + print(f'{colorama.Style.BRIGHT}{colorama.Fore.CYAN}Serving at ' + f'{colorama.Style.RESET_ALL}{colorama.Fore.CYAN}' + f'{endpoint}.\n' + f'{colorama.Style.RESET_ALL}') + + +def serve_down( + name: str, + purge: bool, +): + """Teardown a service. + + Please refer to the sky.cli.serve_down for the document. + + Args: + name: Name of the service. + + Raises: + """ + service_record = global_user_state.get_service_from_name(name) + if service_record is None: + with ux_utils.print_exception_no_traceback(): + raise ValueError(f'Service {name} does not exist.') + controller_cluster_name = service_record['controller_cluster_name'] + num_healthy_replicas = service_record['num_healthy_replicas'] + num_unhealthy_replicas = service_record['num_unhealthy_replicas'] + num_replicas = num_healthy_replicas + num_unhealthy_replicas + handle = global_user_state.get_handle_from_cluster_name( + controller_cluster_name) + global_user_state.set_service_status(name, + status_lib.ServiceStatus.SHUTTING_DOWN) + + try: + print( + f'{colorama.Fore.YELLOW}' + f'Stopping control plane and redirector processes on controller...' + f'{colorama.Style.RESET_ALL}') + core.cancel(controller_cluster_name, all=True) + except (ValueError, sky.exceptions.ClusterNotUpError) as e: + if purge: + logger.warning(f'Ignoring error when stopping controller: {e}') + else: + raise e + + try: + if handle is not None: + plural = '' + # TODO(tian): Change to #num replica (including failed one) + if num_replicas > 1: + plural = 's' + print(f'{colorama.Fore.YELLOW}' + f'Tearing down {num_replicas} replica{plural}...' + f'{colorama.Style.RESET_ALL}') + _execute( + entrypoint=sky.Task(name='teardown-all-replicas', + run='sky down -a -y'), + stream_logs=False, + handle=handle, + stages=[Stage.EXEC], + cluster_name=controller_cluster_name, + detach_run=False, + ) + except (RuntimeError, ValueError) as e: + if purge: + logger.warning(f'Ignoring error when cleaning controller: {e}') + else: + raise e + + try: + print(f'{colorama.Fore.YELLOW}' + 'Teardown controller...' + f'{colorama.Style.RESET_ALL}') + core.down(controller_cluster_name, purge=purge) + except (RuntimeError, ValueError) as e: + if purge: + logger.warning(f'Ignoring error when cleaning controller: {e}') + else: + raise e + + global_user_state.remove_service(name) + + print(f'{colorama.Fore.GREEN}' + f'Tear down service {name} done.' + f'{colorama.Style.RESET_ALL}') diff --git a/sky/global_user_state.py b/sky/global_user_state.py index 37d7c9ba903..28c7717b812 100644 --- a/sky/global_user_state.py +++ b/sky/global_user_state.py @@ -26,6 +26,7 @@ if typing.TYPE_CHECKING: from sky import backends from sky.data import Storage + from sky import resources as resources_lib _ENABLED_CLOUDS_KEY = 'enabled_clouds' @@ -92,6 +93,18 @@ def create_table(cursor, conn): handle BLOB, last_use TEXT, status TEXT)""") + # Table for Services + cursor.execute("""\ + CREATE TABLE IF NOT EXISTS services ( + name TEXT PRIMARY KEY, + controller_cluster_name TEXT, + endpoint TEXT, + status TEXT, + num_healthy_replicas INTEGER DEFAULT 0, + num_unhealthy_replicas INTEGER DEFAULT 0, + num_failed_replicas INTEGER DEFAULT 0, + policy TEXT, + requested_resources BLOB)""") # For backward compatibility. # TODO(zhwu): Remove this function after all users have migrated to # the latest version of SkyPilot. @@ -272,6 +285,60 @@ def add_or_update_cluster(cluster_name: str, _DB.conn.commit() +def add_or_update_service( + name: str, controller_cluster_name: str, endpoint: str, + status: status_lib.ServiceStatus, num_healthy_replicas: int, + num_unhealthy_replicas: int, num_failed_replicas, policy: str, + requested_resources: Optional['resources_lib.Resources']): + _DB.cursor.execute( + 'INSERT or REPLACE INTO services' + '(name, controller_cluster_name, endpoint, status, ' + 'num_healthy_replicas, num_unhealthy_replicas, ' + 'num_failed_replicas, policy, requested_resources) ' + 'VALUES (' + # name + '?, ' + # controller_cluster_name + '?, ' + # endpoint + '?, ' + # status + '?, ' + # num_healthy_replicas + '?, ' + # num_unhealthy_replicas + '?, ' + # num_failed_replicas + '?, ' + # policy + '?, ' + # requested_resources + '?' + ')', + ( + # name + name, + # controller_cluster_name + controller_cluster_name, + # endpoint + endpoint, + # status + status.value, + # num_healthy_replicas + num_healthy_replicas, + # num_unhealthy_replicas + num_unhealthy_replicas, + # num_failed_replicas + num_failed_replicas, + # policy + policy, + # requested_resources + pickle.dumps(requested_resources), + )) + + _DB.conn.commit() + + def update_last_use(cluster_name: str): """Updates the last used command for the cluster.""" _DB.cursor.execute('UPDATE clusters SET last_use=(?) WHERE name=(?)', @@ -313,6 +380,21 @@ def remove_cluster(cluster_name: str, terminate: bool) -> None: _DB.conn.commit() +def remove_service(service_name: str): + _DB.cursor.execute('DELETE FROM services WHERE name=(?)', (service_name,)) + _DB.conn.commit() + + +def set_service_status(service_name: str, status: status_lib.ServiceStatus): + _DB.cursor.execute('UPDATE services SET status=(?) ' + 'WHERE name=(?)', (status.value, service_name)) + count = _DB.cursor.rowcount + _DB.conn.commit() + assert count <= 1, count + if count == 0: + raise ValueError(f'Service {service_name} not found.') + + def get_handle_from_cluster_name( cluster_name: str) -> Optional['backends.ResourceHandle']: assert cluster_name is not None, 'cluster_name cannot be None' @@ -534,6 +616,33 @@ def get_cluster_from_name( return None +def get_service_from_name( + service_name: Optional[str]) -> Optional[Dict[str, Any]]: + rows = _DB.cursor.execute('SELECT * FROM services WHERE name=(?)', + (service_name,)).fetchall() + for row in rows: + # Explicitly specify the number of fields to unpack, so that + # we can add new fields to the database in the future without + # breaking the previous code. + (name, controller_cluster_name, endpoint, status, num_healthy_replicas, + num_unhealthy_replicas, num_failed_replicas, policy, + requested_resources) = row[:9] + # TODO: use namedtuple instead of dict + record = { + 'name': name, + 'controller_cluster_name': controller_cluster_name, + 'endpoint': endpoint, + 'status': status_lib.ServiceStatus[status], + 'num_healthy_replicas': num_healthy_replicas, + 'num_unhealthy_replicas': num_unhealthy_replicas, + 'num_failed_replicas': num_failed_replicas, + 'policy': policy, + 'requested_resources': pickle.loads(requested_resources), + } + return record + return None + + def get_clusters() -> List[Dict[str, Any]]: rows = _DB.cursor.execute( 'select * from clusters order by launched_at desc').fetchall() @@ -560,6 +669,31 @@ def get_clusters() -> List[Dict[str, Any]]: return records +def get_services() -> List[Dict[str, Any]]: + rows = _DB.cursor.execute('select * from services').fetchall() + records = [] + for row in rows: + (name, controller_cluster_name, endpoint, status, num_healthy_replicas, + num_unhealthy_replicas, num_failed_replicas, policy, + requested_resources) = row[:9] + # TODO: use namedtuple instead of dict + + record = { + 'name': name, + 'controller_cluster_name': controller_cluster_name, + 'endpoint': endpoint, + 'status': status_lib.ServiceStatus[status], + 'num_healthy_replicas': num_healthy_replicas, + 'num_unhealthy_replicas': num_unhealthy_replicas, + 'num_failed_replicas': num_failed_replicas, + 'policy': policy, + 'requested_resources': pickle.loads(requested_resources), + } + + records.append(record) + return records + + def get_clusters_from_history() -> List[Dict[str, Any]]: rows = _DB.cursor.execute( 'SELECT ch.cluster_hash, ch.name, ch.num_nodes, ' @@ -611,6 +745,12 @@ def get_cluster_names_start_with(starts_with: str) -> List[str]: return [row[0] for row in rows] +def get_service_names_start_with(starts_with: str) -> List[str]: + rows = _DB.cursor.execute('SELECT name FROM services WHERE name LIKE (?)', + (f'{starts_with}%',)) + return [row[0] for row in rows] + + def get_enabled_clouds() -> List[clouds.Cloud]: rows = _DB.cursor.execute('SELECT value FROM config WHERE key = ?', (_ENABLED_CLOUDS_KEY,)) diff --git a/sky/serve/__init__.py b/sky/serve/__init__.py new file mode 100644 index 00000000000..d17081db6d2 --- /dev/null +++ b/sky/serve/__init__.py @@ -0,0 +1,3 @@ +from sky.serve.constants import (CONTROLLER_PREFIX, CONTROLLER_TEMPLATE, + CONTROLLER_YAML_PREFIX, SERVICE_YAML_PREFIX, + CONTROL_PLANE_PORT) diff --git a/sky/serve/autoscalers.py b/sky/serve/autoscalers.py index a9075beedfb..ecf7116888d 100644 --- a/sky/serve/autoscalers.py +++ b/sky/serve/autoscalers.py @@ -1,6 +1,8 @@ import logging import time +from typing import Optional + from sky.serve.infra_providers import InfraProvider from sky.serve.load_balancers import LoadBalancer @@ -43,7 +45,7 @@ def __init__(self, lower_threshold: int = 1, min_nodes: int = 1, **kwargs): - ''' + """ Autoscaler that scales up when the average latency of all servers is above the upper threshold and scales down when the average latency of all servers is below the lower threshold. :param args: @@ -51,7 +53,7 @@ def __init__(self, :param lower_threshold: lower threshold for latency in seconds :param min_nodes: minimum number of nodes to keep running :param kwargs: - ''' + """ super().__init__(*args, **kwargs) self.upper_threshold = upper_threshold self.lower_threshold = lower_threshold @@ -78,10 +80,10 @@ class RequestRateAutoscaler(Autoscaler): def __init__(self, *args, - query_interval: int = 10, - upper_threshold: int = 10, - lower_threshold: int = 2, min_nodes: int = 1, + max_nodes: Optional[int] = None, + upper_threshold: Optional[float] = None, + lower_threshold: Optional[float] = None, cooldown: int = 60, **kwargs): """ @@ -95,10 +97,11 @@ def __init__(self, :param kwargs: """ super().__init__(*args, **kwargs) - self.query_interval = query_interval + self.min_nodes = min_nodes + self.max_nodes = max_nodes or min_nodes + self.query_interval = 60 # Therefore thresholds represent queries per minute. self.upper_threshold = upper_threshold self.lower_threshold = lower_threshold - self.min_nodes = min_nodes self.cooldown = cooldown self.last_scale_operation = 0 # Time of last scale operation. @@ -107,11 +110,11 @@ def evaluate_scaling(self): # Check if cooldown period has passed since the last scaling operation if current_time - self.last_scale_operation < self.cooldown: + logger.info(f'Current time: {current_time}, ' + f'last scale operation: {self.last_scale_operation}, ' + f'cooldown: {self.cooldown}') logger.info( - f'Current time: {current_time}, last scale operation: {self.last_scale_operation}, cooldown: {self.cooldown}' - ) - logger.info( - f'Cooldown period has not passed since last scaling operation. Skipping scaling.' + 'Cooldown period has not passed since last scaling operation. Skipping scaling.' ) return @@ -121,27 +124,30 @@ def evaluate_scaling(self): self.load_balancer.request_timestamps.popleft() num_requests = len(self.load_balancer.request_timestamps) + num_requests = float( + num_requests) / 60 # Convert to requests per second. num_nodes = self.infra_provider.total_servers() requests_per_node = num_requests / num_nodes if num_nodes else num_requests # To account for zero case. logger.info(f'Requests per node: {requests_per_node}') - logger.info( - f'Upper threshold: {self.upper_threshold} q/node, lower threshold: {self.lower_threshold} q/node, queries per node: {requests_per_node} q/node' - ) + logger.info(f'Upper threshold: {self.upper_threshold} qps/node, ' + f'lower threshold: {self.lower_threshold} qps/node, ' + f'queries per node: {requests_per_node} qps/node') scaled = True # Bootstrap case logger.info(f'Number of nodes: {num_nodes}') - if num_nodes == 0 and requests_per_node > 0: - logger.info(f'Bootstrapping autoscaler.') + if num_nodes < self.min_nodes: + logger.info('Bootstrapping autoscaler.') self.scale_up(1) self.last_scale_operation = current_time - elif requests_per_node > self.upper_threshold: - self.scale_up(1) - self.last_scale_operation = current_time - elif requests_per_node < self.lower_threshold: + elif self.upper_threshold is not None and requests_per_node > self.upper_threshold: + if self.infra_provider.total_servers() < self.max_nodes: + self.scale_up(1) + self.last_scale_operation = current_time + elif self.lower_threshold is not None and requests_per_node < self.lower_threshold: if self.infra_provider.total_servers() > self.min_nodes: self.scale_down(1) self.last_scale_operation = current_time else: - logger.info(f'No scaling needed.') + logger.info('No scaling needed.') diff --git a/sky/serve/common.py b/sky/serve/common.py index 5f211dc5a10..f5e66fdf4e1 100644 --- a/sky/serve/common.py +++ b/sky/serve/common.py @@ -1,32 +1,115 @@ -import yaml +from typing import Optional, Dict, Any + +from sky.backends import backend_utils +from sky.utils import schemas +from sky.utils import ux_utils class SkyServiceSpec: - def __init__(self, yaml_path: str): - with open(yaml_path, 'r') as f: - self.task = yaml.safe_load(f) - if 'service' not in self.task: - raise ValueError('Task YAML must have a "service" section') - if 'port' not in self.task['service']: - raise ValueError('Task YAML must have a "port" section') - if 'readiness_probe' not in self.task['service']: - raise ValueError('Task YAML must have a "readiness_probe" section') - self._readiness_path = self.get_readiness_path() - self._app_port = self.get_app_port() - - def get_readiness_path(self): + def __init__( + self, + readiness_path: str, + readiness_timeout: int, + app_port: int, + min_replica: int, + max_replica: Optional[int] = None, + qps_upper_threshold: Optional[float] = None, + qps_lower_threshold: Optional[float] = None, + ): + if max_replica is not None and max_replica < min_replica: + with ux_utils.print_exception_no_traceback(): + raise ValueError( + 'max_replica must be greater than or equal to min_replica') # TODO: check if the path is valid - return f':{self.task["service"]["port"]}{self.task["service"]["readiness_probe"]}' - - def get_app_port(self): + self._readiness_path = f':{app_port}{readiness_path}' + self._readiness_timeout = readiness_timeout # TODO: check if the port is valid - return f'{self.task["service"]["port"]}' + self._app_port = str(app_port) + self._min_replica = min_replica + self._max_replica = max_replica + self._qps_upper_threshold = qps_upper_threshold + self._qps_lower_threshold = qps_lower_threshold + + @classmethod + def from_yaml_config(cls, config: Optional[Dict[str, Any]]): + if config is None: + return None + + backend_utils.validate_schema(config, schemas.get_service_schema(), + 'Invalid service YAML:') + + service_config = {} + service_config['readiness_path'] = config['readiness_probe']['path'] + service_config['readiness_timeout'] = config['readiness_probe'][ + 'readiness_timeout'] + service_config['app_port'] = config['port'] + service_config['min_replica'] = config['replica_policy']['min_replica'] + service_config['max_replica'] = config['replica_policy'].get( + 'max_replica', None) + service_config['qps_upper_threshold'] = config['replica_policy'].get( + 'qps_upper_threshold', None) + service_config['qps_lower_threshold'] = config['replica_policy'].get( + 'qps_lower_threshold', None) + + return SkyServiceSpec(**service_config) + + def to_yaml_config(self): + replica_policy = {} + + def add_if_not_none(key, value, no_empty: bool = False): + if no_empty and not value: + return + if value is not None: + replica_policy[key] = value + + add_if_not_none('min_replica', self.min_replica) + add_if_not_none('max_replica', self.max_replica) + add_if_not_none('qps_upper_threshold', self.qps_upper_threshold) + add_if_not_none('qps_lower_threshold', self.qps_lower_threshold) + + return { + 'port': int(self.app_port), + 'readiness_probe': { + 'path': self.readiness_path[len(f':{self.app_port}'):], + 'readiness_timeout': self.readiness_timeout, + }, + 'replica_policy': replica_policy, + } + + def policy_str(self): + if self.max_replica == self.min_replica or self.max_replica is None: + plural = '' + if self.min_replica > 1: + plural = 'S' + return f'FIXED NODE{plural}: {self.min_replica}' + # TODO(tian): Refactor to contain more information + return f'AUTOSCALE [{self.min_replica}, {self.max_replica}]' @property def readiness_path(self): return self._readiness_path + @property + def readiness_timeout(self): + return self._readiness_timeout + @property def app_port(self): return self._app_port + + @property + def min_replica(self): + return self._min_replica + + @property + def max_replica(self): + return self._max_replica + + @property + def qps_upper_threshold(self): + return self._qps_upper_threshold + + @property + def qps_lower_threshold(self): + return self._qps_lower_threshold diff --git a/sky/serve/constants.py b/sky/serve/constants.py new file mode 100644 index 00000000000..e8ac94662c3 --- /dev/null +++ b/sky/serve/constants.py @@ -0,0 +1,10 @@ +"""Constants used for SkyServe.""" + +CONTROLLER_PREFIX = 'controller-' + +CONTROLLER_TEMPLATE = 'skyserve-controller.yaml.j2' +CONTROLLER_YAML_PREFIX = '~/.sky/serve' + +SERVICE_YAML_PREFIX = '~/.sky/service' + +CONTROL_PLANE_PORT = 31001 diff --git a/sky/serve/controller.py b/sky/serve/control_plane.py similarity index 59% rename from sky/serve/controller.py rename to sky/serve/control_plane.py index dfe5a953764..8b710646a96 100644 --- a/sky/serve/controller.py +++ b/sky/serve/control_plane.py @@ -9,6 +9,7 @@ import time import threading +import yaml from typing import Optional @@ -23,13 +24,13 @@ logger = logging.getLogger(__name__) -class Controller: +class ControlPlane: def __init__(self, + port: int, infra_provider: InfraProvider, load_balancer: LoadBalancer, - autoscaler: Optional[Autoscaler] = None, - port: int = 8082): + autoscaler: Optional[Autoscaler] = None): self.port = port self.infra_provider = infra_provider self.load_balancer = load_balancer @@ -43,9 +44,10 @@ def server_fetcher(self): self.load_balancer.probe_endpoints(server_ips) time.sleep(10) + # TODO(tian): Authentication!!! def run(self): - @self.app.post('/controller/increment_request_count') + @self.app.post('/control_plane/increment_request_count') async def increment_request_count(request: Request): # await request request_data = await request.json() @@ -55,10 +57,26 @@ async def increment_request_count(request: Request): self.load_balancer.increment_request_count(count=count) return {'message': 'Success'} - @self.app.get('/controller/get_server_ips') + @self.app.get('/control_plane/get_server_ips') def get_server_ips(): return {'server_ips': list(self.load_balancer.servers_queue)} + @self.app.get('/control_plane/get_replica_info') + def get_replica_info(): + return {'replica_info': self.infra_provider.get_replica_info()} + + @self.app.get('/control_plane/get_replica_nums') + def get_replica_nums(): + return { + 'num_healthy_replicas': len(self.load_balancer.available_servers + ), + 'num_unhealthy_replicas': + self.infra_provider.total_servers() - + len(self.load_balancer.available_servers), + # TODO(tian): Detect error replicas + 'num_failed_replicas': 0 + } + # Run server_monitor and autoscaler.monitor (if autoscaler is defined) in separate threads in the background. This should not block the main thread. server_fetcher_thread = threading.Thread(target=self.server_fetcher, daemon=True) @@ -73,7 +91,11 @@ def get_server_ips(): if __name__ == '__main__': - parser = argparse.ArgumentParser(description='SkyServe Server') + parser = argparse.ArgumentParser(description='SkyServe Control Plane') + parser.add_argument('--service-name', + type=str, + help='Name of the service', + required=True) parser.add_argument('--task-yaml', type=str, help='Task YAML file', @@ -81,24 +103,26 @@ def get_server_ips(): parser.add_argument('--port', '-p', type=int, - help='Port to run the controller', - default=8082) - parser.add_argument('--min-nodes', - type=int, - default=1, - help='Minimum nodes to keep running') + help='Port to run the control plane', + required=True) args = parser.parse_args() # ======= Infra Provider ========= # infra_provider = DummyInfraProvider() - infra_provider = SkyPilotInfraProvider(args.task_yaml) + infra_provider = SkyPilotInfraProvider(args.task_yaml, args.service_name) # ======= Load Balancer ========= - service_spec = SkyServiceSpec(args.task_yaml) + with open(args.task_yaml, 'r') as f: + task = yaml.safe_load(f) + if 'service' not in task: + raise ValueError('Task YAML must have a "service" section') + service_config = task['service'] + service_spec = SkyServiceSpec.from_yaml_config(service_config) # Select the load balancing policy: RoundRobinLoadBalancer or LeastLoadedLoadBalancer load_balancer = RoundRobinLoadBalancer( infra_provider=infra_provider, - endpoint_path=service_spec.readiness_path) + endpoint_path=service_spec.readiness_path, + readiness_timeout=service_spec.readiness_timeout) # load_balancer = LeastLoadedLoadBalancer(n=5) # autoscaler = LatencyThresholdAutoscaler(load_balancer, # upper_threshold=0.5, # 500ms @@ -106,17 +130,18 @@ def get_server_ips(): # ======= Autoscaler ========= # Create an autoscaler with the RequestRateAutoscaler policy. Thresholds are defined as requests per node in the defined interval. - autoscaler = RequestRateAutoscaler(infra_provider, - load_balancer, - frequency=5, - query_interval=60, - lower_threshold=0, - upper_threshold=1, - min_nodes=args.min_nodes, - cooldown=60) - - # ======= Controller ========= - # Create a controller object and run it. - controller = Controller(infra_provider, load_balancer, autoscaler, - args.port) - controller.run() + autoscaler = RequestRateAutoscaler( + infra_provider, + load_balancer, + frequency=5, + min_nodes=service_spec.min_replica, + max_nodes=service_spec.max_replica, + upper_threshold=service_spec.qps_upper_threshold, + lower_threshold=service_spec.qps_lower_threshold, + cooldown=60) + + # ======= ControlPlane ========= + # Create a control plane object and run it. + control_plane = ControlPlane(args.port, infra_provider, load_balancer, + autoscaler) + control_plane.run() diff --git a/sky/serve/examples/http_server/README.md b/sky/serve/examples/http_server/README.md deleted file mode 100644 index 2bc059cee94..00000000000 --- a/sky/serve/examples/http_server/README.md +++ /dev/null @@ -1,11 +0,0 @@ -# HTTP Server example for SkyServe - -## Usage - -```bash -# Run controller. -python -m sky.serve.controller --task-yaml sky/serve/examples/http_server/task.yaml - -# Run redirector. -python -m sky.serve.redirector --task-yaml sky/serve/examples/http_server/task.yaml -``` diff --git a/sky/serve/examples/http_server/server.py b/sky/serve/examples/http_server/server.py index 4ea616b148e..303b117d26d 100644 --- a/sky/serve/examples/http_server/server.py +++ b/sky/serve/examples/http_server/server.py @@ -3,7 +3,9 @@ PORT = 8081 + class MyHttpRequestHandler(http.server.SimpleHTTPRequestHandler): + def do_GET(self): # Return 200 for all paths # Therefore, readiness_probe will return 200 at path '/health' @@ -23,6 +25,7 @@ def do_GET(self): self.wfile.write(bytes(html, 'utf8')) return + Handler = MyHttpRequestHandler with socketserver.TCPServer(("", PORT), Handler) as httpd: diff --git a/sky/serve/examples/http_server/task.yaml b/sky/serve/examples/http_server/task.yaml index b82fbb29f75..d0fe866f259 100644 --- a/sky/serve/examples/http_server/task.yaml +++ b/sky/serve/examples/http_server/task.yaml @@ -9,4 +9,9 @@ run: python3 server.py service: port: 8081 - readiness_probe: /health + readiness_probe: + path: /health + readiness_timeout: 12000 + replica_policy: + min_replica: 1 + max_replica: 1 diff --git a/sky/serve/examples/llama2/chat.py b/sky/serve/examples/llama2/chat.py new file mode 100644 index 00000000000..2f450479851 --- /dev/null +++ b/sky/serve/examples/llama2/chat.py @@ -0,0 +1,42 @@ +import requests +import json +import openai + +stream = True +model = "Llama-2-7b-chat-hf" +init_prompt = "You are a helful assistant." +history = [{"role": "system", "content": init_prompt}] +endpoint = input("Endpoint: ") +url = f"http://{endpoint}/v1/chat/completions" +openai.api_base = f"http://{endpoint}/v1" +openai.api_key = "placeholder" + +try: + while True: + user_input = input("[User] ") + history.append({"role": "user", "content": user_input}) + if stream: + resp = openai.ChatCompletion.create(model=model, + messages=history, + stream=True) + print("[Chatbot]", end="", flush=True) + tot = "" + for i in resp: + dlt = i["choices"][0]["delta"] + if "content" not in dlt: + continue + print(dlt["content"], end="", flush=True) + tot += dlt["content"] + print() + history.append({"role": "assistant", "content": tot}) + else: + resp = requests.post(url, + data=json.dumps({ + "model": model, + "messages": history + })) + msg = resp.json()["choices"][0]["message"] + print("[Chatbot]" + msg["content"]) + history.append(msg) +except KeyboardInterrupt: + print("\nBye!") diff --git a/sky/serve/examples/llama2/llama2.yaml b/sky/serve/examples/llama2/llama2.yaml new file mode 100644 index 00000000000..a1317e33509 --- /dev/null +++ b/sky/serve/examples/llama2/llama2.yaml @@ -0,0 +1,50 @@ +resources: + cloud: gcp + memory: 32+ + accelerators: T4:1 + disk_size: 1024 + disk_tier: high + +service: + port: 8087 + readiness_probe: + path: /v1/models + readiness_timeout: 1200 + replica_policy: + min_replica: 2 + +envs: + MODEL_SIZE: 7 + HF_TOKEN: # TODO: Replace with huggingface token + +setup: | + conda activate chatbot + if [ $? -ne 0 ]; then + conda create -n chatbot python=3.9 -y + conda activate chatbot + fi + + # Install dependencies + pip install git+https://github.com/lm-sys/FastChat.git + # Need the latest transformers to support 70B model + pip install git+https://github.com/huggingface/transformers.git + + python -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')" + +run: | + conda activate chatbot + + echo 'Starting controller...' + python -u -m fastchat.serve.controller --host 0.0.0.0 > ~/controller.log 2>&1 & + sleep 10 + echo 'Starting model worker...' + python -u -m fastchat.serve.model_worker --host 0.0.0.0 \ + --model-path meta-llama/Llama-2-${MODEL_SIZE}b-chat-hf \ + --num-gpus $SKYPILOT_NUM_GPUS_PER_NODE 2>&1 \ + | tee model_worker.log & + + echo 'Waiting for model worker to start...' + while ! `cat model_worker.log | grep -q 'Uvicorn running on'`; do sleep 1; done + + echo 'Starting openai api server server...' + python -u -m fastchat.serve.openai_api_server --host 0.0.0.0 --port 8087 | tee ~/openai_api_server.log diff --git a/sky/serve/infra_providers.py b/sky/serve/infra_providers.py index 0ebb3b32c58..0a4c931e678 100644 --- a/sky/serve/infra_providers.py +++ b/sky/serve/infra_providers.py @@ -1,6 +1,8 @@ import logging from typing import List import time +import pickle +import base64 import sky from sky.backends import backend_utils @@ -77,21 +79,21 @@ def terminate_servers(self, unhealthy_servers: List[str]): class SkyPilotInfraProvider(InfraProvider): - CLUSTER_NAME_PREFIX = 'skyserve-' - def __init__(self, task_yaml_path: str): + def __init__(self, task_yaml_path: str, cluster_name_prefix: str): self.task_yaml_path = task_yaml_path + self.cluster_name_prefix = cluster_name_prefix + '-' self.id_counter = self._get_id_start() def _get_id_start(self): - ''' + """ Returns the id to start from when creating a new cluster - ''' + """ clusters = sky.global_user_state.get_clusters() # Filter out clusters that don't have the prefix clusters = [ cluster for cluster in clusters - if self.CLUSTER_NAME_PREFIX in cluster['name'] + if self.cluster_name_prefix in cluster['name'] ] # Get the greatest id max_id = 0 @@ -108,7 +110,7 @@ def _get_ip_clusname_map(self): ip_clusname_map = {} for cluster in clusters: name = cluster['name'] - if self.CLUSTER_NAME_PREFIX in name: + if self.cluster_name_prefix in name: handle = cluster['handle'] try: # Get the head node ip @@ -121,6 +123,23 @@ def _get_ip_clusname_map(self): continue return ip_clusname_map + def get_replica_info(self): + clusters = sky.global_user_state.get_clusters() + infos = [] + for cluster in clusters: + if self.cluster_name_prefix in cluster['name']: + info = { + 'name': cluster['name'], + 'handle': cluster['handle'], + 'status': cluster['status'], + } + info = { + k: base64.b64encode(pickle.dumps(v)).decode('utf-8') + for k, v in info.items() + } + infos.append(info) + return infos + def _get_server_ips(self): return list(self._get_ip_clusname_map().keys()) @@ -130,7 +149,7 @@ def _return_total_servers(self): # FIXME - this is a hack to get around. should implement a better filtering mechanism clusters = [ cluster for cluster in clusters - if self.CLUSTER_NAME_PREFIX in cluster['name'] + if self.cluster_name_prefix in cluster['name'] ] return len(clusters) @@ -138,10 +157,12 @@ def _scale_up(self, n): # Launch n new clusters task = sky.Task.from_yaml(self.task_yaml_path) for i in range(0, n): - cluster_name = f'{self.CLUSTER_NAME_PREFIX}{self.id_counter}' + cluster_name = f'{self.cluster_name_prefix}{self.id_counter}' logger.info(f'Creating SkyPilot cluster {cluster_name}') - sky.launch(task, cluster_name=cluster_name, - detach_run=True) # TODO - make the launch parallel + sky.launch(task, + cluster_name=cluster_name, + detach_run=True, + retry_until_up=True) # TODO - make the launch parallel self.id_counter += 1 def _scale_down(self, n): @@ -151,7 +172,7 @@ def _scale_down(self, n): # Filter out clusters that don't have the prefix clusters = [ cluster for cluster in clusters - if self.CLUSTER_NAME_PREFIX in cluster['name'] + if self.cluster_name_prefix in cluster['name'] ] num_clusters = len(clusters) if num_clusters > 0: @@ -194,10 +215,8 @@ def terminate_servers(self, unhealthy_servers: List[str]): name = ip_to_name_map[endpoint_url] if endpoint_url in unhealthy_servers: logger.info(f'Deleting SkyPilot cluster {name}') - # Run sky.down in a daemon thread so that it doesn't block the main thread threading.Thread(target=sky.down, args=(name,), kwargs={ 'purge': True - }, - daemon=True).start() + }).start() diff --git a/sky/serve/load_balancers.py b/sky/serve/load_balancers.py index 4c0c95c6ff7..f1d1b464a76 100644 --- a/sky/serve/load_balancers.py +++ b/sky/serve/load_balancers.py @@ -12,12 +12,17 @@ class LoadBalancer: - def __init__(self, infra_provider, endpoint_path, post_data=None): + def __init__(self, + infra_provider, + endpoint_path, + readiness_timeout, + post_data=None): self.available_servers = [] self.request_count = 0 self.request_timestamps = deque() self.infra_provider = infra_provider self.endpoint_path = endpoint_path + self.readiness_timeout = readiness_timeout self.post_data = post_data def increment_request_count(self, count=1): @@ -37,7 +42,6 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.servers_queue = deque() self.first_unhealthy_time = {} - self.timeout = 18000 logger.info(f'Endpoint path: {self.endpoint_path}') def probe_endpoints(self, endpoint_ips): @@ -101,7 +105,7 @@ def probe_endpoint(endpoint_ip): if server not in self.first_unhealthy_time: self.first_unhealthy_time[server] = time.time() elif time.time() - self.first_unhealthy_time[ - server] > self.timeout: # cooldown before terminating a dead server to avoid hysterisis + server] > self.readiness_timeout: # cooldown before terminating a dead server to avoid hysterisis servers_to_terminate.append(server) self.infra_provider.terminate_servers(servers_to_terminate) diff --git a/sky/serve/redirector.py b/sky/serve/redirector.py index 5c3df12c62c..95b364f184a 100644 --- a/sky/serve/redirector.py +++ b/sky/serve/redirector.py @@ -1,6 +1,8 @@ import time import logging +import yaml from collections import deque +from typing import List, Deque from sky.serve.common import SkyServiceSpec @@ -24,34 +26,34 @@ class SkyServeRedirector: def __init__(self, - controller_url: str, + control_plane_url: str, service_spec: SkyServiceSpec, port: int = 8081): - self.controller_url = controller_url + self.control_plane_url = control_plane_url self.port = port self.app_port = service_spec.app_port - self.server_ips = [] - self.servers_queue = deque() + self.server_ips: List[str] = [] + self.servers_queue: Deque[str] = deque() self.app = FastAPI() self.request_count = 0 - self.controller_sync_timeout = 20 + self.control_plane_sync_timeout = 20 - def sync_with_controller(self): + def sync_with_control_plane(self): while True: server_ips = [] with requests.Session() as session: try: # send request count response = session.post( - self.controller_url + - '/controller/increment_request_count', + self.control_plane_url + + '/control_plane/increment_request_count', json={'counts': self.request_count}, timeout=5) response.raise_for_status() self.request_count = 0 # get server ips - response = session.get(self.controller_url + - '/controller/get_server_ips') + response = session.get(self.control_plane_url + + '/control_plane/get_server_ips') response.raise_for_status() server_ips = response.json()['server_ips'] except requests.RequestException as e: @@ -59,7 +61,7 @@ def sync_with_controller(self): else: logger.info(f'Server IPs: {server_ips}') self.servers_queue = deque(server_ips) - time.sleep(self.controller_sync_timeout) + time.sleep(self.control_plane_sync_timeout) def select_server(self): if not self.servers_queue: @@ -86,7 +88,7 @@ def serve(self): methods=['GET', 'POST', 'PUT', 'DELETE']) server_fetcher_thread = threading.Thread( - target=self.sync_with_controller, daemon=True) + target=self.sync_with_control_plane, daemon=True) server_fetcher_thread.start() logger.info(f'Sky Server started on http://0.0.0.0:{self.port}') @@ -106,15 +108,21 @@ def serve(self): '-p', type=int, help='Port to run the redirector on', - default=8081) - parser.add_argument('--controller-addr', - default='http://localhost:8082', + required=True) + parser.add_argument('--control-plane-addr', type=str, - help='Controller address (ip:port).') + help='Control plane address (ip:port).', + required=True) args = parser.parse_args() - service_spec = SkyServiceSpec(args.task_yaml) - redirector = SkyServeRedirector(controller_url=args.controller_addr, + with open(args.task_yaml, 'r') as f: + task = yaml.safe_load(f) + if 'service' not in task: + raise ValueError('Task YAML must have a "service" section') + service_config = task['service'] + service_spec = SkyServiceSpec.from_yaml_config(service_config) + + redirector = SkyServeRedirector(control_plane_url=args.control_plane_addr, service_spec=service_spec, port=args.port) redirector.serve() diff --git a/sky/setup_files/setup.py b/sky/setup_files/setup.py index acfaafacaae..c28fc671655 100644 --- a/sky/setup_files/setup.py +++ b/sky/setup_files/setup.py @@ -112,7 +112,10 @@ def parse_readme(readme: str) -> str: 'pulp', # Ray job has an issue with pydantic>2.0.0, due to API changes of pydantic. See # https://github.com/ray-project/ray/issues/36990 - 'pydantic<2.0' + 'pydantic<2.0', + # Required by the SkyServe library + 'uvicorn', + 'fastapi' ] # NOTE: Change the templates/spot-controller.yaml.j2 file if any of the diff --git a/sky/status_lib.py b/sky/status_lib.py index ae9a00c84de..ff977b009e2 100644 --- a/sky/status_lib.py +++ b/sky/status_lib.py @@ -49,3 +49,37 @@ class StorageStatus(enum.Enum): # Finished uploading, in terminal state READY = 'READY' + + +class ServiceStatus(enum.Enum): + """Service status as recorded in table 'services'.""" + + # Middleware is initializing + CONTROLLER_INIT = 'CONTROLLER_INIT' + + # Replica is initializing + REPLICA_INIT = 'REPLICA_INIT' + + # At least one replica is ready + RUNNING = 'RUNNING' + + # Service is being stopped + SHUTTING_DOWN = 'SHUTTING_DOWN' + + # At least one replica is failed + FAILED = 'FAILED' + + def colored_str(self): + color = _SERVICE_STATUS_TO_COLOR[self] + return f'{color}{self.value}{colorama.Style.RESET_ALL}' + + +_SERVICE_STATUS_TO_COLOR = { + ServiceStatus.CONTROLLER_INIT: colorama.Fore.BLUE, + ServiceStatus.REPLICA_INIT: colorama.Fore.BLUE, + ServiceStatus.RUNNING: colorama.Fore.GREEN, + ServiceStatus.SHUTTING_DOWN: colorama.Fore.YELLOW, + ServiceStatus.FAILED: colorama.Fore.RED, +} + +# TODO(tian): Add status for replicas to distinguish 'skypilot UP' and 'health probe succeeded' diff --git a/sky/task.py b/sky/task.py index 2105d0feeda..150113da5f1 100644 --- a/sky/task.py +++ b/sky/task.py @@ -15,6 +15,7 @@ from sky.backends import backend_utils from sky.data import storage as storage_lib from sky.data import data_utils +from sky.serve import common from sky.skylet import constants from sky.utils import schemas from sky.utils import ux_utils @@ -194,6 +195,7 @@ def __init__( self.estimated_outputs_size_gigabytes = None # Default to CPUNode self.resources = {sky.Resources()} + self._service = None self.time_estimator_func: Optional[Callable[['sky.Resources'], int]] = None self.file_mounts: Optional[Dict[str, str]] = None @@ -365,10 +367,12 @@ def from_yaml_config( resources = config.pop('resources', None) resources = sky.Resources.from_yaml_config(resources) - # FIXME: find a better way to exclude unused fields. - config.pop('service', None) - task.set_resources({resources}) + + service = config.pop('service', None) + service = common.SkyServiceSpec.from_yaml_config(service) + task.set_service(service) + assert not config, f'Invalid task args: {config.keys()}' return task @@ -528,6 +532,22 @@ def set_resources( def get_resources(self): return self.resources + @property + def service(self) -> Optional[common.SkyServiceSpec]: + return self._service + + def set_service(self, service: Optional[common.SkyServiceSpec]) -> 'Task': + """Sets the service spec for this task. + + Args: + service: a SkyServiceSpec object. + + Returns: + self: The current task, with service set. + """ + self._service = service + return self + def set_time_estimator(self, func: Callable[['sky.Resources'], int]) -> 'Task': """Sets a func mapping resources to estimated time (secs). @@ -884,6 +904,10 @@ def add_if_not_none(key, value, no_empty: bool = False): assert len(self.resources) == 1 resources = list(self.resources)[0] add_if_not_none('resources', resources.to_yaml_config()) + + if self.service is not None: + add_if_not_none('service', self.service.to_yaml_config()) + add_if_not_none('num_nodes', self.num_nodes) if self.inputs is not None: diff --git a/sky/templates/skyserve-controller.yaml.j2 b/sky/templates/skyserve-controller.yaml.j2 new file mode 100644 index 00000000000..a750c01e899 --- /dev/null +++ b/sky/templates/skyserve-controller.yaml.j2 @@ -0,0 +1,27 @@ +resources: + cloud: gcp + disk_size: 100 + ports: +{%- for port in ports %} + - {{port}} +{%- endfor %} + +# {% if workdir is not none %} +# workdir: {{workdir}} +# {% endif %} + +file_mounts: + {{remote_task_yaml_path}}: {{local_task_yaml_path}} + +envs: + # skip cloud identity check for serve controller to avoid the overhead. + SKYPILOT_SKIP_CLOUD_IDENTITY_CHECK: 1 +{% if is_dev %} + SKYPILOT_DEV: 1 +{% endif %} +{% if is_debug %} + SKYPILOT_DEBUG: 1 +{% endif %} +{% if disable_logging %} + SKYPILOT_DISABLE_USAGE_COLLECTION: 1 +{% endif %} diff --git a/sky/utils/cli_utils/status_utils.py b/sky/utils/cli_utils/status_utils.py index 8ec8222c599..331494b672e 100644 --- a/sky/utils/cli_utils/status_utils.py +++ b/sky/utils/cli_utils/status_utils.py @@ -19,6 +19,9 @@ _ClusterRecord = Dict[str, Any] # A record returned by core.cost_report(); see its docstr for all fields. _ClusterCostReportRecord = Dict[str, Any] +# A record in global_user_state's 'services' table. +_ServiceRecord = Dict[str, Any] +_ReplicaRecord = Dict[str, Any] def truncate_long_string(s: str, max_length: int = 35) -> str: @@ -107,6 +110,68 @@ def show_status_table(cluster_records: List[_ClusterRecord], return num_pending_autostop +def show_service_table(service_records: List[_ServiceRecord], show_all: bool): + status_columns = [ + StatusColumn('NAME', _get_name), + StatusColumn('CONTROLLER_CLUSTER_NAME', + _get_controller_cluster_name, + show_by_default=False), + StatusColumn('ENDPOINT', _get_endpoint), + StatusColumn('#HEALTHY_REPLICAS', _get_healthy_replicas), + StatusColumn('#UNHEALTHY_REPLICAS', _get_unhealthy_replicas), + # TODO(tian): After we have a better way to detect failed replicas + # StatusColumn('#FAILED_REPLICAS', _get_failed_replicas), + StatusColumn('STATUS', _get_service_status_colored), + StatusColumn('POLICY', _get_policy, show_by_default=False), + StatusColumn('REQUESTED_RESOURCES', + _get_requested_resources, + show_by_default=False), + ] + + columns = [] + for status_column in status_columns: + if status_column.show_by_default or show_all: + columns.append(status_column.name) + service_table = log_utils.create_table(columns) + for record in service_records: + row = [] + for status_column in status_columns: + if status_column.show_by_default or show_all: + row.append(status_column.calc(record)) + service_table.add_row(row) + if service_records: + click.echo(service_table) + else: + click.echo('No existing services.') + + +def show_replica_table(replica_records: List[_ReplicaRecord], show_all: bool): + status_columns = [ + StatusColumn('NAME', _get_name), + StatusColumn('RESOURCES', + _get_resources, + trunc_length=70 if not show_all else 0), + StatusColumn('REGION', _get_region), + StatusColumn('STATUS', _get_status_colored), + ] + + columns = [] + for status_column in status_columns: + if status_column.show_by_default or show_all: + columns.append(status_column.name) + replica_table = log_utils.create_table(columns) + for record in replica_records: + row = [] + for status_column in status_columns: + if status_column.show_by_default or show_all: + row.append(status_column.calc(record)) + replica_table.add_row(row) + if replica_records: + click.echo(replica_table) + else: + click.echo('No existing replicas.') + + def get_total_cost_of_displayed_records( cluster_records: List[_ClusterCostReportRecord], display_all: bool): """Compute total cost of records to be displayed in cost report.""" @@ -307,6 +372,27 @@ def show_local_status_table(local_clusters: List[str]): _get_command = (lambda cluster_record: cluster_record['last_use']) _get_duration = (lambda cluster_record: log_utils.readable_time_duration( 0, cluster_record['duration'], absolute=True)) +_get_controller_cluster_name = ( + lambda service_record: service_record['controller_cluster_name']) +_get_endpoint = (lambda service_record: service_record['endpoint']) +_get_healthy_replicas = ( + lambda service_record: service_record['num_healthy_replicas']) +_get_unhealthy_replicas = ( + lambda service_record: service_record['num_unhealthy_replicas']) +_get_failed_replicas = ( + lambda service_record: service_record['num_failed_replicas']) +_get_policy = (lambda service_record: service_record['policy']) +_get_requested_resources = ( + lambda service_record: service_record['requested_resources']) + + +def _get_service_status( + service_record: _ServiceRecord) -> status_lib.ServiceStatus: + return service_record['status'] + + +def _get_service_status_colored(service_record: _ServiceRecord) -> str: + return _get_service_status(service_record).colored_str() def _get_status(cluster_record: _ClusterRecord) -> status_lib.ClusterStatus: diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 8126a97da9d..d711640a7f0 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -143,6 +143,52 @@ def get_storage_schema(): } +def get_service_schema(): + return { + '$schema': 'http://json-schema.org/draft-07/schema#', + 'type': 'object', + 'required': ['port', 'readiness_probe', 'replica_policy'], + 'additionalProperties': False, + 'properties': { + 'port': { + 'type': 'integer', + }, + 'readiness_probe': { + 'type': 'object', + 'required': ['path', 'readiness_timeout'], + 'additionalProperties': False, + 'properties': { + 'path': { + 'type': 'string', + }, + 'readiness_timeout': { + 'type': 'number', + }, + } + }, + 'replica_policy': { + 'type': 'object', + 'required': ['min_replica'], + 'additionalProperties': False, + 'properties': { + 'min_replica': { + 'type': 'integer', + }, + 'max_replica': { + 'type': 'integer', + }, + 'qps_upper_threshold': { + 'type': 'number', + }, + 'qps_lower_threshold': { + 'type': 'number', + }, + } + } + } + } + + def get_task_schema(): return { '$schema': 'https://json-schema.org/draft/2020-12/schema', @@ -170,6 +216,10 @@ def get_task_schema(): 'file_mounts': { 'type': 'object', }, + # service config is validated separately using SERVICE_SCHEMA + 'service': { + 'type': 'object', + }, 'setup': { 'type': 'string', }, @@ -203,9 +253,6 @@ def get_task_schema(): 'additionalProperties': { 'type': 'number' } - }, - 'service': { - 'type': 'object', } } }