Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API/YAML: Support using env vars to define file_mounts. #2146

Merged
merged 9 commits into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions examples/using_file_mounts_with_env_vars.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Example showcasing how env vars can be used in the file_mounts section.

# You can set the default values for env vars in this 'envs' section.
# When launching, `sky launch --env ENV=val` will override the default.
envs:
MY_BUCKET: skypilot-temp-gcs-test
MY_LOCAL_PATH: tmp-workdir
MODEL_SIZE: 13b

resources:
cloud: gcp

# You can use env vars in
# - the destination: source paths
# - for SkyPilot Storage
# - the "name" (bucket name) and "source" (local dir/file to upload) fields.
#
# Both syntaxes work: ${MY_BUCKET} and $MY_BUCKET.
file_mounts:
/mydir:
name: ${MY_BUCKET} # Name of the bucket.
store: gcs
mode: MOUNT

/mydir2:
name: $MY_BUCKET # Name of the bucket.
store: gcs
mode: MOUNT

/another-dir:
name: ${MY_BUCKET}-2
source: ["~/${MY_LOCAL_PATH}"]
store: gcs
mode: MOUNT

/another-dir2:
name: $MY_BUCKET-2
source: ["~/${MY_LOCAL_PATH}"]
store: gcs
mode: MOUNT

/checkpoint/${MODEL_SIZE}: ~/${MY_LOCAL_PATH}

run: |
echo Env var MY_BUCKET has value: ${MY_BUCKET}
echo Env var MY_LOCAL_PATH has value: ${MY_LOCAL_PATH}

