Skip to content

Commit

Permalink
remove BasicLauncher constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
jieru-hu committed May 4, 2021
1 parent fbdbfc4 commit d2a4213
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 32 deletions.
26 changes: 6 additions & 20 deletions hydra/_internal/core_plugins/basic_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

from omegaconf import DictConfig, open_dict

from hydra._internal.callbacks import Callbacks
from hydra.core.config_loader import ConfigLoader
from hydra.core.config_store import ConfigStore
from hydra.core.utils import (
JobReturn,
Expand Down Expand Up @@ -38,31 +36,23 @@ def __init__(self) -> None:
self.config: Optional[DictConfig] = None
self.task_function: Optional[TaskFunction] = None
self.hydra_context: Optional[HydraContext] = None
# BasicLauncher supports Hydra 1.0 style setup for compatibility with 3rd party Sweeper
self.config_loader: Optional[ConfigLoader] = None

def setup(
self,
*,
hydra_context: HydraContext,
task_function: TaskFunction,
config: DictConfig,
hydra_context: Optional[HydraContext] = None,
config_loader: Optional[ConfigLoader] = None,
) -> None:
self.config = config
self.hydra_context = hydra_context
self.config_loader = config_loader
self.task_function = task_function

def launch(
self, job_overrides: Sequence[Sequence[str]], initial_job_idx: int
) -> Sequence[JobReturn]:
setup_globals()
config_loader = (
self.hydra_context.config_loader
if self.hydra_context
else self.config_loader
)
assert config_loader is not None
assert self.hydra_context is not None
assert self.config is not None
assert self.task_function is not None

Expand All @@ -71,20 +61,16 @@ def launch(
Path(str(sweep_dir)).mkdir(parents=True, exist_ok=True)
log.info(f"Launching {len(job_overrides)} jobs locally")
runs: List[JobReturn] = []
if self.hydra_context is None:
self.hydra_context = HydraContext(
config_loader=config_loader, callbacks=Callbacks()
)

for idx, overrides in enumerate(job_overrides):
idx = initial_job_idx + idx
lst = " ".join(filter_overrides(overrides))
log.info(f"\t#{idx} : {lst}")
sweep_config = config_loader.load_sweep_config(self.config, list(overrides))
sweep_config = self.hydra_context.config_loader.load_sweep_config(
self.config, list(overrides)
)
with open_dict(sweep_config):
sweep_config.hydra.job.id = idx
sweep_config.hydra.job.num = idx

ret = run_job(
hydra_context=self.hydra_context,
task_function=self.task_function,
Expand Down
19 changes: 7 additions & 12 deletions hydra/core/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from omegaconf import DictConfig

from hydra._internal.callbacks import Callbacks
from hydra._internal.sources_registry import SourcesRegistry
from hydra.core.config_loader import ConfigLoader
from hydra.core.singleton import Singleton
Expand Down Expand Up @@ -125,17 +126,7 @@ def _setup_plugin(

param_keys = signature(plugin.setup).parameters.keys()

from hydra._internal.core_plugins.basic_launcher import BasicLauncher

if isinstance(plugin, BasicLauncher):
# BasicLauncher supports Hydra 1.0 style setup for compatibility with 3rd party Sweeper
plugin.setup(
config=config,
config_loader=config_loader,
task_function=task_function,
hydra_context=hydra_context,
)
elif "config_loader" in param_keys:
if "config_loader" in param_keys:
warnings.warn(
message=(
"\n"
Expand All @@ -156,7 +147,11 @@ def _setup_plugin(
task_function=task_function,
)
else:
assert hydra_context is not None
if hydra_context is None:
assert config_loader is not None
hydra_context = HydraContext(
config_loader=config_loader, callbacks=Callbacks()
)
plugin.setup(
config=config, hydra_context=hydra_context, task_function=task_function
)
Expand Down

0 comments on commit d2a4213

Please sign in to comment.