Skip to content

Commit

Permalink
[Spot] Let cancel interrupt the spot job (#1414)
Browse files Browse the repository at this point in the history
* Let cancel interrupt the job

* Add test

* Fix test

* Cancel early

* fix test

* fix test

* Fix exceptions

* pass test

* increase waiting time

* address comments

* add job id

* remove 'auto' in ray.init
  • Loading branch information
Michaelvll authored Nov 18, 2022
1 parent 9aecc7b commit 3bbf4aa
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 127 deletions.
243 changes: 133 additions & 110 deletions sky/spot/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import colorama
import filelock
import ray

import sky
from sky import exceptions
Expand All @@ -28,9 +29,9 @@ class SpotController:

def __init__(self, job_id: int, task_yaml: str,
retry_until_up: bool) -> None:
self._job_id = job_id
self._task_name = pathlib.Path(task_yaml).stem
self._task = sky.Task.from_yaml(task_yaml)
self.job_id = job_id
self.task_name = pathlib.Path(task_yaml).stem
self.task = sky.Task.from_yaml(task_yaml)

self._retry_until_up = retry_until_up
# TODO(zhwu): this assumes the specific backend.
Expand All @@ -39,145 +40,77 @@ def __init__(self, job_id: int, task_yaml: str,
# Add a unique identifier to the task environment variables, so that
# the user can have the same id for multiple recoveries.
# Example value: sky-2022-10-04-22-46-52-467694_id-17
task_envs = self._task.envs or {}
task_envs = self.task.envs or {}
job_id_env_var = common_utils.get_global_job_id(
self.backend.run_timestamp, 'spot', self._job_id)
self.backend.run_timestamp, 'spot', self.job_id)
task_envs[constants.JOB_ID_ENV_VAR] = job_id_env_var
self._task.set_envs(task_envs)
self.task.set_envs(task_envs)

spot_state.set_submitted(
self._job_id,
self._task_name,
self.job_id,
self.task_name,
self.backend.run_timestamp,
resources_str=backend_utils.get_task_resources_str(self._task))
resources_str=backend_utils.get_task_resources_str(self.task))
logger.info(f'Submitted spot job; SKYPILOT_JOB_ID: {job_id_env_var}')
self._cluster_name = spot_utils.generate_spot_cluster_name(
self._task_name, self._job_id)
self._strategy_executor = recovery_strategy.StrategyExecutor.make(
self._cluster_name, self.backend, self._task, retry_until_up,
self._handle_signal)

def _run(self):
"""Busy loop monitoring spot cluster status and handling recovery."""
logger.info(f'Started monitoring spot task {self._task_name} '
f'(id: {self._job_id})')
spot_state.set_starting(self._job_id)
start_at = self._strategy_executor.launch()

spot_state.set_started(self._job_id, start_time=start_at)
while True:
time.sleep(spot_utils.JOB_STATUS_CHECK_GAP_SECONDS)
# Handle the signal if it is sent by the user.
self._handle_signal()

# Check the network connection to avoid false alarm for job failure.
# Network glitch was observed even in the VM.
try:
backend_utils.check_network_connection()
except exceptions.NetworkError:
logger.info(
'Network is not available. Retrying again in '
f'{spot_utils.JOB_STATUS_CHECK_GAP_SECONDS} seconds.')
continue

# NOTE: we do not check cluster status first because race condition
# can occur, i.e. cluster can be down during the job status check.
job_status = spot_utils.get_job_status(self.backend,
self._cluster_name)

if job_status is not None and not job_status.is_terminal():
need_recovery = False
if self._task.num_nodes > 1:
# Check the cluster status for multi-node jobs, since the
# job may not be set to FAILED immediately when only some
# of the nodes are preempted.
(cluster_status,
handle) = backend_utils.refresh_cluster_status_handle(
self._cluster_name, force_refresh=True)
if cluster_status != global_user_state.ClusterStatus.UP:
# recover the cluster if it is not up.
logger.info(f'Cluster status {cluster_status.value}. '
'Recovering...')
need_recovery = True
if not need_recovery:
# The job and cluster are healthy, continue to monitor the
# job status.
continue

