diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index e06d15929d3..b87b8373592 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -1225,11 +1225,8 @@ def _get_tpu_vm_pod_ips(ray_config: Dict[str, Any], cluster_name = ray_config['cluster_name'] zone = ray_config['provider']['availability_zone'] - # Excluding preempted VMs is safe as they are already terminated and - # do not charge. query_cmd = (f'gcloud compute tpus tpu-vm list --filter=' - f'"(labels.ray-cluster-name={cluster_name} AND ' - f'state!=PREEMPTED)" ' + f'\\(labels.ray-cluster-name={cluster_name}\\) ' f'--zone={zone} --format=value\\(name\\)') if not get_internal_ips: tpuvm_cmd = (f'gcloud compute tpus tpu-vm describe $({query_cmd})' diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index d85d4d1148e..19566bcaf22 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -2680,12 +2680,9 @@ def teardown_no_lock(self, # check if gcloud includes TPU VM API backend_utils.check_gcp_cli_include_tpu_vm() - # Excluding preempted VMs is safe as they are already - # terminated and do not charge. query_cmd = ( f'gcloud compute tpus tpu-vm list --filter=' - f'"(labels.ray-cluster-name={cluster_name} AND ' - f'state!=PREEMPTED)" ' + f'\\(labels.ray-cluster-name={cluster_name}\\) ' f'--zone={zone} --format=value\\(name\\)') terminate_cmd = ( f'gcloud compute tpus tpu-vm delete --zone={zone}' diff --git a/sky/spot/controller.py b/sky/spot/controller.py index 27640bbfe6a..deb7fc22fc5 100644 --- a/sky/spot/controller.py +++ b/sky/spot/controller.py @@ -85,30 +85,6 @@ def _run(self): job_status = spot_utils.get_job_status(self._backend, self._cluster_name) - if job_status is not None and not job_status.is_terminal(): - need_recovery = False - if self._task.num_nodes > 1: - # Check the cluster status for multi-node jobs, since the - # job may not be set to FAILED immediately when only some - # of the nodes are preempted. - (cluster_status, - handle) = backend_utils.refresh_cluster_status_handle( - self._cluster_name, force_refresh=True) - if cluster_status != global_user_state.ClusterStatus.UP: - # recover the cluster if it is not up. - # The status could be None when the cluster is preempted - # right after the job was found FAILED. - cluster_status_str = ('is preempted' - if cluster_status is None else - f'status {cluster_status.value}') - logger.info(f'Cluster {cluster_status_str}. ' - 'Recovering...') - need_recovery = True - if not need_recovery: - # The job and cluster are healthy, continue to monitor the - # job status. - continue - if job_status == job_lib.JobStatus.SUCCEEDED: end_time = spot_utils.get_job_timestamp(self._backend, self._cluster_name, @@ -117,14 +93,35 @@ def _run(self): spot_state.set_succeeded(self._job_id, end_time=end_time) break - if job_status == job_lib.JobStatus.FAILED: - # Check the status of the spot cluster. If it is not UP, - # the cluster is preempted. - (cluster_status, - handle) = backend_utils.refresh_cluster_status_handle( - self._cluster_name, force_refresh=True) - if cluster_status == global_user_state.ClusterStatus.UP: - # The user code has probably crashed. + # For single-node jobs, nonterminated job_status indicates a + # healthy cluster. We can safely continue monitoring. + # For multi-node jobs, since the job may not be set to FAILED + # immediately (depending on user program) when only some of the + # nodes are preempted, need to check the actual cluster status. + if (job_status is not None and not job_status.is_terminal() and + self._task.num_nodes == 1): + continue + + # Pull the actual cluster status from the cloud provider to + # determine whether the cluster is preempted. + (cluster_status, + handle) = backend_utils.refresh_cluster_status_handle( + self._cluster_name, force_refresh=True) + + if cluster_status != global_user_state.ClusterStatus.UP: + # The cluster is (partially) preempted. It can be down, INIT + # or STOPPED, based on the interruption behavior of the cloud. + # Spot recovery is needed (will be done later in the code). + cluster_status_str = ('' if cluster_status is None else + f' (status: {cluster_status.value})') + logger.info( + f'Cluster is preempted{cluster_status_str}. Recovering...') + else: + if job_status is not None and not job_status.is_terminal(): + # The multi-node job is still running, continue monitoring. + continue + elif job_status == job_lib.JobStatus.FAILED: + # The user code has probably crashed, fail immediately. end_time = spot_utils.get_job_timestamp(self._backend, self._cluster_name, get_end_time=True) @@ -140,11 +137,16 @@ def _run(self): failure_type=spot_state.SpotStatus.FAILED, end_time=end_time) break - # cluster can be down, INIT or STOPPED, based on the interruption - # behavior of the cloud. - # Failed to connect to the cluster or the cluster is partially down. - # job_status is None or job_status == job_lib.JobStatus.FAILED - logger.info('The cluster is preempted.') + # Although the cluster is healthy, we fail to access the + # job status. Try to recover the job (will not restart the + # cluster, if the cluster is healthy). + assert job_status is None, job_status + logger.info('Failed to fetch the job status while the ' + 'cluster is healthy. Try to recover the job ' + '(the cluster will not be restarted).') + + # Try to recover the spot jobs, when the cluster is preempted + # or the job status is failed to be fetched. spot_state.set_recovering(self._job_id) recovered_time = self._strategy_executor.recover() spot_state.set_recovered(self._job_id,