Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add safe guard for provisioning/terminating TPU VM and fix spot launch TPU resource leak #1500

Merged
merged 16 commits into from
Dec 15, 2022
78 changes: 51 additions & 27 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,34 +1228,58 @@ def _get_tpu_vm_pod_ips(ray_config: Dict[str, Any],
query_cmd = (f'gcloud compute tpus tpu-vm list --filter='
f'\\(labels.ray-cluster-name={cluster_name}\\) '
f'--zone={zone} --format=value\\(name\\)')
if not get_internal_ips:
tpuvm_cmd = (f'gcloud compute tpus tpu-vm describe $({query_cmd})'
f' --zone {zone} --format="value[delimiter=\'\\n\']'
'(networkEndpoints.accessConfig.externalIp)"')
else:
tpuvm_cmd = (f'gcloud compute tpus tpu-vm describe $({query_cmd})'
f' --zone {zone} --format="value[delimiter=\'\\n\']'
'(networkEndpoints.ipAddress)"')
returncode, stdout, stderr = log_lib.run_with_log(query_cmd,
'/dev/null',
shell=True,
stream_logs=False,
require_outputs=True)
subprocess_utils.handle_returncode(
returncode,
query_cmd,
'Failed to run gcloud to get TPU VM IDs.',
stderr=stdout + stderr)
if len(stdout) == 0:
logger.debug('No TPU VMs found with cluster name '
f'{cluster_name} in zone {zone}.')
if len(stdout.splitlines()) > 1:
# Rare case, this could mean resource leakage. Hint user.
logger.warning('Found more than one TPU VM/Pod with the same cluster '
f'name {cluster_name} in zone {zone}.')

all_ips = []
for tpu_id in stdout.splitlines():
tpuvm_cmd = (f'gcloud compute tpus tpu-vm describe {tpu_id}'
f' --zone {zone} --format=json')
returncode, stdout, stderr = log_lib.run_with_log(tpuvm_cmd,
'/dev/null',
shell=True,
stream_logs=False,
require_outputs=True)
subprocess_utils.handle_returncode(
returncode,
tpuvm_cmd,
'Failed to run gcloud tpu-vm describe.',
stderr=stdout + stderr)

tpuvm_json = json.loads(stdout)
if tpuvm_json['state'] != 'READY':
# May be a leaked preempted resource.
logger.warning(f'TPU VM {tpu_id} is not in READY state. '
'Could be a garbage resource. Skipping...')
continue
Comment on lines +1264 to +1269
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this state be different for differnet tpu_id?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each TPU VM or TPU Pod only maps to a single tpu_id. So yes, different tpu_id can have different states.
But when this multiple tpu_id situation happens, it means there is a leaked resource with the same cluster name as the current one. That's why I print the garbage resource message.

Normally there should be only one tpu_id returned with the below query command.

query_cmd = (f'gcloud compute tpus tpu-vm list --filter='
                 f'\\(labels.ray-cluster-name={cluster_name}\\) '
                 f'--zone={zone} --format=value\\(name\\)')


if not get_internal_ips:
ips = [
endpoint['accessConfig']['externalIp']
for endpoint in tpuvm_json['networkEndpoints']
]
else:
ips = [
endpoint['ipAddress']
for endpoint in tpuvm_json['networkEndpoints']
]
all_ips.extend(ips)

rcode, stdout, stderr = log_lib.run_with_log(tpuvm_cmd,
'/dev/null',
shell=True,
stream_logs=False,
require_outputs=True)
if rcode != 0:
failure_massage = ('Failed to run gcloud to get TPU VM Pod IPs.\n'
'**** STDOUT ****\n'
'{stdout}\n'
'**** STDERR ****\n'
'{stderr}\n'
'**** CMD ****\n'
'{tpuvm_cmd}')
with ux_utils.print_exception_no_traceback():
raise RuntimeError(
failure_massage.format(stdout=stdout,
stderr=stderr,
tpuvm_cmd=tpuvm_cmd))
all_ips = re.findall(IP_ADDR_REGEX, stdout)
return all_ips


Expand Down
26 changes: 23 additions & 3 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2684,9 +2684,29 @@ def teardown_no_lock(self,
f'gcloud compute tpus tpu-vm list --filter='
f'\\(labels.ray-cluster-name={cluster_name}\\) '
f'--zone={zone} --format=value\\(name\\)')
terminate_cmd = (
f'gcloud compute tpus tpu-vm delete --zone={zone}'
f' --quiet $({query_cmd})')
returncode, stdout, stderr = log_lib.run_with_log(
query_cmd,
log_abs_path,
shell=True,
stream_logs=False,
require_outputs=True)

# Skip the termination command, if the TPU ID
# query command fails.
if returncode != 0:
terminate_cmd = (f'echo "cmd: {query_cmd}" && '
f'echo "{stdout}" && '
f'echo "{stderr}" >&2 && '
f'exit {returncode}')
else:
# Needs to create a list as GCP does not allow deleting
# multiple TPU VMs at once.
tpu_terminate_cmds = []
for tpu_id in stdout.splitlines():
tpu_terminate_cmds.append(
'gcloud compute tpus tpu-vm delete '
f'--zone={zone} --quiet {tpu_id}')
terminate_cmd = ' && '.join(tpu_terminate_cmds)
else:
query_cmd = (
f'gcloud compute instances list --filter='
Expand Down
13 changes: 13 additions & 0 deletions sky/clouds/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,5 +214,18 @@ def accelerator_in_region_or_zone(self,
"""Returns whether the accelerator is valid in the region or zone."""
raise NotImplementedError

def need_cleanup_after_preemption(self,
resource: 'resources.Resources') -> bool:
"""Returns whether a spot resource needs cleanup after preeemption.

In most cases, spot resources do not need cleanup after preemption,
as long as the cluster can be relaunched with the same name and tag,
no matter the preemption behavior is to terminate or stop the cluster.
The only exception by far is GCP's Spot TPU VM. We override this method
in gcp.py.
"""
del resource
return False

def __repr__(self):
return self._REPR
12 changes: 12 additions & 0 deletions sky/clouds/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,18 @@ def accelerator_in_region_or_zone(self,
return service_catalog.accelerator_in_region_or_zone(
accelerator, acc_count, region, zone, 'gcp')

def need_cleanup_after_preemption(self,
resources: 'resources.Resources') -> bool:
"""Returns whether a spot resource needs cleanup after preeemption."""
# Spot TPU VMs require manual cleanup after preemption.
# "If your Cloud TPU is preempted,
# you must delete it and create a new one ..."
# See: https://cloud.google.com/tpu/docs/preemptible#tpu-vm

# pylint: disable=import-outside-toplevel
from sky.utils import tpu_utils
return tpu_utils.is_tpu_vm(resources)

@classmethod
def get_project_id(cls, dryrun: bool = False) -> str:
# TODO(zhwu): change the project id fetching with the following command
Expand Down
4 changes: 4 additions & 0 deletions sky/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,10 @@ def _set_accelerators(
def is_launchable(self) -> bool:
return self.cloud is not None and self._instance_type is not None

def need_cleanup_after_preemption(self) -> bool:
"""Returns whether a spot resource needs cleanup after preeemption."""
return self.cloud.need_cleanup_after_preemption(self)

def _set_region_zone(self, region: Optional[str],
zone: Optional[str]) -> None:
if region is None and zone is None:
Expand Down
7 changes: 7 additions & 0 deletions sky/spot/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,13 @@ def _run(self):
'cluster is healthy. Try to recover the job '
'(the cluster will not be restarted).')

resources = list(self._task.resources)[0]
if resources.need_cleanup_after_preemption():
# Some spot resource (e.g., Spot TPU VM) may need to be
# cleaned up after preemption.
logger.info('Cleaning up the preempted spot cluster...')
self._strategy_executor.terminate_cluster()

# Try to recover the spot jobs, when the cluster is preempted
# or the job status is failed to be fetched.
spot_state.set_recovering(self._job_id)
Expand Down
8 changes: 3 additions & 5 deletions sky/spot/recovery_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,6 @@ def _launch(self, max_retry=3, raise_on_failure=True) -> Optional[float]:
launched_resources.region)
return launch_time

def terminate_cluster(self, max_retry: int = 3) -> None:
super().terminate_cluster(max_retry)
self._launched_cloud_region = None

def recover(self) -> float:
# 1. Cancel the jobs and launch the cluster with the STOPPED status,
# so that it will try on the current region first until timeout.
Expand Down Expand Up @@ -313,7 +309,9 @@ def recover(self) -> float:
return launched_time

# Step 2
logger.debug('Terminating unhealthy spot cluster.')
logger.debug('Terminating unhealthy spot cluster and '
'reset cloud region.')
self._launched_cloud_region = None
self.terminate_cluster()

# Step 3
Expand Down
17 changes: 9 additions & 8 deletions sky/utils/tpu_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,29 @@
from sky import resources as resources_lib


def is_tpu(resources: resources_lib.Resources) -> bool:
if resources.accelerators is None:
def is_tpu(resources: Optional[resources_lib.Resources]) -> bool:
if resources is None or resources.accelerators is None:
return False
acc, _ = list(resources.accelerators.items())[0]
return acc.startswith('tpu')


def is_tpu_vm(resources: resources_lib.Resources) -> bool:
if resources.accelerator_args is None:
def is_tpu_vm(resources: Optional[resources_lib.Resources]) -> bool:
if resources is None or resources.accelerator_args is None:
return False
return resources.accelerator_args.get('tpu_vm', False)


def is_tpu_vm_pod(resources: resources_lib.Resources) -> bool:
if not is_tpu_vm(resources):
def is_tpu_vm_pod(resources: Optional[resources_lib.Resources]) -> bool:
if resources is None or not is_tpu_vm(resources):
return False
acc, _ = list(resources.accelerators.items())[0]
return acc not in ['tpu-v2-8', 'tpu-v3-8']


def get_num_tpu_devices(resources: resources_lib.Resources) -> Optional[int]:
if not is_tpu(resources):
def get_num_tpu_devices(
resources: Optional[resources_lib.Resources]) -> Optional[int]:
if resources is None or not is_tpu(resources):
return None
acc, _ = list(resources.accelerators.items())[0]
num_tpu_devices = int(int(acc.split('-')[2]) / 8)
Expand Down