Skip to content

Commit

Permalink
[Spot] Let cancel interrupt the spot job (#1414) (#1433)
Browse files Browse the repository at this point in the history
* Let cancel interrupt the job

* Add test

* Fix test

* Cancel early

* fix test

* fix test

* Fix exceptions

* pass test

* increase waiting time

* address comments

* add job id

* remove 'auto' in ray.init

* Fix serialization problem

* refactor a bit

* Fix

* Add comments

* format

* pylint

* revert a format change

* Add docstr

* Move ray.init

* replace ray with multiprocess.Process

* Add test for setup cancelation

* Fix logging

* Fix test

* lint

* Use SIGTERM instead

* format

* Change exception type

* revert to KeyboardInterrupt

* remove

* Fix test

* fix test

* fix test

* typo
  • Loading branch information
Michaelvll authored Nov 22, 2022
1 parent f98fc9d commit 235d9d6
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 83 deletions.
18 changes: 2 additions & 16 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import jinja2
import jsonschema
from packaging import version
import psutil
import requests
from requests import adapters
from requests.packages.urllib3.util import retry as retry_lib
Expand Down Expand Up @@ -1999,23 +1998,10 @@ def check_gcp_cli_include_tpu_vm() -> None:
' TPU VM APIs, check "gcloud version" for details.')


def kill_children_processes():
# We need to kill the children, so that the underlying subprocess
# will not print the logs to the terminal, after this program
# exits.
parent_process = psutil.Process()
for child in parent_process.children(recursive=True):
try:
child.terminate()
except psutil.NoSuchProcess:
# The child process may have already been terminated.
pass


