Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AWS] Fix fetching AZ when describe zones permission does not exist in all regions #2463

Merged
merged 12 commits into from
Aug 30, 2023
4 changes: 2 additions & 2 deletions sky/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def check(quiet: bool = False, verbose: bool = False) -> None:
if not isinstance(cloud, clouds.Local):
echo(' ' + click.style(
f'{cloud}: {status_msg}', fg=status_color, bold=True) +
' ' * 10)
' ' * 30)
if ok:
enabled_clouds.append(str(cloud))
if verbose:
Expand All @@ -48,7 +48,7 @@ def check(quiet: bool = False, verbose: bool = False) -> None:
status_color = 'green' if r2_is_enabled else 'red'
echo(' ' +
click.style(f'{cloud}: {status_msg}', fg=status_color, bold=True) +
' ' * 10)
' ' * 30)
if not r2_is_enabled:
echo(f' Reason: {reason}')

Expand Down
6 changes: 4 additions & 2 deletions sky/clouds/service_catalog/aws_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,10 @@ def _get_az_mappings(aws_user_hash: str) -> Optional[pd.DataFrame]:
az_mappings = None
if aws_user_hash != 'default':
# Fetch az mapping from AWS.
logger.info(f'{colorama.Style.DIM}Fetching availability zones '
f'mapping for AWS...{colorama.Style.RESET_ALL}')
print(
f'\r{colorama.Style.DIM}AWS: Fetching availability zones '
f'mapping...{colorama.Style.RESET_ALL}',
end='')
az_mappings = fetch_aws.fetch_availability_zone_mappings()
else:
return None
Expand Down
70 changes: 49 additions & 21 deletions sky/clouds/service_catalog/data_fetchers/fetch_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,22 @@
This script takes about 1 minute to finish.
"""
import argparse
import collections
import datetime
import itertools
from multiprocessing import pool as mp_pool
import os
import subprocess
import sys
import textwrap
from typing import Dict, List, Optional, Set, Tuple, Union

import numpy as np
import pandas as pd

from sky import exceptions
from sky.adaptors import aws
from sky.utils import common_utils
from sky.utils import log_utils
from sky.utils import ux_utils

# Enable most of the regions. Each user's account may have a subset of these
Expand Down Expand Up @@ -117,7 +120,7 @@ def _get_instance_type_offerings(region: str) -> pd.DataFrame:
columns={'Location': 'AvailabilityZoneName'})


def _get_availability_zones(region: str) -> Optional[pd.DataFrame]:
def _get_availability_zones(region: str) -> pd.DataFrame:
client = aws.client('ec2', region_name=region)
zones = []
try:
Expand All @@ -130,16 +133,17 @@ def _get_availability_zones(region: str) -> Optional[pd.DataFrame]:
# (AuthFailure) when calling the DescribeAvailabilityZones
# operation: AWS was not able to validate the provided
# access credentials
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
return None
elif e.response['Error']['Code'] == 'UnauthorizedOperation':
with ux_utils.print_exception_no_traceback():
raise RuntimeError(
'Failed to retrieve availability zones. '
'Please ensure that the `ec2:DescribeAvailabilityZones` '
'action is enabled for your AWS account in IAM. '
'Ref: https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_DescribeAvailabilityZones.html.' # pylint: disable=line-too-long
f'{common_utils.format_exception(e, use_bracket=True)}'
raise exceptions.AWSAzFetchingError(
region,
reason=exceptions.AWSAzFetchingError.Reason.AUTH_FAILURE
) from None
elif e.response['Error']['Code'] == 'UnauthorizedOperation':
with ux_utils.print_exception_no_traceback():
raise exceptions.AWSAzFetchingError(
region,
reason=exceptions.AWSAzFetchingError.Reason.
AZ_PERMISSION_DENIED) from None
else:
raise
for resp in response['AvailabilityZones']:
Expand Down Expand Up @@ -231,8 +235,6 @@ def _get_instance_types_df(region: str) -> Union[str, pd.DataFrame]:
# Fetch the zone info first to make sure the account has access to the
# region.
zone_df = _get_availability_zones(region)
if zone_df is None:
raise RuntimeError(f'No access to region {region}')

# Use ThreadPool instead of Pool because this function can be called
# within a multiprocessing.Pool, and Pool cannot be nested.
Expand Down Expand Up @@ -413,20 +415,46 @@ def fetch_availability_zone_mappings() -> pd.DataFrame:
use1-az2 us-east-1a
"""
regions = list(get_enabled_regions())

