Skip to content

Commit

Permalink
[spot] Fix multiprocessing signal handling in spot controller (#1745)
Browse files Browse the repository at this point in the history
Previously, we send an interruption signal to the controller process and the controller process handles cleanup. However, we figure out the behavior differs from cloud to cloud (e.g., GCP ignore 'SIGINT'). A possible reason is https://unix.stackexchange.com/questions/356408/strange-problem-with-trap-and-sigint. But anyway, a clean solution is killing the controller process directly, and then cleanup the cluster state.

Tested (run the relevant ones):

- [ ] Any manual or new tests for this PR (please specify below)
- [ ] All smoke tests: `pytest tests/test_smoke.py` 
- [x] Relevant individual smoke tests: `pytest tests/test_smoke.py --managed-spot`  (both AWS and GCP as spot controllers)
- [ ] Backward compatibility tests: `bash tests/backward_comaptibility_tests.sh`
  • Loading branch information
suquark authored Mar 16, 2023
1 parent 322ffad commit 464b5db
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 86 deletions.
106 changes: 66 additions & 40 deletions sky/spot/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import argparse
import multiprocessing
import pathlib
import signal
import time
import traceback

Expand Down Expand Up @@ -181,7 +180,7 @@ def _run(self):
# Some spot resource (e.g., Spot TPU VM) may need to be
# cleaned up after preemption.
logger.info('Cleaning up the preempted spot cluster...')
self._strategy_executor.terminate_cluster()
recovery_strategy.terminate_cluster(self._cluster_name)

# Try to recover the spot jobs, when the cluster is preempted
# or the job status is failed to be fetched.
Expand All @@ -194,10 +193,6 @@ def run(self):
"""Run controller logic and handle exceptions."""
try:
self._run()
except KeyboardInterrupt:
# Kill the children processes launched by log_lib.run_with_log.
subprocess_utils.kill_children_processes()
spot_state.set_cancelled(self._job_id)
except exceptions.ProvisionPrechecksError as e:
# Please refer to the docstring of self._run for the cases when
# this exception can occur.
Expand Down Expand Up @@ -228,36 +223,10 @@ def run(self):
self._job_id,
failure_type=spot_state.SpotStatus.FAILED_CONTROLLER,
failure_reason=msg)
finally:
self._strategy_executor.terminate_cluster()
job_status = spot_state.get_status(self._job_id)
# The job can be non-terminal if the controller exited abnormally,
# e.g. failed to launch cluster after reaching the MAX_RETRY.
if not job_status.is_terminal():
logger.info(f'Previous spot job status: {job_status.value}')
spot_state.set_failed(
self._job_id,
failure_type=spot_state.SpotStatus.FAILED_CONTROLLER,
failure_reason=(
'Unexpected error occurred. For details, '
f'run: sky spot logs --controller {self._job_id}'))

# Clean up Storages with persistent=False.
self._backend.teardown_ephemeral_storage(self._task)


def _run_controller(job_id: int, task_yaml: str, retry_until_up: bool):
"""Runs the controller in a remote process for interruption."""

# Override the SIGTERM handler to gracefully terminate the controller.
def handle_interupt(signum, frame):
"""Handle the interrupt signal."""
# Need to raise KeyboardInterrupt to avoid the exception being caught by
# the strategy executor.
raise KeyboardInterrupt()

signal.signal(signal.SIGTERM, handle_interupt)

# The controller needs to be instantiated in the remote process, since
# the controller is not serializable.
spot_controller = SpotController(job_id, task_yaml, retry_until_up)
Expand Down Expand Up @@ -294,11 +263,31 @@ def _handle_signal(job_id):
f'User sent {user_signal.value} signal.')


def _cleanup(job_id: int, task_yaml: str):
# NOTE: The code to get cluster name is same as what we did in the spot
# controller, we should keep it in sync with SpotController.__init__()
task = sky.Task.from_yaml(task_yaml)
task_name = task.name
cluster_name = spot_utils.generate_spot_cluster_name(task_name, job_id)
recovery_strategy.terminate_cluster(cluster_name)
# Clean up Storages with persistent=False.
# TODO(zhwu): this assumes the specific backend.
backend = cloud_vm_ray_backend.CloudVmRayBackend()
backend.teardown_ephemeral_storage(task)


