From b42e01f734de51f387796ef7cef7fd7edb299c4d Mon Sep 17 00:00:00 2001 From: Vivek Khimani Date: Sat, 21 Jan 2023 14:30:42 -0800 Subject: [PATCH] [sky/feat/show-gpus] adding region based filtering for show-gpus command. (#1187) * [sky/feat] adding region based filtering for show-gpus command. * [tests] added test for the new region flag. * [workflows] fixed the pylint and yapf workflows. * [merge] rebase after a long time. * [fix] fixes the region column display with -a option * [fix] yapf and pylint formatting * [fix] yapf and pylint formatting * [nits] code review nits and pylint fixes. --- sky/cli.py | 58 ++++++++++++++++----- sky/clouds/service_catalog/__init__.py | 6 ++- sky/clouds/service_catalog/aws_catalog.py | 14 ++--- sky/clouds/service_catalog/azure_catalog.py | 14 ++--- sky/clouds/service_catalog/common.py | 23 +++++++- sky/clouds/service_catalog/gcp_catalog.py | 5 +- tests/test_list_accelerators.py | 13 ++++- 7 files changed, 102 insertions(+), 31 deletions(-) diff --git a/sky/cli.py b/sky/cli.py index e26411c076e..8d504a0ce59 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -2667,8 +2667,20 @@ def check(): default=None, type=str, help='Cloud provider to query.') +@click.option( + '--region', + required=False, + type=str, + help= + ('The region to use. If not specified, shows accelerators from all regions.' + ), +) @usage_lib.entrypoint -def show_gpus(gpu_name: Optional[str], all: bool, cloud: Optional[str]): # pylint: disable=redefined-builtin +def show_gpus( + gpu_name: Optional[str], + all: bool, # pylint: disable=redefined-builtin + cloud: Optional[str], + region: Optional[str]): """Show supported GPU/TPU/accelerators. The names and counts shown can be set in the ``accelerators`` field in task @@ -2682,9 +2694,14 @@ def show_gpus(gpu_name: Optional[str], all: bool, cloud: Optional[str]): # pyli To show all accelerators, including less common ones and their detailed information, use ``sky show-gpus --all``. - NOTE: The price displayed for each instance type is the lowest across all - regions for both on-demand and spot instances. + NOTE: If region is not specified, the price displayed for each instance type + is the lowest across all regions for both on-demand and spot instances. """ + # validation for the --region flag + if region is not None and cloud is None: + raise click.UsageError( + 'The --region flag is only valid when the --cloud flag is set.') + service_catalog.validate_region_zone(region, None, clouds=cloud) show_all = all if show_all and gpu_name is not None: raise click.UsageError('--all is only allowed without a GPU name.') @@ -2701,8 +2718,11 @@ def _output(): ['OTHER_GPU', 'AVAILABLE_QUANTITIES']) if gpu_name is None: - result = service_catalog.list_accelerator_counts(gpus_only=True, - clouds=cloud) + result = service_catalog.list_accelerator_counts( + gpus_only=True, + clouds=cloud, + region_filter=region, + ) # NVIDIA GPUs for gpu in service_catalog.get_common_gpus(): if gpu in result: @@ -2730,6 +2750,7 @@ def _output(): # Show detailed accelerator information result = service_catalog.list_accelerators(gpus_only=True, name_filter=gpu_name, + region_filter=region, clouds=cloud) if len(result) == 0: yield f'Resources \'{gpu_name}\' not found. ' @@ -2742,7 +2763,7 @@ def _output(): yield 'the host VM\'s cost is not included.\n\n' import pandas as pd # pylint: disable=import-outside-toplevel for i, (gpu, items) in enumerate(result.items()): - accelerator_table = log_utils.create_table([ + accelerator_table_headers = [ 'GPU', 'QTY', 'CLOUD', @@ -2751,7 +2772,11 @@ def _output(): 'HOST_MEMORY', 'HOURLY_PRICE', 'HOURLY_SPOT_PRICE', - ]) + ] + if not show_all: + accelerator_table_headers.append('REGION') + accelerator_table = log_utils.create_table( + accelerator_table_headers) for item in items: instance_type_str = item.instance_type if not pd.isna( item.instance_type) else '(attachable)' @@ -2769,11 +2794,20 @@ def _output(): item.price) else '-' spot_price_str = f'$ {item.spot_price:.3f}' if not pd.isna( item.spot_price) else '-' - accelerator_table.add_row([ - item.accelerator_name, item.accelerator_count, item.cloud, - instance_type_str, cpu_str, mem_str, price_str, - spot_price_str - ]) + region_str = item.region if not pd.isna(item.region) else '-' + accelerator_table_vals = [ + item.accelerator_name, + item.accelerator_count, + item.cloud, + instance_type_str, + cpu_str, + mem_str, + price_str, + spot_price_str, + ] + if not show_all: + accelerator_table_vals.append(region_str) + accelerator_table.add_row(accelerator_table_vals) if i != 0: yield '\n\n' diff --git a/sky/clouds/service_catalog/__init__.py b/sky/clouds/service_catalog/__init__.py index 2acd61d527c..a722701cc53 100644 --- a/sky/clouds/service_catalog/__init__.py +++ b/sky/clouds/service_catalog/__init__.py @@ -49,6 +49,7 @@ def _map_clouds_catalog(clouds: CloudFilter, method_name: str, *args, **kwargs): def list_accelerators( gpus_only: bool = True, name_filter: Optional[str] = None, + region_filter: Optional[str] = None, clouds: CloudFilter = None, case_sensitive: bool = True, ) -> 'Dict[str, List[common.InstanceTypeInfo]]': @@ -58,7 +59,7 @@ def list_accelerators( of instance type offerings. See usage in cli.py. """ results = _map_clouds_catalog(clouds, 'list_accelerators', gpus_only, - name_filter, case_sensitive) + name_filter, region_filter, case_sensitive) if not isinstance(results, list): results = [results] ret: Dict[str, @@ -72,6 +73,7 @@ def list_accelerators( def list_accelerator_counts( gpus_only: bool = True, name_filter: Optional[str] = None, + region_filter: Optional[str] = None, clouds: CloudFilter = None, ) -> Dict[str, List[int]]: """List all accelerators offered by Sky and available counts. @@ -80,7 +82,7 @@ def list_accelerator_counts( of available counts. See usage in cli.py. """ results = _map_clouds_catalog(clouds, 'list_accelerators', gpus_only, - name_filter) + name_filter, region_filter, False) if not isinstance(results, list): results = [results] accelerator_counts: Dict[str, Set[int]] = collections.defaultdict(set) diff --git a/sky/clouds/service_catalog/aws_catalog.py b/sky/clouds/service_catalog/aws_catalog.py index 9fccd495b5c..150e45b384e 100644 --- a/sky/clouds/service_catalog/aws_catalog.py +++ b/sky/clouds/service_catalog/aws_catalog.py @@ -71,7 +71,7 @@ def instance_type_exists(instance_type: str) -> bool: def validate_region_zone( region: Optional[str], zone: Optional[str]) -> Tuple[Optional[str], Optional[str]]: - return common.validate_region_zone_impl(_df, region, zone) + return common.validate_region_zone_impl('aws', _df, region, zone) def accelerator_in_region_or_zone(acc_name: str, @@ -134,13 +134,15 @@ def get_region_zones_for_instance_type(instance_type: str, return us_region_list + other_region_list -def list_accelerators(gpus_only: bool, - name_filter: Optional[str], - case_sensitive: bool = True - ) -> Dict[str, List[common.InstanceTypeInfo]]: +def list_accelerators( + gpus_only: bool, + name_filter: Optional[str], + region_filter: Optional[str], + case_sensitive: bool = True +) -> Dict[str, List[common.InstanceTypeInfo]]: """Returns all instance types in AWS offering accelerators.""" return common.list_accelerators_impl('AWS', _df, gpus_only, name_filter, - case_sensitive) + region_filter, case_sensitive) def get_image_id_from_tag(tag: str, region: Optional[str]) -> Optional[str]: diff --git a/sky/clouds/service_catalog/azure_catalog.py b/sky/clouds/service_catalog/azure_catalog.py index fc74fc79540..c07f8e8ac7c 100644 --- a/sky/clouds/service_catalog/azure_catalog.py +++ b/sky/clouds/service_catalog/azure_catalog.py @@ -22,7 +22,7 @@ def validate_region_zone( if zone is not None: with ux_utils.print_exception_no_traceback(): raise ValueError('Azure does not support zones.') - return common.validate_region_zone_impl(_df, region, zone) + return common.validate_region_zone_impl('azure', _df, region, zone) def accelerator_in_region_or_zone(acc_name: str, @@ -89,10 +89,12 @@ def get_gen_version_from_instance_type(instance_type: str) -> Optional[int]: return _df[_df['InstanceType'] == instance_type]['Generation'].iloc[0] -def list_accelerators(gpus_only: bool, - name_filter: Optional[str], - case_sensitive: bool = True - ) -> Dict[str, List[common.InstanceTypeInfo]]: +def list_accelerators( + gpus_only: bool, + name_filter: Optional[str], + region_filter: Optional[str], + case_sensitive: bool = True +) -> Dict[str, List[common.InstanceTypeInfo]]: """Returns all instance types in Azure offering GPUs.""" return common.list_accelerators_impl('Azure', _df, gpus_only, name_filter, - case_sensitive) + region_filter, case_sensitive) diff --git a/sky/clouds/service_catalog/common.py b/sky/clouds/service_catalog/common.py index d646673117f..84b2c67be12 100644 --- a/sky/clouds/service_catalog/common.py +++ b/sky/clouds/service_catalog/common.py @@ -34,6 +34,7 @@ class InstanceTypeInfo(NamedTuple): - memory: Instance memory in GiB. - price: Regular instance price per hour (cheapest across all regions). - spot_price: Spot instance price per hour (cheapest across all regions). + - region: Region where this instance type belongs to. """ cloud: str instance_type: Optional[str] @@ -43,6 +44,7 @@ class InstanceTypeInfo(NamedTuple): memory: Optional[float] price: float spot_price: float + region: str def get_catalog_path(filename: str) -> str: @@ -161,7 +163,7 @@ def instance_type_exists_impl(df: pd.DataFrame, instance_type: str) -> bool: def validate_region_zone_impl( - df: pd.DataFrame, region: Optional[str], + cloud_name: str, df: pd.DataFrame, region: Optional[str], zone: Optional[str]) -> Tuple[Optional[str], Optional[str]]: """Validates whether region and zone exist in the catalog.""" @@ -174,6 +176,11 @@ def _get_candidate_str(loc: str, all_loc: List[str]) -> str: candidate_strs = f'\nDid you mean one of these: {candidate_strs!r}?' return candidate_strs + def _get_all_supported_regions_str() -> str: + all_regions: List[str] = sorted(df['Region'].unique().tolist()) + return \ + f'\nList of supported {cloud_name} regions: {", ".join(all_regions)!r}' + validated_region, validated_zone = region, zone filter_df = df @@ -182,7 +189,12 @@ def _get_candidate_str(loc: str, all_loc: List[str]) -> str: if len(filter_df) == 0: with ux_utils.print_exception_no_traceback(): error_msg = (f'Invalid region {region!r}') - error_msg += _get_candidate_str(region, df['Region'].unique()) + candidate_strs = _get_candidate_str(region, + df['Region'].unique()) + if not candidate_strs: + error_msg += _get_all_supported_regions_str() + raise ValueError(error_msg) + error_msg += candidate_strs raise ValueError(error_msg) if zone is not None: @@ -321,6 +333,7 @@ def list_accelerators_impl( df: pd.DataFrame, gpus_only: bool, name_filter: Optional[str], + region_filter: Optional[str], case_sensitive: bool = True, ) -> Dict[str, List[InstanceTypeInfo]]: """Lists accelerators offered in a cloud service catalog. @@ -341,11 +354,16 @@ def list_accelerators_impl( 'MemoryGiB', 'Price', 'SpotPrice', + 'Region', ]].dropna(subset=['AcceleratorName']).drop_duplicates() if name_filter is not None: df = df[df['AcceleratorName'].str.contains(name_filter, case=case_sensitive, regex=True)] + if region_filter is not None: + df = df[df['Region'].str.contains(region_filter, + case=case_sensitive, + regex=True)] df['AcceleratorCount'] = df['AcceleratorCount'].astype(int) grouped = df.groupby('AcceleratorName') @@ -366,6 +384,7 @@ def make_list_from_df(rows): row['MemoryGiB'], row['Price'], row['SpotPrice'], + row['Region'], ), axis='columns', ).tolist() diff --git a/sky/clouds/service_catalog/gcp_catalog.py b/sky/clouds/service_catalog/gcp_catalog.py index 39e3eecb966..b2cb8f4b893 100644 --- a/sky/clouds/service_catalog/gcp_catalog.py +++ b/sky/clouds/service_catalog/gcp_catalog.py @@ -210,7 +210,7 @@ def get_instance_type_for_accelerator( def validate_region_zone( region: Optional[str], zone: Optional[str]) -> Tuple[Optional[str], Optional[str]]: - return common.validate_region_zone_impl(_df, region, zone) + return common.validate_region_zone_impl('gcp', _df, region, zone) def accelerator_in_region_or_zone(acc_name: str, @@ -273,11 +273,12 @@ def get_accelerator_hourly_cost(accelerator: str, def list_accelerators( gpus_only: bool, name_filter: Optional[str] = None, + region_filter: Optional[str] = None, case_sensitive: bool = True, ) -> Dict[str, List[common.InstanceTypeInfo]]: """Returns all instance types in GCP offering GPUs.""" results = common.list_accelerators_impl('GCP', _df, gpus_only, name_filter, - case_sensitive) + region_filter, case_sensitive) a100_infos = results.get('A100', []) + results.get('A100-80GB', []) if not a100_infos: diff --git a/tests/test_list_accelerators.py b/tests/test_list_accelerators.py index 56bde6b4f9d..c8e0579b203 100644 --- a/tests/test_list_accelerators.py +++ b/tests/test_list_accelerators.py @@ -17,6 +17,17 @@ def test_list_ccelerators_all(): assert 'A100-80GB' in result, result -def test_list_accelerators_filters(): +def test_list_accelerators_name_filter(): result = sky.list_accelerators(gpus_only=False, name_filter='V100') assert sorted(result.keys()) == ['V100', 'V100-32GB'], result + + +def test_list_accelerators_region_filter(): + result = sky.list_accelerators(gpus_only=False, + clouds="aws", + region_filter='us-west-1') + all_regions = [] + for res in result.values(): + for instance in res: + all_regions.append(instance.region) + assert all([region == 'us-west-1' for region in all_regions])