Skip to content

Commit

Permalink
feat: 🎸 Add --zones and --exclude flag to fetch_gcp (#2562)
Browse files Browse the repository at this point in the history
* feat: 🎸 Add --zones and --exclude flag to fetch_gcp

✅ Closes: #2555

* fix

* fix

* fix

* fix

* fix

* Clean up comments

* Lint

---------

Co-authored-by: Isaac Ong <[email protected]>
  • Loading branch information
sunny0826 and iojw authored Sep 21, 2023
1 parent 68a4be7 commit b6fc1ec
Showing 1 changed file with 45 additions and 7 deletions.
52 changes: 45 additions & 7 deletions sky/clouds/service_catalog/data_fetchers/fetch_gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import multiprocessing
import os
import textwrap
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Set

import google.auth
from googleapiclient import discovery
Expand Down Expand Up @@ -83,6 +83,8 @@
tpu_client = discovery.build('tpu', 'v1')

SINGLE_THREADED = False
ZONES: Set[str] = set()
EXCLUDED_REGIONS: Set[str] = set()


def get_skus(service_id: str) -> List[Dict[str, Any]]:
Expand Down Expand Up @@ -140,6 +142,26 @@ def _get_unit_price(sku: Dict[str, Any]) -> float:
return units + nanos


def filter_zones(func: Callable[[], List[str]]) -> Callable[[], List[str]]:
"""Decorator to filter the zones returned by the decorated function.
It first intersects the result with the global ZONES (if defined) and then
removes any zones present in the global EXCLUDED_REGIONS (if defined).
"""

def wrapper(*args, **kwargs) -> List[str]: # pylint: disable=redefined-outer-name
original_zones = set(func(*args, **kwargs))
if ZONES:
original_zones &= ZONES
if EXCLUDED_REGIONS:
original_zones -= EXCLUDED_REGIONS
if not original_zones:
raise ValueError('No zones to fetch. Please check your arguments.')
return list(original_zones)

return wrapper


@filter_zones
@functools.lru_cache(maxsize=None)
def _get_all_zones() -> List[str]:
zones_request = gcp_client.zones().list(project=project_id)
Expand Down Expand Up @@ -187,6 +209,9 @@ def _get_machine_types(region_prefix: str) -> pd.DataFrame:

def get_vm_df(skus: List[Dict[str, Any]], region_prefix: str) -> pd.DataFrame:
df = _get_machine_types(region_prefix)
if df.empty:
return df

# Drop the unsupported series.
df = df[df['InstanceType'].str.startswith(
tuple(f'{series}-' for series in SERIES_TO_DISCRIPTION))]
Expand Down Expand Up @@ -304,6 +329,8 @@ def get_gpu_df(skus: List[Dict[str, Any]], region_prefix: str) -> pd.DataFrame:
sku for sku in skus if sku['category']['resourceGroup'] == 'GPU'
]
df = _get_gpus(region_prefix)
if df.empty:
return df

def get_gpu_price(row: pd.Series, spot: bool) -> Optional[float]:
ondemand_or_spot = 'OnDemand' if not spot else 'Preemptible'
Expand Down Expand Up @@ -388,6 +415,8 @@ def _get_tpus() -> pd.DataFrame:
# TODO: the TPUs fetched fails to contain us-east1
def get_tpu_df(skus: List[Dict[str, Any]]) -> pd.DataFrame:
df = _get_tpus()
if df.empty:
return df

def get_tpu_price(row: pd.Series, spot: bool) -> float:
tpu_price = None
Expand Down Expand Up @@ -455,7 +484,8 @@ def get_catalog_df(region_prefix: str) -> pd.DataFrame:
# Drop regions without the given prefix.
# NOTE: We intentionally do not drop any TPU regions.
vm_df = vm_df[vm_df['Region'].str.startswith(region_prefix)]
gpu_df = gpu_df[gpu_df['Region'].str.startswith(region_prefix)]
gpu_df = gpu_df[gpu_df['Region'].str.startswith(
region_prefix)] if not gpu_df.empty else gpu_df

gcp_tpu_skus = get_skus(TPU_SERVICE_ID)
tpu_df = get_tpu_df(gcp_tpu_skus)
Expand Down Expand Up @@ -485,10 +515,16 @@ def get_catalog_df(region_prefix: str) -> pd.DataFrame:

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--all-regions',
action='store_true',
help='Fetch all global regions, not just the U.S. ones.')
group = parser.add_mutually_exclusive_group()
group.add_argument('--all-regions',
action='store_true',
help='Fetch all global regions, not just the U.S. ones.')
group.add_argument('--zones',
nargs='+',
help='Fetch the list of specified zones.')
parser.add_argument('--exclude',
nargs='+',
help='Exclude the list of specified regions.')
parser.add_argument('--single-threaded',
action='store_true',
help='Run in single-threaded mode. This is useful when '
Expand All @@ -498,8 +534,10 @@ def get_catalog_df(region_prefix: str) -> pd.DataFrame:
args = parser.parse_args()

SINGLE_THREADED = args.single_threaded
ZONES = set(args.zones) if args.zones else set()
EXCLUDED_REGIONS = set(args.exclude) if args.exclude else set()

region_prefix_filter = '' if args.all_regions else 'us-'
region_prefix_filter = '' if args.zones or args.all_regions else 'us-'
catalog_df = get_catalog_df(region_prefix_filter)

os.makedirs('gcp', exist_ok=True)
Expand Down

0 comments on commit b6fc1ec

Please sign in to comment.