Skip to content

Commit

Permalink
Ray AWS launcher (#518)
Browse files Browse the repository at this point in the history
  • Loading branch information
jieru-hu authored Oct 28, 2020
1 parent e98f518 commit 6244a3e
Show file tree
Hide file tree
Showing 33 changed files with 1,869 additions and 7 deletions.
4 changes: 2 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ commands:
name: Preparing environment - Other dependency
command: |
sudo apt-get update
sudo apt-get install -y expect fish openjdk-11-jre
sudo apt-get install -y expect fish openjdk-11-jre rsync
- run:
name: Preparing environment - Hydra
command: |
Expand Down Expand Up @@ -273,4 +273,4 @@ workflows:


orbs:
win: circleci/[email protected]
win: circleci/[email protected]
1 change: 0 additions & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,3 @@ skip=
,hydra/grammar/gen
,tools/configen/example/gen
,tools/configen/tests/test_modules/expected

3 changes: 3 additions & 0 deletions plugins/hydra_ray_launcher/MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
global-exclude *.pyc
global-exclude __pycache__
recursive-include hydra_plugins/* *.yaml
4 changes: 4 additions & 0 deletions plugins/hydra_ray_launcher/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Hydra Ray Launcher
Provides a [`Ray`](https://docs.ray.io/en/latest/) based Hydra Launcher supporting execution on AWS.

See [website](https://hydra.cc/docs/next/plugins/ray_launcher) for more information
1 change: 1 addition & 0 deletions plugins/hydra_ray_launcher/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
6 changes: 6 additions & 0 deletions plugins/hydra_ray_launcher/example/conf/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
defaults:
- hydra/launcher: ray_local


random_seed: 0
checkpoint_path: checkpoint
24 changes: 24 additions & 0 deletions plugins/hydra_ray_launcher/example/conf/extra_configs/aws.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# @package _global_

hydra:
launcher:
sync_up:
# source dir is relative in this case, assuming you are running from
# <project_root>/hydra/plugins/hydra_ray_launcher/example
# absolute path is also supported.
source_dir: "."
# we leave target_dir to be null
# as a result the files will be synced to a temp dir on remote cluster.
# the temp dir will be cleaned up after the jobs are done.
# recommend to leave target_dir to be null if you are syncing code/artifacts to remote cluster so you don't need
# configure $PYTHONPATH on remote cluster
include: ["model", "*.py"]
# No need to sync up config files.
exclude: ["*"]
sync_down:
include: ["*.pt", "*/"]
# No need to sync down config files.
exclude: ["*"]
ray_cluster_cfg:
provider:
cache_stopped_nodes: true
19 changes: 19 additions & 0 deletions plugins/hydra_ray_launcher/example/model/my_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
from datetime import datetime
from pathlib import Path

log = logging.getLogger(__name__)


class MyModel:
def __init__(self, random_seed: int):
self.random_seed = random_seed
log.info("Init my model")

def save(self, checkpoint_path: str) -> None:
checkpoint_dir = Path(checkpoint_path)
checkpoint_dir.mkdir(parents=True, exist_ok=True)
log.info(f"Created dir for checkpoints. dir={checkpoint_dir}")
with open(checkpoint_dir / f"checkpoint_{self.random_seed}.pt", "w") as f:
f.write(f"checkpoint@{datetime.now()}")
20 changes: 20 additions & 0 deletions plugins/hydra_ray_launcher/example/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging

import hydra
from model.my_model import MyModel
from omegaconf import DictConfig

log = logging.getLogger(__name__)


@hydra.main(config_path="conf", config_name="config")
def main(cfg: DictConfig) -> None:
log.info("Start training...")
model = MyModel(cfg.random_seed)
# save checkpoint to current working dir.
model.save(cfg.checkpoint_path)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from hydra.core.config_search_path import ConfigSearchPath
from hydra.plugins.search_path_plugin import SearchPathPlugin


class RayLauncherSearchPathPlugin(SearchPathPlugin):
def manipulate_search_path(self, search_path: ConfigSearchPath) -> None:
# Appends the search path for this plugin to the end of the search path
search_path.append(
"hydra-ray-launcher", "pkg://hydra_plugins.hydra_ray_launcher.conf"
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
import os
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Sequence

import cloudpickle # type: ignore
import pickle5 as pickle # type: ignore
from hydra.core.hydra_config import HydraConfig
from hydra.core.singleton import Singleton
from hydra.core.utils import JobReturn, configure_log, filter_overrides, setup_globals
from omegaconf import OmegaConf, open_dict, read_write

from hydra_plugins.hydra_ray_launcher._launcher_util import ( # type: ignore
JOB_RETURN_PICKLE,
JOB_SPEC_PICKLE,
ray_down,
ray_exec,
ray_rsync_down,
ray_rsync_up,
ray_tmp_dir,
ray_up,
rsync,
)
from hydra_plugins.hydra_ray_launcher.ray_aws_launcher import ( # type: ignore
RayAWSLauncher,
)

log = logging.getLogger(__name__)


def _get_abs_code_dir(code_dir: str) -> str:
if code_dir:
if os.path.isabs(code_dir):
return code_dir
else:
return os.path.join(os.getcwd(), code_dir)
else:
return ""


def _pickle_jobs(tmp_dir: str, **jobspec: Dict[Any, Any]) -> None:
path = os.path.join(tmp_dir, JOB_SPEC_PICKLE)
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "wb") as f:
cloudpickle.dump(jobspec, f)
log.info(f"Pickle for jobs: {f.name}")


def launch(
launcher: RayAWSLauncher,
job_overrides: Sequence[Sequence[str]],
initial_job_idx: int,
) -> Sequence[JobReturn]:
setup_globals()
assert launcher.config is not None
assert launcher.config_loader is not None
assert launcher.task_function is not None

setup_commands = launcher.mandatory_install.install_commands
setup_commands.extend(launcher.ray_cluster_cfg.setup_commands)

with read_write(launcher.ray_cluster_cfg):
launcher.ray_cluster_cfg.setup_commands = setup_commands

configure_log(launcher.config.hydra.hydra_logging, launcher.config.hydra.verbose)

log.info(f"Ray Launcher is launching {len(job_overrides)} jobs, ")

with tempfile.TemporaryDirectory() as local_tmp_dir:
sweep_configs = []
for idx, overrides in enumerate(job_overrides):
idx = initial_job_idx + idx
ostr = " ".join(filter_overrides(overrides))
log.info(f"\t#{idx} : {ostr}")
sweep_config = launcher.config_loader.load_sweep_config(
launcher.config, list(overrides)
)
with open_dict(sweep_config):
# job.id will be set on the EC2 instance before running the job.
sweep_config.hydra.job.num = idx

sweep_configs.append(sweep_config)

_pickle_jobs(
tmp_dir=local_tmp_dir,
sweep_configs=sweep_configs, # type: ignore
task_function=launcher.task_function,
singleton_state=Singleton.get_state(),
)

with tempfile.NamedTemporaryFile(suffix=".yaml", delete=False) as f:
with open(f.name, "w") as file:
OmegaConf.save(
config=launcher.ray_cluster_cfg, f=file.name, resolve=True
)
launcher.ray_yaml_path = f.name
log.info(
f"Saving RayClusterConf in a temp yaml file: {launcher.ray_yaml_path}."
)

return launch_jobs(
launcher, local_tmp_dir, Path(HydraConfig.get().sweep.dir)
)


def launch_jobs(
launcher: RayAWSLauncher, local_tmp_dir: str, sweep_dir: Path
) -> Sequence[JobReturn]:
ray_up(launcher.ray_yaml_path)
with tempfile.TemporaryDirectory() as local_tmp_download_dir:

with ray_tmp_dir(
launcher.ray_yaml_path, launcher.docker_enabled
) as remote_tmp_dir:

ray_rsync_up(
launcher.ray_yaml_path, os.path.join(local_tmp_dir, ""), remote_tmp_dir
)

script_path = os.path.join(os.path.dirname(__file__), "_remote_invoke.py")
ray_rsync_up(launcher.ray_yaml_path, script_path, remote_tmp_dir)

if launcher.sync_up.source_dir:
source_dir = _get_abs_code_dir(launcher.sync_up.source_dir)
target_dir = (
launcher.sync_up.target_dir
if launcher.sync_up.target_dir
else remote_tmp_dir
)
rsync(
launcher.ray_yaml_path,
launcher.sync_up.include,
launcher.sync_up.exclude,
os.path.join(source_dir, ""),
target_dir,
)

ray_exec(
launcher.ray_yaml_path,
launcher.docker_enabled,
os.path.join(remote_tmp_dir, "_remote_invoke.py"),
remote_tmp_dir,
)

ray_rsync_down(
launcher.ray_yaml_path,
os.path.join(remote_tmp_dir, JOB_RETURN_PICKLE),
local_tmp_download_dir,
)

sync_down_cfg = launcher.sync_down

if (
sync_down_cfg.target_dir
or sync_down_cfg.source_dir
or sync_down_cfg.include
or sync_down_cfg.exclude
):
source_dir = (
sync_down_cfg.source_dir if sync_down_cfg.source_dir else sweep_dir
)
target_dir = (
sync_down_cfg.source_dir if sync_down_cfg.source_dir else sweep_dir
)
target_dir = Path(_get_abs_code_dir(target_dir))
target_dir.mkdir(parents=True, exist_ok=True)

rsync(
launcher.ray_yaml_path,
launcher.sync_down.include,
launcher.sync_down.exclude,
os.path.join(source_dir),
str(target_dir),
up=False,
)
log.info(
f"Syncing outputs from remote dir: {source_dir} to local dir: {target_dir.absolute()} "
)

if launcher.stop_cluster:
log.info("Stopping cluster now. (stop_cluster=true)")
if launcher.ray_cluster_cfg.provider.cache_stopped_nodes:
log.info("NOT deleting the cluster (provider.cache_stopped_nodes=true)")
else:
log.info("Deleted the cluster (provider.cache_stopped_nodes=false)")
ray_down(launcher.ray_yaml_path)
else:
log.warning(
"NOT stopping cluster, this may incur extra cost for you. (stop_cluster=false)"
)

with open(os.path.join(local_tmp_download_dir, JOB_RETURN_PICKLE), "rb") as f:
job_returns = pickle.load(f) # nosec
assert isinstance(job_returns, List)
for run in job_returns:
assert isinstance(run, JobReturn)
return job_returns
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
from pathlib import Path
from typing import Sequence

import ray
from hydra.core.singleton import Singleton
from hydra.core.utils import JobReturn, configure_log, filter_overrides, setup_globals
from omegaconf import open_dict

from hydra_plugins.hydra_ray_launcher._launcher_util import ( # type: ignore
launch_job_on_ray,
start_ray,
)
from hydra_plugins.hydra_ray_launcher.ray_local_launcher import ( # type: ignore
RayLocalLauncher,
)

log = logging.getLogger(__name__)


def launch(
launcher: RayLocalLauncher,
job_overrides: Sequence[Sequence[str]],
initial_job_idx: int,
) -> Sequence[JobReturn]:
setup_globals()
assert launcher.config is not None
assert launcher.config_loader is not None
assert launcher.task_function is not None

configure_log(launcher.config.hydra.hydra_logging, launcher.config.hydra.verbose)
sweep_dir = Path(str(launcher.config.hydra.sweep.dir))
sweep_dir.mkdir(parents=True, exist_ok=True)
log.info(
f"Ray Launcher is launching {len(job_overrides)} jobs, "
f"sweep output dir: {sweep_dir}"
)

start_ray(launcher.ray_init_cfg)

runs = []
for idx, overrides in enumerate(job_overrides):
idx = initial_job_idx + idx
ostr = " ".join(filter_overrides(overrides))
log.info(f"\t#{idx} : {ostr}")
sweep_config = launcher.config_loader.load_sweep_config(
launcher.config, list(overrides)
)
with open_dict(sweep_config):
# This typically coming from the underlying scheduler (SLURM_JOB_ID for instance)
# In that case, it will not be available here because we are still in the main process.
# but instead should be populated remotely before calling the task_function.
sweep_config.hydra.job.id = f"job_id_for_{idx}"
sweep_config.hydra.job.num = idx
ray_obj = launch_job_on_ray(
launcher.ray_remote_cfg,
sweep_config,
launcher.task_function,
Singleton.get_state(),
)
runs.append(ray_obj)

return [ray.get(run) for run in runs]
Loading

0 comments on commit 6244a3e

Please sign in to comment.