Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Admin policy enforcement plugin #3966

Merged
merged 67 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
cb28b8d
support policy hook
Michaelvll Sep 19, 2024
b64efa0
test task labels
Michaelvll Sep 19, 2024
cf89929
Add test for policy that sets labels
Michaelvll Sep 20, 2024
54c93ea
Fix comment
Michaelvll Sep 20, 2024
1d1c500
format
Michaelvll Sep 20, 2024
a0bdb2c
use -e to make test related files visible
Michaelvll Sep 20, 2024
543e66a
Add config.rst
Michaelvll Sep 20, 2024
520a2a1
Fix test
Michaelvll Sep 20, 2024
b533351
fix config rst
Michaelvll Sep 20, 2024
466f7fe
Apply policy to service
Michaelvll Sep 20, 2024
050dc7a
add policy for serving
Michaelvll Sep 20, 2024
31e0174
Add docs
Michaelvll Sep 20, 2024
0c74f2a
fix
Michaelvll Sep 20, 2024
48a6cc9
format
Michaelvll Sep 20, 2024
1ca5a8a
Update interface
Michaelvll Sep 20, 2024
14b2346
fix
Michaelvll Sep 21, 2024
cb39c73
Fix
Michaelvll Sep 21, 2024
1e3ddef
fix
Michaelvll Sep 21, 2024
aa87df7
Fix test config
Michaelvll Sep 21, 2024
28487a4
Fix mutated config
Michaelvll Sep 21, 2024
d1f0480
fix
Michaelvll Sep 21, 2024
f42ace5
Add policy doc
Michaelvll Sep 21, 2024
c04f3dc
rename
Michaelvll Sep 21, 2024
58f413c
minor
Michaelvll Sep 21, 2024
52053bd
Add additional arguments for autostop
Michaelvll Sep 21, 2024
4a4f682
fix mypy
Michaelvll Sep 21, 2024
a8d1c44
format
Michaelvll Sep 22, 2024
6c73d81
rejected message
Michaelvll Sep 22, 2024
247c0b8
format
Michaelvll Sep 22, 2024
f8a5a64
Update sky/utils/policy_utils.py
Michaelvll Sep 22, 2024
73a4581
Update sky/utils/policy_utils.py
Michaelvll Sep 22, 2024
d78a822
Fix
Michaelvll Sep 22, 2024
8cc963c
Merge branch 'policy-hook' of github.com:skypilot-org/skypilot into p…
Michaelvll Sep 22, 2024
68275f6
Update examples/admin_policy/example_policy/example_policy/__init__.py
Michaelvll Sep 22, 2024
9644622
Update docs/source/reference/config.rst
Michaelvll Sep 22, 2024
17f8fa1
Address comments
Michaelvll Sep 22, 2024
07c4748
format
Michaelvll Sep 22, 2024
15f1062
Merge branch 'policy-hook' of github.com:skypilot-org/skypilot into p…
Michaelvll Sep 22, 2024
994272b
changes in examples
Michaelvll Sep 22, 2024
3597dae
Fix enforce autostop
Michaelvll Sep 22, 2024
43a6088
Fix autostop enforcement
Michaelvll Sep 22, 2024
8770d0b
fix test
Michaelvll Sep 22, 2024
7984beb
Update docs/source/cloud-setup/policy.rst
Michaelvll Sep 23, 2024
d155d60
Update sky/admin_policy.py
Michaelvll Sep 23, 2024
6ffa5ae
Update sky/admin_policy.py
Michaelvll Sep 23, 2024
a6dd900
wip
Michaelvll Sep 23, 2024
4274287
Update docs/source/cloud-setup/policy.rst
Michaelvll Sep 23, 2024
0609482
Update docs/source/cloud-setup/policy.rst
Michaelvll Sep 23, 2024
67552d7
Update docs/source/cloud-setup/policy.rst
Michaelvll Sep 23, 2024
7de757e
fix
Michaelvll Sep 23, 2024
8443ddc
Merge branch 'policy-hook' of github.com:skypilot-org/skypilot into p…
Michaelvll Sep 23, 2024
7fbc30d
fix
Michaelvll Sep 23, 2024
92b68fc
fix
Michaelvll Sep 23, 2024
7d8af9a
Use sky.status for autostop
Michaelvll Sep 23, 2024
5b37f47
update policy
Michaelvll Sep 23, 2024
c7af310
Update docs/source/cloud-setup/policy.rst
Michaelvll Sep 23, 2024
cb232a8
fix policy.rst
Michaelvll Sep 23, 2024
5e9f544
Merge branch 'policy-hook' of github.com:skypilot-org/skypilot into p…
Michaelvll Sep 23, 2024
deb4c92
Add comment
Michaelvll Sep 23, 2024
cbff59d
Fix logging
Michaelvll Sep 23, 2024
1fe350a
fix CI
Michaelvll Sep 23, 2024
2e8e41c
Update docs/source/cloud-setup/policy.rst
Michaelvll Sep 23, 2024
aae42ce
Use sphnix inline code
Michaelvll Sep 23, 2024
73c8fb7
Merge branch 'policy-hook' of github.com:skypilot-org/skypilot into p…
Michaelvll Sep 23, 2024
11bbd5e
Add comment
Michaelvll Sep 23, 2024
3630535
fix skypilot config file mounts for jobs and serve
Michaelvll Sep 23, 2024
e020dea
Merge branch 'master' of github.com:skypilot-org/skypilot into policy…
Michaelvll Sep 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ".[all]"
pip install -e ".[all]"
pip install pytest pytest-xdist pytest-env>=0.6 memory-profiler==0.61.0

