Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API/YAML: Support using env vars to define file_mounts. #2146

Merged
merged 9 commits into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,11 @@ def _check_yaml(entrypoint: str) -> Tuple[bool, Optional[Dict[str, Any]]]:
try:
with open(entrypoint, 'r') as f:
try:
config = list(yaml.safe_load_all(f))[0]
config = list(yaml.safe_load_all(f))
if config:
config = config[0]
else:
config = {}
if isinstance(config, str):
# 'sky exec cluster ./my_script.sh'
is_yaml = False
Expand Down Expand Up @@ -1041,7 +1045,7 @@ def _make_task_or_dag_from_entrypoint_with_overrides(
if is_yaml:
assert entrypoint is not None
usage_lib.messages.usage.update_user_task_yaml(entrypoint)
dag = dag_utils.load_chain_dag_from_yaml(entrypoint)
dag = dag_utils.load_chain_dag_from_yaml(entrypoint, env_overrides=env)
concretevitamin marked this conversation as resolved.
Show resolved Hide resolved
if len(dag.tasks) > 1:
# When the dag has more than 1 task. It is unclear how to
# override the params for the dag. So we just ignore the
Expand All @@ -1052,7 +1056,9 @@ def _make_task_or_dag_from_entrypoint_with_overrides(
'since the yaml file contains multiple tasks.',
fg='yellow')
return dag
task = sky.Task.from_yaml(entrypoint)
assert len(dag.tasks) == 1, (
f'If you see this, please file an issue; tasks: {dag.tasks}')
task = dag.tasks[0]
else:
task = sky.Task(name='sky-cmd', run=entrypoint)
task.set_resources({sky.Resources()})
Expand Down
2 changes: 1 addition & 1 deletion sky/clouds/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ def get_credential_file_mounts(self) -> Dict[str, str]:
# be modified on the remote cluster by ray causing authentication
# problems. The backup file will be updated to the remote cluster
# whenever the original file is not empty and will be applied
# appropriately on the remote cluster when neccessary.
# appropriately on the remote cluster when necessary.
if (os.path.exists(os.path.expanduser(GCP_CONFIG_PATH)) and
os.path.getsize(os.path.expanduser(GCP_CONFIG_PATH)) > 0):
subprocess.run(f'cp {GCP_CONFIG_PATH} {GCP_CONFIG_SKY_BACKUP_PATH}',
Expand Down
9 changes: 4 additions & 5 deletions sky/data/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,11 +1321,10 @@ def _validate(self):
if not _is_storage_cloud_enabled(str(clouds.GCP())):
with ux_utils.print_exception_no_traceback():
raise exceptions.ResourcesUnavailableError(
'Storage \'store: gcs\' specified, but ' \
'GCP access is disabled. To fix, enable '\
'GCP by running `sky check`. '\
'More info: https://skypilot.readthedocs.io/en/latest/getting-started/installation.html.' # pylint: disable=line-too-long
)
'Storage \'store: gcs\' specified, but '
'GCP access is disabled. To fix, enable '
'GCP by running `sky check`. '
'More info: https://skypilot.readthedocs.io/en/latest/getting-started/installation.html.') # pylint: disable=line-too-long

@classmethod
def validate_name(cls, name) -> str:
Expand Down
1 change: 1 addition & 0 deletions sky/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ def _execute(
backend.sync_workdir(handle, task.workdir)

if Stage.SYNC_FILE_MOUNTS in stages:
task.fill_env_vars_in_storage_mounts()
backend.sync_file_mounts(handle, task.file_mounts,
task.storage_mounts)

Expand Down
70 changes: 66 additions & 4 deletions sky/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,45 @@ def _is_valid_env_var(name: str) -> bool:
return bool(re.fullmatch(_VALID_ENV_VAR_REGEX, name))


def _fill_in_env_vars_in_storage_config(
storage_config: Dict[str, Any], task_envs: Dict[str,
str]) -> Dict[str, Any]:
"""Detects env vars in storage_config and fills them with task.envs.

This func tries to replace two fields in storage_config: 'source' and
'name'.

Env vars of the following forms are detected:
- ${ENV}
- $ENV
where <ENV> must appear in task.envs.

Example:
{'mode': 'MOUNT', 'name': '${GSBUCKET}', 'store': 'gcs'}
-> {mode: 'MOUNT', name: 'mybucket', store: 'gcs'}
"""

# TODO(zongheng): support ${ENV:-default}?
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
def _substitute_env_vars(target: Union[str, List[str]],
envs: Dict[str, str]) -> Union[str, List[str]]:
for key, value in envs.items():
pattern = r'\$\{?\b' + key + r'\b\}?'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may need to add a check for the key of the task.envs somewhere in our code, i.e. we need to make sure the key ~ '[a-zA-Z][a-zA-Z0-9]*'

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will get checked by storage.py's validate_name().

Copy link
Collaborator

@Michaelvll Michaelvll Jun 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, I was trying to say that if the user provides a envvar like the following:

envs:
  hello world: some_value

run: |
  echo ${hello world}

Here the run section will fail, as it is not a valid syntax in bash. However, if this env var is used in the file_mounts section, the envvar will be replaced correctly.

file_mounts:
  /dst:
    source: ${hello world}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call! Done.

env_var_pattern = re.compile(pattern)
if isinstance(target, str):
target = env_var_pattern.sub(value, target)
else:
# source: [/local/path, /local/path2]
target = [env_var_pattern.sub(value, t) for t in target]
return target

fields_to_fill_in = ('name', 'source')
for field in fields_to_fill_in:
if field in storage_config:
storage_config[field] = _substitute_env_vars(
storage_config[field], task_envs)
return storage_config


class Task:
"""Task: a computation to be run on the cloud."""

Expand Down Expand Up @@ -231,7 +270,19 @@ def _validate(self):
f'a symlink to a directory). {self.workdir} not found.')

@staticmethod
def from_yaml_config(config: Dict[str, Any]) -> 'Task':
def from_yaml_config(
config: Dict[str, Any],
env_overrides: Optional[List[Tuple[str, str]]] = None,
) -> 'Task':
if env_overrides is not None:
# We must override env vars before constructing the Task, because
# the Storage object creation is eager and it (its name/source
# fields) may depend on env vars. The eagerness / how we construct
# Task's from entrypoint (YAML, CLI args) should be fixed (FIXME).
new_envs = config.get('envs', {})
new_envs.update(env_overrides)
config['envs'] = new_envs

backend_utils.validate_schema(config, schemas.get_task_schema(),
'Invalid task YAML: ')

Expand Down Expand Up @@ -269,10 +320,12 @@ def from_yaml_config(config: Dict[str, Any]) -> 'Task':
all_storages = fm_storages
for storage in all_storages:
mount_path = storage[0]
assert mount_path, \
'Storage mount path cannot be empty.'
assert mount_path, 'Storage mount path cannot be empty.'
try:
storage_obj = storage_lib.Storage.from_yaml_config(storage[1])
storage_config = _fill_in_env_vars_in_storage_config(
storage[1], task.envs)
storage_obj = storage_lib.Storage.from_yaml_config(
storage_config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of overriding the storage_config, should we just replace all the env var in the file_mounts section? For example, in L285 we do:

if config.get('file_mounts') is not None:
    file_mounts_str = yaml.dumps(config['file_mounts'])
    new_file_mounts_str = _replace_env_var(file_mounts_str, config.get('envs', {}))
    config['file_mounts'] = yaml.loads(new_file_mounts_str)

With that, the user can do something like the following:

file_mounts:
  /model_path/llama-${SIZE}b: s3://llama-weights/llama-${SIZE}b

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One worry with string replacement is that it'll make some weird things "work":

file_mounts:
  /path:
    mo${empty_envvar}de: xxxx

I think we should make replacement more structured.

That said, supporting env vars in src/dst seems a useful case. How about we support it when users ask?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, though I still think the replacement for dst and src might be quite important, as that was one thing we would like to have for the Vicuna for downloading the correct llama weight from the bucket based on the envvar.

For the problem you mentioned, we can do the schema validation before the replacement happens to avoid that issue.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call! Updated as such.

except exceptions.StorageSourceError as e:
# Patch the error message to include the mount path, if included
e.args = (e.args[0].replace('<destination_path>',
Expand Down Expand Up @@ -670,6 +723,15 @@ def update_storage_mounts(
task_storage_mounts.update(storage_mounts)
return self.set_storage_mounts(task_storage_mounts)

def fill_env_vars_in_storage_mounts(self) -> None:
"""Updates self.storage_mounts with self.envs."""
new_storage_mounts = {}
for mnt_path, storage in self.storage_mounts.items():
new_storage_mounts[mnt_path] = storage_lib.Storage.from_yaml_config(
_fill_in_env_vars_in_storage_config(storage.to_yaml_config(),
self.envs))
self.update_storage_mounts(new_storage_mounts)

def get_preferred_store_type(self) -> storage_lib.StoreType:
# TODO(zhwu, romilb): The optimizer should look at the source and
# destination to figure out the right stores to use. For now, we
Expand Down
6 changes: 5 additions & 1 deletion sky/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,11 @@ def read_yaml(path) -> Dict[str, Any]:
def read_yaml_all(path: str) -> List[Dict[str, Any]]:
with open(path, 'r') as f:
config = yaml.safe_load_all(f)
return list(config)
configs = list(config)
if not configs:
# Empty YAML file.
return [{}]
return configs


def dump_yaml(path, config) -> None:
Expand Down
21 changes: 19 additions & 2 deletions sky/utils/dag_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
"""Utilities for loading and dumping DAGs from/to YAML files."""
from typing import List, Optional, Tuple

from sky import dag as dag_lib
from sky import task as task_lib
from sky.backends import backend_utils
from sky.utils import common_utils


def load_chain_dag_from_yaml(path: str) -> dag_lib.Dag:
def load_chain_dag_from_yaml(
path: str,
env_overrides: Optional[List[Tuple[str, str]]] = None,
) -> dag_lib.Dag:
"""Loads a chain DAG from a YAML file.

Has special handling for an initial section in YAML that contains only the
'name' field, which is the DAG name.

Returns: a chain Dag with 1 or more tasks (an empty entrypoint would create
a trivial task).
"""
configs = common_utils.read_yaml_all(path)
dag_name = None
if set(configs[0].keys()) == {'name'}:
Expand All @@ -14,12 +27,16 @@ def load_chain_dag_from_yaml(path: str) -> dag_lib.Dag:
elif len(configs) == 1:
dag_name = configs[0].get('name')

if len(configs) == 0:
# YAML has only `name: xxx`. Still instantiate a task.
configs = [{'name': dag_name}]

current_task = None
with dag_lib.Dag() as dag:
for task_config in configs:
if task_config is None:
continue
task = task_lib.Task.from_yaml_config(task_config)
task = task_lib.Task.from_yaml_config(task_config, env_overrides)
if current_task is not None:
current_task >> task # pylint: disable=pointless-statement
current_task = task
Expand Down
56 changes: 42 additions & 14 deletions tests/test_optimizer_dryruns.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,38 @@
from sky import exceptions


def _test_parse_cpus(spec, expected_cpus):
def _test_parse_task_yaml(spec: str, test_fn):
"""Tests parsing a task from a YAML spec and running a test_fn."""
with tempfile.NamedTemporaryFile('w') as f:
f.write(spec)
f.flush()
with sky.Dag():
task = sky.Task.from_yaml(f.name)
assert list(task.resources)[0].cpus == expected_cpus
test_fn(task)


def _test_parse_cpus(spec, expected_cpus):

def test_fn(task):
assert list(task.resources)[0].cpus == expected_cpus

_test_parse_task_yaml(spec, test_fn)


def _test_parse_memory(spec, expected_memory):
with tempfile.NamedTemporaryFile('w') as f:
f.write(spec)
f.flush()
with sky.Dag():
task = sky.Task.from_yaml(f.name)
assert list(task.resources)[0].memory == expected_memory

def test_fn(task):
assert list(task.resources)[0].memory == expected_memory

_test_parse_task_yaml(spec, test_fn)


def _test_parse_accelerators(spec, expected_accelerators):
with tempfile.NamedTemporaryFile('w') as f:
f.write(spec)
f.flush()
with sky.Dag():
task = sky.Task.from_yaml(f.name)
assert list(task.resources)[0].accelerators == expected_accelerators

def test_fn(task):
assert list(task.resources)[0].accelerators == expected_accelerators

_test_parse_task_yaml(spec, test_fn)


# Monkey-patching is required because in the test environment, no cloud is
Expand Down Expand Up @@ -547,3 +554,24 @@ def test_invalid_num_nodes():
task = sky.Task()
task.num_nodes = invalid_value
assert 'num_nodes should be a positive int' in str(e.value)


def test_parse_empty_yaml():
spec = textwrap.dedent("""\
""")

def test_fn(task):
assert task.num_nodes == 1

_test_parse_task_yaml(spec, test_fn)


def test_parse_name_only_yaml():
spec = textwrap.dedent("""\
name: test_task
""")

def test_fn(task):
assert task.name == 'test_task'

_test_parse_task_yaml(spec, test_fn)
24 changes: 23 additions & 1 deletion tests/test_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Run all tests except for AWS and Lambda Cloud
# > pytest tests/test_smoke.py
#
# Terminate failed clusetrs after test finishes
# Terminate failed clusters after test finishes
# > pytest tests/test_smoke.py --terminate-on-failure
#
# Re-run last failed tests
Expand Down Expand Up @@ -738,6 +738,28 @@ def test_scp_file_mounts():
run_one_test(test)


def test_using_file_mounts_with_env_vars(generic_cloud: str):
name = _get_cluster_name()
test_commands = [
*storage_setup_commands,
(f'sky launch -y -c {name} --cpus 2+ --cloud {generic_cloud} '
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, we should have all our tests to be --cpus 2+ in the future to save cost. ; )

'examples/using_file_mounts_with_env_vars.yaml'),
concretevitamin marked this conversation as resolved.
Show resolved Hide resolved
f'sky logs {name} 1 --status', # Ensure the job succeeded.
# Override with --env:
(f'sky launch -y -c {name}-2 --cpus 2+ --cloud {generic_cloud} '
'examples/using_file_mounts_with_env_vars.yaml '
'--env MY_LOCAL_PATH=tmpfile'),
f'sky logs {name}-2 1 --status', # Ensure the job succeeded.
]
test = Test(
'using_file_mounts_with_env_vars',
test_commands,
f'sky down -y {name} {name}-2',
timeout=20 * 60, # 20 mins
)
run_one_test(test)


# ---------- storage ----------
@pytest.mark.aws
def test_aws_storage_mounts():
Expand Down