errored_region_reasons = []

def _get_availability_zones_with_error_handling(
region: str) -> Optional[pd.DataFrame]:
try:
azs = _get_availability_zones(region)
except exceptions.AWSAzFetchingError as e:
errored_region_reasons.append(
(region, e.reason)) # GIL means it's thread-safe.
return None
return azs

# Use ThreadPool instead of Pool because this function can be called within
# a Pool, and Pool cannot be nested.
with mp_pool.ThreadPool() as pool:
az_mappings = pool.map(_get_availability_zones, regions)
missing_regions = {
regions[i] for i, m in enumerate(az_mappings) if m is None
}
if missing_regions:
# This could happen if a AWS API glitch happens, it is to make sure
# that the availability zone does not get lost silently.
print('WARNING: Missing availability zone mappings for the following '
f'enabled regions: {missing_regions}')
az_mappings = pool.map(_get_availability_zones_with_error_handling,
regions)
# Remove the regions that the user does not have access to.
az_mappings = [m for m in az_mappings if m is not None]
errored_regions = collections.defaultdict(set)
for region, reason in errored_region_reasons:
errored_regions[reason].add(region)
if errored_regions:
# This could happen if (1) an AWS API glitch happens, (2) permission
# error happens for specific availability zones. We print those zones to
# make sure that those zone does not get lost silently.
table = log_utils.create_table(['Regions', 'Reason'])
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
for reason, region_set in errored_regions.items():
reason_str = '\n'.join(textwrap.wrap(str(reason.message), 80))
region_str = '\n'.join(
textwrap.wrap(', '.join(region_set), 60,
break_on_hyphens=False))
table.add_row([region_str, reason_str])
if not az_mappings:
raise RuntimeError('Failed to fetch availability zone mappings for '
f'all enabled regions.\n{table}')
else:
print('\rAWS: [WARNING] Missing availability zone mappings for the '
f'following enabled regions:\n{table}')
az_mappings = pd.concat(az_mappings)
return az_mappings

Expand Down
32 changes: 32 additions & 0 deletions sky/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,35 @@ class ClusterOwnerIdentityMismatchError(Exception):
class NoCloudAccessError(Exception):
"""Raised when all clouds are disabled."""
pass


class AWSAzFetchingError(Exception):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(not feeling strongly) nit: AwsAzFetchingError

https://google.github.io/styleguide/cppguide.html#General_Naming_Rules

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I slightly prefer the AWS, as the AWSNodeProvider from ray code already applies this, but feel free to push back.

"""Raised when fetching the AWS availability zone fails."""

class Reason(enum.Enum):
"""Reason for fetching availability zone failure."""

AUTH_FAILURE = 'AUTH_FAILURE'
AZ_PERMISSION_DENIED = 'AZ_PERMISSION_DENIED'

@property
def message(self) -> str:
if self == self.AUTH_FAILURE:
return ('Failed to access AWS services. Please check your AWS '
'credentials.')
elif self == self.AZ_PERMISSION_DENIED:
return (
'Failed to retrieve availability zones. '
'Please ensure that the `ec2:DescribeAvailabilityZones` '
'action is enabled for your AWS account in IAM. '
'Ref: https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_DescribeAvailabilityZones.html.' # pylint: disable=line-too-long
)
else:
raise ValueError(f'Unknown reason {self}')

def __init__(self, region: str,
reason: 'AWSAzFetchingError.Reason') -> None:
self.region = region
self.reason = reason

super().__init__(reason.message)