Skip to content

Commit

Permalink
[AWS] Enable p4de in the catalog (#1827)
Browse files Browse the repository at this point in the history
* patch p4de for AWS catalog

* format

* format

* sort by zone

* Address comment
  • Loading branch information
Michaelvll authored Apr 1, 2023
1 parent 2e3bb48 commit f8eeeea
Showing 1 changed file with 42 additions and 5 deletions.
47 changes: 42 additions & 5 deletions sky/clouds/service_catalog/data_fetchers/fetch_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import itertools
from multiprocessing import pool as mp_pool
import os
import sys
import subprocess
from typing import Dict, List, Optional, Set, Tuple, Union

Expand Down Expand Up @@ -59,6 +60,10 @@
# only available in this region, but it serves pricing information for all
# regions.
PRICING_TABLE_URL_FMT = 'https://pricing.us-east-1.amazonaws.com/offers/v1.0/aws/AmazonEC2/current/{region}/index.csv' # pylint: disable=line-too-long
# Hardcode the regions that offer p4de.24xlarge as our credential does not have
# the permission to query the offerings of the instance.
# Ref: https://aws.amazon.com/ec2/instance-types/p4/
P4DE_REGIONS = ['us-east-1', 'us-west-2']

regions_enabled: Optional[Set[str]] = None

Expand Down Expand Up @@ -166,6 +171,34 @@ def _get_spot_pricing_table(region: str) -> pd.DataFrame:
return df


def _patch_p4de(region: str, df: pd.DataFrame,
pricing_df: pd.DataFrame) -> pd.DataFrame:
# Hardcoded patch for p4de.24xlarge, as our credentials doesn't have access
# to the instance type.
# Columns:
# InstanceType,AcceleratorName,AcceleratorCount,vCPUs,MemoryGiB,GpuInfo,
# Price,SpotPrice,Region,AvailabilityZone
for zone in df[df['Region'] == region]['AvailabilityZone'].unique():
df = df.append(pd.Series({
'InstanceType': 'p4de.24xlarge',
'AcceleratorName': 'A100-80GB',
'AcceleratorCount': 8,
'vCPUs': 96,
'MemoryGiB': 1152,
'GpuInfo':
('{\'Gpus\': [{\'Name\': \'A100-80GB\', \'Manufacturer\': '
'\'NVIDIA\', \'Count\': 8, \'MemoryInfo\': {\'SizeInMiB\': '
'81920}}], \'TotalGpuMemoryInMiB\': 655360}'),
'AvailabilityZone': zone,
'Region': region,
'Price': pricing_df[pricing_df['InstanceType'] == 'p4de.24xlarge']
['Price'].values[0],
'SpotPrice': np.nan,
}),
ignore_index=True)
return df


def _get_instance_types_df(region: str) -> Union[str, pd.DataFrame]:
try:
# Fetch the zone info first to make sure the account has access to the
Expand Down Expand Up @@ -247,11 +280,14 @@ def get_additional_columns(row) -> pd.Series:
df = pd.concat(
[df, df.apply(get_additional_columns, axis='columns')],
axis='columns')
# patch the GpuInfo for p4de.24xlarge
df.loc[df['InstanceType'] == 'p4de.24xlarge', 'GpuInfo'] = 'A100-80GB'
# patch the df for p4de.24xlarge
if region in P4DE_REGIONS:
df = _patch_p4de(region, df, pricing_df)
if 'GpuInfo' not in df.columns:
df['GpuInfo'] = np.nan
df = df[USEFUL_COLUMNS]
except Exception as e: # pylint: disable=broad-except
print(f'{region} failed with {e}')
print(f'{region} failed with {e}', file=sys.stderr)
return region
return df

Expand All @@ -267,7 +303,7 @@ def get_all_regions_instance_types_df(regions: Set[str]) -> pd.DataFrame:
new_dfs.append(df_or_region)

df = pd.concat(new_dfs)
df.sort_values(['InstanceType', 'Region'], inplace=True)
df.sort_values(['InstanceType', 'Region', 'AvailabilityZone'], inplace=True)
return df


Expand Down Expand Up @@ -402,9 +438,10 @@ def _check_regions_integrity(df: pd.DataFrame, name: str):
# requested are the same as the ones we fetched.
# The mismatch could happen for network issues or glitches
# in the AWS API.
diff = user_regions - fetched_regions
raise RuntimeError(
f'{name}: Fetched regions {fetched_regions} does not match '
f'requested regions {user_regions}.')
f'requested regions {user_regions}; Diff: {diff}')

instance_df = get_all_regions_instance_types_df(user_regions)
_check_regions_integrity(instance_df, 'instance types')
Expand Down

0 comments on commit f8eeeea

Please sign in to comment.