Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AWS] Enable p4de in the catalog #1827

Merged
merged 5 commits into from
Apr 1, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 42 additions & 5 deletions sky/clouds/service_catalog/data_fetchers/fetch_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import itertools
from multiprocessing import pool as mp_pool
import os
import sys
import subprocess
from typing import Dict, List, Optional, Set, Tuple, Union

Expand Down Expand Up @@ -59,6 +60,10 @@
# only available in this region, but it serves pricing information for all
# regions.
PRICING_TABLE_URL_FMT = 'https://pricing.us-east-1.amazonaws.com/offers/v1.0/aws/AmazonEC2/current/{region}/index.csv' # pylint: disable=line-too-long
# Hardcode the regions that offer p4de.24xlarge as our credential does not have
# the permission to query the offerings of the instance.
# Ref: https://aws.amazon.com/ec2/instance-types/p4/
P4DE_REGIONS = ['us-east-1', 'us-west-2']
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved

regions_enabled: Optional[Set[str]] = None

Expand Down Expand Up @@ -166,6 +171,34 @@ def _get_spot_pricing_table(region: str) -> pd.DataFrame:
return df


def _patch_p4de(region: str, df: pd.DataFrame,
pricing_df: pd.DataFrame) -> pd.DataFrame:
# Hardcoded patch for p4de.24xlarge, as our credentials doesn't have access
# to the instance type.
# Columns:
# InstanceType,AcceleratorName,AcceleratorCount,vCPUs,MemoryGiB,GpuInfo,
# Price,SpotPrice,Region,AvailabilityZone
for zone in df[df['Region'] == region]['AvailabilityZone'].unique():
df = df.append(pd.Series({
'InstanceType': 'p4de.24xlarge',
'AcceleratorName': 'A100-80GB',
'AcceleratorCount': 8,
'vCPUs': 96,
'MemoryGiB': 1152,
'GpuInfo':
('{\'Gpus\': [{\'Name\': \'A100-80GB\', \'Manufacturer\': '
'\'NVIDIA\', \'Count\': 8, \'MemoryInfo\': {\'SizeInMiB\': '
'81920}}], \'TotalGpuMemoryInMiB\': 655360}'),
'AvailabilityZone': zone,
'Region': region,
'Price': pricing_df[pricing_df['InstanceType'] == 'p4de.24xlarge']
['Price'].values[0],
'SpotPrice': np.nan,
}),
ignore_index=True)
return df


def _get_instance_types_df(region: str) -> Union[str, pd.DataFrame]:
try:
# Fetch the zone info first to make sure the account has access to the
Expand Down Expand Up @@ -247,11 +280,14 @@ def get_additional_columns(row) -> pd.Series:
df = pd.concat(
[df, df.apply(get_additional_columns, axis='columns')],
axis='columns')
# patch the GpuInfo for p4de.24xlarge
df.loc[df['InstanceType'] == 'p4de.24xlarge', 'GpuInfo'] = 'A100-80GB'
# patch the df for p4de.24xlarge
if region in P4DE_REGIONS:
df = _patch_p4de(region, df, pricing_df)
if 'GpuInfo' not in df.columns:
df['GpuInfo'] = np.nan
df = df[USEFUL_COLUMNS]
except Exception as e: # pylint: disable=broad-except
print(f'{region} failed with {e}')
print(f'{region} failed with {e}', file=sys.stderr)
return region
return df

Expand All @@ -267,7 +303,7 @@ def get_all_regions_instance_types_df(regions: Set[str]) -> pd.DataFrame:
new_dfs.append(df_or_region)

df = pd.concat(new_dfs)
df.sort_values(['InstanceType', 'Region'], inplace=True)
df.sort_values(['InstanceType', 'Region', 'AvailabilityZone'], inplace=True)
return df


Expand Down Expand Up @@ -402,9 +438,10 @@ def _check_regions_integrity(df: pd.DataFrame, name: str):
# requested are the same as the ones we fetched.
# The mismatch could happen for network issues or glitches
# in the AWS API.
diff = user_regions - fetched_regions
raise RuntimeError(
f'{name}: Fetched regions {fetched_regions} does not match '
f'requested regions {user_regions}.')
f'requested regions {user_regions}; Diff: {diff}')

instance_df = get_all_regions_instance_types_df(user_regions)
_check_regions_integrity(instance_df, 'instance types')
Expand Down