-
-
Notifications
You must be signed in to change notification settings - Fork 653
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
33 changed files
with
1,899 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,4 +23,3 @@ skip= | |
,hydra/grammar/gen | ||
,tools/configen/example/gen | ||
,tools/configen/tests/test_modules/expected | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
global-exclude *.pyc | ||
global-exclude __pycache__ | ||
recursive-include hydra_plugins/* *.yaml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Hydra Ray Launcher | ||
Provides a [`Ray`](https://docs.ray.io/en/latest/) based Hydra Launcher supporting execution on AWS. | ||
|
||
See [website](https://hydra.cc/docs/next/plugins/ray_launcher) for more information |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
defaults: | ||
- hydra/launcher: ray_local | ||
|
||
|
||
random_seed: 0 | ||
checkpoint_path: checkpoint |
24 changes: 24 additions & 0 deletions
24
plugins/hydra_ray_launcher/example/conf/extra_configs/aws.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# @package _global_ | ||
|
||
hydra: | ||
launcher: | ||
sync_up: | ||
# source dir is relative in this case, assuming you are running from | ||
# <project_root>/hydra/plugins/hydra_ray_launcher/example | ||
# absolute path is also supported. | ||
source_dir: "." | ||
# we leave target_dir to be null | ||
# as a result the files will be synced to a temp dir on remote cluster. | ||
# the temp dir will be cleaned up after the jobs are done. | ||
# recommend to leave target_dir to be null if you are syncing code/artifacts to remote cluster so you don't need | ||
# configure $PYTHONPATH on remote cluster | ||
include: ["model", "*.py"] | ||
# No need to sync up config files. | ||
exclude: ["*"] | ||
sync_down: | ||
include: ["*.pt", "*/"] | ||
# No need to sync down config files. | ||
exclude: ["*"] | ||
ray_cluster_cfg: | ||
provider: | ||
cache_stopped_nodes: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
import logging | ||
from datetime import datetime | ||
from pathlib import Path | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class MyModel: | ||
def __init__(self, random_seed: int): | ||
self.random_seed = random_seed | ||
log.info("Init my model") | ||
|
||
def save(self, checkpoint_path: str) -> None: | ||
checkpoint_dir = Path(checkpoint_path) | ||
checkpoint_dir.mkdir(parents=True, exist_ok=True) | ||
log.info(f"Created dir for checkpoints. dir={checkpoint_dir}") | ||
with open(checkpoint_dir / f"checkpoint_{self.random_seed}.pt", "w") as f: | ||
f.write(f"checkpoint@{datetime.now()}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
import logging | ||
|
||
import hydra | ||
from model.my_model import MyModel | ||
from omegaconf import DictConfig | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
@hydra.main(config_path="conf", config_name="config") | ||
def main(cfg: DictConfig) -> None: | ||
log.info("Start training...") | ||
model = MyModel(cfg.random_seed) | ||
# save checkpoint to current working dir. | ||
model.save(cfg.checkpoint_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
11 changes: 11 additions & 0 deletions
11
plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
from hydra.core.config_search_path import ConfigSearchPath | ||
from hydra.plugins.search_path_plugin import SearchPathPlugin | ||
|
||
|
||
class RayLauncherSearchPathPlugin(SearchPathPlugin): | ||
def manipulate_search_path(self, search_path: ConfigSearchPath) -> None: | ||
# Appends the search path for this plugin to the end of the search path | ||
search_path.append( | ||
"hydra-ray-launcher", "pkg://hydra_plugins.hydra_ray_launcher.conf" | ||
) |
Oops, something went wrong.