Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make --fast robust against credential or wheel updates #4289

Merged
merged 21 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions sky/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,23 @@ def check_resources_fit_cluster(self, handle: _ResourceHandleType,
@timeline.event
@usage_lib.messages.usage.update_runtime('provision')
def provision(
self,
task: 'task_lib.Task',
to_provision: Optional['resources.Resources'],
dryrun: bool,
stream_logs: bool,
cluster_name: Optional[str] = None,
retry_until_up: bool = False) -> Optional[_ResourceHandleType]:
self,
task: 'task_lib.Task',
to_provision: Optional['resources.Resources'],
dryrun: bool,
stream_logs: bool,
cluster_name: Optional[str] = None,
retry_until_up: bool = False,
skip_if_no_updates: bool = False,
cg505 marked this conversation as resolved.
Show resolved Hide resolved
) -> Optional[_ResourceHandleType]:
if cluster_name is None:
cluster_name = sky.backends.backend_utils.generate_cluster_name()
usage_lib.record_cluster_name_for_current_operation(cluster_name)
usage_lib.messages.usage.update_actual_task(task)
with rich_utils.safe_status(ux_utils.spinner_message('Launching')):
return self._provision(task, to_provision, dryrun, stream_logs,
cluster_name, retry_until_up)
cluster_name, retry_until_up,
skip_if_no_updates)

@timeline.event
@usage_lib.messages.usage.update_runtime('sync_workdir')
Expand Down Expand Up @@ -126,13 +129,15 @@ def register_info(self, **kwargs) -> None:

# --- Implementations of the APIs ---
def _provision(
self,
task: 'task_lib.Task',
to_provision: Optional['resources.Resources'],
dryrun: bool,
stream_logs: bool,
cluster_name: str,
retry_until_up: bool = False) -> Optional[_ResourceHandleType]:
self,
task: 'task_lib.Task',
to_provision: Optional['resources.Resources'],
dryrun: bool,
stream_logs: bool,
cluster_name: str,
retry_until_up: bool = False,
skip_if_no_updates: bool = False,
) -> Optional[_ResourceHandleType]:
raise NotImplementedError

def _sync_workdir(self, handle: _ResourceHandleType, workdir: Path) -> None:
Expand Down
129 changes: 124 additions & 5 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import enum
import fnmatch
import functools
import hashlib
import os
import pathlib
import pprint
Expand Down Expand Up @@ -640,11 +641,14 @@ def write_cluster_config(
keep_launch_fields_in_existing_config: bool = True) -> Dict[str, str]:
"""Fills in cluster configuration templates and writes them out.

Returns: {provisioner: path to yaml, the provisioning spec}.
'provisioner' can be
- 'ray'
- 'tpu-create-script' (if TPU is requested)
- 'tpu-delete-script' (if TPU is requested)
Returns:
Dict with the following keys:
- 'ray': Path to the generated Ray yaml config file
- 'cluster_name': Name of the cluster
- 'cluster_name_on_cloud': Name of the cluster as it appears in the
cloud provider
- 'config_hash': Hash of the cluster config and file mounts contents

Raises:
exceptions.ResourcesUnavailableError: if the region/zones requested does
not appear in the catalog, or an ssh_proxy_command is specified but
Expand Down Expand Up @@ -860,6 +864,7 @@ def write_cluster_config(
if dryrun:
# If dryrun, return the unfinished tmp yaml path.
config_dict['ray'] = tmp_yaml_path
config_dict['config_hash'] = _deterministic_yaml_hash(tmp_yaml_path)
return config_dict
_add_auth_to_cluster_config(cloud, tmp_yaml_path)

Expand All @@ -882,6 +887,11 @@ def write_cluster_config(
yaml_config = common_utils.read_yaml(tmp_yaml_path)
config_dict['cluster_name_on_cloud'] = yaml_config['cluster_name']

# Make sure to do this before we optimize file mounts. Optimization is
# non-deterministic, but everything else before this point should be
# deterministic.
config_dict['config_hash'] = _deterministic_yaml_hash(tmp_yaml_path)
cg505 marked this conversation as resolved.
Show resolved Hide resolved

# Optimization: copy the contents of source files in file_mounts to a
# special dir, and upload that as the only file_mount instead. Delay
# calling this optimization until now, when all source files have been
Expand Down Expand Up @@ -990,6 +1000,115 @@ def get_ready_nodes_counts(pattern, output):
return ready_head, ready_workers


@timeline.event
def _deterministic_yaml_hash(yaml_path: str) -> str:
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
"""Hash the cluster yaml and contents of file mounts to a unique string.

