Skip to content

Commit

Permalink
API/YAML: Support using env vars to define file_mounts. (#2146)
Browse files Browse the repository at this point in the history
* WIP: storage creation is eager

* Fixes.

* Fixes

* Updates

* Guard against invalid env keys.

* Updates

* cleanup

* Pylint

* Update
  • Loading branch information
concretevitamin authored Jun 30, 2023
1 parent eee216a commit 3385975
Show file tree
Hide file tree
Showing 11 changed files with 294 additions and 49 deletions.
52 changes: 52 additions & 0 deletions examples/using_file_mounts_with_env_vars.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Example showcasing how env vars can be used in the file_mounts section.

# You can set the default values for env vars in this 'envs' section.
# When launching, `sky launch --env ENV=val` will override the default.
envs:
MY_BUCKET: skypilot-temp-gcs-test
MY_LOCAL_PATH: tmp-workdir
MODEL_SIZE: 13b

resources:
cloud: gcp

# You can use env vars in
# - the destination: source paths
# - for SkyPilot Storage
# - the "name" (bucket name) and "source" (local dir/file to upload) fields.
#
# Both syntaxes work: ${MY_BUCKET} and $MY_BUCKET.
file_mounts:
/mydir:
name: ${MY_BUCKET} # Name of the bucket.
store: gcs
mode: MOUNT

/mydir2:
name: $MY_BUCKET # Name of the bucket.
store: gcs
mode: MOUNT

/another-dir:
name: ${MY_BUCKET}-2
source: ["~/${MY_LOCAL_PATH}"]
store: gcs
mode: MOUNT

/another-dir2:
name: $MY_BUCKET-2
source: ["~/${MY_LOCAL_PATH}"]
store: gcs
mode: MOUNT

/checkpoint/${MODEL_SIZE}: ~/${MY_LOCAL_PATH}

run: |
echo Env var MY_BUCKET has value: ${MY_BUCKET}
echo Env var MY_LOCAL_PATH has value: ${MY_LOCAL_PATH}
ls -lthr /mydir
ls -lthr /mydir2
ls -lthr /another-dir
ls -lthr /another-dir2
ls -lthr /checkpoint/${MODEL_SIZE}
33 changes: 22 additions & 11 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2502,17 +2502,28 @@ def validate_schema(obj, schema, err_msg_prefix=''):
validator.SchemaValidator(schema).validate(obj)
except jsonschema.ValidationError as e:
if e.validator == 'additionalProperties':
err_msg = err_msg_prefix + 'The following fields are invalid:'
known_fields = set(e.schema.get('properties', {}).keys())
for field in e.instance:
if field not in known_fields:
most_similar_field = difflib.get_close_matches(
field, known_fields, 1)
if most_similar_field:
err_msg += (f'\nInstead of {field!r}, did you mean '
f'{most_similar_field[0]!r}?')
else:
err_msg += f'\nFound unsupported field {field!r}.'
if tuple(e.schema_path) == ('properties', 'envs',
'additionalProperties'):
# Hack. Here the error is Task.envs having some invalid keys. So
# we should not print "unsupported field".
#
# This will print something like:
# 'hello world' does not match any of the regexes: <regex>
err_msg = (err_msg_prefix +
'The `envs` field contains invalid keys:\n' +
e.message)
else:
err_msg = err_msg_prefix + 'The following fields are invalid:'
known_fields = set(e.schema.get('properties', {}).keys())
for field in e.instance:
if field not in known_fields:
most_similar_field = difflib.get_close_matches(
field, known_fields, 1)
if most_similar_field:
err_msg += (f'\nInstead of {field!r}, did you mean '
f'{most_similar_field[0]!r}?')
else:
err_msg += f'\nFound unsupported field {field!r}.'
else:
# Example e.json_path value: '$.resources'
err_msg = (err_msg_prefix + e.message +
Expand Down
26 changes: 18 additions & 8 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,15 +941,23 @@ def _check_yaml(entrypoint: str) -> Tuple[bool, Optional[Dict[str, Any]]]:
entrypoint: Path to a YAML file.
"""
is_yaml = True
config = None
config: Optional[List[Dict[str, Any]]] = None
result = None
shell_splits = shlex.split(entrypoint)
yaml_file_provided = len(shell_splits) == 1 and \
(shell_splits[0].endswith('yaml') or shell_splits[0].endswith('.yml'))
yaml_file_provided = (len(shell_splits) == 1 and
(shell_splits[0].endswith('yaml') or
shell_splits[0].endswith('.yml')))
try:
with open(entrypoint, 'r') as f:
try:
config = list(yaml.safe_load_all(f))[0]
if isinstance(config, str):
config = list(yaml.safe_load_all(f))
if config:
# FIXME(zongheng): in a chain DAG YAML it only returns the
# first section. OK for downstream but is weird.
result = config[0]
else:
result = {}
if isinstance(result, str):
# 'sky exec cluster ./my_script.sh'
is_yaml = False
except yaml.YAMLError as e:
Expand Down Expand Up @@ -977,7 +985,7 @@ def _check_yaml(entrypoint: str) -> Tuple[bool, Optional[Dict[str, Any]]]:
f'{entrypoint!r} looks like a yaml path but {invalid_reason}\n'
'It will be treated as a command to be run remotely. Continue?',
abort=True)
return is_yaml, config
return is_yaml, result


def _make_task_or_dag_from_entrypoint_with_overrides(
Expand Down Expand Up @@ -1041,7 +1049,7 @@ def _make_task_or_dag_from_entrypoint_with_overrides(
if is_yaml:
assert entrypoint is not None
usage_lib.messages.usage.update_user_task_yaml(entrypoint)
dag = dag_utils.load_chain_dag_from_yaml(entrypoint)
dag = dag_utils.load_chain_dag_from_yaml(entrypoint, env_overrides=env)
if len(dag.tasks) > 1:
# When the dag has more than 1 task. It is unclear how to
# override the params for the dag. So we just ignore the
Expand All @@ -1052,7 +1060,9 @@ def _make_task_or_dag_from_entrypoint_with_overrides(
'since the yaml file contains multiple tasks.',
fg='yellow')
return dag
task = sky.Task.from_yaml(entrypoint)
assert len(dag.tasks) == 1, (
f'If you see this, please file an issue; tasks: {dag.tasks}')
task = dag.tasks[0]
else:
task = sky.Task(name='sky-cmd', run=entrypoint)
task.set_resources({sky.Resources()})
Expand Down
2 changes: 1 addition & 1 deletion sky/clouds/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ def get_credential_file_mounts(self) -> Dict[str, str]:
# be modified on the remote cluster by ray causing authentication
# problems. The backup file will be updated to the remote cluster
# whenever the original file is not empty and will be applied
# appropriately on the remote cluster when neccessary.
# appropriately on the remote cluster when necessary.
if (os.path.exists(os.path.expanduser(GCP_CONFIG_PATH)) and
os.path.getsize(os.path.expanduser(GCP_CONFIG_PATH)) > 0):
subprocess.run(f'cp {GCP_CONFIG_PATH} {GCP_CONFIG_SKY_BACKUP_PATH}',
Expand Down
9 changes: 4 additions & 5 deletions sky/data/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,11 +1321,10 @@ def _validate(self):
if not _is_storage_cloud_enabled(str(clouds.GCP())):
with ux_utils.print_exception_no_traceback():
raise exceptions.ResourcesUnavailableError(
'Storage \'store: gcs\' specified, but ' \
'GCP access is disabled. To fix, enable '\
'GCP by running `sky check`. '\
'More info: https://skypilot.readthedocs.io/en/latest/getting-started/installation.html.' # pylint: disable=line-too-long
)
'Storage \'store: gcs\' specified, but '
'GCP access is disabled. To fix, enable '
'GCP by running `sky check`. '
'More info: https://skypilot.readthedocs.io/en/latest/getting-started/installation.html.') # pylint: disable=line-too-long

@classmethod
def validate_name(cls, name) -> str:
Expand Down
69 changes: 65 additions & 4 deletions sky/task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Task: a coarse-grained stage in an application."""
import inspect
import json
import os
import re
import typing
Expand Down Expand Up @@ -68,6 +69,40 @@ def _is_valid_env_var(name: str) -> bool:
return bool(re.fullmatch(_VALID_ENV_VAR_REGEX, name))


def _fill_in_env_vars_in_file_mounts(
file_mounts: Dict[str, Any],
task_envs: Dict[str, str],
) -> Dict[str, Any]:
"""Detects env vars in file_mounts and fills them with task_envs.
Use cases of env vars in file_mounts:
- dst/src paths; e.g.,
/model_path/llama-${SIZE}b: s3://llama-weights/llama-${SIZE}b
- storage's name (bucket name)
- storage's source (local path)
We simply dump file_mounts into a json string, and replace env vars using
regex. This should be safe as file_mounts has been schema-validated.
Env vars of the following forms are detected:
- ${ENV}
- $ENV
where <ENV> must appear in task.envs.
"""
# TODO(zongheng): support ${ENV:-default}?
file_mounts_str = json.dumps(file_mounts)

def replace_var(match):
var_name = match.group(1)
# If the variable isn't in the dictionary, return it unchanged
return task_envs.get(var_name, match.group(0))

# Pattern for valid env var names in bash.
pattern = r'\$\{?\b([a-zA-Z_][a-zA-Z0-9_]*)\b\}?'
file_mounts_str = re.sub(pattern, replace_var, file_mounts_str)
return json.loads(file_mounts_str)


class Task:
"""Task: a computation to be run on the cloud."""

Expand Down Expand Up @@ -231,10 +266,37 @@ def _validate(self):
f'a symlink to a directory). {self.workdir} not found.')

@staticmethod
def from_yaml_config(config: Dict[str, Any]) -> 'Task':
def from_yaml_config(
config: Dict[str, Any],
env_overrides: Optional[List[Tuple[str, str]]] = None,
) -> 'Task':
if env_overrides is not None:
# We must override env vars before constructing the Task, because
# the Storage object creation is eager and it (its name/source
# fields) may depend on env vars.
#
# FIXME(zongheng): The eagerness / how we construct Task's from
# entrypoint (YAML, CLI args) should be fixed.
new_envs = config.get('envs', {})
new_envs.update(env_overrides)
config['envs'] = new_envs

# More robust handling for 'envs': explicitly convert keys and values to
# str, since users may pass '123' as keys/values which will get parsed
# as int causing validate_schema() to fail.
envs = config.get('envs')
if envs is not None and isinstance(envs, dict):
config['envs'] = {str(k): str(v) for k, v in envs.items()}

backend_utils.validate_schema(config, schemas.get_task_schema(),
'Invalid task YAML: ')

# Fill in any Task.envs into file_mounts (src/dst paths, storage
# name/source).
if config.get('file_mounts') is not None:
config['file_mounts'] = _fill_in_env_vars_in_file_mounts(
config['file_mounts'], config.get('envs', {}))

task = Task(
config.pop('name', None),
run=config.pop('run', None),
Expand Down Expand Up @@ -269,8 +331,7 @@ def from_yaml_config(config: Dict[str, Any]) -> 'Task':
all_storages = fm_storages
for storage in all_storages:
mount_path = storage[0]
assert mount_path, \
'Storage mount path cannot be empty.'
assert mount_path, 'Storage mount path cannot be empty.'
try:
storage_obj = storage_lib.Storage.from_yaml_config(storage[1])
except exceptions.StorageSourceError as e:
Expand Down Expand Up @@ -649,7 +710,7 @@ def update_storage_mounts(
Different from set_storage_mounts(), this function updates into the
existing storage_mounts (calls ``dict.update()``), rather than
overwritting it.
overwriting it.
This should be called before provisioning in order to take effect.
Expand Down
6 changes: 5 additions & 1 deletion sky/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,11 @@ def read_yaml(path) -> Dict[str, Any]:
def read_yaml_all(path: str) -> List[Dict[str, Any]]:
with open(path, 'r') as f:
config = yaml.safe_load_all(f)
return list(config)
configs = list(config)
if not configs:
# Empty YAML file.
return [{}]
return configs


def dump_yaml(path, config) -> None:
Expand Down
32 changes: 30 additions & 2 deletions sky/utils/dag_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,29 @@
"""Utilities for loading and dumping DAGs from/to YAML files."""
from typing import List, Optional, Tuple

from sky import dag as dag_lib
from sky import task as task_lib
from sky.backends import backend_utils
from sky.utils import common_utils


def load_chain_dag_from_yaml(path: str) -> dag_lib.Dag:
def load_chain_dag_from_yaml(
path: str,
env_overrides: Optional[List[Tuple[str, str]]] = None,
) -> dag_lib.Dag:
"""Loads a chain DAG from a YAML file.
Has special handling for an initial section in YAML that contains only the
'name' field, which is the DAG name.
'env_overrides' is in effect only when there's exactly one task. It is a
list of (key, value) pairs that will be used to update the task's 'envs'
section.
Returns:
A chain Dag with 1 or more tasks (an empty entrypoint would create a
trivial task).
"""
configs = common_utils.read_yaml_all(path)
dag_name = None
if set(configs[0].keys()) == {'name'}:
Expand All @@ -14,12 +32,22 @@ def load_chain_dag_from_yaml(path: str) -> dag_lib.Dag:
elif len(configs) == 1:
dag_name = configs[0].get('name')

if len(configs) == 0:
# YAML has only `name: xxx`. Still instantiate a task.
configs = [{'name': dag_name}]

if len(configs) > 1:
# TODO(zongheng): in a chain DAG of N tasks, cli.py currently makes the
# decision to not apply overrides. Here we maintain this behavior. We
# can listen to user feedback to change this.
env_overrides = None

current_task = None
with dag_lib.Dag() as dag:
for task_config in configs:
if task_config is None:
continue
task = task_lib.Task.from_yaml_config(task_config)
task = task_lib.Task.from_yaml_config(task_config, env_overrides)
if current_task is not None:
current_task >> task # pylint: disable=pointless-statement
current_task = task
Expand Down
8 changes: 7 additions & 1 deletion sky/utils/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,13 @@ def get_task_schema():
'envs': {
'type': 'object',
'required': [],
'additionalProperties': True,
'patternProperties': {
# Checks env keys are valid env var names.
'^[a-zA-Z_][a-zA-Z0-9_]*$': {
'type': 'string'
}
},
'additionalProperties': False,
},
# inputs and outputs are experimental
'inputs': {
Expand Down
Loading

0 comments on commit 3385975

Please sign in to comment.