- name: Run tests with pytest
Expand Down
8 changes: 8 additions & 0 deletions docs/source/reference/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ Available fields and semantics:
# Default: false.
disable_ecc: false

# Custom policy to be applied to all tasks.
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
#
# The policy function to be applied and mutate all tasks, which can be used to
# enforce certain policies on all tasks.
#
# See details in: <TODO: add link to policy docs>
policy: my_package.skypilot_policy_fn_v1
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved

# Advanced AWS configurations (optional).
# Apply to all new instances but not existing ones.
aws:
Expand Down
1 change: 1 addition & 0 deletions examples/policy/config_label_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
policy: example_policy.config_label_policy
concretevitamin marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 2 additions & 0 deletions examples/policy/example_policy/example_policy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from example_policy.skypilot_policy import config_label_policy
from example_policy.skypilot_policy import task_label_policy
31 changes: 31 additions & 0 deletions examples/policy/example_policy/example_policy/skypilot_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import getpass

from sky import MutatedUserTask
from sky import UserTask
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved


def task_label_policy(user_task: UserTask) -> MutatedUserTask:
"""Example policy."""
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
local_user_name = getpass.getuser()

# Add label for task with the local user name
task = user_task.task
for r in task.resources:
r.labels['local_user'] = local_user_name

return MutatedUserTask(task=task, skypilot_config=user_task.skypilot_config)


def config_label_policy(user_task: UserTask) -> MutatedUserTask:
"""Example policy."""
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
local_user_name = getpass.getuser()

# Add label for skypilot_config with the local user name
skypilot_config = user_task.skypilot_config
if not skypilot_config.get('aws'):
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
skypilot_config['aws'] = {}
labels = skypilot_config['aws'].get('labels', {})
labels['local_user'] = local_user_name
skypilot_config['aws']['labels'] = labels

return MutatedUserTask(task=user_task.task, skypilot_config=skypilot_config)
7 changes: 7 additions & 0 deletions examples/policy/example_policy/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "example_policy"
version = "0.0.1"
12 changes: 12 additions & 0 deletions examples/policy/task.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
resources:
cloud: aws
cpus: 2
labels:
other_labels: test


setup: |
echo "setup"

