Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make --fast robust against credential or wheel updates #4289

Merged
merged 21 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions sky/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,23 @@ def check_resources_fit_cluster(self, handle: _ResourceHandleType,
@timeline.event
@usage_lib.messages.usage.update_runtime('provision')
def provision(
self,
task: 'task_lib.Task',
to_provision: Optional['resources.Resources'],
dryrun: bool,
stream_logs: bool,
cluster_name: Optional[str] = None,
retry_until_up: bool = False) -> Optional[_ResourceHandleType]:
self,
task: 'task_lib.Task',
to_provision: Optional['resources.Resources'],
dryrun: bool,
stream_logs: bool,
cluster_name: Optional[str] = None,
retry_until_up: bool = False,
skip_if_config_hash_matches: Optional[str] = None
) -> Optional[_ResourceHandleType]:
if cluster_name is None:
cluster_name = sky.backends.backend_utils.generate_cluster_name()
usage_lib.record_cluster_name_for_current_operation(cluster_name)
usage_lib.messages.usage.update_actual_task(task)
with rich_utils.safe_status(ux_utils.spinner_message('Launching')):
return self._provision(task, to_provision, dryrun, stream_logs,
cluster_name, retry_until_up)
cluster_name, retry_until_up,
skip_if_config_hash_matches)

@timeline.event
@usage_lib.messages.usage.update_runtime('sync_workdir')
Expand Down Expand Up @@ -126,13 +129,15 @@ def register_info(self, **kwargs) -> None:

# --- Implementations of the APIs ---
def _provision(
self,
task: 'task_lib.Task',
to_provision: Optional['resources.Resources'],
dryrun: bool,
stream_logs: bool,
cluster_name: str,
retry_until_up: bool = False) -> Optional[_ResourceHandleType]:
self,
task: 'task_lib.Task',
to_provision: Optional['resources.Resources'],
dryrun: bool,
stream_logs: bool,
cluster_name: str,
retry_until_up: bool = False,
skip_if_config_hash_matches: Optional[str] = None
cg505 marked this conversation as resolved.
Show resolved Hide resolved
) -> Optional[_ResourceHandleType]:
raise NotImplementedError

def _sync_workdir(self, handle: _ResourceHandleType, workdir: Path) -> None:
Expand Down
133 changes: 128 additions & 5 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import enum
import fnmatch
import functools
import hashlib
import os
import pathlib
import pprint
Expand Down Expand Up @@ -640,11 +641,14 @@ def write_cluster_config(
keep_launch_fields_in_existing_config: bool = True) -> Dict[str, str]:
"""Fills in cluster configuration templates and writes them out.

Returns: {provisioner: path to yaml, the provisioning spec}.
'provisioner' can be
- 'ray'
- 'tpu-create-script' (if TPU is requested)
- 'tpu-delete-script' (if TPU is requested)
Returns:
Dict with the following keys:
- 'ray': Path to the generated Ray yaml config file
- 'cluster_name': Name of the cluster
- 'cluster_name_on_cloud': Name of the cluster as it appears in the
cloud provider
- 'config_hash': Hash of the cluster config and file mounts contents

Raises:
exceptions.ResourcesUnavailableError: if the region/zones requested does
not appear in the catalog, or an ssh_proxy_command is specified but
Expand Down Expand Up @@ -860,6 +864,7 @@ def write_cluster_config(
if dryrun:
# If dryrun, return the unfinished tmp yaml path.
config_dict['ray'] = tmp_yaml_path
config_dict['config_hash'] = _deterministic_yaml_hash(tmp_yaml_path)
return config_dict
_add_auth_to_cluster_config(cloud, tmp_yaml_path)

Expand All @@ -882,6 +887,11 @@ def write_cluster_config(
yaml_config = common_utils.read_yaml(tmp_yaml_path)
config_dict['cluster_name_on_cloud'] = yaml_config['cluster_name']

# Make sure to do this before we optimize file mounts. Optimization is
# non-deterministic, but everything else before this point should be
# deterministic.
config_dict['config_hash'] = _deterministic_yaml_hash(tmp_yaml_path)
cg505 marked this conversation as resolved.
Show resolved Hide resolved

# Optimization: copy the contents of source files in file_mounts to a
# special dir, and upload that as the only file_mount instead. Delay
# calling this optimization until now, when all source files have been
Expand Down Expand Up @@ -990,6 +1000,119 @@ def get_ready_nodes_counts(pattern, output):
return ready_head, ready_workers


@timeline.event
def _deterministic_yaml_hash(yaml_path: str) -> str:
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
"""Hashes the cluster yaml and contents of file mounts. Two invocations of
cg505 marked this conversation as resolved.
Show resolved Hide resolved
this function should return the same string if and only if the contents of
the yaml are the same and the file contents of all the file_mounts specified
in the yaml are the same.

