diff --git a/pyproject.toml b/pyproject.toml index 565a91243f9..53a88eddab2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ env = [ "SKYPILOT_DEBUG=1", "SKYPILOT_DISABLE_USAGE_COLLECTION=1" ] -addopts = "-s -n 16 -q --tb=short --disable-warnings" +addopts = "-s -n 16 -q --tb=short --dist loadgroup --disable-warnings" [tool.mypy] python_version = "3.8" diff --git a/sky/__init__.py b/sky/__init__.py index 21eae850858..3462fd67dc5 100644 --- a/sky/__init__.py +++ b/sky/__init__.py @@ -28,6 +28,7 @@ AWS = clouds.AWS Azure = clouds.Azure GCP = clouds.GCP +Lambda = clouds.Lambda Local = clouds.Local optimize = Optimizer.optimize @@ -36,6 +37,7 @@ 'AWS', 'Azure', 'GCP', + 'Lambda', 'Local', 'Optimizer', 'OptimizeTarget', diff --git a/sky/authentication.py b/sky/authentication.py index 981d25a5b9e..ae1aaa66c8c 100644 --- a/sky/authentication.py +++ b/sky/authentication.py @@ -21,6 +21,7 @@ from sky.utils import common_utils from sky.utils import subprocess_utils from sky.utils import ux_utils +from sky.skylet.providers.lambda_cloud import lambda_utils logger = sky_logging.init_logger(__name__) @@ -310,3 +311,26 @@ def setup_azure_authentication(config: Dict[str, Any]) -> Dict[str, Any]: config['file_mounts'] = file_mounts return config + + +def setup_lambda_authentication(config: Dict[str, Any]) -> Dict[str, Any]: + get_or_generate_keys() + + # Ensure ssh key is registered with Lambda Cloud + lambda_client = lambda_utils.LambdaCloudClient() + if lambda_client.ssh_key_name is None: + public_key_path = os.path.expanduser(PUBLIC_SSH_KEY_PATH) + with open(public_key_path, 'r') as f: + public_key = f.read() + name = f'sky-key-{common_utils.get_user_hash()}' + lambda_client.set_ssh_key(name, public_key) + + # Need to use ~ relative path because Ray uses the same + # path for finding the public key path on both local and head node. + config['auth']['ssh_public_key'] = PUBLIC_SSH_KEY_PATH + + file_mounts = config['file_mounts'] + file_mounts[PUBLIC_SSH_KEY_PATH] = PUBLIC_SSH_KEY_PATH + config['file_mounts'] = file_mounts + + return config diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index 65124347889..40a7fed429a 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -44,6 +44,7 @@ from sky.backends import onprem_utils from sky.skylet import constants from sky.skylet import log_lib +from sky.skylet.providers.lambda_cloud import lambda_utils from sky.utils import common_utils from sky.utils import command_runner from sky.utils import env_options @@ -951,6 +952,8 @@ def _add_auth_to_cluster_config(cloud: clouds.Cloud, cluster_config_file: str): config = auth.setup_gcp_authentication(config) elif isinstance(cloud, clouds.Azure): config = auth.setup_azure_authentication(config) + elif isinstance(cloud, clouds.Lambda): + config = auth.setup_lambda_authentication(config) else: assert isinstance(cloud, clouds.Local), cloud # Local cluster case, authentication is already filled by the user @@ -1802,10 +1805,29 @@ def _query_status_azure( return _process_cli_query('Azure', cluster, query_cmd, '\t', status_map) +def _query_status_lambda( + cluster: str, + ray_config: Dict[str, Any], # pylint: disable=unused-argument +) -> List[global_user_state.ClusterStatus]: + status_map = { + 'booting': global_user_state.ClusterStatus.INIT, + 'active': global_user_state.ClusterStatus.UP, + 'unhealthy': global_user_state.ClusterStatus.INIT, + 'terminated': None, + } + # TODO(ewzeng): filter by hash_filter_string to be safe + vms = lambda_utils.LambdaCloudClient().list_instances() + for node in vms: + if node['name'] == cluster: + return [status_map[node['status']]] + return [] + + _QUERY_STATUS_FUNCS = { 'AWS': _query_status_aws, 'GCP': _query_status_gcp, 'Azure': _query_status_azure, + 'Lambda': _query_status_lambda, } diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 209b3a6ccab..b9d580b1f7f 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -114,6 +114,7 @@ def _get_cluster_config_template(cloud): clouds.AWS: 'aws-ray.yml.j2', clouds.Azure: 'azure-ray.yml.j2', clouds.GCP: 'gcp-ray.yml.j2', + clouds.Lambda: 'lambda-ray.yml.j2', clouds.Local: 'local-ray.yml.j2', } return cloud_to_template[type(cloud)] @@ -574,12 +575,14 @@ class GangSchedulingStatus(enum.Enum): def __init__(self, log_dir: str, dag: 'dag.Dag', optimize_target: OptimizeTarget, + requested_features: Set[clouds.CloudImplementationFeatures], local_wheel_path: pathlib.Path, wheel_hash: str): self._blocked_resources = set() self.log_dir = os.path.expanduser(log_dir) self._dag = dag self._optimize_target = optimize_target + self._requested_features = requested_features self._local_wheel_path = local_wheel_path self._wheel_hash = wheel_hash @@ -775,6 +778,42 @@ def _update_blocklist_on_azure_error( else: self._blocked_resources.add(launchable_resources.copy(zone=None)) + def _update_blocklist_on_lambda_error( + self, launchable_resources: 'resources_lib.Resources', region, + zones, stdout, stderr): + del zones # Unused. + style = colorama.Style + stdout_splits = stdout.split('\n') + stderr_splits = stderr.split('\n') + errors = [ + s.strip() + for s in stdout_splits + stderr_splits + if 'LambdaCloudError:' in s.strip() + ] + if not errors: + logger.info('====== stdout ======') + for s in stdout_splits: + print(s) + logger.info('====== stderr ======') + for s in stderr_splits: + print(s) + with ux_utils.print_exception_no_traceback(): + raise RuntimeError('Errors occurred during provision; ' + 'check logs above.') + + logger.warning(f'Got error(s) in {region.name}:') + messages = '\n\t'.join(errors) + logger.warning(f'{style.DIM}\t{messages}{style.RESET_ALL}') + self._blocked_resources.add(launchable_resources.copy(zone=None)) + + # Sometimes, LambdaCloudError will list available regions. + for e in errors: + if e.find('Regions with capacity available:') != -1: + for r in clouds.Lambda.regions(): + if e.find(r.name) == -1: + self._blocked_resources.add( + launchable_resources.copy(region=r.name, zone=None)) + def _update_blocklist_on_local_error( self, launchable_resources: 'resources_lib.Resources', region, zones, stdout, stderr): @@ -834,6 +873,7 @@ def _update_blocklist_on_error( clouds.AWS: self._update_blocklist_on_aws_error, clouds.Azure: self._update_blocklist_on_azure_error, clouds.GCP: self._update_blocklist_on_gcp_error, + clouds.Lambda: self._update_blocklist_on_lambda_error, clouds.Local: self._update_blocklist_on_local_error, } cloud = launchable_resources.cloud @@ -888,6 +928,9 @@ def _yield_region_zones(self, to_provision: resources_lib.Resources, elif cloud.is_same_cloud(clouds.Azure()): region = config['provider']['location'] zones = None + elif cloud.is_same_cloud(clouds.Lambda()): + region = config['provider']['region'] + zones = None elif cloud.is_same_cloud(clouds.Local()): local_regions = clouds.Local.regions() region = local_regions[0].name @@ -1636,6 +1679,15 @@ def provision_with_retries( cloud_user = None else: cloud_user = to_provision.cloud.get_current_user_identity() + # Skip if to_provision.cloud does not support requested features + if not to_provision.cloud.supports(self._requested_features): + self._blocked_resources.add( + resources_lib.Resources(cloud=to_provision.cloud)) + requested_features_str = ', '.join( + [f.value for f in self._requested_features]) + raise exceptions.ResourcesUnavailableError( + f'{to_provision.cloud} does not support all the ' + f'features in [{requested_features_str}].') config_dict = self._retry_region_zones( to_provision, num_nodes, @@ -1952,6 +2004,7 @@ def __init__(self): self._dag = None self._optimize_target = None + self._requested_features = set() # Command for running the setup script. It is only set when the # setup needs to be run outside the self._setup() and as part of @@ -1964,6 +2017,8 @@ def register_info(self, **kwargs) -> None: self._dag = kwargs.pop('dag', self._dag) self._optimize_target = kwargs.pop( 'optimize_target', self._optimize_target) or OptimizeTarget.COST + self._requested_features = kwargs.pop('requested_features', + self._requested_features) assert len(kwargs) == 0, f'Unexpected kwargs: {kwargs}' def check_resources_fit_cluster(self, handle: ResourceHandle, @@ -2087,10 +2142,9 @@ def _provision(self, # in if retry_until_up is set, which will kick off new "rounds" # of optimization infinitely. try: - provisioner = RetryingVmProvisioner(self.log_dir, self._dag, - self._optimize_target, - local_wheel_path, - wheel_hash) + provisioner = RetryingVmProvisioner( + self.log_dir, self._dag, self._optimize_target, + self._requested_features, local_wheel_path, wheel_hash) config_dict = provisioner.provision_with_retries( task, to_provision_config, dryrun, stream_logs) break diff --git a/sky/clouds/__init__.py b/sky/clouds/__init__.py index 053cfecc351..ca442dcb5b3 100644 --- a/sky/clouds/__init__.py +++ b/sky/clouds/__init__.py @@ -1,11 +1,13 @@ """Clouds in Sky.""" from sky.clouds.cloud import Cloud from sky.clouds.cloud import CLOUD_REGISTRY +from sky.clouds.cloud import CloudImplementationFeatures from sky.clouds.cloud import Region from sky.clouds.cloud import Zone from sky.clouds.aws import AWS from sky.clouds.azure import Azure from sky.clouds.gcp import GCP +from sky.clouds.lambda_cloud import Lambda from sky.clouds.local import Local __all__ = [ @@ -13,7 +15,9 @@ 'Azure', 'Cloud', 'GCP', + 'Lambda', 'Local', + 'CloudImplementationFeatures', 'Region', 'Zone', 'CLOUD_REGISTRY', diff --git a/sky/clouds/aws.py b/sky/clouds/aws.py index 48c9282420a..8892fc70b79 100644 --- a/sky/clouds/aws.py +++ b/sky/clouds/aws.py @@ -6,7 +6,7 @@ import os import subprocess import typing -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Dict, Iterator, List, Optional, Set, Tuple from sky import clouds from sky import exceptions @@ -544,3 +544,11 @@ def accelerator_in_region_or_zone(self, zone: Optional[str] = None) -> bool: return service_catalog.accelerator_in_region_or_zone( accelerator, acc_count, region, zone, 'aws') + + @classmethod + def supports( + cls, requested_features: Set[clouds.CloudImplementationFeatures] + ) -> bool: + # All clouds.CloudImplementationFeatures implemented + del requested_features + return True diff --git a/sky/clouds/azure.py b/sky/clouds/azure.py index 299f4458457..85fb2baadca 100644 --- a/sky/clouds/azure.py +++ b/sky/clouds/azure.py @@ -3,7 +3,7 @@ import os import subprocess import typing -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Dict, Iterator, List, Optional, Set, Tuple from sky import clouds from sky import exceptions @@ -406,3 +406,11 @@ def get_project_id(cls, dryrun: bool = False) -> str: 'cli command: "az account set -s ".' ) from e return azure_subscription_id + + @classmethod + def supports( + cls, requested_features: Set[clouds.CloudImplementationFeatures] + ) -> bool: + # All clouds.CloudImplementationFeatures implemented + del requested_features + return True diff --git a/sky/clouds/cloud.py b/sky/clouds/cloud.py index 9c50b93b7b5..82afb38e41f 100644 --- a/sky/clouds/cloud.py +++ b/sky/clouds/cloud.py @@ -1,7 +1,8 @@ """Interfaces: clouds, regions, and zones.""" +import enum import collections import typing -from typing import Dict, Iterator, List, Optional, Tuple, Type +from typing import Dict, Iterator, List, Optional, Set, Tuple, Type from sky.clouds import service_catalog from sky.utils import ux_utils @@ -10,6 +11,16 @@ from sky import resources +class CloudImplementationFeatures(enum.Enum): + """Features that might not be implemented for all clouds. + + Used by Cloud.supports() + """ + STOP = 'stop' + AUTOSTOP = 'autostop' + MULTI_NODE = 'multi-node' + + class Region(collections.namedtuple('Region', ['name'])): """A region.""" name: str @@ -308,5 +319,15 @@ def need_cleanup_after_preemption(self, del resource return False + @classmethod + def supports(cls, + requested_features: Set[CloudImplementationFeatures]) -> bool: + """Returns whether all of the requested features are supported. + + For instance, Lambda Cloud does not support autostop, so + Lambda.support({CloudImplementationFeatures.AUTOSTOP}) returns False. + """ + raise NotImplementedError + def __repr__(self): return self._REPR diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index 28c79308678..2ae332237c8 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -4,7 +4,7 @@ import subprocess import time import typing -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Dict, Iterator, List, Optional, Set, Tuple from sky import clouds from sky import exceptions @@ -664,3 +664,11 @@ def check_accelerator_attachable_to_host( zone: Optional[str] = None) -> None: service_catalog.check_accelerator_attachable_to_host( instance_type, accelerators, zone, 'gcp') + + @classmethod + def supports( + cls, requested_features: Set[clouds.CloudImplementationFeatures] + ) -> bool: + # All clouds.CloudImplementationFeatures implemented + del requested_features + return True diff --git a/sky/clouds/lambda_cloud.py b/sky/clouds/lambda_cloud.py new file mode 100644 index 00000000000..bd52f4af609 --- /dev/null +++ b/sky/clouds/lambda_cloud.py @@ -0,0 +1,256 @@ +"""Lambda Cloud.""" +import json +import typing +from typing import Dict, Iterator, List, Optional, Set, Tuple + +from sky import clouds +from sky.clouds import service_catalog +from sky.skylet.providers.lambda_cloud import lambda_utils + +if typing.TYPE_CHECKING: + # Renaming to avoid shadowing variables. + from sky import resources as resources_lib + +# Minimum set of files under ~/.lambda_cloud that grant Lambda Cloud access. +_CREDENTIAL_FILES = [ + 'lambda_keys', +] + +# Currently, none of clouds.CloudImplementationFeatures are implemented +# for Lambda Cloud +_LAMBDA_IMPLEMENTATION_FEATURES: Set[clouds.CloudImplementationFeatures] = set() + + +@clouds.CLOUD_REGISTRY.register +class Lambda(clouds.Cloud): + """Lambda Labs GPU Cloud.""" + + _REPR = 'Lambda' + _regions: List[clouds.Region] = [] + + @classmethod + def regions(cls) -> List[clouds.Region]: + if not cls._regions: + cls._regions = [ + # Popular US regions + clouds.Region('us-east-1'), + clouds.Region('us-west-2'), + clouds.Region('us-west-1'), + clouds.Region('us-south-1'), + + # Everyone else + clouds.Region('asia-northeast-1'), + clouds.Region('asia-northeast-2'), + clouds.Region('asia-south-1'), + clouds.Region('australia-southeast-1'), + clouds.Region('europe-central-1'), + clouds.Region('europe-south-1'), + clouds.Region('me-west-1'), + ] + return cls._regions + + @classmethod + def regions_with_offering(cls, instance_type: Optional[str], + accelerators: Optional[Dict[str, int]], + use_spot: bool, region: Optional[str], + zone: Optional[str]) -> List[clouds.Region]: + del accelerators, zone # unused + if use_spot: + return [] + if instance_type is None: + # Fall back to default regions + regions = cls.regions() + else: + regions = service_catalog.get_region_zones_for_instance_type( + instance_type, use_spot, 'lambda') + + if region is not None: + regions = [r for r in regions if r.name == region] + return regions + + @classmethod + def region_zones_provision_loop( + cls, + *, + instance_type: Optional[str] = None, + accelerators: Optional[Dict[str, int]] = None, + use_spot: bool = False, + ) -> Iterator[Tuple[clouds.Region, List[clouds.Zone]]]: + regions = cls.regions_with_offering(instance_type, + accelerators, + use_spot, + region=None, + zone=None) + for region in regions: + yield region, region.zones + + def instance_type_to_hourly_cost(self, + instance_type: str, + use_spot: bool, + region: Optional[str] = None, + zone: Optional[str] = None) -> float: + return service_catalog.get_hourly_cost(instance_type, + use_spot=use_spot, + region=region, + zone=zone, + clouds='lambda') + + def accelerators_to_hourly_cost(self, + accelerators: Dict[str, int], + use_spot: bool, + region: Optional[str] = None, + zone: Optional[str] = None) -> float: + del accelerators, use_spot, region, zone # unused + # Lambda includes accelerators as part of the instance type. + return 0.0 + + def get_egress_cost(self, num_gigabytes: float) -> float: + return 0.0 + + def __repr__(self): + return 'Lambda' + + def is_same_cloud(self, other: clouds.Cloud) -> bool: + # Returns true if the two clouds are the same cloud type. + return isinstance(other, Lambda) + + @classmethod + def get_default_instance_type(cls, + cpus: Optional[str] = None) -> Optional[str]: + return service_catalog.get_default_instance_type(cpus=cpus, + clouds='lambda') + + @classmethod + def get_accelerators_from_instance_type( + cls, + instance_type: str, + ) -> Optional[Dict[str, int]]: + return service_catalog.get_accelerators_from_instance_type( + instance_type, clouds='lambda') + + @classmethod + def get_vcpus_from_instance_type( + cls, + instance_type: str, + ) -> Optional[float]: + return service_catalog.get_vcpus_from_instance_type(instance_type, + clouds='lambda') + + @classmethod + def get_zone_shell_cmd(cls) -> Optional[str]: + return None + + def make_deploy_resources_variables( + self, resources: 'resources_lib.Resources', + region: Optional['clouds.Region'], + zones: Optional[List['clouds.Zone']]) -> Dict[str, Optional[str]]: + del zones + if region is None: + region = self._get_default_region() + + r = resources + acc_dict = self.get_accelerators_from_instance_type(r.instance_type) + if acc_dict is not None: + custom_resources = json.dumps(acc_dict, separators=(',', ':')) + else: + custom_resources = None + + return { + 'instance_type': resources.instance_type, + 'custom_resources': custom_resources, + 'region': region.name, + } + + def get_feasible_launchable_resources(self, + resources: 'resources_lib.Resources'): + if resources.use_spot: + return ([], []) + if resources.instance_type is not None: + assert resources.is_launchable(), resources + # Accelerators are part of the instance type in Lambda Cloud + resources = resources.copy(accelerators=None) + return ([resources], []) + + def _make(instance_list): + resource_list = [] + for instance_type in instance_list: + r = resources.copy( + cloud=Lambda(), + instance_type=instance_type, + # Setting this to None as Lambda doesn't separately bill / + # attach the accelerators. Billed as part of the VM type. + accelerators=None, + cpus=None, + ) + resource_list.append(r) + return resource_list + + # Currently, handle a filter on accelerators only. + accelerators = resources.accelerators + if accelerators is None: + # Return a default instance type with the given number of vCPUs. + default_instance_type = Lambda.get_default_instance_type( + cpus=resources.cpus) + if default_instance_type is None: + return ([], []) + else: + return (_make([default_instance_type]), []) + + assert len(accelerators) == 1, resources + acc, acc_count = list(accelerators.items())[0] + (instance_list, fuzzy_candidate_list + ) = service_catalog.get_instance_type_for_accelerator( + acc, + acc_count, + use_spot=resources.use_spot, + cpus=resources.cpus, + region=resources.region, + zone=resources.zone, + clouds='lambda') + if instance_list is None: + return ([], fuzzy_candidate_list) + return (_make(instance_list), fuzzy_candidate_list) + + def check_credentials(self) -> Tuple[bool, Optional[str]]: + try: + lambda_utils.LambdaCloudClient().list_instances() + except (AssertionError, KeyError, lambda_utils.LambdaCloudError): + return False, ('Failed to access Lambda Cloud with credentials. ' + 'To configure credentials, go to:\n ' + ' https://cloud.lambdalabs.com/api-keys\n ' + 'to generate API key and add the line\n ' + ' api_key = [YOUR API KEY]\n ' + 'to ~/.lambda_cloud/lambda_keys') + return True, None + + def get_credential_file_mounts(self) -> Dict[str, str]: + return { + f'~/.lambda_cloud/{filename}': f'~/.lambda_cloud/{filename}' + for filename in _CREDENTIAL_FILES + } + + def get_current_user_identity(self) -> Optional[str]: + # TODO(ewzeng): Implement get_current_user_identity for Lambda + return None + + def instance_type_exists(self, instance_type: str) -> bool: + return service_catalog.instance_type_exists(instance_type, 'lambda') + + def validate_region_zone(self, region: Optional[str], zone: Optional[str]): + return service_catalog.validate_region_zone(region, + zone, + clouds='lambda') + + def accelerator_in_region_or_zone(self, + accelerator: str, + acc_count: int, + region: Optional[str] = None, + zone: Optional[str] = None) -> bool: + return service_catalog.accelerator_in_region_or_zone( + accelerator, acc_count, region, zone, 'lambda') + + @classmethod + def supports( + cls, requested_features: Set[clouds.CloudImplementationFeatures] + ) -> bool: + return requested_features.issubset(_LAMBDA_IMPLEMENTATION_FEATURES) diff --git a/sky/clouds/local.py b/sky/clouds/local.py index bad0d79b8c7..beb59a53909 100644 --- a/sky/clouds/local.py +++ b/sky/clouds/local.py @@ -1,7 +1,7 @@ """Local/On-premise.""" import subprocess import typing -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Dict, Iterator, List, Optional, Set, Tuple from sky import clouds @@ -172,3 +172,10 @@ def validate_region_zone(self, region: Optional[str], zone: Optional[str]): raise ValueError(f'Region {region!r} does not match the Local' ' cloud region {Local.LOCAL_REGION.name!r}.') return region, zone + + @classmethod + def supports( + cls, requested_features: Set[clouds.CloudImplementationFeatures] + ) -> bool: + del requested_features + return True diff --git a/sky/clouds/service_catalog/__init__.py b/sky/clouds/service_catalog/__init__.py index 00680e2c294..23d90faf30b 100644 --- a/sky/clouds/service_catalog/__init__.py +++ b/sky/clouds/service_catalog/__init__.py @@ -15,7 +15,7 @@ from sky.clouds.service_catalog import common CloudFilter = Optional[Union[List[str], str]] -_ALL_CLOUDS = ('aws', 'azure', 'gcp') +_ALL_CLOUDS = ('aws', 'azure', 'gcp', 'lambda') def _map_clouds_catalog(clouds: CloudFilter, method_name: str, *args, **kwargs): diff --git a/sky/clouds/service_catalog/lambda_catalog.py b/sky/clouds/service_catalog/lambda_catalog.py new file mode 100644 index 00000000000..21975719ef5 --- /dev/null +++ b/sky/clouds/service_catalog/lambda_catalog.py @@ -0,0 +1,125 @@ +"""Lambda Cloud Catalog. + +This module loads the service catalog file and can be used to query +instance types and pricing information for Lambda. +""" +import typing +from typing import Dict, List, Optional, Tuple + +from sky.clouds.service_catalog import common +from sky.utils import ux_utils + +if typing.TYPE_CHECKING: + from sky.clouds import cloud + +_df = common.read_catalog('lambda/vms.csv') + +# Number of vCPUS for gpu_1x_a100_sxm4 +_DEFAULT_NUM_VCPUS = 30 + + +def instance_type_exists(instance_type: str) -> bool: + return common.instance_type_exists_impl(_df, instance_type) + + +def validate_region_zone( + region: Optional[str], + zone: Optional[str]) -> Tuple[Optional[str], Optional[str]]: + if zone is not None: + with ux_utils.print_exception_no_traceback(): + raise ValueError('Lambda Cloud does not support zones.') + return common.validate_region_zone_impl('lambda', _df, region, zone) + + +def accelerator_in_region_or_zone(acc_name: str, + acc_count: int, + region: Optional[str] = None, + zone: Optional[str] = None) -> bool: + if zone is not None: + with ux_utils.print_exception_no_traceback(): + raise ValueError('Lambda Cloud does not support zones.') + return common.accelerator_in_region_or_zone_impl(_df, acc_name, acc_count, + region, zone) + + +def get_hourly_cost(instance_type: str, + use_spot: bool = False, + region: Optional[str] = None, + zone: Optional[str] = None) -> float: + """Returns the cost, or the cheapest cost among all zones for spot.""" + assert not use_spot, 'Lambda Cloud does not support spot.' + if zone is not None: + with ux_utils.print_exception_no_traceback(): + raise ValueError('Lambda Cloud does not support zones.') + return common.get_hourly_cost_impl(_df, instance_type, use_spot, region, + zone) + + +def get_vcpus_from_instance_type(instance_type: str) -> Optional[float]: + return common.get_vcpus_from_instance_type_impl(_df, instance_type) + + +def get_default_instance_type(cpus: Optional[str] = None) -> Optional[str]: + if cpus is None: + cpus = str(_DEFAULT_NUM_VCPUS) + # Set to gpu_1x_a100_sxm4 to be the default instance type if match vCPU + # requirement. + df = _df[_df['InstanceType'].eq('gpu_1x_a100_sxm4')] + instance = common.get_instance_type_for_cpus_impl(df, cpus) + if not instance: + instance = common.get_instance_type_for_cpus_impl(_df, cpus) + return instance + + +def get_accelerators_from_instance_type( + instance_type: str) -> Optional[Dict[str, int]]: + return common.get_accelerators_from_instance_type_impl(_df, instance_type) + + +def get_instance_type_for_accelerator( + acc_name: str, + acc_count: int, + cpus: Optional[str] = None, + use_spot: bool = False, + region: Optional[str] = None, + zone: Optional[str] = None) -> Tuple[Optional[List[str]], List[str]]: + """ + Returns a list of instance types satisfying the required count of + accelerators with sorted prices and a list of candidates with fuzzy search. + """ + if zone is not None: + with ux_utils.print_exception_no_traceback(): + raise ValueError('Lambda Cloud does not support zones.') + return common.get_instance_type_for_accelerator_impl(df=_df, + acc_name=acc_name, + acc_count=acc_count, + cpus=cpus, + use_spot=use_spot, + region=region, + zone=zone) + + +def get_region_zones_for_instance_type(instance_type: str, + use_spot: bool) -> List['cloud.Region']: + df = _df[_df['InstanceType'] == instance_type] + region_list = common.get_region_zones(df, use_spot) + # Hack: Enforce US regions are always tried first + us_region_list = [] + other_region_list = [] + for region in region_list: + if region.name.startswith('us-'): + us_region_list.append(region) + else: + other_region_list.append(region) + return us_region_list + other_region_list + + +def list_accelerators( + gpus_only: bool, + name_filter: Optional[str], + region_filter: Optional[str], + case_sensitive: bool = True +) -> Dict[str, List[common.InstanceTypeInfo]]: + """Returns all instance types in Lambda offering GPUs.""" + return common.list_accelerators_impl('Lambda', _df, gpus_only, name_filter, + region_filter, case_sensitive) diff --git a/sky/core.py b/sky/core.py index 1242c0fa3aa..1ad78b01fe1 100644 --- a/sky/core.py +++ b/sky/core.py @@ -4,6 +4,7 @@ import sys from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from sky import clouds from sky import dag from sky import task from sky import backends @@ -309,6 +310,12 @@ def stop(cluster_name: str, purge: bool = False) -> None: f'Stopping cluster {cluster_name!r} with TPU VM Pod ' 'is not supported.') + # Check cloud supports stopping instances + cloud = handle.launched_resources.cloud + if not cloud.supports({clouds.CloudImplementationFeatures.STOP}): + raise exceptions.NotSupportedError( + (f'{cloud} does not support stopping instances.')) + backend = backend_utils.get_backend_from_handle(handle) if (isinstance(backend, backends.CloudVmRayBackend) and handle.launched_resources.use_spot): @@ -427,6 +434,13 @@ def autostop( f'{operation} cluster {cluster_name!r} with TPU VM Pod ' 'is not supported.') + # Check autostop is implemented for cloud + cloud = handle.launched_resources.cloud + if not down and idle_minutes >= 0: + if not cloud.supports({clouds.CloudImplementationFeatures.AUTOSTOP}): + raise exceptions.NotSupportedError( + (f'autostop not implemented for {cloud}.')) + backend = backend_utils.get_backend_from_handle(handle) usage_lib.record_cluster_name_for_current_operation(cluster_name) backend.set_autostop(handle, idle_minutes, down) diff --git a/sky/data/storage.py b/sky/data/storage.py index 513a173cc11..f16b73282fe 100644 --- a/sky/data/storage.py +++ b/sky/data/storage.py @@ -87,6 +87,9 @@ def get_storetype_from_cloud(cloud: clouds.Cloud) -> StoreType: elif isinstance(cloud, clouds.Azure): with ux_utils.print_exception_no_traceback(): raise ValueError('Azure Blob Storage is not supported yet.') + elif isinstance(cloud, clouds.Lambda): + with ux_utils.print_exception_no_traceback(): + raise ValueError('Lambda Cloud does not provide cloud storage.') else: with ux_utils.print_exception_no_traceback(): raise ValueError(f'Unknown cloud type: {cloud}') diff --git a/sky/execution.py b/sky/execution.py index 8eb85451bae..6038ba16c19 100644 --- a/sky/execution.py +++ b/sky/execution.py @@ -23,6 +23,7 @@ import sky from sky import backends +from sky import clouds from sky import exceptions from sky import global_user_state from sky import optimizer @@ -188,6 +189,12 @@ def _execute( stages = stages if stages is not None else list(Stage) + # Requested features that some clouds support and others don't. + requested_features = set() + + if task.num_nodes > 1: + requested_features.add(clouds.CloudImplementationFeatures.MULTI_NODE) + backend = backend if backend is not None else backends.CloudVmRayBackend() if isinstance(backend, backends.CloudVmRayBackend): if down and idle_minutes_to_autostop is None: @@ -208,6 +215,11 @@ def _execute( f'{colorama.Style.RESET_ALL}') idle_minutes_to_autostop = 1 stages.remove(Stage.DOWN) + + if not down: + requested_features.add( + clouds.CloudImplementationFeatures.AUTOSTOP) + elif idle_minutes_to_autostop is not None: # TODO(zhwu): Autostop is not supported for non-CloudVmRayBackend. with ux_utils.print_exception_no_traceback(): @@ -241,7 +253,9 @@ def _execute( task = dag.tasks[0] # Keep: dag may have been deep-copied. assert task.best_resources is not None, task - backend.register_info(dag=dag, optimize_target=optimize_target) + backend.register_info(dag=dag, + optimize_target=optimize_target, + requested_features=requested_features) if task.storage_mounts is not None: # Optimizer should eventually choose where to store bucket diff --git a/sky/registry.py b/sky/registry.py index bd1606700a9..553464f21f2 100644 --- a/sky/registry.py +++ b/sky/registry.py @@ -13,6 +13,7 @@ clouds.AWS(), clouds.Azure(), clouds.GCP(), + clouds.Lambda(), ] diff --git a/sky/setup_files/MANIFEST.in b/sky/setup_files/MANIFEST.in index 7981e7aaef9..f6ba9e298de 100644 --- a/sky/setup_files/MANIFEST.in +++ b/sky/setup_files/MANIFEST.in @@ -4,6 +4,7 @@ include sky/skylet/providers/aws/* include sky/skylet/providers/aws/cloudwatch/* include sky/skylet/providers/azure/* include sky/skylet/providers/gcp/* +include sky/skylet/providers/lambda_cloud/* include sky/skylet/ray_patches/*.patch include sky/templates/* include sky/setup_files/* diff --git a/sky/skylet/providers/lambda_cloud/__init__.py b/sky/skylet/providers/lambda_cloud/__init__.py new file mode 100644 index 00000000000..64dac295eb5 --- /dev/null +++ b/sky/skylet/providers/lambda_cloud/__init__.py @@ -0,0 +1,2 @@ +"""Lambda Cloud node provider""" +from sky.skylet.providers.lambda_cloud.node_provider import LambdaNodeProvider diff --git a/sky/skylet/providers/lambda_cloud/lambda_utils.py b/sky/skylet/providers/lambda_cloud/lambda_utils.py new file mode 100644 index 00000000000..0ae886070e1 --- /dev/null +++ b/sky/skylet/providers/lambda_cloud/lambda_utils.py @@ -0,0 +1,185 @@ +"""Lambda Cloud helper functions.""" +import os +import json +import requests +from typing import Any, Dict, List + +CREDENTIALS_PATH = '~/.lambda_cloud/lambda_keys' +API_ENDPOINT = 'https://cloud.lambdalabs.com/api/v1' + + +class LambdaCloudError(Exception): + pass + + +class Metadata: + """Per-cluster metadata file.""" + + def __init__(self, path_prefix: str, cluster_name: str) -> None: + # TODO(ewzeng): Metadata file is not thread safe. This is fine for + # now since SkyPilot uses a per-cluster lock for ray-related + # operations. In the future, add a filelock around __getitem__, + # __setitem__ and refresh. + self.path = os.path.expanduser(f'{path_prefix}-{cluster_name}') + # In case parent directory does not exist + os.makedirs(os.path.dirname(self.path), exist_ok=True) + + def __getitem__(self, instance_id: str) -> Dict[str, Any]: + assert os.path.exists(self.path), 'Metadata file not found' + with open(self.path, 'r') as f: + metadata = json.load(f) + return metadata.get(instance_id) + + def __setitem__(self, instance_id: str, value: Dict[str, Any]) -> None: + # Read from metadata file + if os.path.exists(self.path): + with open(self.path, 'r') as f: + metadata = json.load(f) + else: + metadata = {} + # Update metadata + if value is None: + if instance_id in metadata: + metadata.pop(instance_id) # del entry + if len(metadata) == 0: + if os.path.exists(self.path): + os.remove(self.path) + return + else: + metadata[instance_id] = value + # Write to metadata file + with open(self.path, 'w') as f: + json.dump(metadata, f) + + def refresh(self, instance_ids: List[str]) -> None: + """Remove all tags for instances not in instance_ids.""" + if not os.path.exists(self.path): + return + with open(self.path, 'r') as f: + metadata = json.load(f) + for instance_id in list(metadata.keys()): + if instance_id not in instance_ids: + del metadata[instance_id] + if len(metadata) == 0: + os.remove(self.path) + return + with open(self.path, 'w') as f: + json.dump(metadata, f) + + +def raise_lambda_error(response: requests.Response) -> None: + """Raise LambdaCloudError if appropriate. """ + status_code = response.status_code + if status_code == 200: + return + if status_code == 429: + # https://docs.lambdalabs.com/cloud/rate-limiting/ + raise LambdaCloudError('Your API requests are being rate limited.') + try: + resp_json = response.json() + code = resp_json['error']['code'] + message = resp_json['error']['message'] + except (KeyError, json.decoder.JSONDecodeError): + raise LambdaCloudError(f'Unexpected error. Status code: {status_code}') + raise LambdaCloudError(f'{code}: {message}') + + +class LambdaCloudClient: + """Wrapper functions for Lambda Cloud API.""" + + def __init__(self) -> None: + self.credentials = os.path.expanduser(CREDENTIALS_PATH) + assert os.path.exists(self.credentials), 'Credentials not found' + with open(self.credentials, 'r') as f: + lines = [line.strip() for line in f.readlines() if ' = ' in line] + self._credentials = { + line.split(' = ')[0]: line.split(' = ')[1] + for line in lines + } + self.api_key = self._credentials['api_key'] + self.ssh_key_name = self._credentials.get('ssh_key_name', None) + self.headers = {'Authorization': f'Bearer {self.api_key}'} + + def create_instances(self, + instance_type: str = 'gpu_1x_a100_sxm4', + region: str = 'us-east-1', + quantity: int = 1, + name: str = '') -> Dict[str, Any]: + """Launch new instances.""" + assert self.ssh_key_name is not None + + # Optimization: + # Most API requests are rate limited at ~1 request every second but + # launch requests are rate limited at ~1 request every 10 seconds. + # So don't use launch requests to check availability. + # See https://docs.lambdalabs.com/cloud/rate-limiting/ for more. + available_regions = self.list_catalog()[instance_type]\ + ['regions_with_capacity_available'] + available_regions = [reg['name'] for reg in available_regions] + if region not in available_regions: + if len(available_regions) > 0: + aval_reg = ' '.join(available_regions) + else: + aval_reg = 'None' + raise LambdaCloudError(('instance-operations/launch/' + 'insufficient-capacity: Not enough ' + 'capacity to fulfill launch request. ' + 'Regions with capacity available: ' + f'{aval_reg}')) + + # Try to launch instance + data = json.dumps({ + 'region_name': region, + 'instance_type_name': instance_type, + 'ssh_key_names': [ + self.ssh_key_name + ], + 'quantity': quantity, + 'name': name + }) + response = requests.post(f'{API_ENDPOINT}/instance-operations/launch', + data=data, + headers=self.headers) + raise_lambda_error(response) + return response.json().get('data', []).get('instance_ids', []) + + def remove_instances(self, *instance_ids: str) -> Dict[str, Any]: + """Terminate instances.""" + data = json.dumps({ + 'instance_ids': [ + instance_ids[0] # TODO(ewzeng) don't hardcode + ] + }) + response = requests.post(f'{API_ENDPOINT}/instance-operations/terminate', + data=data, + headers=self.headers) + raise_lambda_error(response) + return response.json().get('data', []).get('terminated_instances', []) + + def list_instances(self) -> Dict[str, Any]: + """List existing instances.""" + response = requests.get(f'{API_ENDPOINT}/instances', headers=self.headers) + raise_lambda_error(response) + return response.json().get('data', []) + + def set_ssh_key(self, name: str, pub_key: str) -> None: + """Set ssh key.""" + data = json.dumps({ + 'name': name, + 'public_key': pub_key + }) + response = requests.post(f'{API_ENDPOINT}/ssh-keys', + data=data, + headers=self.headers) + raise_lambda_error(response) + self.ssh_key_name = name + with open(self.credentials, 'w') as f: + f.write(f'api_key = {self.api_key}\n') + f.write(f'ssh_key_name = {self.ssh_key_name}\n') + + def list_catalog(self) -> Dict[str, Any]: + """List offered instances and their availability.""" + response = requests.get(f'{API_ENDPOINT}/instance-types', + headers=self.headers) + raise_lambda_error(response) + return response.json().get('data', []) diff --git a/sky/skylet/providers/lambda_cloud/node_provider.py b/sky/skylet/providers/lambda_cloud/node_provider.py new file mode 100644 index 00000000000..47ef9b16225 --- /dev/null +++ b/sky/skylet/providers/lambda_cloud/node_provider.py @@ -0,0 +1,217 @@ +import logging +import os +import time +from threading import RLock +from typing import Any, Dict, List, Optional + +from ray.autoscaler.node_provider import NodeProvider +from ray.autoscaler.tags import ( + TAG_RAY_CLUSTER_NAME, + TAG_RAY_USER_NODE_TYPE, + TAG_RAY_NODE_NAME, + TAG_RAY_LAUNCH_CONFIG, + TAG_RAY_NODE_STATUS, + STATUS_UP_TO_DATE, + TAG_RAY_NODE_KIND, + NODE_KIND_WORKER, + NODE_KIND_HEAD, +) +from ray.autoscaler._private.util import hash_launch_conf +from sky.skylet.providers.lambda_cloud import lambda_utils +from sky.utils import common_utils + +TAG_PATH_PREFIX = '~/.sky/generated/lambda_cloud/metadata' +REMOTE_RAY_YAML = '~/ray_bootstrap_config.yaml' + +logger = logging.getLogger(__name__) + + +def synchronized(f): + def wrapper(self, *args, **kwargs): + self.lock.acquire() + try: + return f(self, *args, **kwargs) + finally: + self.lock.release() + + return wrapper + + +class LambdaNodeProvider(NodeProvider): + """Node Provider for Lambda Cloud. + + This provider assumes Lambda Cloud credentials are set. + """ + + def __init__(self, + provider_config: Dict[str, Any], + cluster_name: str) -> None: + NodeProvider.__init__(self, provider_config, cluster_name) + self.lock = RLock() + self.lambda_client = lambda_utils.LambdaCloudClient() + self.cached_nodes = {} + self.metadata = lambda_utils.Metadata(TAG_PATH_PREFIX, cluster_name) + vms = self._list_instances_in_cluster() + + # The tag file for autodowned clusters is not autoremoved. Hence, if + # a previous cluster was autodowned and has the same name as the + # current cluster, then self.metadata might load the old tag file. + # We prevent this by removing any old vms in the tag file. + self.metadata.refresh([node['id'] for node in vms]) + + # If tag file does not exist on head, create it and add basic tags. + # This is a hack to make sure that ray on head can access some + # important tags. + # TODO(ewzeng): change when Lambda Cloud adds tag support. + ray_yaml_path = os.path.expanduser(REMOTE_RAY_YAML) + if os.path.exists(ray_yaml_path) and not os.path.exists( + self.metadata.path): + config = common_utils.read_yaml(ray_yaml_path) + # Ensure correct cluster so sky launch on head node works correctly + if config['cluster_name'] != cluster_name: + return + # Compute launch hash + head_node_config = config.get('head_node', {}) + head_node_type = config.get('head_node_type') + if head_node_type: + head_config = config['available_node_types'][head_node_type] + head_node_config.update(head_config["node_config"]) + launch_hash = hash_launch_conf(head_node_config, config['auth']) + # Populate tags + for node in vms: + self.metadata[node['id']] = {'tags': + { + TAG_RAY_CLUSTER_NAME: cluster_name, + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, + TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_USER_NODE_TYPE: 'ray_head_default', + TAG_RAY_NODE_NAME: f'ray-{cluster_name}-head', + TAG_RAY_LAUNCH_CONFIG: launch_hash, + }} + + def _list_instances_in_cluster(self) -> Dict[str, Any]: + """List running instances in cluster.""" + vms = self.lambda_client.list_instances() + return [ + node for node in vms + if node['name'] == self.cluster_name + ] + + @synchronized + def _get_filtered_nodes(self, + tag_filters: Dict[str, str]) -> Dict[str, Any]: + + def match_tags(vm): + vm_info = self.metadata[vm['id']] + tags = {} if vm_info is None else vm_info['tags'] + for k, v in tag_filters.items(): + if tags.get(k) != v: + return False + return True + + vms = self._list_instances_in_cluster() + nodes = [self._extract_metadata(vm) for vm in filter(match_tags, vms)] + self.cached_nodes = {node['id']: node for node in nodes} + return self.cached_nodes + + def _extract_metadata(self, vm: Dict[str, Any]) -> Dict[str, Any]: + metadata = {'id': vm['id'], 'status': vm['status'], 'tags': {}} + instance_info = self.metadata[vm['id']] + if instance_info is not None: + metadata['tags'] = instance_info['tags'] + ip = vm['ip'] + metadata['external_ip'] = ip + # TODO(ewzeng): The internal ip is hard to get, so set it to the + # external ip as a hack. This should be changed in the future. + # https://docs.lambdalabs.com/cloud/learn-private-ip-address/ + metadata['internal_ip'] = ip + return metadata + + def non_terminated_nodes(self, tag_filters: Dict[str, str]) -> List[str]: + """Return a list of node ids filtered by the specified tags dict. + + This list must not include terminated nodes. For performance reasons, + providers are allowed to cache the result of a call to + non_terminated_nodes() to serve single-node queries + (e.g. is_running(node_id)). This means that non_terminated_nodes() must + be called again to refresh results. + + Examples: + >>> provider.non_terminated_nodes({TAG_RAY_NODE_KIND: "worker"}) + ["node-1", "node-2"] + """ + nodes = self._get_filtered_nodes(tag_filters=tag_filters) + return [k for k, _ in nodes.items()] + + def is_running(self, node_id: str) -> bool: + """Return whether the specified node is running.""" + return self._get_cached_node(node_id=node_id) is not None + + def is_terminated(self, node_id: str) -> bool: + """Return whether the specified node is terminated.""" + return self._get_cached_node(node_id=node_id) is None + + def node_tags(self, node_id: str) -> Dict[str, str]: + """Returns the tags of the given node (string dict).""" + return self._get_cached_node(node_id=node_id)['tags'] + + def external_ip(self, node_id: str) -> str: + """Returns the external ip of the given node.""" + return self._get_cached_node(node_id=node_id)['external_ip'] + + def internal_ip(self, node_id: str) -> str: + """Returns the internal ip (Ray ip) of the given node.""" + return self._get_cached_node(node_id=node_id)['internal_ip'] + + def create_node(self, + node_config: Dict[str, Any], + tags: Dict[str, str], + count: int) -> None: + """Creates a number of nodes within the namespace.""" + assert count == 1, count # Only support 1-node clusters for now + + # get the tags + config_tags = node_config.get('tags', {}).copy() + config_tags.update(tags) + config_tags[TAG_RAY_CLUSTER_NAME] = self.cluster_name + + # create the node + ttype = node_config['InstanceType'] + region = self.provider_config['region'] + vm_list = self.lambda_client.create_instances(instance_type=ttype, + region=region, + quantity=1, + name=self.cluster_name) + assert len(vm_list) == 1, len(vm_list) + vm_id = vm_list[0] + self.metadata[vm_id] = {'tags': config_tags} + + # Wait for booting to finish + # TODO(ewzeng): For multi-node, launch all vms first and then wait. + while True: + vms = self.lambda_client.list_instances() + for vm in vms: + if vm['id'] == vm_id and vm['status'] == 'active': + return + time.sleep(10) + + @synchronized + def set_node_tags(self, node_id: str, tags: Dict[str, str]) -> None: + """Sets the tag values (string dict) for the specified node.""" + node = self._get_node(node_id) + node['tags'].update(tags) + self.metadata[node_id] = {'tags': node['tags']} + + def terminate_node(self, node_id: str) -> None: + """Terminates the specified node.""" + self.lambda_client.remove_instances(node_id) + self.metadata[node_id] = None + + def _get_node(self, node_id: str) -> Optional[Dict[str, Any]]: + self._get_filtered_nodes({}) # Side effect: updates cache + return self.cached_nodes.get(node_id, None) + + def _get_cached_node(self, node_id: str) -> Optional[Dict[str, Any]]: + if node_id in self.cached_nodes: + return self.cached_nodes[node_id] + return self._get_node(node_id=node_id) diff --git a/sky/templates/lambda-ray.yml.j2 b/sky/templates/lambda-ray.yml.j2 new file mode 100644 index 00000000000..fb6306fa61f --- /dev/null +++ b/sky/templates/lambda-ray.yml.j2 @@ -0,0 +1,110 @@ +cluster_name: {{cluster_name}} + +# The maximum number of workers nodes to launch in addition to the head node. +max_workers: {{num_nodes - 1}} +upscaling_speed: {{num_nodes - 1}} +idle_timeout_minutes: 60 + +provider: + type: external + module: sky.skylet.providers.lambda_cloud.LambdaNodeProvider + region: {{region}} + +auth: + ssh_user: ubuntu + ssh_private_key: {{ssh_private_key}} + +available_node_types: + ray_head_default: + resources: {} + node_config: + InstanceType: {{instance_type}} +{% if num_nodes > 1 %} + ray_worker_default: + min_workers: {{num_nodes - 1}} + max_workers: {{num_nodes - 1}} + resources: {} + node_config: + InstanceType: {{instance_type}} +{%- endif %} + +head_node_type: ray_head_default + +# Format: `REMOTE_PATH : LOCAL_PATH` +file_mounts: { + "{{sky_ray_yaml_remote_path}}": "{{sky_ray_yaml_local_path}}", + "{{sky_remote_path}}/{{sky_wheel_hash}}": "{{sky_local_path}}", +{%- for remote_path, local_path in credentials.items() %} + "{{remote_path}}": "{{local_path}}", +{%- endfor %} +} + +rsync_exclude: [] + +initialization_commands: [] + +# List of shell commands to run to set up nodes. +# NOTE: these are very performance-sensitive. Each new item opens/closes an SSH +# connection, which is expensive. Try your best to co-locate commands into fewer +# items! +# +# Increment the following for catching performance bugs easier: +# current num items (num SSH connections): 1 +setup_commands: + # Disable `unattended-upgrades` to prevent apt-get from hanging. It should be called at the beginning before the process started to avoid being blocked. (This is a temporary fix.) + # Create ~/.ssh/config file in case the file does not exist in the image. + # Line 'rm ..': there is another installation of pip. + # Line 'sudo bash ..': set the ulimit as suggested by ray docs for performance. https://docs.ray.io/en/latest/cluster/vms/user-guides/large-cluster-best-practices.html#system-configuration + # Line 'sudo grep ..': set the number of threads per process to unlimited to avoid ray job submit stucking issue when the number of running ray jobs increase. + # Line 'mkdir -p ..': disable host key check + # Line 'python3 -c ..': patch the buggy ray files and enable `-o allow_other` option for `goofys` + - sudo systemctl stop unattended-upgrades || true; + sudo systemctl disable unattended-upgrades || true; + sudo sed -i 's/Unattended-Upgrade "1"/Unattended-Upgrade "0"/g' /etc/apt/apt.conf.d/20auto-upgrades || true; + sudo kill -9 `sudo lsof /var/lib/dpkg/lock-frontend | awk '{print $2}' | tail -n 1` || true; + sudo pkill -9 apt-get; + sudo pkill -9 dpkg; + sudo dpkg --configure -a; + mkdir -p ~/.ssh; touch ~/.ssh/config; + rm ~/.local/bin/pip ~/.local/bin/pip3 ~/.local/bin/pip3.8 ~/.local/bin/pip3.10; + (type -a python | grep -q python3) || echo 'alias python=python3' >> ~/.bashrc; + (type -a pip | grep -q pip3) || echo 'alias pip=pip3' >> ~/.bashrc; + which conda > /dev/null 2>&1 || (wget -nc https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && bash Miniconda3-latest-Linux-x86_64.sh -b && eval "$(~/miniconda3/bin/conda shell.bash hook)" && conda init && conda config --set auto_activate_base true); + source ~/.bashrc; + (pip3 list | grep ray | grep {{ray_version}} 2>&1 > /dev/null || pip3 install -U ray[default]=={{ray_version}}) && mkdir -p ~/sky_workdir && mkdir -p ~/.sky/sky_app && touch ~/.sudo_as_admin_successful; + (pip3 list | grep skypilot && [ "$(cat {{sky_remote_path}}/current_sky_wheel_hash)" == "{{sky_wheel_hash}}" ]) || (pip3 uninstall skypilot -y; pip3 install "$(echo {{sky_remote_path}}/{{sky_wheel_hash}}/skypilot-{{sky_version}}*.whl)" && echo "{{sky_wheel_hash}}" > {{sky_remote_path}}/current_sky_wheel_hash || exit 1); + sudo bash -c 'rm -rf /etc/security/limits.d; echo "* soft nofile 1048576" >> /etc/security/limits.conf; echo "* hard nofile 1048576" >> /etc/security/limits.conf'; + sudo grep -e '^DefaultTasksMax' /etc/systemd/system.conf || (sudo bash -c 'echo "DefaultTasksMax=infinity" >> /etc/systemd/system.conf'); sudo systemctl set-property user-$(id -u $(whoami)).slice TasksMax=infinity; sudo systemctl daemon-reload; + mkdir -p ~/.ssh; (grep -Pzo -q "Host \*\n StrictHostKeyChecking no" ~/.ssh/config) || printf "Host *\n StrictHostKeyChecking no\n" >> ~/.ssh/config; + python3 -c "from sky.skylet.ray_patches import patch; patch()" || exit 1; + [ -f /etc/fuse.conf ] && sudo sed -i 's/#user_allow_other/user_allow_other/g' /etc/fuse.conf || (sudo sh -c 'echo "user_allow_other" > /etc/fuse.conf'); + +# Command to start ray on the head node. You don't need to change this. +# NOTE: these are very performance-sensitive. Each new item opens/closes an SSH +# connection, which is expensive. Try your best to co-locate commands into fewer +# items! The same comment applies for worker_start_ray_commands. +# +# Increment the following for catching performance bugs easier: +# current num items (num SSH connections): 2 +head_start_ray_commands: + # Start skylet daemon. (Should not place it in the head_setup_commands, otherwise it will run before skypilot is installed.) + - ((ps aux | grep -v nohup | grep -v grep | grep -q -- "python3 -m sky.skylet.skylet") || nohup python3 -m sky.skylet.skylet >> ~/.sky/skylet.log 2>&1 &); + ray stop; RAY_SCHEDULER_EVENTS=0 ray start --disable-usage-stats --head --port=6379 --object-manager-port=8076 --autoscaling-config=~/ray_bootstrap_config.yaml {{"--resources='%s'" % custom_resources if custom_resources}} || exit 1; + which prlimit && for id in $(pgrep -f raylet/raylet); do sudo prlimit --nofile=1048576:1048576 --pid=$id || true; done; + +{%- if num_nodes > 1 %} +worker_start_ray_commands: + - ray stop; RAY_SCHEDULER_EVENTS=0 ray start --disable-usage-stats --address=$RAY_HEAD_IP:6379 --object-manager-port=8076 {{"--resources='%s'" % custom_resources if custom_resources}} || exit 1; + which prlimit && for id in $(pgrep -f raylet/raylet); do sudo prlimit --nofile=1048576:1048576 --pid=$id || true; done; +{%- else %} +worker_start_ray_commands: [] +{%- endif %} + +head_node: {} +worker_nodes: {} + +# These fields are required for external cloud providers. +head_setup_commands: [] +worker_setup_commands: [] +cluster_synced_files: [] +file_mounts_sync_continuously: False diff --git a/tests/conftest.py b/tests/conftest.py index 4849d9c76f9..8383a926601 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,14 +12,24 @@ # By default, only run generic tests and cloud-specific tests for GCP and Azure, # due to the cloud credit limit for the development account. # To only run tests for a specific cloud (as well as generic tests), use -# --aws, --gcp, or --azure. +# --aws, --gcp, --azure, or --lambda. # To only run tests for managed spot (without generic tests), use --managed-spot. # A "generic test" tests a generic functionality (e.g., autostop) that # should work on any cloud we support. The cloud used for such a test # is controlled by `--generic-cloud` (typically you do not need to set it). -all_clouds_in_smoke_tests = ['aws', 'gcp', 'azure'] +all_clouds_in_smoke_tests = ['aws', 'gcp', 'azure', 'lambda'] default_clouds_to_run = ['gcp', 'azure'] +# Translate cloud name to pytest keyword. We need this because +# @pytest.mark.lambda is not allowed, so we use @pytest.mark.lambda_cloud +# instead. +cloud_to_pytest_keyword = { + 'aws': 'aws', + 'gcp': 'gcp', + 'azure': 'azure', + 'lambda': 'lambda_cloud' +} + def pytest_addoption(parser): # tests marked as `slow` will be skipped by default, use --runslow to run @@ -49,8 +59,9 @@ def pytest_addoption(parser): def pytest_configure(config): config.addinivalue_line('markers', 'slow: mark test as slow to run') for cloud in all_clouds_in_smoke_tests: - config.addinivalue_line('markers', - f'{cloud}: mark test as {cloud} specific') + cloud_keyword = cloud_to_pytest_keyword[cloud] + config.addinivalue_line( + 'markers', f'{cloud_keyword}: mark test as {cloud} specific') def _get_cloud_to_run(config) -> List[str]: @@ -78,23 +89,56 @@ def pytest_collection_modifyitems(config, items): if 'slow' in item.keywords and not config.getoption('--runslow'): item.add_marker(skip_marks['slow']) for cloud in all_clouds_in_smoke_tests: - if cloud in item.keywords and cloud not in cloud_to_run: + cloud_keyword = cloud_to_pytest_keyword[cloud] + if (f'no_{cloud_keyword}' in item.keywords or + (cloud_keyword in item.keywords and cloud not in cloud_to_run)): item.add_marker(skip_marks[cloud]) if (not 'managed_spot' in item.keywords) and config.getoption('--managed-spot'): item.add_marker(skip_marks['managed_spot']) + # We run Lambda Cloud tests serially because Lambda Cloud rate limits its + # launch API to one launch every 10 seconds. + serial_mark = pytest.mark.xdist_group(name='serial_lambda_cloud') + # Handle generic tests + if _generic_cloud(config) == 'lambda': + for item in items: + if (_is_generic_test(item) and + 'no_lambda_cloud' not in item.keywords): + item.add_marker(serial_mark) + # Adding the serial mark does not update the item.nodeid, + # but item.nodeid is important for pytest.xdist_group, e.g. + # https://github.com/pytest-dev/pytest-xdist/blob/master/src/xdist/scheduler/loadgroup.py + # This is a hack to update item.nodeid + item._nodeid = f'{item.nodeid}@serial_lambda_cloud' + # Handle Lambda Cloud specific tests + for item in items: + if 'lambda_cloud' in item.keywords: + item.add_marker(serial_mark) + item._nodeid = f'{item.nodeid}@serial_lambda_cloud' # See comment on item.nodeid above -@pytest.fixture -def generic_cloud(request) -> str: - c = request.config.getoption('--generic-cloud') - cloud_to_run = _get_cloud_to_run(request.config) + +def _is_generic_test(item) -> bool: + for cloud in all_clouds_in_smoke_tests: + if cloud_to_pytest_keyword[cloud] in item.keywords: + return False + return True + + +def _generic_cloud(config) -> str: + c = config.getoption('--generic-cloud') + cloud_to_run = _get_cloud_to_run(config) if c not in cloud_to_run: c = cloud_to_run[0] return c +@pytest.fixture +def generic_cloud(request) -> str: + return _generic_cloud(request.config) + + def pytest_sessionstart(session): from sky.clouds.service_catalog import common aws_az_mapping_path = common.get_catalog_path('aws/az_mappings.csv') diff --git a/tests/test_smoke.py b/tests/test_smoke.py index c065deb2a76..a25275d0acf 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -35,6 +35,8 @@ # id. test_id = str(uuid.uuid4())[-2:] +LAMBDA_TYPE = '--cloud lambda --gpus A100' + storage_setup_commands = [ 'touch ~/tmpfile', 'mkdir -p ~/tmp-workdir', 'touch ~/tmp-workdir/tmp\ file', 'touch ~/tmp-workdir/foo', @@ -484,6 +486,7 @@ def test_image_no_conda(): # ------------ Test stale job ------------ +@pytest.mark.no_lambda_cloud # Lambda Cloud does not support stopping instances def test_stale_job(generic_cloud: str): name = _get_cluster_name() test = Test( @@ -561,6 +564,7 @@ def test_gcp_stale_job_manual_restart(): # ---------- Check Sky's environment variables; workdir. ---------- +@pytest.mark.no_lambda_cloud # Lambda Cloud does not support num_nodes > 1 yet def test_env_check(generic_cloud: str): name = _get_cluster_name() test = Test( @@ -575,6 +579,7 @@ def test_env_check(generic_cloud: str): # ---------- file_mounts ---------- +@pytest.mark.no_lambda_cloud # Lambda Cloud does not support num_nodes > 1 yet def test_file_mounts(generic_cloud: str): name = _get_cluster_name() test_commands = [ @@ -591,6 +596,24 @@ def test_file_mounts(generic_cloud: str): run_one_test(test) +# TODO(ewzeng): merge this with 'test_file_mounts' when multi-node is supported. +@pytest.mark.lambda_cloud +def test_lambda_file_mounts(): + name = _get_cluster_name() + test_commands = [ + *storage_setup_commands, + f'sky launch -y -c {name} {LAMBDA_TYPE} --num-nodes 1 examples/using_file_mounts.yaml', + f'sky logs {name} 1 --status', # Ensure the job succeeded. + ] + test = Test( + 'lambda_using_file_mounts', + test_commands, + f'sky down -y {name}', + timeout=20 * 60, # 20 mins + ) + run_one_test(test) + + # ---------- storage ---------- @pytest.mark.aws def test_aws_storage_mounts(): @@ -647,6 +670,7 @@ def test_gcp_storage_mounts(): # ---------- CLI logs ---------- +@pytest.mark.no_lambda_cloud # Lambda Cloud does not support num_nodes > 1 yet def test_cli_logs(generic_cloud: str): name = _get_cluster_name() timestamp = time.time() @@ -668,7 +692,31 @@ def test_cli_logs(generic_cloud: str): run_one_test(test) +# TODO(ewzeng): merge this with 'test_cli_logs' when multi-node is supported. +@pytest.mark.lambda_cloud +def test_lambda_logs(): + name = _get_cluster_name() + timestamp = time.time() + test = Test( + 'lambda_cli_logs', + [ + f'sky launch -y -c {name} {LAMBDA_TYPE} "echo {timestamp} 1"', + f'sky exec {name} "echo {timestamp} 2"', + f'sky exec {name} "echo {timestamp} 3"', + f'sky exec {name} "echo {timestamp} 4"', + f'sky logs {name} 2 --status', + f'sky logs {name} 3 4 --sync-down', + f'sky logs {name} * --sync-down', + f'sky logs {name} 1 | grep "{timestamp} 1"', + f'sky logs {name} | grep "{timestamp} 4"', + ], + f'sky down -y {name}', + ) + run_one_test(test) + + # ---------- Job Queue. ---------- +@pytest.mark.no_lambda_cloud # Lambda Cloud does not have K80 gpus def test_job_queue(generic_cloud: str): name = _get_cluster_name() test = Test( @@ -695,6 +743,30 @@ def test_job_queue(generic_cloud: str): run_one_test(test) +@pytest.mark.lambda_cloud +def test_lambda_job_queue(): + name = _get_cluster_name() + test = Test( + 'lambda_job_queue', + [ + f'sky launch -y -c {name} {LAMBDA_TYPE} examples/job_queue/cluster.yaml', + f'sky exec {name} -n {name}-1 --gpus A100:0.5 -d examples/job_queue/job.yaml', + f'sky exec {name} -n {name}-2 --gpus A100:0.5 -d examples/job_queue/job.yaml', + f'sky exec {name} -n {name}-3 --gpus A100:0.5 -d examples/job_queue/job.yaml', + f'sky queue {name} | grep {name}-1 | grep RUNNING', + f'sky queue {name} | grep {name}-2 | grep RUNNING', + f'sky queue {name} | grep {name}-3 | grep PENDING', + f'sky cancel -y {name} 2', + 'sleep 5', + f'sky queue {name} | grep {name}-3 | grep RUNNING', + f'sky cancel -y {name} 3', + ], + f'sky down -y {name}', + ) + run_one_test(test) + + +@pytest.mark.no_lambda_cloud # Lambda Cloud does not support num_nodes > 1 yet def test_job_queue_multinode(generic_cloud: str): name = _get_cluster_name() test = Test( @@ -731,6 +803,7 @@ def test_job_queue_multinode(generic_cloud: str): run_one_test(test) +@pytest.mark.no_lambda_cloud # No Lambda Cloud VM has 8 CPUs def test_large_job_queue(generic_cloud: str): name = _get_cluster_name() test = Test( @@ -751,6 +824,7 @@ def test_large_job_queue(generic_cloud: str): # ---------- Submitting multiple tasks to the same cluster. ---------- +@pytest.mark.no_lambda_cloud # Lambda Cloud does not support num_nodes > 1 yet def test_multi_echo(generic_cloud: str): name = _get_cluster_name() test = Test( @@ -771,6 +845,7 @@ def test_multi_echo(generic_cloud: str): # ---------- Task: 1 node training. ---------- +@pytest.mark.no_lambda_cloud # Lambda Cloud does not have V100 instances def test_huggingface(generic_cloud: str): name = _get_cluster_name() test = Test( @@ -786,6 +861,22 @@ def test_huggingface(generic_cloud: str): run_one_test(test) +@pytest.mark.lambda_cloud +def test_lambda_huggingface(generic_cloud: str): + name = _get_cluster_name() + test = Test( + 'lambda_huggingface_glue_imdb_app', + [ + f'sky launch -y -c {name} {LAMBDA_TYPE} examples/huggingface_glue_imdb_app.yaml', + f'sky logs {name} 1 --status', # Ensure the job succeeded. + f'sky exec {name} {LAMBDA_TYPE} examples/huggingface_glue_imdb_app.yaml', + f'sky logs {name} 2 --status', # Ensure the job succeeded. + ], + f'sky down -y {name}', + ) + run_one_test(test) + + # ---------- TPU. ---------- @pytest.mark.gcp def test_tpu(): @@ -847,6 +938,7 @@ def test_tpu_vm_pod(): # ---------- Simple apps. ---------- +@pytest.mark.no_lambda_cloud # Lambda Cloud does not support num_nodes > 1 yet def test_multi_hostname(generic_cloud: str): name = _get_cluster_name() test = Test( @@ -863,6 +955,7 @@ def test_multi_hostname(generic_cloud: str): # ---------- Task: n=2 nodes with setups. ---------- +@pytest.mark.no_lambda_cloud # Lambda Cloud does not support num_nodes > 1 yet def test_distributed_tf(generic_cloud: str): name = _get_cluster_name() test = Test( @@ -926,6 +1019,7 @@ def test_azure_start_stop(): # ---------- Testing Autostopping ---------- +@pytest.mark.no_lambda_cloud # Lambda Cloud does not support stopping instances def test_autostop(generic_cloud: str): name = _get_cluster_name() test = Test( @@ -981,6 +1075,7 @@ def test_autostop(generic_cloud: str): # ---------- Testing Autodowning ---------- +@pytest.mark.no_lambda_cloud # Lambda Cloud does not support num_nodes > 1 yet def test_autodown(generic_cloud: str): name = _get_cluster_name() test = Test( @@ -1015,6 +1110,41 @@ def test_autodown(generic_cloud: str): run_one_test(test) +@pytest.mark.lambda_cloud +def test_lambda_autodown(): + name = _get_cluster_name() + test = Test( + 'lambda_autodown', + [ + f'sky launch -y -d -c {name} {LAMBDA_TYPE} tests/test_yamls/minimal.yaml', + f'sky autostop -y {name} --down -i 1', + # Ensure autostop is set. + f'sky status | grep {name} | grep "1m (down)"', + # Ensure the cluster is not terminated early. + 'sleep 45', + f'sky status --refresh | grep {name} | grep UP', + # Ensure the cluster is terminated. + 'sleep 200', + f's=$(SKYPILOT_DEBUG=0 sky status --refresh) && printf "$s" && {{ echo "$s" | grep {name} | grep "Autodowned cluster\|terminated on the cloud"; }} || {{ echo "$s" | grep {name} && exit 1 || exit 0; }}', + f'sky launch -y -d -c {name} {LAMBDA_TYPE} --down tests/test_yamls/minimal.yaml', + f'sky status | grep {name} | grep UP', # Ensure the cluster is UP. + f'sky exec {name} {LAMBDA_TYPE} tests/test_yamls/minimal.yaml', + f'sky status | grep {name} | grep "1m (down)"', + 'sleep 200', + # Ensure the cluster is terminated. + f's=$(SKYPILOT_DEBUG=0 sky status --refresh) && printf "$s" && {{ echo "$s" | grep {name} | grep "Autodowned cluster\|terminated on the cloud"; }} || {{ echo "$s" | grep {name} && exit 1 || exit 0; }}', + f'sky launch -y -d -c {name} {LAMBDA_TYPE} --down tests/test_yamls/minimal.yaml', + f'sky autostop -y {name} --cancel', + 'sleep 200', + # Ensure the cluster is still UP. + f's=$(SKYPILOT_DEBUG=0 sky status --refresh) && printf "$s" && echo "$s" | grep {name} | grep UP', + ], + f'sky down -y {name}', + timeout=25 * 60, + ) + run_one_test(test) + + def _get_cancel_task_with_cloud(name, cloud, timeout=15 * 60): test = Test( f'{cloud}-cancel-task', @@ -1058,7 +1188,7 @@ def test_cancel_azure(): run_one_test(test) -# ---------- Testing `sky cancel` ---------- +@pytest.mark.no_lambda_cloud # Lambda Cloud does not support num_nodes > 1 yet def test_cancel_pytorch(generic_cloud: str): name = _get_cluster_name() test = Test( @@ -1082,6 +1212,7 @@ def test_cancel_pytorch(generic_cloud: str): # ---------- Testing use-spot option ---------- +@pytest.mark.no_lambda_cloud # Lambda Cloud does not support spot instances def test_use_spot(generic_cloud: str): """Test use-spot and sky exec.""" name = _get_cluster_name() @@ -1099,6 +1230,7 @@ def test_use_spot(generic_cloud: str): # ---------- Testing managed spot ---------- +@pytest.mark.no_lambda_cloud # Lambda Cloud does not support spot instances @pytest.mark.managed_spot def test_spot(generic_cloud: str): """Test the spot yaml.""" @@ -1124,6 +1256,7 @@ def test_spot(generic_cloud: str): run_one_test(test) +@pytest.mark.no_lambda_cloud # Lambda Cloud does not support spot instances @pytest.mark.managed_spot def test_spot_failed_setup(generic_cloud: str): """Test managed spot job with failed setup.""" @@ -1207,6 +1340,7 @@ def test_spot_recovery_gcp(): run_one_test(test) +@pytest.mark.no_lambda_cloud # Lambda Cloud does not support spot instances @pytest.mark.managed_spot def test_spot_recovery_default_resources(generic_cloud: str): """Test managed spot recovery for default resources.""" @@ -1408,6 +1542,7 @@ def test_spot_cancellation_gcp(): # ---------- Testing storage for managed spot ---------- +@pytest.mark.no_lambda_cloud # Lambda Cloud does not support spot instances @pytest.mark.managed_spot def test_spot_storage(generic_cloud: str): """Test storage with managed spot""" @@ -1459,6 +1594,7 @@ def test_spot_tpu(): # ---------- Testing env for spot ---------- +@pytest.mark.no_lambda_cloud # Lambda Cloud does not support spot instances @pytest.mark.managed_spot def test_spot_inline_env(generic_cloud: str): """Test spot env"""