# Handle ctrl-c
def interrupt_handler(signum, frame):
del signum, frame
kill_children_processes()
subprocess_utils.kill_children_processes()
# Avoid using logger here, as it will print the stack trace for broken
# pipe, when the output is piped to another program.
print(f'{colorama.Style.DIM}Tip: The job will keep '
Expand All @@ -2027,7 +2013,7 @@ def interrupt_handler(signum, frame):
# Handle ctrl-z
def stop_handler(signum, frame):
del signum, frame
kill_children_processes()
subprocess_utils.kill_children_processes()
# Avoid using logger here, as it will print the stack trace for broken
# pipe, when the output is piped to another program.
print(f'{colorama.Style.DIM}Tip: The job will keep '
Expand Down
140 changes: 94 additions & 46 deletions sky/spot/controller.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Controller: handles the life cycle of a managed spot cluster (job)."""
import argparse
import multiprocessing
import pathlib
import signal
import time
import traceback

Expand All @@ -19,6 +21,7 @@
from sky.spot import spot_state
from sky.spot import spot_utils
from sky.utils import common_utils
from sky.utils import subprocess_utils

logger = sky_logging.init_logger(__name__)

Expand All @@ -34,28 +37,27 @@ def __init__(self, job_id: int, task_yaml: str,

self._retry_until_up = retry_until_up
# TODO(zhwu): this assumes the specific backend.
self.backend = cloud_vm_ray_backend.CloudVmRayBackend()
self._backend = cloud_vm_ray_backend.CloudVmRayBackend()

# Add a unique identifier to the task environment variables, so that
# the user can have the same id for multiple recoveries.
# Example value: sky-2022-10-04-22-46-52-467694_id-17
task_envs = self._task.envs or {}
job_id_env_var = common_utils.get_global_job_id(
self.backend.run_timestamp, 'spot', self._job_id)
self._backend.run_timestamp, 'spot', self._job_id)
task_envs[constants.JOB_ID_ENV_VAR] = job_id_env_var
self._task.set_envs(task_envs)

spot_state.set_submitted(
self._job_id,
self._task_name,
self.backend.run_timestamp,
self._backend.run_timestamp,
resources_str=backend_utils.get_task_resources_str(self._task))
logger.info(f'Submitted spot job; SKYPILOT_JOB_ID: {job_id_env_var}')
self._cluster_name = spot_utils.generate_spot_cluster_name(
self._task_name, self._job_id)
self._strategy_executor = recovery_strategy.StrategyExecutor.make(
self._cluster_name, self.backend, self._task, retry_until_up,
self._handle_signal)
self._cluster_name, self._backend, self._task, retry_until_up)

def _run(self):
"""Busy loop monitoring spot cluster status and handling recovery."""
Expand All @@ -67,8 +69,6 @@ def _run(self):
spot_state.set_started(self._job_id, start_time=start_at)
while True:
time.sleep(spot_utils.JOB_STATUS_CHECK_GAP_SECONDS)
# Handle the signal if it is sent by the user.
self._handle_signal()

# Check the network connection to avoid false alarm for job failure.
# Network glitch was observed even in the VM.
Expand All @@ -82,7 +82,7 @@ def _run(self):

# NOTE: we do not check cluster status first because race condition
# can occur, i.e. cluster can be down during the job status check.
job_status = spot_utils.get_job_status(self.backend,
job_status = spot_utils.get_job_status(self._backend,
self._cluster_name)

if job_status is not None and not job_status.is_terminal():
Expand All @@ -105,7 +105,7 @@ def _run(self):
continue

if job_status == job_lib.JobStatus.SUCCEEDED:
end_time = spot_utils.get_job_timestamp(self.backend,
end_time = spot_utils.get_job_timestamp(self._backend,
self._cluster_name,
get_end_time=True)
# The job is done.
Expand All @@ -120,15 +120,15 @@ def _run(self):
self._cluster_name, force_refresh=True)
if cluster_status == global_user_state.ClusterStatus.UP:
# The user code has probably crashed.
end_time = spot_utils.get_job_timestamp(self.backend,
end_time = spot_utils.get_job_timestamp(self._backend,
self._cluster_name,
get_end_time=True)
logger.info(
'The user job failed. Please check the logs below.\n'
f'== Logs of the user job (ID: {self._job_id}) ==\n')
self.backend.tail_logs(handle,
None,
spot_job_id=self._job_id)
self._backend.tail_logs(handle,
None,
spot_job_id=self._job_id)
logger.info(f'\n== End of logs (ID: {self._job_id}) ==')
spot_state.set_failed(
self._job_id,
Expand All @@ -145,12 +145,13 @@ def _run(self):
spot_state.set_recovered(self._job_id,
recovered_time=recovered_time)

def start(self):
"""Start the controller."""
def run(self):
"""Run controller logic and handle exceptions."""
try:
self._run()
except exceptions.SpotUserCancelledError as e:
logger.info(e)
except KeyboardInterrupt:
# Kill the children processes launched by log_lib.run_with_log.
subprocess_utils.kill_children_processes()
spot_state.set_cancelled(self._job_id)
except exceptions.ResourcesUnavailableError as e:
logger.error(f'Resources unavailable: {colorama.Fore.RED}{e}'
Expand All @@ -167,37 +168,86 @@ def start(self):
# The job can be non-terminal if the controller exited abnormally,
# e.g. failed to launch cluster after reaching the MAX_RETRY.
if not job_status.is_terminal():
logger.info(f'Previous spot job status: {job_status.value}')
spot_state.set_failed(
self._job_id,
failure_type=spot_state.SpotStatus.FAILED_CONTROLLER)

# Clean up Storages with persistent=False.
self.backend.teardown_ephemeral_storage(self._task)

def _handle_signal(self):
"""Handle the signal if the user sent it."""
signal_file = pathlib.Path(
spot_utils.SIGNAL_FILE_PREFIX.format(self._job_id))
signal = None
if signal_file.exists():
# Filelock is needed to prevent race condition with concurrent
# signal writing.
# TODO(mraheja): remove pylint disabling when filelock version
# updated
# pylint: disable=abstract-class-instantiated
with filelock.FileLock(str(signal_file) + '.lock'):
with signal_file.open(mode='r') as f:
signal = f.read().strip()
signal = spot_utils.UserSignal(signal)
# Remove the signal file, after reading the signal.
signal_file.unlink()
if signal is None:
return
if signal == spot_utils.UserSignal.CANCEL:
raise exceptions.SpotUserCancelledError(
f'User sent {signal.value} signal.')

raise RuntimeError(f'Unknown SkyPilot signal received: {signal.value}.')
self._backend.teardown_ephemeral_storage(self._task)


def _run_controller(job_id: int, task_yaml: str, retry_until_up: bool):
"""Runs the controller in a remote process for interruption."""

# Override the SIGTERM handler to gracefully terminate the controller.
def handle_interupt(signum, frame):
"""Handle the interrupt signal."""
# Need to raise KeyboardInterrupt to avoid the exception being caught by
# the strategy executor.
raise KeyboardInterrupt()

signal.signal(signal.SIGTERM, handle_interupt)

# The controller needs to be instantiated in the remote process, since
# the controller is not serializable.
spot_controller = SpotController(job_id, task_yaml, retry_until_up)
spot_controller.run()


def _handle_signal(job_id):
"""Handle the signal if the user sent it."""
signal_file = pathlib.Path(spot_utils.SIGNAL_FILE_PREFIX.format(job_id))
user_signal = None
if signal_file.exists():
# Filelock is needed to prevent race condition with concurrent
# signal writing.
# TODO(mraheja): remove pylint disabling when filelock version
# updated
# pylint: disable=abstract-class-instantiated
with filelock.FileLock(str(signal_file) + '.lock'):
with signal_file.open(mode='r') as f:
user_signal = f.read().strip()
try:
user_signal = spot_utils.UserSignal(user_signal)
except ValueError:
logger.warning(
f'Unknown signal received: {user_signal}. Ignoring.')
user_signal = None
# Remove the signal file, after reading the signal.
signal_file.unlink()
if user_signal is None:
# None or empty string.
return
assert user_signal == spot_utils.UserSignal.CANCEL, (
f'Only cancel signal is supported, but {user_signal} got.')
raise exceptions.SpotUserCancelledError(
f'User sent {user_signal.value} signal.')


def start(job_id, task_yaml, retry_until_up):
"""Start the controller."""
controller_process = None
try:
_handle_signal(job_id)
controller_process = multiprocessing.Process(target=_run_controller,
args=(job_id, task_yaml,
retry_until_up))
controller_process.start()
while controller_process.is_alive():
_handle_signal(job_id)
time.sleep(1)
except exceptions.SpotUserCancelledError:
logger.info(f'Cancelling spot job {job_id}...')
if controller_process is not None:
logger.info('Sending SIGTERM to controller process '
f'{controller_process.pid}')
# This will raise KeyboardInterrupt in the task.
# Using SIGTERM instead of SIGINT, as the SIGINT is weirdly ignored
# by the controller process when it is started inside a ray job.
controller_process.terminate()
if controller_process is not None:
controller_process.join()


if __name__ == '__main__':
Expand All @@ -214,6 +264,4 @@ def _handle_signal(self):
help='The path to the user spot task yaml file. '
'The file name is the spot task name.')
args = parser.parse_args()
controller = SpotController(args.job_id, args.task_yaml,
args.retry_until_up)
controller.start()
start(args.job_id, args.task_yaml, args.retry_until_up)
16 changes: 4 additions & 12 deletions sky/spot/recovery_strategy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""The strategy to handle launching/recovery/termination of spot clusters."""
import time
import typing
from typing import Callable, Optional
from typing import Optional

import sky
from sky import exceptions
Expand Down Expand Up @@ -34,24 +34,20 @@ class StrategyExecutor:
RETRY_INIT_GAP_SECONDS = 60

def __init__(self, cluster_name: str, backend: 'backends.Backend',
task: 'task_lib.Task', retry_until_up: bool,
signal_handler: Callable) -> None:
task: 'task_lib.Task', retry_until_up: bool) -> None:
"""Initialize the strategy executor.
Args:
cluster_name: The name of the cluster.
backend: The backend to use. Only CloudVMRayBackend is supported.
task: The task to execute.
retry_until_up: Whether to retry until the cluster is up.
signal_handler: The signal handler that will raise an exception if a
SkyPilot signal is received.
"""
self.dag = sky.Dag()
self.dag.add(task)
self.cluster_name = cluster_name
self.backend = backend
self.retry_until_up = retry_until_up
self.signal_handler = signal_handler

def __init_subclass__(cls, name: str, default: bool = False):
SPOT_STRATEGIES[name] = cls
Expand All @@ -63,8 +59,7 @@ def __init_subclass__(cls, name: str, default: bool = False):

@classmethod
def make(cls, cluster_name: str, backend: 'backends.Backend',
task: 'task_lib.Task', retry_until_up: bool,
signal_handler: Callable) -> 'StrategyExecutor':
task: 'task_lib.Task', retry_until_up: bool) -> 'StrategyExecutor':
"""Create a strategy from a task."""
resources = task.resources
assert len(resources) == 1, 'Only one resource is supported.'
Expand All @@ -77,7 +72,7 @@ def make(cls, cluster_name: str, backend: 'backends.Backend',
# will be handled by the strategy class.
task.set_resources({resources.copy(spot_recovery=None)})
return SPOT_STRATEGIES[spot_recovery](cluster_name, backend, task,
retry_until_up, signal_handler)
retry_until_up)

