From 011b6cdceb1826f2a74356d4a1b094cf30ac37d8 Mon Sep 17 00:00:00 2001 From: Omry Yadan Date: Fri, 25 Sep 2020 00:11:25 -0700 Subject: [PATCH] updated instantiate documentation --- examples/instantiate/docs_example/config.yaml | 10 ++ examples/instantiate/docs_example/my_app.py | 82 +++++++++ hydra/_internal/utils.py | 8 +- hydra/utils.py | 29 ++-- .../test_instantiate_examples.py | 35 ++++ .../instantiate_objects/config_files.md | 2 +- .../patterns/instantiate_objects/overview.md | 157 ++++++++++++++---- .../instantiate_objects/structured_config.md | 2 +- 8 files changed, 276 insertions(+), 49 deletions(-) create mode 100644 examples/instantiate/docs_example/config.yaml create mode 100644 examples/instantiate/docs_example/my_app.py diff --git a/examples/instantiate/docs_example/config.yaml b/examples/instantiate/docs_example/config.yaml new file mode 100644 index 00000000000..9d230c66d2e --- /dev/null +++ b/examples/instantiate/docs_example/config.yaml @@ -0,0 +1,10 @@ +trainer: + _target_: my_app.Trainer + optimizer: + _target_: my_app.Optimizer + algo: SGD + lr: 0.01 + dataset: + _target_: my_app.Dataset + name: Imagenet + path: /datasets/imagenet diff --git a/examples/instantiate/docs_example/my_app.py b/examples/instantiate/docs_example/my_app.py new file mode 100644 index 00000000000..fe48bb54e5a --- /dev/null +++ b/examples/instantiate/docs_example/my_app.py @@ -0,0 +1,82 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import hydra +from hydra.utils import instantiate +from omegaconf import DictConfig + + +class Optimizer: + algo: str + lr: float + + def __init__(self, algo: str, lr: float): + self.algo = algo + self.lr = lr + + def __repr__(self): + return f"Optimizer(algo={self.algo},lr={self.lr})" + + +class Dataset: + name: str + path: str + + def __init__(self, name: str, path: str): + self.name = name + self.path = path + + def __repr__(self): + return f"Dataset(name={self.name}, path={self.path})" + + +class Trainer: + def __init__(self, optimizer: Optimizer, dataset: Dataset): + self.optimizer = optimizer + self.dataset = dataset + + def __repr__(self): + return f"Trainer(\n optimizer={self.optimizer},\n dataset={self.dataset}\n)" + + +@hydra.main(config_name="config") +def my_app(cfg: DictConfig) -> None: + optimizer = instantiate(cfg.trainer.optimizer) + print(optimizer) + # Optimizer(algo=SGD,lr=0.01) + + # override parameters on the call-site + optimizer = instantiate(cfg.trainer.optimizer, lr=0.2) + print(optimizer) + # Optimizer(algo=SGD,lr=0.2) + + # recursive instantiation + trainer = instantiate(cfg.trainer) + print(trainer) + # Trainer( + # optimizer=Optimizer(algo=SGD,lr=0.01), + # dataset=Dataset(name=Imagenet, path=/datasets/imagenet) + # ) + + # override nested parameters from the call-site + trainer = instantiate( + cfg.trainer, + optimizer={"lr": 0.3}, + dataset={"name": "cifar10", "path": "/datasets/cifar10"}, + ) + print(trainer) + # Trainer( + # optimizer=Optimizer(algo=SGD,lr=0.3), + # dataset=Dataset(name=cifar10, path=/datasets/cifar10) + # ) + + # non recursive instantiation + optimizer = instantiate(cfg.trainer, _recursive_=False) + print(optimizer) + # Trainer( + # optimizer={'_target_': 'my_app.Optimizer', 'algo': 'SGD', 'lr': 0.01}, + # dataset={'_target_': 'my_app.Dataset', 'name': 'Imagenet', 'path': '/datasets/imagenet'} + # ) + + +if __name__ == "__main__": + my_app() diff --git a/hydra/_internal/utils.py b/hydra/_internal/utils.py index 0af623a5085..71d8f88645d 100644 --- a/hydra/_internal/utils.py +++ b/hydra/_internal/utils.py @@ -579,7 +579,7 @@ def _get_kwargs( config: Union[DictConfig, ListConfig], **kwargs: Any, ) -> Any: - from hydra.utils import _call + from hydra.utils import instantiate assert OmegaConf.is_config(config) @@ -602,12 +602,12 @@ def _get_kwargs( if recursive: for k, v in final_kwargs.items(): if _is_target(v): - final_kwargs[k] = _call(v) + final_kwargs[k] = instantiate(v) elif OmegaConf.is_dict(v) and not OmegaConf.is_none(v): d = OmegaConf.create({}, flags={"allow_objects": True}) for key, value in v.items(): if _is_target(value): - d[key] = _call(value) + d[key] = instantiate(value) elif OmegaConf.is_config(value): d[key] = _get_kwargs(value) else: @@ -617,7 +617,7 @@ def _get_kwargs( lst = OmegaConf.create([], flags={"allow_objects": True}) for x in v: if _is_target(x): - lst.append(_call(x)) + lst.append(instantiate(x)) elif OmegaConf.is_config(x): lst.append(_get_kwargs(x)) else: diff --git a/hydra/utils.py b/hydra/utils.py index 1c9122e4d60..9309d79ea3d 100644 --- a/hydra/utils.py +++ b/hydra/utils.py @@ -15,18 +15,21 @@ log = logging.getLogger(__name__) -def call(config: Any, *args: Any, **kwargs: Any) -> Any: - return _call(config, *args, **kwargs) - - -def _call(config: Any, *args: Any, **kwargs: Any) -> Any: +def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any: """ - :param config: An object describing what to call and what params to use. - _target_ : str : Mandatory target (class name, function name etc) - _recursive_: bool = True : recursive instantiation, defaults to True - :param args: optional positional parameters pass-through - :param kwargs: optional named parameters pass-through - :return: the return value from the specified class or method + :param config: An config object describing what to call and what params to use. + In addition to the parameters, the config must contain: + _target_ : target class or callable name (str) + _recursive_: Construct nested objects as well (bool). + True by default. + may be overridden via a _recursive_ key in + the kwargs + :param args: Optional positional parameters pass-through + :param kwargs: Optional named parameters to override + parameters in the config object. Parameters not present + in the config objects are being passed as is to the target. + :return: if _target_ is a class name: the instantiated object + if _target_ is a callable: the return value of the call """ if OmegaConf.is_none(config): @@ -70,8 +73,8 @@ def _call(config: Any, *args: Any, **kwargs: Any) -> Any: raise type(e)(f"Error instantiating/calling '{cls}' : {e}") -# Alias for call -instantiate = call +# Alias for instantiate +call = instantiate def get_class(path: str) -> type: diff --git a/tests/test_examples/test_instantiate_examples.py b/tests/test_examples/test_instantiate_examples.py index 95210f45564..a334fed9779 100644 --- a/tests/test_examples/test_instantiate_examples.py +++ b/tests/test_examples/test_instantiate_examples.py @@ -89,3 +89,38 @@ def test_instantiate_schema_recursive( cmd = ["my_app.py", "hydra.run.dir=" + str(tmpdir)] + overrides result, _err = get_run_output(cmd) assert_text_same(result, expected) + + +@pytest.mark.parametrize( # type: ignore + "overrides,expected", + [ + ( + [], + dedent( + """\ + Optimizer(algo=SGD,lr=0.01) + Optimizer(algo=SGD,lr=0.2) + Trainer( + optimizer=Optimizer(algo=SGD,lr=0.01), + dataset=Dataset(name=Imagenet, path=/datasets/imagenet) + ) + Trainer( + optimizer=Optimizer(algo=SGD,lr=0.3), + dataset=Dataset(name=cifar10, path=/datasets/cifar10) + ) + Trainer( + optimizer={'_target_': 'my_app.Optimizer', 'algo': 'SGD', 'lr': 0.01}, + dataset={'_target_': 'my_app.Dataset', 'name': 'Imagenet', 'path': '/datasets/imagenet'} + ) + """ + ), + ), + ], +) +def test_instantiate_docs_example( + monkeypatch: Any, tmpdir: Path, overrides: List[str], expected: str +) -> None: + monkeypatch.chdir("examples/instantiate/docs_example") + cmd = ["my_app.py", "hydra.run.dir=" + str(tmpdir)] + overrides + result, _err = get_run_output(cmd) + assert_text_same(result, expected) diff --git a/website/docs/patterns/instantiate_objects/config_files.md b/website/docs/patterns/instantiate_objects/config_files.md index 7db41729637..23ae40bf1c6 100644 --- a/website/docs/patterns/instantiate_objects/config_files.md +++ b/website/docs/patterns/instantiate_objects/config_files.md @@ -3,7 +3,7 @@ id: config_files title: Config files example sidebar_label: Config files example --- -[![Example application](https://img.shields.io/badge/-Example%20application-informational)](https://github.com/facebookresearch/hydra/tree/master/examples/instantiate/object) +[![Example applications](https://img.shields.io/badge/-Example%20applications-informational)](https://github.com/facebookresearch/hydra/tree/master/examples/instantiate) This example demonstrates the use of config files to instantiated objects. diff --git a/website/docs/patterns/instantiate_objects/overview.md b/website/docs/patterns/instantiate_objects/overview.md index 21542fd334e..cf51b9e3727 100644 --- a/website/docs/patterns/instantiate_objects/overview.md +++ b/website/docs/patterns/instantiate_objects/overview.md @@ -3,44 +3,58 @@ id: overview title: Instantiating objects with Hydra sidebar_label: Overview --- +[![Example applications](https://img.shields.io/badge/-Example%20applications-informational)](https://github.com/facebookresearch/hydra/tree/master/examples/instantiate) One of the best ways to drive different behavior in an application is to instantiate different implementations of an interface. The code using the instantiated object only knows the interface which remains constant, but the behavior is determined by the actual object instance. -Hydra provides `hydra.utils.call()` (and its alias `hydra.utils.instantiate()`) for instantiating objects and calling functions. Prefer `instantiate` for creating objects and `call` for invoking functions. +Hydra provides `hydra.utils.instantiate()` (and its alias `hydra.utils.call()`) for instantiating objects and calling functions. Prefer `instantiate` for creating objects and `call` for invoking functions. Call/instantiate supports: -- Class names : Call the `__init__` method -- Callables like functions, static functions, class methods and objects +- Constructing an object by calling the `__init__` method +- Calling functions, static functions, class methods and other callable global objects ```python -def call(config: Any, *args: Any, **kwargs: Any) -> Any: +def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any: """ - :param config: An object describing what to call and what params to use. - Must have a _target_ field. - :param args: optional positional parameters pass-through - :param kwargs: optional named parameters pass-through - :return: the return value from the specified class or method + :param config: An config object describing what to call and what params to use. + In addition to the parameters, the config must contain: + _target_ : target class or callable name (str) + _recursive_: Construct nested objects as well (bool). + True by default. + may be overridden via a _recursive_ key in + the kwargs + :param args: Optional positional parameters pass-through + :param kwargs: Optional named parameters to override + parameters in the config object. Parameters not present + in the config objects are being passed as is to the target. + :return: if _target_ is a class name: the instantiated object + if _target_ is a callable: the return value of the call """ ... -# Alias for call -instantiate = call +# Alias for instantiate +call = instantiate ``` The config passed to these functions must have a key called `_target_`, with the value of a fully qualified class name, class method, static method or callable. Any additional parameters are passed as keyword arguments to tha target. +For convenience, `None` config results in a `None` object. -For example, your application may have a User class that looks like this: -```python title="user.py" -class User: - name: str - age : int - - def __init__(self, name: str, age: int): - self.name = name - self.age = age +### Simple usage +Your application might have an Optimizer class: +```python title="Example class" +class Optimizer: + algo: str + lr: float + + def __init__(self, algo: str, lr: float): + self.algo = algo + self.lr = lr + + def __repr__(self): + return f"Optimizer(algo={self.algo},lr={self.lr})" ```
@@ -48,10 +62,14 @@ class User:
```yaml title="Config" -bond: - _target_: user.User - name: Bond - age: 7 +optimizer: + _target_: my_app.Optimizer + algo: SGD + lr: 0.01 + + + + ``` @@ -60,17 +78,96 @@ bond:
```python title="Instantiation" -user : User = instantiate(cfg.bond) -assert isinstance(user, user.User) -assert user.name == "Bond" -assert user.age == 7 +opt = instantiate(cfg.optimizer) +print(opt) +# Optimizer(algo=SGD,lr=0.01) + +# override parameters on the call-site +opt = instantiate(cfg.optimizer, lr=0.2) +print(opt) +# Optimizer(algo=SGD,lr=0.2) ```
-For convenience, instantiate/call returns `None` when receiving `None` as input. +### Recursive instantiation +Hydra will instantiate nested objects automatically. +```python title="Additional classes" +class Dataset: + name: str + path: str + + def __init__(self, name: str, path: str): + self.name = name + self.path = path + + +class Trainer: + def __init__(self, optimizer: Optimizer, dataset: Dataset): + self.optimizer = optimizer + self.dataset = dataset +``` + + +```yaml title="Example config" +trainer: + _target_: my_app.Trainer + optimizer: + _target_: my_app.Optimizer + algo: SGD + lr: 0.01 + dataset: + _target_: my_app.Dataset + name: Imagenet + path: /datasets/imagenet +``` + +Instantiate is recursive by default +```python +trainer = instantiate(cfg.trainer) +print(trainer) +``` +Output: ```python -assert instantiate(None) is None +Trainer( + optimizer=Optimizer(algo=SGD,lr=0.01), + dataset=Dataset(name=Imagenet, path=/datasets/imagenet) +) ``` + +You can override parameters for nested objects: +```python +trainer = instantiate( + cfg.trainer, + optimizer={"lr": 0.3}, + dataset={"name": "cifar10", "path": "/datasets/cifar10"}, +) +print(trainer) +``` +Output: +```python +Trainer( + optimizer=Optimizer(algo=SGD,lr=0.3), + dataset=Dataset(name=cifar10, path=/datasets/cifar10) +) +``` + +You can disable it by setting `_recursive_` to `False` in the config node or in the call-site +Note that in that case you will receive an OmegaConf DictConfig instead of the real object. +```python +optimizer = instantiate(cfg.trainer, _recursive_=False) +print(optimizer) +``` + +Output: +```python +Trainer( + optimizer={ + '_target_': 'my_app.Optimizer', 'algo': 'SGD', 'lr': 0.01 + }, + dataset={ + '_target_': 'my_app.Dataset', 'name': 'Imagenet', 'path': '/datasets/imagenet' + } +) \ No newline at end of file diff --git a/website/docs/patterns/instantiate_objects/structured_config.md b/website/docs/patterns/instantiate_objects/structured_config.md index 31cc9b3609e..2a2db5b7599 100644 --- a/website/docs/patterns/instantiate_objects/structured_config.md +++ b/website/docs/patterns/instantiate_objects/structured_config.md @@ -4,7 +4,7 @@ title: Structured Configs example sidebar_label: Structured Configs example --- -[![Example application](https://img.shields.io/badge/-Example%20application-informational)](https://github.com/facebookresearch/hydra/tree/master/examples/instantiate/schema/my_app.py) +[![Example applications](https://img.shields.io/badge/-Example%20applications-informational)](https://github.com/facebookresearch/hydra/tree/master/examples/instantiate) This example demonstrates the use of Structured Configs to instantiated objects.