From fe3360d17e4885eb31174d6bdb1006da035338f9 Mon Sep 17 00:00:00 2001 From: "Siyuan (Ryans) Zhuang" Date: Wed, 26 Jul 2023 16:19:18 -0700 Subject: [PATCH] [AWS] Adopt new provisioner to query clusters (#2288) Adopt new provisioner to query cluster status, i.e. instances and their status. --- sky/backends/backend_utils.py | 33 +++++++++++++-- sky/clouds/aws.py | 75 ++++++----------------------------- sky/provision/__init__.py | 19 +++++++++ sky/provision/aws/__init__.py | 3 +- sky/provision/aws/instance.py | 72 +++++++++++++++++++++++++-------- 5 files changed, 118 insertions(+), 84 deletions(-) diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index 32152464a62..4b0516ec4e7 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -34,6 +34,7 @@ from sky import clouds from sky import exceptions from sky import global_user_state +from sky import provision as provision_lib from sky import skypilot_config from sky import sky_logging from sky import spot as spot_lib @@ -1684,7 +1685,12 @@ def tag_filter_for_cluster(cluster_name: str) -> Dict[str, str]: def _query_cluster_status_via_cloud_api( handle: 'cloud_vm_ray_backend.CloudVmRayResourceHandle' ) -> List[status_lib.ClusterStatus]: - """Returns the status of the cluster.""" + """Returns the status of the cluster. + + Raises: + exceptions.ClusterStatusFetchingError: the cluster status cannot be + fetched from the cloud provider. + """ cluster_name = handle.cluster_name # Use region and zone from the cluster config, instead of the # handle.launched_resources, because the latter may not be set @@ -1698,9 +1704,22 @@ def _query_cluster_status_via_cloud_api( kwargs['use_tpu_vm'] = ray_config['provider'].get('_has_tpus', False) # Query the cloud provider. - node_statuses = handle.launched_resources.cloud.query_status( - cluster_name, tag_filter_for_cluster(cluster_name), region, zone, - **kwargs) + # TODO(suquark): move implementations of more clouds here + if isinstance(handle.launched_resources.cloud, clouds.AWS): + cloud_name = repr(handle.launched_resources.cloud) + try: + node_status_dict = provision_lib.query_instances( + cloud_name, cluster_name, provider_config) + node_statuses = list(node_status_dict.values()) + except Exception as e: # pylint: disable=broad-except + with ux_utils.print_exception_no_traceback(): + raise exceptions.ClusterStatusFetchingError( + f'Failed to query {cloud_name} cluster {cluster_name!r} ' + f'status: {e}') + else: + node_statuses = handle.launched_resources.cloud.query_status( + cluster_name, tag_filter_for_cluster(cluster_name), region, zone, + **kwargs) # GCP does not clean up preempted TPU VMs. We remove it ourselves. # TODO(wei-lin): handle multi-node cases. # TODO(zhwu): this should be moved into the GCP class, after we refactor @@ -1814,6 +1833,12 @@ def check_can_clone_disk_and_override_task( def _update_cluster_status_no_lock( cluster_name: str) -> Optional[Dict[str, Any]]: + """Updates the status of the cluster. + + Raises: + exceptions.ClusterStatusFetchingError: the cluster status cannot be + fetched from the cloud provider. + """ record = global_user_state.get_cluster_from_name(cluster_name) if record is None: return None diff --git a/sky/clouds/aws.py b/sky/clouds/aws.py index 2fbee67c4dd..677b4be6141 100644 --- a/sky/clouds/aws.py +++ b/sky/clouds/aws.py @@ -11,8 +11,8 @@ from sky import clouds from sky import exceptions +from sky import provision as provision_lib from sky import sky_logging -from sky import status_lib from sky.adaptors import aws from sky.clouds import service_catalog from sky.utils import common_utils @@ -23,6 +23,7 @@ if typing.TYPE_CHECKING: # renaming to avoid shadowing variables from sky import resources as resources_lib + from sky import status_lib logger = sky_logging.init_logger(__name__) @@ -741,80 +742,30 @@ def check_quota_available(cls, # Quota found to be greater than zero, try provisioning return True - @classmethod - def _query_instance_property_with_retries( - cls, - tag_filters: Dict[str, str], - region: str, - query: str, - ) -> Tuple[int, str, str]: - filter_str = ' '.join(f'Name=tag:{key},Values={value}' - for key, value in tag_filters.items()) - query_cmd = (f'aws ec2 describe-instances --filters {filter_str} ' - f'--region {region} --query "{query}" --output json') - returncode, stdout, stderr = subprocess_utils.run_with_retries( - query_cmd, - retry_returncode=[255], - retry_stderrs=[ - 'Unable to locate credentials. You can configure credentials by ' - 'running "aws configure"' - ]) - return returncode, stdout, stderr - @classmethod def query_status(cls, name: str, tag_filters: Dict[str, str], region: Optional[str], zone: Optional[str], **kwargs) -> List['status_lib.ClusterStatus']: - del zone # unused - status_map = { - 'pending': status_lib.ClusterStatus.INIT, - 'running': status_lib.ClusterStatus.UP, - # TODO(zhwu): stopping and shutting-down could occasionally fail - # due to internal errors of AWS. We should cover that case. - 'stopping': status_lib.ClusterStatus.STOPPED, - 'stopped': status_lib.ClusterStatus.STOPPED, - 'shutting-down': None, - 'terminated': None, - } - - assert region is not None, (tag_filters, region) - returncode, stdout, stderr = cls._query_instance_property_with_retries( - tag_filters, region, query='Reservations[].Instances[].State.Name') - - if returncode != 0: - with ux_utils.print_exception_no_traceback(): - raise exceptions.ClusterStatusFetchingError( - f'Failed to query AWS cluster {name!r} status: ' - f'{stdout + stderr}') - - original_statuses = json.loads(stdout.strip()) - - statuses = [] - for s in original_statuses: - node_status = status_map[s] - if node_status is not None: - statuses.append(node_status) - return statuses + # TODO(suquark): deprecate this method + assert False, 'This could path should not be used.' @classmethod def create_image_from_cluster(cls, cluster_name: str, tag_filters: Dict[str, str], region: Optional[str], zone: Optional[str]) -> str: - del zone # unused assert region is not None, (tag_filters, region) + del tag_filters, zone # unused + image_name = f'skypilot-{cluster_name}-{int(time.time())}' - returncode, stdout, stderr = cls._query_instance_property_with_retries( - tag_filters, region, query='Reservations[].Instances[].InstanceId') - subprocess_utils.handle_returncode( - returncode, - '', - error_msg='Failed to find the source cluster on AWS.', - stderr=stderr, - stream_logs=False) + status = provision_lib.query_instances('AWS', cluster_name, + {'region': region}) + instance_ids = list(status.keys()) + if not instance_ids: + with ux_utils.print_exception_no_traceback(): + raise RuntimeError('Failed to find the source cluster on AWS.') - instance_ids = json.loads(stdout.strip()) if len(instance_ids) != 1: with ux_utils.print_exception_no_traceback(): raise exceptions.NotSupportedError( @@ -845,7 +796,7 @@ def create_image_from_cluster(cls, cluster_name: str, wait_image_cmd = ( f'aws ec2 wait image-available --region {region} --image-ids {image_id}' ) - returncode, stdout, stderr = subprocess_utils.run_with_retries( + returncode, _, stderr = subprocess_utils.run_with_retries( wait_image_cmd, retry_returncode=[255], ) diff --git a/sky/provision/__init__.py b/sky/provision/__init__.py index 079bb8acc7b..c3206ac5e4d 100644 --- a/sky/provision/__init__.py +++ b/sky/provision/__init__.py @@ -9,6 +9,8 @@ import importlib import inspect +from sky import status_lib + def _route_to_cloud_impl(func): @@ -36,6 +38,23 @@ def _wrapper(*args, **kwargs): # TODO(suquark): Bring all other functions here from the +@_route_to_cloud_impl +def query_instances( + provider_name: str, + cluster_name: str, + provider_config: Optional[Dict[str, Any]] = None, + non_terminated_only: bool = True, +) -> Dict[str, Optional[status_lib.ClusterStatus]]: + """Query instances. + + Returns a dictionary of instance IDs and status. + + A None status means the instance is marked as "terminated" + or "terminating". + """ + raise NotImplementedError + + @_route_to_cloud_impl def stop_instances( provider_name: str, diff --git a/sky/provision/aws/__init__.py b/sky/provision/aws/__init__.py index 7dc4ad5acfd..c231cde3c80 100644 --- a/sky/provision/aws/__init__.py +++ b/sky/provision/aws/__init__.py @@ -1,3 +1,4 @@ """AWS provisioner for SkyPilot.""" -from sky.provision.aws.instance import stop_instances, terminate_instances +from sky.provision.aws.instance import (query_instances, terminate_instances, + stop_instances) diff --git a/sky/provision/aws/instance.py b/sky/provision/aws/instance.py index 53eccb77b22..dd3a973aa12 100644 --- a/sky/provision/aws/instance.py +++ b/sky/provision/aws/instance.py @@ -2,7 +2,9 @@ from typing import Dict, List, Any, Optional from botocore import config + from sky.adaptors import aws +from sky import status_lib BOTO_MAX_RETRIES = 12 # Tag uniquely identifying all nodes of a cluster @@ -10,6 +12,20 @@ TAG_RAY_NODE_KIND = 'ray-node-type' +def _default_ec2_resource(region: str) -> Any: + return aws.resource( + 'ec2', + region_name=region, + config=config.Config(retries={'max_attempts': BOTO_MAX_RETRIES})) + + +def _cluster_name_filter(cluster_name: str) -> List[Dict[str, Any]]: + return [{ + 'Name': f'tag:{TAG_RAY_CLUSTER_NAME}', + 'Values': [cluster_name], + }] + + def _filter_instances(ec2, filters: List[Dict[str, Any]], included_instances: Optional[List[str]], excluded_instances: Optional[List[str]]): @@ -28,6 +44,40 @@ def _filter_instances(ec2, filters: List[Dict[str, Any]], return instances +# TODO(suquark): Does it make sense to not expose this and always assume +# non_terminated_only=True? +# Will there be callers who would want this to be False? +# stop() and terminate() for example already implicitly assume non-terminated. +def query_instances( + cluster_name: str, + provider_config: Optional[Dict[str, Any]] = None, + non_terminated_only: bool = True, +) -> Dict[str, Optional[status_lib.ClusterStatus]]: + """See sky/provision/__init__.py""" + assert provider_config is not None, (cluster_name, provider_config) + region = provider_config['region'] + ec2 = _default_ec2_resource(region) + filters = _cluster_name_filter(cluster_name) + instances = ec2.instances.filter(Filters=filters) + status_map = { + 'pending': status_lib.ClusterStatus.INIT, + 'running': status_lib.ClusterStatus.UP, + # TODO(zhwu): stopping and shutting-down could occasionally fail + # due to internal errors of AWS. We should cover that case. + 'stopping': status_lib.ClusterStatus.STOPPED, + 'stopped': status_lib.ClusterStatus.STOPPED, + 'shutting-down': None, + 'terminated': None, + } + statuses = {} + for inst in instances: + status = status_map[inst.state['Name']] + if non_terminated_only and status is None: + continue + statuses[inst.id] = status + return statuses + + def stop_instances( cluster_name: str, provider_config: Optional[Dict[str, Any]] = None, @@ -36,19 +86,13 @@ def stop_instances( """See sky/provision/__init__.py""" assert provider_config is not None, (cluster_name, provider_config) region = provider_config['region'] - ec2 = aws.resource( - 'ec2', - region_name=region, - config=config.Config(retries={'max_attempts': BOTO_MAX_RETRIES})) - filters = [ + ec2 = _default_ec2_resource(region) + filters: List[Dict[str, Any]] = [ { 'Name': 'instance-state-name', 'Values': ['pending', 'running'], }, - { - 'Name': f'tag:{TAG_RAY_CLUSTER_NAME}', - 'Values': [cluster_name], - }, + *_cluster_name_filter(cluster_name), ] if worker_only: filters.append({ @@ -73,20 +117,14 @@ def terminate_instances( """See sky/provision/__init__.py""" assert provider_config is not None, (cluster_name, provider_config) region = provider_config['region'] - ec2 = aws.resource( - 'ec2', - region_name=region, - config=config.Config(retries={'max_attempts': BOTO_MAX_RETRIES})) + ec2 = _default_ec2_resource(region) filters = [ { 'Name': 'instance-state-name', # exclude 'shutting-down' or 'terminated' states 'Values': ['pending', 'running', 'stopping', 'stopped'], }, - { - 'Name': f'tag:{TAG_RAY_CLUSTER_NAME}', - 'Values': [cluster_name], - }, + *_cluster_name_filter(cluster_name), ] if worker_only: filters.append({