Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Spot] Refactor spot APIs into spot.xxx #3417

Merged
merged 14 commits into from
Apr 11, 2024
8 changes: 4 additions & 4 deletions sky/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]):
from sky.skylet.job_lib import JobStatus
# TODO (zhwu): These imports are for backward compatibility, and spot APIs
# should be called with `sky.spot.xxx` instead. Remove in release 0.7.0
from sky.spot import cancel as spot_cancel
from sky.spot import launch as spot_launch
from sky.spot import queue as spot_queue
from sky.spot import tail_logs as spot_tail_logs
from sky.spot.core import spot_cancel
from sky.spot.core import spot_launch
from sky.spot.core import spot_queue
from sky.spot.core import spot_tail_logs
from sky.status_lib import ClusterStatus
from sky.task import Task

Expand Down
2 changes: 1 addition & 1 deletion sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2292,7 +2292,7 @@ def is_controller_accessible(
# will not start the controller manually from the cloud console.
#
# The acquire_lock_timeout is set to 0 to avoid hanging the command when
# multiple spot_launch commands are running at the same time. Our later
# multiple spot.launch commands are running at the same time. Our later
# code will check if the controller is accessible by directly checking
# the ssh connection to the controller, if it fails to get accurate
# status of the controller.
Expand Down
12 changes: 6 additions & 6 deletions sky/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def start(
retry_until_up: bool = False,
down: bool = False, # pylint: disable=redefined-outer-name
force: bool = False,
) -> None:
) -> backends.CloudVmRayResourceHandle:
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
"""Restart a cluster.

Expand Down Expand Up @@ -276,11 +276,11 @@ def start(
if down and idle_minutes_to_autostop is None:
raise ValueError(
'`idle_minutes_to_autostop` must be set if `down` is True.')
_start(cluster_name,
idle_minutes_to_autostop,
retry_until_up,
down,
force=force)
return _start(cluster_name,
idle_minutes_to_autostop,
retry_until_up,
down,
force=force)


def _stop_not_supported_message(resources: 'resources_lib.Resources') -> str:
Expand Down
2 changes: 1 addition & 1 deletion sky/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def _execute(
f'automatically recover from preemptions.{reset}\n{yellow}To '
'get automatic recovery, use managed spot instead: '
f'{reset}{bold}sky spot launch{reset} {yellow}or{reset} '
f'{bold}sky.spot_launch(){reset}.')
f'{bold}sky.spot.launch(){reset}.')

if Stage.OPTIMIZE in stages:
if task.best_resources is None:
Expand Down
237 changes: 127 additions & 110 deletions sky/spot/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,110 @@
from sky.utils import ux_utils


@usage_lib.entrypoint
def launch(
task: Union['sky.Task', 'sky.Dag'],
name: Optional[str] = None,
stream_logs: bool = True,
detach_run: bool = False,
retry_until_up: bool = False,
) -> None:
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
"""Launch a managed spot job.

Please refer to the sky.cli.spot_launch for the document.

Args:
task: sky.Task, or sky.Dag (experimental; 1-task only) to launch as a
managed spot job.
name: Name of the spot job.
detach_run: Whether to detach the run.

