Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Lambda Labs #1557

Merged
merged 47 commits into from
Jan 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
54b4dea
Apply gmittal's lambda lab PR (#1136) on top of commit ad37a47
ewzeng Dec 9, 2022
eefa6c6
Basic working Lambda Labs support
ewzeng Dec 16, 2022
2c4d72a
Add error handling for Lambda Labs API and small lambda-ray.yml bugfix
ewzeng Dec 19, 2022
652064b
Add automatic key generation, improve sky check, and resolve import bug
ewzeng Dec 19, 2022
1679e5f
Improve Lambda Labs launch code and error handling
ewzeng Dec 20, 2022
962a9c6
Remove bootstrap_config, change metadata file design, and resolve
ewzeng Dec 21, 2022
6333ed7
Make autodown work on Lambda Labs
ewzeng Dec 22, 2022
ac0a336
Add basic tests and improve lambda-ray.yml.j2 bugfix
ewzeng Dec 22, 2022
b711534
Add sky cancel test and do not allow Lambda nodes to stop
ewzeng Dec 23, 2022
9e47585
Polish provider code and change local metadata path to avoid clutter
ewzeng Dec 23, 2022
3191847
Update and move catalog out of repo
ewzeng Dec 24, 2022
fc0b771
Clean up code
ewzeng Dec 24, 2022
a9a8df7
Cleanup and add CLI logs test
ewzeng Dec 26, 2022
8fda3a7
Disallow --num-nodes > 1 and rename some variables
ewzeng Dec 27, 2022
61f3ccd
Do not let optimizer consider Lambda Labs when launching spot
ewzeng Jan 3, 2023
e7a3cb7
Merge branch 'master' into lambda-labs-v3
ewzeng Jan 4, 2023
9e95dd3
Fix issues arising from merge
ewzeng Jan 4, 2023
245a0c5
Address Michaelvll comments
ewzeng Jan 6, 2023
786fcb6
Address infwinston comments
ewzeng Jan 6, 2023
5344a48
Update Lambda Labs help string
ewzeng Jan 7, 2023
59b93a9
Move Lambda Lab tests into smoke tests and change local tag file
ewzeng Jan 8, 2023
aa37002
Improve remote node detection
ewzeng Jan 11, 2023
ddca707
Change tag file scheme
ewzeng Jan 17, 2023
3af5de7
Add comments and change region_zone lookup
ewzeng Jan 17, 2023
a0ec422
Use same tag file path for local and remote
ewzeng Jan 17, 2023
08869d2
Merge branch 'master' into lambda-labs-v3
ewzeng Jan 18, 2023
fcb6466
Remove is_remote file
ewzeng Jan 18, 2023
750b1a5
Clean up imports in Lambda Labs node_provider
ewzeng Jan 19, 2023
2564baf
Make optimizer skip clouds that do not implement requested_features
ewzeng Jan 19, 2023
2934040
Rename Lambda Labs client functions, nits
ewzeng Jan 19, 2023
0661757
Improve requested_features implementation, nits
ewzeng Jan 23, 2023
aa119fc
Add type annotations, nits
ewzeng Jan 24, 2023
eedbc3e
Merge branch 'master' into lambda-labs-v3, update Lambda Labs testing
ewzeng Jan 25, 2023
0da8638
Improve pytest serialization logic
ewzeng Jan 25, 2023
29c31ac
Improve requested_features, introduce CloudImplementationFeatures enums
ewzeng Jan 26, 2023
88d02c0
Update lambda_utils.Metadata, address nits
ewzeng Jan 26, 2023
f4bcef9
Fix conftest.py bug introduced in previous commit
ewzeng Jan 26, 2023
8b83718
Update test comment
ewzeng Jan 27, 2023
5831d71
Rename Lambda Labs -> Lambda Cloud
ewzeng Jan 27, 2023
14992ae
Fix tag file reuse bug
ewzeng Jan 27, 2023
131f8d3
Testing nit
ewzeng Jan 27, 2023
dc2f2f1
Fix auth bug and address nits
ewzeng Jan 28, 2023
65375c5
Address final nits
ewzeng Jan 30, 2023
e37c7b2
Merge branch 'master' into lambda-labs-v3
ewzeng Jan 30, 2023
bd79715
Fix typing issues from merge
ewzeng Jan 30, 2023
30ea3d7
Provide basic support for cpus in resource specification
ewzeng Jan 30, 2023
8570d62
Improve 'cpu' resource specification for Lambda Cloud
ewzeng Jan 30, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ env = [
"SKYPILOT_DEBUG=1",
"SKYPILOT_DISABLE_USAGE_COLLECTION=1"
]
addopts = "-s -n 16 -q --tb=short --disable-warnings"
addopts = "-s -n 16 -q --tb=short --dist loadgroup --disable-warnings"

