diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index b87b8373592..5e626ca97d5 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -1228,34 +1228,58 @@ def _get_tpu_vm_pod_ips(ray_config: Dict[str, Any], query_cmd = (f'gcloud compute tpus tpu-vm list --filter=' f'\\(labels.ray-cluster-name={cluster_name}\\) ' f'--zone={zone} --format=value\\(name\\)') - if not get_internal_ips: - tpuvm_cmd = (f'gcloud compute tpus tpu-vm describe $({query_cmd})' - f' --zone {zone} --format="value[delimiter=\'\\n\']' - '(networkEndpoints.accessConfig.externalIp)"') - else: - tpuvm_cmd = (f'gcloud compute tpus tpu-vm describe $({query_cmd})' - f' --zone {zone} --format="value[delimiter=\'\\n\']' - '(networkEndpoints.ipAddress)"') + returncode, stdout, stderr = log_lib.run_with_log(query_cmd, + '/dev/null', + shell=True, + stream_logs=False, + require_outputs=True) + subprocess_utils.handle_returncode( + returncode, + query_cmd, + 'Failed to run gcloud to get TPU VM IDs.', + stderr=stdout + stderr) + if len(stdout) == 0: + logger.debug('No TPU VMs found with cluster name ' + f'{cluster_name} in zone {zone}.') + if len(stdout.splitlines()) > 1: + # Rare case, this could mean resource leakage. Hint user. + logger.warning('Found more than one TPU VM/Pod with the same cluster ' + f'name {cluster_name} in zone {zone}.') + + all_ips = [] + for tpu_id in stdout.splitlines(): + tpuvm_cmd = (f'gcloud compute tpus tpu-vm describe {tpu_id}' + f' --zone {zone} --format=json') + returncode, stdout, stderr = log_lib.run_with_log(tpuvm_cmd, + '/dev/null', + shell=True, + stream_logs=False, + require_outputs=True) + subprocess_utils.handle_returncode( + returncode, + tpuvm_cmd, + 'Failed to run gcloud tpu-vm describe.', + stderr=stdout + stderr) + + tpuvm_json = json.loads(stdout) + if tpuvm_json['state'] != 'READY': + # May be a leaked preempted resource. + logger.warning(f'TPU VM {tpu_id} is not in READY state. ' + 'Could be a garbage resource. Skipping...') + continue + + if not get_internal_ips: + ips = [ + endpoint['accessConfig']['externalIp'] + for endpoint in tpuvm_json['networkEndpoints'] + ] + else: + ips = [ + endpoint['ipAddress'] + for endpoint in tpuvm_json['networkEndpoints'] + ] + all_ips.extend(ips) - rcode, stdout, stderr = log_lib.run_with_log(tpuvm_cmd, - '/dev/null', - shell=True, - stream_logs=False, - require_outputs=True) - if rcode != 0: - failure_massage = ('Failed to run gcloud to get TPU VM Pod IPs.\n' - '**** STDOUT ****\n' - '{stdout}\n' - '**** STDERR ****\n' - '{stderr}\n' - '**** CMD ****\n' - '{tpuvm_cmd}') - with ux_utils.print_exception_no_traceback(): - raise RuntimeError( - failure_massage.format(stdout=stdout, - stderr=stderr, - tpuvm_cmd=tpuvm_cmd)) - all_ips = re.findall(IP_ADDR_REGEX, stdout) return all_ips diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index c9f51452d4a..a3c7371da56 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -2684,9 +2684,29 @@ def teardown_no_lock(self, f'gcloud compute tpus tpu-vm list --filter=' f'\\(labels.ray-cluster-name={cluster_name}\\) ' f'--zone={zone} --format=value\\(name\\)') - terminate_cmd = ( - f'gcloud compute tpus tpu-vm delete --zone={zone}' - f' --quiet $({query_cmd})') + returncode, stdout, stderr = log_lib.run_with_log( + query_cmd, + log_abs_path, + shell=True, + stream_logs=False, + require_outputs=True) + + # Skip the termination command, if the TPU ID + # query command fails. + if returncode != 0: + terminate_cmd = (f'echo "cmd: {query_cmd}" && ' + f'echo "{stdout}" && ' + f'echo "{stderr}" >&2 && ' + f'exit {returncode}') + else: + # Needs to create a list as GCP does not allow deleting + # multiple TPU VMs at once. + tpu_terminate_cmds = [] + for tpu_id in stdout.splitlines(): + tpu_terminate_cmds.append( + 'gcloud compute tpus tpu-vm delete ' + f'--zone={zone} --quiet {tpu_id}') + terminate_cmd = ' && '.join(tpu_terminate_cmds) else: query_cmd = ( f'gcloud compute instances list --filter=' diff --git a/sky/clouds/cloud.py b/sky/clouds/cloud.py index f7530561165..6eb856f526e 100644 --- a/sky/clouds/cloud.py +++ b/sky/clouds/cloud.py @@ -214,5 +214,18 @@ def accelerator_in_region_or_zone(self, """Returns whether the accelerator is valid in the region or zone.""" raise NotImplementedError + def need_cleanup_after_preemption(self, + resource: 'resources.Resources') -> bool: + """Returns whether a spot resource needs cleanup after preeemption. + + In most cases, spot resources do not need cleanup after preemption, + as long as the cluster can be relaunched with the same name and tag, + no matter the preemption behavior is to terminate or stop the cluster. + The only exception by far is GCP's Spot TPU VM. We override this method + in gcp.py. + """ + del resource + return False + def __repr__(self): return self._REPR diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index e176e20d2e8..fffd1c9a5ed 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -468,6 +468,18 @@ def accelerator_in_region_or_zone(self, return service_catalog.accelerator_in_region_or_zone( accelerator, acc_count, region, zone, 'gcp') + def need_cleanup_after_preemption(self, + resources: 'resources.Resources') -> bool: + """Returns whether a spot resource needs cleanup after preeemption.""" + # Spot TPU VMs require manual cleanup after preemption. + # "If your Cloud TPU is preempted, + # you must delete it and create a new one ..." + # See: https://cloud.google.com/tpu/docs/preemptible#tpu-vm + + # pylint: disable=import-outside-toplevel + from sky.utils import tpu_utils + return tpu_utils.is_tpu_vm(resources) + @classmethod def get_project_id(cls, dryrun: bool = False) -> str: # TODO(zhwu): change the project id fetching with the following command diff --git a/sky/resources.py b/sky/resources.py index 71754ec9ef7..e112cf700db 100644 --- a/sky/resources.py +++ b/sky/resources.py @@ -260,6 +260,10 @@ def _set_accelerators( def is_launchable(self) -> bool: return self.cloud is not None and self._instance_type is not None + def need_cleanup_after_preemption(self) -> bool: + """Returns whether a spot resource needs cleanup after preeemption.""" + return self.cloud.need_cleanup_after_preemption(self) + def _set_region_zone(self, region: Optional[str], zone: Optional[str]) -> None: if region is None and zone is None: diff --git a/sky/spot/controller.py b/sky/spot/controller.py index deb7fc22fc5..7a59e4900eb 100644 --- a/sky/spot/controller.py +++ b/sky/spot/controller.py @@ -145,6 +145,13 @@ def _run(self): 'cluster is healthy. Try to recover the job ' '(the cluster will not be restarted).') + resources = list(self._task.resources)[0] + if resources.need_cleanup_after_preemption(): + # Some spot resource (e.g., Spot TPU VM) may need to be + # cleaned up after preemption. + logger.info('Cleaning up the preempted spot cluster...') + self._strategy_executor.terminate_cluster() + # Try to recover the spot jobs, when the cluster is preempted # or the job status is failed to be fetched. spot_state.set_recovering(self._job_id) diff --git a/sky/spot/recovery_strategy.py b/sky/spot/recovery_strategy.py index 3459cb2edf7..de6b8eeb213 100644 --- a/sky/spot/recovery_strategy.py +++ b/sky/spot/recovery_strategy.py @@ -276,10 +276,6 @@ def _launch(self, max_retry=3, raise_on_failure=True) -> Optional[float]: launched_resources.region) return launch_time - def terminate_cluster(self, max_retry: int = 3) -> None: - super().terminate_cluster(max_retry) - self._launched_cloud_region = None - def recover(self) -> float: # 1. Cancel the jobs and launch the cluster with the STOPPED status, # so that it will try on the current region first until timeout. @@ -313,7 +309,9 @@ def recover(self) -> float: return launched_time # Step 2 - logger.debug('Terminating unhealthy spot cluster.') + logger.debug('Terminating unhealthy spot cluster and ' + 'reset cloud region.') + self._launched_cloud_region = None self.terminate_cluster() # Step 3 diff --git a/sky/utils/tpu_utils.py b/sky/utils/tpu_utils.py index 70f2b05b117..72f89f9ed54 100644 --- a/sky/utils/tpu_utils.py +++ b/sky/utils/tpu_utils.py @@ -4,28 +4,29 @@ from sky import resources as resources_lib -def is_tpu(resources: resources_lib.Resources) -> bool: - if resources.accelerators is None: +def is_tpu(resources: Optional[resources_lib.Resources]) -> bool: + if resources is None or resources.accelerators is None: return False acc, _ = list(resources.accelerators.items())[0] return acc.startswith('tpu') -def is_tpu_vm(resources: resources_lib.Resources) -> bool: - if resources.accelerator_args is None: +def is_tpu_vm(resources: Optional[resources_lib.Resources]) -> bool: + if resources is None or resources.accelerator_args is None: return False return resources.accelerator_args.get('tpu_vm', False) -def is_tpu_vm_pod(resources: resources_lib.Resources) -> bool: - if not is_tpu_vm(resources): +def is_tpu_vm_pod(resources: Optional[resources_lib.Resources]) -> bool: + if resources is None or not is_tpu_vm(resources): return False acc, _ = list(resources.accelerators.items())[0] return acc not in ['tpu-v2-8', 'tpu-v3-8'] -def get_num_tpu_devices(resources: resources_lib.Resources) -> Optional[int]: - if not is_tpu(resources): +def get_num_tpu_devices( + resources: Optional[resources_lib.Resources]) -> Optional[int]: + if resources is None or not is_tpu(resources): return None acc, _ = list(resources.accelerators.items())[0] num_tpu_devices = int(int(acc.split('-')[2]) / 8)