Two invocations of this function should return the same string if and only
if the contents of the yaml are the same and the file contents of all the
file_mounts specified in the yaml are the same.

Limitations:
- This function can be expensive if the file mounts are large. (E.g. a few
seconds for ~1GB.) This should be okay since we expect that the
file_mounts in the cluster yaml (the wheel and cloud credentials) will be
small.
- Symbolic links are not explicitly handled. Some symbolic link changes may
not be detected.

Implementation: We create a byte sequence that captures the state of the
yaml file and all the files in the file mounts, then hash the byte sequence.

The format of the byte sequence is:
32 bytes - sha256 hash of the yaml file
for each file mount:
file mount remote destination (UTF-8), \0
if the file mount source is a file:
'file' encoded to UTF-8
32 byte sha256 hash of the file contents
if the file mount source is a directory:
'dir' encoded to UTF-8
for each directory and subdirectory withinin the file mount (starting from
the root and descending recursively):
name of the directory (UTF-8), \0
name of each subdirectory within the directory (UTF-8) terminated by \0
\0
for each file in the directory:
name of the file (UTF-8), \0
32 bytes - sha256 hash of the file contents
\0
if the file mount source is something else or does not exist, nothing
\0\0

Rather than constructing the whole byte sequence, which may be quite large,
we construct it incrementally by using hash.update() to add new bytes.
"""

def _hash_file(path: str) -> bytes:
return common_utils.hash_file(path, 'sha256').digest()

config_hash = hashlib.sha256()

config_hash.update(_hash_file(yaml_path))

yaml_config = common_utils.read_yaml(yaml_path)
file_mounts = yaml_config.get('file_mounts', {})
# Remove the file mounts added by the newline.
if '' in file_mounts:
assert file_mounts[''] == '', file_mounts['']
file_mounts.pop('')
cg505 marked this conversation as resolved.
Show resolved Hide resolved

for dst, src in sorted(file_mounts.items()):
expanded_src = os.path.expanduser(src)
cg505 marked this conversation as resolved.
Show resolved Hide resolved
config_hash.update(dst.encode('utf-8') + b'\0')

# If the file mount source is a symlink, this should be true. In that
# case we hash the contents of the symlink destination.
if os.path.isfile(expanded_src):
config_hash.update('file'.encode('utf-8'))
config_hash.update(_hash_file(expanded_src))

# This can also be a symlink to a directory. os.walk will treat it as a
# normal directory and list the contents of the symlink destination.
elif os.path.isdir(expanded_src):
config_hash.update('dir'.encode('utf-8'))

# Aside from expanded_src, os.walk will list symlinks to directories
# but will not recurse into them.
for (dirpath, dirnames, filenames) in os.walk(expanded_src):
config_hash.update(dirpath.encode('utf-8') + b'\0')

# Note: inplace sort will also affect the traversal order of
# os.walk. We need it so that the os.walk order is
# deterministic.
dirnames.sort()
cg505 marked this conversation as resolved.
Show resolved Hide resolved
# This includes symlinks to directories. os.walk will recurse
# into all the directories but not the symlinks. We don't hash
# the link destination, so if a symlink to a directory changes,
# we won't notice.
for dirname in dirnames:
config_hash.update(dirname.encode('utf-8') + b'\0')
config_hash.update(b'\0')

filenames.sort()
# This includes symlinks to files. We could hash the symlink
# destination itself but instead just hash the destination
# contents.
for filename in filenames:
config_hash.update(filename.encode('utf-8') + b'\0')
config_hash.update(
_hash_file(os.path.join(dirpath, filename)))
config_hash.update(b'\0')

else:
logger.debug(
f'Unexpected file_mount that is not a file or dir: {src}')

config_hash.update(b'\0\0')

return config_hash.hexdigest()


def get_docker_user(ip: str, cluster_config_file: str) -> str:
"""Find docker container username."""
ssh_credentials = ssh_credential_from_yaml(cluster_config_file)
Expand Down
Loading
Loading