Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Spot] Let cancel interrupt the spot job (#1414) #1433

Merged
merged 35 commits into from
Nov 22, 2022
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
c3a8599
Let cancel interrupt the job
Michaelvll Nov 14, 2022
169e8b5
Add test
Michaelvll Nov 14, 2022
3e8bac6
Fix test
Michaelvll Nov 14, 2022
e82d89c
Cancel early
Michaelvll Nov 14, 2022
c081c60
fix test
Michaelvll Nov 14, 2022
3441b43
fix test
Michaelvll Nov 14, 2022
88f4db5
Fix exceptions
Michaelvll Nov 14, 2022
331fa32
pass test
Michaelvll Nov 14, 2022
25c568c
increase waiting time
Michaelvll Nov 15, 2022
71a253d
address comments
Michaelvll Nov 17, 2022
5253d5d
add job id
Michaelvll Nov 17, 2022
6e9ba0c
remove 'auto' in ray.init
Michaelvll Nov 17, 2022
7bffd84
Fix serialization problem
Michaelvll Nov 19, 2022
a5e7b20
refactor a bit
Michaelvll Nov 19, 2022
0b66584
Fix
Michaelvll Nov 19, 2022
aa6dd91
Add comments
Michaelvll Nov 19, 2022
3065fed
format
Michaelvll Nov 19, 2022
5f0d801
pylint
Michaelvll Nov 19, 2022
7875d35
revert a format change
Michaelvll Nov 19, 2022
f7b4f8b
Add docstr
Michaelvll Nov 19, 2022
3cfa747
Move ray.init
Michaelvll Nov 19, 2022
c140801
replace ray with multiprocess.Process
Michaelvll Nov 20, 2022
256d1f9
Add test for setup cancelation
Michaelvll Nov 20, 2022
3feb30f
Fix logging
Michaelvll Nov 20, 2022
8f469a5
Fix test
Michaelvll Nov 20, 2022
ba0f7b7
lint
Michaelvll Nov 20, 2022
af72709
Use SIGTERM instead
Michaelvll Nov 20, 2022
0556774
format
Michaelvll Nov 20, 2022
98db4a6
Change exception type
Michaelvll Nov 20, 2022
1338a22
revert to KeyboardInterrupt
Michaelvll Nov 20, 2022
76b62fb
remove
Michaelvll Nov 20, 2022
8985d73
Fix test
Michaelvll Nov 21, 2022
d93c5f6
fix test
Michaelvll Nov 21, 2022
832bde1
fix test
Michaelvll Nov 22, 2022
2e3fbe4
typo
Michaelvll Nov 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 83 additions & 44 deletions sky/spot/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import colorama
import filelock
import ray

import sky
from sky import exceptions
Expand Down Expand Up @@ -34,28 +35,27 @@ def __init__(self, job_id: int, task_yaml: str,

self._retry_until_up = retry_until_up
# TODO(zhwu): this assumes the specific backend.
self.backend = cloud_vm_ray_backend.CloudVmRayBackend()
self._backend = cloud_vm_ray_backend.CloudVmRayBackend()

# Add a unique identifier to the task environment variables, so that
# the user can have the same id for multiple recoveries.
# Example value: sky-2022-10-04-22-46-52-467694_id-17
task_envs = self._task.envs or {}
job_id_env_var = common_utils.get_global_job_id(
self.backend.run_timestamp, 'spot', self._job_id)
self._backend.run_timestamp, 'spot', self._job_id)
task_envs[constants.JOB_ID_ENV_VAR] = job_id_env_var
self._task.set_envs(task_envs)

spot_state.set_submitted(
self._job_id,
self._task_name,
self.backend.run_timestamp,
self._backend.run_timestamp,
resources_str=backend_utils.get_task_resources_str(self._task))
logger.info(f'Submitted spot job; SKYPILOT_JOB_ID: {job_id_env_var}')
self._cluster_name = spot_utils.generate_spot_cluster_name(
self._task_name, self._job_id)
self._strategy_executor = recovery_strategy.StrategyExecutor.make(
self._cluster_name, self.backend, self._task, retry_until_up,
self._handle_signal)
self._cluster_name, self._backend, self._task, retry_until_up)

def _run(self):
"""Busy loop monitoring spot cluster status and handling recovery."""
Expand All @@ -67,8 +67,6 @@ def _run(self):
spot_state.set_started(self._job_id, start_time=start_at)
while True:
time.sleep(spot_utils.JOB_STATUS_CHECK_GAP_SECONDS)
# Handle the signal if it is sent by the user.
self._handle_signal()