run: |
echo "run"
1 change: 1 addition & 0 deletions examples/policy/task_label_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
policy: example_policy.task_label_policy
4 changes: 4 additions & 0 deletions sky/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]):
from sky.jobs.core import spot_tail_logs
from sky.optimizer import Optimizer
from sky.optimizer import OptimizeTarget
from sky.policy import MutatedUserTask
from sky.policy import UserTask
from sky.resources import Resources
from sky.skylet.job_lib import JobStatus
from sky.status_lib import ClusterStatus
Expand Down Expand Up @@ -185,4 +187,6 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]):
# core APIs Storage Management
'storage_ls',
'storage_delete',
'UserTask',
'MutatedUserTask',
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
]
30 changes: 17 additions & 13 deletions sky/dag.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""DAGs: user applications to be run."""
import pprint
import threading
import typing
from typing import List, Optional

if typing.TYPE_CHECKING:
from sky import task


class Dag:
"""Dag: a user application, represented as a DAG of Tasks.
Expand All @@ -13,37 +17,37 @@ class Dag:
>>> task = sky.Task(...)
"""

def __init__(self):
self.tasks = []
def __init__(self) -> None:
self.tasks: List['task.Task'] = []
import networkx as nx # pylint: disable=import-outside-toplevel

self.graph = nx.DiGraph()
self.name = None
self.name: Optional[str] = None

def add(self, task):
def add(self, task: 'task.Task') -> None:
self.graph.add_node(task)
self.tasks.append(task)

def remove(self, task):
def remove(self, task: 'task.Task') -> None:
self.tasks.remove(task)
self.graph.remove_node(task)

def add_edge(self, op1, op2):
def add_edge(self, op1: 'task.Task', op2: 'task.Task') -> None:
assert op1 in self.graph.nodes
assert op2 in self.graph.nodes
self.graph.add_edge(op1, op2)

def __len__(self):
def __len__(self) -> int:
return len(self.tasks)

def __enter__(self):
def __enter__(self) -> 'Dag':
push_dag(self)
return self

def __exit__(self, exc_type, exc_value, traceback):
def __exit__(self, exc_type, exc_value, traceback) -> None:
pop_dag()

def __repr__(self):
def __repr__(self) -> str:
pformat = pprint.pformat(self.tasks)
return f'DAG:\n{pformat}'

Expand All @@ -70,15 +74,15 @@ def is_chain(self) -> bool:

class _DagContext(threading.local):
"""A thread-local stack of Dags."""
_current_dag = None
_current_dag: Optional[Dag] = None
_previous_dags: List[Dag] = []

def push_dag(self, dag):
def push_dag(self, dag: Dag):
if self._current_dag is not None:
self._previous_dags.append(self._current_dag)
self._current_dag = dag

def pop_dag(self):
def pop_dag(self) -> Optional[Dag]:
old_dag = self._current_dag
if self._previous_dags:
self._current_dag = self._previous_dags.pop()
Expand Down
2 changes: 1 addition & 1 deletion sky/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _execute(
handle: Optional[backends.ResourceHandle]; the handle to the cluster. None
if dryrun.
"""
dag = dag_utils.convert_entrypoint_to_dag(entrypoint)
dag = dag_utils.convert_entrypoint_to_dag_and_apply_policy(entrypoint)
assert len(dag) == 1, f'We support 1 task for now. {dag}'
task = dag.tasks[0]