def start(job_id, task_yaml, retry_until_up):
"""Start the controller."""
controller_process = None
cancelling = False
try:
_handle_signal(job_id)
# TODO(suquark): In theory, we should make controller process a
# daemon process so it will be killed after this process exits,
# however daemon process cannot launch subprocesses, explained here:
# https://docs.python.org/3/library/multiprocessing.html#multiprocessing.Process.daemon # pylint: disable=line-too-long
# So we can only enable daemon after we no longer need to
# start daemon processes like Ray.
controller_process = multiprocessing.Process(target=_run_controller,
args=(job_id, task_yaml,
retry_until_up))
Expand All @@ -308,15 +297,49 @@ def start(job_id, task_yaml, retry_until_up):
time.sleep(1)
except exceptions.SpotUserCancelledError:
logger.info(f'Cancelling spot job {job_id}...')
cancelling = True
finally:
if controller_process is not None:
logger.info('Sending SIGTERM to controller process '
f'{controller_process.pid}')
# This will raise KeyboardInterrupt in the task.
# Using SIGTERM instead of SIGINT, as the SIGINT is weirdly ignored
# by the controller process when it is started inside a ray job.
controller_process.terminate()
if controller_process is not None:
controller_process.join()
logger.info(f'Killing controller process {controller_process.pid}')
# NOTE: it is ok to kill or join a killed process.
# Kill the controller process first; if its child process is
# killed first, then the controller process will raise errors.
# Kill any possible remaining children processes recursively.
subprocess_utils.kill_children_processes(controller_process.pid,
force=True)
controller_process.join()
logger.info(f'Controller process {controller_process.pid} killed.')

logger.info(f'Cleaning up spot clusters of job {job_id}.')
# NOTE: Originally, we send an interruption signal to the controller
# process and the controller process handles cleanup. However, we
# figure out the behavior differs from cloud to cloud
# (e.g., GCP ignores 'SIGINT'). A possible explanation is
# https://unix.stackexchange.com/questions/356408/strange-problem-with-trap-and-sigint
# But anyway, a clean solution is killing the controller process
# directly, and then cleanup the cluster state.
_cleanup(job_id, task_yaml=task_yaml)
logger.info(f'Spot clusters of job {job_id} has been taken down.')

# TODO(suquark): It could take a long time cleaning up the cluster.
# In the future, we may add a "cancelling" state for the spot
# controller.
if cancelling:
spot_state.set_cancelled(job_id)

# We should check job status after 'set_cancelled', otherwise
# the job status is not terminal.
job_status = spot_state.get_status(job_id)
# The job can be non-terminal if the controller exited abnormally,
# e.g. failed to launch cluster after reaching the MAX_RETRY.
assert job_status is not None
if not job_status.is_terminal():
logger.info(f'Previous spot job status: {job_status.value}')
spot_state.set_failed(
job_id,
failure_type=spot_state.SpotStatus.FAILED_CONTROLLER,
failure_reason=('Unexpected error occurred. For details, '
f'run: sky spot logs --controller {job_id}'))


if __name__ == '__main__':
Expand All @@ -332,4 +355,7 @@ def start(job_id, task_yaml, retry_until_up):
type=str,
help='The path to the user spot task yaml file.')
args = parser.parse_args()
# We start process with 'spawn', because 'fork' could result in weird
# behaviors; 'spawn' is also cross-platform.
multiprocessing.set_start_method('spawn', force=True)
start(args.job_id, args.task_yaml, args.retry_until_up)
47 changes: 24 additions & 23 deletions sky/spot/recovery_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,28 @@
MAX_JOB_CHECKING_RETRY = 10


def terminate_cluster(cluster_name: str, max_retry: int = 3) -> None:
"""Terminate the spot cluster."""
retry_cnt = 0
while True:
try:
usage_lib.messages.usage.set_internal()
sky.down(cluster_name)
return
except ValueError:
# The cluster is already down.
return
except Exception as e: # pylint: disable=broad-except
retry_cnt += 1
if retry_cnt >= max_retry:
raise RuntimeError('Failed to terminate the spot cluster '
f'{cluster_name}.') from e
logger.error('Failed to terminate the spot cluster '
f'{cluster_name}. Retrying.'
f'Details: {common_utils.format_exception(e)}')
logger.error(f' Traceback: {traceback.format_exc()}')


