From c6ae536d8dfedc3bbcf427a81480382b9d5f4c29 Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Sat, 19 Oct 2024 15:41:44 -0700 Subject: [PATCH] [Serve] Support manually terminating a replica and with purge option (#4032) * define replica id param in cli * create endpoint on controller * call controller endpoint to scale down replica * add classmethod decorator * add handler methods for readability in cli * update docstr and error msg, and inline in cli * update log and return err msg * add docstr, catch and reraise err, add stopped and nonexistent message * inline constant to avoid circular import * fix error statement and return encoded str * add purge feature * add purge replica usage in docstr * use .get to handle unexpected packages * fix: diff terminate replica when failed/purging or not * fix: stay up to date for `is_controller_accessible` * revert * up to date with current APIs * error handling * when purged remove record in the main loop * refactor due to reviewer's suggestions * combine functions * fix: terminate the healthy replica even with purge option * remove abbr * Update sky/serve/core.py Co-authored-by: Tian Xia * Update sky/serve/core.py Co-authored-by: Tian Xia * Update sky/serve/controller.py Co-authored-by: Tian Xia * Update sky/serve/controller.py Co-authored-by: Tian Xia * Update sky/cli.py Co-authored-by: Tian Xia * got services hint * check if not yes in the outside if branch * fix some output messages * Update sky/serve/core.py Co-authored-by: Tian Xia * set conflict status code for already scheduled termination * combine purge and normal terminating down branch together * bump version * global exception handler to render a json response with error messages * fix: use responses.JSONResponse for dict serialize * error messages for old controller * fix: check version mismatch in generated code * revert mistakenly change update_service * refine already in terminating message * fix: branch code workaround in cls.build * wording Co-authored-by: Tian Xia * refactor due to reviewer's comments * fix use ux_utils Co-authored-by: Tian Xia * add changelog as comments * fix messages * edit the message for mismatch error Co-authored-by: Tian Xia * no traceback when raising in `terminate_replica` * messages decode * Apply suggestions from code review Co-authored-by: Tian Xia * format * forma * Empty commit --------- Co-authored-by: David Tran Co-authored-by: David Tran Co-authored-by: Tian Xia --- sky/cli.py | 58 +++++++++++++++++++++++------ sky/serve/__init__.py | 2 + sky/serve/constants.py | 9 ++++- sky/serve/controller.py | 70 +++++++++++++++++++++++++++++++++++ sky/serve/core.py | 47 +++++++++++++++++++++++ sky/serve/replica_managers.py | 17 +++++++-- sky/serve/serve_utils.py | 44 +++++++++++++++++++++- 7 files changed, 229 insertions(+), 18 deletions(-) diff --git a/sky/cli.py b/sky/cli.py index 114c18c9256..fb5a38bba7b 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -4380,9 +4380,14 @@ def serve_status(all: bool, endpoint: bool, service_names: List[str]): default=False, required=False, help='Skip confirmation prompt.') +@click.option('--replica-id', + default=None, + type=int, + help='Tear down a given replica') # pylint: disable=redefined-builtin -def serve_down(service_names: List[str], all: bool, purge: bool, yes: bool): - """Teardown service(s). +def serve_down(service_names: List[str], all: bool, purge: bool, yes: bool, + replica_id: Optional[int]): + """Teardown service(s) or a replica. SERVICE_NAMES is the name of the service (or glob pattern) to tear down. If both SERVICE_NAMES and ``--all`` are supplied, the latter takes precedence. @@ -4408,6 +4413,12 @@ def serve_down(service_names: List[str], all: bool, purge: bool, yes: bool): \b # Forcefully tear down a service in failed status. sky serve down failed-service --purge + \b + # Tear down a specific replica + sky serve down my-service --replica-id 1 + \b + # Forcefully tear down a specific replica, even in failed status. + sky serve down my-service --replica-id 1 --purge """ if sum([len(service_names) > 0, all]) != 1: argument_str = f'SERVICE_NAMES={",".join(service_names)}' if len( @@ -4417,22 +4428,45 @@ def serve_down(service_names: List[str], all: bool, purge: bool, yes: bool): 'Can only specify one of SERVICE_NAMES or --all. ' f'Provided {argument_str!r}.') + replica_id_is_defined = replica_id is not None + if replica_id_is_defined: + if len(service_names) != 1: + service_names_str = ', '.join(service_names) + raise click.UsageError(f'The --replica-id option can only be used ' + f'with a single service name. Got: ' + f'{service_names_str}.') + if all: + raise click.UsageError('The --replica-id option cannot be used ' + 'with the --all option.') + backend_utils.is_controller_accessible( controller=controller_utils.Controllers.SKY_SERVE_CONTROLLER, stopped_message='All services should have been terminated.', exit_if_not_accessible=True) if not yes: - quoted_service_names = [f'{name!r}' for name in service_names] - service_identity_str = f'service(s) {", ".join(quoted_service_names)}' - if all: - service_identity_str = 'all services' - click.confirm(f'Terminating {service_identity_str}. Proceed?', - default=True, - abort=True, - show_default=True) - - serve_lib.down(service_names=service_names, all=all, purge=purge) + if replica_id_is_defined: + click.confirm( + f'Terminating replica ID {replica_id} in ' + f'{service_names[0]!r}. Proceed?', + default=True, + abort=True, + show_default=True) + else: + quoted_service_names = [f'{name!r}' for name in service_names] + service_identity_str = (f'service(s) ' + f'{", ".join(quoted_service_names)}') + if all: + service_identity_str = 'all services' + click.confirm(f'Terminating {service_identity_str}. Proceed?', + default=True, + abort=True, + show_default=True) + + if replica_id_is_defined: + serve_lib.terminate_replica(service_names[0], replica_id, purge) + else: + serve_lib.down(service_names=service_names, all=all, purge=purge) @serve.command('logs', cls=_DocumentedCodeCommand) diff --git a/sky/serve/__init__.py b/sky/serve/__init__.py index d85b6e9311e..f93495809c3 100644 --- a/sky/serve/__init__.py +++ b/sky/serve/__init__.py @@ -8,6 +8,7 @@ from sky.serve.core import down from sky.serve.core import status from sky.serve.core import tail_logs +from sky.serve.core import terminate_replica from sky.serve.core import up from sky.serve.core import update from sky.serve.serve_state import ReplicaStatus @@ -42,6 +43,7 @@ 'SKY_SERVE_CONTROLLER_NAME', 'SKYSERVE_METADATA_DIR', 'status', + 'terminate_replica', 'tail_logs', 'up', 'update', diff --git a/sky/serve/constants.py b/sky/serve/constants.py index 7775c3f8a6e..3974293190e 100644 --- a/sky/serve/constants.py +++ b/sky/serve/constants.py @@ -92,4 +92,11 @@ # change for the serve_utils.ServeCodeGen, we need to bump this version, so that # the user can be notified to update their SkyPilot serve version on the remote # cluster. -SERVE_VERSION = 1 +# Changelog: +# v1.0 - Introduce rolling update. +# v2.0 - Added template-replica feature. +SERVE_VERSION = 2 + +TERMINATE_REPLICA_VERSION_MISMATCH_ERROR = ( + 'The version of service is outdated and does not support manually ' + 'terminating replicas. Please terminate the service and spin up again.') diff --git a/sky/serve/controller.py b/sky/serve/controller.py index 580964273ef..75d14b76079 100644 --- a/sky/serve/controller.py +++ b/sky/serve/controller.py @@ -9,6 +9,7 @@ import traceback from typing import Any, Dict, List +import colorama import fastapi from fastapi import responses import uvicorn @@ -157,6 +158,75 @@ async def update_service(request: fastapi.Request) -> fastapi.Response: return responses.JSONResponse(content={'message': 'Error'}, status_code=500) + @self._app.post('/controller/terminate_replica') + async def terminate_replica( + request: fastapi.Request) -> fastapi.Response: + request_data = await request.json() + replica_id = request_data['replica_id'] + assert isinstance(replica_id, + int), 'Error: replica ID must be an integer.' + purge = request_data['purge'] + assert isinstance(purge, bool), 'Error: purge must be a boolean.' + replica_info = serve_state.get_replica_info_from_id( + self._service_name, replica_id) + assert replica_info is not None, (f'Error: replica ' + f'{replica_id} does not exist.') + replica_status = replica_info.status + + if replica_status == serve_state.ReplicaStatus.SHUTTING_DOWN: + return responses.JSONResponse( + status_code=409, + content={ + 'message': + f'Replica {replica_id} of service ' + f'{self._service_name!r} is already in the process ' + f'of terminating. Skip terminating now.' + }) + + if (replica_status in serve_state.ReplicaStatus.failed_statuses() + and not purge): + return responses.JSONResponse( + status_code=409, + content={ + 'message': f'{colorama.Fore.YELLOW}Replica ' + f'{replica_id} of service ' + f'{self._service_name!r} is in failed ' + f'status ({replica_info.status}). ' + f'Skipping its termination as it could ' + f'lead to a resource leak. ' + f'(Use `sky serve down ' + f'{self._service_name!r} --replica-id ' + f'{replica_id} --purge` to ' + 'forcefully terminate the replica.)' + f'{colorama.Style.RESET_ALL}' + }) + + self._replica_manager.scale_down(replica_id, purge=purge) + + action = 'terminated' if not purge else 'purged' + message = (f'{colorama.Fore.GREEN}Replica {replica_id} of service ' + f'{self._service_name!r} is scheduled to be ' + f'{action}.{colorama.Style.RESET_ALL}\n' + f'Please use {ux_utils.BOLD}sky serve status ' + f'{self._service_name}{ux_utils.RESET_BOLD} ' + f'to check the latest status.') + return responses.JSONResponse(status_code=200, + content={'message': message}) + + @self._app.exception_handler(Exception) + async def validation_exception_handler( + request: fastapi.Request, exc: Exception) -> fastapi.Response: + with ux_utils.enable_traceback(): + logger.error(f'Error in controller: {exc!r}') + return responses.JSONResponse( + status_code=500, + content={ + 'message': + (f'Failed method {request.method} at URL {request.url}.' + f' Exception message is {exc!r}.') + }, + ) + threading.Thread(target=self._run_autoscaler).start() logger.info('SkyServe Controller started on ' diff --git a/sky/serve/core.py b/sky/serve/core.py index 3ad260213f1..691a3edea0b 100644 --- a/sky/serve/core.py +++ b/sky/serve/core.py @@ -503,6 +503,53 @@ def down( sky_logging.print(stdout) +@usage_lib.entrypoint +def terminate_replica(service_name: str, replica_id: int, purge: bool) -> None: + """Tear down a specific replica for the given service. + + Args: + service_name: Name of the service. + replica_id: ID of replica to terminate. + purge: Whether to terminate replicas in a failed status. These replicas + may lead to resource leaks, so we require the user to explicitly + specify this flag to make sure they are aware of this potential + resource leak. + + Raises: + sky.exceptions.ClusterNotUpError: if the sky sere controller is not up. + RuntimeError: if failed to terminate the replica. + """ + handle = backend_utils.is_controller_accessible( + controller=controller_utils.Controllers.SKY_SERVE_CONTROLLER, + stopped_message= + 'No service is running now. Please spin up a service first.', + non_existent_message='No service is running now. ' + 'Please spin up a service first.', + ) + + backend = backend_utils.get_backend_from_handle(handle) + assert isinstance(backend, backends.CloudVmRayBackend) + + code = serve_utils.ServeCodeGen.terminate_replica(service_name, replica_id, + purge) + returncode, stdout, stderr = backend.run_on_head(handle, + code, + require_outputs=True, + stream_logs=False, + separate_stderr=True) + + try: + subprocess_utils.handle_returncode(returncode, + code, + 'Failed to terminate the replica', + stderr, + stream_logs=True) + except exceptions.CommandError as e: + raise RuntimeError(e.error_msg) from e + + sky_logging.print(stdout) + + @usage_lib.entrypoint def status( service_names: Optional[Union[str, diff --git a/sky/serve/replica_managers.py b/sky/serve/replica_managers.py index 337b28ba61b..c0e5220e779 100644 --- a/sky/serve/replica_managers.py +++ b/sky/serve/replica_managers.py @@ -247,6 +247,8 @@ class ReplicaStatusProperty: is_scale_down: bool = False # The replica's spot instance was preempted. preempted: bool = False + # Whether the replica is purged. + purged: bool = False def remove_terminated_replica(self) -> bool: """Whether to remove the replica record from the replica table. @@ -307,6 +309,8 @@ def should_track_service_status(self) -> bool: return False if self.preempted: return False + if self.purged: + return False return True def to_replica_status(self) -> serve_state.ReplicaStatus: @@ -590,7 +594,7 @@ def scale_up(self, """ raise NotImplementedError - def scale_down(self, replica_id: int) -> None: + def scale_down(self, replica_id: int, purge: bool = False) -> None: """Scale down replica with replica_id.""" raise NotImplementedError @@ -679,7 +683,8 @@ def _terminate_replica(self, replica_id: int, sync_down_logs: bool, replica_drain_delay_seconds: int, - is_scale_down: bool = False) -> None: + is_scale_down: bool = False, + purge: bool = False) -> None: if replica_id in self._launch_process_pool: info = serve_state.get_replica_info_from_id(self._service_name, @@ -763,16 +768,18 @@ def _download_and_stream_logs(info: ReplicaInfo): ) info.status_property.sky_down_status = ProcessStatus.RUNNING info.status_property.is_scale_down = is_scale_down + info.status_property.purged = purge serve_state.add_or_update_replica(self._service_name, replica_id, info) p.start() self._down_process_pool[replica_id] = p - def scale_down(self, replica_id: int) -> None: + def scale_down(self, replica_id: int, purge: bool = False) -> None: self._terminate_replica( replica_id, sync_down_logs=False, replica_drain_delay_seconds=_DEFAULT_DRAIN_SECONDS, - is_scale_down=True) + is_scale_down=True, + purge=purge) def _handle_preemption(self, info: ReplicaInfo) -> bool: """Handle preemption of the replica if any error happened. @@ -911,6 +918,8 @@ def _refresh_process_pool(self) -> None: # since user should fixed the error before update. elif info.version != self.latest_version: removal_reason = 'for version outdated' + elif info.status_property.purged: + removal_reason = 'for purge' else: logger.info(f'Termination of replica {replica_id} ' 'finished. Replica info is kept since some ' diff --git a/sky/serve/serve_utils.py b/sky/serve/serve_utils.py index 0ecf34135a7..cb8b53f9814 100644 --- a/sky/serve/serve_utils.py +++ b/sky/serve/serve_utils.py @@ -313,6 +313,36 @@ def update_service_encoded(service_name: str, version: int, mode: str) -> str: return common_utils.encode_payload(service_msg) +def terminate_replica(service_name: str, replica_id: int, purge: bool) -> str: + service_status = _get_service_status(service_name) + if service_status is None: + with ux_utils.print_exception_no_traceback(): + raise ValueError(f'Service {service_name!r} does not exist.') + replica_info = serve_state.get_replica_info_from_id(service_name, + replica_id) + if replica_info is None: + with ux_utils.print_exception_no_traceback(): + raise ValueError( + f'Replica {replica_id} for service {service_name} does not ' + 'exist.') + + controller_port = service_status['controller_port'] + resp = requests.post( + _CONTROLLER_URL.format(CONTROLLER_PORT=controller_port) + + '/controller/terminate_replica', + json={ + 'replica_id': replica_id, + 'purge': purge, + }) + + message: str = resp.json()['message'] + if resp.status_code != 200: + with ux_utils.print_exception_no_traceback(): + raise ValueError(f'Failed to terminate replica {replica_id} ' + f'in {service_name}. Reason:\n{message}') + return message + + def _get_service_status( service_name: str, with_replica_info: bool = True) -> Optional[Dict[str, Any]]: @@ -735,7 +765,7 @@ def _get_replicas(service_record: Dict[str, Any]) -> str: def get_endpoint(service_record: Dict[str, Any]) -> str: - # Don't use backend_utils.is_controller_up since it is too slow. + # Don't use backend_utils.is_controller_accessible since it is too slow. handle = global_user_state.get_handle_from_cluster_name( SKY_SERVE_CONTROLLER_NAME) assert isinstance(handle, backends.CloudVmRayResourceHandle) @@ -915,6 +945,18 @@ def terminate_services(cls, service_names: Optional[List[str]], ] return cls._build(code) + @classmethod + def terminate_replica(cls, service_name: str, replica_id: int, + purge: bool) -> str: + code = [ + f'(lambda: print(serve_utils.terminate_replica({service_name!r}, ' + f'{replica_id}, {purge}), end="", flush=True) ' + 'if getattr(constants, "SERVE_VERSION", 0) >= 2 else ' + f'exec("raise RuntimeError(' + f'{constants.TERMINATE_REPLICA_VERSION_MISMATCH_ERROR!r})"))()' + ] + return cls._build(code) + @classmethod def wait_service_registration(cls, service_name: str, job_id: int) -> str: code = [