Expand Down
2 changes: 2 additions & 0 deletions sky/jobs/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(self, job_id: int, dag_yaml: str,
if len(self._dag.tasks) <= 1:
task_name = self._dag_name
else:
assert task.name is not None, task
task_name = task.name
# This is guaranteed by the spot_launch API, where we fill in
# the task.name with
Expand Down Expand Up @@ -447,6 +448,7 @@ def _cleanup(job_id: int, dag_yaml: str):
# controller, we should keep it in sync with JobsController.__init__()
dag, _ = _get_dag_and_name(dag_yaml)
for task in dag.tasks:
assert task.name is not None, task
cluster_name = managed_job_utils.generate_managed_job_cluster_name(
task.name, job_id)
recovery_strategy.terminate_cluster(cluster_name)
Expand Down
2 changes: 1 addition & 1 deletion sky/jobs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def launch(
entrypoint = task
dag_uuid = str(uuid.uuid4().hex[:4])

dag = dag_utils.convert_entrypoint_to_dag(entrypoint)
dag = dag_utils.convert_entrypoint_to_dag_and_apply_policy(entrypoint)
if not dag.is_chain():
with ux_utils.print_exception_no_traceback():
raise ValueError('Only single-task or chain DAG is '
Expand Down
154 changes: 154 additions & 0 deletions sky/policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""Customize policy by users."""
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
import copy
import dataclasses
import importlib
import os
import tempfile
import typing
from typing import Any, Callable, Dict, Optional

from sky import dag as dag_lib
from sky import sky_logging
from sky import skypilot_config
from sky.utils import common_utils
from sky.utils import ux_utils

logger = sky_logging.init_logger(__name__)

if typing.TYPE_CHECKING:
from sky import task as task_lib


@dataclasses.dataclass
class UserTask:
task: 'task_lib.Task'
skypilot_config: Dict[str, Any]


@dataclasses.dataclass
class MutatedUserTask:
task: 'task_lib.Task'
skypilot_config: Dict[str, Any]


class Policy:
concretevitamin marked this conversation as resolved.
Show resolved Hide resolved
"""User-defined policy.
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved

A user-defined policy is a string to a python function that can be imported
from the same environment where SkyPilot is running.

It can be specified in the SkyPilot config file under the key 'policy', e.g.
concretevitamin marked this conversation as resolved.
Show resolved Hide resolved

policy: my_package.skypilot_policy_fn_v1

The python function is expected to have the following signature:

def skypilot_policy_fn_v1(user_task: UserTask) -> MutatedUserTask:
...
return MutatedUserTask(task=..., skypilot_config=...)

The function can mutate both task and skypilot_config.
"""

def __init__(self) -> None:
"""Initialize the policy from SkyPilot config."""
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
# Policy is a string to a python function within some user provided
# module.
self.policy: Optional[str] = skypilot_config.get_nested(('policy',),
None)
self.policy_fn: Optional[Callable[[UserTask], MutatedUserTask]] = None
if self.policy is not None:
try:
module_path, func_name = self.policy.rsplit('.', 1)
module = importlib.import_module(module_path)
except ImportError as e:
with ux_utils.print_exception_no_traceback():
raise ImportError(
f'Failed to import policy module: {module_path}. '
'Please check if the module is in your Python '
'environment.') from e
try:
self.policy_fn = getattr(module, func_name)
except AttributeError as e:
with ux_utils.print_exception_no_traceback():
raise AttributeError(
f'Failed to get policy function: {func_name} from '
f'module: {module_path}. Please check with your policy '
f'admin if the function {func_name!r} is in the '
'module.') from e

def apply(self, dag: 'dag_lib.Dag') -> 'dag_lib.Dag':
"""Apply user-defined policy to a DAG.
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved

It mutates a Dag by applying user-defined policy and also update the
global SkyPilot config if there is any changes made by the policy.

Args:
dag: The dag to be mutated by the policy.

Returns:
The mutated dag.
"""
if self.policy_fn is None:
return dag
logger.info(f'Applying policy: {self.policy}')
original_config = skypilot_config.to_dict()
config = copy.deepcopy(original_config)
mutated_dag = dag_lib.Dag()
mutated_dag.name = dag.name

mutated_config = None
for task in dag.tasks:
user_task = UserTask(task, config)
mutated_user_task = self.policy_fn(user_task)
if mutated_config is None:
mutated_config = mutated_user_task.skypilot_config
else:
if mutated_config != mutated_user_task.skypilot_config:
with ux_utils.print_exception_no_traceback():
raise ValueError(
'All tasks must have the same skypilot '
'config after applying the policy. Please'
'check with your policy admin for details.')
mutated_dag.add(mutated_user_task.task)

# Update the new_dag's graph with the old dag's graph
for u, v in dag.graph.edges:
u_idx = dag.tasks.index(u)
v_idx = dag.tasks.index(v)
mutated_dag.graph.add_edge(mutated_dag.tasks[u_idx],
mutated_dag.tasks[v_idx])

if original_config != mutated_config:
with tempfile.NamedTemporaryFile(
delete=False,
mode='w',
prefix='policy-mutated-skypilot-config-',
suffix='.yaml') as temp_file:
common_utils.dump_yaml(temp_file.name, mutated_config)
os.environ[skypilot_config.ENV_VAR_SKYPILOT_CONFIG] = temp_file.name
logger.debug(f'Updated SkyPilot config: {temp_file.name}')
# TODO(zhwu): This is not a clean way to update the SkyPilot config,
# because we are resetting the global context for a single DAG,
# which is conceptually weird.
importlib.reload(skypilot_config)

return mutated_dag

def apply_to_task(self, task: 'task_lib.Task') -> 'task_lib.Task':
"""Apply user-defined policy to a task.

It mutates a task by applying user-defined policy and also update the
global SkyPilot config if there is any changes made by the policy.

Args:
task: The task to be mutated by the policy.

Returns:
The mutated task.
"""
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved

dag = dag_lib.Dag()
dag.add(task)
dag = self.apply(dag)
return dag.tasks[0]
3 changes: 3 additions & 0 deletions sky/serve/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sky
from sky import backends
from sky import exceptions
from sky import policy
from sky import sky_logging
from sky import task as task_lib
from sky.backends import backend_utils
Expand Down Expand Up @@ -124,6 +125,8 @@ def up(

_validate_service_task(task)

task = policy.Policy().apply_to_task(task)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit unclear if it's in-place since it also returns the task. How about "apply_in_place"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not in-place. Either way is fine, but I found not in-place is more commonly seen operations.


controller_utils.maybe_translate_local_file_mounts_and_sync_up(task,
path='serve')

Expand Down
2 changes: 1 addition & 1 deletion sky/skypilot_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@

# The loaded config.
_dict: Optional[Dict[str, Any]] = None
_loaded_config_path = None
_loaded_config_path: Optional[str] = None


def _get_nested(configs: Optional[Dict[str, Any]], keys: Iterable[str],
Expand Down
Loading
Loading