Skip to content

Commit

Permalink
[sky/feat/show-gpus] adding region based filtering for show-gpus comm…
Browse files Browse the repository at this point in the history
…and. (#1187)

* [sky/feat] adding region based filtering for show-gpus command.

* [tests] added test for the new region flag.

* [workflows] fixed the pylint and yapf workflows.

* [merge] rebase after a long time.

* [fix] fixes the region column display with -a option

* [fix] yapf and pylint formatting

* [fix] yapf and pylint formatting

* [nits] code review nits and pylint fixes.
  • Loading branch information
Vivek Khimani authored Jan 21, 2023
1 parent 9fca91b commit b42e01f
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 31 deletions.
58 changes: 46 additions & 12 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2667,8 +2667,20 @@ def check():
default=None,
type=str,
help='Cloud provider to query.')
@click.option(
'--region',
required=False,
type=str,
help=
('The region to use. If not specified, shows accelerators from all regions.'
),
)
@usage_lib.entrypoint
def show_gpus(gpu_name: Optional[str], all: bool, cloud: Optional[str]): # pylint: disable=redefined-builtin
def show_gpus(
gpu_name: Optional[str],
all: bool, # pylint: disable=redefined-builtin
cloud: Optional[str],
region: Optional[str]):
"""Show supported GPU/TPU/accelerators.
The names and counts shown can be set in the ``accelerators`` field in task
Expand All @@ -2682,9 +2694,14 @@ def show_gpus(gpu_name: Optional[str], all: bool, cloud: Optional[str]): # pyli
To show all accelerators, including less common ones and their detailed
information, use ``sky show-gpus --all``.
NOTE: The price displayed for each instance type is the lowest across all
regions for both on-demand and spot instances.
NOTE: If region is not specified, the price displayed for each instance type
is the lowest across all regions for both on-demand and spot instances.
"""
# validation for the --region flag
if region is not None and cloud is None:
raise click.UsageError(
'The --region flag is only valid when the --cloud flag is set.')
service_catalog.validate_region_zone(region, None, clouds=cloud)
show_all = all
if show_all and gpu_name is not None:
raise click.UsageError('--all is only allowed without a GPU name.')
Expand All @@ -2701,8 +2718,11 @@ def _output():
['OTHER_GPU', 'AVAILABLE_QUANTITIES'])

if gpu_name is None:
result = service_catalog.list_accelerator_counts(gpus_only=True,
clouds=cloud)
result = service_catalog.list_accelerator_counts(
gpus_only=True,
clouds=cloud,
region_filter=region,
)
# NVIDIA GPUs
for gpu in service_catalog.get_common_gpus():
if gpu in result:
Expand Down Expand Up @@ -2730,6 +2750,7 @@ def _output():
# Show detailed accelerator information
result = service_catalog.list_accelerators(gpus_only=True,
name_filter=gpu_name,
region_filter=region,
clouds=cloud)
if len(result) == 0:
yield f'Resources \'{gpu_name}\' not found. '
Expand All @@ -2742,7 +2763,7 @@ def _output():
yield 'the host VM\'s cost is not included.\n\n'
import pandas as pd # pylint: disable=import-outside-toplevel
for i, (gpu, items) in enumerate(result.items()):
accelerator_table = log_utils.create_table([
accelerator_table_headers = [
'GPU',
'QTY',
'CLOUD',
Expand All @@ -2751,7 +2772,11 @@ def _output():
'HOST_MEMORY',
'HOURLY_PRICE',
'HOURLY_SPOT_PRICE',
])
]
if not show_all:
accelerator_table_headers.append('REGION')
accelerator_table = log_utils.create_table(
accelerator_table_headers)
for item in items:
instance_type_str = item.instance_type if not pd.isna(
item.instance_type) else '(attachable)'
Expand All @@ -2769,11 +2794,20 @@ def _output():
item.price) else '-'
spot_price_str = f'$ {item.spot_price:.3f}' if not pd.isna(
item.spot_price) else '-'
accelerator_table.add_row([
item.accelerator_name, item.accelerator_count, item.cloud,
instance_type_str, cpu_str, mem_str, price_str,
spot_price_str
])
region_str = item.region if not pd.isna(item.region) else '-'
accelerator_table_vals = [
item.accelerator_name,
item.accelerator_count,
item.cloud,
instance_type_str,
cpu_str,
mem_str,
price_str,
spot_price_str,
]
if not show_all:
accelerator_table_vals.append(region_str)
accelerator_table.add_row(accelerator_table_vals)