ls -lthr /mydir
ls -lthr /mydir2
ls -lthr /another-dir
ls -lthr /another-dir2
ls -lthr /checkpoint/${MODEL_SIZE}
33 changes: 22 additions & 11 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2502,17 +2502,28 @@ def validate_schema(obj, schema, err_msg_prefix=''):
validator.SchemaValidator(schema).validate(obj)
except jsonschema.ValidationError as e:
if e.validator == 'additionalProperties':
err_msg = err_msg_prefix + 'The following fields are invalid:'
known_fields = set(e.schema.get('properties', {}).keys())
for field in e.instance:
if field not in known_fields:
most_similar_field = difflib.get_close_matches(
field, known_fields, 1)
if most_similar_field:
err_msg += (f'\nInstead of {field!r}, did you mean '
f'{most_similar_field[0]!r}?')
else:
err_msg += f'\nFound unsupported field {field!r}.'
if tuple(e.schema_path) == ('properties', 'envs',
'additionalProperties'):
# Hack. Here the error is Task.envs having some invalid keys. So
# we should not print "unsupported field".
#
# This will print something like:
# 'hello world' does not match any of the regexes: <regex>
err_msg = (err_msg_prefix +
'The `envs` field contains invalid keys:\n' +
e.message)
else:
err_msg = err_msg_prefix + 'The following fields are invalid:'
known_fields = set(e.schema.get('properties', {}).keys())
for field in e.instance:
if field not in known_fields:
most_similar_field = difflib.get_close_matches(
field, known_fields, 1)
if most_similar_field:
err_msg += (f'\nInstead of {field!r}, did you mean '
f'{most_similar_field[0]!r}?')
else:
err_msg += f'\nFound unsupported field {field!r}.'
else:
# Example e.json_path value: '$.resources'
err_msg = (err_msg_prefix + e.message +
Expand Down
26 changes: 18 additions & 8 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,15 +941,23 @@ def _check_yaml(entrypoint: str) -> Tuple[bool, Optional[Dict[str, Any]]]:
entrypoint: Path to a YAML file.
"""
is_yaml = True
config = None
config: Optional[List[Dict[str, Any]]] = None
result = None
shell_splits = shlex.split(entrypoint)
yaml_file_provided = len(shell_splits) == 1 and \
(shell_splits[0].endswith('yaml') or shell_splits[0].endswith('.yml'))
yaml_file_provided = (len(shell_splits) == 1 and
(shell_splits[0].endswith('yaml') or
shell_splits[0].endswith('.yml')))
try:
with open(entrypoint, 'r') as f:
try:
config = list(yaml.safe_load_all(f))[0]
if isinstance(config, str):
config = list(yaml.safe_load_all(f))
if config:
# FIXME(zongheng): in a chain DAG YAML it only returns the
# first section. OK for downstream but is weird.
result = config[0]
else:
result = {}
if isinstance(result, str):
# 'sky exec cluster ./my_script.sh'
is_yaml = False
except yaml.YAMLError as e:
Expand Down Expand Up @@ -977,7 +985,7 @@ def _check_yaml(entrypoint: str) -> Tuple[bool, Optional[Dict[str, Any]]]:
f'{entrypoint!r} looks like a yaml path but {invalid_reason}\n'
'It will be treated as a command to be run remotely. Continue?',
abort=True)
return is_yaml, config
return is_yaml, result


def _make_task_or_dag_from_entrypoint_with_overrides(
Expand Down Expand Up @@ -1041,7 +1049,7 @@ def _make_task_or_dag_from_entrypoint_with_overrides(
if is_yaml:
assert entrypoint is not None
usage_lib.messages.usage.update_user_task_yaml(entrypoint)
dag = dag_utils.load_chain_dag_from_yaml(entrypoint)
dag = dag_utils.load_chain_dag_from_yaml(entrypoint, env_overrides=env)
concretevitamin marked this conversation as resolved.
Show resolved Hide resolved
if len(dag.tasks) > 1:
# When the dag has more than 1 task. It is unclear how to
# override the params for the dag. So we just ignore the
Expand All @@ -1052,7 +1060,9 @@ def _make_task_or_dag_from_entrypoint_with_overrides(
'since the yaml file contains multiple tasks.',
fg='yellow')
return dag
task = sky.Task.from_yaml(entrypoint)
assert len(dag.tasks) == 1, (
f'If you see this, please file an issue; tasks: {dag.tasks}')
task = dag.tasks[0]
else:
task = sky.Task(name='sky-cmd', run=entrypoint)
task.set_resources({sky.Resources()})
Expand Down
2 changes: 1 addition & 1 deletion sky/clouds/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ def get_credential_file_mounts(self) -> Dict[str, str]:
# be modified on the remote cluster by ray causing authentication
# problems. The backup file will be updated to the remote cluster
# whenever the original file is not empty and will be applied
# appropriately on the remote cluster when neccessary.
# appropriately on the remote cluster when necessary.
if (os.path.exists(os.path.expanduser(GCP_CONFIG_PATH)) and
os.path.getsize(os.path.expanduser(GCP_CONFIG_PATH)) > 0):
subprocess.run(f'cp {GCP_CONFIG_PATH} {GCP_CONFIG_SKY_BACKUP_PATH}',
Expand Down
9 changes: 4 additions & 5 deletions sky/data/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,11 +1321,10 @@ def _validate(self):
if not _is_storage_cloud_enabled(str(clouds.GCP())):
with ux_utils.print_exception_no_traceback():
raise exceptions.ResourcesUnavailableError(
'Storage \'store: gcs\' specified, but ' \
'GCP access is disabled. To fix, enable '\
'GCP by running `sky check`. '\
'More info: https://skypilot.readthedocs.io/en/latest/getting-started/installation.html.' # pylint: disable=line-too-long
)
'Storage \'store: gcs\' specified, but '
'GCP access is disabled. To fix, enable '
'GCP by running `sky check`. '
'More info: https://skypilot.readthedocs.io/en/latest/getting-started/installation.html.') # pylint: disable=line-too-long

@classmethod
def validate_name(cls, name) -> str:
Expand Down
64 changes: 60 additions & 4 deletions sky/task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Task: a coarse-grained stage in an application."""
import inspect
import json
import os
import re
import typing
Expand Down Expand Up @@ -68,6 +69,35 @@ def _is_valid_env_var(name: str) -> bool:
return bool(re.fullmatch(_VALID_ENV_VAR_REGEX, name))


def _fill_in_env_vars_in_file_mounts(
file_mounts: Dict[str, Any],
task_envs: Dict[str, str],
) -> Dict[str, Any]:
"""Detects env vars in file_mounts and fills them with task_envs.

Use cases of env vars in file_mounts:
- dst/src paths; e.g.,
/model_path/llama-${SIZE}b: s3://llama-weights/llama-${SIZE}b
- storage's name (bucket name)
- storage's source (local path)

We simply dump file_mounts into a json string, and replace env vars using
regex. This should be safe as file_mounts has been schema-validated.

Env vars of the following forms are detected:
- ${ENV}
- $ENV
where <ENV> must appear in task.envs.
"""
# TODO(zongheng): support ${ENV:-default}?
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
file_mounts_str = json.dumps(file_mounts)
for key, value in task_envs.items():
pattern = r'\$\{?\b' + key + r'\b\}?'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it seems bash replaces the envvar by looking for the envvar in the string first and match it with the existing envvar, while here we do it in a reversed way. That will cause a slight inconsistency in the behavior:

HELLO_WORLD=1
echo $HELLO_WORLD_1

The block above will output an empty string, but in our implementation, we will output 1_1.

Probably, an alternative way to do this is:

# Create a replacement function
def replace_var(match):
    var_name = match.group(1)  # Now directly accessing the variable name
    return task_envs.get(var_name, match.group(0))  # If the variable isn't in the dictionary, return it unchanged

pattern = r'\$\{?\b([a-zA-Z_][a-zA-Z0-9_]*)\b\}?'
# Use re.sub with the replacement function
result = re.sub(pattern, replace_var, file_mounts_str)

This is a very corner case, we can also leave it as is to see how people feel about it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch! Used this & reran the smoke test.

env_var_pattern = re.compile(pattern)
file_mounts_str = env_var_pattern.sub(value, file_mounts_str)
return json.loads(file_mounts_str)


class Task:
"""Task: a computation to be run on the cloud."""

Expand Down Expand Up @@ -231,10 +261,37 @@ def _validate(self):
f'a symlink to a directory). {self.workdir} not found.')

