From cb28b8d62a1a717357b2033949a7a50573357e12 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Thu, 19 Sep 2024 22:01:39 +0000 Subject: [PATCH 01/61] support policy hook --- sky/__init__.py | 4 ++ sky/dag.py | 30 ++++++++------ sky/execution.py | 2 +- sky/jobs/controller.py | 2 + sky/jobs/core.py | 2 +- sky/policy.py | 94 ++++++++++++++++++++++++++++++++++++++++++ sky/skypilot_config.py | 10 ++++- sky/task.py | 1 + sky/utils/dag_utils.py | 13 ++++-- sky/utils/schemas.py | 5 +++ 10 files changed, 143 insertions(+), 20 deletions(-) create mode 100644 sky/policy.py diff --git a/sky/__init__.py b/sky/__init__.py index a077fb8966a..6d5b32e2b90 100644 --- a/sky/__init__.py +++ b/sky/__init__.py @@ -110,6 +110,8 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]): from sky.jobs.core import spot_tail_logs from sky.optimizer import Optimizer from sky.optimizer import OptimizeTarget +from sky.policy import MutatedUserTask +from sky.policy import UserTask from sky.resources import Resources from sky.skylet.job_lib import JobStatus from sky.status_lib import ClusterStatus @@ -185,4 +187,6 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]): # core APIs Storage Management 'storage_ls', 'storage_delete', + 'UserTask', + 'MutatedUserTask', ] diff --git a/sky/dag.py b/sky/dag.py index d1904eb9fcc..4af5adc76b5 100644 --- a/sky/dag.py +++ b/sky/dag.py @@ -1,8 +1,12 @@ """DAGs: user applications to be run.""" import pprint import threading +import typing from typing import List, Optional +if typing.TYPE_CHECKING: + from sky import task + class Dag: """Dag: a user application, represented as a DAG of Tasks. @@ -13,37 +17,37 @@ class Dag: >>> task = sky.Task(...) """ - def __init__(self): - self.tasks = [] + def __init__(self) -> None: + self.tasks: List['task.Task'] = [] import networkx as nx # pylint: disable=import-outside-toplevel self.graph = nx.DiGraph() - self.name = None + self.name: Optional[str] = None - def add(self, task): + def add(self, task: 'task.Task') -> None: self.graph.add_node(task) self.tasks.append(task) - def remove(self, task): + def remove(self, task: 'task.Task') -> None: self.tasks.remove(task) self.graph.remove_node(task) - def add_edge(self, op1, op2): + def add_edge(self, op1: 'task.Task', op2: 'task.Task') -> None: assert op1 in self.graph.nodes assert op2 in self.graph.nodes self.graph.add_edge(op1, op2) - def __len__(self): + def __len__(self) -> int: return len(self.tasks) - def __enter__(self): + def __enter__(self) -> 'Dag': push_dag(self) return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type, exc_value, traceback) -> None: pop_dag() - def __repr__(self): + def __repr__(self) -> str: pformat = pprint.pformat(self.tasks) return f'DAG:\n{pformat}' @@ -70,15 +74,15 @@ def is_chain(self) -> bool: class _DagContext(threading.local): """A thread-local stack of Dags.""" - _current_dag = None + _current_dag: Optional[Dag] = None _previous_dags: List[Dag] = [] - def push_dag(self, dag): + def push_dag(self, dag: Dag): if self._current_dag is not None: self._previous_dags.append(self._current_dag) self._current_dag = dag - def pop_dag(self): + def pop_dag(self) -> Optional[Dag]: old_dag = self._current_dag if self._previous_dags: self._current_dag = self._previous_dags.pop() diff --git a/sky/execution.py b/sky/execution.py index 1f6bd09f9c3..f691b340c6c 100644 --- a/sky/execution.py +++ b/sky/execution.py @@ -158,7 +158,7 @@ def _execute( handle: Optional[backends.ResourceHandle]; the handle to the cluster. None if dryrun. """ - dag = dag_utils.convert_entrypoint_to_dag(entrypoint) + dag = dag_utils.convert_entrypoint_to_dag_and_apply_policy(entrypoint) assert len(dag) == 1, f'We support 1 task for now. {dag}' task = dag.tasks[0] diff --git a/sky/jobs/controller.py b/sky/jobs/controller.py index 39c89d2784b..f3cd81576e2 100644 --- a/sky/jobs/controller.py +++ b/sky/jobs/controller.py @@ -64,6 +64,7 @@ def __init__(self, job_id: int, dag_yaml: str, if len(self._dag.tasks) <= 1: task_name = self._dag_name else: + assert task.name is not None, task task_name = task.name # This is guaranteed by the spot_launch API, where we fill in # the task.name with @@ -447,6 +448,7 @@ def _cleanup(job_id: int, dag_yaml: str): # controller, we should keep it in sync with JobsController.__init__() dag, _ = _get_dag_and_name(dag_yaml) for task in dag.tasks: + assert task.name is not None, task cluster_name = managed_job_utils.generate_managed_job_cluster_name( task.name, job_id) recovery_strategy.terminate_cluster(cluster_name) diff --git a/sky/jobs/core.py b/sky/jobs/core.py index 561d47f4b25..c66199fe270 100644 --- a/sky/jobs/core.py +++ b/sky/jobs/core.py @@ -53,7 +53,7 @@ def launch( entrypoint = task dag_uuid = str(uuid.uuid4().hex[:4]) - dag = dag_utils.convert_entrypoint_to_dag(entrypoint) + dag = dag_utils.convert_entrypoint_to_dag_and_apply_policy(entrypoint) if not dag.is_chain(): with ux_utils.print_exception_no_traceback(): raise ValueError('Only single-task or chain DAG is ' diff --git a/sky/policy.py b/sky/policy.py new file mode 100644 index 00000000000..204b5fc7eeb --- /dev/null +++ b/sky/policy.py @@ -0,0 +1,94 @@ +"""Customize policy by users.""" +import dataclasses +import importlib +import os +import tempfile +import typing +from typing import Any, Callable, Dict, Optional + +from sky import dag as dag_lib +from sky import skypilot_config +from sky.utils import common_utils +from sky.utils import ux_utils + +if typing.TYPE_CHECKING: + from sky import task as task_lib + + +@dataclasses.dataclass +class UserTask: + task: 'task_lib.Task' + skypilot_config: Dict[str, Any] + + +@dataclasses.dataclass +class MutatedUserTask: + task: 'task_lib.Task' + skypilot_config: Dict[str, Any] + + +class Policy: + """Customize Policy by users.""" + + def __init__(self) -> None: + # Policy is a string to a python function within some user provided + # module. + self.policy: Optional[str] = skypilot_config.get_nested(('policy',), + None) + self.policy_fn: Optional[Callable[[UserTask], MutatedUserTask]] = None + if self.policy is not None: + try: + module_path, func_name = self.policy.rsplit('.', 1) + module = importlib.import_module(module_path) + except ImportError as e: + raise ImportError( + f'Failed to import policy module: {module_path}. Please ' + 'check if the module is in your Python environment.') from e + try: + self.policy_fn = getattr(module, func_name) + except AttributeError as e: + raise AttributeError( + f'Failed to get policy function: {func_name} from module: ' + f'{module_path}. Please check with your policy admin if ' + f'the function {func_name!r} is in the module.') from e + + def apply(self, dag: 'dag_lib.Dag') -> 'dag_lib.Dag': + if self.policy_fn is None: + return dag + config = skypilot_config.to_dict() + mutated_dag = dag_lib.Dag() + mutated_dag.name = dag.name + + mutated_config = None + for task in dag.tasks: + user_task = UserTask(task, config) + mutated_user_task = self.policy_fn(user_task) + if mutated_config is None: + mutated_config = mutated_user_task.skypilot_config + else: + if mutated_config != mutated_user_task.skypilot_config: + with ux_utils.print_exception_no_traceback(): + raise ValueError( + 'All tasks must have the same skypilot ' + 'config after applying the policy. Please' + 'check with your policy admin for details.') + mutated_dag.add(mutated_user_task.task) + + # Update the new_dag's graph with the old dag's graph + for u, v in dag.graph.edges: + u_idx = dag.tasks.index(u) + v_idx = dag.tasks.index(v) + mutated_dag.graph.add_edge(mutated_dag.tasks[u_idx], + mutated_dag.tasks[v_idx]) + + if config != mutated_config: + with tempfile.NamedTemporaryFile( + delete=False, + mode='w', + prefix='policy-mutated-skypilot-config-', + suffix='.yaml') as temp_file: + common_utils.dump_yaml(temp_file.name, mutated_config) + os.environ[ + skypilot_config.ENV_VAR_SKYPILOT_CONFIG] = temp_file.name + skypilot_config.reload_config() + return mutated_dag diff --git a/sky/skypilot_config.py b/sky/skypilot_config.py index 52e1d0ae3d9..fd79ab8d393 100644 --- a/sky/skypilot_config.py +++ b/sky/skypilot_config.py @@ -82,7 +82,7 @@ # The loaded config. _dict: Optional[Dict[str, Any]] = None -_loaded_config_path = None +_loaded_config_path: Optional[str] = None def _get_nested(configs: Optional[Dict[str, Any]], keys: Iterable[str], @@ -210,6 +210,14 @@ def _try_load_config() -> None: logger.debug('Config syntax check passed.') +def reload_config() -> None: + """Reloads the config from the file specified by the env var.""" + global _dict, _loaded_config_path + _dict = None + _loaded_config_path = None + _try_load_config() + + def loaded_config_path() -> Optional[str]: """Returns the path to the loaded config file.""" return _loaded_config_path diff --git a/sky/task.py b/sky/task.py index cebc616dc6d..ebd11f00e8b 100644 --- a/sky/task.py +++ b/sky/task.py @@ -1,4 +1,5 @@ """Task: a coarse-grained stage in an application.""" +import dataclasses import inspect import json import os diff --git a/sky/utils/dag_utils.py b/sky/utils/dag_utils.py index 7a4fe90e7fb..e68873e4691 100644 --- a/sky/utils/dag_utils.py +++ b/sky/utils/dag_utils.py @@ -4,6 +4,7 @@ from sky import dag as dag_lib from sky import jobs +from sky import policy from sky import sky_logging from sky import task as task_lib from sky.backends import backend_utils @@ -35,31 +36,35 @@ """.strip() -def convert_entrypoint_to_dag(entrypoint: Any) -> 'dag_lib.Dag': - """Convert the entrypoint to a sky.Dag. +def convert_entrypoint_to_dag_and_apply_policy( + entrypoint: Any) -> 'dag_lib.Dag': + """Convert the entrypoint to a sky.Dag and apply the policy. Raises TypeError if 'entrypoint' is not a 'sky.Task' or 'sky.Dag'. """ # Not suppressing stacktrace: when calling this via API user may want to # see their own program in the stacktrace. Our CLI impl would not trigger # these errors. + converted_dag: 'dag_lib.Dag' if isinstance(entrypoint, str): with ux_utils.print_exception_no_traceback(): raise TypeError(_ENTRYPOINT_STRING_AS_DAG_MESSAGE) elif isinstance(entrypoint, dag_lib.Dag): - return copy.deepcopy(entrypoint) + converted_dag = copy.deepcopy(entrypoint) elif isinstance(entrypoint, task_lib.Task): entrypoint = copy.deepcopy(entrypoint) with dag_lib.Dag() as dag: dag.add(entrypoint) dag.name = entrypoint.name - return dag + converted_dag = dag else: with ux_utils.print_exception_no_traceback(): raise TypeError( 'Expected a sky.Task or sky.Dag but received argument of type: ' f'{type(entrypoint)}') + return policy.Policy().apply(converted_dag) + def load_chain_dag_from_yaml( path: str, diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 01dc14f617c..203592486af 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -442,6 +442,11 @@ def _experimental_task_schema() -> dict: 'additionalProperties': False, 'properties': { 'config_overrides': config_override_schema, + 'policy': { + 'type': 'string', + # Check regex to be a valid python module path + 'pattern': '^[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)*$', + } } } } From b64efa0a71f66715303316bcef8c41d53660fecc Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Thu, 19 Sep 2024 23:53:56 +0000 Subject: [PATCH 02/61] test task labels --- examples/policy/config.yaml | 1 + .../example_policy/example_policy/__init__.py | 1 + .../example_policy/skypilot_policy.py | 16 ++++++++++++++++ examples/policy/example_policy/pyproject.toml | 7 +++++++ examples/policy/task.yaml | 12 ++++++++++++ sky/policy.py | 18 +++++++++++------- sky/task.py | 1 - sky/utils/schemas.py | 15 ++++++++++----- 8 files changed, 58 insertions(+), 13 deletions(-) create mode 100644 examples/policy/config.yaml create mode 100644 examples/policy/example_policy/example_policy/__init__.py create mode 100644 examples/policy/example_policy/example_policy/skypilot_policy.py create mode 100644 examples/policy/example_policy/pyproject.toml create mode 100644 examples/policy/task.yaml diff --git a/examples/policy/config.yaml b/examples/policy/config.yaml new file mode 100644 index 00000000000..2186d9ab9f6 --- /dev/null +++ b/examples/policy/config.yaml @@ -0,0 +1 @@ +policy: example_policy.example_policy_task_label diff --git a/examples/policy/example_policy/example_policy/__init__.py b/examples/policy/example_policy/example_policy/__init__.py new file mode 100644 index 00000000000..d8f7a51898b --- /dev/null +++ b/examples/policy/example_policy/example_policy/__init__.py @@ -0,0 +1 @@ +from example_policy.skypilot_policy import example_policy_task_label diff --git a/examples/policy/example_policy/example_policy/skypilot_policy.py b/examples/policy/example_policy/example_policy/skypilot_policy.py new file mode 100644 index 00000000000..1f1cf08c865 --- /dev/null +++ b/examples/policy/example_policy/example_policy/skypilot_policy.py @@ -0,0 +1,16 @@ +import getpass + +from sky import MutatedUserTask +from sky import UserTask + + +def example_policy_task_label(user_task: UserTask) -> MutatedUserTask: + """Example policy.""" + local_user_name = getpass.getuser() + + # Add label for task with the local user name + task = user_task.task + for r in task.resources: + r.labels['local_user'] = local_user_name + + return MutatedUserTask(task=task, skypilot_config=user_task.skypilot_config) diff --git a/examples/policy/example_policy/pyproject.toml b/examples/policy/example_policy/pyproject.toml new file mode 100644 index 00000000000..b4aa56be4b2 --- /dev/null +++ b/examples/policy/example_policy/pyproject.toml @@ -0,0 +1,7 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "example_policy" +version = "0.0.1" diff --git a/examples/policy/task.yaml b/examples/policy/task.yaml new file mode 100644 index 00000000000..065b4cbfb11 --- /dev/null +++ b/examples/policy/task.yaml @@ -0,0 +1,12 @@ +resources: + cloud: aws + cpus: 2 + labels: + other_labels: test + + +setup: | + echo "setup" + +run: | + echo "run" diff --git a/sky/policy.py b/sky/policy.py index 204b5fc7eeb..d42e8d6e804 100644 --- a/sky/policy.py +++ b/sky/policy.py @@ -41,16 +41,20 @@ def __init__(self) -> None: module_path, func_name = self.policy.rsplit('.', 1) module = importlib.import_module(module_path) except ImportError as e: - raise ImportError( - f'Failed to import policy module: {module_path}. Please ' - 'check if the module is in your Python environment.') from e + with ux_utils.print_exception_no_traceback(): + raise ImportError( + f'Failed to import policy module: {module_path}. ' + 'Please check if the module is in your Python ' + 'environment.' + ) from e try: self.policy_fn = getattr(module, func_name) except AttributeError as e: - raise AttributeError( - f'Failed to get policy function: {func_name} from module: ' - f'{module_path}. Please check with your policy admin if ' - f'the function {func_name!r} is in the module.') from e + with ux_utils.print_exception_no_traceback(): + raise AttributeError( + f'Failed to get policy function: {func_name} from module: ' + f'{module_path}. Please check with your policy admin if ' + f'the function {func_name!r} is in the module.') from e def apply(self, dag: 'dag_lib.Dag') -> 'dag_lib.Dag': if self.policy_fn is None: diff --git a/sky/task.py b/sky/task.py index ebd11f00e8b..cebc616dc6d 100644 --- a/sky/task.py +++ b/sky/task.py @@ -1,5 +1,4 @@ """Task: a coarse-grained stage in an application.""" -import dataclasses import inspect import json import os diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 203592486af..996bc083185 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -442,11 +442,6 @@ def _experimental_task_schema() -> dict: 'additionalProperties': False, 'properties': { 'config_overrides': config_override_schema, - 'policy': { - 'type': 'string', - # Check regex to be a valid python module path - 'pattern': '^[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)*$', - } } } } @@ -853,6 +848,15 @@ def get_config_schema(): }, } + policy_schema = { + 'policy': { + 'type': 'string', + # Check regex to be a valid python module path + 'pattern': (r'^[a-zA-Z_][a-zA-Z0-9_]*' + r'(\.[a-zA-Z_][a-zA-Z0-9_]*)+$'), + } + } + allowed_clouds = { # A list of cloud names that are allowed to be used 'type': 'array', @@ -910,6 +914,7 @@ def get_config_schema(): 'spot': controller_resources_schema, 'serve': controller_resources_schema, 'allowed_clouds': allowed_clouds, + 'policy': policy_schema, 'docker': docker_configs, 'nvidia_gpus': gpu_configs, **cloud_configs, From cf89929f6a290c255a8a51891b2f62cea2c014da Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 20 Sep 2024 00:40:55 +0000 Subject: [PATCH 03/61] Add test for policy that sets labels --- examples/policy/config.yaml | 1 - examples/policy/config_label_config.yaml | 1 + .../example_policy/example_policy/__init__.py | 3 +- .../example_policy/skypilot_policy.py | 17 +++++++++- examples/policy/task_label_config.yaml | 1 + sky/policy.py | 20 ++++++++--- sky/skypilot_config.py | 8 ----- tests/unit_tests/test_policy.py | 34 +++++++++++++++++++ 8 files changed, 69 insertions(+), 16 deletions(-) delete mode 100644 examples/policy/config.yaml create mode 100644 examples/policy/config_label_config.yaml create mode 100644 examples/policy/task_label_config.yaml create mode 100644 tests/unit_tests/test_policy.py diff --git a/examples/policy/config.yaml b/examples/policy/config.yaml deleted file mode 100644 index 2186d9ab9f6..00000000000 --- a/examples/policy/config.yaml +++ /dev/null @@ -1 +0,0 @@ -policy: example_policy.example_policy_task_label diff --git a/examples/policy/config_label_config.yaml b/examples/policy/config_label_config.yaml new file mode 100644 index 00000000000..ad79e2fe1e2 --- /dev/null +++ b/examples/policy/config_label_config.yaml @@ -0,0 +1 @@ +policy: example_policy.config_label_policy diff --git a/examples/policy/example_policy/example_policy/__init__.py b/examples/policy/example_policy/example_policy/__init__.py index d8f7a51898b..5de65e5bc57 100644 --- a/examples/policy/example_policy/example_policy/__init__.py +++ b/examples/policy/example_policy/example_policy/__init__.py @@ -1 +1,2 @@ -from example_policy.skypilot_policy import example_policy_task_label +from example_policy.skypilot_policy import task_label_policy +from example_policy.skypilot_policy import config_label_policy diff --git a/examples/policy/example_policy/example_policy/skypilot_policy.py b/examples/policy/example_policy/example_policy/skypilot_policy.py index 1f1cf08c865..1edea4ab18e 100644 --- a/examples/policy/example_policy/example_policy/skypilot_policy.py +++ b/examples/policy/example_policy/example_policy/skypilot_policy.py @@ -4,7 +4,7 @@ from sky import UserTask -def example_policy_task_label(user_task: UserTask) -> MutatedUserTask: +def task_label_policy(user_task: UserTask) -> MutatedUserTask: """Example policy.""" local_user_name = getpass.getuser() @@ -14,3 +14,18 @@ def example_policy_task_label(user_task: UserTask) -> MutatedUserTask: r.labels['local_user'] = local_user_name return MutatedUserTask(task=task, skypilot_config=user_task.skypilot_config) + + +def config_label_policy(user_task: UserTask) -> MutatedUserTask: + """Example policy.""" + local_user_name = getpass.getuser() + + # Add label for skypilot_config with the local user name + skypilot_config = user_task.skypilot_config + if not skypilot_config.get('aws'): + skypilot_config['aws'] = {} + labels = skypilot_config['aws'].get('labels', {}) + labels['local_user'] = local_user_name + skypilot_config['aws']['labels'] = labels + + return MutatedUserTask(task=user_task.task, skypilot_config=skypilot_config) diff --git a/examples/policy/task_label_config.yaml b/examples/policy/task_label_config.yaml new file mode 100644 index 00000000000..e1e494c4827 --- /dev/null +++ b/examples/policy/task_label_config.yaml @@ -0,0 +1 @@ +policy: example_policy.task_label_policy diff --git a/sky/policy.py b/sky/policy.py index d42e8d6e804..0e9577587c1 100644 --- a/sky/policy.py +++ b/sky/policy.py @@ -1,4 +1,5 @@ """Customize policy by users.""" +import copy import dataclasses import importlib import os @@ -10,6 +11,9 @@ from sky import skypilot_config from sky.utils import common_utils from sky.utils import ux_utils +from sky import sky_logging + +logger = sky_logging.init_logger(__name__) if typing.TYPE_CHECKING: from sky import task as task_lib @@ -59,7 +63,9 @@ def __init__(self) -> None: def apply(self, dag: 'dag_lib.Dag') -> 'dag_lib.Dag': if self.policy_fn is None: return dag - config = skypilot_config.to_dict() + logger.info(f'Applying policy: {self.policy}') + original_config = skypilot_config.to_dict() + config = copy.deepcopy(original_config) mutated_dag = dag_lib.Dag() mutated_dag.name = dag.name @@ -85,14 +91,18 @@ def apply(self, dag: 'dag_lib.Dag') -> 'dag_lib.Dag': mutated_dag.graph.add_edge(mutated_dag.tasks[u_idx], mutated_dag.tasks[v_idx]) - if config != mutated_config: + if original_config != mutated_config: with tempfile.NamedTemporaryFile( delete=False, mode='w', prefix='policy-mutated-skypilot-config-', suffix='.yaml') as temp_file: common_utils.dump_yaml(temp_file.name, mutated_config) - os.environ[ - skypilot_config.ENV_VAR_SKYPILOT_CONFIG] = temp_file.name - skypilot_config.reload_config() + os.environ[ + skypilot_config.ENV_VAR_SKYPILOT_CONFIG] = temp_file.name + logger.debug(f'Updated SkyPilot config: {temp_file.name}') + # This is not a clean way to update the SkyPilot config, because + # we are resetting the global context for a single DAG, which is + # conceptually weird. + importlib.reload(skypilot_config) return mutated_dag diff --git a/sky/skypilot_config.py b/sky/skypilot_config.py index fd79ab8d393..b6a12cf7ef4 100644 --- a/sky/skypilot_config.py +++ b/sky/skypilot_config.py @@ -210,14 +210,6 @@ def _try_load_config() -> None: logger.debug('Config syntax check passed.') -def reload_config() -> None: - """Reloads the config from the file specified by the env var.""" - global _dict, _loaded_config_path - _dict = None - _loaded_config_path = None - _try_load_config() - - def loaded_config_path() -> Optional[str]: """Returns the path to the loaded config file.""" return _loaded_config_path diff --git a/tests/unit_tests/test_policy.py b/tests/unit_tests/test_policy.py new file mode 100644 index 00000000000..5285b2b7e52 --- /dev/null +++ b/tests/unit_tests/test_policy.py @@ -0,0 +1,34 @@ +import os +import sys +import pytest +import importlib + +import sky +from sky import sky_logging +from sky import skypilot_config +from sky.utils import dag_utils + +logger = sky_logging.init_logger(__name__) + +POLICY_PATH = os.path.join(os.path.dirname(os.path.dirname(sky.__file__)), + 'examples', 'policy') + +@pytest.fixture +def add_example_policy_paths(): + # Add to path to be able to import + sys.path.append(os.path.join(POLICY_PATH, 'example_policy')) + +def _load_task_and_apply_policy(config_path) -> sky.Dag: + os.environ['SKYPILOT_CONFIG'] = config_path + importlib.reload(skypilot_config) + task = sky.Task.from_yaml(os.path.join(POLICY_PATH, 'task.yaml')) + return dag_utils.convert_entrypoint_to_dag_and_apply_policy(task) + +def test_task_level_changes_policy(add_example_policy_paths): + dag = _load_task_and_apply_policy(os.path.join(POLICY_PATH, 'task_label_config.yaml')) + assert 'local_user' in list(dag.tasks[0].resources)[0].labels + + +def test_config_level_changes_policy(add_example_policy_paths): + _load_task_and_apply_policy(os.path.join(POLICY_PATH, 'config_label_config.yaml')) + assert 'local_user' in skypilot_config.get_nested(('aws', 'labels'), {}) From 54c93ea0a2f3f1ac2622df005c155cc1736621e8 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 20 Sep 2024 00:59:02 +0000 Subject: [PATCH 04/61] Fix comment --- sky/policy.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sky/policy.py b/sky/policy.py index 0e9577587c1..89b75ca552a 100644 --- a/sky/policy.py +++ b/sky/policy.py @@ -101,8 +101,9 @@ def apply(self, dag: 'dag_lib.Dag') -> 'dag_lib.Dag': os.environ[ skypilot_config.ENV_VAR_SKYPILOT_CONFIG] = temp_file.name logger.debug(f'Updated SkyPilot config: {temp_file.name}') - # This is not a clean way to update the SkyPilot config, because - # we are resetting the global context for a single DAG, which is - # conceptually weird. + # TODO(zhwu): This is not a clean way to update the SkyPilot config, + # because we are resetting the global context for a single DAG, + # which is conceptually weird. importlib.reload(skypilot_config) + return mutated_dag From 1d1c500591a26afd3ad28ed0162ee31ea9b3e9a2 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 20 Sep 2024 01:14:34 +0000 Subject: [PATCH 05/61] format --- .../example_policy/example_policy/__init__.py | 2 +- .../example_policy/skypilot_policy.py | 2 +- sky/policy.py | 15 +++++++-------- tests/unit_tests/test_policy.py | 12 +++++++++--- 4 files changed, 18 insertions(+), 13 deletions(-) diff --git a/examples/policy/example_policy/example_policy/__init__.py b/examples/policy/example_policy/example_policy/__init__.py index 5de65e5bc57..a728ae70dfb 100644 --- a/examples/policy/example_policy/example_policy/__init__.py +++ b/examples/policy/example_policy/example_policy/__init__.py @@ -1,2 +1,2 @@ -from example_policy.skypilot_policy import task_label_policy from example_policy.skypilot_policy import config_label_policy +from example_policy.skypilot_policy import task_label_policy diff --git a/examples/policy/example_policy/example_policy/skypilot_policy.py b/examples/policy/example_policy/example_policy/skypilot_policy.py index 1edea4ab18e..a7fa558f2dd 100644 --- a/examples/policy/example_policy/example_policy/skypilot_policy.py +++ b/examples/policy/example_policy/example_policy/skypilot_policy.py @@ -27,5 +27,5 @@ def config_label_policy(user_task: UserTask) -> MutatedUserTask: labels = skypilot_config['aws'].get('labels', {}) labels['local_user'] = local_user_name skypilot_config['aws']['labels'] = labels - + return MutatedUserTask(task=user_task.task, skypilot_config=skypilot_config) diff --git a/sky/policy.py b/sky/policy.py index 89b75ca552a..c4c5202461e 100644 --- a/sky/policy.py +++ b/sky/policy.py @@ -8,10 +8,10 @@ from typing import Any, Callable, Dict, Optional from sky import dag as dag_lib +from sky import sky_logging from sky import skypilot_config from sky.utils import common_utils from sky.utils import ux_utils -from sky import sky_logging logger = sky_logging.init_logger(__name__) @@ -49,16 +49,16 @@ def __init__(self) -> None: raise ImportError( f'Failed to import policy module: {module_path}. ' 'Please check if the module is in your Python ' - 'environment.' - ) from e + 'environment.') from e try: self.policy_fn = getattr(module, func_name) except AttributeError as e: with ux_utils.print_exception_no_traceback(): raise AttributeError( - f'Failed to get policy function: {func_name} from module: ' - f'{module_path}. Please check with your policy admin if ' - f'the function {func_name!r} is in the module.') from e + f'Failed to get policy function: {func_name} from ' + f'module: {module_path}. Please check with your policy ' + f'admin if the function {func_name!r} is in the ' + 'module.') from e def apply(self, dag: 'dag_lib.Dag') -> 'dag_lib.Dag': if self.policy_fn is None: @@ -98,8 +98,7 @@ def apply(self, dag: 'dag_lib.Dag') -> 'dag_lib.Dag': prefix='policy-mutated-skypilot-config-', suffix='.yaml') as temp_file: common_utils.dump_yaml(temp_file.name, mutated_config) - os.environ[ - skypilot_config.ENV_VAR_SKYPILOT_CONFIG] = temp_file.name + os.environ[skypilot_config.ENV_VAR_SKYPILOT_CONFIG] = temp_file.name logger.debug(f'Updated SkyPilot config: {temp_file.name}') # TODO(zhwu): This is not a clean way to update the SkyPilot config, # because we are resetting the global context for a single DAG, diff --git a/tests/unit_tests/test_policy.py b/tests/unit_tests/test_policy.py index 5285b2b7e52..1e967287d08 100644 --- a/tests/unit_tests/test_policy.py +++ b/tests/unit_tests/test_policy.py @@ -1,7 +1,8 @@ +import importlib import os import sys + import pytest -import importlib import sky from sky import sky_logging @@ -13,22 +14,27 @@ POLICY_PATH = os.path.join(os.path.dirname(os.path.dirname(sky.__file__)), 'examples', 'policy') + @pytest.fixture def add_example_policy_paths(): # Add to path to be able to import sys.path.append(os.path.join(POLICY_PATH, 'example_policy')) + def _load_task_and_apply_policy(config_path) -> sky.Dag: os.environ['SKYPILOT_CONFIG'] = config_path importlib.reload(skypilot_config) task = sky.Task.from_yaml(os.path.join(POLICY_PATH, 'task.yaml')) return dag_utils.convert_entrypoint_to_dag_and_apply_policy(task) + def test_task_level_changes_policy(add_example_policy_paths): - dag = _load_task_and_apply_policy(os.path.join(POLICY_PATH, 'task_label_config.yaml')) + dag = _load_task_and_apply_policy( + os.path.join(POLICY_PATH, 'task_label_config.yaml')) assert 'local_user' in list(dag.tasks[0].resources)[0].labels def test_config_level_changes_policy(add_example_policy_paths): - _load_task_and_apply_policy(os.path.join(POLICY_PATH, 'config_label_config.yaml')) + _load_task_and_apply_policy( + os.path.join(POLICY_PATH, 'config_label_config.yaml')) assert 'local_user' in skypilot_config.get_nested(('aws', 'labels'), {}) From a0bdb2c31f423d912e118d6dc0a9119d9a25830d Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 20 Sep 2024 01:16:01 +0000 Subject: [PATCH 06/61] use -e to make test related files visible --- .github/workflows/pytest.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index ac723f35fc2..3faf75acf8d 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -53,7 +53,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install ".[all]" + pip install -e ".[all]" pip install pytest pytest-xdist pytest-env>=0.6 memory-profiler==0.61.0 - name: Run tests with pytest From 543e66a7924e2b8bb2f562f2063162a7d849b93e Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 20 Sep 2024 06:25:43 +0000 Subject: [PATCH 07/61] Add config.rst --- docs/source/reference/config.rst | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst index 6c2fe2569a6..7c5f38629ec 100644 --- a/docs/source/reference/config.rst +++ b/docs/source/reference/config.rst @@ -17,7 +17,6 @@ Spec: ``~/.sky/config.yaml`` Available fields and semantics: .. code-block:: yaml - # Custom managed jobs controller resources (optional). # # These take effects only when a managed jobs controller does not already exist. @@ -87,6 +86,13 @@ Available fields and semantics: # Default: false. disable_ecc: false + # Custom policy to be applied to all tasks. + # + # The policy function to be applied and mutate all tasks, which can be used to + # enforce certain policies on all tasks. + # See more details in :ref:`advanced-policy-config` + policy: my_package.skypilot_policy_fn_v1 + # Advanced AWS configurations (optional). # Apply to all new instances but not existing ones. aws: From 520a2a1a1ad6aaf8cc5c8b436d2f8489af42f2ed Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 20 Sep 2024 06:55:50 +0000 Subject: [PATCH 08/61] Fix test --- tests/unit_tests/test_resources.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/unit_tests/test_resources.py b/tests/unit_tests/test_resources.py index 70da0532e9b..91123071ea6 100644 --- a/tests/unit_tests/test_resources.py +++ b/tests/unit_tests/test_resources.py @@ -1,3 +1,5 @@ +import importlib +import os from typing import Dict from unittest.mock import Mock from unittest.mock import patch @@ -91,10 +93,6 @@ def test_kubernetes_labels_resources(): _run_label_test(allowed_labels, invalid_labels, cloud) -@patch.object(skypilot_config, 'CONFIG_PATH', - './tests/test_yamls/test_aws_config.yaml') -@patch.object(skypilot_config, '_dict', None) -@patch.object(skypilot_config, '_loaded_config_path', None) @patch('sky.clouds.service_catalog.instance_type_exists', return_value=True) @patch('sky.clouds.service_catalog.get_accelerators_from_instance_type', return_value={'fake-acc': 2}) @@ -102,7 +100,8 @@ def test_kubernetes_labels_resources(): return_value='fake-image') @patch.object(clouds.aws, 'DEFAULT_SECURITY_GROUP_NAME', 'fake-default-sg') def test_aws_make_deploy_variables(*mocks) -> None: - skypilot_config._try_load_config() + os.environ['SKYPILOT_CONFIG'] = './tests/test_yamls/test_aws_config.yaml' + importlib.reload(skypilot_config) cloud = clouds.AWS() cluster_name = resources_utils.ClusterName(display_name='display', From b5333511a86b8669d014dec3a3065cacbf90415d Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 20 Sep 2024 07:01:59 +0000 Subject: [PATCH 09/61] fix config rst --- docs/source/reference/config.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst index 7c5f38629ec..bb45e0e8508 100644 --- a/docs/source/reference/config.rst +++ b/docs/source/reference/config.rst @@ -17,6 +17,7 @@ Spec: ``~/.sky/config.yaml`` Available fields and semantics: .. code-block:: yaml + # Custom managed jobs controller resources (optional). # # These take effects only when a managed jobs controller does not already exist. @@ -90,7 +91,8 @@ Available fields and semantics: # # The policy function to be applied and mutate all tasks, which can be used to # enforce certain policies on all tasks. - # See more details in :ref:`advanced-policy-config` + # + # See details in: policy: my_package.skypilot_policy_fn_v1 # Advanced AWS configurations (optional). From 466f7fe8764c85d558f8a660edc3fb532dd17f5f Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 20 Sep 2024 20:08:12 +0000 Subject: [PATCH 10/61] Apply policy to service --- sky/policy.py | 9 +++++++++ sky/serve/core.py | 3 +++ 2 files changed, 12 insertions(+) diff --git a/sky/policy.py b/sky/policy.py index c4c5202461e..3d7a4c9a677 100644 --- a/sky/policy.py +++ b/sky/policy.py @@ -106,3 +106,12 @@ def apply(self, dag: 'dag_lib.Dag') -> 'dag_lib.Dag': importlib.reload(skypilot_config) return mutated_dag + + def apply_to_task(self, task: 'task_lib.Task') -> 'task_lib.Task': + if self.policy_fn is None: + return task + + dag = dag_lib.Dag() + dag.add(task) + dag = self.apply(dag) + return dag.tasks[0] diff --git a/sky/serve/core.py b/sky/serve/core.py index 4f15413cf7f..7d88ba688c3 100644 --- a/sky/serve/core.py +++ b/sky/serve/core.py @@ -6,6 +6,7 @@ import colorama import sky +from sky import policy from sky import backends from sky import exceptions from sky import sky_logging @@ -124,6 +125,8 @@ def up( _validate_service_task(task) + task = policy.Policy().apply_to_task(task) + controller_utils.maybe_translate_local_file_mounts_and_sync_up(task, path='serve') From 050dc7a6e8d587f69291dc90e802b011e33b785f Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 20 Sep 2024 20:15:06 +0000 Subject: [PATCH 11/61] add policy for serving --- sky/serve/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sky/serve/core.py b/sky/serve/core.py index 7d88ba688c3..a1272c8f152 100644 --- a/sky/serve/core.py +++ b/sky/serve/core.py @@ -6,9 +6,9 @@ import colorama import sky -from sky import policy from sky import backends from sky import exceptions +from sky import policy from sky import sky_logging from sky import task as task_lib from sky.backends import backend_utils From 31e0174e27b3bd5a6aef2a39683abc2b536f5aca Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 20 Sep 2024 20:20:43 +0000 Subject: [PATCH 12/61] Add docs --- sky/policy.py | 47 ++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 3 deletions(-) diff --git a/sky/policy.py b/sky/policy.py index 3d7a4c9a677..c740826bb34 100644 --- a/sky/policy.py +++ b/sky/policy.py @@ -32,7 +32,28 @@ class MutatedUserTask: class Policy: - """Customize Policy by users.""" + """User-defined policy. + + A user-defined policy is a string to a python function that can be imported + from the same environment where SkyPilot is running. + + It can be specified in the SkyPilot config file under the key 'policy', e.g. + + policy: my_package.skypilot_policy_fn_v1 + + The python function is expected to have the following signature: + + def skypilot_policy_fn_v1(user_task: UserTask) -> MutatedUserTask: + ... + return MutatedUserTask(task=..., skypilot_config=...) + + The function can mutate both task and skypilot_config. + """ + + def skypilot_policy_fn_v1(user_task: UserTask) -> MutatedUserTask: + ... + return MutatedUserTask(task=..., skypilot_config=...) + """ def __init__(self) -> None: # Policy is a string to a python function within some user provided @@ -61,6 +82,17 @@ def __init__(self) -> None: 'module.') from e def apply(self, dag: 'dag_lib.Dag') -> 'dag_lib.Dag': + """Apply user-defined policy to a DAG. + + It mutates a Dag by applying user-defined policy and also update the + global SkyPilot config if there is any changes made by the policy. + + Args: + dag: The dag to be mutated by the policy. + + Returns: + The mutated dag. + """ if self.policy_fn is None: return dag logger.info(f'Applying policy: {self.policy}') @@ -108,8 +140,17 @@ def apply(self, dag: 'dag_lib.Dag') -> 'dag_lib.Dag': return mutated_dag def apply_to_task(self, task: 'task_lib.Task') -> 'task_lib.Task': - if self.policy_fn is None: - return task + """Apply user-defined policy to a task. + + It mutates a task by applying user-defined policy and also update the + global SkyPilot config if there is any changes made by the policy. + + Args: + task: The task to be mutated by the policy. + + Returns: + The mutated task. + """ dag = dag_lib.Dag() dag.add(task) From 0c74f2a2585edc3a2376bf2e4eb47ebc1147e13f Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 20 Sep 2024 20:21:39 +0000 Subject: [PATCH 13/61] fix --- sky/policy.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/sky/policy.py b/sky/policy.py index c740826bb34..6b9f0d64f9b 100644 --- a/sky/policy.py +++ b/sky/policy.py @@ -50,12 +50,8 @@ def skypilot_policy_fn_v1(user_task: UserTask) -> MutatedUserTask: The function can mutate both task and skypilot_config. """ - def skypilot_policy_fn_v1(user_task: UserTask) -> MutatedUserTask: - ... - return MutatedUserTask(task=..., skypilot_config=...) - """ - def __init__(self) -> None: + """Initialize the policy from SkyPilot config.""" # Policy is a string to a python function within some user provided # module. self.policy: Optional[str] = skypilot_config.get_nested(('policy',), From 48a6cc9f284a3b68b55fbbfd4e69bfe77dd0e091 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 20 Sep 2024 20:27:25 +0000 Subject: [PATCH 14/61] format --- sky/policy.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sky/policy.py b/sky/policy.py index 6b9f0d64f9b..9860747b0a9 100644 --- a/sky/policy.py +++ b/sky/policy.py @@ -33,7 +33,7 @@ class MutatedUserTask: class Policy: """User-defined policy. - + A user-defined policy is a string to a python function that can be imported from the same environment where SkyPilot is running. @@ -79,7 +79,7 @@ def __init__(self) -> None: def apply(self, dag: 'dag_lib.Dag') -> 'dag_lib.Dag': """Apply user-defined policy to a DAG. - + It mutates a Dag by applying user-defined policy and also update the global SkyPilot config if there is any changes made by the policy. @@ -137,7 +137,7 @@ def apply(self, dag: 'dag_lib.Dag') -> 'dag_lib.Dag': def apply_to_task(self, task: 'task_lib.Task') -> 'task_lib.Task': """Apply user-defined policy to a task. - + It mutates a task by applying user-defined policy and also update the global SkyPilot config if there is any changes made by the policy. From 1ca5a8aa9463cce0b46fe78d25b4127bc6b9155b Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 20 Sep 2024 22:33:46 +0000 Subject: [PATCH 15/61] Update interface --- docs/source/reference/config.rst | 6 +- .../admin_policy/config_label_config.yaml | 1 + .../example_policy/example_policy/__init__.py | 4 + .../example_policy/skypilot_policy.py | 33 ++++ .../example_policy/pyproject.toml | 0 examples/{policy => admin_policy}/task.yaml | 0 examples/admin_policy/task_label_config.yaml | 1 + examples/policy/config_label_config.yaml | 1 - .../example_policy/example_policy/__init__.py | 2 - .../example_policy/skypilot_policy.py | 31 ---- examples/policy/task_label_config.yaml | 1 - sky/__init__.py | 6 +- sky/execution.py | 4 +- sky/jobs/core.py | 4 +- sky/policy.py | 147 +++--------------- sky/serve/core.py | 4 +- sky/utils/dag_utils.py | 8 +- sky/utils/policy_utils.py | 142 +++++++++++++++++ sky/utils/schemas.py | 14 +- tests/unit_tests/test_policy.py | 8 +- 20 files changed, 230 insertions(+), 187 deletions(-) create mode 100644 examples/admin_policy/config_label_config.yaml create mode 100644 examples/admin_policy/example_policy/example_policy/__init__.py create mode 100644 examples/admin_policy/example_policy/example_policy/skypilot_policy.py rename examples/{policy => admin_policy}/example_policy/pyproject.toml (100%) rename examples/{policy => admin_policy}/task.yaml (100%) create mode 100644 examples/admin_policy/task_label_config.yaml delete mode 100644 examples/policy/config_label_config.yaml delete mode 100644 examples/policy/example_policy/example_policy/__init__.py delete mode 100644 examples/policy/example_policy/example_policy/skypilot_policy.py delete mode 100644 examples/policy/task_label_config.yaml create mode 100644 sky/utils/policy_utils.py diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst index bb45e0e8508..9fde305dfba 100644 --- a/docs/source/reference/config.rst +++ b/docs/source/reference/config.rst @@ -89,11 +89,11 @@ Available fields and semantics: # Custom policy to be applied to all tasks. # - # The policy function to be applied and mutate all tasks, which can be used to + # The policy class to be applied and mutate all tasks, which can be used to # enforce certain policies on all tasks. # - # See details in: - policy: my_package.skypilot_policy_fn_v1 + # The policy class should implement the sky.AdminPolicy interface. + admin_policy: my_package.SkyPilotPolicyV1 # Advanced AWS configurations (optional). # Apply to all new instances but not existing ones. diff --git a/examples/admin_policy/config_label_config.yaml b/examples/admin_policy/config_label_config.yaml new file mode 100644 index 00000000000..9228986028d --- /dev/null +++ b/examples/admin_policy/config_label_config.yaml @@ -0,0 +1 @@ +admin_policy: example_policy.ConfigLabelPolicy diff --git a/examples/admin_policy/example_policy/example_policy/__init__.py b/examples/admin_policy/example_policy/example_policy/__init__.py new file mode 100644 index 00000000000..f901a9a30c6 --- /dev/null +++ b/examples/admin_policy/example_policy/example_policy/__init__.py @@ -0,0 +1,4 @@ +"""Example module for SkyPilot admin policies.""" + +from example_policy.skypilot_policy import ConfigLabelPolicy +from example_policy.skypilot_policy import TaskLabelPolicy diff --git a/examples/admin_policy/example_policy/example_policy/skypilot_policy.py b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py new file mode 100644 index 00000000000..aaf72c5fee0 --- /dev/null +++ b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py @@ -0,0 +1,33 @@ +import getpass + +import sky + +class TaskLabelPolicy(sky.AdminPolicy): + """Example policy: add label for task with the local user name.""" + + @classmethod + def validate_and_mutate(cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: + local_user_name = getpass.getuser() + + # Add label for task with the local user name + task = user_request.task + for r in task.resources: + r.labels['local_user'] = local_user_name + + return sky.MutatedUserRequest(task=task, skypilot_config=user_request.skypilot_config) + +class ConfigLabelPolicy(sky.AdminPolicy): + """Example policy: add label for skypilot_config with the local user name.""" + @classmethod + def validate_and_mutate(cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: + local_user_name = getpass.getuser() + + # Add label for skypilot_config with the local user name + skypilot_config = user_request.skypilot_config + if skypilot_config.get('gcp') is None: + skypilot_config['gcp'] = {} + labels = skypilot_config['gcp'].get('labels', {}) + labels['local_user'] = local_user_name + skypilot_config['gcp']['labels'] = labels + + return sky.MutatedUserRequest(task=user_request.task, skypilot_config=skypilot_config) diff --git a/examples/policy/example_policy/pyproject.toml b/examples/admin_policy/example_policy/pyproject.toml similarity index 100% rename from examples/policy/example_policy/pyproject.toml rename to examples/admin_policy/example_policy/pyproject.toml diff --git a/examples/policy/task.yaml b/examples/admin_policy/task.yaml similarity index 100% rename from examples/policy/task.yaml rename to examples/admin_policy/task.yaml diff --git a/examples/admin_policy/task_label_config.yaml b/examples/admin_policy/task_label_config.yaml new file mode 100644 index 00000000000..f21774e7086 --- /dev/null +++ b/examples/admin_policy/task_label_config.yaml @@ -0,0 +1 @@ +admin_policy: example_policy.TaskLabelPolicy diff --git a/examples/policy/config_label_config.yaml b/examples/policy/config_label_config.yaml deleted file mode 100644 index ad79e2fe1e2..00000000000 --- a/examples/policy/config_label_config.yaml +++ /dev/null @@ -1 +0,0 @@ -policy: example_policy.config_label_policy diff --git a/examples/policy/example_policy/example_policy/__init__.py b/examples/policy/example_policy/example_policy/__init__.py deleted file mode 100644 index a728ae70dfb..00000000000 --- a/examples/policy/example_policy/example_policy/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from example_policy.skypilot_policy import config_label_policy -from example_policy.skypilot_policy import task_label_policy diff --git a/examples/policy/example_policy/example_policy/skypilot_policy.py b/examples/policy/example_policy/example_policy/skypilot_policy.py deleted file mode 100644 index a7fa558f2dd..00000000000 --- a/examples/policy/example_policy/example_policy/skypilot_policy.py +++ /dev/null @@ -1,31 +0,0 @@ -import getpass - -from sky import MutatedUserTask -from sky import UserTask - - -def task_label_policy(user_task: UserTask) -> MutatedUserTask: - """Example policy.""" - local_user_name = getpass.getuser() - - # Add label for task with the local user name - task = user_task.task - for r in task.resources: - r.labels['local_user'] = local_user_name - - return MutatedUserTask(task=task, skypilot_config=user_task.skypilot_config) - - -def config_label_policy(user_task: UserTask) -> MutatedUserTask: - """Example policy.""" - local_user_name = getpass.getuser() - - # Add label for skypilot_config with the local user name - skypilot_config = user_task.skypilot_config - if not skypilot_config.get('aws'): - skypilot_config['aws'] = {} - labels = skypilot_config['aws'].get('labels', {}) - labels['local_user'] = local_user_name - skypilot_config['aws']['labels'] = labels - - return MutatedUserTask(task=user_task.task, skypilot_config=skypilot_config) diff --git a/examples/policy/task_label_config.yaml b/examples/policy/task_label_config.yaml deleted file mode 100644 index e1e494c4827..00000000000 --- a/examples/policy/task_label_config.yaml +++ /dev/null @@ -1 +0,0 @@ -policy: example_policy.task_label_policy diff --git a/sky/__init__.py b/sky/__init__.py index 6d5b32e2b90..31707ca32d3 100644 --- a/sky/__init__.py +++ b/sky/__init__.py @@ -110,12 +110,14 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]): from sky.jobs.core import spot_tail_logs from sky.optimizer import Optimizer from sky.optimizer import OptimizeTarget -from sky.policy import MutatedUserTask -from sky.policy import UserTask from sky.resources import Resources from sky.skylet.job_lib import JobStatus from sky.status_lib import ClusterStatus from sky.task import Task +# Admin Policy interfaces +from sky.policy import UserRequest +from sky.policy import MutatedUserRequest +from sky.policy import AdminPolicy # Aliases. IBM = clouds.IBM diff --git a/sky/execution.py b/sky/execution.py index f691b340c6c..aa3b93a44c7 100644 --- a/sky/execution.py +++ b/sky/execution.py @@ -24,6 +24,7 @@ from sky.utils import subprocess_utils from sky.utils import timeline from sky.utils import ux_utils +from sky.utils import policy_utils logger = sky_logging.init_logger(__name__) @@ -158,7 +159,8 @@ def _execute( handle: Optional[backends.ResourceHandle]; the handle to the cluster. None if dryrun. """ - dag = dag_utils.convert_entrypoint_to_dag_and_apply_policy(entrypoint) + dag = dag_utils.convert_entrypoint_to_dag(entrypoint) + dag = policy_utils.apply(dag) assert len(dag) == 1, f'We support 1 task for now. {dag}' task = dag.tasks[0] diff --git a/sky/jobs/core.py b/sky/jobs/core.py index c66199fe270..a8c598978b2 100644 --- a/sky/jobs/core.py +++ b/sky/jobs/core.py @@ -21,6 +21,7 @@ from sky.utils import common_utils from sky.utils import controller_utils from sky.utils import dag_utils +from sky.utils import policy_utils from sky.utils import rich_utils from sky.utils import subprocess_utils from sky.utils import ux_utils @@ -53,7 +54,8 @@ def launch( entrypoint = task dag_uuid = str(uuid.uuid4().hex[:4]) - dag = dag_utils.convert_entrypoint_to_dag_and_apply_policy(entrypoint) + dag = dag_utils.convert_entrypoint_to_dag(entrypoint) + dag = policy_utils.apply(dag) if not dag.is_chain(): with ux_utils.print_exception_no_traceback(): raise ValueError('Only single-task or chain DAG is ' diff --git a/sky/policy.py b/sky/policy.py index 9860747b0a9..952dc70242c 100644 --- a/sky/policy.py +++ b/sky/policy.py @@ -1,154 +1,47 @@ -"""Customize policy by users.""" -import copy +"""Interface for admin-defined policy for user requests.""" +import abc import dataclasses -import importlib -import os -import tempfile import typing -from typing import Any, Callable, Dict, Optional - -from sky import dag as dag_lib -from sky import sky_logging -from sky import skypilot_config -from sky.utils import common_utils -from sky.utils import ux_utils - -logger = sky_logging.init_logger(__name__) +from typing import Any, Dict if typing.TYPE_CHECKING: from sky import task as task_lib @dataclasses.dataclass -class UserTask: +class UserRequest: task: 'task_lib.Task' skypilot_config: Dict[str, Any] @dataclasses.dataclass -class MutatedUserTask: +class MutatedUserRequest: task: 'task_lib.Task' skypilot_config: Dict[str, Any] - -class Policy: - """User-defined policy. - +class AdminPolicy: + """Interface for admin-defined policy for user requests. + A user-defined policy is a string to a python function that can be imported from the same environment where SkyPilot is running. It can be specified in the SkyPilot config file under the key 'policy', e.g. - policy: my_package.skypilot_policy_fn_v1 + policy: my_package.SkyPilotPolicyV1 - The python function is expected to have the following signature: + The AdminPolicy class is expected to have the following signature: + + import sky - def skypilot_policy_fn_v1(user_task: UserTask) -> MutatedUserTask: - ... - return MutatedUserTask(task=..., skypilot_config=...) + class SkyPilotPolicyV1(sky.AdminPolicy): + def validate_and_mutate(user_request: UserRequest) -> MutatedUserRequest: + ... + return MutatedUserRequest(task=..., skypilot_config=...) The function can mutate both task and skypilot_config. """ - def __init__(self) -> None: - """Initialize the policy from SkyPilot config.""" - # Policy is a string to a python function within some user provided - # module. - self.policy: Optional[str] = skypilot_config.get_nested(('policy',), - None) - self.policy_fn: Optional[Callable[[UserTask], MutatedUserTask]] = None - if self.policy is not None: - try: - module_path, func_name = self.policy.rsplit('.', 1) - module = importlib.import_module(module_path) - except ImportError as e: - with ux_utils.print_exception_no_traceback(): - raise ImportError( - f'Failed to import policy module: {module_path}. ' - 'Please check if the module is in your Python ' - 'environment.') from e - try: - self.policy_fn = getattr(module, func_name) - except AttributeError as e: - with ux_utils.print_exception_no_traceback(): - raise AttributeError( - f'Failed to get policy function: {func_name} from ' - f'module: {module_path}. Please check with your policy ' - f'admin if the function {func_name!r} is in the ' - 'module.') from e - - def apply(self, dag: 'dag_lib.Dag') -> 'dag_lib.Dag': - """Apply user-defined policy to a DAG. - - It mutates a Dag by applying user-defined policy and also update the - global SkyPilot config if there is any changes made by the policy. - - Args: - dag: The dag to be mutated by the policy. - - Returns: - The mutated dag. - """ - if self.policy_fn is None: - return dag - logger.info(f'Applying policy: {self.policy}') - original_config = skypilot_config.to_dict() - config = copy.deepcopy(original_config) - mutated_dag = dag_lib.Dag() - mutated_dag.name = dag.name - - mutated_config = None - for task in dag.tasks: - user_task = UserTask(task, config) - mutated_user_task = self.policy_fn(user_task) - if mutated_config is None: - mutated_config = mutated_user_task.skypilot_config - else: - if mutated_config != mutated_user_task.skypilot_config: - with ux_utils.print_exception_no_traceback(): - raise ValueError( - 'All tasks must have the same skypilot ' - 'config after applying the policy. Please' - 'check with your policy admin for details.') - mutated_dag.add(mutated_user_task.task) - - # Update the new_dag's graph with the old dag's graph - for u, v in dag.graph.edges: - u_idx = dag.tasks.index(u) - v_idx = dag.tasks.index(v) - mutated_dag.graph.add_edge(mutated_dag.tasks[u_idx], - mutated_dag.tasks[v_idx]) - - if original_config != mutated_config: - with tempfile.NamedTemporaryFile( - delete=False, - mode='w', - prefix='policy-mutated-skypilot-config-', - suffix='.yaml') as temp_file: - common_utils.dump_yaml(temp_file.name, mutated_config) - os.environ[skypilot_config.ENV_VAR_SKYPILOT_CONFIG] = temp_file.name - logger.debug(f'Updated SkyPilot config: {temp_file.name}') - # TODO(zhwu): This is not a clean way to update the SkyPilot config, - # because we are resetting the global context for a single DAG, - # which is conceptually weird. - importlib.reload(skypilot_config) - - return mutated_dag - - def apply_to_task(self, task: 'task_lib.Task') -> 'task_lib.Task': - """Apply user-defined policy to a task. - - It mutates a task by applying user-defined policy and also update the - global SkyPilot config if there is any changes made by the policy. - - Args: - task: The task to be mutated by the policy. - - Returns: - The mutated task. - """ - - dag = dag_lib.Dag() - dag.add(task) - dag = self.apply(dag) - return dag.tasks[0] + @classmethod + @abc.abstractmethod + def validate_and_mutate(cls, user_request: UserRequest) -> MutatedUserRequest: + raise NotImplementedError('Your policy must implement validate_and_mutate') diff --git a/sky/serve/core.py b/sky/serve/core.py index a1272c8f152..f9eb5c12874 100644 --- a/sky/serve/core.py +++ b/sky/serve/core.py @@ -8,7 +8,7 @@ import sky from sky import backends from sky import exceptions -from sky import policy +from sky.utils import policy_utils from sky import sky_logging from sky import task as task_lib from sky.backends import backend_utils @@ -125,7 +125,7 @@ def up( _validate_service_task(task) - task = policy.Policy().apply_to_task(task) + task = policy_utils.apply(task) controller_utils.maybe_translate_local_file_mounts_and_sync_up(task, path='serve') diff --git a/sky/utils/dag_utils.py b/sky/utils/dag_utils.py index e68873e4691..2f716d5b86d 100644 --- a/sky/utils/dag_utils.py +++ b/sky/utils/dag_utils.py @@ -4,7 +4,7 @@ from sky import dag as dag_lib from sky import jobs -from sky import policy +from sky.utils import policy_utils from sky import sky_logging from sky import task as task_lib from sky.backends import backend_utils @@ -36,9 +36,9 @@ """.strip() -def convert_entrypoint_to_dag_and_apply_policy( +def convert_entrypoint_to_dag( entrypoint: Any) -> 'dag_lib.Dag': - """Convert the entrypoint to a sky.Dag and apply the policy. + """Converts the entrypoint to a sky.Dag and applies the policy. Raises TypeError if 'entrypoint' is not a 'sky.Task' or 'sky.Dag'. """ @@ -63,7 +63,7 @@ def convert_entrypoint_to_dag_and_apply_policy( 'Expected a sky.Task or sky.Dag but received argument of type: ' f'{type(entrypoint)}') - return policy.Policy().apply(converted_dag) + return converted_dag def load_chain_dag_from_yaml( diff --git a/sky/utils/policy_utils.py b/sky/utils/policy_utils.py new file mode 100644 index 00000000000..db48730b4cc --- /dev/null +++ b/sky/utils/policy_utils.py @@ -0,0 +1,142 @@ +"""Customize policy by users.""" +import copy +import importlib +import os +import tempfile +import typing +from typing import Optional, Union + +from sky import dag as dag_lib +from sky import sky_logging +from sky import task as task_lib +from sky import skypilot_config +from sky.utils import common_utils +from sky.utils import ux_utils +from sky import policy as policy_lib + +logger = sky_logging.init_logger(__name__) + + + +def _get_policy_cls() -> Optional[policy_lib.AdminPolicy]: + """Get admin-defined policy.""" + policy = skypilot_config.get_nested(('admin_policy',), None) + if policy is None: + return None + try: + module_path, class_name = policy.rsplit('.', 1) + module = importlib.import_module(module_path) + except ImportError as e: + with ux_utils.print_exception_no_traceback(): + raise ImportError( + f'Failed to import policy module: {policy}. ' + 'Please check if the module is installed in your Python ' + 'environment.') from e + + try: + policy_cls = getattr(module, class_name) + except AttributeError as e: + with ux_utils.print_exception_no_traceback(): + raise AttributeError( + f'Could not find {class_name} class in module {module_path}. ' + 'Please check with your policy admin for details.') from e + + # Check if the module implements the AdminPolicy interface. + if not issubclass(policy_cls, policy_lib.AdminPolicy): + with ux_utils.print_exception_no_traceback(): + raise ValueError( + f'Policy module {policy} does not implement the AdminPolicy ' + 'interface. Please check with your policy admin for details.') + return policy_cls + +@typing.overload +def apply(dag: 'dag_lib.Dag') -> 'dag_lib.Dag': + ... + +@typing.overload +def apply(dag: 'task_lib.Task') -> 'task_lib.Task': + ... + +def apply(entrypoint: Union['dag_lib.Dag', 'task_lib.Task']) -> Union['dag_lib.Dag', 'task_lib.Task']: + """Apply user-defined policy to a DAG or a task. + + It mutates a Dag by applying user-defined policy and also update the + global SkyPilot config if there is any changes made by the policy. + + Args: + dag: The dag to be mutated by the policy. + + Returns: + The mutated dag or task. + """ + if isinstance(entrypoint, task_lib.Task): + return _apply_to_task(entrypoint) + + dag = entrypoint + + policy_cls = _get_policy_cls() + if policy_cls is None: + return dag + logger.info(f'Applying policy: {policy_cls}') + import traceback + logger.debug(f'Stack trace: {traceback.format_stack()}') + original_config = skypilot_config.to_dict() + config = copy.deepcopy(original_config) + mutated_dag = dag_lib.Dag() + mutated_dag.name = dag.name + + mutated_config = None + for task in dag.tasks: + user_request = policy_lib.UserRequest(task, config) + mutated_user_request = policy_cls.validate_and_mutate(user_request) + if mutated_config is None: + mutated_config = mutated_user_request.skypilot_config + else: + if mutated_config != mutated_user_request.skypilot_config: + with ux_utils.print_exception_no_traceback(): + raise ValueError( + 'All tasks must have the same skypilot ' + 'config after applying the policy. Please' + 'check with your policy admin for details.') + mutated_dag.add(mutated_user_request.task) + + # Update the new_dag's graph with the old dag's graph + for u, v in dag.graph.edges: + u_idx = dag.tasks.index(u) + v_idx = dag.tasks.index(v) + mutated_dag.graph.add_edge(mutated_dag.tasks[u_idx], + mutated_dag.tasks[v_idx]) + + if original_config != mutated_config: + with tempfile.NamedTemporaryFile( + delete=False, + mode='w', + prefix='policy-mutated-skypilot-config-', + suffix='.yaml') as temp_file: + common_utils.dump_yaml(temp_file.name, mutated_config) + os.environ[skypilot_config.ENV_VAR_SKYPILOT_CONFIG] = temp_file.name + logger.debug(f'Updated SkyPilot config: {temp_file.name}') + # TODO(zhwu): This is not a clean way to update the SkyPilot config, + # because we are resetting the global context for a single DAG, + # which is conceptually weird. + importlib.reload(skypilot_config) + + logger.debug(f'Mutated user request: {mutated_user_request}') + return mutated_dag + +def _apply_to_task(task: 'task_lib.Task') -> 'task_lib.Task': + """Apply user-defined policy to a task. + + It mutates a task by applying user-defined policy and also update the + global SkyPilot config if there is any changes made by the policy. + + Args: + task: The task to be mutated by the policy. + + Returns: + The mutated task. + """ + dag = dag_lib.Dag() + dag.add(task) + dag = apply(dag) + return dag.tasks[0] diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 996bc083185..a50c400b805 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -848,13 +848,11 @@ def get_config_schema(): }, } - policy_schema = { - 'policy': { - 'type': 'string', - # Check regex to be a valid python module path - 'pattern': (r'^[a-zA-Z_][a-zA-Z0-9_]*' - r'(\.[a-zA-Z_][a-zA-Z0-9_]*)+$'), - } + admin_policy_schema = { + 'type': 'string', + # Check regex to be a valid python module path + 'pattern': (r'^[a-zA-Z_][a-zA-Z0-9_]*' + r'(\.[a-zA-Z_][a-zA-Z0-9_]*)+$'), } allowed_clouds = { @@ -914,7 +912,7 @@ def get_config_schema(): 'spot': controller_resources_schema, 'serve': controller_resources_schema, 'allowed_clouds': allowed_clouds, - 'policy': policy_schema, + 'admin_policy': admin_policy_schema, 'docker': docker_configs, 'nvidia_gpus': gpu_configs, **cloud_configs, diff --git a/tests/unit_tests/test_policy.py b/tests/unit_tests/test_policy.py index 1e967287d08..482ae5ff924 100644 --- a/tests/unit_tests/test_policy.py +++ b/tests/unit_tests/test_policy.py @@ -7,7 +7,7 @@ import sky from sky import sky_logging from sky import skypilot_config -from sky.utils import dag_utils +from sky.utils import policy_utils logger = sky_logging.init_logger(__name__) @@ -25,13 +25,13 @@ def _load_task_and_apply_policy(config_path) -> sky.Dag: os.environ['SKYPILOT_CONFIG'] = config_path importlib.reload(skypilot_config) task = sky.Task.from_yaml(os.path.join(POLICY_PATH, 'task.yaml')) - return dag_utils.convert_entrypoint_to_dag_and_apply_policy(task) + return policy_utils.apply(task) def test_task_level_changes_policy(add_example_policy_paths): - dag = _load_task_and_apply_policy( + task = _load_task_and_apply_policy( os.path.join(POLICY_PATH, 'task_label_config.yaml')) - assert 'local_user' in list(dag.tasks[0].resources)[0].labels + assert 'local_user' in list(task.resources)[0].labels def test_config_level_changes_policy(add_example_policy_paths): From 14b23467019268b494ded7c6db402a23cd05acfd Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sat, 21 Sep 2024 00:32:37 +0000 Subject: [PATCH 16/61] fix --- .../example_policy/skypilot_policy.py | 15 +++- sky/__init__.py | 15 ++-- sky/execution.py | 2 +- sky/jobs/core.py | 6 +- sky/policy.py | 23 ++--- sky/serve/core.py | 7 +- sky/skypilot_config.py | 81 +++++++++-------- sky/templates/jobs-controller.yaml.j2 | 2 +- sky/templates/sky-serve-controller.yaml.j2 | 2 +- sky/utils/controller_utils.py | 65 ++++++++------ sky/utils/dag_utils.py | 4 +- sky/utils/policy_utils.py | 89 ++++++++++--------- 12 files changed, 177 insertions(+), 134 deletions(-) diff --git a/examples/admin_policy/example_policy/example_policy/skypilot_policy.py b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py index aaf72c5fee0..9a100e07c8b 100644 --- a/examples/admin_policy/example_policy/example_policy/skypilot_policy.py +++ b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py @@ -2,11 +2,13 @@ import sky + class TaskLabelPolicy(sky.AdminPolicy): """Example policy: add label for task with the local user name.""" @classmethod - def validate_and_mutate(cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: + def validate_and_mutate( + cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: local_user_name = getpass.getuser() # Add label for task with the local user name @@ -14,12 +16,16 @@ def validate_and_mutate(cls, user_request: sky.UserRequest) -> sky.MutatedUserRe for r in task.resources: r.labels['local_user'] = local_user_name - return sky.MutatedUserRequest(task=task, skypilot_config=user_request.skypilot_config) + return sky.MutatedUserRequest( + task=task, skypilot_config=user_request.skypilot_config) + class ConfigLabelPolicy(sky.AdminPolicy): """Example policy: add label for skypilot_config with the local user name.""" + @classmethod - def validate_and_mutate(cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: + def validate_and_mutate( + cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: local_user_name = getpass.getuser() # Add label for skypilot_config with the local user name @@ -30,4 +36,5 @@ def validate_and_mutate(cls, user_request: sky.UserRequest) -> sky.MutatedUserRe labels['local_user'] = local_user_name skypilot_config['gcp']['labels'] = labels - return sky.MutatedUserRequest(task=user_request.task, skypilot_config=skypilot_config) + return sky.MutatedUserRequest(task=user_request.task, + skypilot_config=skypilot_config) diff --git a/sky/__init__.py b/sky/__init__.py index 31707ca32d3..1895ba80c81 100644 --- a/sky/__init__.py +++ b/sky/__init__.py @@ -110,14 +110,15 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]): from sky.jobs.core import spot_tail_logs from sky.optimizer import Optimizer from sky.optimizer import OptimizeTarget +# Admin Policy interfaces +from sky.policy import AdminPolicy +from sky.policy import MutatedUserRequest +from sky.policy import UserRequest from sky.resources import Resources from sky.skylet.job_lib import JobStatus +from sky.skypilot_config import NestedConfig from sky.status_lib import ClusterStatus from sky.task import Task -# Admin Policy interfaces -from sky.policy import UserRequest -from sky.policy import MutatedUserRequest -from sky.policy import AdminPolicy # Aliases. IBM = clouds.IBM @@ -189,6 +190,8 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]): # core APIs Storage Management 'storage_ls', 'storage_delete', - 'UserTask', - 'MutatedUserTask', + 'UserRequest', + 'MutatedUserRequest', + 'AdminPolicy', + 'NestedConfig', ] diff --git a/sky/execution.py b/sky/execution.py index aa3b93a44c7..c9801b1be96 100644 --- a/sky/execution.py +++ b/sky/execution.py @@ -19,12 +19,12 @@ from sky.utils import controller_utils from sky.utils import dag_utils from sky.utils import env_options +from sky.utils import policy_utils from sky.utils import resources_utils from sky.utils import rich_utils from sky.utils import subprocess_utils from sky.utils import timeline from sky.utils import ux_utils -from sky.utils import policy_utils logger = sky_logging.init_logger(__name__) diff --git a/sky/jobs/core.py b/sky/jobs/core.py index a8c598978b2..d7a166812ef 100644 --- a/sky/jobs/core.py +++ b/sky/jobs/core.py @@ -55,7 +55,10 @@ def launch( dag_uuid = str(uuid.uuid4().hex[:4]) dag = dag_utils.convert_entrypoint_to_dag(entrypoint) - dag = policy_utils.apply(dag) + # TODO(zhwu): We should only apply policy to dag and save the config file, + # instead of having the config file actually being used. + dag, mutated_user_config = policy_utils.apply(dag, + apply_skypilot_config=False) if not dag.is_chain(): with ux_utils.print_exception_no_traceback(): raise ValueError('Only single-task or chain DAG is ' @@ -105,6 +108,7 @@ def launch( **controller_utils.shared_controller_vars_to_fill( controller_utils.Controllers.JOBS_CONTROLLER, remote_user_config_path=remote_user_config_path, + local_user_config=mutated_user_config, ), } diff --git a/sky/policy.py b/sky/policy.py index 952dc70242c..f015abcbc8c 100644 --- a/sky/policy.py +++ b/sky/policy.py @@ -2,26 +2,27 @@ import abc import dataclasses import typing -from typing import Any, Dict if typing.TYPE_CHECKING: - from sky import task as task_lib + import sky @dataclasses.dataclass class UserRequest: - task: 'task_lib.Task' - skypilot_config: Dict[str, Any] + task: 'sky.Task' + skypilot_config: 'sky.NestedConfig' @dataclasses.dataclass class MutatedUserRequest: - task: 'task_lib.Task' - skypilot_config: Dict[str, Any] + task: 'sky.Task' + skypilot_config: 'sky.NestedConfig' + +# pylint: disable=line-too-long class AdminPolicy: """Interface for admin-defined policy for user requests. - + A user-defined policy is a string to a python function that can be imported from the same environment where SkyPilot is running. @@ -30,7 +31,7 @@ class AdminPolicy: policy: my_package.SkyPilotPolicyV1 The AdminPolicy class is expected to have the following signature: - + import sky class SkyPilotPolicyV1(sky.AdminPolicy): @@ -43,5 +44,7 @@ def validate_and_mutate(user_request: UserRequest) -> MutatedUserRequest: @classmethod @abc.abstractmethod - def validate_and_mutate(cls, user_request: UserRequest) -> MutatedUserRequest: - raise NotImplementedError('Your policy must implement validate_and_mutate') + def validate_and_mutate(cls, + user_request: UserRequest) -> MutatedUserRequest: + raise NotImplementedError( + 'Your policy must implement validate_and_mutate') diff --git a/sky/serve/core.py b/sky/serve/core.py index f9eb5c12874..c1914afe90e 100644 --- a/sky/serve/core.py +++ b/sky/serve/core.py @@ -8,7 +8,6 @@ import sky from sky import backends from sky import exceptions -from sky.utils import policy_utils from sky import sky_logging from sky import task as task_lib from sky.backends import backend_utils @@ -20,6 +19,7 @@ from sky.usage import usage_lib from sky.utils import common_utils from sky.utils import controller_utils +from sky.utils import policy_utils from sky.utils import resources_utils from sky.utils import rich_utils from sky.utils import subprocess_utils @@ -125,7 +125,9 @@ def up( _validate_service_task(task) - task = policy_utils.apply(task) + dag, mutated_user_config = policy_utils.apply(task, + apply_skypilot_config=False) + task = dag.tasks[0] controller_utils.maybe_translate_local_file_mounts_and_sync_up(task, path='serve') @@ -161,6 +163,7 @@ def up( **controller_utils.shared_controller_vars_to_fill( controller=controller_utils.Controllers.SKY_SERVE_CONTROLLER, remote_user_config_path=remote_config_yaml_path, + local_user_config=mutated_user_config, ), } common_utils.fill_template(serve_constants.CONTROLLER_TEMPLATE, diff --git a/sky/skypilot_config.py b/sky/skypilot_config.py index b6a12cf7ef4..042e5382251 100644 --- a/sky/skypilot_config.py +++ b/sky/skypilot_config.py @@ -51,6 +51,7 @@ import copy import os import pprint +import typing from typing import Any, Dict, Iterable, Optional, Tuple import yaml @@ -61,6 +62,38 @@ from sky.utils import schemas from sky.utils import ux_utils +K = typing.TypeVar('K') +V = typing.TypeVar('V') + + +class NestedConfig(Dict[str, Any]): + """A nested dictionary that allows for setting and getting values.""" + + def get_nested(self, + keys: Tuple[str, ...], + default_value: Any, + override_configs: Optional[Dict[str, Any]] = None) -> Any: + """Gets a nested key.""" + config = copy.deepcopy(self) + if override_configs is not None: + config = _recursive_update(config, override_configs) + return _get_nested(config, keys, default_value) + + def set_nested(self, keys: Tuple[str, ...], value: Any) -> None: + """Sets a nested key.""" + override = {} + for i, key in enumerate(reversed(keys)): + if i == 0: + override = {key: value} + else: + override = {key: override} + _recursive_update(self, override) + + @classmethod + def from_dict(cls, config: Dict[str, Any]) -> 'NestedConfig': + return cls(**config) + + # The config path is discovered in this order: # # (1) (Used internally) If env var {ENV_VAR_SKYPILOT_CONFIG} exists, use its @@ -81,7 +114,7 @@ logger = sky_logging.init_logger(__name__) # The loaded config. -_dict: Optional[Dict[str, Any]] = None +_dict = NestedConfig() _loaded_config_path: Optional[str] = None @@ -131,17 +164,11 @@ def get_nested(keys: Tuple[str, ...], ), (f'Override configs must not be provided when keys {keys} is not within ' 'constants.OVERRIDEABLE_CONFIG_KEYS: ' f'{constants.OVERRIDEABLE_CONFIG_KEYS}') - config: Dict[str, Any] = {} - if _dict is not None: - config = copy.deepcopy(_dict) - if override_configs is None: - override_configs = {} - config = _recursive_update(config, override_configs) - return _get_nested(config, keys, default_value) + return _dict.get_nested(keys, default_value, override_configs) -def _recursive_update(base_config: Dict[str, Any], - override_config: Dict[str, Any]) -> Dict[str, Any]: +def _recursive_update(base_config: NestedConfig, + override_config: Dict[str, Any]) -> NestedConfig: """Recursively updates base configuration with override configuration""" for key, value in override_config.items(): if (isinstance(value, dict) and key in base_config and @@ -157,22 +184,14 @@ def set_nested(keys: Tuple[str, ...], value: Any) -> Dict[str, Any]: Like get_nested(), if any key is not found, this will not raise an error. """ - _check_loaded_or_die() - assert _dict is not None - override = {} - for i, key in enumerate(reversed(keys)): - if i == 0: - override = {key: value} - else: - override = {key: override} - return _recursive_update(copy.deepcopy(_dict), override) + copied_dict = copy.deepcopy(_dict) + copied_dict.set_nested(keys, value) + return copied_dict -def to_dict() -> Dict[str, Any]: +def to_dict() -> NestedConfig: """Returns a deep-copied version of the current config.""" - if _dict is not None: - return copy.deepcopy(_dict) - return {} + return copy.deepcopy(_dict) def _try_load_config() -> None: @@ -192,13 +211,13 @@ def _try_load_config() -> None: config_path = os.path.expanduser(config_path) if os.path.exists(config_path): logger.debug(f'Using config path: {config_path}') - _loaded_config_path = config_path try: - _dict = common_utils.read_yaml(config_path) + _dict = NestedConfig.from_dict(common_utils.read_yaml(config_path)) + _loaded_config_path = config_path logger.debug(f'Config loaded:\n{pprint.pformat(_dict)}') except yaml.YAMLError as e: logger.error(f'Error in loading config file ({config_path}):', e) - if _dict is not None: + if _dict: common_utils.validate_schema( _dict, schemas.get_config_schema(), @@ -219,14 +238,6 @@ def loaded_config_path() -> Optional[str]: _try_load_config() -def _check_loaded_or_die(): - """Checks loaded() is true; otherwise raises RuntimeError.""" - if _dict is None: - raise RuntimeError( - f'No user configs loaded. Check {CONFIG_PATH} exists and ' - 'can be loaded.') - - def loaded() -> bool: """Returns if the user configurations are loaded.""" - return _dict is not None + return _loaded_config_path is not None diff --git a/sky/templates/jobs-controller.yaml.j2 b/sky/templates/jobs-controller.yaml.j2 index 51083e84a59..8869a5874bf 100644 --- a/sky/templates/jobs-controller.yaml.j2 +++ b/sky/templates/jobs-controller.yaml.j2 @@ -4,7 +4,7 @@ name: {{dag_name}} file_mounts: {{remote_user_yaml_path}}: {{user_yaml_path}} - {{remote_user_config_path}}: skypilot:local_skypilot_config_path + {{remote_user_config_path}}: {{local_user_config_path}} {%- for remote_catalog_path, local_catalog_path in modified_catalogs.items() %} {{remote_catalog_path}}: {{local_catalog_path}} {%- endfor %} diff --git a/sky/templates/sky-serve-controller.yaml.j2 b/sky/templates/sky-serve-controller.yaml.j2 index a20c2d680aa..3b6a5ad2d49 100644 --- a/sky/templates/sky-serve-controller.yaml.j2 +++ b/sky/templates/sky-serve-controller.yaml.j2 @@ -23,7 +23,7 @@ setup: | file_mounts: {{remote_task_yaml_path}}: {{local_task_yaml_path}} - {{remote_user_config_path}}: skypilot:local_skypilot_config_path + {{remote_user_config_path}}: {{local_user_config_path}} {%- for remote_catalog_path, local_catalog_path in modified_catalogs.items() %} {{remote_catalog_path}}: {{local_catalog_path}} {%- endfor %} diff --git a/sky/utils/controller_utils.py b/sky/utils/controller_utils.py index 866aaf1ee1a..9704d6ebd0b 100644 --- a/sky/utils/controller_utils.py +++ b/sky/utils/controller_utils.py @@ -45,7 +45,8 @@ 'Details:\n {err}') # The placeholder for the local skypilot config path in file mounts. -LOCAL_SKYPILOT_CONFIG_PATH_PLACEHOLDER = 'skypilot:local_skypilot_config_path' +_LOCAL_SKYPILOT_CONFIG_PATH_SUFFIX_IDENTIFIER = ( + '__skypilot:local_skypilot_config_path.yaml') @dataclasses.dataclass @@ -350,8 +351,19 @@ def download_and_stream_latest_job_log( def shared_controller_vars_to_fill( - controller: Controllers, - remote_user_config_path: str) -> Dict[str, str]: + controller: Controllers, remote_user_config_path: str, + local_user_config: Dict[str, Any]) -> Dict[str, str]: + if not local_user_config: + local_user_config_path = None + else: + local_user_config.pop('admin_policy', None) + with tempfile.NamedTemporaryFile( + delete=False, + suffix=_LOCAL_SKYPILOT_CONFIG_PATH_SUFFIX_IDENTIFIER + ) as temp_file: + common_utils.dump_yaml(temp_file, local_user_config) + local_user_config_path = temp_file.name + vars_to_fill: Dict[str, Any] = { 'cloud_dependencies_installation_commands': _get_cloud_dependencies_installation_commands(controller), @@ -360,6 +372,7 @@ def shared_controller_vars_to_fill( # accessed. 'sky_activate_python_env': constants.ACTIVATE_SKY_REMOTE_PYTHON_ENV, 'sky_python_cmd': constants.SKY_PYTHON_CMD, + 'local_user_config_path': local_user_config_path, } env_vars: Dict[str, str] = { env.value: '1' for env in env_options.Options if env.get() @@ -481,7 +494,8 @@ def get_controller_resources( def _setup_proxy_command_on_controller( - controller_launched_cloud: 'clouds.Cloud') -> Dict[str, Any]: + controller_launched_cloud: 'clouds.Cloud', + user_config: Dict[str, Any]) -> skypilot_config.NestedConfig: """Sets up proxy command on the controller. This function should be called on the controller (remote cluster), which @@ -515,21 +529,20 @@ def _setup_proxy_command_on_controller( # (or name). It may not be a sufficient check (as it's always # possible that peering is not set up), but it may catch some # obvious errors. + config = skypilot_config.NestedConfig.from_dict(user_config) proxy_command_key = (str(controller_launched_cloud).lower(), 'ssh_proxy_command') - ssh_proxy_command = skypilot_config.get_nested(proxy_command_key, None) - config_dict = skypilot_config.to_dict() + ssh_proxy_command = config.get_nested(proxy_command_key, None) if isinstance(ssh_proxy_command, str): - config_dict = skypilot_config.set_nested(proxy_command_key, None) + config.set_nested(proxy_command_key, None) elif isinstance(ssh_proxy_command, dict): # Instead of removing the key, we set the value to empty string # so that the controller will only try the regions specified by # the keys. ssh_proxy_command = {k: None for k in ssh_proxy_command} - config_dict = skypilot_config.set_nested(proxy_command_key, - ssh_proxy_command) + config.set_nested(proxy_command_key, ssh_proxy_command) - return config_dict + return config def replace_skypilot_config_path_in_file_mounts( @@ -543,25 +556,21 @@ def replace_skypilot_config_path_in_file_mounts( if file_mounts is None: return replaced = False - to_replace = True - with tempfile.NamedTemporaryFile('w', delete=False) as f: - if skypilot_config.loaded(): - new_skypilot_config = _setup_proxy_command_on_controller(cloud) - common_utils.dump_yaml(f.name, new_skypilot_config) - to_replace = True - else: - # Empty config. Remove the placeholder below. - to_replace = False - for remote_path, local_path in list(file_mounts.items()): - if local_path == LOCAL_SKYPILOT_CONFIG_PATH_PLACEHOLDER: - if to_replace: - file_mounts[remote_path] = f.name - replaced = True - else: - del file_mounts[remote_path] + for remote_path, local_path in list(file_mounts.items()): + if local_path is None: + del file_mounts[remote_path] + continue + if _LOCAL_SKYPILOT_CONFIG_PATH_SUFFIX_IDENTIFIER in local_path: + with tempfile.NamedTemporaryFile('w', delete=False) as f: + user_config = common_utils.read_yaml(local_path) + config = _setup_proxy_command_on_controller(cloud, user_config) + common_utils.dump_yaml(f, config) + file_mounts[remote_path] = f.name + replaced = True if replaced: - logger.debug(f'Replaced {LOCAL_SKYPILOT_CONFIG_PATH_PLACEHOLDER} with ' - f'the real path in file mounts: {file_mounts}') + logger.debug( + f'Replaced {_LOCAL_SKYPILOT_CONFIG_PATH_SUFFIX_IDENTIFIER} ' + f'with the real path in file mounts: {file_mounts}') def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', diff --git a/sky/utils/dag_utils.py b/sky/utils/dag_utils.py index 2f716d5b86d..e6b491c3168 100644 --- a/sky/utils/dag_utils.py +++ b/sky/utils/dag_utils.py @@ -4,7 +4,6 @@ from sky import dag as dag_lib from sky import jobs -from sky.utils import policy_utils from sky import sky_logging from sky import task as task_lib from sky.backends import backend_utils @@ -36,8 +35,7 @@ """.strip() -def convert_entrypoint_to_dag( - entrypoint: Any) -> 'dag_lib.Dag': +def convert_entrypoint_to_dag(entrypoint: Any) -> 'dag_lib.Dag': """Converts the entrypoint to a sky.Dag and applies the policy. Raises TypeError if 'entrypoint' is not a 'sky.Task' or 'sky.Dag'. diff --git a/sky/utils/policy_utils.py b/sky/utils/policy_utils.py index db48730b4cc..98fe632cf47 100644 --- a/sky/utils/policy_utils.py +++ b/sky/utils/policy_utils.py @@ -4,20 +4,19 @@ import os import tempfile import typing -from typing import Optional, Union +from typing import Literal, Optional, Tuple, Union from sky import dag as dag_lib +from sky import policy as policy_lib from sky import sky_logging -from sky import task as task_lib from sky import skypilot_config +from sky import task as task_lib from sky.utils import common_utils from sky.utils import ux_utils -from sky import policy as policy_lib logger = sky_logging.init_logger(__name__) - def _get_policy_cls() -> Optional[policy_lib.AdminPolicy]: """Get admin-defined policy.""" policy = skypilot_config.get_nested(('admin_policy',), None) @@ -49,15 +48,27 @@ def _get_policy_cls() -> Optional[policy_lib.AdminPolicy]: 'interface. Please check with your policy admin for details.') return policy_cls + @typing.overload -def apply(dag: 'dag_lib.Dag') -> 'dag_lib.Dag': +def apply( + entrypoint: Union['dag_lib.Dag', 'task_lib.Task'], + apply_skypilot_config: Literal[True] = True, +) -> 'dag_lib.Dag': ... + @typing.overload -def apply(dag: 'task_lib.Task') -> 'task_lib.Task': +def apply( + entrypoint: Union['dag_lib.Dag', 'task_lib.Task'], + apply_skypilot_config: Literal[False], +) -> Tuple['dag_lib.Dag', skypilot_config.NestedConfig]: ... -def apply(entrypoint: Union['dag_lib.Dag', 'task_lib.Task']) -> Union['dag_lib.Dag', 'task_lib.Task']: + +def apply( + entrypoint: Union['dag_lib.Dag', 'task_lib.Task'], + apply_skypilot_config: bool = True, +) -> Union['dag_lib.Dag', Tuple['dag_lib.Dag', skypilot_config.NestedConfig]]: """Apply user-defined policy to a DAG or a task. It mutates a Dag by applying user-defined policy and also update the @@ -65,31 +76,38 @@ def apply(entrypoint: Union['dag_lib.Dag', 'task_lib.Task']) -> Union['dag_lib.D Args: dag: The dag to be mutated by the policy. + apply_skypilot_config: Whether to apply the skypilot config changes to + the global skypilot config. Returns: The mutated dag or task. + Or, a tuple of the mutated dag and path to the mutated skypilot + config, if apply_skypilot_config is set to False. """ if isinstance(entrypoint, task_lib.Task): - return _apply_to_task(entrypoint) - - dag = entrypoint + dag = dag_lib.Dag() + dag.add(entrypoint) + else: + dag = entrypoint policy_cls = _get_policy_cls() if policy_cls is None: - return dag + if apply_skypilot_config: + return dag + else: + return dag, skypilot_config.to_dict() + logger.info(f'Applying policy: {policy_cls}') - import traceback - logger.debug(f'Stack trace: {traceback.format_stack()}') original_config = skypilot_config.to_dict() config = copy.deepcopy(original_config) mutated_dag = dag_lib.Dag() mutated_dag.name = dag.name - mutated_config = None + mutated_config = skypilot_config.NestedConfig() for task in dag.tasks: user_request = policy_lib.UserRequest(task, config) mutated_user_request = policy_cls.validate_and_mutate(user_request) - if mutated_config is None: + if not mutated_config: mutated_config = mutated_user_request.skypilot_config else: if mutated_config != mutated_user_request.skypilot_config: @@ -105,38 +123,25 @@ def apply(entrypoint: Union['dag_lib.Dag', 'task_lib.Task']) -> Union['dag_lib.D u_idx = dag.tasks.index(u) v_idx = dag.tasks.index(v) mutated_dag.graph.add_edge(mutated_dag.tasks[u_idx], - mutated_dag.tasks[v_idx]) + mutated_dag.tasks[v_idx]) - if original_config != mutated_config: + if apply_skypilot_config and original_config != mutated_config: with tempfile.NamedTemporaryFile( delete=False, mode='w', prefix='policy-mutated-skypilot-config-', suffix='.yaml') as temp_file: - common_utils.dump_yaml(temp_file.name, mutated_config) - os.environ[skypilot_config.ENV_VAR_SKYPILOT_CONFIG] = temp_file.name - logger.debug(f'Updated SkyPilot config: {temp_file.name}') - # TODO(zhwu): This is not a clean way to update the SkyPilot config, - # because we are resetting the global context for a single DAG, - # which is conceptually weird. - importlib.reload(skypilot_config) - - logger.debug(f'Mutated user request: {mutated_user_request}') - return mutated_dag - -def _apply_to_task(task: 'task_lib.Task') -> 'task_lib.Task': - """Apply user-defined policy to a task. - - It mutates a task by applying user-defined policy and also update the - global SkyPilot config if there is any changes made by the policy. - Args: - task: The task to be mutated by the policy. + common_utils.dump_yaml(temp_file.name, config) + os.environ[skypilot_config.ENV_VAR_SKYPILOT_CONFIG] = temp_file.name + logger.debug(f'Updated SkyPilot config: {temp_file.name}') + # TODO(zhwu): This is not a clean way to update the SkyPilot config, + # because we are resetting the global context for a single DAG, + # which is conceptually weird. + importlib.reload(skypilot_config) - Returns: - The mutated task. - """ - dag = dag_lib.Dag() - dag.add(task) - dag = apply(dag) - return dag.tasks[0] + logger.debug(f'Mutated user request: {mutated_user_request}') + if apply_skypilot_config: + return mutated_dag + else: + return mutated_dag, mutated_config From cb39c734b501444b956fdebddfe35524343ea364 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sat, 21 Sep 2024 00:58:30 +0000 Subject: [PATCH 17/61] Fix --- sky/skypilot_config.py | 4 ---- sky/utils/common_utils.py | 9 +++++---- sky/utils/controller_utils.py | 4 ++-- sky/utils/policy_utils.py | 8 ++++---- 4 files changed, 11 insertions(+), 14 deletions(-) diff --git a/sky/skypilot_config.py b/sky/skypilot_config.py index 042e5382251..3208514e90e 100644 --- a/sky/skypilot_config.py +++ b/sky/skypilot_config.py @@ -51,7 +51,6 @@ import copy import os import pprint -import typing from typing import Any, Dict, Iterable, Optional, Tuple import yaml @@ -62,9 +61,6 @@ from sky.utils import schemas from sky.utils import ux_utils -K = typing.TypeVar('K') -V = typing.TypeVar('V') - class NestedConfig(Dict[str, Any]): """A nested dictionary that allows for setting and getting values.""" diff --git a/sky/utils/common_utils.py b/sky/utils/common_utils.py index a9227fb4c20..02ca585e13c 100644 --- a/sky/utils/common_utils.py +++ b/sky/utils/common_utils.py @@ -316,12 +316,13 @@ def read_yaml_all(path: str) -> List[Dict[str, Any]]: return configs -def dump_yaml(path, config) -> None: +def dump_yaml(path: str, config: Union[List[Dict[str, Any]], + Dict[str, Any]]) -> None: with open(path, 'w', encoding='utf-8') as f: f.write(dump_yaml_str(config)) -def dump_yaml_str(config): +def dump_yaml_str(config: Union[List[Dict[str, Any]], Dict[str, Any]]) -> str: # https://github.com/yaml/pyyaml/issues/127 class LineBreakDumper(yaml.SafeDumper): @@ -331,9 +332,9 @@ def write_line_break(self, data=None): super().write_line_break() if isinstance(config, list): - dump_func = yaml.dump_all + dump_func = yaml.dump_all # type: ignore else: - dump_func = yaml.dump + dump_func = yaml.dump # type: ignore return dump_func(config, Dumper=LineBreakDumper, sort_keys=False, diff --git a/sky/utils/controller_utils.py b/sky/utils/controller_utils.py index 9704d6ebd0b..de771b6a4c7 100644 --- a/sky/utils/controller_utils.py +++ b/sky/utils/controller_utils.py @@ -361,7 +361,7 @@ def shared_controller_vars_to_fill( delete=False, suffix=_LOCAL_SKYPILOT_CONFIG_PATH_SUFFIX_IDENTIFIER ) as temp_file: - common_utils.dump_yaml(temp_file, local_user_config) + common_utils.dump_yaml(temp_file.name, dict(**local_user_config)) local_user_config_path = temp_file.name vars_to_fill: Dict[str, Any] = { @@ -564,7 +564,7 @@ def replace_skypilot_config_path_in_file_mounts( with tempfile.NamedTemporaryFile('w', delete=False) as f: user_config = common_utils.read_yaml(local_path) config = _setup_proxy_command_on_controller(cloud, user_config) - common_utils.dump_yaml(f, config) + common_utils.dump_yaml(f.name, dict(**config)) file_mounts[remote_path] = f.name replaced = True if replaced: diff --git a/sky/utils/policy_utils.py b/sky/utils/policy_utils.py index 98fe632cf47..cbcd88d07c2 100644 --- a/sky/utils/policy_utils.py +++ b/sky/utils/policy_utils.py @@ -17,9 +17,8 @@ logger = sky_logging.init_logger(__name__) -def _get_policy_cls() -> Optional[policy_lib.AdminPolicy]: +def _get_policy_cls(policy: Optional[str]) -> Optional[policy_lib.AdminPolicy]: """Get admin-defined policy.""" - policy = skypilot_config.get_nested(('admin_policy',), None) if policy is None: return None try: @@ -90,7 +89,8 @@ def apply( else: dag = entrypoint - policy_cls = _get_policy_cls() + policy = skypilot_config.get_nested(('admin_policy',), None) + policy_cls = _get_policy_cls(policy) if policy_cls is None: if apply_skypilot_config: return dag @@ -132,7 +132,7 @@ def apply( prefix='policy-mutated-skypilot-config-', suffix='.yaml') as temp_file: - common_utils.dump_yaml(temp_file.name, config) + common_utils.dump_yaml(temp_file.name, dict(**config)) os.environ[skypilot_config.ENV_VAR_SKYPILOT_CONFIG] = temp_file.name logger.debug(f'Updated SkyPilot config: {temp_file.name}') # TODO(zhwu): This is not a clean way to update the SkyPilot config, From 1e3ddeffe1255b580315bab6311606933afa9d76 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sat, 21 Sep 2024 01:34:09 +0000 Subject: [PATCH 18/61] fix --- sky/utils/policy_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sky/utils/policy_utils.py b/sky/utils/policy_utils.py index cbcd88d07c2..1bf6221bbdc 100644 --- a/sky/utils/policy_utils.py +++ b/sky/utils/policy_utils.py @@ -97,7 +97,7 @@ def apply( else: return dag, skypilot_config.to_dict() - logger.info(f'Applying policy: {policy_cls}') + logger.info(f'Applying policy: {policy}') original_config = skypilot_config.to_dict() config = copy.deepcopy(original_config) mutated_dag = dag_lib.Dag() From aa87df733371477d2e154da90ff231e54a5a9c3c Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sat, 21 Sep 2024 02:18:07 +0000 Subject: [PATCH 19/61] Fix test config --- .../example_policy/skypilot_policy.py | 13 ++-- sky/skypilot_config.py | 70 ++++++++++++------- sky/utils/common_utils.py | 2 +- tests/test_config.py | 20 ++++-- 4 files changed, 67 insertions(+), 38 deletions(-) diff --git a/examples/admin_policy/example_policy/example_policy/skypilot_policy.py b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py index 9a100e07c8b..9c10f6e8d07 100644 --- a/examples/admin_policy/example_policy/example_policy/skypilot_policy.py +++ b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py @@ -1,3 +1,4 @@ +import copy import getpass import sky @@ -9,6 +10,7 @@ class TaskLabelPolicy(sky.AdminPolicy): @classmethod def validate_and_mutate( cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: + """Add label for task with the local user name.""" local_user_name = getpass.getuser() # Add label for task with the local user name @@ -26,15 +28,12 @@ class ConfigLabelPolicy(sky.AdminPolicy): @classmethod def validate_and_mutate( cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: + """Add label for skypilot_config with the local user name.""" local_user_name = getpass.getuser() # Add label for skypilot_config with the local user name - skypilot_config = user_request.skypilot_config - if skypilot_config.get('gcp') is None: - skypilot_config['gcp'] = {} - labels = skypilot_config['gcp'].get('labels', {}) - labels['local_user'] = local_user_name - skypilot_config['gcp']['labels'] = labels - + skypilot_config = copy.deepcopy(user_request.skypilot_config) + skypilot_config.set_nested(('gcp', 'labels', 'local_user'), + local_user_name) return sky.MutatedUserRequest(task=user_request.task, skypilot_config=skypilot_config) diff --git a/sky/skypilot_config.py b/sky/skypilot_config.py index 3208514e90e..4313e04b2cc 100644 --- a/sky/skypilot_config.py +++ b/sky/skypilot_config.py @@ -61,6 +61,25 @@ from sky.utils import schemas from sky.utils import ux_utils +logger = sky_logging.init_logger(__name__) + +# The config path is discovered in this order: +# +# (1) (Used internally) If env var {ENV_VAR_SKYPILOT_CONFIG} exists, use its +# path; +# (2) If file {CONFIG_PATH} exists, use this file. +# +# If the path discovered by (1) fails to load, we do not attempt to go to step +# 2 in the list. + +# (Used internally) An env var holding the path to the local config file. This +# is only used by jobs controller tasks to ensure recoveries of the same job +# use the same config file. +ENV_VAR_SKYPILOT_CONFIG = 'SKYPILOT_CONFIG' + +# Path to the local config file. +CONFIG_PATH = '~/.sky/config.yaml' + class NestedConfig(Dict[str, Any]): """A nested dictionary that allows for setting and getting values.""" @@ -69,14 +88,31 @@ def get_nested(self, keys: Tuple[str, ...], default_value: Any, override_configs: Optional[Dict[str, Any]] = None) -> Any: - """Gets a nested key.""" + """Gets a nested key. + + If any key is not found, or any intermediate key does not point to a + dict value, returns 'default_value'. + + Args: + keys: A tuple of strings representing the nested keys. + default_value: The default value to return if the key is not found. + override_configs: A dict of override configs with the same schema as + the config file, but only containing the keys to override. + + Returns: + The value of the nested key, or 'default_value' if not found. + """ config = copy.deepcopy(self) if override_configs is not None: config = _recursive_update(config, override_configs) return _get_nested(config, keys, default_value) def set_nested(self, keys: Tuple[str, ...], value: Any) -> None: - """Sets a nested key.""" + """Returns a deep-copied config with the nested key set to value. + + Like get_nested(), if any key is not found, this will not raise an + error. + """ override = {} for i, key in enumerate(reversed(keys)): if i == 0: @@ -86,29 +122,12 @@ def set_nested(self, keys: Tuple[str, ...], value: Any) -> None: _recursive_update(self, override) @classmethod - def from_dict(cls, config: Dict[str, Any]) -> 'NestedConfig': + def from_dict(cls, config: Optional[Dict[str, Any]]) -> 'NestedConfig': + if config is None: + return cls() return cls(**config) -# The config path is discovered in this order: -# -# (1) (Used internally) If env var {ENV_VAR_SKYPILOT_CONFIG} exists, use its -# path; -# (2) If file {CONFIG_PATH} exists, use this file. -# -# If the path discovered by (1) fails to load, we do not attempt to go to step -# 2 in the list. - -# (Used internally) An env var holding the path to the local config file. This -# is only used by jobs controller tasks to ensure recoveries of the same job -# use the same config file. -ENV_VAR_SKYPILOT_CONFIG = 'SKYPILOT_CONFIG' - -# Path to the local config file. -CONFIG_PATH = '~/.sky/config.yaml' - -logger = sky_logging.init_logger(__name__) - # The loaded config. _dict = NestedConfig() _loaded_config_path: Optional[str] = None @@ -182,7 +201,7 @@ def set_nested(keys: Tuple[str, ...], value: Any) -> Dict[str, Any]: """ copied_dict = copy.deepcopy(_dict) copied_dict.set_nested(keys, value) - return copied_dict + return dict(**copied_dict) def to_dict() -> NestedConfig: @@ -208,7 +227,8 @@ def _try_load_config() -> None: if os.path.exists(config_path): logger.debug(f'Using config path: {config_path}') try: - _dict = NestedConfig.from_dict(common_utils.read_yaml(config_path)) + config = common_utils.read_yaml(config_path) + _dict = NestedConfig.from_dict(config) _loaded_config_path = config_path logger.debug(f'Config loaded:\n{pprint.pformat(_dict)}') except yaml.YAMLError as e: @@ -236,4 +256,4 @@ def loaded_config_path() -> Optional[str]: def loaded() -> bool: """Returns if the user configurations are loaded.""" - return _loaded_config_path is not None + return bool(_dict) diff --git a/sky/utils/common_utils.py b/sky/utils/common_utils.py index 02ca585e13c..dffe784cc33 100644 --- a/sky/utils/common_utils.py +++ b/sky/utils/common_utils.py @@ -300,7 +300,7 @@ def user_and_hostname_hash() -> str: return f'{getpass.getuser()}-{hostname_hash}' -def read_yaml(path) -> Dict[str, Any]: +def read_yaml(path: str) -> Dict[str, Any]: with open(path, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) return config diff --git a/tests/test_config.py b/tests/test_config.py index 0cae5f9befb..a66e641d2a4 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,4 +1,5 @@ import copy +import importlib import pathlib import textwrap @@ -21,19 +22,17 @@ def _reload_config() -> None: - skypilot_config._dict = None + skypilot_config._dict = skypilot_config.NestedConfig() + skypilot_config._loaded_config_path = None skypilot_config._try_load_config() - def _check_empty_config() -> None: """Check that the config is empty.""" - assert not skypilot_config.loaded() + assert not skypilot_config.loaded(), (skypilot_config._dict, skypilot_config._loaded_config_path) assert skypilot_config.get_nested( ('aws', 'ssh_proxy_command'), None) is None assert skypilot_config.get_nested(('aws', 'ssh_proxy_command'), 'default') == 'default' - with pytest.raises(RuntimeError): - skypilot_config.set_nested(('aws', 'ssh_proxy_command'), 'value') def _create_config_file(config_file_path: pathlib.Path) -> None: @@ -98,6 +97,17 @@ def _create_task_yaml_file(task_file_path: pathlib.Path) -> None: """)) +def test_nested_config(monkeypatch) -> None: + """Test that the nested config works.""" + config = skypilot_config.NestedConfig() + config.set_nested(('aws', 'ssh_proxy_command'), 'value') + assert config == {'aws': {'ssh_proxy_command': 'value'}} + + assert config.get_nested(('admin_policy', ), 'default') == 'default' + config.set_nested(('aws', 'use_internal_ips'), True) + assert config == {'aws': {'ssh_proxy_command': 'value', 'use_internal_ips': True}} + + def test_no_config(monkeypatch) -> None: """Test that the config is not loaded if the config file does not exist.""" monkeypatch.setattr(skypilot_config, 'CONFIG_PATH', '/tmp/does_not_exist') From 28487a47fedba618c446f6ea783bfa8b86508d77 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sat, 21 Sep 2024 02:35:20 +0000 Subject: [PATCH 20/61] Fix mutated config --- .../example_policy/example_policy/__init__.py | 1 + .../example_policy/skypilot_policy.py | 11 +++++++++++ sky/exceptions.py | 5 +++++ sky/utils/policy_utils.py | 18 +++++++++++++----- tests/test_config.py | 13 ++++++++++--- tests/unit_tests/test_policy.py | 17 +++++++++++++---- 6 files changed, 53 insertions(+), 12 deletions(-) diff --git a/examples/admin_policy/example_policy/example_policy/__init__.py b/examples/admin_policy/example_policy/example_policy/__init__.py index f901a9a30c6..4c8db4d2f4b 100644 --- a/examples/admin_policy/example_policy/example_policy/__init__.py +++ b/examples/admin_policy/example_policy/example_policy/__init__.py @@ -1,4 +1,5 @@ """Example module for SkyPilot admin policies.""" from example_policy.skypilot_policy import ConfigLabelPolicy +from example_policy.skypilot_policy import RejectAllPolicy from example_policy.skypilot_policy import TaskLabelPolicy diff --git a/examples/admin_policy/example_policy/example_policy/skypilot_policy.py b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py index 9c10f6e8d07..86bf71cbaec 100644 --- a/examples/admin_policy/example_policy/example_policy/skypilot_policy.py +++ b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py @@ -37,3 +37,14 @@ def validate_and_mutate( local_user_name) return sky.MutatedUserRequest(task=user_request.task, skypilot_config=skypilot_config) + + +class RejectAllPolicy(sky.AdminPolicy): + """Example policy: reject all user requests.""" + + @classmethod + def validate_and_mutate( + cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: + """Reject all user requests.""" + del user_request + raise RuntimeError('Reject all policy') diff --git a/sky/exceptions.py b/sky/exceptions.py index 15f3ea3f34e..2a5a39f7edd 100644 --- a/sky/exceptions.py +++ b/sky/exceptions.py @@ -286,3 +286,8 @@ class ServeUserTerminatedError(Exception): class PortDoesNotExistError(Exception): """Raised when the port does not exist.""" + + +class UserRequestRejectedByPolicy(Exception): + """Raised when a user request is rejected by policy.""" + pass diff --git a/sky/utils/policy_utils.py b/sky/utils/policy_utils.py index 1bf6221bbdc..7891f5b08e6 100644 --- a/sky/utils/policy_utils.py +++ b/sky/utils/policy_utils.py @@ -7,6 +7,7 @@ from typing import Literal, Optional, Tuple, Union from sky import dag as dag_lib +from sky import exceptions from sky import policy as policy_lib from sky import sky_logging from sky import skypilot_config @@ -103,20 +104,27 @@ def apply( mutated_dag = dag_lib.Dag() mutated_dag.name = dag.name - mutated_config = skypilot_config.NestedConfig() + mutated_config = None for task in dag.tasks: user_request = policy_lib.UserRequest(task, config) - mutated_user_request = policy_cls.validate_and_mutate(user_request) - if not mutated_config: + try: + mutated_user_request = policy_cls.validate_and_mutate(user_request) + except Exception as e: + with ux_utils.print_exception_no_traceback(): + raise exceptions.UserRequestRejectedByPolicy( + f'User request rejected by policy {policy}, due to error: ' + f'{common_utils.format_exception(e, use_bracket=True)}') + if mutated_config is None: mutated_config = mutated_user_request.skypilot_config else: if mutated_config != mutated_user_request.skypilot_config: with ux_utils.print_exception_no_traceback(): - raise ValueError( + raise exceptions.UserRequestRejectedByPolicy( 'All tasks must have the same skypilot ' 'config after applying the policy. Please' 'check with your policy admin for details.') mutated_dag.add(mutated_user_request.task) + assert mutated_config is not None, dag # Update the new_dag's graph with the old dag's graph for u, v in dag.graph.edges: @@ -132,7 +140,7 @@ def apply( prefix='policy-mutated-skypilot-config-', suffix='.yaml') as temp_file: - common_utils.dump_yaml(temp_file.name, dict(**config)) + common_utils.dump_yaml(temp_file.name, dict(**mutated_config)) os.environ[skypilot_config.ENV_VAR_SKYPILOT_CONFIG] = temp_file.name logger.debug(f'Updated SkyPilot config: {temp_file.name}') # TODO(zhwu): This is not a clean way to update the SkyPilot config, diff --git a/tests/test_config.py b/tests/test_config.py index a66e641d2a4..d3a2517bba5 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -26,9 +26,11 @@ def _reload_config() -> None: skypilot_config._loaded_config_path = None skypilot_config._try_load_config() + def _check_empty_config() -> None: """Check that the config is empty.""" - assert not skypilot_config.loaded(), (skypilot_config._dict, skypilot_config._loaded_config_path) + assert not skypilot_config.loaded(), (skypilot_config._dict, + skypilot_config._loaded_config_path) assert skypilot_config.get_nested( ('aws', 'ssh_proxy_command'), None) is None assert skypilot_config.get_nested(('aws', 'ssh_proxy_command'), @@ -103,9 +105,14 @@ def test_nested_config(monkeypatch) -> None: config.set_nested(('aws', 'ssh_proxy_command'), 'value') assert config == {'aws': {'ssh_proxy_command': 'value'}} - assert config.get_nested(('admin_policy', ), 'default') == 'default' + assert config.get_nested(('admin_policy',), 'default') == 'default' config.set_nested(('aws', 'use_internal_ips'), True) - assert config == {'aws': {'ssh_proxy_command': 'value', 'use_internal_ips': True}} + assert config == { + 'aws': { + 'ssh_proxy_command': 'value', + 'use_internal_ips': True + } + } def test_no_config(monkeypatch) -> None: diff --git a/tests/unit_tests/test_policy.py b/tests/unit_tests/test_policy.py index 482ae5ff924..cb08a533a29 100644 --- a/tests/unit_tests/test_policy.py +++ b/tests/unit_tests/test_policy.py @@ -5,6 +5,7 @@ import pytest import sky +from sky import exceptions from sky import sky_logging from sky import skypilot_config from sky.utils import policy_utils @@ -12,7 +13,7 @@ logger = sky_logging.init_logger(__name__) POLICY_PATH = os.path.join(os.path.dirname(os.path.dirname(sky.__file__)), - 'examples', 'policy') + 'examples', 'admin_policy') @pytest.fixture @@ -29,12 +30,20 @@ def _load_task_and_apply_policy(config_path) -> sky.Dag: def test_task_level_changes_policy(add_example_policy_paths): - task = _load_task_and_apply_policy( + dag = _load_task_and_apply_policy( os.path.join(POLICY_PATH, 'task_label_config.yaml')) - assert 'local_user' in list(task.resources)[0].labels + assert 'local_user' in list(dag.tasks[0].resources)[0].labels def test_config_level_changes_policy(add_example_policy_paths): _load_task_and_apply_policy( os.path.join(POLICY_PATH, 'config_label_config.yaml')) - assert 'local_user' in skypilot_config.get_nested(('aws', 'labels'), {}) + print(skypilot_config._dict) + assert 'local_user' in skypilot_config.get_nested(('gcp', 'labels'), {}) + + +def test_reject_all_policy(add_example_policy_paths): + with pytest.raises(exceptions.UserRequestRejectedByPolicy, + match='Reject all policy'): + _load_task_and_apply_policy( + os.path.join(POLICY_PATH, 'reject_all_config.yaml')) From d1f048035cc79153ef1edb63bd0e68fedb605d02 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sat, 21 Sep 2024 02:36:05 +0000 Subject: [PATCH 21/61] fix --- sky/utils/policy_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sky/utils/policy_utils.py b/sky/utils/policy_utils.py index 7891f5b08e6..a4c6e65bcae 100644 --- a/sky/utils/policy_utils.py +++ b/sky/utils/policy_utils.py @@ -109,7 +109,7 @@ def apply( user_request = policy_lib.UserRequest(task, config) try: mutated_user_request = policy_cls.validate_and_mutate(user_request) - except Exception as e: + except Exception as e: # pylint: disable=broad-except with ux_utils.print_exception_no_traceback(): raise exceptions.UserRequestRejectedByPolicy( f'User request rejected by policy {policy}, due to error: ' From f42ace5dfa3edf8a3117bb9af7a60b4c58686ee6 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sat, 21 Sep 2024 02:56:04 +0000 Subject: [PATCH 22/61] Add policy doc --- docs/source/cloud-setup/policy.rst | 138 +++++++++++++++++++ docs/source/docs/index.rst | 3 +- examples/admin_policy/reject_all_config.yaml | 1 + 3 files changed, 141 insertions(+), 1 deletion(-) create mode 100644 docs/source/cloud-setup/policy.rst create mode 100644 examples/admin_policy/reject_all_config.yaml diff --git a/docs/source/cloud-setup/policy.rst b/docs/source/cloud-setup/policy.rst new file mode 100644 index 00000000000..561d38d6fb8 --- /dev/null +++ b/docs/source/cloud-setup/policy.rst @@ -0,0 +1,138 @@ +.. _advanced-policy-config: + +Advanced Admin Policy Enforcement +================================= + + +SkyPilot allows admins to enforce policies on users' SkyPilot usage by applying +custom validation and mutation logic on user's task and SkyPilot config. + +In short, admins offers a Python package with a customized inheritance of SkyPilot's +``AdminPolicy`` interface, and a user just needs to set the ``admin_policy`` field in +the SkyPilot config ``~/.sky/config.yaml`` to enforce the policy to all their +tasks. + +Overview +-------- + + + +User-Side +~~~~~~~~~~ + +To apply the policy, a user needs to set the ``admin_policy`` field in the SkyPilot config +``~/.sky/config.yaml`` to the path of the Python package that implements the policy. +For example: + +.. code-block:: yaml + + admin_policy: mypackage.subpackage.MyPolicy + + +.. hint:: + + SkyPilot loads the policy from the given package in the same Python environment. + You can test the existance of the policy by running: + + .. code-block:: bash + + python -c "from mypackage.subpackage import MyPolicy" + + +Admin-Side +~~~~~~~~~~ + +An admin can distribute the Python package to users with pre-defined policy. The +policy should follow the following interface: + +.. code-block:: python + + import sky + + class MyPolicy(sky.AdminPolicy): + @classmethod + def validate_and_mutate(cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: + # Logics for validate and modify user requests. + ... + return sky.MutatedUserRequest(user_request.task, + user_request.skypilot_config) + + +``UserRequest`` and ``MutatedUserRequest`` are defined as follows: + +.. code-block:: python + + class UserRequest: + task: sky.Task + skypilot_config: sky.NestedConfig + # More fields can be added in the future. + + class MutatedUserRequest: + task: sky.Task + skypilot_config: sky.NestedConfig + +That said, an ``AdminPolicy`` can mutate any fields of a user request, including +the :ref:`task ` and the :ref:`global skypilot config `, +giving admins a lot of flexibility to control user's SkyPilot usage. + +An ``AdminPolicy`` is responsible to both validate and mutate user requests. If +a request should be rejected, the policy should raise an exception. + + +Example Policies +---------------- + +Reject All +~~~~~~~~~~ + +.. code-block:: python + + class RejectAllPolicy(sky.AdminPolicy): + @classmethod + def validate_and_mutate(cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: + raise RuntimeError("This policy rejects all user requests.") + +.. code-block:: yaml + + admin_policy: examples.admin_policy.reject_all.RejectAllPolicy + + +Add Kubernetes Labels for all Tasks +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + class AddLabelsPolicy(sky.AdminPolicy): + @classmethod + def validate_and_mutate(cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: + config = user_request.skypilot_config + labels = config.get_nested(('kubernetes', 'labels'), {}) + labels['app'] = 'skypilot' + config.set_nested(('kubernetes', 'labels'), labels) + return sky.MutatedUserRequest(user_request.task, config) + +.. code-block:: yaml + + admin_policy: examples.admin_policy.add_labels.AddLabelsPolicy + + +Always Disable Public IP for AWS Tasks +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + class DisablePublicIPPolicy(sky.AdminPolicy): + @classmethod + def validate_and_mutate(cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: + config = user_request.skypilot_config + config.set_nested(('aws', 'use_internal_ip'), True) + if config.get_nested(('aws', 'vpc_name'), None) is None: + # If no VPC name is specified, it is likely a mistake. We should + # reject the request + raise RuntimeError('VPC name should be set. Check organization ' + 'wiki for more information.') + return sky.MutatedUserRequest(user_request.task, config) + +.. code-block:: yaml + + admin_policy: examples.admin_policy.disable_public_ip.DisablePublicIPPolicy diff --git a/docs/source/docs/index.rst b/docs/source/docs/index.rst index eeef2386337..00a645a3834 100644 --- a/docs/source/docs/index.rst +++ b/docs/source/docs/index.rst @@ -201,7 +201,8 @@ Read the research: ../cloud-setup/cloud-permissions/index ../cloud-setup/cloud-auth ../cloud-setup/quota - + ../cloud-setup/policy + .. toctree:: :hidden: :maxdepth: 1 diff --git a/examples/admin_policy/reject_all_config.yaml b/examples/admin_policy/reject_all_config.yaml new file mode 100644 index 00000000000..fe6632089d9 --- /dev/null +++ b/examples/admin_policy/reject_all_config.yaml @@ -0,0 +1 @@ +admin_policy: example_policy.RejectAllPolicy From c04f3dcb547e5f8a733cc46280a2577c5a486fb0 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sat, 21 Sep 2024 02:57:25 +0000 Subject: [PATCH 23/61] rename --- docs/source/cloud-setup/policy.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/cloud-setup/policy.rst b/docs/source/cloud-setup/policy.rst index 561d38d6fb8..dc5e4cf17a5 100644 --- a/docs/source/cloud-setup/policy.rst +++ b/docs/source/cloud-setup/policy.rst @@ -1,7 +1,7 @@ .. _advanced-policy-config: -Advanced Admin Policy Enforcement -================================= +Admin Policy Enforcement +======================== SkyPilot allows admins to enforce policies on users' SkyPilot usage by applying From 58f413c9e61728ffe8c3e88ec453d67cd49f64ee Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sat, 21 Sep 2024 17:11:33 +0000 Subject: [PATCH 24/61] minor --- sky/__init__.py | 1 + sky/utils/policy_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sky/__init__.py b/sky/__init__.py index 1895ba80c81..cfb40a993d3 100644 --- a/sky/__init__.py +++ b/sky/__init__.py @@ -190,6 +190,7 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]): # core APIs Storage Management 'storage_ls', 'storage_delete', + # Admin Policy 'UserRequest', 'MutatedUserRequest', 'AdminPolicy', diff --git a/sky/utils/policy_utils.py b/sky/utils/policy_utils.py index a4c6e65bcae..30aaf09fcec 100644 --- a/sky/utils/policy_utils.py +++ b/sky/utils/policy_utils.py @@ -19,7 +19,7 @@ def _get_policy_cls(policy: Optional[str]) -> Optional[policy_lib.AdminPolicy]: - """Get admin-defined policy.""" + """Gets admin-defined policy.""" if policy is None: return None try: @@ -69,7 +69,7 @@ def apply( entrypoint: Union['dag_lib.Dag', 'task_lib.Task'], apply_skypilot_config: bool = True, ) -> Union['dag_lib.Dag', Tuple['dag_lib.Dag', skypilot_config.NestedConfig]]: - """Apply user-defined policy to a DAG or a task. + """Applies user-defined policy to a DAG or a task. It mutates a Dag by applying user-defined policy and also update the global SkyPilot config if there is any changes made by the policy. From 52053bdadc4ba624e5311947ca1f55e6875b1845 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sat, 21 Sep 2024 18:07:47 +0000 Subject: [PATCH 25/61] Add additional arguments for autostop --- docs/source/cloud-setup/policy.rst | 33 ++++++++++++++++++- examples/admin_policy/enforce_autostop.yaml | 1 + .../example_policy/example_policy/__init__.py | 1 + .../example_policy/skypilot_policy.py | 24 ++++++++++++++ sky/__init__.py | 2 ++ sky/execution.py | 29 ++++++++++------ sky/policy.py | 32 ++++++++++++++++++ sky/utils/policy_utils.py | 7 ++-- tests/unit_tests/test_policy.py | 26 +++++++++++++-- 9 files changed, 140 insertions(+), 15 deletions(-) create mode 100644 examples/admin_policy/enforce_autostop.yaml diff --git a/docs/source/cloud-setup/policy.rst b/docs/source/cloud-setup/policy.rst index dc5e4cf17a5..3a85af90546 100644 --- a/docs/source/cloud-setup/policy.rst +++ b/docs/source/cloud-setup/policy.rst @@ -65,7 +65,7 @@ policy should follow the following interface: class UserRequest: task: sky.Task skypilot_config: sky.NestedConfig - # More fields can be added in the future. + operation_args: sky.OperationArgs class MutatedUserRequest: task: sky.Task @@ -136,3 +136,34 @@ Always Disable Public IP for AWS Tasks .. code-block:: yaml admin_policy: examples.admin_policy.disable_public_ip.DisablePublicIPPolicy + + +Enforce Autostop for all Tasks +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + class EnforceAutostopPolicy(sky.AdminPolicy): + @classmethod + def validate_and_mutate(cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: + operation_args = user_request.operation_args + # Operation args can be None for jobs and services, for which we + # don't need to enforce autostop, as they are already managed. + if operation_args is None: + return sky.MutatedUserRequest( + task=user_request.task, + skypilot_config=user_request.skypilot_config) + idle_minutes_to_autostop = operation_args.idle_minutes_to_autostop + # Enforce autostop/down to be set for all tasks for new clusters. + if not operation_args.cluster_exists and ( + idle_minutes_to_autostop is None or + idle_minutes_to_autostop < 0): + raise RuntimeError('Autostop/down must be set for all newly ' + 'launched clusters.') + return sky.MutatedUserRequest( + task=user_request.task, + skypilot_config=user_request.skypilot_config) + +.. code-block:: yaml + + admin_policy: examples.admin_policy.enforce_autostop.EnforceAutostopPolicy diff --git a/examples/admin_policy/enforce_autostop.yaml b/examples/admin_policy/enforce_autostop.yaml new file mode 100644 index 00000000000..f0194fb994e --- /dev/null +++ b/examples/admin_policy/enforce_autostop.yaml @@ -0,0 +1 @@ +admin_policy: example_policy.EnforceAutostopPolicy diff --git a/examples/admin_policy/example_policy/example_policy/__init__.py b/examples/admin_policy/example_policy/example_policy/__init__.py index 4c8db4d2f4b..7f73a72b729 100644 --- a/examples/admin_policy/example_policy/example_policy/__init__.py +++ b/examples/admin_policy/example_policy/example_policy/__init__.py @@ -1,5 +1,6 @@ """Example module for SkyPilot admin policies.""" from example_policy.skypilot_policy import ConfigLabelPolicy +from example_policy.skypilot_policy import EnforceAutostopPolicy from example_policy.skypilot_policy import RejectAllPolicy from example_policy.skypilot_policy import TaskLabelPolicy diff --git a/examples/admin_policy/example_policy/example_policy/skypilot_policy.py b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py index 86bf71cbaec..d66ef0e7ad4 100644 --- a/examples/admin_policy/example_policy/example_policy/skypilot_policy.py +++ b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py @@ -48,3 +48,27 @@ def validate_and_mutate( """Reject all user requests.""" del user_request raise RuntimeError('Reject all policy') + + +class EnforceAutostopPolicy(sky.AdminPolicy): + """Example policy: enforce autostop for all tasks.""" + + @classmethod + def validate_and_mutate( + cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: + """Enforce autostop for all tasks.""" + operation_args = user_request.operation_args + if operation_args is None: + return sky.MutatedUserRequest( + task=user_request.task, + skypilot_config=user_request.skypilot_config) + idle_minutes_to_autostop = operation_args.idle_minutes_to_autostop + # Enforce autostop/down to be set for all tasks for new clusters. + if not operation_args.cluster_exists and ( + idle_minutes_to_autostop is None or + idle_minutes_to_autostop < 0): + raise RuntimeError('Autostop/down must be set for all newly ' + 'launched clusters.') + return sky.MutatedUserRequest( + task=user_request.task, + skypilot_config=user_request.skypilot_config) diff --git a/sky/__init__.py b/sky/__init__.py index cfb40a993d3..b3296216be1 100644 --- a/sky/__init__.py +++ b/sky/__init__.py @@ -113,6 +113,7 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]): # Admin Policy interfaces from sky.policy import AdminPolicy from sky.policy import MutatedUserRequest +from sky.policy import OperationArgs from sky.policy import UserRequest from sky.resources import Resources from sky.skylet.job_lib import JobStatus @@ -195,4 +196,5 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]): 'MutatedUserRequest', 'AdminPolicy', 'NestedConfig', + 'OperationArgs', ] diff --git a/sky/execution.py b/sky/execution.py index c9801b1be96..622559f992b 100644 --- a/sky/execution.py +++ b/sky/execution.py @@ -13,6 +13,7 @@ from sky import clouds from sky import global_user_state from sky import optimizer +from sky import policy from sky import sky_logging from sky.backends import backend_utils from sky.usage import usage_lib @@ -159,8 +160,25 @@ def _execute( handle: Optional[backends.ResourceHandle]; the handle to the cluster. None if dryrun. """ + cluster_exists = False + if cluster_name is not None: + existing_handle = global_user_state.get_handle_from_cluster_name( + cluster_name) + cluster_exists = existing_handle is not None + # TODO(woosuk): If the cluster exists, print a warning that + # `cpus` and `memory` are not used as a job scheduling constraint, + # unlike `gpus`. + dag = dag_utils.convert_entrypoint_to_dag(entrypoint) - dag = policy_utils.apply(dag) + dag = policy_utils.apply( + dag, + operation_args=policy.OperationArgs( + cluster_name=cluster_name, + cluster_exists=cluster_exists, + idle_minutes_to_autostop=idle_minutes_to_autostop, + down=down, + dryrun=dryrun, + )) assert len(dag) == 1, f'We support 1 task for now. {dag}' task = dag.tasks[0] @@ -170,15 +188,6 @@ def _execute( 'Job recovery is specified in the task. To launch a ' 'managed job, please use: sky jobs launch') - cluster_exists = False - if cluster_name is not None: - existing_handle = global_user_state.get_handle_from_cluster_name( - cluster_name) - cluster_exists = existing_handle is not None - # TODO(woosuk): If the cluster exists, print a warning that - # `cpus` and `memory` are not used as a job scheduling constraint, - # unlike `gpus`. - stages = stages if stages is not None else list(Stage) # Requested features that some clouds support and others don't. diff --git a/sky/policy.py b/sky/policy.py index f015abcbc8c..7a690fe4638 100644 --- a/sky/policy.py +++ b/sky/policy.py @@ -2,15 +2,34 @@ import abc import dataclasses import typing +from typing import Optional if typing.TYPE_CHECKING: import sky +@dataclasses.dataclass +class OperationArgs: + cluster_name: Optional[str] + cluster_exists: bool + idle_minutes_to_autostop: Optional[int] + down: bool + dryrun: bool + + @dataclasses.dataclass class UserRequest: + """User request to the policy. + + Args: + task: User specified task. + skypilot_config: Global skypilot config. + execution_args: Execution arguments. It can be None for jobs and + services. + """ task: 'sky.Task' skypilot_config: 'sky.NestedConfig' + operation_args: Optional['OperationArgs'] = None @dataclasses.dataclass @@ -46,5 +65,18 @@ def validate_and_mutate(user_request: UserRequest) -> MutatedUserRequest: @abc.abstractmethod def validate_and_mutate(cls, user_request: UserRequest) -> MutatedUserRequest: + """Validates and mutates the user request and returns mutated request. + + Args: + user_request: The user request to validate and mutate. + UserRequest contains (sky.Task, sky.NestedConfig) + + Returns: + MutatedUserRequest: The mutated user request. + MutatedUserRequest contains (sky.Task, sky.NestedConfig) + + Raises: + Any exception to reject the user request. + """ raise NotImplementedError( 'Your policy must implement validate_and_mutate') diff --git a/sky/utils/policy_utils.py b/sky/utils/policy_utils.py index 30aaf09fcec..1cac783eceb 100644 --- a/sky/utils/policy_utils.py +++ b/sky/utils/policy_utils.py @@ -53,6 +53,7 @@ def _get_policy_cls(policy: Optional[str]) -> Optional[policy_lib.AdminPolicy]: def apply( entrypoint: Union['dag_lib.Dag', 'task_lib.Task'], apply_skypilot_config: Literal[True] = True, + execution_args: Optional[policy_lib.OperationArgs] = None, ) -> 'dag_lib.Dag': ... @@ -61,6 +62,7 @@ def apply( def apply( entrypoint: Union['dag_lib.Dag', 'task_lib.Task'], apply_skypilot_config: Literal[False], + execution_args: Optional[policy_lib.OperationArgs] = None, ) -> Tuple['dag_lib.Dag', skypilot_config.NestedConfig]: ... @@ -68,10 +70,11 @@ def apply( def apply( entrypoint: Union['dag_lib.Dag', 'task_lib.Task'], apply_skypilot_config: bool = True, + operation_args: Optional[policy_lib.OperationArgs] = None, ) -> Union['dag_lib.Dag', Tuple['dag_lib.Dag', skypilot_config.NestedConfig]]: """Applies user-defined policy to a DAG or a task. - It mutates a Dag by applying user-defined policy and also update the + It mutates a Dag by applying user-defined policy and also updates the global SkyPilot config if there is any changes made by the policy. Args: @@ -106,7 +109,7 @@ def apply( mutated_config = None for task in dag.tasks: - user_request = policy_lib.UserRequest(task, config) + user_request = policy_lib.UserRequest(task, config, operation_args) try: mutated_user_request = policy_cls.validate_and_mutate(user_request) except Exception as e: # pylint: disable=broad-except diff --git a/tests/unit_tests/test_policy.py b/tests/unit_tests/test_policy.py index cb08a533a29..7e1e1e69000 100644 --- a/tests/unit_tests/test_policy.py +++ b/tests/unit_tests/test_policy.py @@ -1,6 +1,7 @@ import importlib import os import sys +from typing import Optional import pytest @@ -22,11 +23,21 @@ def add_example_policy_paths(): sys.path.append(os.path.join(POLICY_PATH, 'example_policy')) -def _load_task_and_apply_policy(config_path) -> sky.Dag: +def _load_task_and_apply_policy( + config_path: str, + idle_minutes_to_autostop: Optional[int] = None) -> sky.Dag: os.environ['SKYPILOT_CONFIG'] = config_path importlib.reload(skypilot_config) task = sky.Task.from_yaml(os.path.join(POLICY_PATH, 'task.yaml')) - return policy_utils.apply(task) + return policy_utils.apply( + task, + operation_args=sky.OperationArgs( + cluster_name='test', + cluster_exists=False, + idle_minutes_to_autostop=idle_minutes_to_autostop, + down=False, + dryrun=False, + )) def test_task_level_changes_policy(add_example_policy_paths): @@ -47,3 +58,14 @@ def test_reject_all_policy(add_example_policy_paths): match='Reject all policy'): _load_task_and_apply_policy( os.path.join(POLICY_PATH, 'reject_all_config.yaml')) + + +def test_enforce_autostop_policy(add_example_policy_paths): + _load_task_and_apply_policy(os.path.join(POLICY_PATH, + 'enforce_autostop.yaml'), + idle_minutes_to_autostop=10) + with pytest.raises(exceptions.UserRequestRejectedByPolicy, + match='Autostop/down must be set'): + _load_task_and_apply_policy(os.path.join(POLICY_PATH, + 'enforce_autostop.yaml'), + idle_minutes_to_autostop=None) From 4a4f682246f146ea5550d7fe9362d11907adba55 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sat, 21 Sep 2024 23:14:07 +0000 Subject: [PATCH 26/61] fix mypy --- sky/utils/policy_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sky/utils/policy_utils.py b/sky/utils/policy_utils.py index 1cac783eceb..53e7a98aa80 100644 --- a/sky/utils/policy_utils.py +++ b/sky/utils/policy_utils.py @@ -53,7 +53,7 @@ def _get_policy_cls(policy: Optional[str]) -> Optional[policy_lib.AdminPolicy]: def apply( entrypoint: Union['dag_lib.Dag', 'task_lib.Task'], apply_skypilot_config: Literal[True] = True, - execution_args: Optional[policy_lib.OperationArgs] = None, + operation_args: Optional[policy_lib.OperationArgs] = None, ) -> 'dag_lib.Dag': ... @@ -62,7 +62,7 @@ def apply( def apply( entrypoint: Union['dag_lib.Dag', 'task_lib.Task'], apply_skypilot_config: Literal[False], - execution_args: Optional[policy_lib.OperationArgs] = None, + operation_args: Optional[policy_lib.OperationArgs] = None, ) -> Tuple['dag_lib.Dag', skypilot_config.NestedConfig]: ... From a8d1c440b679012162ffc52fd8e954c35d0a4d4a Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sun, 22 Sep 2024 03:16:20 +0000 Subject: [PATCH 27/61] format --- sky/utils/policy_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sky/utils/policy_utils.py b/sky/utils/policy_utils.py index 53e7a98aa80..b6d1ca97403 100644 --- a/sky/utils/policy_utils.py +++ b/sky/utils/policy_utils.py @@ -115,8 +115,10 @@ def apply( except Exception as e: # pylint: disable=broad-except with ux_utils.print_exception_no_traceback(): raise exceptions.UserRequestRejectedByPolicy( - f'User request rejected by policy {policy}, due to error: ' - f'{common_utils.format_exception(e, use_bracket=True)}') + f'User request rejected by policy {policy!r}, due to ' + 'error: ' + f'{common_utils.format_exception(e, use_bracket=True)}' + ) from e if mutated_config is None: mutated_config = mutated_user_request.skypilot_config else: From 6c73d81a8834b3c08509a61e072c429b55f6bdef Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sun, 22 Sep 2024 03:18:10 +0000 Subject: [PATCH 28/61] rejected message --- sky/utils/policy_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sky/utils/policy_utils.py b/sky/utils/policy_utils.py index b6d1ca97403..8c466d6ca77 100644 --- a/sky/utils/policy_utils.py +++ b/sky/utils/policy_utils.py @@ -1,4 +1,5 @@ """Customize policy by users.""" +import colorama import copy import importlib import os @@ -115,8 +116,8 @@ def apply( except Exception as e: # pylint: disable=broad-except with ux_utils.print_exception_no_traceback(): raise exceptions.UserRequestRejectedByPolicy( - f'User request rejected by policy {policy!r}, due to ' - 'error: ' + f'{colorama.Fore.RED}User request rejected by policy ' + f'{policy!r}{colorama.Fore.RESET}: ' f'{common_utils.format_exception(e, use_bracket=True)}' ) from e if mutated_config is None: From 247c0b8a8fd94606f2cd1165f88c85afb2397ed2 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sun, 22 Sep 2024 03:23:09 +0000 Subject: [PATCH 29/61] format --- sky/utils/policy_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sky/utils/policy_utils.py b/sky/utils/policy_utils.py index 8c466d6ca77..bc5354c857f 100644 --- a/sky/utils/policy_utils.py +++ b/sky/utils/policy_utils.py @@ -1,5 +1,4 @@ """Customize policy by users.""" -import colorama import copy import importlib import os @@ -7,6 +6,8 @@ import typing from typing import Literal, Optional, Tuple, Union +import colorama + from sky import dag as dag_lib from sky import exceptions from sky import policy as policy_lib From f8a5a641d9023cffd19ae494eb3cfd7f349f33ea Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sun, 22 Sep 2024 11:42:10 -0700 Subject: [PATCH 30/61] Update sky/utils/policy_utils.py Co-authored-by: Zongheng Yang --- sky/utils/policy_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sky/utils/policy_utils.py b/sky/utils/policy_utils.py index bc5354c857f..404b1c63789 100644 --- a/sky/utils/policy_utils.py +++ b/sky/utils/policy_utils.py @@ -1,4 +1,4 @@ -"""Customize policy by users.""" +"""Admin policy utils.""" import copy import importlib import os From 73a45811daf7876ed0389173ce1122c95662bcf7 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sun, 22 Sep 2024 11:45:19 -0700 Subject: [PATCH 31/61] Update sky/utils/policy_utils.py Co-authored-by: Zongheng Yang --- sky/utils/policy_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sky/utils/policy_utils.py b/sky/utils/policy_utils.py index 404b1c63789..53e55a85815 100644 --- a/sky/utils/policy_utils.py +++ b/sky/utils/policy_utils.py @@ -74,7 +74,7 @@ def apply( apply_skypilot_config: bool = True, operation_args: Optional[policy_lib.OperationArgs] = None, ) -> Union['dag_lib.Dag', Tuple['dag_lib.Dag', skypilot_config.NestedConfig]]: - """Applies user-defined policy to a DAG or a task. + """Applies an admin policy (if registered) to a DAG or a task. It mutates a Dag by applying user-defined policy and also updates the global SkyPilot config if there is any changes made by the policy. From d78a822ba394eb3acbb2b35762f98997dd4502e2 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sun, 22 Sep 2024 18:52:27 +0000 Subject: [PATCH 32/61] Fix --- sky/__init__.py | 10 ++-- sky/{policy.py => admin_policy.py} | 0 sky/execution.py | 8 +-- sky/jobs/core.py | 6 +-- sky/serve/core.py | 6 +-- ...{policy_utils.py => admin_policy_utils.py} | 52 +++++-------------- .../{test_policy.py => test_admin_policy.py} | 13 ++--- 7 files changed, 36 insertions(+), 59 deletions(-) rename sky/{policy.py => admin_policy.py} (100%) rename sky/utils/{policy_utils.py => admin_policy_utils.py} (77%) rename tests/unit_tests/{test_policy.py => test_admin_policy.py} (88%) diff --git a/sky/__init__.py b/sky/__init__.py index b3296216be1..bdeb8299755 100644 --- a/sky/__init__.py +++ b/sky/__init__.py @@ -82,6 +82,11 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]): from sky import backends from sky import benchmark from sky import clouds +# Admin Policy interfaces +from sky.admin_policy import AdminPolicy +from sky.admin_policy import MutatedUserRequest +from sky.admin_policy import OperationArgs +from sky.admin_policy import UserRequest from sky.clouds.service_catalog import list_accelerators from sky.core import autostop from sky.core import cancel @@ -110,11 +115,6 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]): from sky.jobs.core import spot_tail_logs from sky.optimizer import Optimizer from sky.optimizer import OptimizeTarget -# Admin Policy interfaces -from sky.policy import AdminPolicy -from sky.policy import MutatedUserRequest -from sky.policy import OperationArgs -from sky.policy import UserRequest from sky.resources import Resources from sky.skylet.job_lib import JobStatus from sky.skypilot_config import NestedConfig diff --git a/sky/policy.py b/sky/admin_policy.py similarity index 100% rename from sky/policy.py rename to sky/admin_policy.py diff --git a/sky/execution.py b/sky/execution.py index 622559f992b..ccbf40a929d 100644 --- a/sky/execution.py +++ b/sky/execution.py @@ -9,18 +9,18 @@ import colorama import sky +from sky import admin_policy from sky import backends from sky import clouds from sky import global_user_state from sky import optimizer -from sky import policy from sky import sky_logging from sky.backends import backend_utils from sky.usage import usage_lib +from sky.utils import admin_policy_utils from sky.utils import controller_utils from sky.utils import dag_utils from sky.utils import env_options -from sky.utils import policy_utils from sky.utils import resources_utils from sky.utils import rich_utils from sky.utils import subprocess_utils @@ -170,9 +170,9 @@ def _execute( # unlike `gpus`. dag = dag_utils.convert_entrypoint_to_dag(entrypoint) - dag = policy_utils.apply( + dag, _ = admin_policy_utils.apply( dag, - operation_args=policy.OperationArgs( + operation_args=admin_policy.OperationArgs( cluster_name=cluster_name, cluster_exists=cluster_exists, idle_minutes_to_autostop=idle_minutes_to_autostop, diff --git a/sky/jobs/core.py b/sky/jobs/core.py index d7a166812ef..689ba91a373 100644 --- a/sky/jobs/core.py +++ b/sky/jobs/core.py @@ -18,10 +18,10 @@ from sky.jobs import utils as managed_job_utils from sky.skylet import constants as skylet_constants from sky.usage import usage_lib +from sky.utils import admin_policy_utils from sky.utils import common_utils from sky.utils import controller_utils from sky.utils import dag_utils -from sky.utils import policy_utils from sky.utils import rich_utils from sky.utils import subprocess_utils from sky.utils import ux_utils @@ -57,8 +57,8 @@ def launch( dag = dag_utils.convert_entrypoint_to_dag(entrypoint) # TODO(zhwu): We should only apply policy to dag and save the config file, # instead of having the config file actually being used. - dag, mutated_user_config = policy_utils.apply(dag, - apply_skypilot_config=False) + dag, mutated_user_config = admin_policy_utils.apply( + dag, apply_skypilot_config=False) if not dag.is_chain(): with ux_utils.print_exception_no_traceback(): raise ValueError('Only single-task or chain DAG is ' diff --git a/sky/serve/core.py b/sky/serve/core.py index c1914afe90e..9d148ece957 100644 --- a/sky/serve/core.py +++ b/sky/serve/core.py @@ -17,9 +17,9 @@ from sky.serve import serve_utils from sky.skylet import constants from sky.usage import usage_lib +from sky.utils import admin_policy_utils from sky.utils import common_utils from sky.utils import controller_utils -from sky.utils import policy_utils from sky.utils import resources_utils from sky.utils import rich_utils from sky.utils import subprocess_utils @@ -125,8 +125,8 @@ def up( _validate_service_task(task) - dag, mutated_user_config = policy_utils.apply(task, - apply_skypilot_config=False) + dag, mutated_user_config = admin_policy_utils.apply( + task, apply_skypilot_config=False) task = dag.tasks[0] controller_utils.maybe_translate_local_file_mounts_and_sync_up(task, diff --git a/sky/utils/policy_utils.py b/sky/utils/admin_policy_utils.py similarity index 77% rename from sky/utils/policy_utils.py rename to sky/utils/admin_policy_utils.py index bc5354c857f..3927877f359 100644 --- a/sky/utils/policy_utils.py +++ b/sky/utils/admin_policy_utils.py @@ -3,14 +3,13 @@ import importlib import os import tempfile -import typing -from typing import Literal, Optional, Tuple, Union +from typing import Optional, Tuple, Union import colorama +from sky import admin_policy from sky import dag as dag_lib from sky import exceptions -from sky import policy as policy_lib from sky import sky_logging from sky import skypilot_config from sky import task as task_lib @@ -20,7 +19,8 @@ logger = sky_logging.init_logger(__name__) -def _get_policy_cls(policy: Optional[str]) -> Optional[policy_lib.AdminPolicy]: +def _get_policy_cls( + policy: Optional[str]) -> Optional[admin_policy.AdminPolicy]: """Gets admin-defined policy.""" if policy is None: return None @@ -43,37 +43,19 @@ def _get_policy_cls(policy: Optional[str]) -> Optional[policy_lib.AdminPolicy]: 'Please check with your policy admin for details.') from e # Check if the module implements the AdminPolicy interface. - if not issubclass(policy_cls, policy_lib.AdminPolicy): + if not issubclass(policy_cls, admin_policy.AdminPolicy): with ux_utils.print_exception_no_traceback(): raise ValueError( - f'Policy module {policy} does not implement the AdminPolicy ' + f'Policy class {policy!r} does not implement the AdminPolicy ' 'interface. Please check with your policy admin for details.') return policy_cls -@typing.overload -def apply( - entrypoint: Union['dag_lib.Dag', 'task_lib.Task'], - apply_skypilot_config: Literal[True] = True, - operation_args: Optional[policy_lib.OperationArgs] = None, -) -> 'dag_lib.Dag': - ... - - -@typing.overload -def apply( - entrypoint: Union['dag_lib.Dag', 'task_lib.Task'], - apply_skypilot_config: Literal[False], - operation_args: Optional[policy_lib.OperationArgs] = None, -) -> Tuple['dag_lib.Dag', skypilot_config.NestedConfig]: - ... - - def apply( entrypoint: Union['dag_lib.Dag', 'task_lib.Task'], apply_skypilot_config: bool = True, - operation_args: Optional[policy_lib.OperationArgs] = None, -) -> Union['dag_lib.Dag', Tuple['dag_lib.Dag', skypilot_config.NestedConfig]]: + operation_args: Optional[admin_policy.OperationArgs] = None, +) -> Tuple['dag_lib.Dag', skypilot_config.NestedConfig]: """Applies user-defined policy to a DAG or a task. It mutates a Dag by applying user-defined policy and also updates the @@ -83,11 +65,11 @@ def apply( dag: The dag to be mutated by the policy. apply_skypilot_config: Whether to apply the skypilot config changes to the global skypilot config. + operation_args: Additional arguments user passed in SkyPilot operations. Returns: - The mutated dag or task. - Or, a tuple of the mutated dag and path to the mutated skypilot - config, if apply_skypilot_config is set to False. + - The mutated dag or task. + - The mutated skypilot config. """ if isinstance(entrypoint, task_lib.Task): dag = dag_lib.Dag() @@ -98,10 +80,7 @@ def apply( policy = skypilot_config.get_nested(('admin_policy',), None) policy_cls = _get_policy_cls(policy) if policy_cls is None: - if apply_skypilot_config: - return dag - else: - return dag, skypilot_config.to_dict() + return dag, skypilot_config.to_dict() logger.info(f'Applying policy: {policy}') original_config = skypilot_config.to_dict() @@ -111,7 +90,7 @@ def apply( mutated_config = None for task in dag.tasks: - user_request = policy_lib.UserRequest(task, config, operation_args) + user_request = admin_policy.UserRequest(task, config, operation_args) try: mutated_user_request = policy_cls.validate_and_mutate(user_request) except Exception as e: # pylint: disable=broad-except @@ -156,7 +135,4 @@ def apply( importlib.reload(skypilot_config) logger.debug(f'Mutated user request: {mutated_user_request}') - if apply_skypilot_config: - return mutated_dag - else: - return mutated_dag, mutated_config + return mutated_dag, mutated_config diff --git a/tests/unit_tests/test_policy.py b/tests/unit_tests/test_admin_policy.py similarity index 88% rename from tests/unit_tests/test_policy.py rename to tests/unit_tests/test_admin_policy.py index 7e1e1e69000..582be733d10 100644 --- a/tests/unit_tests/test_policy.py +++ b/tests/unit_tests/test_admin_policy.py @@ -1,7 +1,7 @@ import importlib import os import sys -from typing import Optional +from typing import Optional, Tuple import pytest @@ -9,7 +9,7 @@ from sky import exceptions from sky import sky_logging from sky import skypilot_config -from sky.utils import policy_utils +from sky.utils import admin_policy_utils logger = sky_logging.init_logger(__name__) @@ -24,12 +24,13 @@ def add_example_policy_paths(): def _load_task_and_apply_policy( - config_path: str, - idle_minutes_to_autostop: Optional[int] = None) -> sky.Dag: + config_path: str, + idle_minutes_to_autostop: Optional[int] = None +) -> Tuple[sky.Dag, skypilot_config.NestedConfig]: os.environ['SKYPILOT_CONFIG'] = config_path importlib.reload(skypilot_config) task = sky.Task.from_yaml(os.path.join(POLICY_PATH, 'task.yaml')) - return policy_utils.apply( + return admin_policy_utils.apply( task, operation_args=sky.OperationArgs( cluster_name='test', @@ -41,7 +42,7 @@ def _load_task_and_apply_policy( def test_task_level_changes_policy(add_example_policy_paths): - dag = _load_task_and_apply_policy( + dag, _ = _load_task_and_apply_policy( os.path.join(POLICY_PATH, 'task_label_config.yaml')) assert 'local_user' in list(dag.tasks[0].resources)[0].labels From 68275f60f719057ca2a5f1eeca239f0da5159905 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sun, 22 Sep 2024 12:29:05 -0700 Subject: [PATCH 33/61] Update examples/admin_policy/example_policy/example_policy/__init__.py Co-authored-by: Zongheng Yang --- examples/admin_policy/example_policy/example_policy/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/admin_policy/example_policy/example_policy/__init__.py b/examples/admin_policy/example_policy/example_policy/__init__.py index 7f73a72b729..37baacc7d01 100644 --- a/examples/admin_policy/example_policy/example_policy/__init__.py +++ b/examples/admin_policy/example_policy/example_policy/__init__.py @@ -1,4 +1,4 @@ -"""Example module for SkyPilot admin policies.""" +"""Example admin policy module and prebuilt policies.""" from example_policy.skypilot_policy import ConfigLabelPolicy from example_policy.skypilot_policy import EnforceAutostopPolicy From 9644622c57f9facc800484acd1e42376e642e784 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sun, 22 Sep 2024 12:30:04 -0700 Subject: [PATCH 34/61] Update docs/source/reference/config.rst Co-authored-by: Zongheng Yang --- docs/source/reference/config.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst index 9fde305dfba..9118b4cfcc1 100644 --- a/docs/source/reference/config.rst +++ b/docs/source/reference/config.rst @@ -87,7 +87,7 @@ Available fields and semantics: # Default: false. disable_ecc: false - # Custom policy to be applied to all tasks. + # Custom admin policy to be applied to all tasks. # # The policy class to be applied and mutate all tasks, which can be used to # enforce certain policies on all tasks. From 17f8fa108cb262a2bf11266e0caeb679ada076a4 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sun, 22 Sep 2024 19:41:55 +0000 Subject: [PATCH 35/61] Address comments --- docs/source/cloud-setup/policy.rst | 47 +++++++++++++++++++++++++-- docs/source/reference/config.rst | 8 +++-- sky/__init__.py | 7 ++-- sky/admin_policy.py | 39 ++++++++++++---------- sky/exceptions.py | 2 +- sky/execution.py | 2 +- sky/jobs/core.py | 4 +-- sky/serve/core.py | 2 +- sky/skypilot_config.py | 18 +++++----- sky/utils/admin_policy_utils.py | 17 +++++----- sky/utils/controller_utils.py | 24 ++++++++------ tests/test_config.py | 4 +-- tests/unit_tests/test_admin_policy.py | 4 +-- 13 files changed, 113 insertions(+), 65 deletions(-) diff --git a/docs/source/cloud-setup/policy.rst b/docs/source/cloud-setup/policy.rst index 3a85af90546..7cbb2c30d66 100644 --- a/docs/source/cloud-setup/policy.rst +++ b/docs/source/cloud-setup/policy.rst @@ -63,13 +63,24 @@ policy should follow the following interface: .. code-block:: python class UserRequest: + """User request to the policy. + + It is a combination of a task, request options, and the global skypilot + config used to run a task, including `sky launch / exec / jobs launch / ..`. + + Args: + task: User specified task. + skypilot_config: Global skypilot config to be used in this request. + request_options: Request options. It can be None for jobs and + services. + """ task: sky.Task - skypilot_config: sky.NestedConfig - operation_args: sky.OperationArgs + skypilot_config: sky.Config + operation_args: sky.RequestOptions class MutatedUserRequest: task: sky.Task - skypilot_config: sky.NestedConfig + skypilot_config: sky.Config That said, an ``AdminPolicy`` can mutate any fields of a user request, including the :ref:`task ` and the :ref:`global skypilot config `, @@ -78,6 +89,36 @@ giving admins a lot of flexibility to control user's SkyPilot usage. An ``AdminPolicy`` is responsible to both validate and mutate user requests. If a request should be rejected, the policy should raise an exception. +The ``sky.Config`` and ``sky.RequestOptions`` are defined as follows: + +.. code-block:: python + + class Config: + def get_nested(self, + keys: Tuple[str, ...], + default_value: Any, + override_configs: Optional[Dict[str, Any]] = None, + ) -> Any: + """Gets a value with nested keys. + + If override_configs is provided, it value will be merged on top of + the current config. + """ + ... + + def set_nested(self, keys: Tuple[str, ...], value: Any) -> None: + """Sets a value with nested keys.""" + ... + + @dataclass + class RequestOptions: + """Options a user specified in their request to SkyPilot.""" + cluster_name: Optional[str] + cluster_exists: bool + idle_minutes_to_autostop: Optional[int] + down: bool + dryrun: bool + Example Policies ---------------- diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst index 9fde305dfba..4e0d901f36e 100644 --- a/docs/source/reference/config.rst +++ b/docs/source/reference/config.rst @@ -87,10 +87,12 @@ Available fields and semantics: # Default: false. disable_ecc: false - # Custom policy to be applied to all tasks. + # Custom policy to be applied to all tasks. (optional). # - # The policy class to be applied and mutate all tasks, which can be used to - # enforce certain policies on all tasks. + # The policy class to be applied to all tasks, which can be used to validate + # and mutate user requests. + # This is useful for enforcing certain policies on all tasks, e.g., + # add custom labels; enforce certain resource limits; etc. # # The policy class should implement the sky.AdminPolicy interface. admin_policy: my_package.SkyPilotPolicyV1 diff --git a/sky/__init__.py b/sky/__init__.py index bdeb8299755..37b5a1caf08 100644 --- a/sky/__init__.py +++ b/sky/__init__.py @@ -82,10 +82,8 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]): from sky import backends from sky import benchmark from sky import clouds -# Admin Policy interfaces from sky.admin_policy import AdminPolicy from sky.admin_policy import MutatedUserRequest -from sky.admin_policy import OperationArgs from sky.admin_policy import UserRequest from sky.clouds.service_catalog import list_accelerators from sky.core import autostop @@ -117,7 +115,7 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]): from sky.optimizer import OptimizeTarget from sky.resources import Resources from sky.skylet.job_lib import JobStatus -from sky.skypilot_config import NestedConfig +from sky.skypilot_config import Config from sky.status_lib import ClusterStatus from sky.task import Task @@ -195,6 +193,5 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]): 'UserRequest', 'MutatedUserRequest', 'AdminPolicy', - 'NestedConfig', - 'OperationArgs', + 'Config', ] diff --git a/sky/admin_policy.py b/sky/admin_policy.py index 7a690fe4638..9d524f1ce19 100644 --- a/sky/admin_policy.py +++ b/sky/admin_policy.py @@ -9,7 +9,7 @@ @dataclasses.dataclass -class OperationArgs: +class RequestOptions: cluster_name: Optional[str] cluster_exists: bool idle_minutes_to_autostop: Optional[int] @@ -21,35 +21,35 @@ class OperationArgs: class UserRequest: """User request to the policy. + It is a combination of a task, request options, and the global skypilot + config used to run a task, including `sky launch / exec / jobs launch / ..`. + Args: task: User specified task. - skypilot_config: Global skypilot config. - execution_args: Execution arguments. It can be None for jobs and + skypilot_config: Global skypilot config to be used in this request. + request_options: Request options. It can be None for jobs and services. """ task: 'sky.Task' - skypilot_config: 'sky.NestedConfig' - operation_args: Optional['OperationArgs'] = None + skypilot_config: 'sky.Config' + request_options: Optional['RequestOptions'] = None @dataclasses.dataclass class MutatedUserRequest: task: 'sky.Task' - skypilot_config: 'sky.NestedConfig' + skypilot_config: 'sky.Config' # pylint: disable=line-too-long class AdminPolicy: - """Interface for admin-defined policy for user requests. - - A user-defined policy is a string to a python function that can be imported - from the same environment where SkyPilot is running. + """Abstract interface of an admin-defined policy for all user requests. - It can be specified in the SkyPilot config file under the key 'policy', e.g. + A policy is a string to a python class inheriting from AdminPolicy that can + be imported from the same environment where SkyPilot is running. - policy: my_package.SkyPilotPolicyV1 - The AdminPolicy class is expected to have the following signature: + Admins can implement a subclass of AdminPolicy with the following signature: import sky @@ -58,7 +58,12 @@ def validate_and_mutate(user_request: UserRequest) -> MutatedUserRequest: ... return MutatedUserRequest(task=..., skypilot_config=...) - The function can mutate both task and skypilot_config. + The policy can mutate both task and skypilot_config. + + Users can register a subclass of AdminPolicy in the SkyPilot config file + under the key 'admin_policy', e.g. + + admin_policy: my_package.SkyPilotPolicyV1 """ @classmethod @@ -69,14 +74,14 @@ def validate_and_mutate(cls, Args: user_request: The user request to validate and mutate. - UserRequest contains (sky.Task, sky.NestedConfig) + UserRequest contains (sky.Task, sky.Config) Returns: MutatedUserRequest: The mutated user request. - MutatedUserRequest contains (sky.Task, sky.NestedConfig) + MutatedUserRequest contains (sky.Task, sky.Config) Raises: - Any exception to reject the user request. + Exception to throw if the user request failed the validation. """ raise NotImplementedError( 'Your policy must implement validate_and_mutate') diff --git a/sky/exceptions.py b/sky/exceptions.py index 2a5a39f7edd..04c50ad4e08 100644 --- a/sky/exceptions.py +++ b/sky/exceptions.py @@ -289,5 +289,5 @@ class PortDoesNotExistError(Exception): class UserRequestRejectedByPolicy(Exception): - """Raised when a user request is rejected by policy.""" + """Raised when a user request is rejected by an admin policy.""" pass diff --git a/sky/execution.py b/sky/execution.py index ccbf40a929d..97822618684 100644 --- a/sky/execution.py +++ b/sky/execution.py @@ -172,7 +172,7 @@ def _execute( dag = dag_utils.convert_entrypoint_to_dag(entrypoint) dag, _ = admin_policy_utils.apply( dag, - operation_args=admin_policy.OperationArgs( + operation_args=admin_policy.RequestOptions( cluster_name=cluster_name, cluster_exists=cluster_exists, idle_minutes_to_autostop=idle_minutes_to_autostop, diff --git a/sky/jobs/core.py b/sky/jobs/core.py index 689ba91a373..d11b61cb3ca 100644 --- a/sky/jobs/core.py +++ b/sky/jobs/core.py @@ -55,10 +55,8 @@ def launch( dag_uuid = str(uuid.uuid4().hex[:4]) dag = dag_utils.convert_entrypoint_to_dag(entrypoint) - # TODO(zhwu): We should only apply policy to dag and save the config file, - # instead of having the config file actually being used. dag, mutated_user_config = admin_policy_utils.apply( - dag, apply_skypilot_config=False) + dag, update_skypilot_config_for_current_request=False) if not dag.is_chain(): with ux_utils.print_exception_no_traceback(): raise ValueError('Only single-task or chain DAG is ' diff --git a/sky/serve/core.py b/sky/serve/core.py index 9d148ece957..46415a48040 100644 --- a/sky/serve/core.py +++ b/sky/serve/core.py @@ -126,7 +126,7 @@ def up( _validate_service_task(task) dag, mutated_user_config = admin_policy_utils.apply( - task, apply_skypilot_config=False) + task, update_skypilot_config_for_current_request=False) task = dag.tasks[0] controller_utils.maybe_translate_local_file_mounts_and_sync_up(task, diff --git a/sky/skypilot_config.py b/sky/skypilot_config.py index 4313e04b2cc..aae62afc616 100644 --- a/sky/skypilot_config.py +++ b/sky/skypilot_config.py @@ -81,8 +81,8 @@ CONFIG_PATH = '~/.sky/config.yaml' -class NestedConfig(Dict[str, Any]): - """A nested dictionary that allows for setting and getting values.""" +class Config(Dict[str, Any]): + """SkyPilot config that supports setting/getting values with nested keys.""" def get_nested(self, keys: Tuple[str, ...], @@ -108,7 +108,7 @@ def get_nested(self, return _get_nested(config, keys, default_value) def set_nested(self, keys: Tuple[str, ...], value: Any) -> None: - """Returns a deep-copied config with the nested key set to value. + """In-place sets a nested key to value. Like get_nested(), if any key is not found, this will not raise an error. @@ -122,14 +122,14 @@ def set_nested(self, keys: Tuple[str, ...], value: Any) -> None: _recursive_update(self, override) @classmethod - def from_dict(cls, config: Optional[Dict[str, Any]]) -> 'NestedConfig': + def from_dict(cls, config: Optional[Dict[str, Any]]) -> 'Config': if config is None: return cls() return cls(**config) # The loaded config. -_dict = NestedConfig() +_dict = Config() _loaded_config_path: Optional[str] = None @@ -182,8 +182,8 @@ def get_nested(keys: Tuple[str, ...], return _dict.get_nested(keys, default_value, override_configs) -def _recursive_update(base_config: NestedConfig, - override_config: Dict[str, Any]) -> NestedConfig: +def _recursive_update(base_config: Config, + override_config: Dict[str, Any]) -> Config: """Recursively updates base configuration with override configuration""" for key, value in override_config.items(): if (isinstance(value, dict) and key in base_config and @@ -204,7 +204,7 @@ def set_nested(keys: Tuple[str, ...], value: Any) -> Dict[str, Any]: return dict(**copied_dict) -def to_dict() -> NestedConfig: +def to_dict() -> Config: """Returns a deep-copied version of the current config.""" return copy.deepcopy(_dict) @@ -228,7 +228,7 @@ def _try_load_config() -> None: logger.debug(f'Using config path: {config_path}') try: config = common_utils.read_yaml(config_path) - _dict = NestedConfig.from_dict(config) + _dict = Config.from_dict(config) _loaded_config_path = config_path logger.debug(f'Config loaded:\n{pprint.pformat(_dict)}') except yaml.YAMLError as e: diff --git a/sky/utils/admin_policy_utils.py b/sky/utils/admin_policy_utils.py index d67e2f756ed..f0f7fe6eab8 100644 --- a/sky/utils/admin_policy_utils.py +++ b/sky/utils/admin_policy_utils.py @@ -53,13 +53,14 @@ def _get_policy_cls( def apply( entrypoint: Union['dag_lib.Dag', 'task_lib.Task'], - apply_skypilot_config: bool = True, - operation_args: Optional[admin_policy.OperationArgs] = None, -) -> Tuple['dag_lib.Dag', skypilot_config.NestedConfig]: + update_skypilot_config_for_current_request: bool = True, + operation_args: Optional[admin_policy.RequestOptions] = None, +) -> Tuple['dag_lib.Dag', skypilot_config.Config]: """Applies an admin policy (if registered) to a DAG or a task. - It mutates a Dag by applying user-defined policy and also updates the - global SkyPilot config if there is any changes made by the policy. + It mutates a Dag by applying any registered admin policy and also + potentially updates (controlled by `apply_skypilot_config`) the global + SkyPilot config if there is any changes made by the policy. Args: dag: The dag to be mutated by the policy. @@ -68,8 +69,8 @@ def apply( operation_args: Additional arguments user passed in SkyPilot operations. Returns: - - The mutated dag or task. - - The mutated skypilot config. + - The new copy of dag after applying the policy + - The new copy of skypilot config after applying the policy. """ if isinstance(entrypoint, task_lib.Task): dag = dag_lib.Dag() @@ -119,7 +120,7 @@ def apply( mutated_dag.graph.add_edge(mutated_dag.tasks[u_idx], mutated_dag.tasks[v_idx]) - if apply_skypilot_config and original_config != mutated_config: + if (update_skypilot_config_for_current_request and original_config != mutated_config): with tempfile.NamedTemporaryFile( delete=False, mode='w', diff --git a/sky/utils/controller_utils.py b/sky/utils/controller_utils.py index de771b6a4c7..118f9a2b718 100644 --- a/sky/utils/controller_utils.py +++ b/sky/utils/controller_utils.py @@ -44,8 +44,11 @@ '{controller_type}.controller.resources is a valid resources spec. ' 'Details:\n {err}') -# The placeholder for the local skypilot config path in file mounts. -_LOCAL_SKYPILOT_CONFIG_PATH_SUFFIX_IDENTIFIER = ( +# The suffix for local skypilot config path for a job/service in file mounts +# that tells the controller logic to update the config with specific settings, +# e.g., removing the ssh_proxy_command when a job/service is launched in a same +# cloud as controller. +_LOCAL_SKYPILOT_CONFIG_PATH_SUFFIX = ( '__skypilot:local_skypilot_config_path.yaml') @@ -356,11 +359,13 @@ def shared_controller_vars_to_fill( if not local_user_config: local_user_config_path = None else: + # Remove admin_policy from local_user_config so that it is not applied + # again on the controller. This is required since admin_policy is not + # installed on the controller. local_user_config.pop('admin_policy', None) with tempfile.NamedTemporaryFile( delete=False, - suffix=_LOCAL_SKYPILOT_CONFIG_PATH_SUFFIX_IDENTIFIER - ) as temp_file: + suffix=_LOCAL_SKYPILOT_CONFIG_PATH_SUFFIX) as temp_file: common_utils.dump_yaml(temp_file.name, dict(**local_user_config)) local_user_config_path = temp_file.name @@ -495,7 +500,7 @@ def get_controller_resources( def _setup_proxy_command_on_controller( controller_launched_cloud: 'clouds.Cloud', - user_config: Dict[str, Any]) -> skypilot_config.NestedConfig: + user_config: Dict[str, Any]) -> skypilot_config.Config: """Sets up proxy command on the controller. This function should be called on the controller (remote cluster), which @@ -529,7 +534,7 @@ def _setup_proxy_command_on_controller( # (or name). It may not be a sufficient check (as it's always # possible that peering is not set up), but it may catch some # obvious errors. - config = skypilot_config.NestedConfig.from_dict(user_config) + config = skypilot_config.Config.from_dict(user_config) proxy_command_key = (str(controller_launched_cloud).lower(), 'ssh_proxy_command') ssh_proxy_command = config.get_nested(proxy_command_key, None) @@ -560,7 +565,7 @@ def replace_skypilot_config_path_in_file_mounts( if local_path is None: del file_mounts[remote_path] continue - if _LOCAL_SKYPILOT_CONFIG_PATH_SUFFIX_IDENTIFIER in local_path: + if local_path.endswith(_LOCAL_SKYPILOT_CONFIG_PATH_SUFFIX): with tempfile.NamedTemporaryFile('w', delete=False) as f: user_config = common_utils.read_yaml(local_path) config = _setup_proxy_command_on_controller(cloud, user_config) @@ -568,9 +573,8 @@ def replace_skypilot_config_path_in_file_mounts( file_mounts[remote_path] = f.name replaced = True if replaced: - logger.debug( - f'Replaced {_LOCAL_SKYPILOT_CONFIG_PATH_SUFFIX_IDENTIFIER} ' - f'with the real path in file mounts: {file_mounts}') + logger.debug(f'Replaced {_LOCAL_SKYPILOT_CONFIG_PATH_SUFFIX} ' + f'with the real path in file mounts: {file_mounts}') def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', diff --git a/tests/test_config.py b/tests/test_config.py index d3a2517bba5..5789214dc61 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -22,7 +22,7 @@ def _reload_config() -> None: - skypilot_config._dict = skypilot_config.NestedConfig() + skypilot_config._dict = skypilot_config.Config() skypilot_config._loaded_config_path = None skypilot_config._try_load_config() @@ -101,7 +101,7 @@ def _create_task_yaml_file(task_file_path: pathlib.Path) -> None: def test_nested_config(monkeypatch) -> None: """Test that the nested config works.""" - config = skypilot_config.NestedConfig() + config = skypilot_config.Config() config.set_nested(('aws', 'ssh_proxy_command'), 'value') assert config == {'aws': {'ssh_proxy_command': 'value'}} diff --git a/tests/unit_tests/test_admin_policy.py b/tests/unit_tests/test_admin_policy.py index 582be733d10..eaa8ddeb711 100644 --- a/tests/unit_tests/test_admin_policy.py +++ b/tests/unit_tests/test_admin_policy.py @@ -26,13 +26,13 @@ def add_example_policy_paths(): def _load_task_and_apply_policy( config_path: str, idle_minutes_to_autostop: Optional[int] = None -) -> Tuple[sky.Dag, skypilot_config.NestedConfig]: +) -> Tuple[sky.Dag, skypilot_config.Config]: os.environ['SKYPILOT_CONFIG'] = config_path importlib.reload(skypilot_config) task = sky.Task.from_yaml(os.path.join(POLICY_PATH, 'task.yaml')) return admin_policy_utils.apply( task, - operation_args=sky.OperationArgs( + operation_args=sky.RequestOptions( cluster_name='test', cluster_exists=False, idle_minutes_to_autostop=idle_minutes_to_autostop, From 07c47484af0cbba2db67f6a58c411a3c641b7c38 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sun, 22 Sep 2024 19:42:01 +0000 Subject: [PATCH 36/61] format --- sky/utils/admin_policy_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sky/utils/admin_policy_utils.py b/sky/utils/admin_policy_utils.py index f0f7fe6eab8..f2edaa2792d 100644 --- a/sky/utils/admin_policy_utils.py +++ b/sky/utils/admin_policy_utils.py @@ -120,7 +120,8 @@ def apply( mutated_dag.graph.add_edge(mutated_dag.tasks[u_idx], mutated_dag.tasks[v_idx]) - if (update_skypilot_config_for_current_request and original_config != mutated_config): + if (update_skypilot_config_for_current_request and + original_config != mutated_config): with tempfile.NamedTemporaryFile( delete=False, mode='w', From 994272b35f0c61df204d838793363fa01a1cd601 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sun, 22 Sep 2024 19:43:29 +0000 Subject: [PATCH 37/61] changes in examples --- .../example_policy/skypilot_policy.py | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/examples/admin_policy/example_policy/example_policy/skypilot_policy.py b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py index d66ef0e7ad4..0ab74c06d58 100644 --- a/examples/admin_policy/example_policy/example_policy/skypilot_policy.py +++ b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py @@ -5,15 +5,15 @@ class TaskLabelPolicy(sky.AdminPolicy): - """Example policy: add label for task with the local user name.""" + """Example policy: adds a label of the local user name to all tasks.""" @classmethod def validate_and_mutate( cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: - """Add label for task with the local user name.""" + """Adds a label for task with the local user name.""" local_user_name = getpass.getuser() - # Add label for task with the local user name + # Adds a label for task with the local user name task = user_request.task for r in task.resources: r.labels['local_user'] = local_user_name @@ -23,15 +23,15 @@ def validate_and_mutate( class ConfigLabelPolicy(sky.AdminPolicy): - """Example policy: add label for skypilot_config with the local user name.""" + """Example policy: adds a label for skypilot_config with local user name.""" @classmethod def validate_and_mutate( cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: - """Add label for skypilot_config with the local user name.""" + """Adds a label for skypilot_config with local user name.""" local_user_name = getpass.getuser() - # Add label for skypilot_config with the local user name + # Adds label for skypilot_config with the local user name skypilot_config = copy.deepcopy(user_request.skypilot_config) skypilot_config.set_nested(('gcp', 'labels', 'local_user'), local_user_name) @@ -40,12 +40,12 @@ def validate_and_mutate( class RejectAllPolicy(sky.AdminPolicy): - """Example policy: reject all user requests.""" + """Example policy: rejects all user requests.""" @classmethod def validate_and_mutate( cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: - """Reject all user requests.""" + """Rejects all user requests.""" del user_request raise RuntimeError('Reject all policy') @@ -56,15 +56,17 @@ class EnforceAutostopPolicy(sky.AdminPolicy): @classmethod def validate_and_mutate( cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: - """Enforce autostop for all tasks.""" - operation_args = user_request.operation_args - if operation_args is None: + """Enforces autostop for all tasks.""" + request_options = user_request.request_options + # Request options is None when a task is executed with `jobs launch` or + # `sky serve up`. + if request_options is None: return sky.MutatedUserRequest( task=user_request.task, skypilot_config=user_request.skypilot_config) - idle_minutes_to_autostop = operation_args.idle_minutes_to_autostop + idle_minutes_to_autostop = request_options.idle_minutes_to_autostop # Enforce autostop/down to be set for all tasks for new clusters. - if not operation_args.cluster_exists and ( + if not request_options.cluster_exists and ( idle_minutes_to_autostop is None or idle_minutes_to_autostop < 0): raise RuntimeError('Autostop/down must be set for all newly ' From 3597dae3a46fb8442dd3b973da9b27cb312c4095 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sun, 22 Sep 2024 20:15:43 +0000 Subject: [PATCH 38/61] Fix enforce autostop --- .../example_policy/skypilot_policy.py | 2 +- sky/admin_policy.py | 12 +++++++++++- sky/execution.py | 14 ++++++++++---- sky/utils/admin_policy_utils.py | 1 + 4 files changed, 23 insertions(+), 6 deletions(-) diff --git a/examples/admin_policy/example_policy/example_policy/skypilot_policy.py b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py index 0ab74c06d58..8c31a047c05 100644 --- a/examples/admin_policy/example_policy/example_policy/skypilot_policy.py +++ b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py @@ -66,7 +66,7 @@ def validate_and_mutate( skypilot_config=user_request.skypilot_config) idle_minutes_to_autostop = request_options.idle_minutes_to_autostop # Enforce autostop/down to be set for all tasks for new clusters. - if not request_options.cluster_exists and ( + if not request_options.cluster_running and ( idle_minutes_to_autostop is None or idle_minutes_to_autostop < 0): raise RuntimeError('Autostop/down must be set for all newly ' diff --git a/sky/admin_policy.py b/sky/admin_policy.py index 9d524f1ce19..835760aae5f 100644 --- a/sky/admin_policy.py +++ b/sky/admin_policy.py @@ -10,8 +10,18 @@ @dataclasses.dataclass class RequestOptions: + """Request options for admin policy. + + Args: + cluster_name: Name of the cluster to create/reuse. + cluster_running: Whether the cluster is running. + idle_minutes_to_autostop: If provided, the cluster will be set to + autostop after this many minutes of idleness. + down: Whether to down the cluster. + dryrun: Whether to dryrun the request. + """ cluster_name: Optional[str] - cluster_exists: bool + cluster_running: bool idle_minutes_to_autostop: Optional[int] down: bool dryrun: bool diff --git a/sky/execution.py b/sky/execution.py index 97822618684..ca793c49927 100644 --- a/sky/execution.py +++ b/sky/execution.py @@ -15,6 +15,7 @@ from sky import global_user_state from sky import optimizer from sky import sky_logging +from sky import status_lib from sky.backends import backend_utils from sky.usage import usage_lib from sky.utils import admin_policy_utils @@ -161,10 +162,15 @@ def _execute( if dryrun. """ cluster_exists = False + cluster_running = False if cluster_name is not None: - existing_handle = global_user_state.get_handle_from_cluster_name( - cluster_name) - cluster_exists = existing_handle is not None + cluster_record = global_user_state.get_cluster_from_name(cluster_name) + cluster_exists = cluster_record is not None + cluster_running = cluster_exists and cluster_record['status'] in [ + status_lib.ClusterStatus.UP, + status_lib.ClusterStatus.INIT, + ] + # TODO(woosuk): If the cluster exists, print a warning that # `cpus` and `memory` are not used as a job scheduling constraint, # unlike `gpus`. @@ -174,7 +180,7 @@ def _execute( dag, operation_args=admin_policy.RequestOptions( cluster_name=cluster_name, - cluster_exists=cluster_exists, + cluster_running=cluster_running, idle_minutes_to_autostop=idle_minutes_to_autostop, down=down, dryrun=dryrun, diff --git a/sky/utils/admin_policy_utils.py b/sky/utils/admin_policy_utils.py index f2edaa2792d..f79d66f1f67 100644 --- a/sky/utils/admin_policy_utils.py +++ b/sky/utils/admin_policy_utils.py @@ -137,4 +137,5 @@ def apply( importlib.reload(skypilot_config) logger.debug(f'Mutated user request: {mutated_user_request}') + logger.info(f'Applied policy: {policy}') return mutated_dag, mutated_config From 43a60882118b464ee7fc444f7b88641018143195 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sun, 22 Sep 2024 20:27:03 +0000 Subject: [PATCH 39/61] Fix autostop enforcement --- docs/source/cloud-setup/policy.rst | 26 +++++++++++++++++++------- sky/execution.py | 9 +++++---- sky/utils/admin_policy_utils.py | 1 - 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/docs/source/cloud-setup/policy.rst b/docs/source/cloud-setup/policy.rst index 7cbb2c30d66..c61b8dacfcf 100644 --- a/docs/source/cloud-setup/policy.rst +++ b/docs/source/cloud-setup/policy.rst @@ -129,8 +129,11 @@ Reject All .. code-block:: python class RejectAllPolicy(sky.AdminPolicy): + """Example policy: rejects all user requests.""" + @classmethod def validate_and_mutate(cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: + """Rejects all user requests.""" raise RuntimeError("This policy rejects all user requests.") .. code-block:: yaml @@ -144,6 +147,8 @@ Add Kubernetes Labels for all Tasks .. code-block:: python class AddLabelsPolicy(sky.AdminPolicy): + """Example policy: adds a kubernetes label for skypilot_config.""" + @classmethod def validate_and_mutate(cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: config = user_request.skypilot_config @@ -163,6 +168,8 @@ Always Disable Public IP for AWS Tasks .. code-block:: python class DisablePublicIPPolicy(sky.AdminPolicy): + """Example policy: disables public IP for all tasks.""" + @classmethod def validate_and_mutate(cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: config = user_request.skypilot_config @@ -185,18 +192,22 @@ Enforce Autostop for all Tasks .. code-block:: python class EnforceAutostopPolicy(sky.AdminPolicy): + """Example policy: enforce autostop for all tasks.""" + @classmethod - def validate_and_mutate(cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: - operation_args = user_request.operation_args - # Operation args can be None for jobs and services, for which we - # don't need to enforce autostop, as they are already managed. - if operation_args is None: + def validate_and_mutate( + cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: + """Enforces autostop for all tasks.""" + request_options = user_request.request_options + # Request options is None when a task is executed with `jobs launch` or + # `sky serve up`. + if request_options is None: return sky.MutatedUserRequest( task=user_request.task, skypilot_config=user_request.skypilot_config) - idle_minutes_to_autostop = operation_args.idle_minutes_to_autostop + idle_minutes_to_autostop = request_options.idle_minutes_to_autostop # Enforce autostop/down to be set for all tasks for new clusters. - if not operation_args.cluster_exists and ( + if not request_options.cluster_running and ( idle_minutes_to_autostop is None or idle_minutes_to_autostop < 0): raise RuntimeError('Autostop/down must be set for all newly ' @@ -205,6 +216,7 @@ Enforce Autostop for all Tasks task=user_request.task, skypilot_config=user_request.skypilot_config) + .. code-block:: yaml admin_policy: examples.admin_policy.enforce_autostop.EnforceAutostopPolicy diff --git a/sky/execution.py b/sky/execution.py index ca793c49927..849c9a9bbe1 100644 --- a/sky/execution.py +++ b/sky/execution.py @@ -166,10 +166,11 @@ def _execute( if cluster_name is not None: cluster_record = global_user_state.get_cluster_from_name(cluster_name) cluster_exists = cluster_record is not None - cluster_running = cluster_exists and cluster_record['status'] in [ - status_lib.ClusterStatus.UP, - status_lib.ClusterStatus.INIT, - ] + cluster_running = (cluster_record is not None and + cluster_record['status'] in [ + status_lib.ClusterStatus.UP, + status_lib.ClusterStatus.INIT, + ]) # TODO(woosuk): If the cluster exists, print a warning that # `cpus` and `memory` are not used as a job scheduling constraint, diff --git a/sky/utils/admin_policy_utils.py b/sky/utils/admin_policy_utils.py index f79d66f1f67..f2edaa2792d 100644 --- a/sky/utils/admin_policy_utils.py +++ b/sky/utils/admin_policy_utils.py @@ -137,5 +137,4 @@ def apply( importlib.reload(skypilot_config) logger.debug(f'Mutated user request: {mutated_user_request}') - logger.info(f'Applied policy: {policy}') return mutated_dag, mutated_config From 8770d0b82735bbf279a83f801516a41b0159a7b3 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sun, 22 Sep 2024 20:29:38 +0000 Subject: [PATCH 40/61] fix test --- tests/unit_tests/test_admin_policy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit_tests/test_admin_policy.py b/tests/unit_tests/test_admin_policy.py index eaa8ddeb711..e8c1d75ac35 100644 --- a/tests/unit_tests/test_admin_policy.py +++ b/tests/unit_tests/test_admin_policy.py @@ -32,9 +32,9 @@ def _load_task_and_apply_policy( task = sky.Task.from_yaml(os.path.join(POLICY_PATH, 'task.yaml')) return admin_policy_utils.apply( task, - operation_args=sky.RequestOptions( + operation_args=sky.admin_policy.RequestOptions( cluster_name='test', - cluster_exists=False, + cluster_running=False, idle_minutes_to_autostop=idle_minutes_to_autostop, down=False, dryrun=False, From 7984beb100afd2a0a34f20542091e330f7bcbce8 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sun, 22 Sep 2024 17:11:59 -0700 Subject: [PATCH 41/61] Update docs/source/cloud-setup/policy.rst Co-authored-by: Zongheng Yang --- docs/source/cloud-setup/policy.rst | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/docs/source/cloud-setup/policy.rst b/docs/source/cloud-setup/policy.rst index c61b8dacfcf..d6d1b2ab808 100644 --- a/docs/source/cloud-setup/policy.rst +++ b/docs/source/cloud-setup/policy.rst @@ -4,13 +4,22 @@ Admin Policy Enforcement ======================== -SkyPilot allows admins to enforce policies on users' SkyPilot usage by applying -custom validation and mutation logic on user's task and SkyPilot config. - -In short, admins offers a Python package with a customized inheritance of SkyPilot's -``AdminPolicy`` interface, and a user just needs to set the ``admin_policy`` field in -the SkyPilot config ``~/.sky/config.yaml`` to enforce the policy to all their -tasks. +SkyPilot provides an **admin policy** mechanism that admins can use to enforce certain policies on users' SkyPilot usage. An admin policy applies +custom validation and mutation logic to a user's tasks and SkyPilot config. + +Example usage: + + - Adds custom labels to all tasks [Link to below, fix case] + - Always Disable Public IP for AWS Tasks [Link to below] + - Enforce Autostop for all Tasks [Link to below] + +To implement and use an admin policy: + +- Admins writes a simple Python package with a policy class that implements SkyPilot's +``sky.AdminPolicy`` interface; +- Admins distributes this package to users; +- Users simply set the ``admin_policy`` field in +the SkyPilot config file ``~/.sky/config.yaml`` for the policy to go into effect. Overview -------- From d155d6066bc1c4706e8dee3313c641388e75d9bc Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sun, 22 Sep 2024 17:12:25 -0700 Subject: [PATCH 42/61] Update sky/admin_policy.py Co-authored-by: Zongheng Yang --- sky/admin_policy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sky/admin_policy.py b/sky/admin_policy.py index 835760aae5f..0345ebe43e0 100644 --- a/sky/admin_policy.py +++ b/sky/admin_policy.py @@ -17,8 +17,8 @@ class RequestOptions: cluster_running: Whether the cluster is running. idle_minutes_to_autostop: If provided, the cluster will be set to autostop after this many minutes of idleness. - down: Whether to down the cluster. - dryrun: Whether to dryrun the request. + down: If true, use autodown rather than autostop. + dryrun: Is the request a dryrun? """ cluster_name: Optional[str] cluster_running: bool From 6ffa5ae8037ce68f6603c2e2d6f4f956fde42a06 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sun, 22 Sep 2024 17:26:34 -0700 Subject: [PATCH 43/61] Update sky/admin_policy.py Co-authored-by: Zongheng Yang --- sky/admin_policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sky/admin_policy.py b/sky/admin_policy.py index 0345ebe43e0..12b4a3e367a 100644 --- a/sky/admin_policy.py +++ b/sky/admin_policy.py @@ -68,7 +68,7 @@ def validate_and_mutate(user_request: UserRequest) -> MutatedUserRequest: ... return MutatedUserRequest(task=..., skypilot_config=...) - The policy can mutate both task and skypilot_config. + The policy can mutate both task and skypilot_config. Admins then distribute a simple module that contains this implementation, installable in a way that it can be imported by users from the same Python environment where SkyPilot is running. Users can register a subclass of AdminPolicy in the SkyPilot config file under the key 'admin_policy', e.g. From a6dd900d48fb6d362da06a192f241734953f9c56 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 23 Sep 2024 00:27:16 +0000 Subject: [PATCH 44/61] wip --- sky/execution.py | 2 +- sky/jobs/core.py | 2 +- sky/serve/core.py | 2 +- sky/utils/admin_policy_utils.py | 15 ++++++++------- tests/unit_tests/test_admin_policy.py | 2 +- tests/unit_tests/test_backend_utils.py | 10 +++------- 6 files changed, 15 insertions(+), 18 deletions(-) diff --git a/sky/execution.py b/sky/execution.py index 849c9a9bbe1..eb7ef731c26 100644 --- a/sky/execution.py +++ b/sky/execution.py @@ -179,7 +179,7 @@ def _execute( dag = dag_utils.convert_entrypoint_to_dag(entrypoint) dag, _ = admin_policy_utils.apply( dag, - operation_args=admin_policy.RequestOptions( + request_options=admin_policy.RequestOptions( cluster_name=cluster_name, cluster_running=cluster_running, idle_minutes_to_autostop=idle_minutes_to_autostop, diff --git a/sky/jobs/core.py b/sky/jobs/core.py index d11b61cb3ca..c4f59f65eca 100644 --- a/sky/jobs/core.py +++ b/sky/jobs/core.py @@ -56,7 +56,7 @@ def launch( dag = dag_utils.convert_entrypoint_to_dag(entrypoint) dag, mutated_user_config = admin_policy_utils.apply( - dag, update_skypilot_config_for_current_request=False) + dag, use_mutated_config_in_current_request=False) if not dag.is_chain(): with ux_utils.print_exception_no_traceback(): raise ValueError('Only single-task or chain DAG is ' diff --git a/sky/serve/core.py b/sky/serve/core.py index 46415a48040..2bb6e1384ee 100644 --- a/sky/serve/core.py +++ b/sky/serve/core.py @@ -126,7 +126,7 @@ def up( _validate_service_task(task) dag, mutated_user_config = admin_policy_utils.apply( - task, update_skypilot_config_for_current_request=False) + task, use_mutated_config_in_current_request=False) task = dag.tasks[0] controller_utils.maybe_translate_local_file_mounts_and_sync_up(task, diff --git a/sky/utils/admin_policy_utils.py b/sky/utils/admin_policy_utils.py index f2edaa2792d..7bc644018c4 100644 --- a/sky/utils/admin_policy_utils.py +++ b/sky/utils/admin_policy_utils.py @@ -53,20 +53,21 @@ def _get_policy_cls( def apply( entrypoint: Union['dag_lib.Dag', 'task_lib.Task'], - update_skypilot_config_for_current_request: bool = True, - operation_args: Optional[admin_policy.RequestOptions] = None, + use_mutated_config_in_current_request: bool = True, + request_options: Optional[admin_policy.RequestOptions] = None, ) -> Tuple['dag_lib.Dag', skypilot_config.Config]: """Applies an admin policy (if registered) to a DAG or a task. It mutates a Dag by applying any registered admin policy and also - potentially updates (controlled by `apply_skypilot_config`) the global - SkyPilot config if there is any changes made by the policy. + potentially updates (controlled by + `update_skypilot_config_for_current_request`) the global SkyPilot config + if there is any changes made by the policy. Args: dag: The dag to be mutated by the policy. apply_skypilot_config: Whether to apply the skypilot config changes to the global skypilot config. - operation_args: Additional arguments user passed in SkyPilot operations. + request_options: Additional arguments user passed in SkyPilot operations. Returns: - The new copy of dag after applying the policy @@ -91,7 +92,7 @@ def apply( mutated_config = None for task in dag.tasks: - user_request = admin_policy.UserRequest(task, config, operation_args) + user_request = admin_policy.UserRequest(task, config, request_options) try: mutated_user_request = policy_cls.validate_and_mutate(user_request) except Exception as e: # pylint: disable=broad-except @@ -120,7 +121,7 @@ def apply( mutated_dag.graph.add_edge(mutated_dag.tasks[u_idx], mutated_dag.tasks[v_idx]) - if (update_skypilot_config_for_current_request and + if (use_mutated_config_in_current_request and original_config != mutated_config): with tempfile.NamedTemporaryFile( delete=False, diff --git a/tests/unit_tests/test_admin_policy.py b/tests/unit_tests/test_admin_policy.py index e8c1d75ac35..c1c17510dbe 100644 --- a/tests/unit_tests/test_admin_policy.py +++ b/tests/unit_tests/test_admin_policy.py @@ -32,7 +32,7 @@ def _load_task_and_apply_policy( task = sky.Task.from_yaml(os.path.join(POLICY_PATH, 'task.yaml')) return admin_policy_utils.apply( task, - operation_args=sky.admin_policy.RequestOptions( + request_options=sky.admin_policy.RequestOptions( cluster_name='test', cluster_running=False, idle_minutes_to_autostop=idle_minutes_to_autostop, diff --git a/tests/unit_tests/test_backend_utils.py b/tests/unit_tests/test_backend_utils.py index cb1b83f1999..aa22f36df20 100644 --- a/tests/unit_tests/test_backend_utils.py +++ b/tests/unit_tests/test_backend_utils.py @@ -1,19 +1,14 @@ +import os import pathlib -from typing import Dict -from unittest.mock import Mock from unittest.mock import patch -import pytest - from sky import clouds from sky import skypilot_config from sky.backends import backend_utils from sky.resources import Resources -from sky.resources import resources_utils -@patch.object(skypilot_config, 'CONFIG_PATH', - './tests/test_yamls/test_aws_config.yaml') +# Set env var to test config file. @patch.object(skypilot_config, '_dict', None) @patch.object(skypilot_config, '_loaded_config_path', None) @patch('sky.clouds.service_catalog.instance_type_exists', return_value=True) @@ -29,6 +24,7 @@ @patch('sky.utils.common_utils.fill_template') def test_write_cluster_config_w_remote_identity(mock_fill_template, *mocks) -> None: + os.environ['SKYPILOT_CONFIG'] = './tests/test_yamls/test_aws_config.yaml' skypilot_config._try_load_config() cloud = clouds.AWS() From 4274287f7071382dfeee1eea5c508ac9bf181379 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sun, 22 Sep 2024 17:27:39 -0700 Subject: [PATCH 45/61] Update docs/source/cloud-setup/policy.rst Co-authored-by: Zongheng Yang --- docs/source/cloud-setup/policy.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/cloud-setup/policy.rst b/docs/source/cloud-setup/policy.rst index d6d1b2ab808..249c279a3ea 100644 --- a/docs/source/cloud-setup/policy.rst +++ b/docs/source/cloud-setup/policy.rst @@ -51,7 +51,7 @@ For example: Admin-Side ~~~~~~~~~~ -An admin can distribute the Python package to users with pre-defined policy. The +An admin can distribute the Python package to users with a pre-defined policy. The policy should follow the following interface: .. code-block:: python From 060948234b7a4126921b078912507040e12aca06 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sun, 22 Sep 2024 17:27:47 -0700 Subject: [PATCH 46/61] Update docs/source/cloud-setup/policy.rst Co-authored-by: Zongheng Yang --- docs/source/cloud-setup/policy.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/cloud-setup/policy.rst b/docs/source/cloud-setup/policy.rst index 249c279a3ea..b784f14f7ac 100644 --- a/docs/source/cloud-setup/policy.rst +++ b/docs/source/cloud-setup/policy.rst @@ -41,7 +41,7 @@ For example: .. hint:: SkyPilot loads the policy from the given package in the same Python environment. - You can test the existance of the policy by running: + You can test the existence of the policy by running: .. code-block:: bash From 67552d7a0863f0065ddd346048d5aec7785a0ff9 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sun, 22 Sep 2024 17:28:00 -0700 Subject: [PATCH 47/61] Update docs/source/cloud-setup/policy.rst Co-authored-by: Zongheng Yang --- docs/source/cloud-setup/policy.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/cloud-setup/policy.rst b/docs/source/cloud-setup/policy.rst index b784f14f7ac..ec3829fe45a 100644 --- a/docs/source/cloud-setup/policy.rst +++ b/docs/source/cloud-setup/policy.rst @@ -61,7 +61,7 @@ policy should follow the following interface: class MyPolicy(sky.AdminPolicy): @classmethod def validate_and_mutate(cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: - # Logics for validate and modify user requests. + # Logic for validate and modify user requests. ... return sky.MutatedUserRequest(user_request.task, user_request.skypilot_config) From 7de757e97180b461e22cbfb928c7b9b4c7610fb6 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 23 Sep 2024 00:28:21 +0000 Subject: [PATCH 48/61] fix --- docs/source/reference/config.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst index 4e0d901f36e..ebe8db6751f 100644 --- a/docs/source/reference/config.rst +++ b/docs/source/reference/config.rst @@ -87,10 +87,11 @@ Available fields and semantics: # Default: false. disable_ecc: false - # Custom policy to be applied to all tasks. (optional). + # Admin policy to be applied to all tasks. (optional). # # The policy class to be applied to all tasks, which can be used to validate # and mutate user requests. + # # This is useful for enforcing certain policies on all tasks, e.g., # add custom labels; enforce certain resource limits; etc. # From 7fbc30d806bc1fa05f5cecd640420ce70ee2c501 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 23 Sep 2024 00:41:43 +0000 Subject: [PATCH 49/61] fix --- docs/source/cloud-setup/policy.rst | 10 ++++---- .../example_policy/skypilot_policy.py | 24 ++++++++++++++----- sky/utils/admin_policy_utils.py | 6 ++--- 3 files changed, 26 insertions(+), 14 deletions(-) diff --git a/docs/source/cloud-setup/policy.rst b/docs/source/cloud-setup/policy.rst index ec3829fe45a..2cc3f4153cb 100644 --- a/docs/source/cloud-setup/policy.rst +++ b/docs/source/cloud-setup/policy.rst @@ -13,13 +13,13 @@ Example usage: - Always Disable Public IP for AWS Tasks [Link to below] - Enforce Autostop for all Tasks [Link to below] + To implement and use an admin policy: -- Admins writes a simple Python package with a policy class that implements SkyPilot's -``sky.AdminPolicy`` interface; -- Admins distributes this package to users; -- Users simply set the ``admin_policy`` field in -the SkyPilot config file ``~/.sky/config.yaml`` for the policy to go into effect. + - Admins writes a simple Python package with a policy class that implements SkyPilot's ``sky.AdminPolicy`` interface; + - Admins distributes this package to users; + - Users simply set the ``admin_policy`` field in the SkyPilot config file ``~/.sky/config.yaml`` for the policy to go into effect. + Overview -------- diff --git a/examples/admin_policy/example_policy/example_policy/skypilot_policy.py b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py index 8c31a047c05..ad73f46e000 100644 --- a/examples/admin_policy/example_policy/example_policy/skypilot_policy.py +++ b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py @@ -64,13 +64,25 @@ def validate_and_mutate( return sky.MutatedUserRequest( task=user_request.task, skypilot_config=user_request.skypilot_config) + cluster_record = sky.status(request_options.cluster_name, refresh=True) + need_autostop = False + if not cluster_record: + # Cluster does not exist + need_autostop = True + elif cluster_record[0]['status'] == sky.ClusterStatus.STOPPED: + # Cluster is stopped + need_autostop = True + elif cluster_record[0]['autostop'] < 0: + # Cluster is running but autostop is not set + need_autostop = True + + is_setting_autostop = False idle_minutes_to_autostop = request_options.idle_minutes_to_autostop - # Enforce autostop/down to be set for all tasks for new clusters. - if not request_options.cluster_running and ( - idle_minutes_to_autostop is None or - idle_minutes_to_autostop < 0): - raise RuntimeError('Autostop/down must be set for all newly ' - 'launched clusters.') + is_setting_autostop = (idle_minutes_to_autostop is not None and + idle_minutes_to_autostop >= 0) + if need_autostop and not is_setting_autostop: + raise RuntimeError('Autostop/down must be set for all clusters.') + return sky.MutatedUserRequest( task=user_request.task, skypilot_config=user_request.skypilot_config) diff --git a/sky/utils/admin_policy_utils.py b/sky/utils/admin_policy_utils.py index 7bc644018c4..94db4b39fef 100644 --- a/sky/utils/admin_policy_utils.py +++ b/sky/utils/admin_policy_utils.py @@ -65,9 +65,9 @@ def apply( Args: dag: The dag to be mutated by the policy. - apply_skypilot_config: Whether to apply the skypilot config changes to - the global skypilot config. - request_options: Additional arguments user passed in SkyPilot operations. + use_mutated_config_in_current_request: Whether to use the mutated + config in the current request. + request_options: Additional options user passed for the current request. Returns: - The new copy of dag after applying the policy From 92b68fc22330fd86bf381c0683bb0f11689e91e6 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 23 Sep 2024 01:36:40 +0000 Subject: [PATCH 50/61] fix --- sky/utils/admin_policy_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sky/utils/admin_policy_utils.py b/sky/utils/admin_policy_utils.py index 94db4b39fef..a18a9512ff4 100644 --- a/sky/utils/admin_policy_utils.py +++ b/sky/utils/admin_policy_utils.py @@ -59,9 +59,8 @@ def apply( """Applies an admin policy (if registered) to a DAG or a task. It mutates a Dag by applying any registered admin policy and also - potentially updates (controlled by - `update_skypilot_config_for_current_request`) the global SkyPilot config - if there is any changes made by the policy. + potentially updates (controlled by `use_mutated_config_in_current_request`) + the global SkyPilot config if there is any changes made by the policy. Args: dag: The dag to be mutated by the policy. From 7d8af9ab1d9320a2bc982fd146aaf5cc8f8d8103 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 23 Sep 2024 04:20:59 +0000 Subject: [PATCH 51/61] Use sky.status for autostop --- docs/source/cloud-setup/policy.rst | 55 +++++-- examples/admin_policy/add_labels.yaml | 1 + .../admin_policy/config_label_config.yaml | 1 - examples/admin_policy/disable_public_ip.yaml | 1 + .../example_policy/example_policy/__init__.py | 6 +- .../example_policy/skypilot_policy.py | 97 +++++++----- ...reject_all_config.yaml => reject_all.yaml} | 0 examples/admin_policy/task_label_config.yaml | 1 - examples/admin_policy/use_spot_for_gpu.yaml | 1 + sky/admin_policy.py | 29 ++-- sky/execution.py | 9 -- tests/unit_tests/test_admin_policy.py | 138 +++++++++++++++--- 12 files changed, 248 insertions(+), 91 deletions(-) create mode 100644 examples/admin_policy/add_labels.yaml delete mode 100644 examples/admin_policy/config_label_config.yaml create mode 100644 examples/admin_policy/disable_public_ip.yaml rename examples/admin_policy/{reject_all_config.yaml => reject_all.yaml} (100%) delete mode 100644 examples/admin_policy/task_label_config.yaml create mode 100644 examples/admin_policy/use_spot_for_gpu.yaml diff --git a/docs/source/cloud-setup/policy.rst b/docs/source/cloud-setup/policy.rst index 2cc3f4153cb..482ee262ad8 100644 --- a/docs/source/cloud-setup/policy.rst +++ b/docs/source/cloud-setup/policy.rst @@ -132,6 +132,14 @@ The ``sky.Config`` and ``sky.RequestOptions`` are defined as follows: Example Policies ---------------- +We have provided a few example policies in `examples/admin_policy/example_policy `_. You can test these policies by installing the example policy package in your Python environment. + +.. code-block:: bash + + git clone https://github.com/skypilot-org/skypilot.git + cd skypilot + pip install examples/admin_policy/example_policy + Reject All ~~~~~~~~~~ @@ -147,7 +155,7 @@ Reject All .. code-block:: yaml - admin_policy: examples.admin_policy.reject_all.RejectAllPolicy + admin_policy: example_policy.RejectAllPolicy Add Kubernetes Labels for all Tasks @@ -168,7 +176,7 @@ Add Kubernetes Labels for all Tasks .. code-block:: yaml - admin_policy: examples.admin_policy.add_labels.AddLabelsPolicy + admin_policy: example_policy.AddLabelsPolicy Always Disable Public IP for AWS Tasks @@ -192,7 +200,7 @@ Always Disable Public IP for AWS Tasks .. code-block:: yaml - admin_policy: examples.admin_policy.disable_public_ip.DisablePublicIPPolicy + admin_policy: example_policy.DisablePublicIPPolicy Enforce Autostop for all Tasks @@ -206,21 +214,46 @@ Enforce Autostop for all Tasks @classmethod def validate_and_mutate( cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: - """Enforces autostop for all tasks.""" + """Enforces autostop for all tasks. + + Note that with this policy enforced, users can still change the autostop + setting for an existing cluster by using `sky autostop`. + """ request_options = user_request.request_options + # Request options is None when a task is executed with `jobs launch` or # `sky serve up`. if request_options is None: return sky.MutatedUserRequest( task=user_request.task, skypilot_config=user_request.skypilot_config) + + # Get the cluster record to operate on. + cluster_record = sky.status(request_options.cluster_name, refresh=True) + + # Check if the user request should specify autostop settings. + need_autostop = False + if not cluster_record: + # Cluster does not exist + need_autostop = True + elif cluster_record[0]['status'] == sky.ClusterStatus.STOPPED: + # Cluster is stopped + need_autostop = True + elif cluster_record[0]['autostop'] < 0: + # Cluster is running but autostop is not set + need_autostop = True + + # Check if the user request is setting autostop settings. + is_setting_autostop = False idle_minutes_to_autostop = request_options.idle_minutes_to_autostop - # Enforce autostop/down to be set for all tasks for new clusters. - if not request_options.cluster_running and ( - idle_minutes_to_autostop is None or - idle_minutes_to_autostop < 0): - raise RuntimeError('Autostop/down must be set for all newly ' - 'launched clusters.') + is_setting_autostop = (idle_minutes_to_autostop is not None and + idle_minutes_to_autostop >= 0) + + # If the cluster requires autostop but the user request is not setting + # autostop settings, raise an error. + if need_autostop and not is_setting_autostop: + raise RuntimeError('Autostop/down must be set for all clusters.') + return sky.MutatedUserRequest( task=user_request.task, skypilot_config=user_request.skypilot_config) @@ -228,4 +261,4 @@ Enforce Autostop for all Tasks .. code-block:: yaml - admin_policy: examples.admin_policy.enforce_autostop.EnforceAutostopPolicy + admin_policy: example_policy.EnforceAutostopPolicy diff --git a/examples/admin_policy/add_labels.yaml b/examples/admin_policy/add_labels.yaml new file mode 100644 index 00000000000..113b3b78044 --- /dev/null +++ b/examples/admin_policy/add_labels.yaml @@ -0,0 +1 @@ +admin_policy: example_policy.AddLabelsPolicy diff --git a/examples/admin_policy/config_label_config.yaml b/examples/admin_policy/config_label_config.yaml deleted file mode 100644 index 9228986028d..00000000000 --- a/examples/admin_policy/config_label_config.yaml +++ /dev/null @@ -1 +0,0 @@ -admin_policy: example_policy.ConfigLabelPolicy diff --git a/examples/admin_policy/disable_public_ip.yaml b/examples/admin_policy/disable_public_ip.yaml new file mode 100644 index 00000000000..eb00757f409 --- /dev/null +++ b/examples/admin_policy/disable_public_ip.yaml @@ -0,0 +1 @@ +admin_policy: example_policy.DisablePublicIPPolicy diff --git a/examples/admin_policy/example_policy/example_policy/__init__.py b/examples/admin_policy/example_policy/example_policy/__init__.py index 37baacc7d01..862dfbd3d9f 100644 --- a/examples/admin_policy/example_policy/example_policy/__init__.py +++ b/examples/admin_policy/example_policy/example_policy/__init__.py @@ -1,6 +1,6 @@ """Example admin policy module and prebuilt policies.""" - -from example_policy.skypilot_policy import ConfigLabelPolicy +from example_policy.skypilot_policy import AddLabelsPolicy +from example_policy.skypilot_policy import DisablePublicIPPolicy from example_policy.skypilot_policy import EnforceAutostopPolicy from example_policy.skypilot_policy import RejectAllPolicy -from example_policy.skypilot_policy import TaskLabelPolicy +from example_policy.skypilot_policy import UseSpotForGPUPolicy diff --git a/examples/admin_policy/example_policy/example_policy/skypilot_policy.py b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py index ad73f46e000..5f126d6c63a 100644 --- a/examples/admin_policy/example_policy/example_policy/skypilot_policy.py +++ b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py @@ -1,53 +1,66 @@ -import copy -import getpass - +"""Example prebuilt admin policies.""" import sky -class TaskLabelPolicy(sky.AdminPolicy): - """Example policy: adds a label of the local user name to all tasks.""" +class RejectAllPolicy(sky.AdminPolicy): + """Example policy: rejects all user requests.""" @classmethod def validate_and_mutate( cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: - """Adds a label for task with the local user name.""" - local_user_name = getpass.getuser() + """Rejects all user requests.""" + raise RuntimeError('Reject all policy') - # Adds a label for task with the local user name - task = user_request.task - for r in task.resources: - r.labels['local_user'] = local_user_name - return sky.MutatedUserRequest( - task=task, skypilot_config=user_request.skypilot_config) +class AddLabelsPolicy(sky.AdminPolicy): + """Example policy: adds a kubernetes label for skypilot_config.""" + + @classmethod + def validate_and_mutate( + cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: + config = user_request.skypilot_config + labels = config.get_nested(('kubernetes', 'custom_metadata', 'labels'), + {}) + labels['app'] = 'skypilot' + config.set_nested(('kubernetes', 'custom_metadata', 'labels'), labels) + return sky.MutatedUserRequest(user_request.task, config) -class ConfigLabelPolicy(sky.AdminPolicy): - """Example policy: adds a label for skypilot_config with local user name.""" +class DisablePublicIPPolicy(sky.AdminPolicy): + """Example policy: disables public IP for all tasks.""" @classmethod def validate_and_mutate( cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: - """Adds a label for skypilot_config with local user name.""" - local_user_name = getpass.getuser() - - # Adds label for skypilot_config with the local user name - skypilot_config = copy.deepcopy(user_request.skypilot_config) - skypilot_config.set_nested(('gcp', 'labels', 'local_user'), - local_user_name) - return sky.MutatedUserRequest(task=user_request.task, - skypilot_config=skypilot_config) + config = user_request.skypilot_config + config.set_nested(('aws', 'use_internal_ip'), True) + if config.get_nested(('aws', 'vpc_name'), None) is None: + # If no VPC name is specified, it is likely a mistake. We should + # reject the request + raise RuntimeError('VPC name should be set. Check organization ' + 'wiki for more information.') + return sky.MutatedUserRequest(user_request.task, config) -class RejectAllPolicy(sky.AdminPolicy): - """Example policy: rejects all user requests.""" +class UseSpotForGPUPolicy(sky.AdminPolicy): + """Example policy: use spot instances for all GPU tasks.""" @classmethod def validate_and_mutate( cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: - """Rejects all user requests.""" - del user_request - raise RuntimeError('Reject all policy') + """Sets use_spot to True for all GPU tasks.""" + task = user_request.task + new_resources = [] + for r in task.resources: + if r.accelerators: + new_resources.append(r.copy(use_spot=True)) + else: + new_resources.append(r) + + task.set_resources(type(task.resources)(new_resources)) + + return sky.MutatedUserRequest( + task=task, skypilot_config=user_request.skypilot_config) class EnforceAutostopPolicy(sky.AdminPolicy): @@ -56,30 +69,46 @@ class EnforceAutostopPolicy(sky.AdminPolicy): @classmethod def validate_and_mutate( cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: - """Enforces autostop for all tasks.""" + """Enforces autostop for all tasks. + + Note that with this policy enforced, users can still change the autostop + setting for an existing cluster by using `sky autostop`. + """ request_options = user_request.request_options + # Request options is None when a task is executed with `jobs launch` or # `sky serve up`. if request_options is None: return sky.MutatedUserRequest( task=user_request.task, skypilot_config=user_request.skypilot_config) - cluster_record = sky.status(request_options.cluster_name, refresh=True) + + # Get the cluster record to operate on. + cluster_name = request_options.cluster_name + cluster_records = [] + if cluster_name is not None: + cluster_records = sky.status(cluster_name, refresh=True) + + # Check if the user request should specify autostop settings. need_autostop = False - if not cluster_record: + if not cluster_records: # Cluster does not exist need_autostop = True - elif cluster_record[0]['status'] == sky.ClusterStatus.STOPPED: + elif cluster_records[0]['status'] == sky.ClusterStatus.STOPPED: # Cluster is stopped need_autostop = True - elif cluster_record[0]['autostop'] < 0: + elif cluster_records[0]['autostop'] < 0: # Cluster is running but autostop is not set need_autostop = True + # Check if the user request is setting autostop settings. is_setting_autostop = False idle_minutes_to_autostop = request_options.idle_minutes_to_autostop is_setting_autostop = (idle_minutes_to_autostop is not None and idle_minutes_to_autostop >= 0) + + # If the cluster requires autostop but the user request is not setting + # autostop settings, raise an error. if need_autostop and not is_setting_autostop: raise RuntimeError('Autostop/down must be set for all clusters.') diff --git a/examples/admin_policy/reject_all_config.yaml b/examples/admin_policy/reject_all.yaml similarity index 100% rename from examples/admin_policy/reject_all_config.yaml rename to examples/admin_policy/reject_all.yaml diff --git a/examples/admin_policy/task_label_config.yaml b/examples/admin_policy/task_label_config.yaml deleted file mode 100644 index f21774e7086..00000000000 --- a/examples/admin_policy/task_label_config.yaml +++ /dev/null @@ -1 +0,0 @@ -admin_policy: example_policy.TaskLabelPolicy diff --git a/examples/admin_policy/use_spot_for_gpu.yaml b/examples/admin_policy/use_spot_for_gpu.yaml new file mode 100644 index 00000000000..d1234c27729 --- /dev/null +++ b/examples/admin_policy/use_spot_for_gpu.yaml @@ -0,0 +1 @@ +admin_policy: example_policy.UseSpotForGPUPolicy diff --git a/sky/admin_policy.py b/sky/admin_policy.py index 12b4a3e367a..7c17eaae258 100644 --- a/sky/admin_policy.py +++ b/sky/admin_policy.py @@ -20,8 +20,8 @@ class RequestOptions: down: If true, use autodown rather than autostop. dryrun: Is the request a dryrun? """ + # Cluster name is None if not specified by the user. cluster_name: Optional[str] - cluster_running: bool idle_minutes_to_autostop: Optional[int] down: bool dryrun: bool @@ -29,16 +29,23 @@ class RequestOptions: @dataclasses.dataclass class UserRequest: - """User request to the policy. + """A user request. - It is a combination of a task, request options, and the global skypilot - config used to run a task, including `sky launch / exec / jobs launch / ..`. + A "user request" is defined as a `sky launch / exec` command or its API + equivalent. + + `sky jobs launch / serve up` involves multiple launch requests, including + the launch of controller and clusters for a job (which can have multiple + tasks if it is a pipeline) or service replicas. Each launch is a separate + request. + + This class wraps the underlying task, the global skypilot config used to run + a task, and the request options. Args: task: User specified task. skypilot_config: Global skypilot config to be used in this request. - request_options: Request options. It can be None for jobs and - services. + request_options: Request options. It is None for jobs and services. """ task: 'sky.Task' skypilot_config: 'sky.Config' @@ -55,10 +62,6 @@ class MutatedUserRequest: class AdminPolicy: """Abstract interface of an admin-defined policy for all user requests. - A policy is a string to a python class inheriting from AdminPolicy that can - be imported from the same environment where SkyPilot is running. - - Admins can implement a subclass of AdminPolicy with the following signature: import sky @@ -68,7 +71,10 @@ def validate_and_mutate(user_request: UserRequest) -> MutatedUserRequest: ... return MutatedUserRequest(task=..., skypilot_config=...) - The policy can mutate both task and skypilot_config. Admins then distribute a simple module that contains this implementation, installable in a way that it can be imported by users from the same Python environment where SkyPilot is running. + The policy can mutate both task and skypilot_config. Admins then distribute + a simple module that contains this implementation, installable in a way + that it can be imported by users from the same Python environment where + SkyPilot is running. Users can register a subclass of AdminPolicy in the SkyPilot config file under the key 'admin_policy', e.g. @@ -88,7 +94,6 @@ def validate_and_mutate(cls, Returns: MutatedUserRequest: The mutated user request. - MutatedUserRequest contains (sky.Task, sky.Config) Raises: Exception to throw if the user request failed the validation. diff --git a/sky/execution.py b/sky/execution.py index eb7ef731c26..0b2663765b1 100644 --- a/sky/execution.py +++ b/sky/execution.py @@ -15,7 +15,6 @@ from sky import global_user_state from sky import optimizer from sky import sky_logging -from sky import status_lib from sky.backends import backend_utils from sky.usage import usage_lib from sky.utils import admin_policy_utils @@ -162,16 +161,9 @@ def _execute( if dryrun. """ cluster_exists = False - cluster_running = False if cluster_name is not None: cluster_record = global_user_state.get_cluster_from_name(cluster_name) cluster_exists = cluster_record is not None - cluster_running = (cluster_record is not None and - cluster_record['status'] in [ - status_lib.ClusterStatus.UP, - status_lib.ClusterStatus.INIT, - ]) - # TODO(woosuk): If the cluster exists, print a warning that # `cpus` and `memory` are not used as a job scheduling constraint, # unlike `gpus`. @@ -181,7 +173,6 @@ def _execute( dag, request_options=admin_policy.RequestOptions( cluster_name=cluster_name, - cluster_running=cluster_running, idle_minutes_to_autostop=idle_minutes_to_autostop, down=down, dryrun=dryrun, diff --git a/tests/unit_tests/test_admin_policy.py b/tests/unit_tests/test_admin_policy.py index c1c17510dbe..70f62c4f1e2 100644 --- a/tests/unit_tests/test_admin_policy.py +++ b/tests/unit_tests/test_admin_policy.py @@ -2,6 +2,7 @@ import os import sys from typing import Optional, Tuple +from unittest.mock import patch import pytest @@ -23,50 +24,147 @@ def add_example_policy_paths(): sys.path.append(os.path.join(POLICY_PATH, 'example_policy')) +@pytest.fixture +def task(): + return sky.Task.from_yaml(os.path.join(POLICY_PATH, 'task.yaml')) + + def _load_task_and_apply_policy( + task: sky.Task, config_path: str, - idle_minutes_to_autostop: Optional[int] = None + idle_minutes_to_autostop: Optional[int] = None, ) -> Tuple[sky.Dag, skypilot_config.Config]: os.environ['SKYPILOT_CONFIG'] = config_path importlib.reload(skypilot_config) - task = sky.Task.from_yaml(os.path.join(POLICY_PATH, 'task.yaml')) return admin_policy_utils.apply( task, request_options=sky.admin_policy.RequestOptions( cluster_name='test', - cluster_running=False, idle_minutes_to_autostop=idle_minutes_to_autostop, down=False, dryrun=False, )) -def test_task_level_changes_policy(add_example_policy_paths): +def test_use_spot_for_all_gpus_policy(add_example_policy_paths, task): + dag, _ = _load_task_and_apply_policy( + task, os.path.join(POLICY_PATH, 'use_spot_for_gpu.yaml')) + assert not any(r.use_spot for r in dag.tasks[0].resources), ( + 'use_spot should be False as GPU is not specified') + + task.set_resources([ + sky.Resources(cloud='gcp', accelerators={'A100': 1}), + sky.Resources(accelerators={'L4': 1}) + ]) dag, _ = _load_task_and_apply_policy( - os.path.join(POLICY_PATH, 'task_label_config.yaml')) - assert 'local_user' in list(dag.tasks[0].resources)[0].labels + task, os.path.join(POLICY_PATH, 'use_spot_for_gpu.yaml')) + assert all( + r.use_spot for r in dag.tasks[0].resources), 'use_spot should be True' + + task.set_resources([ + sky.Resources(accelerators={'A100': 1}), + sky.Resources(accelerators={'L4': 1}, use_spot=True), + sky.Resources(instance_type='n1-standard-2'), + sky.Resources(instance_type='n1-standard-2'), + ]) + dag, _ = _load_task_and_apply_policy( + task, os.path.join(POLICY_PATH, 'use_spot_for_gpu.yaml')) + for r in dag.tasks[0].resources: + if r.accelerators: + assert r.use_spot, 'use_spot should be True' + else: + assert not r.use_spot, 'use_spot should be False' -def test_config_level_changes_policy(add_example_policy_paths): - _load_task_and_apply_policy( - os.path.join(POLICY_PATH, 'config_label_config.yaml')) - print(skypilot_config._dict) - assert 'local_user' in skypilot_config.get_nested(('gcp', 'labels'), {}) +def test_add_labels_policy(add_example_policy_paths, task): + dag, _ = _load_task_and_apply_policy( + task, os.path.join(POLICY_PATH, 'add_labels.yaml')) + assert 'app' in skypilot_config.get_nested( + ('kubernetes', 'custom_metadata', 'labels'), + {}), ('label should be set') -def test_reject_all_policy(add_example_policy_paths): +def test_reject_all_policy(add_example_policy_paths, task): with pytest.raises(exceptions.UserRequestRejectedByPolicy, match='Reject all policy'): _load_task_and_apply_policy( - os.path.join(POLICY_PATH, 'reject_all_config.yaml')) + task, os.path.join(POLICY_PATH, 'reject_all.yaml')) -def test_enforce_autostop_policy(add_example_policy_paths): - _load_task_and_apply_policy(os.path.join(POLICY_PATH, - 'enforce_autostop.yaml'), - idle_minutes_to_autostop=10) - with pytest.raises(exceptions.UserRequestRejectedByPolicy, - match='Autostop/down must be set'): - _load_task_and_apply_policy(os.path.join(POLICY_PATH, +def test_enforce_autostop_policy(add_example_policy_paths, task): + + def _gen_cluster_record(status: sky.ClusterStatus, autostop: int) -> dict: + return { + 'name': 'test', + 'status': status, + 'autostop': autostop, + } + + # Cluster does not exist + with patch('sky.status', return_value=[]): + _load_task_and_apply_policy(task, + os.path.join(POLICY_PATH, + 'enforce_autostop.yaml'), + idle_minutes_to_autostop=10) + + with pytest.raises(exceptions.UserRequestRejectedByPolicy, + match='Autostop/down must be set'): + _load_task_and_apply_policy(task, + os.path.join(POLICY_PATH, + 'enforce_autostop.yaml'), + idle_minutes_to_autostop=None) + + # Cluster is stopped + with patch( + 'sky.status', + return_value=[_gen_cluster_record(sky.ClusterStatus.STOPPED, 10)]): + _load_task_and_apply_policy(task, + os.path.join(POLICY_PATH, + 'enforce_autostop.yaml'), + idle_minutes_to_autostop=10) + with pytest.raises(exceptions.UserRequestRejectedByPolicy, + match='Autostop/down must be set'): + _load_task_and_apply_policy(task, + os.path.join(POLICY_PATH, + 'enforce_autostop.yaml'), + idle_minutes_to_autostop=None) + + # Cluster is running but autostop is not set + with patch('sky.status', + return_value=[_gen_cluster_record(sky.ClusterStatus.UP, -1)]): + _load_task_and_apply_policy(task, + os.path.join(POLICY_PATH, + 'enforce_autostop.yaml'), + idle_minutes_to_autostop=10) + with pytest.raises(exceptions.UserRequestRejectedByPolicy, + match='Autostop/down must be set'): + _load_task_and_apply_policy(task, + os.path.join(POLICY_PATH, + 'enforce_autostop.yaml'), + idle_minutes_to_autostop=None) + + # Cluster is init but autostop is not set + with patch('sky.status', + return_value=[_gen_cluster_record(sky.ClusterStatus.INIT, -1)]): + _load_task_and_apply_policy(task, + os.path.join(POLICY_PATH, + 'enforce_autostop.yaml'), + idle_minutes_to_autostop=10) + with pytest.raises(exceptions.UserRequestRejectedByPolicy, + match='Autostop/down must be set'): + _load_task_and_apply_policy(task, + os.path.join(POLICY_PATH, + 'enforce_autostop.yaml'), + idle_minutes_to_autostop=None) + + # Cluster is running and autostop is set + with patch('sky.status', + return_value=[_gen_cluster_record(sky.ClusterStatus.UP, 10)]): + _load_task_and_apply_policy(task, + os.path.join(POLICY_PATH, + 'enforce_autostop.yaml'), + idle_minutes_to_autostop=10) + _load_task_and_apply_policy(task, + os.path.join(POLICY_PATH, 'enforce_autostop.yaml'), idle_minutes_to_autostop=None) From 5b37f471d8d7b846b419e9e2627f2577d0e3adf9 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 23 Sep 2024 04:25:14 +0000 Subject: [PATCH 52/61] update policy --- docs/source/cloud-setup/policy.rst | 32 +++++++++++++++++++----------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/docs/source/cloud-setup/policy.rst b/docs/source/cloud-setup/policy.rst index 482ee262ad8..c02a04d7548 100644 --- a/docs/source/cloud-setup/policy.rst +++ b/docs/source/cloud-setup/policy.rst @@ -52,7 +52,7 @@ Admin-Side ~~~~~~~~~~ An admin can distribute the Python package to users with a pre-defined policy. The -policy should follow the following interface: +policy should implement the `sky.AdminPolicy` `interface `_: .. code-block:: python @@ -67,29 +67,37 @@ policy should follow the following interface: user_request.skypilot_config) -``UserRequest`` and ``MutatedUserRequest`` are defined as follows: +``UserRequest`` and ``MutatedUserRequest`` are defined as follows (see `source code `_ for more details): .. code-block:: python class UserRequest: - """User request to the policy. + """A user request. - It is a combination of a task, request options, and the global skypilot - config used to run a task, including `sky launch / exec / jobs launch / ..`. + A "user request" is defined as a `sky launch / exec` command or its API + equivalent. + + `sky jobs launch / serve up` involves multiple launch requests, including + the launch of controller and clusters for a job (which can have multiple + tasks if it is a pipeline) or service replicas. Each launch is a separate + request. + + This class wraps the underlying task, the global skypilot config used to run + a task, and the request options. Args: task: User specified task. skypilot_config: Global skypilot config to be used in this request. - request_options: Request options. It can be None for jobs and - services. + request_options: Request options. It is None for jobs and services. """ - task: sky.Task - skypilot_config: sky.Config - operation_args: sky.RequestOptions + task: 'sky.Task' + skypilot_config: 'sky.Config' + request_options: Optional['RequestOptions'] = None + class MutatedUserRequest: - task: sky.Task - skypilot_config: sky.Config + task: 'sky.Task' + skypilot_config: 'sky.Config' That said, an ``AdminPolicy`` can mutate any fields of a user request, including the :ref:`task ` and the :ref:`global skypilot config `, From c7af310a3fbc9fed73e6fc2d7278e4c554df629f Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sun, 22 Sep 2024 21:25:38 -0700 Subject: [PATCH 53/61] Update docs/source/cloud-setup/policy.rst Co-authored-by: Zongheng Yang --- docs/source/cloud-setup/policy.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/cloud-setup/policy.rst b/docs/source/cloud-setup/policy.rst index 482ee262ad8..3a25d90894e 100644 --- a/docs/source/cloud-setup/policy.rst +++ b/docs/source/cloud-setup/policy.rst @@ -91,14 +91,14 @@ policy should follow the following interface: task: sky.Task skypilot_config: sky.Config -That said, an ``AdminPolicy`` can mutate any fields of a user request, including +In other words, an ``AdminPolicy`` can mutate any fields of a user request, including the :ref:`task ` and the :ref:`global skypilot config `, giving admins a lot of flexibility to control user's SkyPilot usage. -An ``AdminPolicy`` is responsible to both validate and mutate user requests. If +An ``AdminPolicy`` can be used to both validate and mutate user requests. If a request should be rejected, the policy should raise an exception. -The ``sky.Config`` and ``sky.RequestOptions`` are defined as follows: +The ``sky.Config`` and ``sky.RequestOptions`` classes are defined as follows: .. code-block:: python From cb232a8a0f8e436181c8f570d24fa86ae7c1af97 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 23 Sep 2024 04:26:35 +0000 Subject: [PATCH 54/61] fix policy.rst --- docs/source/cloud-setup/policy.rst | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/docs/source/cloud-setup/policy.rst b/docs/source/cloud-setup/policy.rst index c02a04d7548..d9346261dd2 100644 --- a/docs/source/cloud-setup/policy.rst +++ b/docs/source/cloud-setup/policy.rst @@ -127,11 +127,19 @@ The ``sky.Config`` and ``sky.RequestOptions`` are defined as follows: """Sets a value with nested keys.""" ... - @dataclass class RequestOptions: - """Options a user specified in their request to SkyPilot.""" + """Request options for admin policy. + + Args: + cluster_name: Name of the cluster to create/reuse. + cluster_running: Whether the cluster is running. + idle_minutes_to_autostop: If provided, the cluster will be set to + autostop after this many minutes of idleness. + down: If true, use autodown rather than autostop. + dryrun: Is the request a dryrun? + """ + # Cluster name is None if not specified by the user. cluster_name: Optional[str] - cluster_exists: bool idle_minutes_to_autostop: Optional[int] down: bool dryrun: bool From deb4c92de100ab9abd483f8f1424511db9b6c177 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 23 Sep 2024 06:01:28 +0000 Subject: [PATCH 55/61] Add comment --- sky/utils/admin_policy_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sky/utils/admin_policy_utils.py b/sky/utils/admin_policy_utils.py index a18a9512ff4..4e69fa6c988 100644 --- a/sky/utils/admin_policy_utils.py +++ b/sky/utils/admin_policy_utils.py @@ -105,6 +105,11 @@ def apply( mutated_config = mutated_user_request.skypilot_config else: if mutated_config != mutated_user_request.skypilot_config: + # In the case of a pipeline of tasks, the mutated config + # generated should remain the same for all tasks for now for + # simplicity. + # TODO(zhwu): We should support per-task mutated config or + # allowing overriding required global config in task YAML. with ux_utils.print_exception_no_traceback(): raise exceptions.UserRequestRejectedByPolicy( 'All tasks must have the same skypilot ' From cbff59d14eae1dc5cc47be4cde62c43dc82dfb16 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 23 Sep 2024 06:01:56 +0000 Subject: [PATCH 56/61] Fix logging --- sky/utils/admin_policy_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sky/utils/admin_policy_utils.py b/sky/utils/admin_policy_utils.py index 4e69fa6c988..09db2fc4be8 100644 --- a/sky/utils/admin_policy_utils.py +++ b/sky/utils/admin_policy_utils.py @@ -112,9 +112,9 @@ def apply( # allowing overriding required global config in task YAML. with ux_utils.print_exception_no_traceback(): raise exceptions.UserRequestRejectedByPolicy( - 'All tasks must have the same skypilot ' - 'config after applying the policy. Please' - 'check with your policy admin for details.') + 'All tasks must have the same SkyPilot config after ' + 'applying the policy. Please check with your policy ' + 'admin for details.') mutated_dag.add(mutated_user_request.task) assert mutated_config is not None, dag From 1fe350a2e0367fa8c8143dccb4dae9f42ef8ab23 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 23 Sep 2024 07:29:55 +0000 Subject: [PATCH 57/61] fix CI --- tests/unit_tests/test_admin_policy.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/unit_tests/test_admin_policy.py b/tests/unit_tests/test_admin_policy.py index 70f62c4f1e2..3003c4fa21f 100644 --- a/tests/unit_tests/test_admin_policy.py +++ b/tests/unit_tests/test_admin_policy.py @@ -64,8 +64,7 @@ def test_use_spot_for_all_gpus_policy(add_example_policy_paths, task): task.set_resources([ sky.Resources(accelerators={'A100': 1}), sky.Resources(accelerators={'L4': 1}, use_spot=True), - sky.Resources(instance_type='n1-standard-2'), - sky.Resources(instance_type='n1-standard-2'), + sky.Resources(cpus='2+'), ]) dag, _ = _load_task_and_apply_policy( task, os.path.join(POLICY_PATH, 'use_spot_for_gpu.yaml')) From 2e8e41c7db5e0940864be6f866f006effe86034c Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 23 Sep 2024 10:07:51 -0700 Subject: [PATCH 58/61] Update docs/source/cloud-setup/policy.rst Co-authored-by: Zongheng Yang --- docs/source/cloud-setup/policy.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/cloud-setup/policy.rst b/docs/source/cloud-setup/policy.rst index 18324e3ffe5..6650bf129b2 100644 --- a/docs/source/cloud-setup/policy.rst +++ b/docs/source/cloud-setup/policy.rst @@ -16,8 +16,8 @@ Example usage: To implement and use an admin policy: - - Admins writes a simple Python package with a policy class that implements SkyPilot's ``sky.AdminPolicy`` interface; - - Admins distributes this package to users; + - Admins write a simple Python package with a policy class that implements SkyPilot's ``sky.AdminPolicy`` interface; + - Admins distribute this package to users; - Users simply set the ``admin_policy`` field in the SkyPilot config file ``~/.sky/config.yaml`` for the policy to go into effect. From aae42ceb0cb2f521ffb242e83fa71133f6d41b9a Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 23 Sep 2024 18:37:55 +0000 Subject: [PATCH 59/61] Use sphnix inline code --- docs/source/cloud-setup/policy.rst | 245 ++++++------------ examples/admin_policy/disable_public_ip.yaml | 2 +- .../example_policy/example_policy/__init__.py | 4 +- .../example_policy/skypilot_policy.py | 6 +- examples/admin_policy/use_spot_for_gpu.yaml | 2 +- sky/admin_policy.py | 9 +- sky/execution.py | 15 +- tests/unit_tests/test_admin_policy.py | 21 +- tests/unit_tests/test_backend_utils.py | 29 ++- tests/unit_tests/test_common_utils.py | 8 +- tests/unit_tests/test_resources.py | 18 +- 11 files changed, 139 insertions(+), 220 deletions(-) diff --git a/docs/source/cloud-setup/policy.rst b/docs/source/cloud-setup/policy.rst index 18324e3ffe5..0d3e3444372 100644 --- a/docs/source/cloud-setup/policy.rst +++ b/docs/source/cloud-setup/policy.rst @@ -9,16 +9,17 @@ custom validation and mutation logic to a user's tasks and SkyPilot config. Example usage: - - Adds custom labels to all tasks [Link to below, fix case] - - Always Disable Public IP for AWS Tasks [Link to below] - - Enforce Autostop for all Tasks [Link to below] +- :ref:`kubernetes-labels-policy` +- :ref:`disable-public-ip-policy` +- :ref:`use-spot-for-gpu-policy` +- :ref:`enforce-autostop-policy` To implement and use an admin policy: - - Admins writes a simple Python package with a policy class that implements SkyPilot's ``sky.AdminPolicy`` interface; - - Admins distributes this package to users; - - Users simply set the ``admin_policy`` field in the SkyPilot config file ``~/.sky/config.yaml`` for the policy to go into effect. +- Admins writes a simple Python package with a policy class that implements SkyPilot's ``sky.AdminPolicy`` interface; +- Admins distributes this package to users; +- Users simply set the ``admin_policy`` field in the SkyPilot config file ``~/.sky/config.yaml`` for the policy to go into effect. Overview @@ -52,7 +53,16 @@ Admin-Side ~~~~~~~~~~ An admin can distribute the Python package to users with a pre-defined policy. The -policy should implement the `sky.AdminPolicy` `interface `_: +policy should implement the ``sky.AdminPolicy`` `interface `_: + + +.. literalinclude:: ../../../sky/admin_policy.py + :language: python + :pyobject: AdminPolicy + :caption: `AdminPolicy Interface `_ + + +Your custom admin policy should look like this: .. code-block:: python @@ -69,35 +79,17 @@ policy should implement the `sky.AdminPolicy` `interface `_ for more details): -.. code-block:: python - - class UserRequest: - """A user request. - A "user request" is defined as a `sky launch / exec` command or its API - equivalent. +.. literalinclude:: ../../../sky/admin_policy.py + :language: python + :pyobject: UserRequest + :caption: `UserRequest Class `_ - `sky jobs launch / serve up` involves multiple launch requests, including - the launch of controller and clusters for a job (which can have multiple - tasks if it is a pipeline) or service replicas. Each launch is a separate - request. +.. literalinclude:: ../../../sky/admin_policy.py + :language: python + :pyobject: MutatedUserRequest + :caption: `MutatedUserRequest Class `_ - This class wraps the underlying task, the global skypilot config used to run - a task, and the request options. - - Args: - task: User specified task. - skypilot_config: Global skypilot config to be used in this request. - request_options: Request options. It is None for jobs and services. - """ - task: 'sky.Task' - skypilot_config: 'sky.Config' - request_options: Optional['RequestOptions'] = None - - - class MutatedUserRequest: - task: 'sky.Task' - skypilot_config: 'sky.Config' In other words, an ``AdminPolicy`` can mutate any fields of a user request, including the :ref:`task ` and the :ref:`global skypilot config `, @@ -106,43 +98,19 @@ giving admins a lot of flexibility to control user's SkyPilot usage. An ``AdminPolicy`` can be used to both validate and mutate user requests. If a request should be rejected, the policy should raise an exception. -The ``sky.Config`` and ``sky.RequestOptions`` classes are defined as follows: -.. code-block:: python +The ``sky.Config`` and ``sky.RequestOptions`` classes are defined as follows: - class Config: - def get_nested(self, - keys: Tuple[str, ...], - default_value: Any, - override_configs: Optional[Dict[str, Any]] = None, - ) -> Any: - """Gets a value with nested keys. - - If override_configs is provided, it value will be merged on top of - the current config. - """ - ... +.. literalinclude:: ../../../sky/skypilot_config.py + :language: python + :pyobject: Config + :caption: `Config Class `_ - def set_nested(self, keys: Tuple[str, ...], value: Any) -> None: - """Sets a value with nested keys.""" - ... - class RequestOptions: - """Request options for admin policy. - - Args: - cluster_name: Name of the cluster to create/reuse. - cluster_running: Whether the cluster is running. - idle_minutes_to_autostop: If provided, the cluster will be set to - autostop after this many minutes of idleness. - down: If true, use autodown rather than autostop. - dryrun: Is the request a dryrun? - """ - # Cluster name is None if not specified by the user. - cluster_name: Optional[str] - idle_minutes_to_autostop: Optional[int] - down: bool - dryrun: bool +.. literalinclude:: ../../../sky/admin_policy.py + :language: python + :pyobject: RequestOptions + :caption: `RequestOptions Class `_ Example Policies @@ -159,122 +127,69 @@ We have provided a few example policies in `examples/admin_policy/example_policy Reject All ~~~~~~~~~~ -.. code-block:: python - - class RejectAllPolicy(sky.AdminPolicy): - """Example policy: rejects all user requests.""" +.. literalinclude:: ../../../examples/admin_policy/example_policy/example_policy/skypilot_policy.py + :language: python + :pyobject: RejectAllPolicy + :caption: `RejectAllPolicy `_ - @classmethod - def validate_and_mutate(cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: - """Rejects all user requests.""" - raise RuntimeError("This policy rejects all user requests.") - -.. code-block:: yaml +.. literalinclude:: ../../../examples/admin_policy/reject_all.yaml + :language: yaml + :caption: `Config YAML for using RejectAllPolicy `_ - admin_policy: example_policy.RejectAllPolicy +.. _kubernetes-labels-policy: +Add Labels for all Tasks on Kubernetes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Add Kubernetes Labels for all Tasks -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. code-block:: python - - class AddLabelsPolicy(sky.AdminPolicy): - """Example policy: adds a kubernetes label for skypilot_config.""" - - @classmethod - def validate_and_mutate(cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: - config = user_request.skypilot_config - labels = config.get_nested(('kubernetes', 'labels'), {}) - labels['app'] = 'skypilot' - config.set_nested(('kubernetes', 'labels'), labels) - return sky.MutatedUserRequest(user_request.task, config) - -.. code-block:: yaml +.. literalinclude:: ../../../examples/admin_policy/example_policy/example_policy/skypilot_policy.py + :language: python + :pyobject: AddLabelsPolicy + :caption: `AddLabelsPolicy `_ - admin_policy: example_policy.AddLabelsPolicy +.. literalinclude:: ../../../examples/admin_policy/add_labels.yaml + :language: yaml + :caption: `Config YAML for using AddLabelsPolicy `_ +.. _disable-public-ip-policy: + Always Disable Public IP for AWS Tasks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. code-block:: python +.. literalinclude:: ../../../examples/admin_policy/example_policy/example_policy/skypilot_policy.py + :language: python + :pyobject: DisablePublicIpPolicy + :caption: `DisablePublicIpPolicy `_ - class DisablePublicIPPolicy(sky.AdminPolicy): - """Example policy: disables public IP for all tasks.""" +.. literalinclude:: ../../../examples/admin_policy/disable_public_ip.yaml + :language: yaml + :caption: `Config YAML for using DisablePublicIpPolicy `_ - @classmethod - def validate_and_mutate(cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: - config = user_request.skypilot_config - config.set_nested(('aws', 'use_internal_ip'), True) - if config.get_nested(('aws', 'vpc_name'), None) is None: - # If no VPC name is specified, it is likely a mistake. We should - # reject the request - raise RuntimeError('VPC name should be set. Check organization ' - 'wiki for more information.') - return sky.MutatedUserRequest(user_request.task, config) +.. _use-spot-for-gpu-policy: -.. code-block:: yaml +Use Spot for all GPU Tasks +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. +.. literalinclude:: ../../../examples/admin_policy/example_policy/example_policy/skypilot_policy.py + :language: python + :pyobject: UseSpotForGpuPolicy + :caption: `UseSpotForGpuPolicy `_ - admin_policy: example_policy.DisablePublicIPPolicy +.. literalinclude:: ../../../examples/admin_policy/use_spot_for_gpu.yaml + :language: yaml + :caption: `Config YAML for using UseSpotForGpuPolicy `_ +.. _enforce-autostop-policy: Enforce Autostop for all Tasks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. code-block:: python - - class EnforceAutostopPolicy(sky.AdminPolicy): - """Example policy: enforce autostop for all tasks.""" - - @classmethod - def validate_and_mutate( - cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest: - """Enforces autostop for all tasks. - - Note that with this policy enforced, users can still change the autostop - setting for an existing cluster by using `sky autostop`. - """ - request_options = user_request.request_options - - # Request options is None when a task is executed with `jobs launch` or - # `sky serve up`. - if request_options is None: - return sky.MutatedUserRequest( - task=user_request.task, - skypilot_config=user_request.skypilot_config) - - # Get the cluster record to operate on. - cluster_record = sky.status(request_options.cluster_name, refresh=True) - - # Check if the user request should specify autostop settings. - need_autostop = False - if not cluster_record: - # Cluster does not exist - need_autostop = True - elif cluster_record[0]['status'] == sky.ClusterStatus.STOPPED: - # Cluster is stopped - need_autostop = True - elif cluster_record[0]['autostop'] < 0: - # Cluster is running but autostop is not set - need_autostop = True - - # Check if the user request is setting autostop settings. - is_setting_autostop = False - idle_minutes_to_autostop = request_options.idle_minutes_to_autostop - is_setting_autostop = (idle_minutes_to_autostop is not None and - idle_minutes_to_autostop >= 0) - - # If the cluster requires autostop but the user request is not setting - # autostop settings, raise an error. - if need_autostop and not is_setting_autostop: - raise RuntimeError('Autostop/down must be set for all clusters.') - - return sky.MutatedUserRequest( - task=user_request.task, - skypilot_config=user_request.skypilot_config) - - -.. code-block:: yaml +.. literalinclude:: ../../../examples/admin_policy/example_policy/example_policy/skypilot_policy.py + :language: python + :pyobject: EnforceAutostopPolicy + :caption: `EnforceAutostopPolicy `_ - admin_policy: example_policy.EnforceAutostopPolicy +.. literalinclude:: ../../../examples/admin_policy/enforce_autostop.yaml + :language: yaml + :caption: `Config YAML for using EnforceAutostopPolicy `_ diff --git a/examples/admin_policy/disable_public_ip.yaml b/examples/admin_policy/disable_public_ip.yaml index eb00757f409..cef910cbdaf 100644 --- a/examples/admin_policy/disable_public_ip.yaml +++ b/examples/admin_policy/disable_public_ip.yaml @@ -1 +1 @@ -admin_policy: example_policy.DisablePublicIPPolicy +admin_policy: example_policy.DisablePublicIpPolicy diff --git a/examples/admin_policy/example_policy/example_policy/__init__.py b/examples/admin_policy/example_policy/example_policy/__init__.py index 862dfbd3d9f..12ca4e952e2 100644 --- a/examples/admin_policy/example_policy/example_policy/__init__.py +++ b/examples/admin_policy/example_policy/example_policy/__init__.py @@ -1,6 +1,6 @@ """Example admin policy module and prebuilt policies.""" from example_policy.skypilot_policy import AddLabelsPolicy -from example_policy.skypilot_policy import DisablePublicIPPolicy +from example_policy.skypilot_policy import DisablePublicIpPolicy from example_policy.skypilot_policy import EnforceAutostopPolicy from example_policy.skypilot_policy import RejectAllPolicy -from example_policy.skypilot_policy import UseSpotForGPUPolicy +from example_policy.skypilot_policy import UseSpotForGpuPolicy diff --git a/examples/admin_policy/example_policy/example_policy/skypilot_policy.py b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py index 5f126d6c63a..0676f65d502 100644 --- a/examples/admin_policy/example_policy/example_policy/skypilot_policy.py +++ b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py @@ -26,8 +26,8 @@ def validate_and_mutate( return sky.MutatedUserRequest(user_request.task, config) -class DisablePublicIPPolicy(sky.AdminPolicy): - """Example policy: disables public IP for all tasks.""" +class DisablePublicIpPolicy(sky.AdminPolicy): + """Example policy: disables public IP for all AWS tasks.""" @classmethod def validate_and_mutate( @@ -42,7 +42,7 @@ def validate_and_mutate( return sky.MutatedUserRequest(user_request.task, config) -class UseSpotForGPUPolicy(sky.AdminPolicy): +class UseSpotForGpuPolicy(sky.AdminPolicy): """Example policy: use spot instances for all GPU tasks.""" @classmethod diff --git a/examples/admin_policy/use_spot_for_gpu.yaml b/examples/admin_policy/use_spot_for_gpu.yaml index d1234c27729..45f257017a4 100644 --- a/examples/admin_policy/use_spot_for_gpu.yaml +++ b/examples/admin_policy/use_spot_for_gpu.yaml @@ -1 +1 @@ -admin_policy: example_policy.UseSpotForGPUPolicy +admin_policy: example_policy.UseSpotForGpuPolicy diff --git a/sky/admin_policy.py b/sky/admin_policy.py index 7c17eaae258..304285d04b7 100644 --- a/sky/admin_policy.py +++ b/sky/admin_policy.py @@ -13,14 +13,13 @@ class RequestOptions: """Request options for admin policy. Args: - cluster_name: Name of the cluster to create/reuse. - cluster_running: Whether the cluster is running. - idle_minutes_to_autostop: If provided, the cluster will be set to - autostop after this many minutes of idleness. + cluster_name: Name of the cluster to create/reuse. It is None if not + specified by the user. + idle_minutes_to_autostop: Autostop setting requested by a user. The + cluster will be set to autostop after this many minutes of idleness. down: If true, use autodown rather than autostop. dryrun: Is the request a dryrun? """ - # Cluster name is None if not specified by the user. cluster_name: Optional[str] idle_minutes_to_autostop: Optional[int] down: bool diff --git a/sky/execution.py b/sky/execution.py index 0b2663765b1..792ca5fffc0 100644 --- a/sky/execution.py +++ b/sky/execution.py @@ -160,13 +160,6 @@ def _execute( handle: Optional[backends.ResourceHandle]; the handle to the cluster. None if dryrun. """ - cluster_exists = False - if cluster_name is not None: - cluster_record = global_user_state.get_cluster_from_name(cluster_name) - cluster_exists = cluster_record is not None - # TODO(woosuk): If the cluster exists, print a warning that - # `cpus` and `memory` are not used as a job scheduling constraint, - # unlike `gpus`. dag = dag_utils.convert_entrypoint_to_dag(entrypoint) dag, _ = admin_policy_utils.apply( @@ -186,6 +179,14 @@ def _execute( 'Job recovery is specified in the task. To launch a ' 'managed job, please use: sky jobs launch') + cluster_exists = False + if cluster_name is not None: + cluster_record = global_user_state.get_cluster_from_name(cluster_name) + cluster_exists = cluster_record is not None + # TODO(woosuk): If the cluster exists, print a warning that + # `cpus` and `memory` are not used as a job scheduling constraint, + # unlike `gpus`. + stages = stages if stages is not None else list(Stage) # Requested features that some clouds support and others don't. diff --git a/tests/unit_tests/test_admin_policy.py b/tests/unit_tests/test_admin_policy.py index 3003c4fa21f..96b666493d3 100644 --- a/tests/unit_tests/test_admin_policy.py +++ b/tests/unit_tests/test_admin_policy.py @@ -2,7 +2,7 @@ import os import sys from typing import Optional, Tuple -from unittest.mock import patch +from unittest import mock import pytest @@ -100,7 +100,7 @@ def _gen_cluster_record(status: sky.ClusterStatus, autostop: int) -> dict: } # Cluster does not exist - with patch('sky.status', return_value=[]): + with mock.patch('sky.status', return_value=[]): _load_task_and_apply_policy(task, os.path.join(POLICY_PATH, 'enforce_autostop.yaml'), @@ -114,7 +114,7 @@ def _gen_cluster_record(status: sky.ClusterStatus, autostop: int) -> dict: idle_minutes_to_autostop=None) # Cluster is stopped - with patch( + with mock.patch( 'sky.status', return_value=[_gen_cluster_record(sky.ClusterStatus.STOPPED, 10)]): _load_task_and_apply_policy(task, @@ -129,8 +129,9 @@ def _gen_cluster_record(status: sky.ClusterStatus, autostop: int) -> dict: idle_minutes_to_autostop=None) # Cluster is running but autostop is not set - with patch('sky.status', - return_value=[_gen_cluster_record(sky.ClusterStatus.UP, -1)]): + with mock.patch( + 'sky.status', + return_value=[_gen_cluster_record(sky.ClusterStatus.UP, -1)]): _load_task_and_apply_policy(task, os.path.join(POLICY_PATH, 'enforce_autostop.yaml'), @@ -143,8 +144,9 @@ def _gen_cluster_record(status: sky.ClusterStatus, autostop: int) -> dict: idle_minutes_to_autostop=None) # Cluster is init but autostop is not set - with patch('sky.status', - return_value=[_gen_cluster_record(sky.ClusterStatus.INIT, -1)]): + with mock.patch( + 'sky.status', + return_value=[_gen_cluster_record(sky.ClusterStatus.INIT, -1)]): _load_task_and_apply_policy(task, os.path.join(POLICY_PATH, 'enforce_autostop.yaml'), @@ -157,8 +159,9 @@ def _gen_cluster_record(status: sky.ClusterStatus, autostop: int) -> dict: idle_minutes_to_autostop=None) # Cluster is running and autostop is set - with patch('sky.status', - return_value=[_gen_cluster_record(sky.ClusterStatus.UP, 10)]): + with mock.patch( + 'sky.status', + return_value=[_gen_cluster_record(sky.ClusterStatus.UP, 10)]): _load_task_and_apply_policy(task, os.path.join(POLICY_PATH, 'enforce_autostop.yaml'), diff --git a/tests/unit_tests/test_backend_utils.py b/tests/unit_tests/test_backend_utils.py index aa22f36df20..5da4410abb9 100644 --- a/tests/unit_tests/test_backend_utils.py +++ b/tests/unit_tests/test_backend_utils.py @@ -1,6 +1,6 @@ import os import pathlib -from unittest.mock import patch +from unittest import mock from sky import clouds from sky import skypilot_config @@ -9,19 +9,20 @@ # Set env var to test config file. -@patch.object(skypilot_config, '_dict', None) -@patch.object(skypilot_config, '_loaded_config_path', None) -@patch('sky.clouds.service_catalog.instance_type_exists', return_value=True) -@patch('sky.clouds.service_catalog.get_accelerators_from_instance_type', - return_value={'fake-acc': 2}) -@patch('sky.clouds.service_catalog.get_image_id_from_tag', - return_value='fake-image') -@patch.object(clouds.aws, 'DEFAULT_SECURITY_GROUP_NAME', 'fake-default-sg') -@patch('sky.check.get_cloud_credential_file_mounts', - return_value='~/.aws/credentials') -@patch('sky.backends.backend_utils._get_yaml_path_from_cluster_name', - return_value='/tmp/fake/path') -@patch('sky.utils.common_utils.fill_template') +@mock.patch.object(skypilot_config, '_dict', None) +@mock.patch.object(skypilot_config, '_loaded_config_path', None) +@mock.patch('sky.clouds.service_catalog.instance_type_exists', + return_value=True) +@mock.patch('sky.clouds.service_catalog.get_accelerators_from_instance_type', + return_value={'fake-acc': 2}) +@mock.patch('sky.clouds.service_catalog.get_image_id_from_tag', + return_value='fake-image') +@mock.patch.object(clouds.aws, 'DEFAULT_SECURITY_GROUP_NAME', 'fake-default-sg') +@mock.patch('sky.check.get_cloud_credential_file_mounts', + return_value='~/.aws/credentials') +@mock.patch('sky.backends.backend_utils._get_yaml_path_from_cluster_name', + return_value='/tmp/fake/path') +@mock.patch('sky.utils.common_utils.fill_template') def test_write_cluster_config_w_remote_identity(mock_fill_template, *mocks) -> None: os.environ['SKYPILOT_CONFIG'] = './tests/test_yamls/test_aws_config.yaml' diff --git a/tests/unit_tests/test_common_utils.py b/tests/unit_tests/test_common_utils.py index f38e14069e5..38c31263baa 100644 --- a/tests/unit_tests/test_common_utils.py +++ b/tests/unit_tests/test_common_utils.py @@ -1,4 +1,4 @@ -from unittest.mock import patch +from unittest import mock import pytest @@ -33,18 +33,18 @@ def test_check_when_none(self): class TestMakeClusterNameOnCloud: - @patch('sky.utils.common_utils.get_user_hash') + @mock.patch('sky.utils.common_utils.get_user_hash') def test_make(self, mock_get_user_hash): mock_get_user_hash.return_value = MOCKED_USER_HASH assert "lora-ab12" == common_utils.make_cluster_name_on_cloud("lora") - @patch('sky.utils.common_utils.get_user_hash') + @mock.patch('sky.utils.common_utils.get_user_hash') def test_make_with_hyphen(self, mock_get_user_hash): mock_get_user_hash.return_value = MOCKED_USER_HASH assert "seed-1-ab12" == common_utils.make_cluster_name_on_cloud( "seed-1") - @patch('sky.utils.common_utils.get_user_hash') + @mock.patch('sky.utils.common_utils.get_user_hash') def test_make_with_characters_to_transform(self, mock_get_user_hash): mock_get_user_hash.return_value = MOCKED_USER_HASH assert "cuda-11-8-ab12" == common_utils.make_cluster_name_on_cloud( diff --git a/tests/unit_tests/test_resources.py b/tests/unit_tests/test_resources.py index 91123071ea6..b6180612615 100644 --- a/tests/unit_tests/test_resources.py +++ b/tests/unit_tests/test_resources.py @@ -1,8 +1,7 @@ import importlib import os from typing import Dict -from unittest.mock import Mock -from unittest.mock import patch +from unittest import mock import pytest @@ -25,7 +24,7 @@ def test_get_reservations_available_resources(): - mock = Mock() + mock = mock.Mock() r = Resources(cloud=mock, instance_type="instance_type") r._region = "region" r._zone = "zone" @@ -93,12 +92,13 @@ def test_kubernetes_labels_resources(): _run_label_test(allowed_labels, invalid_labels, cloud) -@patch('sky.clouds.service_catalog.instance_type_exists', return_value=True) -@patch('sky.clouds.service_catalog.get_accelerators_from_instance_type', - return_value={'fake-acc': 2}) -@patch('sky.clouds.service_catalog.get_image_id_from_tag', - return_value='fake-image') -@patch.object(clouds.aws, 'DEFAULT_SECURITY_GROUP_NAME', 'fake-default-sg') +@mock.patch('sky.clouds.service_catalog.instance_type_exists', + return_value=True) +@mock.patch('sky.clouds.service_catalog.get_accelerators_from_instance_type', + return_value={'fake-acc': 2}) +@mock.patch('sky.clouds.service_catalog.get_image_id_from_tag', + return_value='fake-image') +@mock.patch.object(clouds.aws, 'DEFAULT_SECURITY_GROUP_NAME', 'fake-default-sg') def test_aws_make_deploy_variables(*mocks) -> None: os.environ['SKYPILOT_CONFIG'] = './tests/test_yamls/test_aws_config.yaml' importlib.reload(skypilot_config) From 11bbd5e78a30f298cebe03e7252ec9a0d2b61c91 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 23 Sep 2024 18:48:36 +0000 Subject: [PATCH 60/61] Add comment --- .../example_policy/example_policy/skypilot_policy.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/admin_policy/example_policy/example_policy/skypilot_policy.py b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py index 0676f65d502..dc4e4b873fb 100644 --- a/examples/admin_policy/example_policy/example_policy/skypilot_policy.py +++ b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py @@ -73,6 +73,10 @@ def validate_and_mutate( Note that with this policy enforced, users can still change the autostop setting for an existing cluster by using `sky autostop`. + + Since we refresh the cluster status with `sky.status` whenever this + policy is applied, we should expect a few seconds latency when a user + run a request. """ request_options = user_request.request_options From 3630535745d8a9a788466d803d84f8338750630c Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 23 Sep 2024 20:50:19 +0000 Subject: [PATCH 61/61] fix skypilot config file mounts for jobs and serve --- sky/templates/jobs-controller.yaml.j2 | 2 ++ sky/templates/sky-serve-controller.yaml.j2 | 2 ++ tests/unit_tests/test_resources.py | 6 +++--- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/sky/templates/jobs-controller.yaml.j2 b/sky/templates/jobs-controller.yaml.j2 index 8869a5874bf..45cdb5141d4 100644 --- a/sky/templates/jobs-controller.yaml.j2 +++ b/sky/templates/jobs-controller.yaml.j2 @@ -4,7 +4,9 @@ name: {{dag_name}} file_mounts: {{remote_user_yaml_path}}: {{user_yaml_path}} + {%- if local_user_config_path is not none %} {{remote_user_config_path}}: {{local_user_config_path}} + {%- endif %} {%- for remote_catalog_path, local_catalog_path in modified_catalogs.items() %} {{remote_catalog_path}}: {{local_catalog_path}} {%- endfor %} diff --git a/sky/templates/sky-serve-controller.yaml.j2 b/sky/templates/sky-serve-controller.yaml.j2 index 3b6a5ad2d49..507a6e3a325 100644 --- a/sky/templates/sky-serve-controller.yaml.j2 +++ b/sky/templates/sky-serve-controller.yaml.j2 @@ -23,7 +23,9 @@ setup: | file_mounts: {{remote_task_yaml_path}}: {{local_task_yaml_path}} + {%- if local_user_config_path is not none %} {{remote_user_config_path}}: {{local_user_config_path}} + {%- endif %} {%- for remote_catalog_path, local_catalog_path in modified_catalogs.items() %} {{remote_catalog_path}}: {{local_catalog_path}} {%- endfor %} diff --git a/tests/unit_tests/test_resources.py b/tests/unit_tests/test_resources.py index b6180612615..01b83132a1b 100644 --- a/tests/unit_tests/test_resources.py +++ b/tests/unit_tests/test_resources.py @@ -24,12 +24,12 @@ def test_get_reservations_available_resources(): - mock = mock.Mock() - r = Resources(cloud=mock, instance_type="instance_type") + mock_cloud = mock.Mock() + r = Resources(cloud=mock_cloud, instance_type="instance_type") r._region = "region" r._zone = "zone" r.get_reservations_available_resources() - mock.get_reservations_available_resources.assert_called_once_with( + mock_cloud.get_reservations_available_resources.assert_called_once_with( "instance_type", "region", "zone", set())