Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AWS] Adopt new provisioner to query clusters #2288

Merged
merged 7 commits into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from sky import clouds
from sky import exceptions
from sky import global_user_state
from sky import provision as provision_lib
from sky import skypilot_config
from sky import sky_logging
from sky import spot as spot_lib
Expand Down Expand Up @@ -1698,9 +1699,22 @@ def _query_cluster_status_via_cloud_api(
kwargs['use_tpu_vm'] = ray_config['provider'].get('_has_tpus', False)

# Query the cloud provider.
node_statuses = handle.launched_resources.cloud.query_status(
cluster_name, tag_filter_for_cluster(cluster_name), region, zone,
**kwargs)
# TODO(suquark): move implementations of more clouds here
if isinstance(handle.launched_resources.cloud, clouds.AWS):
cloud_name = repr(handle.launched_resources.cloud)
try:
node_status_dict = provision_lib.query_instances(
cloud_name, cluster_name, provider_config)
node_statuses = list(node_status_dict.values())
except Exception as e: # pylint: disable=broad-except
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What exceptions can be thrown by provision_lib.query_instances()? Should we document that?

Also, how would the caller of _query_cluster_status_via_cloud_api() handle them?

Does the previous codepath allow throwing any exceptions?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the previous codepath only checks the returncode of aws cli. so it catches general exceptions. We inherit this behavior here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a

    Raises:
        exceptions.ClusterStatusFetchingError: the cluster status cannot be
          fetched from the cloud provider.

to

  • _query_cluster_status_via_cloud_api
  • _update_cluster_status_no_lock

Just some code gardening.

with ux_utils.print_exception_no_traceback():
raise exceptions.ClusterStatusFetchingError(
f'Failed to query {cloud_name} cluster {cluster_name!r} '
f'status: {e}')
else:
node_statuses = handle.launched_resources.cloud.query_status(
cluster_name, tag_filter_for_cluster(cluster_name), region, zone,
**kwargs)
# GCP does not clean up preempted TPU VMs. We remove it ourselves.
# TODO(wei-lin): handle multi-node cases.
# TODO(zhwu): this should be moved into the GCP class, after we refactor
Expand Down
75 changes: 13 additions & 62 deletions sky/clouds/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

from sky import clouds
from sky import exceptions
from sky import provision as provision_lib
from sky import sky_logging
from sky import status_lib
from sky.adaptors import aws
from sky.clouds import service_catalog
from sky.utils import common_utils
Expand All @@ -23,6 +23,7 @@
if typing.TYPE_CHECKING:
# renaming to avoid shadowing variables
from sky import resources as resources_lib
from sky import status_lib

logger = sky_logging.init_logger(__name__)

Expand Down Expand Up @@ -741,80 +742,30 @@ def check_quota_available(cls,
# Quota found to be greater than zero, try provisioning
return True

@classmethod
def _query_instance_property_with_retries(
cls,
tag_filters: Dict[str, str],
region: str,
query: str,
) -> Tuple[int, str, str]:
filter_str = ' '.join(f'Name=tag:{key},Values={value}'
for key, value in tag_filters.items())
query_cmd = (f'aws ec2 describe-instances --filters {filter_str} '
f'--region {region} --query "{query}" --output json')
returncode, stdout, stderr = subprocess_utils.run_with_retries(
query_cmd,
retry_returncode=[255],
retry_stderrs=[
'Unable to locate credentials. You can configure credentials by '
'running "aws configure"'
])
Comment on lines -758 to -761
Copy link
Member

@concretevitamin concretevitamin Jul 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When reviewing #2314, I realized this code (originally added in #1988) was accidentally left out from this PR/master branch.

Could we add it back? #1988 has context. Tldr: previously users have encountered "ec2 describe-instances" throwing NoCredentialsError with this message (the programmatic client may only throw the Unable to locate credentials part as the exception message).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this reminds of another problem.

Previously, we retry by issuing another CLI cmd "ec2 describe-instances" with a subprocess. This probably means a new underlying boto3 client is created for each retry. This could be the reason the retry has mitigated this problem. E.g., abandoning a malfunctioning client.

With this PR, even if we add retry back, it'll access

@functools.lru_cache()
def client(service_name: str, **kwargs):
"""Create an AWS client of a certain service.
Args:
service_name: AWS service name (e.g., 's3', 'ec2').
kwargs: Other options.
"""
# Need to use the client retrieved from the per-thread session
# to avoid thread-safety issues (Directly creating the client
# with boto3.client() is not thread-safe).
# Reference: https://stackoverflow.com/a/59635814
return session().client(service_name, **kwargs)
which is LRU-cached per thread(?). So if we retry using the same thread, it may not have the same effect.

This is all speculation since we don't have a reliable repro. That said, could we somehow force create a new boto3 client when we retry?

return returncode, stdout, stderr

@classmethod
def query_status(cls, name: str, tag_filters: Dict[str, str],
region: Optional[str], zone: Optional[str],
**kwargs) -> List['status_lib.ClusterStatus']:
del zone # unused
status_map = {
'pending': status_lib.ClusterStatus.INIT,
'running': status_lib.ClusterStatus.UP,
# TODO(zhwu): stopping and shutting-down could occasionally fail
# due to internal errors of AWS. We should cover that case.
'stopping': status_lib.ClusterStatus.STOPPED,
'stopped': status_lib.ClusterStatus.STOPPED,
'shutting-down': None,
'terminated': None,
}

assert region is not None, (tag_filters, region)
returncode, stdout, stderr = cls._query_instance_property_with_retries(
tag_filters, region, query='Reservations[].Instances[].State.Name')

if returncode != 0:
with ux_utils.print_exception_no_traceback():
raise exceptions.ClusterStatusFetchingError(
f'Failed to query AWS cluster {name!r} status: '
f'{stdout + stderr}')

original_statuses = json.loads(stdout.strip())

statuses = []
for s in original_statuses:
node_status = status_map[s]
if node_status is not None:
statuses.append(node_status)
return statuses
# TODO(suquark): deprecate this method
assert False, 'This could path should not be used.'

@classmethod
def create_image_from_cluster(cls, cluster_name: str,
tag_filters: Dict[str,
str], region: Optional[str],
zone: Optional[str]) -> str:
del zone # unused
assert region is not None, (tag_filters, region)
del tag_filters, zone # unused

image_name = f'skypilot-{cluster_name}-{int(time.time())}'
returncode, stdout, stderr = cls._query_instance_property_with_retries(
tag_filters, region, query='Reservations[].Instances[].InstanceId')

subprocess_utils.handle_returncode(
returncode,
'',
error_msg='Failed to find the source cluster on AWS.',
stderr=stderr,
stream_logs=False)
status = provision_lib.query_instances('AWS', cluster_name,
{'region': region})
instance_ids = list(status.keys())
if not instance_ids:
with ux_utils.print_exception_no_traceback():
raise RuntimeError('Failed to find the source cluster on AWS.')

instance_ids = json.loads(stdout.strip())
if len(instance_ids) != 1:
with ux_utils.print_exception_no_traceback():
raise exceptions.NotSupportedError(
Expand Down Expand Up @@ -845,7 +796,7 @@ def create_image_from_cluster(cls, cluster_name: str,
wait_image_cmd = (
f'aws ec2 wait image-available --region {region} --image-ids {image_id}'
)
returncode, stdout, stderr = subprocess_utils.run_with_retries(
returncode, _, stderr = subprocess_utils.run_with_retries(
wait_image_cmd,
retry_returncode=[255],
)
Expand Down
18 changes: 18 additions & 0 deletions sky/provision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import importlib
import inspect

from sky import status_lib
suquark marked this conversation as resolved.
Show resolved Hide resolved


def _route_to_cloud_impl(func):

Expand Down Expand Up @@ -36,6 +38,22 @@ def _wrapper(*args, **kwargs):
# TODO(suquark): Bring all other functions here from the


@_route_to_cloud_impl
def query_instances(
provider_name: str,
cluster_name: str,
provider_config: Optional[Dict[str, Any]] = None,
) -> Dict[str, Optional[status_lib.ClusterStatus]]:
"""Query instances.

Returns a dictionary of instance IDs and status.
suquark marked this conversation as resolved.
Show resolved Hide resolved

A None status means the instance is marked as "terminated"
or "terminating".
"""
raise NotImplementedError


@_route_to_cloud_impl
def stop_instances(
provider_name: str,
Expand Down
3 changes: 2 additions & 1 deletion sky/provision/aws/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""AWS provisioner for SkyPilot."""

from sky.provision.aws.instance import stop_instances, terminate_instances
from sky.provision.aws.instance import (query_instances, terminate_instances,
stop_instances)
65 changes: 48 additions & 17 deletions sky/provision/aws/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,30 @@
from typing import Dict, List, Any, Optional

from botocore import config
suquark marked this conversation as resolved.
Show resolved Hide resolved

from sky.adaptors import aws
from sky import status_lib

BOTO_MAX_RETRIES = 12
# Tag uniquely identifying all nodes of a cluster
TAG_RAY_CLUSTER_NAME = 'ray-cluster-name'
TAG_RAY_NODE_KIND = 'ray-node-type'


def _default_ec2_resource(region: str) -> Any:
return aws.resource(
'ec2',
region_name=region,
config=config.Config(retries={'max_attempts': BOTO_MAX_RETRIES}))


def _cluster_name_filter(cluster_name: str) -> List[Dict[str, Any]]:
return [{
'Name': f'tag:{TAG_RAY_CLUSTER_NAME}',
'Values': [cluster_name],
}]


def _filter_instances(ec2, filters: List[Dict[str, Any]],
included_instances: Optional[List[str]],
excluded_instances: Optional[List[str]]):
Expand All @@ -28,6 +44,33 @@ def _filter_instances(ec2, filters: List[Dict[str, Any]],
return instances


def query_instances(
cluster_name: str,
provider_config: Optional[Dict[str, Any]] = None,
) -> Dict[str, Optional[status_lib.ClusterStatus]]:
"""See sky/provision/__init__.py"""
assert provider_config is not None, (cluster_name, provider_config)
region = provider_config['region']
ec2 = _default_ec2_resource(region)
filters = _cluster_name_filter(cluster_name)
instances = ec2.instances.filter(Filters=filters)
status_map = {
'pending': status_lib.ClusterStatus.INIT,
'running': status_lib.ClusterStatus.UP,
# TODO(zhwu): stopping and shutting-down could occasionally fail
# due to internal errors of AWS. We should cover that case.
'stopping': status_lib.ClusterStatus.STOPPED,
'stopped': status_lib.ClusterStatus.STOPPED,
'shutting-down': None,
'terminated': None,
}
statuses = {}
for inst in instances:
state = inst.state['Name']
statuses[inst.id] = status_map[state]
return statuses


def stop_instances(
cluster_name: str,
provider_config: Optional[Dict[str, Any]] = None,
Expand All @@ -36,19 +79,13 @@ def stop_instances(
"""See sky/provision/__init__.py"""
assert provider_config is not None, (cluster_name, provider_config)
region = provider_config['region']
ec2 = aws.resource(
'ec2',
region_name=region,
config=config.Config(retries={'max_attempts': BOTO_MAX_RETRIES}))
filters = [
ec2 = _default_ec2_resource(region)
filters: List[Dict[str, Any]] = [
{
'Name': 'instance-state-name',
'Values': ['pending', 'running'],
},
{
'Name': f'tag:{TAG_RAY_CLUSTER_NAME}',
'Values': [cluster_name],
},
*_cluster_name_filter(cluster_name),
]
if worker_only:
filters.append({
Expand All @@ -73,20 +110,14 @@ def terminate_instances(
"""See sky/provision/__init__.py"""
assert provider_config is not None, (cluster_name, provider_config)
region = provider_config['region']
ec2 = aws.resource(
'ec2',
region_name=region,
config=config.Config(retries={'max_attempts': BOTO_MAX_RETRIES}))
ec2 = _default_ec2_resource(region)
filters = [
{
'Name': 'instance-state-name',
# exclude 'shutting-down' or 'terminated' states
'Values': ['pending', 'running', 'stopping', 'stopped'],
},
{
'Name': f'tag:{TAG_RAY_CLUSTER_NAME}',
'Values': [cluster_name],
},
*_cluster_name_filter(cluster_name),
]
if worker_only:
filters.append({
Expand Down