From 4294ca4d503752f2edb7385a02ef9589e7048ac6 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 18 Nov 2022 13:59:00 -0800 Subject: [PATCH] Revert "[Spot] Let cancel interrupt the spot job" (#1432) Revert "[Spot] Let cancel interrupt the spot job (#1414)" This reverts commit 3bbf4aaae40d88d100febede4fd62a48c56d66a4. --- sky/spot/controller.py | 243 +++++++++++++++------------------- sky/spot/recovery_strategy.py | 16 ++- sky/spot/spot_utils.py | 3 +- tests/test_smoke.py | 48 +------ 4 files changed, 127 insertions(+), 183 deletions(-) diff --git a/sky/spot/controller.py b/sky/spot/controller.py index 155fc202b5a..c19cee24f4d 100644 --- a/sky/spot/controller.py +++ b/sky/spot/controller.py @@ -6,7 +6,6 @@ import colorama import filelock -import ray import sky from sky import exceptions @@ -29,9 +28,9 @@ class SpotController: def __init__(self, job_id: int, task_yaml: str, retry_until_up: bool) -> None: - self.job_id = job_id - self.task_name = pathlib.Path(task_yaml).stem - self.task = sky.Task.from_yaml(task_yaml) + self._job_id = job_id + self._task_name = pathlib.Path(task_yaml).stem + self._task = sky.Task.from_yaml(task_yaml) self._retry_until_up = retry_until_up # TODO(zhwu): this assumes the specific backend. @@ -40,77 +39,145 @@ def __init__(self, job_id: int, task_yaml: str, # Add a unique identifier to the task environment variables, so that # the user can have the same id for multiple recoveries. # Example value: sky-2022-10-04-22-46-52-467694_id-17 - task_envs = self.task.envs or {} + task_envs = self._task.envs or {} job_id_env_var = common_utils.get_global_job_id( - self.backend.run_timestamp, 'spot', self.job_id) + self.backend.run_timestamp, 'spot', self._job_id) task_envs[constants.JOB_ID_ENV_VAR] = job_id_env_var - self.task.set_envs(task_envs) + self._task.set_envs(task_envs) spot_state.set_submitted( - self.job_id, - self.task_name, + self._job_id, + self._task_name, self.backend.run_timestamp, - resources_str=backend_utils.get_task_resources_str(self.task)) + resources_str=backend_utils.get_task_resources_str(self._task)) logger.info(f'Submitted spot job; SKYPILOT_JOB_ID: {job_id_env_var}') - self.cluster_name = spot_utils.generate_spot_cluster_name( - self.task_name, self.job_id) - self.strategy_executor = recovery_strategy.StrategyExecutor.make( - self.cluster_name, self.backend, self.task, retry_until_up) + self._cluster_name = spot_utils.generate_spot_cluster_name( + self._task_name, self._job_id) + self._strategy_executor = recovery_strategy.StrategyExecutor.make( + self._cluster_name, self.backend, self._task, retry_until_up, + self._handle_signal) + + def _run(self): + """Busy loop monitoring spot cluster status and handling recovery.""" + logger.info(f'Started monitoring spot task {self._task_name} ' + f'(id: {self._job_id})') + spot_state.set_starting(self._job_id) + start_at = self._strategy_executor.launch() + + spot_state.set_started(self._job_id, start_time=start_at) + while True: + time.sleep(spot_utils.JOB_STATUS_CHECK_GAP_SECONDS) + # Handle the signal if it is sent by the user. + self._handle_signal() + + # Check the network connection to avoid false alarm for job failure. + # Network glitch was observed even in the VM. + try: + backend_utils.check_network_connection() + except exceptions.NetworkError: + logger.info( + 'Network is not available. Retrying again in ' + f'{spot_utils.JOB_STATUS_CHECK_GAP_SECONDS} seconds.') + continue + + # NOTE: we do not check cluster status first because race condition + # can occur, i.e. cluster can be down during the job status check. + job_status = spot_utils.get_job_status(self.backend, + self._cluster_name) + + if job_status is not None and not job_status.is_terminal(): + need_recovery = False + if self._task.num_nodes > 1: + # Check the cluster status for multi-node jobs, since the + # job may not be set to FAILED immediately when only some + # of the nodes are preempted. + (cluster_status, + handle) = backend_utils.refresh_cluster_status_handle( + self._cluster_name, force_refresh=True) + if cluster_status != global_user_state.ClusterStatus.UP: + # recover the cluster if it is not up. + logger.info(f'Cluster status {cluster_status.value}. ' + 'Recovering...') + need_recovery = True + if not need_recovery: + # The job and cluster are healthy, continue to monitor the + # job status. + continue + + if job_status == job_lib.JobStatus.SUCCEEDED: + end_time = spot_utils.get_job_timestamp(self.backend, + self._cluster_name, + get_end_time=True) + # The job is done. + spot_state.set_succeeded(self._job_id, end_time=end_time) + break + + if job_status == job_lib.JobStatus.FAILED: + # Check the status of the spot cluster. If it is not UP, + # the cluster is preempted. + (cluster_status, + handle) = backend_utils.refresh_cluster_status_handle( + self._cluster_name, force_refresh=True) + if cluster_status == global_user_state.ClusterStatus.UP: + # The user code has probably crashed. + end_time = spot_utils.get_job_timestamp(self.backend, + self._cluster_name, + get_end_time=True) + logger.info( + 'The user job failed. Please check the logs below.\n' + f'== Logs of the user job (ID: {self._job_id}) ==\n') + self.backend.tail_logs(handle, + None, + spot_job_id=self._job_id) + logger.info(f'\n== End of logs (ID: {self._job_id}) ==') + spot_state.set_failed( + self._job_id, + failure_type=spot_state.SpotStatus.FAILED, + end_time=end_time) + break + # cluster can be down, INIT or STOPPED, based on the interruption + # behavior of the cloud. + # Failed to connect to the cluster or the cluster is partially down. + # job_status is None or job_status == job_lib.JobStatus.FAILED + logger.info('The cluster is preempted.') + spot_state.set_recovering(self._job_id) + recovered_time = self._strategy_executor.recover() + spot_state.set_recovered(self._job_id, + recovered_time=recovered_time) def start(self): """Start the controller.""" try: - self._handle_signal() - controller_task = _controller_run.remote(self) - # Signal can interrupt the underlying controller process. - ready, _ = ray.wait([controller_task], timeout=0) - while not ready: - try: - self._handle_signal() - except exceptions.SpotUserCancelledError as e: - logger.info(f'Cancelling spot job {self.job_id}...') - try: - ray.cancel(controller_task) - ray.get(controller_task) - except ray.exceptions.RayTaskError: - # When the controller task is cancelled, it will raise - # ray.exceptions.RayTaskError, which can be ignored, - # since the SpotUserCancelledError will be raised and - # handled later. - pass - raise e - ready, _ = ray.wait([controller_task], timeout=1) - # Need this to get the exception from the controller task. - ray.get(controller_task) + self._run() except exceptions.SpotUserCancelledError as e: logger.info(e) - spot_state.set_cancelled(self.job_id) + spot_state.set_cancelled(self._job_id) except exceptions.ResourcesUnavailableError as e: logger.error(f'Resources unavailable: {colorama.Fore.RED}{e}' f'{colorama.Style.RESET_ALL}') spot_state.set_failed( - self.job_id, + self._job_id, failure_type=spot_state.SpotStatus.FAILED_NO_RESOURCE) except (Exception, SystemExit) as e: # pylint: disable=broad-except logger.error(traceback.format_exc()) logger.error(f'Unexpected error occurred: {type(e).__name__}: {e}') finally: - self.strategy_executor.terminate_cluster() - job_status = spot_state.get_status(self.job_id) + self._strategy_executor.terminate_cluster() + job_status = spot_state.get_status(self._job_id) # The job can be non-terminal if the controller exited abnormally, # e.g. failed to launch cluster after reaching the MAX_RETRY. if not job_status.is_terminal(): spot_state.set_failed( - self.job_id, + self._job_id, failure_type=spot_state.SpotStatus.FAILED_CONTROLLER) # Clean up Storages with persistent=False. - self.backend.teardown_ephemeral_storage(self.task) + self.backend.teardown_ephemeral_storage(self._task) def _handle_signal(self): """Handle the signal if the user sent it.""" signal_file = pathlib.Path( - spot_utils.SIGNAL_FILE_PREFIX.format(self.job_id)) + spot_utils.SIGNAL_FILE_PREFIX.format(self._job_id)) signal = None if signal_file.exists(): # Filelock is needed to prevent race condition with concurrent @@ -133,97 +200,7 @@ def _handle_signal(self): raise RuntimeError(f'Unknown SkyPilot signal received: {signal.value}.') -@ray.remote(num_cpus=0) -def _controller_run(spot_controller: SpotController): - """Busy loop monitoring spot cluster status and handling recovery.""" - logger.info(f'Started monitoring spot task {spot_controller.task_name} ' - f'(id: {spot_controller.job_id})') - spot_state.set_starting(spot_controller.job_id) - start_at = spot_controller.strategy_executor.launch() - - spot_state.set_started(spot_controller.job_id, start_time=start_at) - while True: - time.sleep(spot_utils.JOB_STATUS_CHECK_GAP_SECONDS) - - # Check the network connection to avoid false alarm for job failure. - # Network glitch was observed even in the VM. - try: - backend_utils.check_network_connection() - except exceptions.NetworkError: - logger.info('Network is not available. Retrying again in ' - f'{spot_utils.JOB_STATUS_CHECK_GAP_SECONDS} seconds.') - continue - - # NOTE: we do not check cluster status first because race condition - # can occur, i.e. cluster can be down during the job status check. - job_status = spot_utils.get_job_status(spot_controller.backend, - spot_controller.cluster_name) - - if job_status is not None and not job_status.is_terminal(): - need_recovery = False - if spot_controller.task.num_nodes > 1: - # Check the cluster status for multi-node jobs, since the - # job may not be set to FAILED immediately when only some - # of the nodes are preempted. - (cluster_status, - handle) = backend_utils.refresh_cluster_status_handle( - spot_controller.cluster_name, force_refresh=True) - if cluster_status != global_user_state.ClusterStatus.UP: - # recover the cluster if it is not up. - logger.info(f'Cluster status {cluster_status.value}. ' - 'Recovering...') - need_recovery = True - if not need_recovery: - # The job and cluster are healthy, continue to monitor the - # job status. - continue - - if job_status == job_lib.JobStatus.SUCCEEDED: - end_time = spot_utils.get_job_timestamp( - spot_controller.backend, - spot_controller.cluster_name, - get_end_time=True) - # The job is done. - spot_state.set_succeeded(spot_controller.job_id, end_time=end_time) - break - - if job_status == job_lib.JobStatus.FAILED: - # Check the status of the spot cluster. If it is not UP, - # the cluster is preempted. - (cluster_status, - handle) = backend_utils.refresh_cluster_status_handle( - spot_controller.cluster_name, force_refresh=True) - if cluster_status == global_user_state.ClusterStatus.UP: - # The user code has probably crashed. - end_time = spot_utils.get_job_timestamp( - spot_controller.backend, - spot_controller.cluster_name, - get_end_time=True) - logger.info( - 'The user job failed. Please check the logs below.\n' - '== Logs of the user job (ID: ' - f'{spot_controller.job_id}) ==\n') - spot_controller.backend.tail_logs( - handle, None, spot_job_id=spot_controller.job_id) - logger.info( - f'\n== End of logs (ID: {spot_controller.job_id}) ==') - spot_state.set_failed(spot_controller.job_id, - failure_type=spot_state.SpotStatus.FAILED, - end_time=end_time) - break - # cluster can be down, INIT or STOPPED, based on the interruption - # behavior of the cloud. - # Failed to connect to the cluster or the cluster is partially down. - # job_status is None or job_status == job_lib.JobStatus.FAILED - logger.info('The cluster is preempted.') - spot_state.set_recovering(spot_controller.job_id) - recovered_time = spot_controller.strategy_executor.recover() - spot_state.set_recovered(spot_controller.job_id, - recovered_time=recovered_time) - - if __name__ == '__main__': - ray.init() parser = argparse.ArgumentParser() parser.add_argument('--job-id', required=True, diff --git a/sky/spot/recovery_strategy.py b/sky/spot/recovery_strategy.py index 534d2d38c9f..7cc439397df 100644 --- a/sky/spot/recovery_strategy.py +++ b/sky/spot/recovery_strategy.py @@ -1,7 +1,7 @@ """The strategy to handle launching/recovery/termination of spot clusters.""" import time import typing -from typing import Optional +from typing import Callable, Optional import sky from sky import exceptions @@ -34,7 +34,8 @@ class StrategyExecutor: RETRY_INIT_GAP_SECONDS = 60 def __init__(self, cluster_name: str, backend: 'backends.Backend', - task: 'task_lib.Task', retry_until_up: bool) -> None: + task: 'task_lib.Task', retry_until_up: bool, + signal_handler: Callable) -> None: """Initialize the strategy executor. Args: @@ -42,12 +43,15 @@ def __init__(self, cluster_name: str, backend: 'backends.Backend', backend: The backend to use. Only CloudVMRayBackend is supported. task: The task to execute. retry_until_up: Whether to retry until the cluster is up. + signal_handler: The signal handler that will raise an exception if a + SkyPilot signal is received. """ self.dag = sky.Dag() self.dag.add(task) self.cluster_name = cluster_name self.backend = backend self.retry_until_up = retry_until_up + self.signal_handler = signal_handler def __init_subclass__(cls, name: str, default: bool = False): SPOT_STRATEGIES[name] = cls @@ -59,7 +63,8 @@ def __init_subclass__(cls, name: str, default: bool = False): @classmethod def make(cls, cluster_name: str, backend: 'backends.Backend', - task: 'task_lib.Task', retry_until_up: bool) -> 'StrategyExecutor': + task: 'task_lib.Task', retry_until_up: bool, + signal_handler: Callable) -> 'StrategyExecutor': """Create a strategy from a task.""" resources = task.resources assert len(resources) == 1, 'Only one resource is supported.' @@ -72,7 +77,7 @@ def make(cls, cluster_name: str, backend: 'backends.Backend', # will be handled by the strategy class. task.set_resources({resources.copy(spot_recovery=None)}) return SPOT_STRATEGIES[spot_recovery](cluster_name, backend, task, - retry_until_up) + retry_until_up, signal_handler) def launch(self) -> Optional[float]: """Launch the spot cluster for the first time. @@ -142,6 +147,9 @@ def _launch(self, max_retry=3, raise_on_failure=True) -> Optional[float]: backoff = common_utils.Backoff(self.RETRY_INIT_GAP_SECONDS) while True: retry_cnt += 1 + # Check the signal every time to be more responsive to user + # signals, such as Cancel. + self.signal_handler() retry_launch = False exception = None try: diff --git a/sky/spot/spot_utils.py b/sky/spot/spot_utils.py index 9c134c7f6b2..452a0a15a52 100644 --- a/sky/spot/spot_utils.py +++ b/sky/spot/spot_utils.py @@ -178,7 +178,8 @@ def cancel_jobs_by_id(job_ids: Optional[List[int]]) -> str: cancelled_job_ids_str = ', '.join(map(str, cancelled_job_ids)) identity_str = f'Jobs with IDs {cancelled_job_ids_str} are' - return f'{identity_str} scheduled to be cancelled.' + return (f'{identity_str} scheduled to be cancelled within ' + f'{JOB_STATUS_CHECK_GAP_SECONDS} seconds.') def cancel_job_by_name(job_name: str) -> str: diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 321080cc6bb..4257833c381 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -805,7 +805,7 @@ def test_spot_recovery(): 'managed-spot-recovery', [ f'sky spot launch --cloud aws --region {region} -n {name} "echo SKYPILOT_JOB_ID: \$SKYPILOT_JOB_ID; sleep 1000" -y -d', - 'sleep 360', + 'sleep 300', f's=$(sky spot queue); printf "$s"; echo; echo; printf "$s" | grep {name} | head -n1 | grep "RUNNING"', f'RUN_ID=$(sky spot logs -n {name} --no-follow | grep SKYPILOT_JOB_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id', # Terminate the cluster manually. @@ -830,10 +830,10 @@ def test_spot_recovery_multi_node(): name = _get_cluster_name() region = 'us-west-2' test = Test( - 'managed-spot-recovery-multi', + 'managed-spot-recovery', [ f'sky spot launch --cloud aws --region {region} -n {name} --num-nodes 2 "echo SKYPILOT_JOB_ID: \$SKYPILOT_JOB_ID; sleep 1000" -y -d', - 'sleep 400', + 'sleep 360', f's=$(sky spot queue); printf "$s"; echo; echo; printf "$s" | grep {name} | head -n1 | grep "RUNNING"', f'RUN_ID=$(sky spot logs -n {name} --no-follow | grep SKYPILOT_JOB_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id', # Terminate the worker manually. @@ -854,48 +854,6 @@ def test_spot_recovery_multi_node(): run_one_test(test) -def test_spot_cancellation(): - name = _get_cluster_name() - region = 'us-east-2' - test = Test( - 'managed-spot-cancellation', - [ - f'sky spot launch --cloud aws --region {region} -n {name} "sleep 1000" -y -d', - 'sleep 60', - f's=$(sky spot queue); printf "$s"; echo; echo; printf "$s" | grep {name} | head -n1 | grep "STARTING"', - # Test cancelling the spot job during launching. - f'sky spot cancel -y -n {name}', - 'sleep 5', - f's=$(sky spot queue); printf "$s"; echo; echo; printf "$s" | grep {name} | head -n1 | grep "CANCELLED"', - 'sleep 100', - (f'aws ec2 describe-instances --region {region} ' - f'--filters Name=tag:ray-cluster-name,Values={name}* ' - f'--query Reservations[].Instances[].State[].Name ' - '--output text | grep terminated'), - # Test cancelling the spot job during running. - f'sky spot launch --cloud aws --region {region} -n {name}-2 "sleep 1000" -y -d', - 'sleep 300', - f's=$(sky spot queue); printf "$s"; echo; echo; printf "$s" | grep {name} | head -n1 | grep "RUNNING"', - # Terminate the cluster manually. - (f'aws ec2 terminate-instances --region {region} --instance-ids $(' - f'aws ec2 describe-instances --region {region} ' - f'--filters Name=tag:ray-cluster-name,Values={name}-2* ' - f'--query Reservations[].Instances[].InstanceId ' - '--output text)'), - 'sleep 50', - f's=$(sky spot queue); printf "$s"; echo; echo; printf "$s" | grep {name}-2 | head -n1 | grep "RECOVERING"', - f'sky spot cancel -y -n {name}-2', - 'sleep 10', - f's=$(sky spot queue); printf "$s"; echo; echo; printf "$s" | grep {name}-2 | head -n1 | grep "CANCELLED"', - 'sleep 90', - (f'aws ec2 describe-instances --region {region} ' - f'--filters Name=tag:ray-cluster-name,Values={name}-2* ' - f'--query Reservations[].Instances[].State[].Name ' - '--output text | grep terminated'), - ]) - run_one_test(test) - - # ---------- Testing storage for managed spot ---------- def test_spot_storage(): """Test storage with managed spot"""