Limitations:
- This function can be expensive if the file mounts are large. (E.g. a few
seconds for ~1GB.)
cg505 marked this conversation as resolved.
Show resolved Hide resolved
- Symbolic links are not explicitly handled. Some symbolic link changes may
not be detected.

Implementation: We create a byte sequence that captures the state of the
yaml file and all the files in the file mounts, then hash the byte sequence.

The format of the byte sequence is:
32 bytes - sha256 hash of the yaml file
for each file mount:
file mount remote destination (UTF-8), \0
if the file mount source is a file:
'file' encoded to UTF-8
32 byte sha256 hash of the file contents
if the file mount source is a directory:
'dir' encoded to UTF-8
for each directory and subdirectory withinin the file mount (starting from
the root and descending recursively):
name of the directory (UTF-8), \0
name of each subdirectory within the directory (UTF-8) terminated by \0
\0
for each file in the directory:
name of the file (UTF-8), \0
32 bytes - sha256 hash of the file contents
\0
if the file mount source is something else or does not exist, nothing
\0\0

Rather than constructing the whole byte sequence, which may be quite large,
we construct it incrementally by using hash.update() to add new bytes.
"""

# In python 3.11, hashlib.file_digest is available, but for <3.11 we have to
# do it manually.
# This implementation is simplified from the implementation in CPython.
# Beware of f.read() as some files may be larger than memory.
def _hash_file(path: str) -> bytes:
with open(path, 'rb') as f:
file_hash = hashlib.sha256()
buf = bytearray(2**18)
view = memoryview(buf)
while True:
size = f.readinto(buf)
if size == 0:
# EOF
break
file_hash.update(view[:size])
return file_hash.digest()
cg505 marked this conversation as resolved.
Show resolved Hide resolved

config_hash = hashlib.sha256()

config_hash.update(_hash_file(yaml_path))

yaml_config = common_utils.read_yaml(yaml_path)
file_mounts = yaml_config.get('file_mounts', {})
# Remove the file mounts added by the newline.
if '' in file_mounts:
assert file_mounts[''] == '', file_mounts['']
file_mounts.pop('')
cg505 marked this conversation as resolved.
Show resolved Hide resolved

for dst, src in sorted(file_mounts.items()):
expanded_src = os.path.expanduser(src)
cg505 marked this conversation as resolved.
Show resolved Hide resolved
config_hash.update(dst.encode('utf-8') + b'\0')

if os.path.isfile(expanded_src):
config_hash.update('file'.encode('utf-8'))
config_hash.update(_hash_file(expanded_src))

elif os.path.isdir(expanded_src):
config_hash.update('dir'.encode('utf-8'))

for (dirpath, dirnames, filenames) in os.walk(expanded_src):
config_hash.update(dirpath.encode('utf-8') + b'\0')

# Note: inplace sort will also affect the traversal order of
# os.walk. We need it so that the os.walk order is
# deterministic.
dirnames.sort()
cg505 marked this conversation as resolved.
Show resolved Hide resolved
# This includes symlinks to directories. We will recurse into
# all the directories here but not the symlinks. We don't hash
# the link destination.
for dirname in dirnames:
config_hash.update(dirname.encode('utf-8') + b'\0')
config_hash.update(b'\0')

filenames.sort()
# This includes symlinks to files. We could hash the symlink
# destination itself but instead just hash the destination
# contents.
for filename in filenames:
config_hash.update(filename.encode('utf-8') + b'\0')
config_hash.update(
_hash_file(os.path.join(dirpath, filename)))
config_hash.update(b'\0')

else:
logger.debug(
f'Unexpected file_mount that is not a file or dir: {src}')

config_hash.update(b'\0\0')

return config_hash.hexdigest()


def get_docker_user(ip: str, cluster_config_file: str) -> str:
"""Find docker container username."""
ssh_credentials = ssh_credential_from_yaml(cluster_config_file)
Expand Down
50 changes: 37 additions & 13 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,6 +1314,7 @@ def _retry_zones(
prev_cluster_status: Optional[status_lib.ClusterStatus],
prev_handle: Optional['CloudVmRayResourceHandle'],
prev_cluster_ever_up: bool,
skip_if_config_hash_matches: Optional[str],
) -> Dict[str, Any]:
"""The provision retry loop."""
# Get log_path name
Expand Down Expand Up @@ -1424,8 +1425,15 @@ def _retry_zones(
raise exceptions.ResourcesUnavailableError(
f'Failed to provision on cloud {to_provision.cloud} due to '
f'invalid cloud config: {common_utils.format_exception(e)}')

if skip_if_config_hash_matches == config_dict['config_hash']:
logger.info('Skipping provisioning of cluster with matching '
'config hash.')
cg505 marked this conversation as resolved.
Show resolved Hide resolved
return config_dict

if dryrun:
return config_dict

cluster_config_file = config_dict['ray']

launched_resources = to_provision.copy(region=region.name)
Expand Down Expand Up @@ -1937,6 +1945,7 @@ def provision_with_retries(
to_provision_config: ToProvisionConfig,
dryrun: bool,
stream_logs: bool,
skip_if_config_hash_matches: Optional[str],
) -> Dict[str, Any]:
"""Provision with retries for all launchable resources."""
cluster_name = to_provision_config.cluster_name
Expand Down Expand Up @@ -1986,7 +1995,8 @@ def provision_with_retries(
cloud_user_identity=cloud_user,
prev_cluster_status=prev_cluster_status,
prev_handle=prev_handle,
prev_cluster_ever_up=prev_cluster_ever_up)
prev_cluster_ever_up=prev_cluster_ever_up,
skip_if_config_hash_matches=skip_if_config_hash_matches)
if dryrun:
return config_dict
except (exceptions.InvalidClusterNameError,
Expand Down Expand Up @@ -2687,13 +2697,15 @@ def check_resources_fit_cluster(
return valid_resource

def _provision(
self,
task: task_lib.Task,
to_provision: Optional[resources_lib.Resources],
dryrun: bool,
stream_logs: bool,
cluster_name: str,
retry_until_up: bool = False) -> Optional[CloudVmRayResourceHandle]:
self,
task: task_lib.Task,
to_provision: Optional[resources_lib.Resources],
dryrun: bool,
stream_logs: bool,
cluster_name: str,
retry_until_up: bool = False,
skip_if_config_hash_matches: Optional[str] = None
cg505 marked this conversation as resolved.
Show resolved Hide resolved
) -> Optional[CloudVmRayResourceHandle]:
"""Provisions using 'ray up'.
cg505 marked this conversation as resolved.
Show resolved Hide resolved

