Skip to content

Commit

Permalink
Clean up preempted resources for TPU (#1483)
Browse files Browse the repository at this point in the history
* fix in controller

* remove debug msg

* msg

* handle job_status == None case and refactor

* space

* update

* comments

* comments
  • Loading branch information
infwinston authored Dec 6, 2022
1 parent 172f6e3 commit ee73e7d
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 45 deletions.
5 changes: 1 addition & 4 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,11 +1225,8 @@ def _get_tpu_vm_pod_ips(ray_config: Dict[str, Any],

cluster_name = ray_config['cluster_name']
zone = ray_config['provider']['availability_zone']
# Excluding preempted VMs is safe as they are already terminated and
# do not charge.
query_cmd = (f'gcloud compute tpus tpu-vm list --filter='
f'"(labels.ray-cluster-name={cluster_name} AND '
f'state!=PREEMPTED)" '
f'\\(labels.ray-cluster-name={cluster_name}\\) '
f'--zone={zone} --format=value\\(name\\)')
if not get_internal_ips:
tpuvm_cmd = (f'gcloud compute tpus tpu-vm describe $({query_cmd})'
Expand Down
5 changes: 1 addition & 4 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2680,12 +2680,9 @@ def teardown_no_lock(self,
# check if gcloud includes TPU VM API
backend_utils.check_gcp_cli_include_tpu_vm()

# Excluding preempted VMs is safe as they are already
# terminated and do not charge.
query_cmd = (
f'gcloud compute tpus tpu-vm list --filter='
f'"(labels.ray-cluster-name={cluster_name} AND '
f'state!=PREEMPTED)" '
f'\\(labels.ray-cluster-name={cluster_name}\\) '
f'--zone={zone} --format=value\\(name\\)')
terminate_cmd = (
f'gcloud compute tpus tpu-vm delete --zone={zone}'
Expand Down
76 changes: 39 additions & 37 deletions sky/spot/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,30 +85,6 @@ def _run(self):
job_status = spot_utils.get_job_status(self._backend,
self._cluster_name)

if job_status is not None and not job_status.is_terminal():
need_recovery = False
if self._task.num_nodes > 1:
# Check the cluster status for multi-node jobs, since the
# job may not be set to FAILED immediately when only some
# of the nodes are preempted.
(cluster_status,
handle) = backend_utils.refresh_cluster_status_handle(
self._cluster_name, force_refresh=True)
if cluster_status != global_user_state.ClusterStatus.UP:
# recover the cluster if it is not up.
# The status could be None when the cluster is preempted
# right after the job was found FAILED.
cluster_status_str = ('is preempted'
if cluster_status is None else
f'status {cluster_status.value}')
logger.info(f'Cluster {cluster_status_str}. '
'Recovering...')
need_recovery = True
if not need_recovery:
# The job and cluster are healthy, continue to monitor the
# job status.
continue

if job_status == job_lib.JobStatus.SUCCEEDED:
end_time = spot_utils.get_job_timestamp(self._backend,
self._cluster_name,
Expand All @@ -117,14 +93,35 @@ def _run(self):
spot_state.set_succeeded(self._job_id, end_time=end_time)
break

if job_status == job_lib.JobStatus.FAILED:
# Check the status of the spot cluster. If it is not UP,
# the cluster is preempted.
(cluster_status,
handle) = backend_utils.refresh_cluster_status_handle(
self._cluster_name, force_refresh=True)
if cluster_status == global_user_state.ClusterStatus.UP:
# The user code has probably crashed.
# For single-node jobs, nonterminated job_status indicates a
# healthy cluster. We can safely continue monitoring.
# For multi-node jobs, since the job may not be set to FAILED
# immediately (depending on user program) when only some of the
# nodes are preempted, need to check the actual cluster status.
if (job_status is not None and not job_status.is_terminal() and
self._task.num_nodes == 1):
continue

# Pull the actual cluster status from the cloud provider to
# determine whether the cluster is preempted.
(cluster_status,
handle) = backend_utils.refresh_cluster_status_handle(
self._cluster_name, force_refresh=True)

if cluster_status != global_user_state.ClusterStatus.UP:
# The cluster is (partially) preempted. It can be down, INIT
# or STOPPED, based on the interruption behavior of the cloud.
# Spot recovery is needed (will be done later in the code).
cluster_status_str = ('' if cluster_status is None else
f' (status: {cluster_status.value})')
logger.info(
f'Cluster is preempted{cluster_status_str}. Recovering...')
else:
if job_status is not None and not job_status.is_terminal():
# The multi-node job is still running, continue monitoring.
continue
elif job_status == job_lib.JobStatus.FAILED:
# The user code has probably crashed, fail immediately.
end_time = spot_utils.get_job_timestamp(self._backend,
self._cluster_name,
get_end_time=True)
Expand All @@ -140,11 +137,16 @@ def _run(self):
failure_type=spot_state.SpotStatus.FAILED,
end_time=end_time)
break
# cluster can be down, INIT or STOPPED, based on the interruption
# behavior of the cloud.
# Failed to connect to the cluster or the cluster is partially down.
# job_status is None or job_status == job_lib.JobStatus.FAILED
logger.info('The cluster is preempted.')
# Although the cluster is healthy, we fail to access the
# job status. Try to recover the job (will not restart the
# cluster, if the cluster is healthy).
assert job_status is None, job_status
logger.info('Failed to fetch the job status while the '
'cluster is healthy. Try to recover the job '
'(the cluster will not be restarted).')

# Try to recover the spot jobs, when the cluster is preempted
# or the job status is failed to be fetched.
spot_state.set_recovering(self._job_id)
recovered_time = self._strategy_executor.recover()
spot_state.set_recovered(self._job_id,
Expand Down

0 comments on commit ee73e7d

Please sign in to comment.