-
-
Notifications
You must be signed in to change notification settings - Fork 653
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
33 changed files
with
1,869 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,7 +66,7 @@ commands: | |
name: Preparing environment - Other dependency | ||
command: | | ||
sudo apt-get update | ||
sudo apt-get install -y expect fish openjdk-11-jre | ||
sudo apt-get install -y expect fish openjdk-11-jre rsync | ||
- run: | ||
name: Preparing environment - Hydra | ||
command: | | ||
|
@@ -273,4 +273,4 @@ workflows: | |
|
||
|
||
orbs: | ||
win: circleci/[email protected] | ||
win: circleci/[email protected] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,4 +23,3 @@ skip= | |
,hydra/grammar/gen | ||
,tools/configen/example/gen | ||
,tools/configen/tests/test_modules/expected | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
global-exclude *.pyc | ||
global-exclude __pycache__ | ||
recursive-include hydra_plugins/* *.yaml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Hydra Ray Launcher | ||
Provides a [`Ray`](https://docs.ray.io/en/latest/) based Hydra Launcher supporting execution on AWS. | ||
|
||
See [website](https://hydra.cc/docs/next/plugins/ray_launcher) for more information |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
defaults: | ||
- hydra/launcher: ray_local | ||
|
||
|
||
random_seed: 0 | ||
checkpoint_path: checkpoint |
24 changes: 24 additions & 0 deletions
24
plugins/hydra_ray_launcher/example/conf/extra_configs/aws.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# @package _global_ | ||
|
||
hydra: | ||
launcher: | ||
sync_up: | ||
# source dir is relative in this case, assuming you are running from | ||
# <project_root>/hydra/plugins/hydra_ray_launcher/example | ||
# absolute path is also supported. | ||
source_dir: "." | ||
# we leave target_dir to be null | ||
# as a result the files will be synced to a temp dir on remote cluster. | ||
# the temp dir will be cleaned up after the jobs are done. | ||
# recommend to leave target_dir to be null if you are syncing code/artifacts to remote cluster so you don't need | ||
# configure $PYTHONPATH on remote cluster | ||
include: ["model", "*.py"] | ||
# No need to sync up config files. | ||
exclude: ["*"] | ||
sync_down: | ||
include: ["*.pt", "*/"] | ||
# No need to sync down config files. | ||
exclude: ["*"] | ||
ray_cluster_cfg: | ||
provider: | ||
cache_stopped_nodes: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
import logging | ||
from datetime import datetime | ||
from pathlib import Path | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class MyModel: | ||
def __init__(self, random_seed: int): | ||
self.random_seed = random_seed | ||
log.info("Init my model") | ||
|
||
def save(self, checkpoint_path: str) -> None: | ||
checkpoint_dir = Path(checkpoint_path) | ||
checkpoint_dir.mkdir(parents=True, exist_ok=True) | ||
log.info(f"Created dir for checkpoints. dir={checkpoint_dir}") | ||
with open(checkpoint_dir / f"checkpoint_{self.random_seed}.pt", "w") as f: | ||
f.write(f"checkpoint@{datetime.now()}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
import logging | ||
|
||
import hydra | ||
from model.my_model import MyModel | ||
from omegaconf import DictConfig | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
@hydra.main(config_path="conf", config_name="config") | ||
def main(cfg: DictConfig) -> None: | ||
log.info("Start training...") | ||
model = MyModel(cfg.random_seed) | ||
# save checkpoint to current working dir. | ||
model.save(cfg.checkpoint_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
11 changes: 11 additions & 0 deletions
11
plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
from hydra.core.config_search_path import ConfigSearchPath | ||
from hydra.plugins.search_path_plugin import SearchPathPlugin | ||
|
||
|
||
class RayLauncherSearchPathPlugin(SearchPathPlugin): | ||
def manipulate_search_path(self, search_path: ConfigSearchPath) -> None: | ||
# Appends the search path for this plugin to the end of the search path | ||
search_path.append( | ||
"hydra-ray-launcher", "pkg://hydra_plugins.hydra_ray_launcher.conf" | ||
) |
199 changes: 199 additions & 0 deletions
199
plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/_core_aws.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
import logging | ||
import os | ||
import tempfile | ||
from pathlib import Path | ||
from typing import Any, Dict, List, Sequence | ||
|
||
import cloudpickle # type: ignore | ||
import pickle5 as pickle # type: ignore | ||
from hydra.core.hydra_config import HydraConfig | ||
from hydra.core.singleton import Singleton | ||
from hydra.core.utils import JobReturn, configure_log, filter_overrides, setup_globals | ||
from omegaconf import OmegaConf, open_dict, read_write | ||
|
||
from hydra_plugins.hydra_ray_launcher._launcher_util import ( # type: ignore | ||
JOB_RETURN_PICKLE, | ||
JOB_SPEC_PICKLE, | ||
ray_down, | ||
ray_exec, | ||
ray_rsync_down, | ||
ray_rsync_up, | ||
ray_tmp_dir, | ||
ray_up, | ||
rsync, | ||
) | ||
from hydra_plugins.hydra_ray_launcher.ray_aws_launcher import ( # type: ignore | ||
RayAWSLauncher, | ||
) | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
def _get_abs_code_dir(code_dir: str) -> str: | ||
if code_dir: | ||
if os.path.isabs(code_dir): | ||
return code_dir | ||
else: | ||
return os.path.join(os.getcwd(), code_dir) | ||
else: | ||
return "" | ||
|
||
|
||
def _pickle_jobs(tmp_dir: str, **jobspec: Dict[Any, Any]) -> None: | ||
path = os.path.join(tmp_dir, JOB_SPEC_PICKLE) | ||
os.makedirs(os.path.dirname(path), exist_ok=True) | ||
with open(path, "wb") as f: | ||
cloudpickle.dump(jobspec, f) | ||
log.info(f"Pickle for jobs: {f.name}") | ||
|
||
|
||
def launch( | ||
launcher: RayAWSLauncher, | ||
job_overrides: Sequence[Sequence[str]], | ||
initial_job_idx: int, | ||
) -> Sequence[JobReturn]: | ||
setup_globals() | ||
assert launcher.config is not None | ||
assert launcher.config_loader is not None | ||
assert launcher.task_function is not None | ||
|
||
setup_commands = launcher.mandatory_install.install_commands | ||
setup_commands.extend(launcher.ray_cluster_cfg.setup_commands) | ||
|
||
with read_write(launcher.ray_cluster_cfg): | ||
launcher.ray_cluster_cfg.setup_commands = setup_commands | ||
|
||
configure_log(launcher.config.hydra.hydra_logging, launcher.config.hydra.verbose) | ||
|
||
log.info(f"Ray Launcher is launching {len(job_overrides)} jobs, ") | ||
|
||
with tempfile.TemporaryDirectory() as local_tmp_dir: | ||
sweep_configs = [] | ||
for idx, overrides in enumerate(job_overrides): | ||
idx = initial_job_idx + idx | ||
ostr = " ".join(filter_overrides(overrides)) | ||
log.info(f"\t#{idx} : {ostr}") | ||
sweep_config = launcher.config_loader.load_sweep_config( | ||
launcher.config, list(overrides) | ||
) | ||
with open_dict(sweep_config): | ||
# job.id will be set on the EC2 instance before running the job. | ||
sweep_config.hydra.job.num = idx | ||
|
||
sweep_configs.append(sweep_config) | ||
|
||
_pickle_jobs( | ||
tmp_dir=local_tmp_dir, | ||
sweep_configs=sweep_configs, # type: ignore | ||
task_function=launcher.task_function, | ||
singleton_state=Singleton.get_state(), | ||
) | ||
|
||
with tempfile.NamedTemporaryFile(suffix=".yaml", delete=False) as f: | ||
with open(f.name, "w") as file: | ||
OmegaConf.save( | ||
config=launcher.ray_cluster_cfg, f=file.name, resolve=True | ||
) | ||
launcher.ray_yaml_path = f.name | ||
log.info( | ||
f"Saving RayClusterConf in a temp yaml file: {launcher.ray_yaml_path}." | ||
) | ||
|
||
return launch_jobs( | ||
launcher, local_tmp_dir, Path(HydraConfig.get().sweep.dir) | ||
) | ||
|
||
|
||
def launch_jobs( | ||
launcher: RayAWSLauncher, local_tmp_dir: str, sweep_dir: Path | ||
) -> Sequence[JobReturn]: | ||
ray_up(launcher.ray_yaml_path) | ||
with tempfile.TemporaryDirectory() as local_tmp_download_dir: | ||
|
||
with ray_tmp_dir( | ||
launcher.ray_yaml_path, launcher.docker_enabled | ||
) as remote_tmp_dir: | ||
|
||
ray_rsync_up( | ||
launcher.ray_yaml_path, os.path.join(local_tmp_dir, ""), remote_tmp_dir | ||
) | ||
|
||
script_path = os.path.join(os.path.dirname(__file__), "_remote_invoke.py") | ||
ray_rsync_up(launcher.ray_yaml_path, script_path, remote_tmp_dir) | ||
|
||
if launcher.sync_up.source_dir: | ||
source_dir = _get_abs_code_dir(launcher.sync_up.source_dir) | ||
target_dir = ( | ||
launcher.sync_up.target_dir | ||
if launcher.sync_up.target_dir | ||
else remote_tmp_dir | ||
) | ||
rsync( | ||
launcher.ray_yaml_path, | ||
launcher.sync_up.include, | ||
launcher.sync_up.exclude, | ||
os.path.join(source_dir, ""), | ||
target_dir, | ||
) | ||
|
||
ray_exec( | ||
launcher.ray_yaml_path, | ||
launcher.docker_enabled, | ||
os.path.join(remote_tmp_dir, "_remote_invoke.py"), | ||
remote_tmp_dir, | ||
) | ||
|
||
ray_rsync_down( | ||
launcher.ray_yaml_path, | ||
os.path.join(remote_tmp_dir, JOB_RETURN_PICKLE), | ||
local_tmp_download_dir, | ||
) | ||
|
||
sync_down_cfg = launcher.sync_down | ||
|
||
if ( | ||
sync_down_cfg.target_dir | ||
or sync_down_cfg.source_dir | ||
or sync_down_cfg.include | ||
or sync_down_cfg.exclude | ||
): | ||
source_dir = ( | ||
sync_down_cfg.source_dir if sync_down_cfg.source_dir else sweep_dir | ||
) | ||
target_dir = ( | ||
sync_down_cfg.source_dir if sync_down_cfg.source_dir else sweep_dir | ||
) | ||
target_dir = Path(_get_abs_code_dir(target_dir)) | ||
target_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
rsync( | ||
launcher.ray_yaml_path, | ||
launcher.sync_down.include, | ||
launcher.sync_down.exclude, | ||
os.path.join(source_dir), | ||
str(target_dir), | ||
up=False, | ||
) | ||
log.info( | ||
f"Syncing outputs from remote dir: {source_dir} to local dir: {target_dir.absolute()} " | ||
) | ||
|
||
if launcher.stop_cluster: | ||
log.info("Stopping cluster now. (stop_cluster=true)") | ||
if launcher.ray_cluster_cfg.provider.cache_stopped_nodes: | ||
log.info("NOT deleting the cluster (provider.cache_stopped_nodes=true)") | ||
else: | ||
log.info("Deleted the cluster (provider.cache_stopped_nodes=false)") | ||
ray_down(launcher.ray_yaml_path) | ||
else: | ||
log.warning( | ||
"NOT stopping cluster, this may incur extra cost for you. (stop_cluster=false)" | ||
) | ||
|
||
with open(os.path.join(local_tmp_download_dir, JOB_RETURN_PICKLE), "rb") as f: | ||
job_returns = pickle.load(f) # nosec | ||
assert isinstance(job_returns, List) | ||
for run in job_returns: | ||
assert isinstance(run, JobReturn) | ||
return job_returns |
64 changes: 64 additions & 0 deletions
64
plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/_core_local.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
import logging | ||
from pathlib import Path | ||
from typing import Sequence | ||
|
||
import ray | ||
from hydra.core.singleton import Singleton | ||
from hydra.core.utils import JobReturn, configure_log, filter_overrides, setup_globals | ||
from omegaconf import open_dict | ||
|
||
from hydra_plugins.hydra_ray_launcher._launcher_util import ( # type: ignore | ||
launch_job_on_ray, | ||
start_ray, | ||
) | ||
from hydra_plugins.hydra_ray_launcher.ray_local_launcher import ( # type: ignore | ||
RayLocalLauncher, | ||
) | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
def launch( | ||
launcher: RayLocalLauncher, | ||
job_overrides: Sequence[Sequence[str]], | ||
initial_job_idx: int, | ||
) -> Sequence[JobReturn]: | ||
setup_globals() | ||
assert launcher.config is not None | ||
assert launcher.config_loader is not None | ||
assert launcher.task_function is not None | ||
|
||
configure_log(launcher.config.hydra.hydra_logging, launcher.config.hydra.verbose) | ||
sweep_dir = Path(str(launcher.config.hydra.sweep.dir)) | ||
sweep_dir.mkdir(parents=True, exist_ok=True) | ||
log.info( | ||
f"Ray Launcher is launching {len(job_overrides)} jobs, " | ||
f"sweep output dir: {sweep_dir}" | ||
) | ||
|
||
start_ray(launcher.ray_init_cfg) | ||
|
||
runs = [] | ||
for idx, overrides in enumerate(job_overrides): | ||
idx = initial_job_idx + idx | ||
ostr = " ".join(filter_overrides(overrides)) | ||
log.info(f"\t#{idx} : {ostr}") | ||
sweep_config = launcher.config_loader.load_sweep_config( | ||
launcher.config, list(overrides) | ||
) | ||
with open_dict(sweep_config): | ||
# This typically coming from the underlying scheduler (SLURM_JOB_ID for instance) | ||
# In that case, it will not be available here because we are still in the main process. | ||
# but instead should be populated remotely before calling the task_function. | ||
sweep_config.hydra.job.id = f"job_id_for_{idx}" | ||
sweep_config.hydra.job.num = idx | ||
ray_obj = launch_job_on_ray( | ||
launcher.ray_remote_cfg, | ||
sweep_config, | ||
launcher.task_function, | ||
Singleton.get_state(), | ||
) | ||
runs.append(ray_obj) | ||
|
||
return [ray.get(run) for run in runs] |
Oops, something went wrong.