if job_status == job_lib.JobStatus.SUCCEEDED:
end_time = spot_utils.get_job_timestamp(self.backend,
self._cluster_name,
get_end_time=True)
# The job is done.
spot_state.set_succeeded(self._job_id, end_time=end_time)
break

if job_status == job_lib.JobStatus.FAILED:
# Check the status of the spot cluster. If it is not UP,
# the cluster is preempted.
(cluster_status,
handle) = backend_utils.refresh_cluster_status_handle(
self._cluster_name, force_refresh=True)
if cluster_status == global_user_state.ClusterStatus.UP:
# The user code has probably crashed.
end_time = spot_utils.get_job_timestamp(self.backend,
self._cluster_name,
get_end_time=True)
logger.info(
'The user job failed. Please check the logs below.\n'
f'== Logs of the user job (ID: {self._job_id}) ==\n')
self.backend.tail_logs(handle,
None,
spot_job_id=self._job_id)
logger.info(f'\n== End of logs (ID: {self._job_id}) ==')
spot_state.set_failed(
self._job_id,
failure_type=spot_state.SpotStatus.FAILED,
end_time=end_time)
break
# cluster can be down, INIT or STOPPED, based on the interruption
# behavior of the cloud.
# Failed to connect to the cluster or the cluster is partially down.
# job_status is None or job_status == job_lib.JobStatus.FAILED
logger.info('The cluster is preempted.')
spot_state.set_recovering(self._job_id)
recovered_time = self._strategy_executor.recover()
spot_state.set_recovered(self._job_id,
recovered_time=recovered_time)
self.cluster_name = spot_utils.generate_spot_cluster_name(
self.task_name, self.job_id)
self.strategy_executor = recovery_strategy.StrategyExecutor.make(
self.cluster_name, self.backend, self.task, retry_until_up)

def start(self):
"""Start the controller."""
try:
self._run()
self._handle_signal()
controller_task = _controller_run.remote(self)
# Signal can interrupt the underlying controller process.
ready, _ = ray.wait([controller_task], timeout=0)
while not ready:
try:
self._handle_signal()
except exceptions.SpotUserCancelledError as e:
logger.info(f'Cancelling spot job {self.job_id}...')
try:
ray.cancel(controller_task)
ray.get(controller_task)
except ray.exceptions.RayTaskError:
# When the controller task is cancelled, it will raise
# ray.exceptions.RayTaskError, which can be ignored,
# since the SpotUserCancelledError will be raised and
# handled later.
pass
raise e
ready, _ = ray.wait([controller_task], timeout=1)
# Need this to get the exception from the controller task.
ray.get(controller_task)
except exceptions.SpotUserCancelledError as e:
logger.info(e)
spot_state.set_cancelled(self._job_id)
spot_state.set_cancelled(self.job_id)
except exceptions.ResourcesUnavailableError as e:
logger.error(f'Resources unavailable: {colorama.Fore.RED}{e}'
f'{colorama.Style.RESET_ALL}')
spot_state.set_failed(
self._job_id,
self.job_id,
failure_type=spot_state.SpotStatus.FAILED_NO_RESOURCE)
except (Exception, SystemExit) as e: # pylint: disable=broad-except
logger.error(traceback.format_exc())
logger.error(f'Unexpected error occurred: {type(e).__name__}: {e}')
finally:
self._strategy_executor.terminate_cluster()
job_status = spot_state.get_status(self._job_id)
self.strategy_executor.terminate_cluster()
job_status = spot_state.get_status(self.job_id)
# The job can be non-terminal if the controller exited abnormally,
# e.g. failed to launch cluster after reaching the MAX_RETRY.
if not job_status.is_terminal():
spot_state.set_failed(
self._job_id,
self.job_id,
failure_type=spot_state.SpotStatus.FAILED_CONTROLLER)