class StrategyExecutor:
"""Handle each launching, recovery and termination of the spot clusters."""

Expand Down Expand Up @@ -101,27 +123,6 @@ def recover(self) -> float:
"""
raise NotImplementedError

def terminate_cluster(self, max_retry: int = 3) -> None:
"""Terminate the spot cluster."""
retry_cnt = 0
while True:
try:
usage_lib.messages.usage.set_internal()
sky.down(self.cluster_name)
return
except ValueError:
# The cluster is already down.
return
except Exception as e: # pylint: disable=broad-except
retry_cnt += 1
if retry_cnt >= max_retry:
raise RuntimeError('Failed to terminate the spot cluster '
f'{self.cluster_name}.') from e
logger.error('Failed to terminate the spot cluster '
f'{self.cluster_name}. Retrying.'
f'Details: {common_utils.format_exception(e)}')
logger.error(f' Traceback: {traceback.format_exc()}')

def _try_cancel_all_jobs(self):
handle = global_user_state.get_handle_from_cluster_name(
self.cluster_name)
Expand Down Expand Up @@ -306,7 +307,7 @@ def _launch(self, max_retry=3, raise_on_failure=True) -> Optional[float]:
'launched cluster, due to unexpected submission errors or '
'the cluster being preempted during job submission.')

self.terminate_cluster()
terminate_cluster(self.cluster_name)
if max_retry is not None and retry_cnt >= max_retry:
# Retry forever if max_retry is None.
if raise_on_failure:
Expand Down Expand Up @@ -385,7 +386,7 @@ def recover(self) -> float:
logger.debug('Terminating unhealthy spot cluster and '
'reset cloud region.')
self._launched_cloud_region = None
self.terminate_cluster()
terminate_cluster(self.cluster_name)

# Step 3
logger.debug('Relaunch the cluster without constraining to prior '
Expand Down
39 changes: 34 additions & 5 deletions sky/utils/subprocess_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,43 @@ def handle_returncode(returncode: int,
raise exceptions.CommandError(returncode, command, format_err_msg)


def kill_children_processes():
# We need to kill the children, so that the underlying subprocess
# will not print the logs to the terminal, after this program
# exits.
def kill_children_processes(first_pid_to_kill: Optional[int] = None,
force: bool = False):
"""Kill children processes recursively.
We need to kill the children, so that
1. The underlying subprocess will not print the logs to the terminal,
after this program exits.
2. The underlying subprocess will not continue with starting a cluster
etc. while we are cleaning up the clusters.
Args:
first_pid_to_kill: Optional PID of a process to be killed first.
This is for guaranteeing the order of cleaning up and suppress
flaky errors.
"""
parent_process = psutil.Process()
child_processes = []
for child in parent_process.children(recursive=True):
if child.pid == first_pid_to_kill:
try:
if force:
child.kill()
else:
child.terminate()
child.wait()
except psutil.NoSuchProcess:
# The child process may have already been terminated.
pass
else:
child_processes.append(child)

for child in child_processes:
try:
child.terminate()
if force:
child.kill()
else:
child.terminate()
except psutil.NoSuchProcess:
# The child process may have already been terminated.
pass
25 changes: 7 additions & 18 deletions tests/test_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,9 +1267,8 @@ def test_spot(generic_cloud: str):
f'{_SPOT_QUEUE_WAIT}| grep {name}-1 | head -n1 | grep "STARTING\|RUNNING"',
f'{_SPOT_QUEUE_WAIT}| grep {name}-2 | head -n1 | grep "STARTING\|RUNNING"',
f'sky spot cancel -y -n {name}-1',
'sleep 10',
'sleep 120',
f'{_SPOT_QUEUE_WAIT}| grep {name}-1 | head -n1 | grep CANCELLED',
'sleep 200',
f'{_SPOT_QUEUE_WAIT}| grep {name}-2 | head -n1 | grep "RUNNING\|SUCCEEDED"',
],
f'sky spot cancel -y -n {name}-1; sky spot cancel -y -n {name}-2',
Expand Down Expand Up @@ -1460,9 +1459,8 @@ def test_spot_cancellation_aws():
'sleep 60',
f'{_SPOT_QUEUE_WAIT}| grep {name} | head -n1 | grep "STARTING"',
f'sky spot cancel -y -n {name}',
'sleep 5',
'sleep 120',
f'{_SPOT_QUEUE_WAIT}| grep {name} | head -n1 | grep "CANCELLED"',
'sleep 100',
(f's=$(aws ec2 describe-instances --region {region} '
f'--filters Name=tag:ray-cluster-name,Values={name}-* '
f'--query Reservations[].Instances[].State[].Name '
Expand All @@ -1472,9 +1470,8 @@ def test_spot_cancellation_aws():
f'sky spot launch --cloud aws --region {region} -n {name}-2 tests/test_yamls/test_long_setup.yaml -y -d',
'sleep 300',
f'sky spot cancel -y -n {name}-2',
'sleep 5',
'sleep 120',
f'{_SPOT_QUEUE_WAIT}| grep {name}-2 | head -n1 | grep "CANCELLED"',
'sleep 100',
(f's=$(aws ec2 describe-instances --region {region} '
f'--filters Name=tag:ray-cluster-name,Values={name}-2-* '
f'--query Reservations[].Instances[].State[].Name '
Expand All @@ -1493,9 +1490,8 @@ def test_spot_cancellation_aws():
'sleep 100',
f'{_SPOT_QUEUE_WAIT}| grep {name}-3 | head -n1 | grep "RECOVERING"',
f'sky spot cancel -y -n {name}-3',
'sleep 10',
'sleep 120',
f'{_SPOT_QUEUE_WAIT}| grep {name}-3 | head -n1 | grep "CANCELLED"',
'sleep 90',
# The cluster should be terminated (shutting-down) after cancellation. We don't use the `=` operator here because
# there can be multiple VM with the same name due to the recovery.
(f's=$(aws ec2 describe-instances --region {region} '
Expand Down Expand Up @@ -1529,20 +1525,14 @@ def test_spot_cancellation_gcp():
'sleep 60',
f'{_SPOT_QUEUE_WAIT}| grep {name} | head -n1 | grep "STARTING"',
f'sky spot cancel -y -n {name}',
'sleep 5',
'sleep 120',
f'{_SPOT_QUEUE_WAIT}| grep {name} | head -n1 | grep "CANCELLED"',
'sleep 100',
f's=$({query_state_cmd}) && echo "$s" && echo; [[ -z "$s" ]] || [[ "$s" = "STOPPING" ]]' # GCP shows STOPPING when shutting down
,
# Test cancelling the spot cluster during spot job being setup.
f'sky spot launch --cloud gcp --zone {zone} -n {name}-2 tests/test_yamls/test_long_setup.yaml -y -d',
'sleep 300',
f'sky spot cancel -y -n {name}-2',
'sleep 5',
'sleep 120',
f'{_SPOT_QUEUE_WAIT}| grep {name}-2 | head -n1 | grep "CANCELLED"',
'sleep 100',
(f's=$({query_state_cmd}) && echo "$s" && echo; [[ -z "$s" ]] || [[ "$s" = "STOPPING" ]]'
),
# Test cancellation during spot job is recovering.
f'sky spot launch --cloud gcp --zone {zone} -n {name}-3 "sleep 1000" -y -d',
'sleep 300',
Expand All @@ -1552,9 +1542,8 @@ def test_spot_cancellation_gcp():
'sleep 100',
f'{_SPOT_QUEUE_WAIT}| grep {name}-3 | head -n1 | grep "RECOVERING"',
f'sky spot cancel -y -n {name}-3',
'sleep 10',
'sleep 120',
f'{_SPOT_QUEUE_WAIT}| grep {name}-3 | head -n1 | grep "CANCELLED"',
'sleep 90',
# The cluster should be terminated (STOPPING) after cancellation. We don't use the `=` operator here because
# there can be multiple VM with the same name due to the recovery.
(f's=$({query_state_cmd}) && echo "$s" && echo; [[ -z "$s" ]] || echo "$s" | grep -v -E "PROVISIONING|STAGING|RUNNING|REPAIRING|TERMINATED|SUSPENDING|SUSPENDED|SUSPENDED"'
Expand Down

0 comments on commit 464b5db

Please sign in to comment.