[tool.mypy]
python_version = "3.8"
Expand Down
2 changes: 2 additions & 0 deletions sky/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
AWS = clouds.AWS
Azure = clouds.Azure
GCP = clouds.GCP
Lambda = clouds.Lambda
Local = clouds.Local
optimize = Optimizer.optimize

Expand All @@ -36,6 +37,7 @@
'AWS',
'Azure',
'GCP',
'Lambda',
'Local',
'Optimizer',
'OptimizeTarget',
Expand Down
24 changes: 24 additions & 0 deletions sky/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from sky.utils import common_utils
from sky.utils import subprocess_utils
from sky.utils import ux_utils
from sky.skylet.providers.lambda_cloud import lambda_utils

logger = sky_logging.init_logger(__name__)

Expand Down Expand Up @@ -310,3 +311,26 @@ def setup_azure_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
config['file_mounts'] = file_mounts

return config


def setup_lambda_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
get_or_generate_keys()

# Ensure ssh key is registered with Lambda Cloud
lambda_client = lambda_utils.LambdaCloudClient()
if lambda_client.ssh_key_name is None:
public_key_path = os.path.expanduser(PUBLIC_SSH_KEY_PATH)
with open(public_key_path, 'r') as f:
public_key = f.read()
name = f'sky-key-{common_utils.get_user_hash()}'
lambda_client.set_ssh_key(name, public_key)

# Need to use ~ relative path because Ray uses the same
# path for finding the public key path on both local and head node.
config['auth']['ssh_public_key'] = PUBLIC_SSH_KEY_PATH

file_mounts = config['file_mounts']
file_mounts[PUBLIC_SSH_KEY_PATH] = PUBLIC_SSH_KEY_PATH
config['file_mounts'] = file_mounts

return config
22 changes: 22 additions & 0 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from sky.backends import onprem_utils
from sky.skylet import constants
from sky.skylet import log_lib
from sky.skylet.providers.lambda_cloud import lambda_utils
from sky.utils import common_utils
from sky.utils import command_runner
from sky.utils import env_options
Expand Down Expand Up @@ -951,6 +952,8 @@ def _add_auth_to_cluster_config(cloud: clouds.Cloud, cluster_config_file: str):
config = auth.setup_gcp_authentication(config)
elif isinstance(cloud, clouds.Azure):
config = auth.setup_azure_authentication(config)
elif isinstance(cloud, clouds.Lambda):
config = auth.setup_lambda_authentication(config)
else:
assert isinstance(cloud, clouds.Local), cloud
# Local cluster case, authentication is already filled by the user
Expand Down Expand Up @@ -1802,10 +1805,29 @@ def _query_status_azure(
return _process_cli_query('Azure', cluster, query_cmd, '\t', status_map)


def _query_status_lambda(
cluster: str,
ray_config: Dict[str, Any], # pylint: disable=unused-argument
) -> List[global_user_state.ClusterStatus]:
status_map = {
'booting': global_user_state.ClusterStatus.INIT,
'active': global_user_state.ClusterStatus.UP,
ewzeng marked this conversation as resolved.
Show resolved Hide resolved
'unhealthy': global_user_state.ClusterStatus.INIT,
'terminated': None,
}
# TODO(ewzeng): filter by hash_filter_string to be safe
vms = lambda_utils.LambdaCloudClient().list_instances()
for node in vms:
if node['name'] == cluster:
return [status_map[node['status']]]
return []


_QUERY_STATUS_FUNCS = {
'AWS': _query_status_aws,
'GCP': _query_status_gcp,
'Azure': _query_status_azure,
'Lambda': _query_status_lambda,
}


Expand Down
62 changes: 58 additions & 4 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def _get_cluster_config_template(cloud):
clouds.AWS: 'aws-ray.yml.j2',
clouds.Azure: 'azure-ray.yml.j2',
clouds.GCP: 'gcp-ray.yml.j2',
clouds.Lambda: 'lambda-ray.yml.j2',
clouds.Local: 'local-ray.yml.j2',
}
return cloud_to_template[type(cloud)]
Expand Down Expand Up @@ -574,12 +575,14 @@ class GangSchedulingStatus(enum.Enum):