Raises:
Expand Down Expand Up @@ -2779,7 +2791,8 @@ def _provision(
rich_utils.force_update_status(
ux_utils.spinner_message('Launching', log_path))
config_dict = retry_provisioner.provision_with_retries(
task, to_provision_config, dryrun, stream_logs)
task, to_provision_config, dryrun, stream_logs,
skip_if_config_hash_matches)
cg505 marked this conversation as resolved.
Show resolved Hide resolved
break
except exceptions.ResourcesUnavailableError as e:
# Do not remove the stopped cluster from the global state
Expand Down Expand Up @@ -2829,6 +2842,15 @@ def _provision(
record = global_user_state.get_cluster_from_name(cluster_name)
return record['handle'] if record is not None else None

config_hash = config_dict['config_hash']

if skip_if_config_hash_matches is not None:
record = global_user_state.get_cluster_from_name(cluster_name)
if (record is not None and skip_if_config_hash_matches ==
config_hash == record['config_hash']):
logger.info('skip remaining')
return record['handle']

if 'provision_record' in config_dict:
# New provisioner is used here.
handle = config_dict['handle']
Expand Down Expand Up @@ -2868,7 +2890,7 @@ def _provision(
self._update_after_cluster_provisioned(
handle, to_provision_config.prev_handle, task,
prev_cluster_status, handle.external_ips(),
handle.external_ssh_ports(), lock_path)
handle.external_ssh_ports(), lock_path, config_hash)
return handle

cluster_config_file = config_dict['ray']
Expand Down Expand Up @@ -2940,7 +2962,8 @@ def _get_zone(runner):

self._update_after_cluster_provisioned(
handle, to_provision_config.prev_handle, task,
prev_cluster_status, ip_list, ssh_port_list, lock_path)
prev_cluster_status, ip_list, ssh_port_list, lock_path,
config_hash)
return handle