@staticmethod
def from_yaml_config(config: Dict[str, Any]) -> 'Task':
def from_yaml_config(
config: Dict[str, Any],
env_overrides: Optional[List[Tuple[str, str]]] = None,
) -> 'Task':
if env_overrides is not None:
# We must override env vars before constructing the Task, because
# the Storage object creation is eager and it (its name/source
# fields) may depend on env vars.
#
# FIXME(zongheng): The eagerness / how we construct Task's from
# entrypoint (YAML, CLI args) should be fixed.
new_envs = config.get('envs', {})
new_envs.update(env_overrides)
config['envs'] = new_envs

# More robust handling for 'envs': explicitly convert keys and values to
# str, since users may pass '123' as keys/values which will get parsed
# as int causing validate_schema() to fail.
envs = config.get('envs')
if envs is not None and isinstance(envs, dict):
config['envs'] = {str(k): str(v) for k, v in envs.items()}

backend_utils.validate_schema(config, schemas.get_task_schema(),
'Invalid task YAML: ')

# Fill in any Task.envs into file_mounts (src/dst paths, storage
# name/source).
if config.get('file_mounts') is not None:
config['file_mounts'] = _fill_in_env_vars_in_file_mounts(
config['file_mounts'], config.get('envs', {}))

task = Task(
config.pop('name', None),
run=config.pop('run', None),
Expand Down Expand Up @@ -269,8 +326,7 @@ def from_yaml_config(config: Dict[str, Any]) -> 'Task':
all_storages = fm_storages
for storage in all_storages:
mount_path = storage[0]
assert mount_path, \
'Storage mount path cannot be empty.'
assert mount_path, 'Storage mount path cannot be empty.'
try:
storage_obj = storage_lib.Storage.from_yaml_config(storage[1])
except exceptions.StorageSourceError as e:
Expand Down Expand Up @@ -649,7 +705,7 @@ def update_storage_mounts(

Different from set_storage_mounts(), this function updates into the
existing storage_mounts (calls ``dict.update()``), rather than
overwritting it.
overwriting it.

This should be called before provisioning in order to take effect.

Expand Down
6 changes: 5 additions & 1 deletion sky/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,11 @@ def read_yaml(path) -> Dict[str, Any]:
def read_yaml_all(path: str) -> List[Dict[str, Any]]:
with open(path, 'r') as f:
config = yaml.safe_load_all(f)
return list(config)
configs = list(config)
if not configs:
# Empty YAML file.
return [{}]
return configs


def dump_yaml(path, config) -> None:
Expand Down
32 changes: 30 additions & 2 deletions sky/utils/dag_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,29 @@
"""Utilities for loading and dumping DAGs from/to YAML files."""
from typing import List, Optional, Tuple

from sky import dag as dag_lib
from sky import task as task_lib
from sky.backends import backend_utils
from sky.utils import common_utils


def load_chain_dag_from_yaml(path: str) -> dag_lib.Dag:
def load_chain_dag_from_yaml(
path: str,
env_overrides: Optional[List[Tuple[str, str]]] = None,
) -> dag_lib.Dag:
"""Loads a chain DAG from a YAML file.

Has special handling for an initial section in YAML that contains only the
'name' field, which is the DAG name.

'env_overrides' is in effect only when there's exactly one task. It is a
list of (key, value) pairs that will be used to update the task's 'envs'
section.

Returns:
A chain Dag with 1 or more tasks (an empty entrypoint would create a
trivial task).
"""
configs = common_utils.read_yaml_all(path)
dag_name = None
if set(configs[0].keys()) == {'name'}:
Expand All @@ -14,12 +32,22 @@ def load_chain_dag_from_yaml(path: str) -> dag_lib.Dag:
elif len(configs) == 1:
dag_name = configs[0].get('name')

if len(configs) == 0:
# YAML has only `name: xxx`. Still instantiate a task.
configs = [{'name': dag_name}]

if len(configs) > 1:
# TODO(zongheng): in a chain DAG of N tasks, cli.py currently makes the
# decision to not apply overrides. Here we maintain this behavior. We
# can listen to user feedback to change this.
env_overrides = None

current_task = None
with dag_lib.Dag() as dag:
for task_config in configs:
if task_config is None:
continue
task = task_lib.Task.from_yaml_config(task_config)
task = task_lib.Task.from_yaml_config(task_config, env_overrides)
if current_task is not None:
current_task >> task # pylint: disable=pointless-statement
current_task = task
Expand Down
8 changes: 7 additions & 1 deletion sky/utils/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,13 @@ def get_task_schema():
'envs': {
'type': 'object',
'required': [],
'additionalProperties': True,
'patternProperties': {
# Checks env keys are valid env var names.
'^[a-zA-Z_][a-zA-Z0-9_]*$': {
'type': 'string'
}
},
'additionalProperties': False,
},
# inputs and outputs are experimental
'inputs': {
Expand Down
Loading