def __init__(self, log_dir: str, dag: 'dag.Dag',
optimize_target: OptimizeTarget,
requested_features: Set[clouds.CloudImplementationFeatures],
local_wheel_path: pathlib.Path, wheel_hash: str):
self._blocked_resources = set()

self.log_dir = os.path.expanduser(log_dir)
self._dag = dag
self._optimize_target = optimize_target
self._requested_features = requested_features
self._local_wheel_path = local_wheel_path
self._wheel_hash = wheel_hash

Expand Down Expand Up @@ -775,6 +778,42 @@ def _update_blocklist_on_azure_error(
else:
self._blocked_resources.add(launchable_resources.copy(zone=None))

def _update_blocklist_on_lambda_error(
self, launchable_resources: 'resources_lib.Resources', region,
zones, stdout, stderr):
del zones # Unused.
style = colorama.Style
stdout_splits = stdout.split('\n')
stderr_splits = stderr.split('\n')
errors = [
s.strip()
for s in stdout_splits + stderr_splits
if 'LambdaCloudError:' in s.strip()
]
if not errors:
logger.info('====== stdout ======')
for s in stdout_splits:
print(s)
logger.info('====== stderr ======')
for s in stderr_splits:
print(s)
with ux_utils.print_exception_no_traceback():
raise RuntimeError('Errors occurred during provision; '
'check logs above.')

logger.warning(f'Got error(s) in {region.name}:')
messages = '\n\t'.join(errors)
logger.warning(f'{style.DIM}\t{messages}{style.RESET_ALL}')
self._blocked_resources.add(launchable_resources.copy(zone=None))

# Sometimes, LambdaCloudError will list available regions.
for e in errors:
if e.find('Regions with capacity available:') != -1:
for r in clouds.Lambda.regions():
if e.find(r.name) == -1:
self._blocked_resources.add(
launchable_resources.copy(region=r.name, zone=None))

def _update_blocklist_on_local_error(
self, launchable_resources: 'resources_lib.Resources', region,
zones, stdout, stderr):
Expand Down Expand Up @@ -834,6 +873,7 @@ def _update_blocklist_on_error(
clouds.AWS: self._update_blocklist_on_aws_error,
clouds.Azure: self._update_blocklist_on_azure_error,
clouds.GCP: self._update_blocklist_on_gcp_error,
clouds.Lambda: self._update_blocklist_on_lambda_error,
clouds.Local: self._update_blocklist_on_local_error,
}
cloud = launchable_resources.cloud
Expand Down Expand Up @@ -888,6 +928,9 @@ def _yield_region_zones(self, to_provision: resources_lib.Resources,
elif cloud.is_same_cloud(clouds.Azure()):
region = config['provider']['location']
zones = None
elif cloud.is_same_cloud(clouds.Lambda()):
region = config['provider']['region']
zones = None
elif cloud.is_same_cloud(clouds.Local()):
local_regions = clouds.Local.regions()
region = local_regions[0].name
Expand Down Expand Up @@ -1636,6 +1679,15 @@ def provision_with_retries(
cloud_user = None
else:
cloud_user = to_provision.cloud.get_current_user_identity()
# Skip if to_provision.cloud does not support requested features
if not to_provision.cloud.supports(self._requested_features):
self._blocked_resources.add(
resources_lib.Resources(cloud=to_provision.cloud))
requested_features_str = ', '.join(
[f.value for f in self._requested_features])
raise exceptions.ResourcesUnavailableError(
f'{to_provision.cloud} does not support all the '
f'features in [{requested_features_str}].')
config_dict = self._retry_region_zones(
to_provision,
num_nodes,
Expand Down Expand Up @@ -1952,6 +2004,7 @@ def __init__(self):

self._dag = None
self._optimize_target = None
self._requested_features = set()

# Command for running the setup script. It is only set when the
# setup needs to be run outside the self._setup() and as part of
Expand All @@ -1964,6 +2017,8 @@ def register_info(self, **kwargs) -> None:
self._dag = kwargs.pop('dag', self._dag)
self._optimize_target = kwargs.pop(
'optimize_target', self._optimize_target) or OptimizeTarget.COST
self._requested_features = kwargs.pop('requested_features',
self._requested_features)
assert len(kwargs) == 0, f'Unexpected kwargs: {kwargs}'

def check_resources_fit_cluster(self, handle: ResourceHandle,
Expand Down Expand Up @@ -2087,10 +2142,9 @@ def _provision(self,
# in if retry_until_up is set, which will kick off new "rounds"
# of optimization infinitely.
try:
provisioner = RetryingVmProvisioner(self.log_dir, self._dag,
self._optimize_target,
local_wheel_path,
wheel_hash)
provisioner = RetryingVmProvisioner(
self.log_dir, self._dag, self._optimize_target,
self._requested_features, local_wheel_path, wheel_hash)
config_dict = provisioner.provision_with_retries(
task, to_provision_config, dryrun, stream_logs)
break
Expand Down
4 changes: 4 additions & 0 deletions sky/clouds/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
"""Clouds in Sky."""
from sky.clouds.cloud import Cloud
from sky.clouds.cloud import CLOUD_REGISTRY
from sky.clouds.cloud import CloudImplementationFeatures
from sky.clouds.cloud import Region
from sky.clouds.cloud import Zone
from sky.clouds.aws import AWS
from sky.clouds.azure import Azure
from sky.clouds.gcp import GCP
from sky.clouds.lambda_cloud import Lambda
from sky.clouds.local import Local

__all__ = [
'AWS',
'Azure',
'Cloud',
'GCP',
'Lambda',
'Local',
'CloudImplementationFeatures',
'Region',
'Zone',
'CLOUD_REGISTRY',
Expand Down
10 changes: 9 additions & 1 deletion sky/clouds/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import subprocess
import typing
from typing import Dict, Iterator, List, Optional, Tuple
from typing import Dict, Iterator, List, Optional, Set, Tuple

from sky import clouds
from sky import exceptions
Expand Down Expand Up @@ -544,3 +544,11 @@ def accelerator_in_region_or_zone(self,
zone: Optional[str] = None) -> bool:
return service_catalog.accelerator_in_region_or_zone(
accelerator, acc_count, region, zone, 'aws')

@classmethod
def supports(
cls, requested_features: Set[clouds.CloudImplementationFeatures]
) -> bool:
# All clouds.CloudImplementationFeatures implemented
del requested_features
return True
10 changes: 9 additions & 1 deletion sky/clouds/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import subprocess
import typing
from typing import Dict, Iterator, List, Optional, Tuple
from typing import Dict, Iterator, List, Optional, Set, Tuple

from sky import clouds
from sky import exceptions
Expand Down Expand Up @@ -406,3 +406,11 @@ def get_project_id(cls, dryrun: bool = False) -> str:
'cli command: "az account set -s <subscription_id>".'
) from e
return azure_subscription_id

@classmethod
def supports(
cls, requested_features: Set[clouds.CloudImplementationFeatures]
) -> bool:
# All clouds.CloudImplementationFeatures implemented
del requested_features
return True
23 changes: 22 additions & 1 deletion sky/clouds/cloud.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Interfaces: clouds, regions, and zones."""
import enum
import collections
import typing
from typing import Dict, Iterator, List, Optional, Tuple, Type
from typing import Dict, Iterator, List, Optional, Set, Tuple, Type

from sky.clouds import service_catalog
from sky.utils import ux_utils
Expand All @@ -10,6 +11,16 @@
from sky import resources


class CloudImplementationFeatures(enum.Enum):
"""Features that might not be implemented for all clouds.

Used by Cloud.supports()
"""
STOP = 'stop'
AUTOSTOP = 'autostop'
MULTI_NODE = 'multi-node'


class Region(collections.namedtuple('Region', ['name'])):
"""A region."""
name: str
Expand Down Expand Up @@ -308,5 +319,15 @@ def need_cleanup_after_preemption(self,
del resource
return False

@classmethod
def supports(cls,
requested_features: Set[CloudImplementationFeatures]) -> bool:
"""Returns whether all of the requested features are supported.

For instance, Lambda Cloud does not support autostop, so
Lambda.support({CloudImplementationFeatures.AUTOSTOP}) returns False.
"""
raise NotImplementedError

def __repr__(self):
return self._REPR
10 changes: 9 additions & 1 deletion sky/clouds/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import subprocess
import time
import typing
from typing import Dict, Iterator, List, Optional, Tuple
from typing import Dict, Iterator, List, Optional, Set, Tuple

from sky import clouds
from sky import exceptions
Expand Down Expand Up @@ -664,3 +664,11 @@ def check_accelerator_attachable_to_host(
zone: Optional[str] = None) -> None:
service_catalog.check_accelerator_attachable_to_host(
instance_type, accelerators, zone, 'gcp')

@classmethod
def supports(
cls, requested_features: Set[clouds.CloudImplementationFeatures]
) -> bool:
# All clouds.CloudImplementationFeatures implemented
del requested_features
return True
Loading