Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Fix resource validation for existing cluster and refactor #1259

Merged
merged 5 commits into from
Jul 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ jobs:
- tests/test_cli.py
- tests/test_config.py
- tests/test_global_user_state.py
- tests/test_jobs.py
- tests/test_list_accelerators.py
- tests/test_optimizer_dryruns.py
- tests/test_optimizer_random_dag.py
Expand Down
16 changes: 11 additions & 5 deletions sky/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,15 @@ def add_storage_objects(self, task: 'task_lib.Task') -> None:

@timeline.event
@usage_lib.messages.usage.update_runtime('execute')
def execute(self, handle: _ResourceHandleType, task: 'task_lib.Task',
detach_run: bool) -> None:
def execute(self,
handle: _ResourceHandleType,
task: 'task_lib.Task',
detach_run: bool,
dryrun: bool = False) -> None:
usage_lib.record_cluster_name_for_current_operation(
handle.get_cluster_name())
usage_lib.messages.usage.update_actual_task(task)
return self._execute(handle, task, detach_run)
return self._execute(handle, task, detach_run, dryrun)

@timeline.event
def post_execute(self, handle: _ResourceHandleType, down: bool) -> None:
Expand Down Expand Up @@ -136,8 +139,11 @@ def _setup(self, handle: _ResourceHandleType, task: 'task_lib.Task',
detach_setup: bool) -> None:
raise NotImplementedError

def _execute(self, handle: _ResourceHandleType, task: 'task_lib.Task',
detach_run: bool) -> None:
def _execute(self,
handle: _ResourceHandleType,
task: 'task_lib.Task',
detach_run: bool,
dryrun: bool = False) -> None:
raise NotImplementedError

def _post_execute(self, handle: _ResourceHandleType, down: bool) -> None:
Expand Down
7 changes: 7 additions & 0 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2136,6 +2136,7 @@ def check_cluster_available(
*,
operation: str,
check_cloud_vm_ray_backend: Literal[True] = True,
dryrun: bool = ...,
) -> 'cloud_vm_ray_backend.CloudVmRayResourceHandle':
...

Expand All @@ -2146,6 +2147,7 @@ def check_cluster_available(
*,
operation: str,
check_cloud_vm_ray_backend: Literal[False],
dryrun: bool = ...,
) -> backends.ResourceHandle:
...

Expand All @@ -2155,6 +2157,7 @@ def check_cluster_available(
*,
operation: str,
check_cloud_vm_ray_backend: bool = True,
dryrun: bool = False,
) -> backends.ResourceHandle:
"""Check if the cluster is available.

Expand All @@ -2168,6 +2171,10 @@ def check_cluster_available(
exceptions.CloudUserIdentityError: if we fail to get the current user
identity.
"""
if dryrun:
record = global_user_state.get_cluster_from_name(cluster_name)
assert record is not None, cluster_name
return record['handle']
try:
cluster_status, handle = refresh_cluster_status_handle(cluster_name)
except exceptions.ClusterStatusFetchingError as e:
Expand Down
67 changes: 40 additions & 27 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1440,7 +1440,7 @@ def _retry_zones(
f'Failed to find catalog in region {region.name}: {e}')
continue
if dryrun:
return
return config_dict
cluster_config_file = config_dict['ray']

# Record early, so if anything goes wrong, 'sky status' will show
Expand Down Expand Up @@ -2414,10 +2414,9 @@ def _provision(
to_provision,
task.num_nodes,
prev_cluster_status=None)
if not dryrun: # dry run doesn't need to check existing cluster.
# Try to launch the exiting cluster first
to_provision_config = self._check_existing_cluster(
task, to_provision, cluster_name)
# Try to launch the exiting cluster first
to_provision_config = self._check_existing_cluster(
task, to_provision, cluster_name, dryrun)
assert to_provision_config.resources is not None, (
'to_provision should not be None', to_provision_config)

Expand Down Expand Up @@ -2495,7 +2494,8 @@ def _provision(
error_message,
failover_history=e.failover_history) from None
if dryrun:
return None
record = global_user_state.get_cluster_from_name(cluster_name)
return record['handle'] if record is not None else None
cluster_config_file = config_dict['ray']

