diff --git a/sky/clouds/service_catalog/data_fetchers/fetch_aws.py b/sky/clouds/service_catalog/data_fetchers/fetch_aws.py index 400882abf82..b0a57bdf59d 100644 --- a/sky/clouds/service_catalog/data_fetchers/fetch_aws.py +++ b/sky/clouds/service_catalog/data_fetchers/fetch_aws.py @@ -6,6 +6,7 @@ import itertools from multiprocessing import pool as mp_pool import os +import sys import subprocess from typing import Dict, List, Optional, Set, Tuple, Union @@ -59,6 +60,10 @@ # only available in this region, but it serves pricing information for all # regions. PRICING_TABLE_URL_FMT = 'https://pricing.us-east-1.amazonaws.com/offers/v1.0/aws/AmazonEC2/current/{region}/index.csv' # pylint: disable=line-too-long +# Hardcode the regions that offer p4de.24xlarge as our credential does not have +# the permission to query the offerings of the instance. +# Ref: https://aws.amazon.com/ec2/instance-types/p4/ +P4DE_REGIONS = ['us-east-1', 'us-west-2'] regions_enabled: Optional[Set[str]] = None @@ -166,6 +171,34 @@ def _get_spot_pricing_table(region: str) -> pd.DataFrame: return df +def _patch_p4de(region: str, df: pd.DataFrame, + pricing_df: pd.DataFrame) -> pd.DataFrame: + # Hardcoded patch for p4de.24xlarge, as our credentials doesn't have access + # to the instance type. + # Columns: + # InstanceType,AcceleratorName,AcceleratorCount,vCPUs,MemoryGiB,GpuInfo, + # Price,SpotPrice,Region,AvailabilityZone + for zone in df[df['Region'] == region]['AvailabilityZone'].unique(): + df = df.append(pd.Series({ + 'InstanceType': 'p4de.24xlarge', + 'AcceleratorName': 'A100-80GB', + 'AcceleratorCount': 8, + 'vCPUs': 96, + 'MemoryGiB': 1152, + 'GpuInfo': + ('{\'Gpus\': [{\'Name\': \'A100-80GB\', \'Manufacturer\': ' + '\'NVIDIA\', \'Count\': 8, \'MemoryInfo\': {\'SizeInMiB\': ' + '81920}}], \'TotalGpuMemoryInMiB\': 655360}'), + 'AvailabilityZone': zone, + 'Region': region, + 'Price': pricing_df[pricing_df['InstanceType'] == 'p4de.24xlarge'] + ['Price'].values[0], + 'SpotPrice': np.nan, + }), + ignore_index=True) + return df + + def _get_instance_types_df(region: str) -> Union[str, pd.DataFrame]: try: # Fetch the zone info first to make sure the account has access to the @@ -247,11 +280,14 @@ def get_additional_columns(row) -> pd.Series: df = pd.concat( [df, df.apply(get_additional_columns, axis='columns')], axis='columns') - # patch the GpuInfo for p4de.24xlarge - df.loc[df['InstanceType'] == 'p4de.24xlarge', 'GpuInfo'] = 'A100-80GB' + # patch the df for p4de.24xlarge + if region in P4DE_REGIONS: + df = _patch_p4de(region, df, pricing_df) + if 'GpuInfo' not in df.columns: + df['GpuInfo'] = np.nan df = df[USEFUL_COLUMNS] except Exception as e: # pylint: disable=broad-except - print(f'{region} failed with {e}') + print(f'{region} failed with {e}', file=sys.stderr) return region return df @@ -267,7 +303,7 @@ def get_all_regions_instance_types_df(regions: Set[str]) -> pd.DataFrame: new_dfs.append(df_or_region) df = pd.concat(new_dfs) - df.sort_values(['InstanceType', 'Region'], inplace=True) + df.sort_values(['InstanceType', 'Region', 'AvailabilityZone'], inplace=True) return df @@ -402,9 +438,10 @@ def _check_regions_integrity(df: pd.DataFrame, name: str): # requested are the same as the ones we fetched. # The mismatch could happen for network issues or glitches # in the AWS API. + diff = user_regions - fetched_regions raise RuntimeError( f'{name}: Fetched regions {fetched_regions} does not match ' - f'requested regions {user_regions}.') + f'requested regions {user_regions}; Diff: {diff}') instance_df = get_all_regions_instance_types_df(user_regions) _check_regions_integrity(instance_df, 'instance types')