def _open_ports(self, handle: CloudVmRayResourceHandle) -> None:
Expand All @@ -2958,8 +2981,8 @@ def _update_after_cluster_provisioned(
prev_handle: Optional[CloudVmRayResourceHandle],
task: task_lib.Task,
prev_cluster_status: Optional[status_lib.ClusterStatus],
ip_list: List[str], ssh_port_list: List[int],
lock_path: str) -> None:
ip_list: List[str], ssh_port_list: List[int], lock_path: str,
config_hash: str) -> None:
usage_lib.messages.usage.update_cluster_resources(
handle.launched_nodes, handle.launched_resources)
usage_lib.messages.usage.update_final_cluster_status(
Expand Down Expand Up @@ -3019,6 +3042,7 @@ def _update_after_cluster_provisioned(
handle,
set(task.resources),
ready=True,
config_hash=config_hash,
)
usage_lib.messages.usage.update_final_cluster_status(
status_lib.ClusterStatus.UP)
Expand Down
18 changes: 11 additions & 7 deletions sky/backends/local_docker_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,14 @@ def check_resources_fit_cluster(self, handle: 'LocalDockerResourceHandle',
pass

def _provision(
self,
task: 'task_lib.Task',
to_provision: Optional['resources.Resources'],
dryrun: bool,
stream_logs: bool,
cluster_name: str,
retry_until_up: bool = False
self,
task: 'task_lib.Task',
to_provision: Optional['resources.Resources'],
dryrun: bool,
stream_logs: bool,
cluster_name: str,
retry_until_up: bool = False,
skip_if_config_hash_matches: Optional[str] = None
) -> Optional[LocalDockerResourceHandle]:
"""Builds docker image for the task and returns cluster name as handle.

Expand All @@ -153,6 +154,9 @@ def _provision(
logger.warning(
f'Retrying until up is not supported in backend: {self.NAME}. '
'Ignored the flag.')
if skip_if_config_hash_matches is not None:
logger.warning(f'Config hashing is not supported in backend: '
f'{self.NAME}. Ignored skip_if_config_hash_matches.')
if stream_logs:
logger.info(
'Streaming build logs is not supported in LocalDockerBackend. '
Expand Down
Loading
Loading