handle = CloudVmRayResourceHandle(
Expand Down Expand Up @@ -3008,6 +3008,7 @@ def _execute(
handle: CloudVmRayResourceHandle,
task: task_lib.Task,
detach_run: bool,
dryrun: bool = False,
) -> None:
if task.run is None:
logger.info('Run commands not specified or empty.')
Expand All @@ -3017,6 +3018,11 @@ def _execute(
self.check_resources_fit_cluster(handle, task)

resources_str = backend_utils.get_task_resources_str(task)

if dryrun:
logger.info(f'Dryrun complete. Would have run:\n{task}')
return

job_id = self._add_job(handle, task.name, resources_str)

is_tpu_vm_pod = tpu_utils.is_tpu_vm_pod(handle.launched_resources)
Expand Down Expand Up @@ -3767,9 +3773,11 @@ def run_on_head(

@timeline.event
def _check_existing_cluster(
self, task: task_lib.Task,
self,
task: task_lib.Task,
to_provision: Optional[resources_lib.Resources],
cluster_name: str) -> RetryingVmProvisioner.ToProvisionConfig:
cluster_name: str,
dryrun: bool = False) -> RetryingVmProvisioner.ToProvisionConfig:
"""Checks if the cluster exists and returns the provision config.

Raises:
Expand All @@ -3782,25 +3790,30 @@ def _check_existing_cluster(
handle_before_refresh = None if record is None else record['handle']
status_before_refresh = None if record is None else record['status']

prev_cluster_status, handle = (
backend_utils.refresh_cluster_status_handle(
cluster_name,
# We force refresh for the init status to determine the actual
# state of a previous cluster in INIT state.
#
# This is important for the case, where an existing cluster is
# transitioned into INIT state due to key interruption during
# launching, with the following steps:
# (1) launch, after answering prompt immediately ctrl-c;
# (2) launch again.
# If we don't refresh the state of the cluster and reset it back
# to STOPPED, our failover logic will consider it as an abnormal
# cluster after hitting resources capacity limit on the cloud,
# and will start failover. This is not desired, because the user
# may want to keep the data on the disk of that cluster.
force_refresh_statuses={status_lib.ClusterStatus.INIT},
acquire_per_cluster_status_lock=False,
))
prev_cluster_status, handle = (status_before_refresh,
handle_before_refresh)

if not dryrun:
prev_cluster_status, handle = (
backend_utils.refresh_cluster_status_handle(
cluster_name,
# We force refresh for the init status to determine the
# actual state of a previous cluster in INIT state.
#
# This is important for the case, where an existing cluster
# is transitioned into INIT state due to key interruption
# during launching, with the following steps:
# (1) launch, after answering prompt immediately ctrl-c;
# (2) launch again.
# If we don't refresh the state of the cluster and reset it
# back to STOPPED, our failover logic will consider it as an
# abnormal cluster after hitting resources capacity limit on
# the cloud, and will start failover. This is not desired,
# because the user may want to keep the data on the disk of
# that cluster.
force_refresh_statuses={status_lib.ClusterStatus.INIT},
acquire_per_cluster_status_lock=False,
))
if prev_cluster_status is not None:
assert handle is not None
# Cluster already exists.
Expand Down
11 changes: 9 additions & 2 deletions sky/backends/local_docker_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,11 @@ def _setup(self, handle: LocalDockerResourceHandle, task: 'task_lib.Task',
requested_resources=task.resources,
ready=True)

def _execute(self, handle: LocalDockerResourceHandle, task: 'task_lib.Task',
detach_run: bool) -> None:
def _execute(self,
handle: LocalDockerResourceHandle,
task: 'task_lib.Task',
detach_run: bool,
dryrun: bool = False) -> None:
""" Launches the container."""

if detach_run:
Expand All @@ -283,6 +286,10 @@ def _execute(self, handle: LocalDockerResourceHandle, task: 'task_lib.Task',
logger.info(f'Nothing to run; run command not specified:\n{task}')
return

if dryrun:
logger.info(f'Dryrun complete. Would have run:\n{task}')
return

self._execute_task_one_node(handle, task)

def _post_execute(self, handle: LocalDockerResourceHandle,
Expand Down
4 changes: 2 additions & 2 deletions sky/clouds/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,8 @@ def make_deploy_resources_variables(
**AWS._get_disk_specs(r.disk_tier)
}

def get_feasible_launchable_resources(self,
resources: 'resources_lib.Resources'):
def _get_feasible_launchable_resources(
self, resources: 'resources_lib.Resources'):
if resources.instance_type is not None:
assert resources.is_launchable(), resources
# Treat Resources(AWS, p3.2x, V100) as Resources(AWS, p3.2x).
Expand Down
12 changes: 8 additions & 4 deletions sky/clouds/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def make_deploy_resources_variables(
'disk_tier': Azure._get_disk_type(r.disk_tier)
}

def get_feasible_launchable_resources(self, resources):
def _get_feasible_launchable_resources(self, resources):

def failover_disk_tier(
instance_type: str,
Expand Down Expand Up @@ -288,11 +288,15 @@ def failover_disk_tier(
return ([], [])
if resources.instance_type is not None:
assert resources.is_launchable(), resources
# Treat Resources(AWS, p3.2x, V100) as Resources(AWS, p3.2x).
ok, disk_tier = failover_disk_tier(resources.instance_type,
resources.disk_tier)
if not ok:
return ([], [])
# Treat Resources(Azure, Standard_NC4as_T4_v3, T4) as
# Resources(Azure, Standard_NC4as_T4_v3).
resources = resources.copy(
accelerators=None,
disk_tier=failover_disk_tier(resources.instance_type,
resources.disk_tier),
disk_tier=disk_tier,
)
return ([resources], [])

Expand Down
63 changes: 58 additions & 5 deletions sky/clouds/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

if typing.TYPE_CHECKING:
from sky import status_lib
from sky import resources
from sky import resources as resources_lib


class CloudImplementationFeatures(enum.Enum):
Expand Down Expand Up @@ -219,7 +219,7 @@ def is_same_cloud(self, other):

def make_deploy_resources_variables(
self,
resources: 'resources.Resources',
resources: 'resources_lib.Resources',
region: 'Region',
zones: Optional[List['Zone']],
) -> Dict[str, Optional[str]]:
Expand Down Expand Up @@ -294,6 +294,11 @@ def get_feasible_launchable_resources(self, resources):

Launchable resources require a cloud and an instance type be assigned.
"""
if resources.is_launchable():
self._check_instance_type_accelerators_combination(resources)
return self._get_feasible_launchable_resources(resources)

def _get_feasible_launchable_resources(self, resources):
raise NotImplementedError

@classmethod
Expand Down Expand Up @@ -395,8 +400,8 @@ def accelerator_in_region_or_zone(self,
"""Returns whether the accelerator is valid in the region or zone."""
raise NotImplementedError

def need_cleanup_after_preemption(self,
resource: 'resources.Resources') -> bool:
def need_cleanup_after_preemption(
self, resource: 'resources_lib.Resources') -> bool:
"""Returns whether a spot resource needs cleanup after preeemption.

In most cases, spot resources do not need cleanup after preemption,
Expand Down Expand Up @@ -478,7 +483,54 @@ def check_disk_tier_enabled(cls, instance_type: str,
raise NotImplementedError

@classmethod
# pylint: disable=unused-argument
def _check_instance_type_accelerators_combination(
cls, resources: 'resources_lib.Resources') -> None:
"""Errors out if the accelerator is not supported by the instance type.

This function is overridden by GCP for host-accelerator logic.

Raises:
ResourcesMismatchError: If the accelerator is not supported.
"""
assert resources.is_launchable(), resources

def _equal_accelerators(
acc_requested: Optional[Dict[str, int]],
acc_from_instance_type: Optional[Dict[str, int]]) -> bool:
"""Check the requested accelerators equals to the instance type

Check the requested accelerators equals to the accelerators
from the instance type (both the accelerator type and the
count).
"""
if acc_requested is None:
return acc_from_instance_type is None
if acc_from_instance_type is None:
return False

for acc in acc_requested:
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
if acc not in acc_from_instance_type:
return False
if acc_requested[acc] != acc_from_instance_type[acc]:
return False
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
return True

acc_from_instance_type = (cls.get_accelerators_from_instance_type(
resources.instance_type))
if not _equal_accelerators(resources.accelerators,
acc_from_instance_type):
with ux_utils.print_exception_no_traceback():
raise exceptions.ResourcesMismatchError(
'Infeasible resource demands found:'
'\n Instance type requested: '
f'{resources.instance_type}\n'
f' Accelerators for {resources.instance_type}: '
f'{acc_from_instance_type}\n'
f' Accelerators requested: {resources.accelerators}\n'
f'To fix: either only specify instance_type, or '
'change the accelerators field to be consistent.')

@classmethod
def check_quota_available(cls,
region: str,
instance_type: str,
Expand Down Expand Up @@ -526,6 +578,7 @@ def check_quota_available(cls,
Returns:
False if the quota is found to be zero, and true otherwise.
"""
del region, instance_type, use_spot # unused

return True

Expand Down
18 changes: 6 additions & 12 deletions sky/clouds/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def make_deploy_resources_variables(

return resources_vars

def get_feasible_launchable_resources(self, resources):
def _get_feasible_launchable_resources(self, resources):
if resources.instance_type is not None:
assert resources.is_launchable(), resources
return ([resources], [])
Expand Down Expand Up @@ -766,18 +766,12 @@ def get_project_id(cls, dryrun: bool = False) -> str:
return project_id

@staticmethod
def check_host_accelerator_compatibility(
instance_type: str, accelerators: Optional[Dict[str, int]]) -> None:
service_catalog.check_host_accelerator_compatibility(
instance_type, accelerators, 'gcp')

@staticmethod
def check_accelerator_attachable_to_host(
instance_type: str,
accelerators: Optional[Dict[str, int]],
zone: Optional[str] = None) -> None:
def _check_instance_type_accelerators_combination(
resources: 'resources.Resources') -> None:
assert resources.is_launchable(), resources
service_catalog.check_accelerator_attachable_to_host(
instance_type, accelerators, zone, 'gcp')
resources.instance_type, resources.accelerators, resources.zone,
'gcp')

@classmethod
def check_disk_tier_enabled(cls, instance_type: str,
Expand Down
Loading