def launch(self) -> Optional[float]:
"""Launch the spot cluster for the first time.
Expand Down Expand Up @@ -147,9 +142,6 @@ def _launch(self, max_retry=3, raise_on_failure=True) -> Optional[float]:
backoff = common_utils.Backoff(self.RETRY_INIT_GAP_SECONDS)
while True:
retry_cnt += 1
# Check the signal every time to be more responsive to user
# signals, such as Cancel.
self.signal_handler()
retry_launch = False
exception = None
try:
Expand Down
6 changes: 2 additions & 4 deletions sky/spot/spot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,7 @@ def cancel_jobs_by_id(job_ids: Optional[List[int]]) -> str:
cancelled_job_ids_str = ', '.join(map(str, cancelled_job_ids))
identity_str = f'Jobs with IDs {cancelled_job_ids_str} are'

return (f'{identity_str} scheduled to be cancelled within '
f'{JOB_STATUS_CHECK_GAP_SECONDS} seconds.')
return f'{identity_str} scheduled to be cancelled.'


def cancel_job_by_name(job_name: str) -> str:
Expand All @@ -192,8 +191,7 @@ def cancel_job_by_name(job_name: str) -> str:
f'with name {job_name!r}.\n'
f'Job IDs: {job_ids}{colorama.Style.RESET_ALL}')
cancel_jobs_by_id(job_ids)
return (f'Job {job_name!r} is scheduled to be cancelled within '
f'{JOB_STATUS_CHECK_GAP_SECONDS} seconds.')
return f'Job {job_name!r} is scheduled to be cancelled.'


def stream_logs_by_id(job_id: int, follow: bool = True) -> str:
Expand Down
14 changes: 14 additions & 0 deletions sky/utils/subprocess_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Utility functions for subprocesses."""
from multiprocessing import pool
import psutil
import subprocess
from typing import Any, Callable, List, Optional, Union

Expand Down Expand Up @@ -74,3 +75,16 @@ def handle_returncode(returncode: int,
f'{colorama.Fore.RED}{error_msg}{colorama.Style.RESET_ALL}')
with ux_utils.print_exception_no_traceback():
raise exceptions.CommandError(returncode, command, format_err_msg)


def kill_children_processes():
# We need to kill the children, so that the underlying subprocess
# will not print the logs to the terminal, after this program
# exits.
parent_process = psutil.Process()
for child in parent_process.children(recursive=True):
try:
child.terminate()
except psutil.NoSuchProcess:
# The child process may have already been terminated.
pass
Loading

0 comments on commit 235d9d6

Please sign in to comment.