Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add safe guard for provisioning/terminating TPU VM and fix spot launch TPU resource leak #1500

Merged
merged 16 commits into from
Dec 15, 2022
81 changes: 57 additions & 24 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,34 +1228,67 @@ def _get_tpu_vm_pod_ips(ray_config: Dict[str, Any],
query_cmd = (f'gcloud compute tpus tpu-vm list --filter='
f'\\(labels.ray-cluster-name={cluster_name}\\) '
f'--zone={zone} --format=value\\(name\\)')
if not get_internal_ips:
tpuvm_cmd = (f'gcloud compute tpus tpu-vm describe $({query_cmd})'
f' --zone {zone} --format="value[delimiter=\'\\n\']'
'(networkEndpoints.accessConfig.externalIp)"')
else:
tpuvm_cmd = (f'gcloud compute tpus tpu-vm describe $({query_cmd})'
f' --zone {zone} --format="value[delimiter=\'\\n\']'
'(networkEndpoints.ipAddress)"')

rcode, stdout, stderr = log_lib.run_with_log(tpuvm_cmd,
'/dev/null',
shell=True,
stream_logs=False,
require_outputs=True)
if rcode != 0:
failure_massage = ('Failed to run gcloud to get TPU VM Pod IPs.\n'
returncode, stdout, stderr = log_lib.run_with_log(query_cmd,
'/dev/null',
shell=True,
stream_logs=False,
require_outputs=True)
if returncode != 0:
failure_massage = ('Failed to run gcloud to get TPU VM IDs.\n'
'**** STDOUT ****\n'
'{stdout}\n'
f'{stdout}\n'
'**** STDERR ****\n'
'{stderr}\n'
f'{stderr}\n'
'**** CMD ****\n'
'{tpuvm_cmd}')
f'{query_cmd}\n')
with ux_utils.print_exception_no_traceback():
raise RuntimeError(
failure_massage.format(stdout=stdout,
stderr=stderr,
tpuvm_cmd=tpuvm_cmd))
all_ips = re.findall(IP_ADDR_REGEX, stdout)
raise RuntimeError(failure_massage)
infwinston marked this conversation as resolved.
Show resolved Hide resolved
if len(stdout) == 0:
logger.warning('No TPU VMs found with cluster name '
f'{cluster_name} in zone {zone}.')
if len(stdout.splitlines()) > 1:
logger.warning('Found more than one TPU VM with cluster name '
f'{cluster_name} in zone {zone}.')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a warning? Also, we must be careful about the logs printed during the status refresh. Since that will corrupt the progress bar output of sky status -r. How about we change them to logger.debug?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I choose logger.warning because multiple TPU VM/Pod with the same cluster name is considered an abnormal case which is not supposed to happen.
When this happens, it means there's a resource leak. I think in this case we'd like to let user know?

Copy link
Collaborator

@Michaelvll Michaelvll Dec 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it a normal case for spot VM? I think for a non-TPU cluster, we don't show the warning. I think we will handle the number of IPs not equal to the actual amount in the caller function.

Also, is it true that a user can have multiple TPU VM with the same name in a same zone?

Copy link
Member Author

@infwinston infwinston Dec 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry for the confusion. let me explain again. For Spot VM, it also shouldn't happen that multiple Spot TPU VM having the same labels.ray-cluster-name.

Basically the query command should only return one VM/Pod in normal case.

query_cmd = (f'gcloud compute tpus tpu-vm list --filter='
                 f'\\(labels.ray-cluster-name={cluster_name}\\) '
                 f'--zone={zone} --format=value\\(name\\)')

But if there's a leak resource (e.g., controller failed to terminate a preempted spot TPU), then this query command will return two VMs which is an abnormal case.

Also, is it true that a user can have multiple TPU VM with the same name in a same zone?

note that I was not referring to the "TPU name" shown on the console but labels.ray-cluster-name. so yes multiple TPU VM can have same labels.ray-cluster-name.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with changing it to logger.debug but I'm also afraid that user will never find out there's a leaked resource unless they manually check the console.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused, then why does the problem not happen for a non TPU VM cluster? What ensures those cluster not leaked?

Copy link
Member Author

@infwinston infwinston Dec 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for non TPU VM cluster, as it doesn't require manual cleanup after preemption, resource won't be leaked this way? but I'm not sure if there are other scenarios that could trigger leakage. Also we mostly rely on ray up to handle non-TPU VM clusters (probably irrelevant)


all_ips = []
for tpu_id in stdout.splitlines():
tpuvm_cmd = (f'gcloud compute tpus tpu-vm describe {tpu_id}'
f' --zone {zone} --format=json')
returncode, stdout, stderr = log_lib.run_with_log(tpuvm_cmd,
'/dev/null',
shell=True,
stream_logs=False,
require_outputs=True)
if returncode != 0:
failure_massage = ('Failed to run gcloud tpu-vm describe.\n'
'**** STDOUT ****\n'
f'{stdout}\n'
'**** STDERR ****\n'
f'{stderr}\n'
'**** CMD ****\n'
f'{tpuvm_cmd}\n')
with ux_utils.print_exception_no_traceback():
raise RuntimeError(failure_massage)

tpuvm_json = json.loads(stdout)
if tpuvm_json['state'] != 'READY':
# May be a leaked preempted resource.
logger.warning(f'TPU VM {tpu_id} is not in READY state. '
'Could be a garbage resource. Skipping...')
continue
Comment on lines +1264 to +1269
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this state be different for differnet tpu_id?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each TPU VM or TPU Pod only maps to a single tpu_id. So yes, different tpu_id can have different states.
But when this multiple tpu_id situation happens, it means there is a leaked resource with the same cluster name as the current one. That's why I print the garbage resource message.

Normally there should be only one tpu_id returned with the below query command.

query_cmd = (f'gcloud compute tpus tpu-vm list --filter='
                 f'\\(labels.ray-cluster-name={cluster_name}\\) '
                 f'--zone={zone} --format=value\\(name\\)')


if not get_internal_ips:
ips = [
endpoint['accessConfig']['externalIp']
for endpoint in tpuvm_json['networkEndpoints']
]
else:
ips = [
endpoint['ipAddress']
for endpoint in tpuvm_json['networkEndpoints']
]
all_ips.extend(ips)

return all_ips


Expand Down
34 changes: 23 additions & 11 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2657,6 +2657,7 @@ def teardown_no_lock(self,
elif (terminate and
(prev_status == global_user_state.ClusterStatus.STOPPED or
use_tpu_vm)):
terminate_cmds = []
# For TPU VMs, gcloud CLI is used for VM termination.
if isinstance(cloud, clouds.AWS):
# TODO(zhwu): Room for optimization. We can move these cloud
Expand All @@ -2669,7 +2670,7 @@ def teardown_no_lock(self,
f'Name=tag:ray-cluster-name,Values={handle.cluster_name} '
f'--query Reservations[].Instances[].InstanceId '
'--output text')
terminate_cmd = (
terminate_cmds.append(
f'aws ec2 terminate-instances --region {region} '
f'--instance-ids $({query_cmd})')
elif isinstance(cloud, clouds.GCP):
Expand All @@ -2684,15 +2685,25 @@ def teardown_no_lock(self,
f'gcloud compute tpus tpu-vm list --filter='
f'\\(labels.ray-cluster-name={cluster_name}\\) '
f'--zone={zone} --format=value\\(name\\)')
terminate_cmd = (
f'gcloud compute tpus tpu-vm delete --zone={zone}'
f' --quiet $({query_cmd})')
returncode, stdout, stderr = log_lib.run_with_log(
query_cmd,
log_abs_path,
shell=True,
stream_logs=False,
require_outputs=True)

# Needs to create a list as GCP does not allow deleting
# multiple TPU VMs at once
for tpu_id in stdout.splitlines():
terminate_cmds.append(
f'gcloud compute tpus tpu-vm delete --zone={zone} '
f'--quiet {tpu_id}')
infwinston marked this conversation as resolved.
Show resolved Hide resolved
else:
query_cmd = (
f'gcloud compute instances list --filter='
f'\\(labels.ray-cluster-name={cluster_name}\\) '
f'--zones={zone} --format=value\\(name\\)')
terminate_cmd = (
terminate_cmds.append(
f'gcloud compute instances delete --zone={zone}'
f' --quiet $({query_cmd})')
else:
Expand All @@ -2701,12 +2712,13 @@ def teardown_no_lock(self,
f'cluster {cluster_name!r}.')
with backend_utils.safe_console_status(f'[bold cyan]Terminating '
f'[green]{cluster_name}'):
returncode, stdout, stderr = log_lib.run_with_log(
terminate_cmd,
log_abs_path,
shell=True,
stream_logs=False,
require_outputs=True)
for terminate_cmd in terminate_cmds:
returncode, stdout, stderr = log_lib.run_with_log(
terminate_cmd,
log_abs_path,
shell=True,
stream_logs=False,
require_outputs=True)
else:
config['provider']['cache_stopped_nodes'] = not terminate
with tempfile.NamedTemporaryFile('w',
Expand Down
9 changes: 9 additions & 0 deletions sky/clouds/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,3 +351,12 @@ def accelerator_in_region_or_zone(self,
zone: Optional[str] = None) -> bool:
return service_catalog.accelerator_in_region_or_zone(
accelerator, acc_count, region, zone, 'aws')

def need_cleanup_after_preemption(
self, resources: 'resources_lib.Resources') -> bool:
"""Returns whether a spot resource needs cleanup after preeemption."""
# By default, AWS Spot instances are not restartable after preemption.
# "Terminate interrupted Spot Instances (this is the default behavior)"
# See: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/interruption-behavior.html # pylint: disable=line-too-long
del resources # unused
return False
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
10 changes: 10 additions & 0 deletions sky/clouds/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,16 @@ def accelerator_in_region_or_zone(self,
return service_catalog.accelerator_in_region_or_zone(
accelerator, acc_count, region, zone, 'azure')

def need_cleanup_after_preemption(self,
resources: 'resources.Resources') -> bool:
"""Returns whether a spot resource needs cleanup after preeemption."""
# By default, Azure Spot instances are restartable after preemption.
# "When creating an Azure Spot Virtual Machine, you can set
# the eviction policy to Deallocate (default) or Delete."
# See: https://learn.microsoft.com/en-us/azure/virtual-machines/spot-vms#eviction-policy # pylint: disable=line-too-long
del resources # Unused.
return True

@classmethod
def get_project_id(cls, dryrun: bool = False) -> str:
if dryrun:
Expand Down
5 changes: 5 additions & 0 deletions sky/clouds/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,5 +214,10 @@ def accelerator_in_region_or_zone(self,
"""Returns whether the accelerator is valid in the region or zone."""
raise NotImplementedError

def need_cleanup_after_preemption(self,
resource: 'resources.Resources') -> bool:
"""Returns whether a spot resource needs cleanup after preeemption."""
raise NotImplementedError

def __repr__(self):
return self._REPR
16 changes: 16 additions & 0 deletions sky/clouds/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,22 @@ def accelerator_in_region_or_zone(self,
return service_catalog.accelerator_in_region_or_zone(
accelerator, acc_count, region, zone, 'gcp')

def need_cleanup_after_preemption(self,
resources: 'resources.Resources') -> bool:
"""Returns whether a spot resource needs cleanup after preeemption."""
# By default, GCP Compute VMs are restartable after preemption.
# "If ... not specified, then Compute Engine stops the VM,
# transitioning the VM to a TERMINATED state."
# See: https://cloud.google.com/compute/docs/instances/spot#preemption-process # pylint: disable=line-too-long
# However, Spot TPU VMs are not restartable after preemption.
# "If your Cloud TPU is preempted,
# you must delete it and create a new one ..."
# See: https://cloud.google.com/tpu/docs/preemptible#tpu-vm

# pylint: disable=import-outside-toplevel
from sky.utils import tpu_utils
return not tpu_utils.is_tpu_vm(resources)

@classmethod
def get_project_id(cls, dryrun: bool = False) -> str:
# TODO(zhwu): change the project id fetching with the following command
Expand Down
4 changes: 4 additions & 0 deletions sky/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,10 @@ def _set_accelerators(
def is_launchable(self) -> bool:
return self.cloud is not None and self._instance_type is not None

def need_cleanup_after_preemption(self) -> bool:
"""Returns whether a spot resource needs cleanup after preeemption."""
return self.cloud.need_cleanup_after_preemption(self)

def _set_region_zone(self, region: Optional[str],
zone: Optional[str]) -> None:
if region is None and zone is None:
Expand Down
7 changes: 7 additions & 0 deletions sky/spot/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,13 @@ def _run(self):
'cluster is healthy. Try to recover the job '
'(the cluster will not be restarted).')

resources = list(self._task.resources)[0]
if not resources.need_cleanup_after_preemption():
# Some spot resource may need to be cleaned up after
# preemption, if the resource is not reusable.
logger.info('Cleaning up the preempted spot cluster...')
self._strategy_executor.terminate_cluster()

# Try to recover the spot jobs, when the cluster is preempted
# or the job status is failed to be fetched.
spot_state.set_recovering(self._job_id)
Expand Down
8 changes: 3 additions & 5 deletions sky/spot/recovery_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,6 @@ def _launch(self, max_retry=3, raise_on_failure=True) -> Optional[float]:
launched_resources.region)
return launch_time

def terminate_cluster(self, max_retry: int = 3) -> None:
super().terminate_cluster(max_retry)
self._launched_cloud_region = None

def recover(self) -> float:
# 1. Cancel the jobs and launch the cluster with the STOPPED status,
# so that it will try on the current region first until timeout.
Expand Down Expand Up @@ -313,7 +309,9 @@ def recover(self) -> float:
return launched_time

# Step 2
logger.debug('Terminating unhealthy spot cluster.')
logger.debug('Terminating unhealthy spot cluster and '
'reset cloud region.')
self._launched_cloud_region = None
self.terminate_cluster()

# Step 3
Expand Down
17 changes: 9 additions & 8 deletions sky/utils/tpu_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,29 @@
from sky import resources as resources_lib


def is_tpu(resources: resources_lib.Resources) -> bool:
if resources.accelerators is None:
def is_tpu(resources: Optional[resources_lib.Resources]) -> bool:
if resources is None or resources.accelerators is None:
return False
acc, _ = list(resources.accelerators.items())[0]
return acc.startswith('tpu')


def is_tpu_vm(resources: resources_lib.Resources) -> bool:
if resources.accelerator_args is None:
def is_tpu_vm(resources: Optional[resources_lib.Resources]) -> bool:
if resources is None or resources.accelerator_args is None:
return False
return resources.accelerator_args.get('tpu_vm', False)


def is_tpu_vm_pod(resources: resources_lib.Resources) -> bool:
if not is_tpu_vm(resources):
def is_tpu_vm_pod(resources: Optional[resources_lib.Resources]) -> bool:
if resources is None or not is_tpu_vm(resources):
return False
acc, _ = list(resources.accelerators.items())[0]
return acc not in ['tpu-v2-8', 'tpu-v3-8']


def get_num_tpu_devices(resources: resources_lib.Resources) -> Optional[int]:
if not is_tpu(resources):
def get_num_tpu_devices(
resources: Optional[resources_lib.Resources]) -> Optional[int]:
if resources is None or not is_tpu(resources):
return None
acc, _ = list(resources.accelerators.items())[0]
num_tpu_devices = int(int(acc.split('-')[2]) / 8)
Expand Down