Skip to content

Commit

Permalink
remove pickling Singleton
Browse files Browse the repository at this point in the history
  • Loading branch information
jieru-hu committed Jan 20, 2021
1 parent e8cd36b commit bcd1627
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@

import ray
from hydra.core.singleton import Singleton
from hydra.core.utils import JobReturn, configure_log, filter_overrides, setup_globals
from hydra.core.utils import (
JobReturn,
JobRuntime,
configure_log,
filter_overrides,
setup_globals,
)
from omegaconf import open_dict

from hydra_plugins.hydra_ray_launcher._launcher_util import ( # type: ignore
Expand Down Expand Up @@ -51,11 +57,13 @@ def launch(
# but instead should be populated remotely before calling the task_function.
sweep_config.hydra.job.id = f"job_id_for_{idx}"
sweep_config.hydra.job.num = idx

ray_obj = launch_job_on_ray(
launcher.ray_cfg.remote,
sweep_config,
launcher.task_function,
Singleton.get_state(),
JobRuntime.instance().conf,
Singleton.get_state().get("omegaconf_resolvers"),
)
runs.append(ray_obj)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
import pickle5 as pickle # type: ignore
from hydra.core.hydra_config import HydraConfig
from hydra.core.singleton import Singleton
from hydra.core.utils import JobReturn, configure_log, filter_overrides, setup_globals
from hydra.core.utils import (
JobReturn,
JobRuntime,
configure_log,
filter_overrides,
setup_globals,
)
from omegaconf import OmegaConf, open_dict, read_write

from hydra_plugins.hydra_ray_launcher._launcher_util import ( # type: ignore
Expand Down Expand Up @@ -94,7 +100,8 @@ def launch(
tmp_dir=local_tmp_dir,
sweep_configs=sweep_configs, # type: ignore
task_function=launcher.task_function,
singleton_state=Singleton.get_state(),
job_runtime=JobRuntime.instance().conf,
omegaconf_resolvers=Singleton.get_state().get("omegaconf_resolvers"),
)

with tempfile.NamedTemporaryFile(suffix=".yaml", delete=False) as f:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
import logging
import os
from contextlib import contextmanager
from copy import deepcopy
from subprocess import PIPE, Popen
from typing import Any, Dict, Generator, List, Tuple

import ray
from hydra.core.hydra_config import HydraConfig
from hydra.core.singleton import Singleton
from hydra.core.utils import JobReturn, run_job, setup_globals
from hydra.core.utils import JobReturn, JobRuntime, run_job
from hydra.types import TaskFunction
from omegaconf import DictConfig, OmegaConf
from omegaconf.basecontainer import BaseContainer

log = logging.getLogger(__name__)

Expand All @@ -32,11 +33,12 @@ def start_ray(init_cfg: DictConfig) -> None:
def _run_job(
sweep_config: DictConfig,
task_function: TaskFunction,
singleton_state: Dict[Any, Any],
job_runtime: DictConfig,
resolvers: Dict[str, Any],
) -> JobReturn:
setup_globals()
Singleton.set_state(singleton_state)
HydraConfig.instance().set_config(sweep_config)
JobRuntime.instance().conf = job_runtime
BaseContainer._resolvers = deepcopy(resolvers)
return run_job(
config=sweep_config,
task_function=task_function,
Expand All @@ -49,7 +51,8 @@ def launch_job_on_ray(
ray_remote: DictConfig,
sweep_config: DictConfig,
task_function: TaskFunction,
singleton_state: Any,
job_runtime: DictConfig,
omegaconf_resolvers: Dict[str, Any],
) -> Any:
if ray_remote:
run_job_ray = ray.remote(**ray_remote)(_run_job)
Expand All @@ -59,7 +62,8 @@ def launch_job_on_ray(
ret = run_job_ray.remote(
sweep_config=sweep_config,
task_function=task_function,
singleton_state=singleton_state,
job_runtime=job_runtime,
resolvers=omegaconf_resolvers,
)
return ret

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import pickle5 as pickle # type: ignore
import ray
from hydra.core.hydra_config import HydraConfig
from hydra.core.singleton import Singleton
from hydra.core.utils import JobReturn, setup_globals
from omegaconf import open_dict

Expand All @@ -30,9 +29,10 @@ def launch_jobs(temp_dir: str) -> None:
runs = []
with open(os.path.join(temp_dir, JOB_SPEC_PICKLE), "rb") as f:
job_spec = pickle.load(f) # nosec
singleton_state = job_spec["singleton_state"]
sweep_configs = job_spec["sweep_configs"]
task_function = job_spec["task_function"]
job_runtime = job_spec["job_runtime"]
omegaconf_resolvers = job_spec["omegaconf_resolvers"]

instance_id = _get_instance_id()

Expand All @@ -44,7 +44,6 @@ def launch_jobs(temp_dir: str) -> None:
f"{instance_id}_{sweep_config.hydra.job.num}"
)
setup_globals()
Singleton.set_state(singleton_state)
HydraConfig.instance().set_config(sweep_config)
ray_init = sweep_config.hydra.launcher.ray.init
ray_remote = sweep_config.hydra.launcher.ray.remote
Expand All @@ -55,7 +54,11 @@ def launch_jobs(temp_dir: str) -> None:

start_ray(ray_init)
ray_obj = launch_job_on_ray(
ray_remote, sweep_config, task_function, singleton_state
ray_remote,
sweep_config,
task_function,
job_runtime,
omegaconf_resolvers,
)
runs.append(ray_obj)

Expand Down
17 changes: 1 addition & 16 deletions plugins/hydra_ray_launcher/tests/test_ray_aws_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,22 +108,7 @@
chdir_plugin_root()


def build_ray_launcher_wheel(tmpdir: str) -> str:
"""
This only works on ray launcher plugin wheels for now, reasons being in our base AMI
we do not necessarily have the dependency for other plugins.
"""
command = "python -m pip --disable-pip-version-check list | grep hydra | grep -v hydra-core "
output = subprocess.getoutput(command).split("\n")
plugins_path = [x.split()[0].replace("-", "_") for x in output]
assert (
len(plugins_path) == 1 and "hydra_ray_launcher" == plugins_path[0]
), "Ray test AMI doesn't have dependency installed for other plugins."

return build_plugin_wheel(tmpdir)


def build_plugin_wheel(tmp_wheel_dir: str) -> str:
def build_ray_launcher_wheel(tmp_wheel_dir: str) -> str:
chdir_hydra_root()
plugin = "hydra_ray_launcher"
os.chdir(Path("plugins") / plugin)
Expand Down

0 comments on commit bcd1627

Please sign in to comment.