diff --git a/sky/spot/controller.py b/sky/spot/controller.py index c19cee24f4d..155fc202b5a 100644 --- a/sky/spot/controller.py +++ b/sky/spot/controller.py @@ -6,6 +6,7 @@ import colorama import filelock +import ray import sky from sky import exceptions @@ -28,9 +29,9 @@ class SpotController: def __init__(self, job_id: int, task_yaml: str, retry_until_up: bool) -> None: - self._job_id = job_id - self._task_name = pathlib.Path(task_yaml).stem - self._task = sky.Task.from_yaml(task_yaml) + self.job_id = job_id + self.task_name = pathlib.Path(task_yaml).stem + self.task = sky.Task.from_yaml(task_yaml) self._retry_until_up = retry_until_up # TODO(zhwu): this assumes the specific backend. @@ -39,145 +40,77 @@ def __init__(self, job_id: int, task_yaml: str, # Add a unique identifier to the task environment variables, so that # the user can have the same id for multiple recoveries. # Example value: sky-2022-10-04-22-46-52-467694_id-17 - task_envs = self._task.envs or {} + task_envs = self.task.envs or {} job_id_env_var = common_utils.get_global_job_id( - self.backend.run_timestamp, 'spot', self._job_id) + self.backend.run_timestamp, 'spot', self.job_id) task_envs[constants.JOB_ID_ENV_VAR] = job_id_env_var - self._task.set_envs(task_envs) + self.task.set_envs(task_envs) spot_state.set_submitted( - self._job_id, - self._task_name, + self.job_id, + self.task_name, self.backend.run_timestamp, - resources_str=backend_utils.get_task_resources_str(self._task)) + resources_str=backend_utils.get_task_resources_str(self.task)) logger.info(f'Submitted spot job; SKYPILOT_JOB_ID: {job_id_env_var}') - self._cluster_name = spot_utils.generate_spot_cluster_name( - self._task_name, self._job_id) - self._strategy_executor = recovery_strategy.StrategyExecutor.make( - self._cluster_name, self.backend, self._task, retry_until_up, - self._handle_signal) - - def _run(self): - """Busy loop monitoring spot cluster status and handling recovery.""" - logger.info(f'Started monitoring spot task {self._task_name} ' - f'(id: {self._job_id})') - spot_state.set_starting(self._job_id) - start_at = self._strategy_executor.launch() - - spot_state.set_started(self._job_id, start_time=start_at) - while True: - time.sleep(spot_utils.JOB_STATUS_CHECK_GAP_SECONDS) - # Handle the signal if it is sent by the user. - self._handle_signal() - - # Check the network connection to avoid false alarm for job failure. - # Network glitch was observed even in the VM. - try: - backend_utils.check_network_connection() - except exceptions.NetworkError: - logger.info( - 'Network is not available. Retrying again in ' - f'{spot_utils.JOB_STATUS_CHECK_GAP_SECONDS} seconds.') - continue - - # NOTE: we do not check cluster status first because race condition - # can occur, i.e. cluster can be down during the job status check. - job_status = spot_utils.get_job_status(self.backend, - self._cluster_name) - - if job_status is not None and not job_status.is_terminal(): - need_recovery = False - if self._task.num_nodes > 1: - # Check the cluster status for multi-node jobs, since the - # job may not be set to FAILED immediately when only some - # of the nodes are preempted. - (cluster_status, - handle) = backend_utils.refresh_cluster_status_handle( - self._cluster_name, force_refresh=True) - if cluster_status != global_user_state.ClusterStatus.UP: - # recover the cluster if it is not up. - logger.info(f'Cluster status {cluster_status.value}. ' - 'Recovering...') - need_recovery = True - if not need_recovery: - # The job and cluster are healthy, continue to monitor the - # job status. - continue - - if job_status == job_lib.JobStatus.SUCCEEDED: - end_time = spot_utils.get_job_timestamp(self.backend, - self._cluster_name, - get_end_time=True) - # The job is done. - spot_state.set_succeeded(self._job_id, end_time=end_time) - break - - if job_status == job_lib.JobStatus.FAILED: - # Check the status of the spot cluster. If it is not UP, - # the cluster is preempted. - (cluster_status, - handle) = backend_utils.refresh_cluster_status_handle( - self._cluster_name, force_refresh=True) - if cluster_status == global_user_state.ClusterStatus.UP: - # The user code has probably crashed. - end_time = spot_utils.get_job_timestamp(self.backend, - self._cluster_name, - get_end_time=True) - logger.info( - 'The user job failed. Please check the logs below.\n' - f'== Logs of the user job (ID: {self._job_id}) ==\n') - self.backend.tail_logs(handle, - None, - spot_job_id=self._job_id) - logger.info(f'\n== End of logs (ID: {self._job_id}) ==') - spot_state.set_failed( - self._job_id, - failure_type=spot_state.SpotStatus.FAILED, - end_time=end_time) - break - # cluster can be down, INIT or STOPPED, based on the interruption - # behavior of the cloud. - # Failed to connect to the cluster or the cluster is partially down. - # job_status is None or job_status == job_lib.JobStatus.FAILED - logger.info('The cluster is preempted.') - spot_state.set_recovering(self._job_id) - recovered_time = self._strategy_executor.recover() - spot_state.set_recovered(self._job_id, - recovered_time=recovered_time) + self.cluster_name = spot_utils.generate_spot_cluster_name( + self.task_name, self.job_id) + self.strategy_executor = recovery_strategy.StrategyExecutor.make( + self.cluster_name, self.backend, self.task, retry_until_up) def start(self): """Start the controller.""" try: - self._run() + self._handle_signal() + controller_task = _controller_run.remote(self) + # Signal can interrupt the underlying controller process. + ready, _ = ray.wait([controller_task], timeout=0) + while not ready: + try: + self._handle_signal() + except exceptions.SpotUserCancelledError as e: + logger.info(f'Cancelling spot job {self.job_id}...') + try: + ray.cancel(controller_task) + ray.get(controller_task) + except ray.exceptions.RayTaskError: + # When the controller task is cancelled, it will raise + # ray.exceptions.RayTaskError, which can be ignored, + # since the SpotUserCancelledError will be raised and + # handled later. + pass + raise e + ready, _ = ray.wait([controller_task], timeout=1) + # Need this to get the exception from the controller task. + ray.get(controller_task) except exceptions.SpotUserCancelledError as e: logger.info(e) - spot_state.set_cancelled(self._job_id) + spot_state.set_cancelled(self.job_id) except exceptions.ResourcesUnavailableError as e: logger.error(f'Resources unavailable: {colorama.Fore.RED}{e}' f'{colorama.Style.RESET_ALL}') spot_state.set_failed( - self._job_id, + self.job_id, failure_type=spot_state.SpotStatus.FAILED_NO_RESOURCE) except (Exception, SystemExit) as e: # pylint: disable=broad-except logger.error(traceback.format_exc()) logger.error(f'Unexpected error occurred: {type(e).__name__}: {e}') finally: - self._strategy_executor.terminate_cluster() - job_status = spot_state.get_status(self._job_id) + self.strategy_executor.terminate_cluster() + job_status = spot_state.get_status(self.job_id) # The job can be non-terminal if the controller exited abnormally, # e.g. failed to launch cluster after reaching the MAX_RETRY. if not job_status.is_terminal(): spot_state.set_failed( - self._job_id, + self.job_id, failure_type=spot_state.SpotStatus.FAILED_CONTROLLER) # Clean up Storages with persistent=False. - self.backend.teardown_ephemeral_storage(self._task) + self.backend.teardown_ephemeral_storage(self.task) def _handle_signal(self): """Handle the signal if the user sent it.""" signal_file = pathlib.Path( - spot_utils.SIGNAL_FILE_PREFIX.format(self._job_id)) + spot_utils.SIGNAL_FILE_PREFIX.format(self.job_id)) signal = None if signal_file.exists(): # Filelock is needed to prevent race condition with concurrent @@ -200,7 +133,97 @@ def _handle_signal(self): raise RuntimeError(f'Unknown SkyPilot signal received: {signal.value}.') +@ray.remote(num_cpus=0) +def _controller_run(spot_controller: SpotController): + """Busy loop monitoring spot cluster status and handling recovery.""" + logger.info(f'Started monitoring spot task {spot_controller.task_name} ' + f'(id: {spot_controller.job_id})') + spot_state.set_starting(spot_controller.job_id) + start_at = spot_controller.strategy_executor.launch() + + spot_state.set_started(spot_controller.job_id, start_time=start_at) + while True: + time.sleep(spot_utils.JOB_STATUS_CHECK_GAP_SECONDS) + + # Check the network connection to avoid false alarm for job failure. + # Network glitch was observed even in the VM. + try: + backend_utils.check_network_connection() + except exceptions.NetworkError: + logger.info('Network is not available. Retrying again in ' + f'{spot_utils.JOB_STATUS_CHECK_GAP_SECONDS} seconds.') + continue + + # NOTE: we do not check cluster status first because race condition + # can occur, i.e. cluster can be down during the job status check. + job_status = spot_utils.get_job_status(spot_controller.backend, + spot_controller.cluster_name) + + if job_status is not None and not job_status.is_terminal(): + need_recovery = False + if spot_controller.task.num_nodes > 1: + # Check the cluster status for multi-node jobs, since the + # job may not be set to FAILED immediately when only some + # of the nodes are preempted. + (cluster_status, + handle) = backend_utils.refresh_cluster_status_handle( + spot_controller.cluster_name, force_refresh=True) + if cluster_status != global_user_state.ClusterStatus.UP: + # recover the cluster if it is not up. + logger.info(f'Cluster status {cluster_status.value}. ' + 'Recovering...') + need_recovery = True + if not need_recovery: + # The job and cluster are healthy, continue to monitor the + # job status. + continue + + if job_status == job_lib.JobStatus.SUCCEEDED: + end_time = spot_utils.get_job_timestamp( + spot_controller.backend, + spot_controller.cluster_name, + get_end_time=True) + # The job is done. + spot_state.set_succeeded(spot_controller.job_id, end_time=end_time) + break + + if job_status == job_lib.JobStatus.FAILED: + # Check the status of the spot cluster. If it is not UP, + # the cluster is preempted. + (cluster_status, + handle) = backend_utils.refresh_cluster_status_handle( + spot_controller.cluster_name, force_refresh=True) + if cluster_status == global_user_state.ClusterStatus.UP: + # The user code has probably crashed. + end_time = spot_utils.get_job_timestamp( + spot_controller.backend, + spot_controller.cluster_name, + get_end_time=True) + logger.info( + 'The user job failed. Please check the logs below.\n' + '== Logs of the user job (ID: ' + f'{spot_controller.job_id}) ==\n') + spot_controller.backend.tail_logs( + handle, None, spot_job_id=spot_controller.job_id) + logger.info( + f'\n== End of logs (ID: {spot_controller.job_id}) ==') + spot_state.set_failed(spot_controller.job_id, + failure_type=spot_state.SpotStatus.FAILED, + end_time=end_time) + break + # cluster can be down, INIT or STOPPED, based on the interruption + # behavior of the cloud. + # Failed to connect to the cluster or the cluster is partially down. + # job_status is None or job_status == job_lib.JobStatus.FAILED + logger.info('The cluster is preempted.') + spot_state.set_recovering(spot_controller.job_id) + recovered_time = spot_controller.strategy_executor.recover() + spot_state.set_recovered(spot_controller.job_id, + recovered_time=recovered_time) + + if __name__ == '__main__': + ray.init() parser = argparse.ArgumentParser() parser.add_argument('--job-id', required=True, diff --git a/sky/spot/recovery_strategy.py b/sky/spot/recovery_strategy.py index 7cc439397df..534d2d38c9f 100644 --- a/sky/spot/recovery_strategy.py +++ b/sky/spot/recovery_strategy.py @@ -1,7 +1,7 @@ """The strategy to handle launching/recovery/termination of spot clusters.""" import time import typing -from typing import Callable, Optional +from typing import Optional import sky from sky import exceptions @@ -34,8 +34,7 @@ class StrategyExecutor: RETRY_INIT_GAP_SECONDS = 60 def __init__(self, cluster_name: str, backend: 'backends.Backend', - task: 'task_lib.Task', retry_until_up: bool, - signal_handler: Callable) -> None: + task: 'task_lib.Task', retry_until_up: bool) -> None: """Initialize the strategy executor. Args: @@ -43,15 +42,12 @@ def __init__(self, cluster_name: str, backend: 'backends.Backend', backend: The backend to use. Only CloudVMRayBackend is supported. task: The task to execute. retry_until_up: Whether to retry until the cluster is up. - signal_handler: The signal handler that will raise an exception if a - SkyPilot signal is received. """ self.dag = sky.Dag() self.dag.add(task) self.cluster_name = cluster_name self.backend = backend self.retry_until_up = retry_until_up - self.signal_handler = signal_handler def __init_subclass__(cls, name: str, default: bool = False): SPOT_STRATEGIES[name] = cls @@ -63,8 +59,7 @@ def __init_subclass__(cls, name: str, default: bool = False): @classmethod def make(cls, cluster_name: str, backend: 'backends.Backend', - task: 'task_lib.Task', retry_until_up: bool, - signal_handler: Callable) -> 'StrategyExecutor': + task: 'task_lib.Task', retry_until_up: bool) -> 'StrategyExecutor': """Create a strategy from a task.""" resources = task.resources assert len(resources) == 1, 'Only one resource is supported.' @@ -77,7 +72,7 @@ def make(cls, cluster_name: str, backend: 'backends.Backend', # will be handled by the strategy class. task.set_resources({resources.copy(spot_recovery=None)}) return SPOT_STRATEGIES[spot_recovery](cluster_name, backend, task, - retry_until_up, signal_handler) + retry_until_up) def launch(self) -> Optional[float]: """Launch the spot cluster for the first time. @@ -147,9 +142,6 @@ def _launch(self, max_retry=3, raise_on_failure=True) -> Optional[float]: backoff = common_utils.Backoff(self.RETRY_INIT_GAP_SECONDS) while True: retry_cnt += 1 - # Check the signal every time to be more responsive to user - # signals, such as Cancel. - self.signal_handler() retry_launch = False exception = None try: diff --git a/sky/spot/spot_utils.py b/sky/spot/spot_utils.py index 452a0a15a52..9c134c7f6b2 100644 --- a/sky/spot/spot_utils.py +++ b/sky/spot/spot_utils.py @@ -178,8 +178,7 @@ def cancel_jobs_by_id(job_ids: Optional[List[int]]) -> str: cancelled_job_ids_str = ', '.join(map(str, cancelled_job_ids)) identity_str = f'Jobs with IDs {cancelled_job_ids_str} are' - return (f'{identity_str} scheduled to be cancelled within ' - f'{JOB_STATUS_CHECK_GAP_SECONDS} seconds.') + return f'{identity_str} scheduled to be cancelled.' def cancel_job_by_name(job_name: str) -> str: diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 4257833c381..321080cc6bb 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -805,7 +805,7 @@ def test_spot_recovery(): 'managed-spot-recovery', [ f'sky spot launch --cloud aws --region {region} -n {name} "echo SKYPILOT_JOB_ID: \$SKYPILOT_JOB_ID; sleep 1000" -y -d', - 'sleep 300', + 'sleep 360', f's=$(sky spot queue); printf "$s"; echo; echo; printf "$s" | grep {name} | head -n1 | grep "RUNNING"', f'RUN_ID=$(sky spot logs -n {name} --no-follow | grep SKYPILOT_JOB_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id', # Terminate the cluster manually. @@ -830,10 +830,10 @@ def test_spot_recovery_multi_node(): name = _get_cluster_name() region = 'us-west-2' test = Test( - 'managed-spot-recovery', + 'managed-spot-recovery-multi', [ f'sky spot launch --cloud aws --region {region} -n {name} --num-nodes 2 "echo SKYPILOT_JOB_ID: \$SKYPILOT_JOB_ID; sleep 1000" -y -d', - 'sleep 360', + 'sleep 400', f's=$(sky spot queue); printf "$s"; echo; echo; printf "$s" | grep {name} | head -n1 | grep "RUNNING"', f'RUN_ID=$(sky spot logs -n {name} --no-follow | grep SKYPILOT_JOB_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id', # Terminate the worker manually. @@ -854,6 +854,48 @@ def test_spot_recovery_multi_node(): run_one_test(test) +def test_spot_cancellation(): + name = _get_cluster_name() + region = 'us-east-2' + test = Test( + 'managed-spot-cancellation', + [ + f'sky spot launch --cloud aws --region {region} -n {name} "sleep 1000" -y -d', + 'sleep 60', + f's=$(sky spot queue); printf "$s"; echo; echo; printf "$s" | grep {name} | head -n1 | grep "STARTING"', + # Test cancelling the spot job during launching. + f'sky spot cancel -y -n {name}', + 'sleep 5', + f's=$(sky spot queue); printf "$s"; echo; echo; printf "$s" | grep {name} | head -n1 | grep "CANCELLED"', + 'sleep 100', + (f'aws ec2 describe-instances --region {region} ' + f'--filters Name=tag:ray-cluster-name,Values={name}* ' + f'--query Reservations[].Instances[].State[].Name ' + '--output text | grep terminated'), + # Test cancelling the spot job during running. + f'sky spot launch --cloud aws --region {region} -n {name}-2 "sleep 1000" -y -d', + 'sleep 300', + f's=$(sky spot queue); printf "$s"; echo; echo; printf "$s" | grep {name} | head -n1 | grep "RUNNING"', + # Terminate the cluster manually. + (f'aws ec2 terminate-instances --region {region} --instance-ids $(' + f'aws ec2 describe-instances --region {region} ' + f'--filters Name=tag:ray-cluster-name,Values={name}-2* ' + f'--query Reservations[].Instances[].InstanceId ' + '--output text)'), + 'sleep 50', + f's=$(sky spot queue); printf "$s"; echo; echo; printf "$s" | grep {name}-2 | head -n1 | grep "RECOVERING"', + f'sky spot cancel -y -n {name}-2', + 'sleep 10', + f's=$(sky spot queue); printf "$s"; echo; echo; printf "$s" | grep {name}-2 | head -n1 | grep "CANCELLED"', + 'sleep 90', + (f'aws ec2 describe-instances --region {region} ' + f'--filters Name=tag:ray-cluster-name,Values={name}-2* ' + f'--query Reservations[].Instances[].State[].Name ' + '--output text | grep terminated'), + ]) + run_one_test(test) + + # ---------- Testing storage for managed spot ---------- def test_spot_storage(): """Test storage with managed spot"""