Skip to content

Commit

Permalink
Add safe guard for provisioning/terminating TPU VM and fix spot launc…
Browse files Browse the repository at this point in the history
…h TPU resource leak (#1500)

* safe guard

* terminate the cluster to be safe

* update

* rm

* better abstraction

* comment

* comments

* rename

* comments

* comment

* msg

* comment

* bug..

* msg

* miss one place

* output error msg
  • Loading branch information
infwinston authored Dec 15, 2022
1 parent 2c0685d commit af1b7fd
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 43 deletions.
78 changes: 51 additions & 27 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,34 +1245,58 @@ def _get_tpu_vm_pod_ips(ray_config: Dict[str, Any],
query_cmd = (f'gcloud compute tpus tpu-vm list --filter='
f'\\(labels.ray-cluster-name={cluster_name}\\) '
f'--zone={zone} --format=value\\(name\\)')
if not get_internal_ips:
tpuvm_cmd = (f'gcloud compute tpus tpu-vm describe $({query_cmd})'
f' --zone {zone} --format="value[delimiter=\'\\n\']'
'(networkEndpoints.accessConfig.externalIp)"')
else:
tpuvm_cmd = (f'gcloud compute tpus tpu-vm describe $({query_cmd})'
f' --zone {zone} --format="value[delimiter=\'\\n\']'
'(networkEndpoints.ipAddress)"')
returncode, stdout, stderr = log_lib.run_with_log(query_cmd,
'/dev/null',
shell=True,
stream_logs=False,
require_outputs=True)
subprocess_utils.handle_returncode(
returncode,
query_cmd,
'Failed to run gcloud to get TPU VM IDs.',
stderr=stdout + stderr)
if len(stdout) == 0:
logger.debug('No TPU VMs found with cluster name '
f'{cluster_name} in zone {zone}.')
if len(stdout.splitlines()) > 1:
# Rare case, this could mean resource leakage. Hint user.
logger.warning('Found more than one TPU VM/Pod with the same cluster '
f'name {cluster_name} in zone {zone}.')

all_ips = []
for tpu_id in stdout.splitlines():
tpuvm_cmd = (f'gcloud compute tpus tpu-vm describe {tpu_id}'
f' --zone {zone} --format=json')
returncode, stdout, stderr = log_lib.run_with_log(tpuvm_cmd,
'/dev/null',
shell=True,
stream_logs=False,
require_outputs=True)
subprocess_utils.handle_returncode(
returncode,
tpuvm_cmd,
'Failed to run gcloud tpu-vm describe.',
stderr=stdout + stderr)

tpuvm_json = json.loads(stdout)
if tpuvm_json['state'] != 'READY':
# May be a leaked preempted resource.
logger.warning(f'TPU VM {tpu_id} is not in READY state. '
'Could be a garbage resource. Skipping...')
continue

if not get_internal_ips:
ips = [
endpoint['accessConfig']['externalIp']
for endpoint in tpuvm_json['networkEndpoints']
]
else:
ips = [
endpoint['ipAddress']
for endpoint in tpuvm_json['networkEndpoints']
]
all_ips.extend(ips)

rcode, stdout, stderr = log_lib.run_with_log(tpuvm_cmd,
'/dev/null',
shell=True,
stream_logs=False,
require_outputs=True)
if rcode != 0:
failure_massage = ('Failed to run gcloud to get TPU VM Pod IPs.\n'
'**** STDOUT ****\n'
'{stdout}\n'
'**** STDERR ****\n'
'{stderr}\n'
'**** CMD ****\n'
'{tpuvm_cmd}')
with ux_utils.print_exception_no_traceback():
raise RuntimeError(
failure_massage.format(stdout=stdout,
stderr=stderr,
tpuvm_cmd=tpuvm_cmd))
all_ips = re.findall(IP_ADDR_REGEX, stdout)
return all_ips


Expand Down
26 changes: 23 additions & 3 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2693,9 +2693,29 @@ def teardown_no_lock(self,
f'gcloud compute tpus tpu-vm list --filter='
f'\\(labels.ray-cluster-name={cluster_name}\\) '
f'--zone={zone} --format=value\\(name\\)')
terminate_cmd = (
f'gcloud compute tpus tpu-vm delete --zone={zone}'
f' --quiet $({query_cmd})')
returncode, stdout, stderr = log_lib.run_with_log(
query_cmd,
log_abs_path,
shell=True,
stream_logs=False,
require_outputs=True)

# Skip the termination command, if the TPU ID
# query command fails.
if returncode != 0:
terminate_cmd = (f'echo "cmd: {query_cmd}" && '
f'echo "{stdout}" && '
f'echo "{stderr}" >&2 && '
f'exit {returncode}')
else:
# Needs to create a list as GCP does not allow deleting
# multiple TPU VMs at once.
tpu_terminate_cmds = []
for tpu_id in stdout.splitlines():
tpu_terminate_cmds.append(
'gcloud compute tpus tpu-vm delete '
f'--zone={zone} --quiet {tpu_id}')
terminate_cmd = ' && '.join(tpu_terminate_cmds)
else:
query_cmd = (
f'gcloud compute instances list --filter='
Expand Down
13 changes: 13 additions & 0 deletions sky/clouds/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,5 +214,18 @@ def accelerator_in_region_or_zone(self,
"""Returns whether the accelerator is valid in the region or zone."""
raise NotImplementedError

def need_cleanup_after_preemption(self,
resource: 'resources.Resources') -> bool:
"""Returns whether a spot resource needs cleanup after preeemption.
In most cases, spot resources do not need cleanup after preemption,
as long as the cluster can be relaunched with the same name and tag,
no matter the preemption behavior is to terminate or stop the cluster.
The only exception by far is GCP's Spot TPU VM. We override this method
in gcp.py.
"""
del resource
return False

def __repr__(self):
return self._REPR
12 changes: 12 additions & 0 deletions sky/clouds/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,18 @@ def accelerator_in_region_or_zone(self,
return service_catalog.accelerator_in_region_or_zone(
accelerator, acc_count, region, zone, 'gcp')

def need_cleanup_after_preemption(self,
resources: 'resources.Resources') -> bool:
"""Returns whether a spot resource needs cleanup after preeemption."""
# Spot TPU VMs require manual cleanup after preemption.
# "If your Cloud TPU is preempted,
# you must delete it and create a new one ..."
# See: https://cloud.google.com/tpu/docs/preemptible#tpu-vm

# pylint: disable=import-outside-toplevel
from sky.utils import tpu_utils
return tpu_utils.is_tpu_vm(resources)

@classmethod
def get_project_id(cls, dryrun: bool = False) -> str:
# TODO(zhwu): change the project id fetching with the following command
Expand Down
4 changes: 4 additions & 0 deletions sky/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,10 @@ def _set_accelerators(
def is_launchable(self) -> bool:
return self.cloud is not None and self._instance_type is not None

def need_cleanup_after_preemption(self) -> bool:
"""Returns whether a spot resource needs cleanup after preeemption."""
return self.cloud.need_cleanup_after_preemption(self)

def _set_region_zone(self, region: Optional[str],
zone: Optional[str]) -> None:
if region is None and zone is None:
Expand Down
7 changes: 7 additions & 0 deletions sky/spot/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,13 @@ def _run(self):
'cluster is healthy. Try to recover the job '
'(the cluster will not be restarted).')

resources = list(self._task.resources)[0]
if resources.need_cleanup_after_preemption():
# Some spot resource (e.g., Spot TPU VM) may need to be
# cleaned up after preemption.
logger.info('Cleaning up the preempted spot cluster...')
self._strategy_executor.terminate_cluster()

# Try to recover the spot jobs, when the cluster is preempted
# or the job status is failed to be fetched.
spot_state.set_recovering(self._job_id)
Expand Down
8 changes: 3 additions & 5 deletions sky/spot/recovery_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,6 @@ def _launch(self, max_retry=3, raise_on_failure=True) -> Optional[float]:
launched_resources.region)
return launch_time

def terminate_cluster(self, max_retry: int = 3) -> None:
super().terminate_cluster(max_retry)
self._launched_cloud_region = None

def recover(self) -> float:
# 1. Cancel the jobs and launch the cluster with the STOPPED status,
# so that it will try on the current region first until timeout.
Expand Down Expand Up @@ -313,7 +309,9 @@ def recover(self) -> float:
return launched_time

# Step 2
logger.debug('Terminating unhealthy spot cluster.')
logger.debug('Terminating unhealthy spot cluster and '
'reset cloud region.')
self._launched_cloud_region = None
self.terminate_cluster()

# Step 3
Expand Down
17 changes: 9 additions & 8 deletions sky/utils/tpu_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,29 @@
from sky import resources as resources_lib


def is_tpu(resources: resources_lib.Resources) -> bool:
if resources.accelerators is None:
def is_tpu(resources: Optional[resources_lib.Resources]) -> bool:
if resources is None or resources.accelerators is None:
return False
acc, _ = list(resources.accelerators.items())[0]
return acc.startswith('tpu')


def is_tpu_vm(resources: resources_lib.Resources) -> bool:
if resources.accelerator_args is None:
def is_tpu_vm(resources: Optional[resources_lib.Resources]) -> bool:
if resources is None or resources.accelerator_args is None:
return False
return resources.accelerator_args.get('tpu_vm', False)


def is_tpu_vm_pod(resources: resources_lib.Resources) -> bool:
if not is_tpu_vm(resources):
def is_tpu_vm_pod(resources: Optional[resources_lib.Resources]) -> bool:
if resources is None or not is_tpu_vm(resources):
return False
acc, _ = list(resources.accelerators.items())[0]
return acc not in ['tpu-v2-8', 'tpu-v3-8']


def get_num_tpu_devices(resources: resources_lib.Resources) -> Optional[int]:
if not is_tpu(resources):
def get_num_tpu_devices(
resources: Optional[resources_lib.Resources]) -> Optional[int]:
if resources is None or not is_tpu(resources):
return None
acc, _ = list(resources.accelerators.items())[0]
num_tpu_devices = int(int(acc.split('-')[2]) / 8)
Expand Down

0 comments on commit af1b7fd

Please sign in to comment.