Raises:
ValueError: cluster does not exist.
sky.exceptions.NotSupportedError: the feature is not supported.
"""
entrypoint = task
dag_uuid = str(uuid.uuid4().hex[:4])

dag = dag_utils.convert_entrypoint_to_dag(entrypoint)
if not dag.is_chain():
with ux_utils.print_exception_no_traceback():
raise ValueError('Only single-task or chain DAG is allowed for '
f'sky.spot.launch. Dag:\n{dag}')

dag_utils.maybe_infer_and_fill_dag_and_task_names(dag)

task_names = set()
for task_ in dag.tasks:
if task_.name in task_names:
raise ValueError(
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
f'Task name {task_.name!r} is duplicated in the DAG. Either '
'change task names to be unique, or specify the DAG name only '
'and comment out the task names (so that they will be auto-'
'generated) .')
task_names.add(task_.name)

dag_utils.fill_default_spot_config_in_dag_for_spot_launch(dag)

for task_ in dag.tasks:
controller_utils.maybe_translate_local_file_mounts_and_sync_up(
task_, path='spot')

with tempfile.NamedTemporaryFile(prefix=f'spot-dag-{dag.name}-',
mode='w') as f:
dag_utils.dump_chain_dag_to_yaml(dag, f.name)
controller_name = spot_utils.SPOT_CONTROLLER_NAME
prefix = constants.SPOT_TASK_YAML_PREFIX
remote_user_yaml_path = f'{prefix}/{dag.name}-{dag_uuid}.yaml'
remote_user_config_path = f'{prefix}/{dag.name}-{dag_uuid}.config_yaml'
controller_resources = controller_utils.get_controller_resources(
controller_type='spot',
controller_resources_config=constants.CONTROLLER_RESOURCES)

vars_to_fill = {
'remote_user_yaml_path': remote_user_yaml_path,
'user_yaml_path': f.name,
'spot_controller': controller_name,
# Note: actual spot cluster name will be <task.name>-<spot job ID>
'dag_name': dag.name,
'retry_until_up': retry_until_up,
'remote_user_config_path': remote_user_config_path,
'sky_python_cmd': skylet_constants.SKY_PYTHON_CMD,
'modified_catalogs':
service_catalog_common.get_modified_catalog_file_mounts(),
**controller_utils.shared_controller_vars_to_fill(
'spot',
remote_user_config_path=remote_user_config_path,
),
}

yaml_path = os.path.join(constants.SPOT_CONTROLLER_YAML_PREFIX,
f'{name}-{dag_uuid}.yaml')
common_utils.fill_template(constants.SPOT_CONTROLLER_TEMPLATE,
vars_to_fill,
output_path=yaml_path)
controller_task = task_lib.Task.from_yaml(yaml_path)
controller_task.set_resources(controller_resources)

controller_task.spot_dag = dag
assert len(controller_task.resources) == 1

sky_logging.print(
f'{colorama.Fore.YELLOW}'
f'Launching managed spot job {dag.name!r} from spot controller...'
f'{colorama.Style.RESET_ALL}')
sky_logging.print('Launching spot controller...')
sky.launch(task=controller_task,
stream_logs=stream_logs,
cluster_name=controller_name,
detach_run=detach_run,
idle_minutes_to_autostop=skylet_constants.
CONTROLLER_IDLE_MINUTES_TO_AUTOSTOP,
retry_until_up=True,
_disable_controller_check=True)


@usage_lib.entrypoint
def queue(refresh: bool, skip_finished: bool = False) -> List[Dict[str, Any]]:
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
Expand Down Expand Up @@ -55,7 +159,7 @@ def queue(refresh: bool, skip_finished: bool = False) -> List[Dict[str, Any]]:
"""
stopped_message = ''
if not refresh:
stopped_message = ('No in-progress spot jobs.')
stopped_message = 'No in-progress spot jobs.'
try:
handle = backend_utils.is_controller_accessible(
controller_type=controller_utils.Controllers.SPOT_CONTROLLER,
Expand Down Expand Up @@ -135,8 +239,9 @@ def cancel(name: Optional[str] = None,
argument_str = f'job_ids={job_id_str}' if len(job_ids) > 0 else ''
argument_str += f' name={name}' if name is not None else ''
argument_str += ' all' if all else ''
raise ValueError('Can only specify one of JOB_IDS or name or all. '
f'Provided {argument_str!r}.')
with ux_utils.print_exception_no_traceback():
raise ValueError('Can only specify one of JOB_IDS or name or all. '
f'Provided {argument_str!r}.')

backend = backend_utils.get_backend_from_handle(handle)
assert isinstance(backend, backends.CloudVmRayBackend)
Expand All @@ -157,7 +262,8 @@ def cancel(name: Optional[str] = None,
'Failed to cancel managed spot job',
stdout)
except exceptions.CommandError as e:
raise RuntimeError(e.error_msg) from e
with ux_utils.print_exception_no_traceback():
raise RuntimeError(e.error_msg) from e

sky_logging.print(stdout)
if 'Multiple jobs found with name' in stdout:
Expand Down Expand Up @@ -191,109 +297,20 @@ def tail_logs(name: Optional[str], job_id: Optional[int], follow: bool) -> None:
backend.tail_spot_logs(handle, job_id=job_id, job_name=name, follow=follow)


@usage_lib.entrypoint
def launch(
task: Union['sky.Task', 'sky.Dag'],
name: Optional[str] = None,
stream_logs: bool = True,
detach_run: bool = False,
retry_until_up: bool = False,
):
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
"""Launch a managed spot job.

Please refer to the sky.cli.spot_launch for the document.

Args:
task: sky.Task, or sky.Dag (experimental; 1-task only) to launch as a
managed spot job.
name: Name of the spot job.
detach_run: Whether to detach the run.

Raises:
ValueError: cluster does not exist.
sky.exceptions.NotSupportedError: the feature is not supported.
"""
entrypoint = task
dag_uuid = str(uuid.uuid4().hex[:4])

dag = dag_utils.convert_entrypoint_to_dag(entrypoint)
assert dag.is_chain(), ('Only single-task or chain DAG is '
'allowed for spot_launch.', dag)

dag_utils.maybe_infer_and_fill_dag_and_task_names(dag)

task_names = set()
for task_ in dag.tasks:
if task_.name in task_names:
raise ValueError(
f'Task name {task_.name!r} is duplicated in the DAG. Either '
'change task names to be unique, or specify the DAG name only '
'and comment out the task names (so that they will be auto-'
'generated) .')
task_names.add(task_.name)

dag_utils.fill_default_spot_config_in_dag_for_spot_launch(dag)

for task_ in dag.tasks:
controller_utils.maybe_translate_local_file_mounts_and_sync_up(
task_, path='spot')

with tempfile.NamedTemporaryFile(prefix=f'spot-dag-{dag.name}-',
mode='w') as f:
dag_utils.dump_chain_dag_to_yaml(dag, f.name)
controller_name = spot_utils.SPOT_CONTROLLER_NAME
prefix = constants.SPOT_TASK_YAML_PREFIX
remote_user_yaml_path = f'{prefix}/{dag.name}-{dag_uuid}.yaml'
remote_user_config_path = f'{prefix}/{dag.name}-{dag_uuid}.config_yaml'
controller_resources = (controller_utils.get_controller_resources(
controller_type='spot',
controller_resources_config=constants.CONTROLLER_RESOURCES))

vars_to_fill = {
'remote_user_yaml_path': remote_user_yaml_path,
'user_yaml_path': f.name,
'spot_controller': controller_name,
# Note: actual spot cluster name will be <task.name>-<spot job ID>
'dag_name': dag.name,
'retry_until_up': retry_until_up,
'remote_user_config_path': remote_user_config_path,
'sky_python_cmd': skylet_constants.SKY_PYTHON_CMD,
'modified_catalogs':
service_catalog_common.get_modified_catalog_file_mounts(),
**controller_utils.shared_controller_vars_to_fill(
'spot',
remote_user_config_path=remote_user_config_path,
),
}

yaml_path = os.path.join(constants.SPOT_CONTROLLER_YAML_PREFIX,
f'{name}-{dag_uuid}.yaml')
common_utils.fill_template(constants.SPOT_CONTROLLER_TEMPLATE,
vars_to_fill,
output_path=yaml_path)
controller_task = task_lib.Task.from_yaml(yaml_path)
assert len(controller_task.resources) == 1, controller_task
# Backward compatibility: if the user changed the
# spot-controller.yaml.j2 to customize the controller resources,
# we should use it.
controller_task_resources = list(controller_task.resources)[0]
if not controller_task_resources.is_empty():
controller_resources = controller_task_resources
controller_task.set_resources(controller_resources)

controller_task.spot_dag = dag
assert len(controller_task.resources) == 1

print(f'{colorama.Fore.YELLOW}'
f'Launching managed spot job {dag.name!r} from spot controller...'
f'{colorama.Style.RESET_ALL}')
print('Launching spot controller...')
sky.launch(task=controller_task,
stream_logs=stream_logs,
cluster_name=controller_name,
detach_run=detach_run,
idle_minutes_to_autostop=skylet_constants.
CONTROLLER_IDLE_MINUTES_TO_AUTOSTOP,
retry_until_up=True,
_disable_controller_check=True)
spot_launch = common_utils.deprecated_function(launch,
name='sky.spot.launch',
deprecated_name='spot_launch',
removing_version='0.7.0')
spot_queue = common_utils.deprecated_function(queue,
name='sky.spot.queue',
deprecated_name='spot_queue',
removing_version='0.7.0')
spot_cancel = common_utils.deprecated_function(cancel,
name='sky.spot.cancel',
deprecated_name='spot_cancel',
removing_version='0.7.0')
spot_tail_logs = common_utils.deprecated_function(
tail_logs,
name='sky.spot.tail_logs',
deprecated_name='spot_tail_logs',
removing_version='0.7.0')
4 changes: 2 additions & 2 deletions sky/utils/command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def rsync(

backoff = common_utils.Backoff(initial_backoff=5, max_backoff_factor=5)
while max_retry >= 0:
returncode, _, stderr = log_lib.run_with_log(
returncode, stdout, stderr = log_lib.run_with_log(
command,
log_path=log_path,
stream_logs=stream_logs,
Expand All @@ -454,7 +454,7 @@ def rsync(
subprocess_utils.handle_returncode(returncode,
command,
error_msg,
stderr=stderr,
stderr=stdout + stderr,
stream_logs=stream_logs)

def check_connection(self) -> bool:
Expand Down
17 changes: 17 additions & 0 deletions sky/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,3 +626,20 @@ def fill_template(template_name: str, variables: Dict,
content = j2_template.render(**variables)
with open(output_path, 'w', encoding='utf-8') as fout:
fout.write(content)


def deprecated_function(func: Callable, name: str, deprecated_name: str,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there other deprecated function in the system? if so, change to this decorator as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I don't think there is any other deprecated functions in the system.

removing_version: str) -> Callable:
"""Decorator for creating deprecated functions, for backward compatibility.

It will result in a warning being emitted when the function is used.
"""

@functools.wraps(func)
def new_func(*args, **kwargs):
logger.warning(
f'Call to deprecated function {deprecated_name}, which will be '
f'removed in {removing_version}. Please use {name}() instead.')
return func(*args, **kwargs)

return new_func
2 changes: 1 addition & 1 deletion sky/utils/controller_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

logger = sky_logging.init_logger(__name__)

# Message thrown when APIs sky.spot_launch(),sky.serve.up() received an invalid
# Message thrown when APIs sky.spot.launch(),sky.serve.up() received an invalid
# controller resources spec.
CONTROLLER_RESOURCES_NOT_VALID_MESSAGE = (
'{controller_type} controller resources is not valid, please check '
Expand Down
Loading
Loading