# Clean up Storages with persistent=False.
self.backend.teardown_ephemeral_storage(self._task)
self.backend.teardown_ephemeral_storage(self.task)

def _handle_signal(self):
"""Handle the signal if the user sent it."""
signal_file = pathlib.Path(
spot_utils.SIGNAL_FILE_PREFIX.format(self._job_id))
spot_utils.SIGNAL_FILE_PREFIX.format(self.job_id))
signal = None
if signal_file.exists():
# Filelock is needed to prevent race condition with concurrent
Expand All @@ -200,7 +133,97 @@ def _handle_signal(self):
raise RuntimeError(f'Unknown SkyPilot signal received: {signal.value}.')


@ray.remote(num_cpus=0)
def _controller_run(spot_controller: SpotController):
"""Busy loop monitoring spot cluster status and handling recovery."""
logger.info(f'Started monitoring spot task {spot_controller.task_name} '
f'(id: {spot_controller.job_id})')
spot_state.set_starting(spot_controller.job_id)
start_at = spot_controller.strategy_executor.launch()

spot_state.set_started(spot_controller.job_id, start_time=start_at)
while True:
time.sleep(spot_utils.JOB_STATUS_CHECK_GAP_SECONDS)

# Check the network connection to avoid false alarm for job failure.
# Network glitch was observed even in the VM.
try:
backend_utils.check_network_connection()
except exceptions.NetworkError:
logger.info('Network is not available. Retrying again in '
f'{spot_utils.JOB_STATUS_CHECK_GAP_SECONDS} seconds.')
continue

# NOTE: we do not check cluster status first because race condition
# can occur, i.e. cluster can be down during the job status check.
job_status = spot_utils.get_job_status(spot_controller.backend,
spot_controller.cluster_name)

if job_status is not None and not job_status.is_terminal():
need_recovery = False
if spot_controller.task.num_nodes > 1:
# Check the cluster status for multi-node jobs, since the
# job may not be set to FAILED immediately when only some
# of the nodes are preempted.
(cluster_status,
handle) = backend_utils.refresh_cluster_status_handle(
spot_controller.cluster_name, force_refresh=True)
if cluster_status != global_user_state.ClusterStatus.UP:
# recover the cluster if it is not up.
logger.info(f'Cluster status {cluster_status.value}. '
'Recovering...')
need_recovery = True
if not need_recovery:
# The job and cluster are healthy, continue to monitor the
# job status.
continue

if job_status == job_lib.JobStatus.SUCCEEDED:
end_time = spot_utils.get_job_timestamp(
spot_controller.backend,
spot_controller.cluster_name,
get_end_time=True)
# The job is done.
spot_state.set_succeeded(spot_controller.job_id, end_time=end_time)
break

if job_status == job_lib.JobStatus.FAILED:
# Check the status of the spot cluster. If it is not UP,
# the cluster is preempted.
(cluster_status,
handle) = backend_utils.refresh_cluster_status_handle(
spot_controller.cluster_name, force_refresh=True)
if cluster_status == global_user_state.ClusterStatus.UP:
# The user code has probably crashed.
end_time = spot_utils.get_job_timestamp(
spot_controller.backend,
spot_controller.cluster_name,
get_end_time=True)
logger.info(
'The user job failed. Please check the logs below.\n'
'== Logs of the user job (ID: '
f'{spot_controller.job_id}) ==\n')
spot_controller.backend.tail_logs(
handle, None, spot_job_id=spot_controller.job_id)
logger.info(
f'\n== End of logs (ID: {spot_controller.job_id}) ==')
spot_state.set_failed(spot_controller.job_id,
failure_type=spot_state.SpotStatus.FAILED,
end_time=end_time)
break
# cluster can be down, INIT or STOPPED, based on the interruption
# behavior of the cloud.
# Failed to connect to the cluster or the cluster is partially down.
# job_status is None or job_status == job_lib.JobStatus.FAILED
logger.info('The cluster is preempted.')
spot_state.set_recovering(spot_controller.job_id)
recovered_time = spot_controller.strategy_executor.recover()
spot_state.set_recovered(spot_controller.job_id,
recovered_time=recovered_time)


