Skip to content

Commit

Permalink
[AWS] Adopt new provisioner to query clusters (#2288)
Browse files Browse the repository at this point in the history
Adopt new provisioner to query cluster status, i.e. instances and their status.
  • Loading branch information
suquark authored Jul 26, 2023
1 parent cfcae0d commit fe3360d
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 84 deletions.
33 changes: 29 additions & 4 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from sky import clouds
from sky import exceptions
from sky import global_user_state
from sky import provision as provision_lib
from sky import skypilot_config
from sky import sky_logging
from sky import spot as spot_lib
Expand Down Expand Up @@ -1684,7 +1685,12 @@ def tag_filter_for_cluster(cluster_name: str) -> Dict[str, str]:
def _query_cluster_status_via_cloud_api(
handle: 'cloud_vm_ray_backend.CloudVmRayResourceHandle'
) -> List[status_lib.ClusterStatus]:
"""Returns the status of the cluster."""
"""Returns the status of the cluster.
Raises:
exceptions.ClusterStatusFetchingError: the cluster status cannot be
fetched from the cloud provider.
"""
cluster_name = handle.cluster_name
# Use region and zone from the cluster config, instead of the
# handle.launched_resources, because the latter may not be set
Expand All @@ -1698,9 +1704,22 @@ def _query_cluster_status_via_cloud_api(
kwargs['use_tpu_vm'] = ray_config['provider'].get('_has_tpus', False)

# Query the cloud provider.
node_statuses = handle.launched_resources.cloud.query_status(
cluster_name, tag_filter_for_cluster(cluster_name), region, zone,
**kwargs)
# TODO(suquark): move implementations of more clouds here
if isinstance(handle.launched_resources.cloud, clouds.AWS):
cloud_name = repr(handle.launched_resources.cloud)
try:
node_status_dict = provision_lib.query_instances(
cloud_name, cluster_name, provider_config)
node_statuses = list(node_status_dict.values())
except Exception as e: # pylint: disable=broad-except
with ux_utils.print_exception_no_traceback():
raise exceptions.ClusterStatusFetchingError(
f'Failed to query {cloud_name} cluster {cluster_name!r} '
f'status: {e}')
else:
node_statuses = handle.launched_resources.cloud.query_status(
cluster_name, tag_filter_for_cluster(cluster_name), region, zone,
**kwargs)
# GCP does not clean up preempted TPU VMs. We remove it ourselves.
# TODO(wei-lin): handle multi-node cases.
# TODO(zhwu): this should be moved into the GCP class, after we refactor
Expand Down Expand Up @@ -1814,6 +1833,12 @@ def check_can_clone_disk_and_override_task(

def _update_cluster_status_no_lock(
cluster_name: str) -> Optional[Dict[str, Any]]:
"""Updates the status of the cluster.
Raises:
exceptions.ClusterStatusFetchingError: the cluster status cannot be
fetched from the cloud provider.
"""
record = global_user_state.get_cluster_from_name(cluster_name)
if record is None:
return None
Expand Down
75 changes: 13 additions & 62 deletions sky/clouds/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

from sky import clouds
from sky import exceptions
from sky import provision as provision_lib
from sky import sky_logging
from sky import status_lib
from sky.adaptors import aws
from sky.clouds import service_catalog
from sky.utils import common_utils
Expand All @@ -23,6 +23,7 @@
if typing.TYPE_CHECKING:
# renaming to avoid shadowing variables
from sky import resources as resources_lib
from sky import status_lib

logger = sky_logging.init_logger(__name__)

Expand Down Expand Up @@ -741,80 +742,30 @@ def check_quota_available(cls,
# Quota found to be greater than zero, try provisioning
return True

@classmethod
def _query_instance_property_with_retries(
cls,
tag_filters: Dict[str, str],
region: str,
query: str,
) -> Tuple[int, str, str]:
filter_str = ' '.join(f'Name=tag:{key},Values={value}'
for key, value in tag_filters.items())
query_cmd = (f'aws ec2 describe-instances --filters {filter_str} '
f'--region {region} --query "{query}" --output json')
returncode, stdout, stderr = subprocess_utils.run_with_retries(
query_cmd,
retry_returncode=[255],
retry_stderrs=[
'Unable to locate credentials. You can configure credentials by '
'running "aws configure"'
])
return returncode, stdout, stderr

@classmethod
def query_status(cls, name: str, tag_filters: Dict[str, str],
region: Optional[str], zone: Optional[str],
**kwargs) -> List['status_lib.ClusterStatus']:
del zone # unused
status_map = {
'pending': status_lib.ClusterStatus.INIT,
'running': status_lib.ClusterStatus.UP,
# TODO(zhwu): stopping and shutting-down could occasionally fail
# due to internal errors of AWS. We should cover that case.
'stopping': status_lib.ClusterStatus.STOPPED,
'stopped': status_lib.ClusterStatus.STOPPED,
'shutting-down': None,
'terminated': None,
}

assert region is not None, (tag_filters, region)
returncode, stdout, stderr = cls._query_instance_property_with_retries(
tag_filters, region, query='Reservations[].Instances[].State.Name')

if returncode != 0:
with ux_utils.print_exception_no_traceback():
raise exceptions.ClusterStatusFetchingError(
f'Failed to query AWS cluster {name!r} status: '
f'{stdout + stderr}')

original_statuses = json.loads(stdout.strip())

statuses = []
for s in original_statuses:
node_status = status_map[s]
if node_status is not None:
statuses.append(node_status)
return statuses
# TODO(suquark): deprecate this method
assert False, 'This could path should not be used.'

@classmethod
def create_image_from_cluster(cls, cluster_name: str,
tag_filters: Dict[str,
str], region: Optional[str],
zone: Optional[str]) -> str:
del zone # unused
assert region is not None, (tag_filters, region)
del tag_filters, zone # unused

image_name = f'skypilot-{cluster_name}-{int(time.time())}'
returncode, stdout, stderr = cls._query_instance_property_with_retries(
tag_filters, region, query='Reservations[].Instances[].InstanceId')

subprocess_utils.handle_returncode(
returncode,
'',
error_msg='Failed to find the source cluster on AWS.',
stderr=stderr,
stream_logs=False)
status = provision_lib.query_instances('AWS', cluster_name,
{'region': region})
instance_ids = list(status.keys())
if not instance_ids:
with ux_utils.print_exception_no_traceback():
raise RuntimeError('Failed to find the source cluster on AWS.')

instance_ids = json.loads(stdout.strip())
if len(instance_ids) != 1:
with ux_utils.print_exception_no_traceback():
raise exceptions.NotSupportedError(
Expand Down Expand Up @@ -845,7 +796,7 @@ def create_image_from_cluster(cls, cluster_name: str,
wait_image_cmd = (
f'aws ec2 wait image-available --region {region} --image-ids {image_id}'
)
returncode, stdout, stderr = subprocess_utils.run_with_retries(
returncode, _, stderr = subprocess_utils.run_with_retries(
wait_image_cmd,
retry_returncode=[255],
)
Expand Down
19 changes: 19 additions & 0 deletions sky/provision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import importlib
import inspect

from sky import status_lib


def _route_to_cloud_impl(func):

Expand Down Expand Up @@ -36,6 +38,23 @@ def _wrapper(*args, **kwargs):
# TODO(suquark): Bring all other functions here from the


@_route_to_cloud_impl
def query_instances(
provider_name: str,
cluster_name: str,
provider_config: Optional[Dict[str, Any]] = None,
non_terminated_only: bool = True,
) -> Dict[str, Optional[status_lib.ClusterStatus]]:
"""Query instances.
Returns a dictionary of instance IDs and status.
A None status means the instance is marked as "terminated"
or "terminating".
"""
raise NotImplementedError


@_route_to_cloud_impl
def stop_instances(
provider_name: str,
Expand Down
3 changes: 2 additions & 1 deletion sky/provision/aws/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""AWS provisioner for SkyPilot."""

from sky.provision.aws.instance import stop_instances, terminate_instances
from sky.provision.aws.instance import (query_instances, terminate_instances,
stop_instances)
72 changes: 55 additions & 17 deletions sky/provision/aws/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,30 @@
from typing import Dict, List, Any, Optional

from botocore import config

from sky.adaptors import aws
from sky import status_lib

BOTO_MAX_RETRIES = 12
# Tag uniquely identifying all nodes of a cluster
TAG_RAY_CLUSTER_NAME = 'ray-cluster-name'
TAG_RAY_NODE_KIND = 'ray-node-type'


def _default_ec2_resource(region: str) -> Any:
return aws.resource(
'ec2',
region_name=region,
config=config.Config(retries={'max_attempts': BOTO_MAX_RETRIES}))


def _cluster_name_filter(cluster_name: str) -> List[Dict[str, Any]]:
return [{
'Name': f'tag:{TAG_RAY_CLUSTER_NAME}',
'Values': [cluster_name],
}]


def _filter_instances(ec2, filters: List[Dict[str, Any]],
included_instances: Optional[List[str]],
excluded_instances: Optional[List[str]]):
Expand All @@ -28,6 +44,40 @@ def _filter_instances(ec2, filters: List[Dict[str, Any]],
return instances


# TODO(suquark): Does it make sense to not expose this and always assume
# non_terminated_only=True?
# Will there be callers who would want this to be False?
# stop() and terminate() for example already implicitly assume non-terminated.
def query_instances(
cluster_name: str,
provider_config: Optional[Dict[str, Any]] = None,
non_terminated_only: bool = True,
) -> Dict[str, Optional[status_lib.ClusterStatus]]:
"""See sky/provision/__init__.py"""
assert provider_config is not None, (cluster_name, provider_config)
region = provider_config['region']
ec2 = _default_ec2_resource(region)
filters = _cluster_name_filter(cluster_name)
instances = ec2.instances.filter(Filters=filters)
status_map = {
'pending': status_lib.ClusterStatus.INIT,
'running': status_lib.ClusterStatus.UP,
# TODO(zhwu): stopping and shutting-down could occasionally fail
# due to internal errors of AWS. We should cover that case.
'stopping': status_lib.ClusterStatus.STOPPED,
'stopped': status_lib.ClusterStatus.STOPPED,
'shutting-down': None,
'terminated': None,
}
statuses = {}
for inst in instances:
status = status_map[inst.state['Name']]
if non_terminated_only and status is None:
continue
statuses[inst.id] = status
return statuses


def stop_instances(
cluster_name: str,
provider_config: Optional[Dict[str, Any]] = None,
Expand All @@ -36,19 +86,13 @@ def stop_instances(
"""See sky/provision/__init__.py"""
assert provider_config is not None, (cluster_name, provider_config)
region = provider_config['region']
ec2 = aws.resource(
'ec2',
region_name=region,
config=config.Config(retries={'max_attempts': BOTO_MAX_RETRIES}))
filters = [
ec2 = _default_ec2_resource(region)
filters: List[Dict[str, Any]] = [
{
'Name': 'instance-state-name',
'Values': ['pending', 'running'],
},
{
'Name': f'tag:{TAG_RAY_CLUSTER_NAME}',
'Values': [cluster_name],
},
*_cluster_name_filter(cluster_name),
]
if worker_only:
filters.append({
Expand All @@ -73,20 +117,14 @@ def terminate_instances(
"""See sky/provision/__init__.py"""
assert provider_config is not None, (cluster_name, provider_config)
region = provider_config['region']
ec2 = aws.resource(
'ec2',
region_name=region,
config=config.Config(retries={'max_attempts': BOTO_MAX_RETRIES}))
ec2 = _default_ec2_resource(region)
filters = [
{
'Name': 'instance-state-name',
# exclude 'shutting-down' or 'terminated' states
'Values': ['pending', 'running', 'stopping', 'stopped'],
},
{
'Name': f'tag:{TAG_RAY_CLUSTER_NAME}',
'Values': [cluster_name],
},
*_cluster_name_filter(cluster_name),
]
if worker_only:
filters.append({
Expand Down

0 comments on commit fe3360d

Please sign in to comment.