if i != 0:
yield '\n\n'
Expand Down
6 changes: 4 additions & 2 deletions sky/clouds/service_catalog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def _map_clouds_catalog(clouds: CloudFilter, method_name: str, *args, **kwargs):
def list_accelerators(
gpus_only: bool = True,
name_filter: Optional[str] = None,
region_filter: Optional[str] = None,
clouds: CloudFilter = None,
case_sensitive: bool = True,
) -> 'Dict[str, List[common.InstanceTypeInfo]]':
Expand All @@ -58,7 +59,7 @@ def list_accelerators(
of instance type offerings. See usage in cli.py.
"""
results = _map_clouds_catalog(clouds, 'list_accelerators', gpus_only,
name_filter, case_sensitive)
name_filter, region_filter, case_sensitive)
if not isinstance(results, list):
results = [results]
ret: Dict[str,
Expand All @@ -72,6 +73,7 @@ def list_accelerators(
def list_accelerator_counts(
gpus_only: bool = True,
name_filter: Optional[str] = None,
region_filter: Optional[str] = None,
clouds: CloudFilter = None,
) -> Dict[str, List[int]]:
"""List all accelerators offered by Sky and available counts.
Expand All @@ -80,7 +82,7 @@ def list_accelerator_counts(
of available counts. See usage in cli.py.
"""
results = _map_clouds_catalog(clouds, 'list_accelerators', gpus_only,
name_filter)
name_filter, region_filter, False)
if not isinstance(results, list):
results = [results]
accelerator_counts: Dict[str, Set[int]] = collections.defaultdict(set)
Expand Down
14 changes: 8 additions & 6 deletions sky/clouds/service_catalog/aws_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def instance_type_exists(instance_type: str) -> bool:
def validate_region_zone(
region: Optional[str],
zone: Optional[str]) -> Tuple[Optional[str], Optional[str]]:
return common.validate_region_zone_impl(_df, region, zone)
return common.validate_region_zone_impl('aws', _df, region, zone)


def accelerator_in_region_or_zone(acc_name: str,
Expand Down Expand Up @@ -134,13 +134,15 @@ def get_region_zones_for_instance_type(instance_type: str,
return us_region_list + other_region_list


def list_accelerators(gpus_only: bool,
name_filter: Optional[str],
case_sensitive: bool = True
) -> Dict[str, List[common.InstanceTypeInfo]]:
def list_accelerators(
gpus_only: bool,
name_filter: Optional[str],
region_filter: Optional[str],
case_sensitive: bool = True
) -> Dict[str, List[common.InstanceTypeInfo]]:
"""Returns all instance types in AWS offering accelerators."""
return common.list_accelerators_impl('AWS', _df, gpus_only, name_filter,
case_sensitive)
region_filter, case_sensitive)


def get_image_id_from_tag(tag: str, region: Optional[str]) -> Optional[str]:
Expand Down
14 changes: 8 additions & 6 deletions sky/clouds/service_catalog/azure_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def validate_region_zone(
if zone is not None:
with ux_utils.print_exception_no_traceback():
raise ValueError('Azure does not support zones.')
return common.validate_region_zone_impl(_df, region, zone)
return common.validate_region_zone_impl('azure', _df, region, zone)


def accelerator_in_region_or_zone(acc_name: str,
Expand Down Expand Up @@ -89,10 +89,12 @@ def get_gen_version_from_instance_type(instance_type: str) -> Optional[int]:
return _df[_df['InstanceType'] == instance_type]['Generation'].iloc[0]


def list_accelerators(gpus_only: bool,
name_filter: Optional[str],
case_sensitive: bool = True
) -> Dict[str, List[common.InstanceTypeInfo]]:
def list_accelerators(
gpus_only: bool,
name_filter: Optional[str],
region_filter: Optional[str],
case_sensitive: bool = True
) -> Dict[str, List[common.InstanceTypeInfo]]:
"""Returns all instance types in Azure offering GPUs."""
return common.list_accelerators_impl('Azure', _df, gpus_only, name_filter,
case_sensitive)
region_filter, case_sensitive)
23 changes: 21 additions & 2 deletions sky/clouds/service_catalog/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class InstanceTypeInfo(NamedTuple):
- memory: Instance memory in GiB.
- price: Regular instance price per hour (cheapest across all regions).
- spot_price: Spot instance price per hour (cheapest across all regions).
- region: Region where this instance type belongs to.
"""
cloud: str
instance_type: Optional[str]
Expand All @@ -43,6 +44,7 @@ class InstanceTypeInfo(NamedTuple):
memory: Optional[float]
price: float
spot_price: float
region: str


def get_catalog_path(filename: str) -> str:
Expand Down Expand Up @@ -161,7 +163,7 @@ def instance_type_exists_impl(df: pd.DataFrame, instance_type: str) -> bool:


def validate_region_zone_impl(
df: pd.DataFrame, region: Optional[str],
cloud_name: str, df: pd.DataFrame, region: Optional[str],
zone: Optional[str]) -> Tuple[Optional[str], Optional[str]]:
"""Validates whether region and zone exist in the catalog."""

Expand All @@ -174,6 +176,11 @@ def _get_candidate_str(loc: str, all_loc: List[str]) -> str:
candidate_strs = f'\nDid you mean one of these: {candidate_strs!r}?'
return candidate_strs

def _get_all_supported_regions_str() -> str:
all_regions: List[str] = sorted(df['Region'].unique().tolist())
return \
f'\nList of supported {cloud_name} regions: {", ".join(all_regions)!r}'

validated_region, validated_zone = region, zone

filter_df = df
Expand All @@ -182,7 +189,12 @@ def _get_candidate_str(loc: str, all_loc: List[str]) -> str:
if len(filter_df) == 0:
with ux_utils.print_exception_no_traceback():
error_msg = (f'Invalid region {region!r}')
error_msg += _get_candidate_str(region, df['Region'].unique())
candidate_strs = _get_candidate_str(region,
df['Region'].unique())
if not candidate_strs:
error_msg += _get_all_supported_regions_str()
raise ValueError(error_msg)
error_msg += candidate_strs
raise ValueError(error_msg)

if zone is not None:
Expand Down Expand Up @@ -321,6 +333,7 @@ def list_accelerators_impl(
df: pd.DataFrame,
gpus_only: bool,
name_filter: Optional[str],
region_filter: Optional[str],
case_sensitive: bool = True,
) -> Dict[str, List[InstanceTypeInfo]]:
"""Lists accelerators offered in a cloud service catalog.
Expand All @@ -341,11 +354,16 @@ def list_accelerators_impl(
'MemoryGiB',
'Price',
'SpotPrice',
'Region',
]].dropna(subset=['AcceleratorName']).drop_duplicates()
if name_filter is not None:
df = df[df['AcceleratorName'].str.contains(name_filter,
case=case_sensitive,
regex=True)]
if region_filter is not None:
df = df[df['Region'].str.contains(region_filter,
case=case_sensitive,
regex=True)]
df['AcceleratorCount'] = df['AcceleratorCount'].astype(int)
grouped = df.groupby('AcceleratorName')

Expand All @@ -366,6 +384,7 @@ def make_list_from_df(rows):
row['MemoryGiB'],
row['Price'],
row['SpotPrice'],
row['Region'],
),
axis='columns',
).tolist()
Expand Down
5 changes: 3 additions & 2 deletions sky/clouds/service_catalog/gcp_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def get_instance_type_for_accelerator(
def validate_region_zone(
region: Optional[str],
zone: Optional[str]) -> Tuple[Optional[str], Optional[str]]:
return common.validate_region_zone_impl(_df, region, zone)
return common.validate_region_zone_impl('gcp', _df, region, zone)


def accelerator_in_region_or_zone(acc_name: str,
Expand Down Expand Up @@ -273,11 +273,12 @@ def get_accelerator_hourly_cost(accelerator: str,
def list_accelerators(
gpus_only: bool,
name_filter: Optional[str] = None,
region_filter: Optional[str] = None,
case_sensitive: bool = True,
) -> Dict[str, List[common.InstanceTypeInfo]]:
"""Returns all instance types in GCP offering GPUs."""
results = common.list_accelerators_impl('GCP', _df, gpus_only, name_filter,
case_sensitive)
region_filter, case_sensitive)

a100_infos = results.get('A100', []) + results.get('A100-80GB', [])
if not a100_infos:
Expand Down
13 changes: 12 additions & 1 deletion tests/test_list_accelerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@ def test_list_ccelerators_all():
assert 'A100-80GB' in result, result


def test_list_accelerators_filters():
def test_list_accelerators_name_filter():
result = sky.list_accelerators(gpus_only=False, name_filter='V100')
assert sorted(result.keys()) == ['V100', 'V100-32GB'], result


def test_list_accelerators_region_filter():
result = sky.list_accelerators(gpus_only=False,
clouds="aws",
region_filter='us-west-1')
all_regions = []
for res in result.values():
for instance in res:
all_regions.append(instance.region)
assert all([region == 'us-west-1' for region in all_regions])

0 comments on commit b42e01f

Please sign in to comment.