From fe90b66e549c3df3d213f5d7348df08af8aac546 Mon Sep 17 00:00:00 2001 From: Omry Yadan Date: Sun, 9 Feb 2020 18:37:02 -0800 Subject: [PATCH] allow passthrough of non primitive objects by name --- hydra/utils.py | 19 +++++++++++++++---- news/400.bugfix | 1 + tests/test_utils.py | 16 ++++++++++++++++ 3 files changed, 32 insertions(+), 4 deletions(-) create mode 100644 news/400.bugfix diff --git a/hydra/utils.py b/hydra/utils.py index 0767a5869ba..72a65c4ca08 100644 --- a/hydra/utils.py +++ b/hydra/utils.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Any -from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig, OmegaConf, _utils from hydra.conf import PluginConf from hydra.core.hydra_config import HydraConfig @@ -67,9 +67,20 @@ def instantiate(config: PluginConf, *args: Any, **kwargs: Any) -> Any: ), "Input config params are expected to be a mapping, found {}".format( type(config.params) ) - params.merge_with(OmegaConf.create(kwargs)) - - return clazz(*args, **params) + primitives = {} + rest = {} + for k, v in kwargs.items(): + if _utils._is_primitive_type(v) or isinstance(v, (dict, list)): + primitives[k] = v + else: + rest[k] = v + params.merge_with(OmegaConf.create(primitives)) + final_kwargs = OmegaConf.to_container(params, resolve=True) + assert isinstance(final_kwargs, DictConfig) + for k, v in rest.items(): + final_kwargs[k] = v + + return clazz(*args, **final_kwargs) except Exception as e: log.error(f"Error instantiating '{classname}' : {e}") raise e diff --git a/news/400.bugfix b/news/400.bugfix new file mode 100644 index 00000000000..cf257c0d393 --- /dev/null +++ b/news/400.bugfix @@ -0,0 +1 @@ +Allow hydra.utils.instantiate() to accept non primitive objects for passthrough by name diff --git a/tests/test_utils.py b/tests/test_utils.py index 9d0384ba9b2..b54f8fb1bee 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -48,6 +48,16 @@ def __ne__(self, other: Any) -> Any: return NotImplemented +class Foo: + def __init__(self, x: int) -> None: + self.x = x + + def __eq__(self, other: Any) -> Any: + if isinstance(other, Foo): + return self.x == other.x + return False + + @pytest.mark.parametrize("path,expected_type", [("tests.test_utils.Bar", Bar)]) # type: ignore def test_get_class(path: str, expected_type: type) -> None: assert utils.get_class(path) == expected_type @@ -103,6 +113,12 @@ def test_get_static_method(path: str, return_value: Any) -> None: {"a": 10, "d": 40}, Bar(10, 200, 200, 40), ), + ( + {"cls": "tests.test_utils.Bar", "params": {"b": 200, "c": "${params.b}"}}, + None, + {"a": 10, "d": Foo(99)}, + Bar(10, 200, 200, Foo(99)), + ), ], ) def test_class_instantiate(