Skip to content

Commit

Permalink
updated instantiate documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
omry committed Sep 25, 2020
1 parent 8343d38 commit 011b6cd
Show file tree
Hide file tree
Showing 8 changed files with 276 additions and 49 deletions.
10 changes: 10 additions & 0 deletions examples/instantiate/docs_example/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
trainer:
_target_: my_app.Trainer
optimizer:
_target_: my_app.Optimizer
algo: SGD
lr: 0.01
dataset:
_target_: my_app.Dataset
name: Imagenet
path: /datasets/imagenet
82 changes: 82 additions & 0 deletions examples/instantiate/docs_example/my_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig


class Optimizer:
algo: str
lr: float

def __init__(self, algo: str, lr: float):
self.algo = algo
self.lr = lr

def __repr__(self):
return f"Optimizer(algo={self.algo},lr={self.lr})"


class Dataset:
name: str
path: str

def __init__(self, name: str, path: str):
self.name = name
self.path = path

def __repr__(self):
return f"Dataset(name={self.name}, path={self.path})"


class Trainer:
def __init__(self, optimizer: Optimizer, dataset: Dataset):
self.optimizer = optimizer
self.dataset = dataset

def __repr__(self):
return f"Trainer(\n optimizer={self.optimizer},\n dataset={self.dataset}\n)"


@hydra.main(config_name="config")
def my_app(cfg: DictConfig) -> None:
optimizer = instantiate(cfg.trainer.optimizer)
print(optimizer)
# Optimizer(algo=SGD,lr=0.01)

# override parameters on the call-site
optimizer = instantiate(cfg.trainer.optimizer, lr=0.2)
print(optimizer)
# Optimizer(algo=SGD,lr=0.2)

# recursive instantiation
trainer = instantiate(cfg.trainer)
print(trainer)
# Trainer(
# optimizer=Optimizer(algo=SGD,lr=0.01),
# dataset=Dataset(name=Imagenet, path=/datasets/imagenet)
# )

# override nested parameters from the call-site
trainer = instantiate(
cfg.trainer,
optimizer={"lr": 0.3},
dataset={"name": "cifar10", "path": "/datasets/cifar10"},
)
print(trainer)
# Trainer(
# optimizer=Optimizer(algo=SGD,lr=0.3),
# dataset=Dataset(name=cifar10, path=/datasets/cifar10)
# )

# non recursive instantiation
optimizer = instantiate(cfg.trainer, _recursive_=False)
print(optimizer)
# Trainer(
# optimizer={'_target_': 'my_app.Optimizer', 'algo': 'SGD', 'lr': 0.01},
# dataset={'_target_': 'my_app.Dataset', 'name': 'Imagenet', 'path': '/datasets/imagenet'}
# )


if __name__ == "__main__":
my_app()
8 changes: 4 additions & 4 deletions hydra/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def _get_kwargs(
config: Union[DictConfig, ListConfig],
**kwargs: Any,
) -> Any:
from hydra.utils import _call
from hydra.utils import instantiate

assert OmegaConf.is_config(config)

Expand All @@ -602,12 +602,12 @@ def _get_kwargs(
if recursive:
for k, v in final_kwargs.items():
if _is_target(v):
final_kwargs[k] = _call(v)
final_kwargs[k] = instantiate(v)
elif OmegaConf.is_dict(v) and not OmegaConf.is_none(v):
d = OmegaConf.create({}, flags={"allow_objects": True})
for key, value in v.items():
if _is_target(value):
d[key] = _call(value)
d[key] = instantiate(value)
elif OmegaConf.is_config(value):
d[key] = _get_kwargs(value)
else:
Expand All @@ -617,7 +617,7 @@ def _get_kwargs(
lst = OmegaConf.create([], flags={"allow_objects": True})
for x in v:
if _is_target(x):
lst.append(_call(x))
lst.append(instantiate(x))
elif OmegaConf.is_config(x):
lst.append(_get_kwargs(x))
else:
Expand Down
29 changes: 16 additions & 13 deletions hydra/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,21 @@
log = logging.getLogger(__name__)


def call(config: Any, *args: Any, **kwargs: Any) -> Any:
return _call(config, *args, **kwargs)


def _call(config: Any, *args: Any, **kwargs: Any) -> Any:
def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
"""
:param config: An object describing what to call and what params to use.
_target_ : str : Mandatory target (class name, function name etc)
_recursive_: bool = True : recursive instantiation, defaults to True
:param args: optional positional parameters pass-through
:param kwargs: optional named parameters pass-through
:return: the return value from the specified class or method
:param config: An config object describing what to call and what params to use.
In addition to the parameters, the config must contain:
_target_ : target class or callable name (str)
_recursive_: Construct nested objects as well (bool).
True by default.
may be overridden via a _recursive_ key in
the kwargs
:param args: Optional positional parameters pass-through
:param kwargs: Optional named parameters to override
parameters in the config object. Parameters not present
in the config objects are being passed as is to the target.
:return: if _target_ is a class name: the instantiated object
if _target_ is a callable: the return value of the call
"""

if OmegaConf.is_none(config):
Expand Down Expand Up @@ -70,8 +73,8 @@ def _call(config: Any, *args: Any, **kwargs: Any) -> Any:
raise type(e)(f"Error instantiating/calling '{cls}' : {e}")


# Alias for call
instantiate = call
# Alias for instantiate
call = instantiate


def get_class(path: str) -> type:
Expand Down
35 changes: 35 additions & 0 deletions tests/test_examples/test_instantiate_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,38 @@ def test_instantiate_schema_recursive(
cmd = ["my_app.py", "hydra.run.dir=" + str(tmpdir)] + overrides
result, _err = get_run_output(cmd)
assert_text_same(result, expected)


@pytest.mark.parametrize( # type: ignore
"overrides,expected",
[
(
[],
dedent(
"""\
Optimizer(algo=SGD,lr=0.01)
Optimizer(algo=SGD,lr=0.2)
Trainer(
optimizer=Optimizer(algo=SGD,lr=0.01),
dataset=Dataset(name=Imagenet, path=/datasets/imagenet)
)
Trainer(
optimizer=Optimizer(algo=SGD,lr=0.3),
dataset=Dataset(name=cifar10, path=/datasets/cifar10)
)
Trainer(
optimizer={'_target_': 'my_app.Optimizer', 'algo': 'SGD', 'lr': 0.01},
dataset={'_target_': 'my_app.Dataset', 'name': 'Imagenet', 'path': '/datasets/imagenet'}
)
"""
),
),
],
)
def test_instantiate_docs_example(
monkeypatch: Any, tmpdir: Path, overrides: List[str], expected: str
) -> None:
monkeypatch.chdir("examples/instantiate/docs_example")
cmd = ["my_app.py", "hydra.run.dir=" + str(tmpdir)] + overrides
result, _err = get_run_output(cmd)
assert_text_same(result, expected)
2 changes: 1 addition & 1 deletion website/docs/patterns/instantiate_objects/config_files.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ id: config_files
title: Config files example
sidebar_label: Config files example
---
[![Example application](https://img.shields.io/badge/-Example%20application-informational)](https://github.com/facebookresearch/hydra/tree/master/examples/instantiate/object)
[![Example applications](https://img.shields.io/badge/-Example%20applications-informational)](https://github.com/facebookresearch/hydra/tree/master/examples/instantiate)

This example demonstrates the use of config files to instantiated objects.

Expand Down
Loading

0 comments on commit 011b6cd

Please sign in to comment.