# Check the network connection to avoid false alarm for job failure.
# Network glitch was observed even in the VM.
Expand All @@ -82,7 +80,7 @@ def _run(self):

# NOTE: we do not check cluster status first because race condition
# can occur, i.e. cluster can be down during the job status check.
job_status = spot_utils.get_job_status(self.backend,
job_status = spot_utils.get_job_status(self._backend,
self._cluster_name)

if job_status is not None and not job_status.is_terminal():
Expand All @@ -105,7 +103,7 @@ def _run(self):
continue

if job_status == job_lib.JobStatus.SUCCEEDED:
end_time = spot_utils.get_job_timestamp(self.backend,
end_time = spot_utils.get_job_timestamp(self._backend,
self._cluster_name,
get_end_time=True)
# The job is done.
Expand All @@ -120,15 +118,15 @@ def _run(self):
self._cluster_name, force_refresh=True)
if cluster_status == global_user_state.ClusterStatus.UP:
# The user code has probably crashed.
end_time = spot_utils.get_job_timestamp(self.backend,
end_time = spot_utils.get_job_timestamp(self._backend,
self._cluster_name,
get_end_time=True)
logger.info(
'The user job failed. Please check the logs below.\n'
f'== Logs of the user job (ID: {self._job_id}) ==\n')
self.backend.tail_logs(handle,
None,
spot_job_id=self._job_id)
self._backend.tail_logs(handle,
None,
spot_job_id=self._job_id)
logger.info(f'\n== End of logs (ID: {self._job_id}) ==')
spot_state.set_failed(
self._job_id,
Expand All @@ -145,12 +143,13 @@ def _run(self):
spot_state.set_recovered(self._job_id,
recovered_time=recovered_time)

def start(self):
"""Start the controller."""
def run(self):
"""Run controller logic and handle exceptions."""
try:
self._run()
except exceptions.SpotUserCancelledError as e:
logger.info(e)
except KeyboardInterrupt as e:
# ray.cancel will raise KeyboardInterrupt.
logger.error(e)
spot_state.set_cancelled(self._job_id)
except exceptions.ResourcesUnavailableError as e:
logger.error(f'Resources unavailable: {colorama.Fore.RED}{e}'
Expand All @@ -167,40 +166,82 @@ def start(self):
# The job can be non-terminal if the controller exited abnormally,
# e.g. failed to launch cluster after reaching the MAX_RETRY.
if not job_status.is_terminal():
logger.info(f'Previous spot job status: {job_status.value}')
spot_state.set_failed(
self._job_id,
failure_type=spot_state.SpotStatus.FAILED_CONTROLLER)

# Clean up Storages with persistent=False.
self.backend.teardown_ephemeral_storage(self._task)

def _handle_signal(self):
"""Handle the signal if the user sent it."""
signal_file = pathlib.Path(
spot_utils.SIGNAL_FILE_PREFIX.format(self._job_id))
signal = None
if signal_file.exists():
# Filelock is needed to prevent race condition with concurrent
# signal writing.
# TODO(mraheja): remove pylint disabling when filelock version
# updated
# pylint: disable=abstract-class-instantiated
with filelock.FileLock(str(signal_file) + '.lock'):
with signal_file.open(mode='r') as f:
signal = f.read().strip()
self._backend.teardown_ephemeral_storage(self._task)


@ray.remote(num_cpus=0)
def _run_controller(job_id: int, task_yaml: str, retry_until_up: bool):
"""Runs the controller in a remote process for interruption."""
# The controller needs to be instantiated in the remote process, since
# the controller is not serializable.
spot_controller = SpotController(job_id, task_yaml, retry_until_up)
spot_controller.run()


def _handle_signal(job_id):
"""Handle the signal if the user sent it."""
signal_file = pathlib.Path(spot_utils.SIGNAL_FILE_PREFIX.format(job_id))
signal = None
if signal_file.exists():
# Filelock is needed to prevent race condition with concurrent
# signal writing.
# TODO(mraheja): remove pylint disabling when filelock version
# updated
# pylint: disable=abstract-class-instantiated
with filelock.FileLock(str(signal_file) + '.lock'):
with signal_file.open(mode='r') as f:
signal = f.read().strip()
try:
signal = spot_utils.UserSignal(signal)
# Remove the signal file, after reading the signal.
signal_file.unlink()
if signal is None:
return
if signal == spot_utils.UserSignal.CANCEL:
raise exceptions.SpotUserCancelledError(
f'User sent {signal.value} signal.')
except ValueError:
logger.warning(
f'Unknown signal received: {signal}. Ignoring.')
signal = None
# Remove the signal file, after reading the signal.
signal_file.unlink()
if signal is None:
# None or empty string.
return
assert signal == spot_utils.UserSignal.CANCEL, (
f'Only cancel signal is supported, but {signal} got.')
raise exceptions.SpotUserCancelledError(f'User sent {signal.value} signal.')


raise RuntimeError(f'Unknown SkyPilot signal received: {signal.value}.')
def start(job_id, task_yaml, retry_until_up):
"""Start the controller."""
controller_task = None
try:
_handle_signal(job_id)
controller_task = _run_controller.remote(job_id, task_yaml,
retry_until_up)
# Signal can interrupt the underlying controller process.
ready, _ = ray.wait([controller_task], timeout=0)
while not ready:
_handle_signal(job_id)
ready, _ = ray.wait([controller_task], timeout=1)
except exceptions.SpotUserCancelledError:
logger.info(f'Cancelling spot job {job_id}...')
try:
if controller_task is not None:
# This will raise KeyboardInterrupt in the task.
ray.cancel(controller_task)
ray.get(controller_task)
except ray.exceptions.RayTaskError:
# When the controller task is cancelled, it will raise
# ray.exceptions.RayTaskError, which can be ignored,
# since the SpotUserCancelledError will be raised and
# handled later.
pass


if __name__ == '__main__':
ray.init()
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
parser = argparse.ArgumentParser()
parser.add_argument('--job-id',
required=True,
Expand All @@ -214,6 +255,4 @@ def _handle_signal(self):
help='The path to the user spot task yaml file. '
'The file name is the spot task name.')
args = parser.parse_args()
controller = SpotController(args.job_id, args.task_yaml,
args.retry_until_up)
controller.start()
start(args.job_id, args.task_yaml, args.retry_until_up)
16 changes: 4 additions & 12 deletions sky/spot/recovery_strategy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""The strategy to handle launching/recovery/termination of spot clusters."""
import time
import typing
from typing import Callable, Optional
from typing import Optional

import sky
from sky import exceptions
Expand Down Expand Up @@ -34,24 +34,20 @@ class StrategyExecutor:
RETRY_INIT_GAP_SECONDS = 60

def __init__(self, cluster_name: str, backend: 'backends.Backend',
task: 'task_lib.Task', retry_until_up: bool,
signal_handler: Callable) -> None:
task: 'task_lib.Task', retry_until_up: bool) -> None:
"""Initialize the strategy executor.

Args:
cluster_name: The name of the cluster.
backend: The backend to use. Only CloudVMRayBackend is supported.
task: The task to execute.
retry_until_up: Whether to retry until the cluster is up.
signal_handler: The signal handler that will raise an exception if a
SkyPilot signal is received.
"""
self.dag = sky.Dag()
self.dag.add(task)
self.cluster_name = cluster_name
self.backend = backend
self.retry_until_up = retry_until_up
self.signal_handler = signal_handler

def __init_subclass__(cls, name: str, default: bool = False):
SPOT_STRATEGIES[name] = cls
Expand All @@ -63,8 +59,7 @@ def __init_subclass__(cls, name: str, default: bool = False):

@classmethod
def make(cls, cluster_name: str, backend: 'backends.Backend',
task: 'task_lib.Task', retry_until_up: bool,
signal_handler: Callable) -> 'StrategyExecutor':
task: 'task_lib.Task', retry_until_up: bool) -> 'StrategyExecutor':
"""Create a strategy from a task."""
resources = task.resources
assert len(resources) == 1, 'Only one resource is supported.'
Expand All @@ -77,7 +72,7 @@ def make(cls, cluster_name: str, backend: 'backends.Backend',
# will be handled by the strategy class.
task.set_resources({resources.copy(spot_recovery=None)})
return SPOT_STRATEGIES[spot_recovery](cluster_name, backend, task,
retry_until_up, signal_handler)
retry_until_up)

def launch(self) -> Optional[float]:
"""Launch the spot cluster for the first time.
Expand Down Expand Up @@ -147,9 +142,6 @@ def _launch(self, max_retry=3, raise_on_failure=True) -> Optional[float]:
backoff = common_utils.Backoff(self.RETRY_INIT_GAP_SECONDS)
while True:
retry_cnt += 1
# Check the signal every time to be more responsive to user
# signals, such as Cancel.
self.signal_handler()
retry_launch = False
exception = None
try:
Expand Down
3 changes: 1 addition & 2 deletions sky/spot/spot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,7 @@ def cancel_jobs_by_id(job_ids: Optional[List[int]]) -> str:
cancelled_job_ids_str = ', '.join(map(str, cancelled_job_ids))
identity_str = f'Jobs with IDs {cancelled_job_ids_str} are'

return (f'{identity_str} scheduled to be cancelled within '
f'{JOB_STATUS_CHECK_GAP_SECONDS} seconds.')
return f'{identity_str} scheduled to be cancelled.'


def cancel_job_by_name(job_name: str) -> str:
Expand Down
50 changes: 46 additions & 4 deletions tests/test_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ def test_spot():


# ---------- Testing managed spot ----------
def test_gcp_spot():
def test_spot_gcp():
"""Test managed spot on GCP."""
name = _get_cluster_name()
test = Test(
Expand Down Expand Up @@ -805,7 +805,7 @@ def test_spot_recovery():
'managed-spot-recovery',
[
f'sky spot launch --cloud aws --region {region} -n {name} "echo SKYPILOT_JOB_ID: \$SKYPILOT_JOB_ID; sleep 1000" -y -d',
'sleep 300',
'sleep 360',
f's=$(sky spot queue); printf "$s"; echo; echo; printf "$s" | grep {name} | head -n1 | grep "RUNNING"',
f'RUN_ID=$(sky spot logs -n {name} --no-follow | grep SKYPILOT_JOB_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id',
# Terminate the cluster manually.
Expand All @@ -830,10 +830,10 @@ def test_spot_recovery_multi_node():
name = _get_cluster_name()
region = 'us-west-2'
test = Test(
'managed-spot-recovery',
'managed-spot-recovery-multi',
[
f'sky spot launch --cloud aws --region {region} -n {name} --num-nodes 2 "echo SKYPILOT_JOB_ID: \$SKYPILOT_JOB_ID; sleep 1000" -y -d',
'sleep 360',
'sleep 400',
f's=$(sky spot queue); printf "$s"; echo; echo; printf "$s" | grep {name} | head -n1 | grep "RUNNING"',
f'RUN_ID=$(sky spot logs -n {name} --no-follow | grep SKYPILOT_JOB_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id',
# Terminate the worker manually.
Expand All @@ -854,6 +854,48 @@ def test_spot_recovery_multi_node():
run_one_test(test)


def test_spot_cancellation():
name = _get_cluster_name()
region = 'us-east-2'
test = Test(
'managed-spot-cancellation',
[
f'sky spot launch --cloud aws --region {region} -n {name} "sleep 1000" -y -d',
'sleep 60',
f's=$(sky spot queue); printf "$s"; echo; echo; printf "$s" | grep {name} | head -n1 | grep "STARTING"',
# Test cancelling the spot job during launching.
f'sky spot cancel -y -n {name}',
'sleep 5',
f's=$(sky spot queue); printf "$s"; echo; echo; printf "$s" | grep {name} | head -n1 | grep "CANCELLED"',
'sleep 100',
(f'aws ec2 describe-instances --region {region} '
f'--filters Name=tag:ray-cluster-name,Values={name}* '
f'--query Reservations[].Instances[].State[].Name '
'--output text | grep terminated'),
# Test cancelling the spot job during running.
f'sky spot launch --cloud aws --region {region} -n {name}-2 "sleep 1000" -y -d',
'sleep 300',
f's=$(sky spot queue); printf "$s"; echo; echo; printf "$s" | grep {name} | head -n1 | grep "RUNNING"',
# Terminate the cluster manually.
(f'aws ec2 terminate-instances --region {region} --instance-ids $('
f'aws ec2 describe-instances --region {region} '
f'--filters Name=tag:ray-cluster-name,Values={name}-2* '
f'--query Reservations[].Instances[].InstanceId '
'--output text)'),
'sleep 50',
f's=$(sky spot queue); printf "$s"; echo; echo; printf "$s" | grep {name}-2 | head -n1 | grep "RECOVERING"',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to test recovery in this test? Isn't it already covered by test_spot_recovery and test_spot_recovery_multi_node? Maybe we can remove it in the interest of keeping our tests fast (while avoiding test duplication)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was intended to test that the spot job can be canceled immediately during the recovering. Do you think that makes sense?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh that does make a lot of sense - good to keep this then!

f'sky spot cancel -y -n {name}-2',
'sleep 10',
f's=$(sky spot queue); printf "$s"; echo; echo; printf "$s" | grep {name}-2 | head -n1 | grep "CANCELLED"',
'sleep 90',
(f'aws ec2 describe-instances --region {region} '
f'--filters Name=tag:ray-cluster-name,Values={name}-2* '
f'--query Reservations[].Instances[].State[].Name '
'--output text | grep terminated'),
])
run_one_test(test)


# ---------- Testing storage for managed spot ----------
def test_spot_storage():
"""Test storage with managed spot"""
Expand Down