diff --git a/examples/using_file_mounts_with_env_vars.yaml b/examples/using_file_mounts_with_env_vars.yaml new file mode 100644 index 00000000000..74881a2ff88 --- /dev/null +++ b/examples/using_file_mounts_with_env_vars.yaml @@ -0,0 +1,52 @@ +# Example showcasing how env vars can be used in the file_mounts section. + +# You can set the default values for env vars in this 'envs' section. +# When launching, `sky launch --env ENV=val` will override the default. +envs: + MY_BUCKET: skypilot-temp-gcs-test + MY_LOCAL_PATH: tmp-workdir + MODEL_SIZE: 13b + +resources: + cloud: gcp + +# You can use env vars in +# - the destination: source paths +# - for SkyPilot Storage +# - the "name" (bucket name) and "source" (local dir/file to upload) fields. +# +# Both syntaxes work: ${MY_BUCKET} and $MY_BUCKET. +file_mounts: + /mydir: + name: ${MY_BUCKET} # Name of the bucket. + store: gcs + mode: MOUNT + + /mydir2: + name: $MY_BUCKET # Name of the bucket. + store: gcs + mode: MOUNT + + /another-dir: + name: ${MY_BUCKET}-2 + source: ["~/${MY_LOCAL_PATH}"] + store: gcs + mode: MOUNT + + /another-dir2: + name: $MY_BUCKET-2 + source: ["~/${MY_LOCAL_PATH}"] + store: gcs + mode: MOUNT + + /checkpoint/${MODEL_SIZE}: ~/${MY_LOCAL_PATH} + +run: | + echo Env var MY_BUCKET has value: ${MY_BUCKET} + echo Env var MY_LOCAL_PATH has value: ${MY_LOCAL_PATH} + + ls -lthr /mydir + ls -lthr /mydir2 + ls -lthr /another-dir + ls -lthr /another-dir2 + ls -lthr /checkpoint/${MODEL_SIZE} diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index c030bbaea39..4a51324bdbb 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -2502,17 +2502,28 @@ def validate_schema(obj, schema, err_msg_prefix=''): validator.SchemaValidator(schema).validate(obj) except jsonschema.ValidationError as e: if e.validator == 'additionalProperties': - err_msg = err_msg_prefix + 'The following fields are invalid:' - known_fields = set(e.schema.get('properties', {}).keys()) - for field in e.instance: - if field not in known_fields: - most_similar_field = difflib.get_close_matches( - field, known_fields, 1) - if most_similar_field: - err_msg += (f'\nInstead of {field!r}, did you mean ' - f'{most_similar_field[0]!r}?') - else: - err_msg += f'\nFound unsupported field {field!r}.' + if tuple(e.schema_path) == ('properties', 'envs', + 'additionalProperties'): + # Hack. Here the error is Task.envs having some invalid keys. So + # we should not print "unsupported field". + # + # This will print something like: + # 'hello world' does not match any of the regexes: + err_msg = (err_msg_prefix + + 'The `envs` field contains invalid keys:\n' + + e.message) + else: + err_msg = err_msg_prefix + 'The following fields are invalid:' + known_fields = set(e.schema.get('properties', {}).keys()) + for field in e.instance: + if field not in known_fields: + most_similar_field = difflib.get_close_matches( + field, known_fields, 1) + if most_similar_field: + err_msg += (f'\nInstead of {field!r}, did you mean ' + f'{most_similar_field[0]!r}?') + else: + err_msg += f'\nFound unsupported field {field!r}.' else: # Example e.json_path value: '$.resources' err_msg = (err_msg_prefix + e.message + diff --git a/sky/cli.py b/sky/cli.py index 1fcb9786f17..e16eb90042f 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -941,15 +941,23 @@ def _check_yaml(entrypoint: str) -> Tuple[bool, Optional[Dict[str, Any]]]: entrypoint: Path to a YAML file. """ is_yaml = True - config = None + config: Optional[List[Dict[str, Any]]] = None + result = None shell_splits = shlex.split(entrypoint) - yaml_file_provided = len(shell_splits) == 1 and \ - (shell_splits[0].endswith('yaml') or shell_splits[0].endswith('.yml')) + yaml_file_provided = (len(shell_splits) == 1 and + (shell_splits[0].endswith('yaml') or + shell_splits[0].endswith('.yml'))) try: with open(entrypoint, 'r') as f: try: - config = list(yaml.safe_load_all(f))[0] - if isinstance(config, str): + config = list(yaml.safe_load_all(f)) + if config: + # FIXME(zongheng): in a chain DAG YAML it only returns the + # first section. OK for downstream but is weird. + result = config[0] + else: + result = {} + if isinstance(result, str): # 'sky exec cluster ./my_script.sh' is_yaml = False except yaml.YAMLError as e: @@ -977,7 +985,7 @@ def _check_yaml(entrypoint: str) -> Tuple[bool, Optional[Dict[str, Any]]]: f'{entrypoint!r} looks like a yaml path but {invalid_reason}\n' 'It will be treated as a command to be run remotely. Continue?', abort=True) - return is_yaml, config + return is_yaml, result def _make_task_or_dag_from_entrypoint_with_overrides( @@ -1041,7 +1049,7 @@ def _make_task_or_dag_from_entrypoint_with_overrides( if is_yaml: assert entrypoint is not None usage_lib.messages.usage.update_user_task_yaml(entrypoint) - dag = dag_utils.load_chain_dag_from_yaml(entrypoint) + dag = dag_utils.load_chain_dag_from_yaml(entrypoint, env_overrides=env) if len(dag.tasks) > 1: # When the dag has more than 1 task. It is unclear how to # override the params for the dag. So we just ignore the @@ -1052,7 +1060,9 @@ def _make_task_or_dag_from_entrypoint_with_overrides( 'since the yaml file contains multiple tasks.', fg='yellow') return dag - task = sky.Task.from_yaml(entrypoint) + assert len(dag.tasks) == 1, ( + f'If you see this, please file an issue; tasks: {dag.tasks}') + task = dag.tasks[0] else: task = sky.Task(name='sky-cmd', run=entrypoint) task.set_resources({sky.Resources()}) diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index 8f2ac579fce..c9d44f54cf8 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -622,7 +622,7 @@ def get_credential_file_mounts(self) -> Dict[str, str]: # be modified on the remote cluster by ray causing authentication # problems. The backup file will be updated to the remote cluster # whenever the original file is not empty and will be applied - # appropriately on the remote cluster when neccessary. + # appropriately on the remote cluster when necessary. if (os.path.exists(os.path.expanduser(GCP_CONFIG_PATH)) and os.path.getsize(os.path.expanduser(GCP_CONFIG_PATH)) > 0): subprocess.run(f'cp {GCP_CONFIG_PATH} {GCP_CONFIG_SKY_BACKUP_PATH}', diff --git a/sky/data/storage.py b/sky/data/storage.py index 48dbe3134d3..bb2c9ab2ebb 100644 --- a/sky/data/storage.py +++ b/sky/data/storage.py @@ -1321,11 +1321,10 @@ def _validate(self): if not _is_storage_cloud_enabled(str(clouds.GCP())): with ux_utils.print_exception_no_traceback(): raise exceptions.ResourcesUnavailableError( - 'Storage \'store: gcs\' specified, but ' \ - 'GCP access is disabled. To fix, enable '\ - 'GCP by running `sky check`. '\ - 'More info: https://skypilot.readthedocs.io/en/latest/getting-started/installation.html.' # pylint: disable=line-too-long - ) + 'Storage \'store: gcs\' specified, but ' + 'GCP access is disabled. To fix, enable ' + 'GCP by running `sky check`. ' + 'More info: https://skypilot.readthedocs.io/en/latest/getting-started/installation.html.') # pylint: disable=line-too-long @classmethod def validate_name(cls, name) -> str: diff --git a/sky/task.py b/sky/task.py index bc2ff18b66e..fb013e2ee13 100644 --- a/sky/task.py +++ b/sky/task.py @@ -1,5 +1,6 @@ """Task: a coarse-grained stage in an application.""" import inspect +import json import os import re import typing @@ -68,6 +69,40 @@ def _is_valid_env_var(name: str) -> bool: return bool(re.fullmatch(_VALID_ENV_VAR_REGEX, name)) +def _fill_in_env_vars_in_file_mounts( + file_mounts: Dict[str, Any], + task_envs: Dict[str, str], +) -> Dict[str, Any]: + """Detects env vars in file_mounts and fills them with task_envs. + + Use cases of env vars in file_mounts: + - dst/src paths; e.g., + /model_path/llama-${SIZE}b: s3://llama-weights/llama-${SIZE}b + - storage's name (bucket name) + - storage's source (local path) + + We simply dump file_mounts into a json string, and replace env vars using + regex. This should be safe as file_mounts has been schema-validated. + + Env vars of the following forms are detected: + - ${ENV} + - $ENV + where must appear in task.envs. + """ + # TODO(zongheng): support ${ENV:-default}? + file_mounts_str = json.dumps(file_mounts) + + def replace_var(match): + var_name = match.group(1) + # If the variable isn't in the dictionary, return it unchanged + return task_envs.get(var_name, match.group(0)) + + # Pattern for valid env var names in bash. + pattern = r'\$\{?\b([a-zA-Z_][a-zA-Z0-9_]*)\b\}?' + file_mounts_str = re.sub(pattern, replace_var, file_mounts_str) + return json.loads(file_mounts_str) + + class Task: """Task: a computation to be run on the cloud.""" @@ -231,10 +266,37 @@ def _validate(self): f'a symlink to a directory). {self.workdir} not found.') @staticmethod - def from_yaml_config(config: Dict[str, Any]) -> 'Task': + def from_yaml_config( + config: Dict[str, Any], + env_overrides: Optional[List[Tuple[str, str]]] = None, + ) -> 'Task': + if env_overrides is not None: + # We must override env vars before constructing the Task, because + # the Storage object creation is eager and it (its name/source + # fields) may depend on env vars. + # + # FIXME(zongheng): The eagerness / how we construct Task's from + # entrypoint (YAML, CLI args) should be fixed. + new_envs = config.get('envs', {}) + new_envs.update(env_overrides) + config['envs'] = new_envs + + # More robust handling for 'envs': explicitly convert keys and values to + # str, since users may pass '123' as keys/values which will get parsed + # as int causing validate_schema() to fail. + envs = config.get('envs') + if envs is not None and isinstance(envs, dict): + config['envs'] = {str(k): str(v) for k, v in envs.items()} + backend_utils.validate_schema(config, schemas.get_task_schema(), 'Invalid task YAML: ') + # Fill in any Task.envs into file_mounts (src/dst paths, storage + # name/source). + if config.get('file_mounts') is not None: + config['file_mounts'] = _fill_in_env_vars_in_file_mounts( + config['file_mounts'], config.get('envs', {})) + task = Task( config.pop('name', None), run=config.pop('run', None), @@ -269,8 +331,7 @@ def from_yaml_config(config: Dict[str, Any]) -> 'Task': all_storages = fm_storages for storage in all_storages: mount_path = storage[0] - assert mount_path, \ - 'Storage mount path cannot be empty.' + assert mount_path, 'Storage mount path cannot be empty.' try: storage_obj = storage_lib.Storage.from_yaml_config(storage[1]) except exceptions.StorageSourceError as e: @@ -649,7 +710,7 @@ def update_storage_mounts( Different from set_storage_mounts(), this function updates into the existing storage_mounts (calls ``dict.update()``), rather than - overwritting it. + overwriting it. This should be called before provisioning in order to take effect. diff --git a/sky/utils/common_utils.py b/sky/utils/common_utils.py index 8f51c372034..1a281b07a06 100644 --- a/sky/utils/common_utils.py +++ b/sky/utils/common_utils.py @@ -182,7 +182,11 @@ def read_yaml(path) -> Dict[str, Any]: def read_yaml_all(path: str) -> List[Dict[str, Any]]: with open(path, 'r') as f: config = yaml.safe_load_all(f) - return list(config) + configs = list(config) + if not configs: + # Empty YAML file. + return [{}] + return configs def dump_yaml(path, config) -> None: diff --git a/sky/utils/dag_utils.py b/sky/utils/dag_utils.py index d8a8ed03d8c..0d00438b316 100644 --- a/sky/utils/dag_utils.py +++ b/sky/utils/dag_utils.py @@ -1,11 +1,29 @@ """Utilities for loading and dumping DAGs from/to YAML files.""" +from typing import List, Optional, Tuple + from sky import dag as dag_lib from sky import task as task_lib from sky.backends import backend_utils from sky.utils import common_utils -def load_chain_dag_from_yaml(path: str) -> dag_lib.Dag: +def load_chain_dag_from_yaml( + path: str, + env_overrides: Optional[List[Tuple[str, str]]] = None, +) -> dag_lib.Dag: + """Loads a chain DAG from a YAML file. + + Has special handling for an initial section in YAML that contains only the + 'name' field, which is the DAG name. + + 'env_overrides' is in effect only when there's exactly one task. It is a + list of (key, value) pairs that will be used to update the task's 'envs' + section. + + Returns: + A chain Dag with 1 or more tasks (an empty entrypoint would create a + trivial task). + """ configs = common_utils.read_yaml_all(path) dag_name = None if set(configs[0].keys()) == {'name'}: @@ -14,12 +32,22 @@ def load_chain_dag_from_yaml(path: str) -> dag_lib.Dag: elif len(configs) == 1: dag_name = configs[0].get('name') + if len(configs) == 0: + # YAML has only `name: xxx`. Still instantiate a task. + configs = [{'name': dag_name}] + + if len(configs) > 1: + # TODO(zongheng): in a chain DAG of N tasks, cli.py currently makes the + # decision to not apply overrides. Here we maintain this behavior. We + # can listen to user feedback to change this. + env_overrides = None + current_task = None with dag_lib.Dag() as dag: for task_config in configs: if task_config is None: continue - task = task_lib.Task.from_yaml_config(task_config) + task = task_lib.Task.from_yaml_config(task_config, env_overrides) if current_task is not None: current_task >> task # pylint: disable=pointless-statement current_task = task diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index f3c4b8f7137..a77f1c4c5ec 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -170,7 +170,13 @@ def get_task_schema(): 'envs': { 'type': 'object', 'required': [], - 'additionalProperties': True, + 'patternProperties': { + # Checks env keys are valid env var names. + '^[a-zA-Z_][a-zA-Z0-9_]*$': { + 'type': 'string' + } + }, + 'additionalProperties': False, }, # inputs and outputs are experimental 'inputs': { diff --git a/tests/test_optimizer_dryruns.py b/tests/test_optimizer_dryruns.py index a3ca0798979..dc7b11768c6 100644 --- a/tests/test_optimizer_dryruns.py +++ b/tests/test_optimizer_dryruns.py @@ -1,6 +1,6 @@ import tempfile import textwrap -from typing import List +from typing import Callable, List, Optional import pytest @@ -9,31 +9,39 @@ from sky import exceptions -def _test_parse_cpus(spec, expected_cpus): +def _test_parse_task_yaml(spec: str, test_fn: Optional[Callable] = None): + """Tests parsing a task from a YAML spec and running a test_fn.""" with tempfile.NamedTemporaryFile('w') as f: f.write(spec) f.flush() with sky.Dag(): task = sky.Task.from_yaml(f.name) - assert list(task.resources)[0].cpus == expected_cpus + if test_fn is not None: + test_fn(task) + + +def _test_parse_cpus(spec, expected_cpus): + + def test_fn(task): + assert list(task.resources)[0].cpus == expected_cpus + + _test_parse_task_yaml(spec, test_fn) def _test_parse_memory(spec, expected_memory): - with tempfile.NamedTemporaryFile('w') as f: - f.write(spec) - f.flush() - with sky.Dag(): - task = sky.Task.from_yaml(f.name) - assert list(task.resources)[0].memory == expected_memory + + def test_fn(task): + assert list(task.resources)[0].memory == expected_memory + + _test_parse_task_yaml(spec, test_fn) def _test_parse_accelerators(spec, expected_accelerators): - with tempfile.NamedTemporaryFile('w') as f: - f.write(spec) - f.flush() - with sky.Dag(): - task = sky.Task.from_yaml(f.name) - assert list(task.resources)[0].accelerators == expected_accelerators + + def test_fn(task): + assert list(task.resources)[0].accelerators == expected_accelerators + + _test_parse_task_yaml(spec, test_fn) # Monkey-patching is required because in the test environment, no cloud is @@ -547,3 +555,47 @@ def test_invalid_num_nodes(): task = sky.Task() task.num_nodes = invalid_value assert 'num_nodes should be a positive int' in str(e.value) + + +def test_parse_empty_yaml(): + spec = textwrap.dedent("""\ + """) + + def test_fn(task): + assert task.num_nodes == 1 + + _test_parse_task_yaml(spec, test_fn) + + +def test_parse_name_only_yaml(): + spec = textwrap.dedent("""\ + name: test_task + """) + + def test_fn(task): + assert task.name == 'test_task' + + _test_parse_task_yaml(spec, test_fn) + + +def test_parse_invalid_envs_yaml(monkeypatch): + spec = textwrap.dedent("""\ + envs: + hello world: 1 # invalid key + 123: val # invalid key + good_key: val + """) + with pytest.raises(ValueError) as e: + _test_parse_task_yaml(spec) + assert '\'123\', \'hello world\' do not match any of the regexes' in str( + e.value) + + +def test_parse_valid_envs_yaml(monkeypatch): + spec = textwrap.dedent("""\ + envs: + hello_world: 1 + HELLO: val + GOOD123: 123 + """) + _test_parse_task_yaml(spec) diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 27be66877cb..dcc159a9888 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -4,7 +4,7 @@ # Run all tests except for AWS and Lambda Cloud # > pytest tests/test_smoke.py # -# Terminate failed clusetrs after test finishes +# Terminate failed clusters after test finishes # > pytest tests/test_smoke.py --terminate-on-failure # # Re-run last failed tests @@ -738,6 +738,28 @@ def test_scp_file_mounts(): run_one_test(test) +def test_using_file_mounts_with_env_vars(generic_cloud: str): + name = _get_cluster_name() + test_commands = [ + *storage_setup_commands, + (f'sky launch -y -c {name} --cpus 2+ --cloud {generic_cloud} ' + 'examples/using_file_mounts_with_env_vars.yaml'), + f'sky logs {name} 1 --status', # Ensure the job succeeded. + # Override with --env: + (f'sky launch -y -c {name}-2 --cpus 2+ --cloud {generic_cloud} ' + 'examples/using_file_mounts_with_env_vars.yaml ' + '--env MY_LOCAL_PATH=tmpfile'), + f'sky logs {name}-2 1 --status', # Ensure the job succeeded. + ] + test = Test( + 'using_file_mounts_with_env_vars', + test_commands, + f'sky down -y {name} {name}-2', + timeout=20 * 60, # 20 mins + ) + run_one_test(test) + + # ---------- storage ---------- @pytest.mark.aws def test_aws_storage_mounts():