Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AWS] Adopt new provisioner to query clusters #2288

Merged
merged 7 commits into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from sky import clouds
from sky import exceptions
from sky import global_user_state
from sky import provision as provision_lib
from sky import skypilot_config
from sky import sky_logging
from sky import spot as spot_lib
Expand Down Expand Up @@ -1684,7 +1685,12 @@ def tag_filter_for_cluster(cluster_name: str) -> Dict[str, str]:
def _query_cluster_status_via_cloud_api(
handle: 'cloud_vm_ray_backend.CloudVmRayResourceHandle'
) -> List[status_lib.ClusterStatus]:
"""Returns the status of the cluster."""
"""Returns the status of the cluster.

Raises:
exceptions.ClusterStatusFetchingError: the cluster status cannot be
fetched from the cloud provider.
"""
cluster_name = handle.cluster_name
# Use region and zone from the cluster config, instead of the
# handle.launched_resources, because the latter may not be set
Expand All @@ -1698,9 +1704,22 @@ def _query_cluster_status_via_cloud_api(
kwargs['use_tpu_vm'] = ray_config['provider'].get('_has_tpus', False)

# Query the cloud provider.
node_statuses = handle.launched_resources.cloud.query_status(
cluster_name, tag_filter_for_cluster(cluster_name), region, zone,
**kwargs)
# TODO(suquark): move implementations of more clouds here
if isinstance(handle.launched_resources.cloud, clouds.AWS):
cloud_name = repr(handle.launched_resources.cloud)
try:
node_status_dict = provision_lib.query_instances(
cloud_name, cluster_name, provider_config)
node_statuses = list(node_status_dict.values())
except Exception as e: # pylint: disable=broad-except
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What exceptions can be thrown by provision_lib.query_instances()? Should we document that?

Also, how would the caller of _query_cluster_status_via_cloud_api() handle them?

Does the previous codepath allow throwing any exceptions?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the previous codepath only checks the returncode of aws cli. so it catches general exceptions. We inherit this behavior here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a

    Raises:
        exceptions.ClusterStatusFetchingError: the cluster status cannot be
          fetched from the cloud provider.

to

  • _query_cluster_status_via_cloud_api
  • _update_cluster_status_no_lock

Just some code gardening.

with ux_utils.print_exception_no_traceback():
raise exceptions.ClusterStatusFetchingError(
f'Failed to query {cloud_name} cluster {cluster_name!r} '
f'status: {e}')
else:
node_statuses = handle.launched_resources.cloud.query_status(
cluster_name, tag_filter_for_cluster(cluster_name), region, zone,
**kwargs)
# GCP does not clean up preempted TPU VMs. We remove it ourselves.
# TODO(wei-lin): handle multi-node cases.
# TODO(zhwu): this should be moved into the GCP class, after we refactor
Expand Down Expand Up @@ -1814,6 +1833,12 @@ def check_can_clone_disk_and_override_task(

def _update_cluster_status_no_lock(
cluster_name: str) -> Optional[Dict[str, Any]]:
"""Updates the status of the cluster.

Raises:
exceptions.ClusterStatusFetchingError: the cluster status cannot be
fetched from the cloud provider.
"""
record = global_user_state.get_cluster_from_name(cluster_name)
if record is None:
return None
Expand Down
75 changes: 13 additions & 62 deletions sky/clouds/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

from sky import clouds
from sky import exceptions
from sky import provision as provision_lib
from sky import sky_logging
from sky import status_lib
from sky.adaptors import aws
from sky.clouds import service_catalog
from sky.utils import common_utils
Expand All @@ -23,6 +23,7 @@
if typing.TYPE_CHECKING:
# renaming to avoid shadowing variables
from sky import resources as resources_lib
from sky import status_lib

logger = sky_logging.init_logger(__name__)

Expand Down Expand Up @@ -741,80 +742,30 @@ def check_quota_available(cls,
# Quota found to be greater than zero, try provisioning
return True

@classmethod
def _query_instance_property_with_retries(
cls,
tag_filters: Dict[str, str],
region: str,
query: str,
) -> Tuple[int, str, str]:
filter_str = ' '.join(f'Name=tag:{key},Values={value}'
for key, value in tag_filters.items())
query_cmd = (f'aws ec2 describe-instances --filters {filter_str} '
f'--region {region} --query "{query}" --output json')
returncode, stdout, stderr = subprocess_utils.run_with_retries(
query_cmd,
retry_returncode=[255],
retry_stderrs=[
'Unable to locate credentials. You can configure credentials by '
'running "aws configure"'
])
Comment on lines -758 to -761
Copy link
Member

@concretevitamin concretevitamin Jul 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When reviewing #2314, I realized this code (originally added in #1988) was accidentally left out from this PR/master branch.

Could we add it back? #1988 has context. Tldr: previously users have encountered "ec2 describe-instances" throwing NoCredentialsError with this message (the programmatic client may only throw the Unable to locate credentials part as the exception message).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this reminds of another problem.

Previously, we retry by issuing another CLI cmd "ec2 describe-instances" with a subprocess. This probably means a new underlying boto3 client is created for each retry. This could be the reason the retry has mitigated this problem. E.g., abandoning a malfunctioning client.

With this PR, even if we add retry back, it'll access

@functools.lru_cache()
def client(service_name: str, **kwargs):
"""Create an AWS client of a certain service.
Args:
service_name: AWS service name (e.g., 's3', 'ec2').
kwargs: Other options.
"""
# Need to use the client retrieved from the per-thread session
# to avoid thread-safety issues (Directly creating the client
# with boto3.client() is not thread-safe).
# Reference: https://stackoverflow.com/a/59635814
return session().client(service_name, **kwargs)
which is LRU-cached per thread(?). So if we retry using the same thread, it may not have the same effect.

This is all speculation since we don't have a reliable repro. That said, could we somehow force create a new boto3 client when we retry?

return returncode, stdout, stderr

@classmethod
def query_status(cls, name: str, tag_filters: Dict[str, str],
region: Optional[str], zone: Optional[str],
**kwargs) -> List['status_lib.ClusterStatus']:
del zone # unused
status_map = {
'pending': status_lib.ClusterStatus.INIT,
'running': status_lib.ClusterStatus.UP,
# TODO(zhwu): stopping and shutting-down could occasionally fail
# due to internal errors of AWS. We should cover that case.
'stopping': status_lib.ClusterStatus.STOPPED,
'stopped': status_lib.ClusterStatus.STOPPED,
'shutting-down': None,
'terminated': None,
}

assert region is not None, (tag_filters, region)
returncode, stdout, stderr = cls._query_instance_property_with_retries(
tag_filters, region, query='Reservations[].Instances[].State.Name')

if returncode != 0:
with ux_utils.print_exception_no_traceback():
raise exceptions.ClusterStatusFetchingError(
f'Failed to query AWS cluster {name!r} status: '
f'{stdout + stderr}')

original_statuses = json.loads(stdout.strip())

statuses = []
for s in original_statuses:
node_status = status_map[s]
if node_status is not None:
statuses.append(node_status)
return statuses
# TODO(suquark): deprecate this method
assert False, 'This could path should not be used.'

@classmethod
def create_image_from_cluster(cls, cluster_name: str,
tag_filters: Dict[str,
str], region: Optional[str],
zone: Optional[str]) -> str:
del zone # unused
assert region is not None, (tag_filters, region)
del tag_filters, zone # unused

image_name = f'skypilot-{cluster_name}-{int(time.time())}'
returncode, stdout, stderr = cls._query_instance_property_with_retries(
tag_filters, region, query='Reservations[].Instances[].InstanceId')

subprocess_utils.handle_returncode(
returncode,
'',
error_msg='Failed to find the source cluster on AWS.',
stderr=stderr,
stream_logs=False)
status = provision_lib.query_instances('AWS', cluster_name,
{'region': region})
instance_ids = list(status.keys())
if not instance_ids:
with ux_utils.print_exception_no_traceback():
raise RuntimeError('Failed to find the source cluster on AWS.')

instance_ids = json.loads(stdout.strip())
if len(instance_ids) != 1:
with ux_utils.print_exception_no_traceback():
raise exceptions.NotSupportedError(
Expand Down Expand Up @@ -845,7 +796,7 @@ def create_image_from_cluster(cls, cluster_name: str,
wait_image_cmd = (
f'aws ec2 wait image-available --region {region} --image-ids {image_id}'
)
returncode, stdout, stderr = subprocess_utils.run_with_retries(
returncode, _, stderr = subprocess_utils.run_with_retries(
wait_image_cmd,
retry_returncode=[255],
)
Expand Down
19 changes: 19 additions & 0 deletions sky/provision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import importlib
import inspect

from sky import status_lib
suquark marked this conversation as resolved.
Show resolved Hide resolved


def _route_to_cloud_impl(func):

Expand Down Expand Up @@ -36,6 +38,23 @@ def _wrapper(*args, **kwargs):
# TODO(suquark): Bring all other functions here from the


@_route_to_cloud_impl
def query_instances(
provider_name: str,
cluster_name: str,
provider_config: Optional[Dict[str, Any]] = None,
non_terminated_only: bool = True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For discussion: Does it make sense to not expose this and always assume non_terminated_only=True? Will there be callers who would want this to be False?

stop() and terminate() for example already implicitly assume non-terminated, e.g.,

    filters = [
        {
            'Name': 'instance-state-name',
            # exclude 'shutting-down' or 'terminated' states
            'Values': ['pending', 'running', 'stopping', 'stopped'],
        },
        *_cluster_name_filter(cluster_name),
    ]

Also similar to node providers' design of get_nonterminated_nodes().

We can certainly leave this for the future.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me leave a comment about it

) -> Dict[str, Optional[status_lib.ClusterStatus]]:
"""Query instances.

Returns a dictionary of instance IDs and status.
suquark marked this conversation as resolved.
Show resolved Hide resolved

A None status means the instance is marked as "terminated"
or "terminating".
"""
raise NotImplementedError


@_route_to_cloud_impl
def stop_instances(
provider_name: str,
Expand Down
3 changes: 2 additions & 1 deletion sky/provision/aws/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""AWS provisioner for SkyPilot."""

from sky.provision.aws.instance import stop_instances, terminate_instances
from sky.provision.aws.instance import (query_instances, terminate_instances,
stop_instances)
68 changes: 51 additions & 17 deletions sky/provision/aws/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,30 @@
from typing import Dict, List, Any, Optional

from botocore import config
suquark marked this conversation as resolved.
Show resolved Hide resolved

from sky.adaptors import aws
from sky import status_lib

BOTO_MAX_RETRIES = 12
# Tag uniquely identifying all nodes of a cluster
TAG_RAY_CLUSTER_NAME = 'ray-cluster-name'
TAG_RAY_NODE_KIND = 'ray-node-type'


def _default_ec2_resource(region: str) -> Any:
return aws.resource(
'ec2',
region_name=region,
config=config.Config(retries={'max_attempts': BOTO_MAX_RETRIES}))


def _cluster_name_filter(cluster_name: str) -> List[Dict[str, Any]]:
return [{
'Name': f'tag:{TAG_RAY_CLUSTER_NAME}',
'Values': [cluster_name],
}]


def _filter_instances(ec2, filters: List[Dict[str, Any]],
included_instances: Optional[List[str]],
excluded_instances: Optional[List[str]]):
Expand All @@ -28,6 +44,36 @@ def _filter_instances(ec2, filters: List[Dict[str, Any]],
return instances


def query_instances(
cluster_name: str,
provider_config: Optional[Dict[str, Any]] = None,
non_terminated_only: bool = True,
) -> Dict[str, Optional[status_lib.ClusterStatus]]:
"""See sky/provision/__init__.py"""
assert provider_config is not None, (cluster_name, provider_config)
region = provider_config['region']
ec2 = _default_ec2_resource(region)
filters = _cluster_name_filter(cluster_name)
instances = ec2.instances.filter(Filters=filters)
status_map = {
'pending': status_lib.ClusterStatus.INIT,
'running': status_lib.ClusterStatus.UP,
# TODO(zhwu): stopping and shutting-down could occasionally fail
# due to internal errors of AWS. We should cover that case.
'stopping': status_lib.ClusterStatus.STOPPED,
'stopped': status_lib.ClusterStatus.STOPPED,
'shutting-down': None,
'terminated': None,
}
statuses = {}
for inst in instances:
status = status_map[inst.state['Name']]
if non_terminated_only and status is None:
continue
statuses[inst.id] = status
return statuses


def stop_instances(
cluster_name: str,
provider_config: Optional[Dict[str, Any]] = None,
Expand All @@ -36,19 +82,13 @@ def stop_instances(
"""See sky/provision/__init__.py"""
assert provider_config is not None, (cluster_name, provider_config)
region = provider_config['region']
ec2 = aws.resource(
'ec2',
region_name=region,
config=config.Config(retries={'max_attempts': BOTO_MAX_RETRIES}))
filters = [
ec2 = _default_ec2_resource(region)
filters: List[Dict[str, Any]] = [
{
'Name': 'instance-state-name',
'Values': ['pending', 'running'],
},
{
'Name': f'tag:{TAG_RAY_CLUSTER_NAME}',
'Values': [cluster_name],
},
*_cluster_name_filter(cluster_name),
]
if worker_only:
filters.append({
Expand All @@ -73,20 +113,14 @@ def terminate_instances(
"""See sky/provision/__init__.py"""
assert provider_config is not None, (cluster_name, provider_config)
region = provider_config['region']
ec2 = aws.resource(
'ec2',
region_name=region,
config=config.Config(retries={'max_attempts': BOTO_MAX_RETRIES}))
ec2 = _default_ec2_resource(region)
filters = [
{
'Name': 'instance-state-name',
# exclude 'shutting-down' or 'terminated' states
'Values': ['pending', 'running', 'stopping', 'stopped'],
},
{
'Name': f'tag:{TAG_RAY_CLUSTER_NAME}',
'Values': [cluster_name],
},
*_cluster_name_filter(cluster_name),
]
if worker_only:
filters.append({
Expand Down