Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sky/feat/show-gpus] adding region based filtering for show-gpus command. #1187

Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 44 additions & 10 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2558,8 +2558,20 @@ def check():
default=None,
type=str,
help='Cloud provider to query.')
@click.option(
'--region',
required=False,
type=str,
help=(
'The region to use. If not given, shows accelerators from all regions. '
vivekkhimani marked this conversation as resolved.
Show resolved Hide resolved
),
)
@usage_lib.entrypoint
def show_gpus(gpu_name: Optional[str], all: bool, cloud: Optional[str]): # pylint: disable=redefined-builtin
def show_gpus(
gpu_name: Optional[str],
all: bool, # pylint: disable=redefined-builtin
cloud: Optional[str],
region: Optional[str]):
"""Show supported GPU/TPU/accelerators.

The names and counts shown can be set in the ``accelerators`` field in task
Expand All @@ -2576,6 +2588,11 @@ def show_gpus(gpu_name: Optional[str], all: bool, cloud: Optional[str]): # pyli
NOTE: The price displayed for each instance type is the lowest across all
regions for both on-demand and spot instances.
vivekkhimani marked this conversation as resolved.
Show resolved Hide resolved
"""
# validation for the --region flag
if region and not cloud:
vivekkhimani marked this conversation as resolved.
Show resolved Hide resolved
raise click.UsageError(
'The --region flag is only valid when the --cloud flag is set.')
service_catalog.validate_region_zone(region, None, clouds=cloud)
show_all = all
if show_all and gpu_name is not None:
raise click.UsageError('--all is only allowed without a GPU name.')
Expand All @@ -2592,8 +2609,11 @@ def _output():
['OTHER_GPU', 'AVAILABLE_QUANTITIES'])

if gpu_name is None:
result = service_catalog.list_accelerator_counts(gpus_only=True,
clouds=cloud)
result = service_catalog.list_accelerator_counts(
gpus_only=True,
clouds=cloud,
region_filter=region,
)
# NVIDIA GPUs
for gpu in service_catalog.get_common_gpus():
if gpu in result:
Expand Down Expand Up @@ -2621,6 +2641,7 @@ def _output():
# Show detailed accelerator information
result = service_catalog.list_accelerators(gpus_only=True,
name_filter=gpu_name,
region_filter=region,
clouds=cloud)
if len(result) == 0:
yield f'Resources \'{gpu_name}\' not found. '
Expand All @@ -2633,7 +2654,7 @@ def _output():
yield 'the host VM\'s cost is not included.\n\n'
import pandas as pd # pylint: disable=import-outside-toplevel
for i, (gpu, items) in enumerate(result.items()):
accelerator_table = log_utils.create_table([
accelerator_table_headers = [
'GPU',
'QTY',
'CLOUD',
Expand All @@ -2642,7 +2663,11 @@ def _output():
'HOST_MEMORY',
'HOURLY_PRICE',
'HOURLY_SPOT_PRICE',
])
]
if not show_all:
accelerator_table_headers.append('REGION')
accelerator_table = log_utils.create_table(
accelerator_table_headers)
for item in items:
instance_type_str = item.instance_type if not pd.isna(
item.instance_type) else '(attachable)'
Expand All @@ -2660,11 +2685,20 @@ def _output():
item.price) else '-'
spot_price_str = f'$ {item.spot_price:.3f}' if not pd.isna(
item.spot_price) else '-'
accelerator_table.add_row([
item.accelerator_name, item.accelerator_count, item.cloud,
instance_type_str, cpu_str, mem_str, price_str,
spot_price_str
])
region_str = item.region if not pd.isna(item.region) else '-'
accelerator_table_vals = [
item.accelerator_name,
item.accelerator_count,
item.cloud,
instance_type_str,
cpu_str,
mem_str,
price_str,
spot_price_str,
]
if not show_all:
accelerator_table_vals.append(region_str)
accelerator_table.add_row(accelerator_table_vals)