if __name__ == '__main__':
ray.init()
parser = argparse.ArgumentParser()
parser.add_argument('--job-id',
required=True,
Expand Down
16 changes: 4 additions & 12 deletions sky/spot/recovery_strategy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""The strategy to handle launching/recovery/termination of spot clusters."""
import time
import typing
from typing import Callable, Optional
from typing import Optional

import sky
from sky import exceptions
Expand Down Expand Up @@ -34,24 +34,20 @@ class StrategyExecutor:
RETRY_INIT_GAP_SECONDS = 60

def __init__(self, cluster_name: str, backend: 'backends.Backend',
task: 'task_lib.Task', retry_until_up: bool,
signal_handler: Callable) -> None:
task: 'task_lib.Task', retry_until_up: bool) -> None:
"""Initialize the strategy executor.
Args:
cluster_name: The name of the cluster.
backend: The backend to use. Only CloudVMRayBackend is supported.
task: The task to execute.
retry_until_up: Whether to retry until the cluster is up.
signal_handler: The signal handler that will raise an exception if a
SkyPilot signal is received.
"""
self.dag = sky.Dag()
self.dag.add(task)
self.cluster_name = cluster_name
self.backend = backend
self.retry_until_up = retry_until_up
self.signal_handler = signal_handler

def __init_subclass__(cls, name: str, default: bool = False):
SPOT_STRATEGIES[name] = cls
Expand All @@ -63,8 +59,7 @@ def __init_subclass__(cls, name: str, default: bool = False):

@classmethod
def make(cls, cluster_name: str, backend: 'backends.Backend',
task: 'task_lib.Task', retry_until_up: bool,
signal_handler: Callable) -> 'StrategyExecutor':
task: 'task_lib.Task', retry_until_up: bool) -> 'StrategyExecutor':
"""Create a strategy from a task."""
resources = task.resources
assert len(resources) == 1, 'Only one resource is supported.'
Expand All @@ -77,7 +72,7 @@ def make(cls, cluster_name: str, backend: 'backends.Backend',
# will be handled by the strategy class.
task.set_resources({resources.copy(spot_recovery=None)})
return SPOT_STRATEGIES[spot_recovery](cluster_name, backend, task,
retry_until_up, signal_handler)
retry_until_up)

def launch(self) -> Optional[float]:
"""Launch the spot cluster for the first time.
Expand Down Expand Up @@ -147,9 +142,6 @@ def _launch(self, max_retry=3, raise_on_failure=True) -> Optional[float]:
backoff = common_utils.Backoff(self.RETRY_INIT_GAP_SECONDS)
while True:
retry_cnt += 1
# Check the signal every time to be more responsive to user
# signals, such as Cancel.
self.signal_handler()
retry_launch = False
exception = None
try:
Expand Down
3 changes: 1 addition & 2 deletions sky/spot/spot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,7 @@ def cancel_jobs_by_id(job_ids: Optional[List[int]]) -> str:
cancelled_job_ids_str = ', '.join(map(str, cancelled_job_ids))
identity_str = f'Jobs with IDs {cancelled_job_ids_str} are'

return (f'{identity_str} scheduled to be cancelled within '
f'{JOB_STATUS_CHECK_GAP_SECONDS} seconds.')
return f'{identity_str} scheduled to be cancelled.'


def cancel_job_by_name(job_name: str) -> str:
Expand Down
Loading

0 comments on commit 3bbf4aa

Please sign in to comment.