if i != 0:
yield '\n\n'
Expand Down
6 changes: 4 additions & 2 deletions sky/clouds/service_catalog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def _map_clouds_catalog(clouds: CloudFilter, method_name: str, *args, **kwargs):
def list_accelerators(
gpus_only: bool = True,
name_filter: Optional[str] = None,
region_filter: Optional[str] = None,
clouds: CloudFilter = None,
case_sensitive: bool = True,
) -> 'Dict[str, List[common.InstanceTypeInfo]]':
Expand All @@ -58,7 +59,7 @@ def list_accelerators(
of instance type offerings. See usage in cli.py.
"""
results = _map_clouds_catalog(clouds, 'list_accelerators', gpus_only,
name_filter, case_sensitive)
name_filter, region_filter, case_sensitive)
if not isinstance(results, list):
results = [results]
ret: Dict[str,
Expand All @@ -72,6 +73,7 @@ def list_accelerators(
def list_accelerator_counts(
gpus_only: bool = True,
name_filter: Optional[str] = None,
region_filter: Optional[str] = None,
clouds: CloudFilter = None,
) -> Dict[str, List[int]]:
"""List all accelerators offered by Sky and available counts.
Expand All @@ -80,7 +82,7 @@ def list_accelerator_counts(
of available counts. See usage in cli.py.
"""
results = _map_clouds_catalog(clouds, 'list_accelerators', gpus_only,
name_filter)
name_filter, region_filter, False)
if not isinstance(results, list):
results = [results]
accelerator_counts: Dict[str, Set[int]] = collections.defaultdict(set)
Expand Down
12 changes: 7 additions & 5 deletions sky/clouds/service_catalog/aws_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,15 @@ def get_region_zones_for_instance_type(instance_type: str,
return us_region_list + other_region_list


def list_accelerators(gpus_only: bool,
name_filter: Optional[str],
case_sensitive: bool = True
) -> Dict[str, List[common.InstanceTypeInfo]]:
def list_accelerators(
gpus_only: bool,
name_filter: Optional[str],
region_filter: Optional[str],
case_sensitive: bool = True
) -> Dict[str, List[common.InstanceTypeInfo]]:
"""Returns all instance types in AWS offering accelerators."""
return common.list_accelerators_impl('AWS', _df, gpus_only, name_filter,
case_sensitive)
region_filter, case_sensitive)


def get_image_id_from_tag(tag: str, region: Optional[str]) -> Optional[str]:
Expand Down
12 changes: 7 additions & 5 deletions sky/clouds/service_catalog/azure_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,12 @@ def get_gen_version_from_instance_type(instance_type: str) -> Optional[int]:
return _df[_df['InstanceType'] == instance_type]['Generation'].iloc[0]


def list_accelerators(gpus_only: bool,
name_filter: Optional[str],
case_sensitive: bool = True
) -> Dict[str, List[common.InstanceTypeInfo]]:
def list_accelerators(
gpus_only: bool,
name_filter: Optional[str],
region_filter: Optional[str],
case_sensitive: bool = True
) -> Dict[str, List[common.InstanceTypeInfo]]:
"""Returns all instance types in Azure offering GPUs."""
return common.list_accelerators_impl('Azure', _df, gpus_only, name_filter,
case_sensitive)
region_filter, case_sensitive)
20 changes: 19 additions & 1 deletion sky/clouds/service_catalog/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class InstanceTypeInfo(NamedTuple):
- memory: Instance memory in GiB.
- price: Regular instance price per hour (cheapest across all regions).
- spot_price: Spot instance price per hour (cheapest across all regions).
- region: Region where this instance type belongs to.
"""
cloud: str
instance_type: Optional[str]
Expand All @@ -43,6 +44,7 @@ class InstanceTypeInfo(NamedTuple):
memory: Optional[float]
price: float
spot_price: float
region: str


def get_catalog_path(filename: str) -> str:
Expand Down Expand Up @@ -174,6 +176,10 @@ def _get_candidate_str(loc: str, all_loc: List[str]) -> str:
candidate_strs = f'\nDid you mean one of these: {candidate_strs!r}?'
return candidate_strs

def _get_all_supported_regions_str() -> str:
all_regions: List[str] = sorted(df['Region'].unique().tolist())
return f'\nList of supported regions: {", ".join(all_regions)!r}'
vivekkhimani marked this conversation as resolved.
Show resolved Hide resolved

validated_region, validated_zone = region, zone

filter_df = df
Expand All @@ -182,7 +188,12 @@ def _get_candidate_str(loc: str, all_loc: List[str]) -> str:
if len(filter_df) == 0:
with ux_utils.print_exception_no_traceback():
error_msg = (f'Invalid region {region!r}')
error_msg += _get_candidate_str(region, df['Region'].unique())
candidate_strs = _get_candidate_str(region,
df['Region'].unique())
if not candidate_strs:
error_msg += _get_all_supported_regions_str()
raise ValueError(error_msg)
error_msg += candidate_strs
raise ValueError(error_msg)

if zone is not None:
Expand Down Expand Up @@ -321,6 +332,7 @@ def list_accelerators_impl(
df: pd.DataFrame,
gpus_only: bool,
name_filter: Optional[str],
region_filter: Optional[str],
case_sensitive: bool = True,
) -> Dict[str, List[InstanceTypeInfo]]:
"""Lists accelerators offered in a cloud service catalog.
Expand All @@ -341,11 +353,16 @@ def list_accelerators_impl(
'MemoryGiB',
'Price',
'SpotPrice',
'Region',
]].dropna(subset=['AcceleratorName']).drop_duplicates()
if name_filter is not None:
df = df[df['AcceleratorName'].str.contains(name_filter,
case=case_sensitive,
regex=True)]
if region_filter is not None:
df = df[df['Region'].str.contains(region_filter,
case=case_sensitive,
regex=True)]
df['AcceleratorCount'] = df['AcceleratorCount'].astype(int)
grouped = df.groupby('AcceleratorName')

Expand All @@ -366,6 +383,7 @@ def make_list_from_df(rows):
row['MemoryGiB'],
row['Price'],
row['SpotPrice'],
row['Region'],
),
axis='columns',
).tolist()
Expand Down
3 changes: 2 additions & 1 deletion sky/clouds/service_catalog/gcp_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,12 @@ def get_accelerator_hourly_cost(accelerator: str,
def list_accelerators(
gpus_only: bool,
name_filter: Optional[str] = None,
region_filter: Optional[str] = None,
case_sensitive: bool = True,
) -> Dict[str, List[common.InstanceTypeInfo]]:
"""Returns all instance types in GCP offering GPUs."""
results = common.list_accelerators_impl('GCP', _df, gpus_only, name_filter,
case_sensitive)
region_filter, case_sensitive)

a100_infos = results.get('A100', []) + results.get('A100-80GB', [])
if not a100_infos:
Expand Down
13 changes: 12 additions & 1 deletion tests/test_list_accelerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@ def test_list_ccelerators_all():
assert 'A100-80GB' in result, result


def test_list_accelerators_filters():
def test_list_accelerators_name_filter():
result = sky.list_accelerators(gpus_only=False, name_filter='V100')
assert sorted(result.keys()) == ['V100', 'V100-32GB'], result


def test_list_accelerators_region_filter():
result = sky.list_accelerators(gpus_only=False,
clouds="aws",
region_filter='us-west-1')
all_regions = []
for res in result.values():
for instance in res:
all_regions.append(instance.region)
assert all([region == 'us-west-1' for region in all_regions])