Skip to content

Commit

Permalink
allow passthrough of non primitive objects by name
Browse files Browse the repository at this point in the history
  • Loading branch information
omry committed Feb 27, 2020
1 parent 218779f commit fe90b66
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
19 changes: 15 additions & 4 deletions hydra/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pathlib import Path
from typing import Any

from omegaconf import DictConfig, OmegaConf
from omegaconf import DictConfig, OmegaConf, _utils

from hydra.conf import PluginConf
from hydra.core.hydra_config import HydraConfig
Expand Down Expand Up @@ -67,9 +67,20 @@ def instantiate(config: PluginConf, *args: Any, **kwargs: Any) -> Any:
), "Input config params are expected to be a mapping, found {}".format(
type(config.params)
)
params.merge_with(OmegaConf.create(kwargs))

return clazz(*args, **params)
primitives = {}
rest = {}
for k, v in kwargs.items():
if _utils._is_primitive_type(v) or isinstance(v, (dict, list)):
primitives[k] = v
else:
rest[k] = v
params.merge_with(OmegaConf.create(primitives))
final_kwargs = OmegaConf.to_container(params, resolve=True)
assert isinstance(final_kwargs, DictConfig)
for k, v in rest.items():
final_kwargs[k] = v

return clazz(*args, **final_kwargs)
except Exception as e:
log.error(f"Error instantiating '{classname}' : {e}")
raise e
Expand Down
1 change: 1 addition & 0 deletions news/400.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow hydra.utils.instantiate() to accept non primitive objects for passthrough by name
16 changes: 16 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@ def __ne__(self, other: Any) -> Any:
return NotImplemented


class Foo:
def __init__(self, x: int) -> None:
self.x = x

def __eq__(self, other: Any) -> Any:
if isinstance(other, Foo):
return self.x == other.x
return False


@pytest.mark.parametrize("path,expected_type", [("tests.test_utils.Bar", Bar)]) # type: ignore
def test_get_class(path: str, expected_type: type) -> None:
assert utils.get_class(path) == expected_type
Expand Down Expand Up @@ -103,6 +113,12 @@ def test_get_static_method(path: str, return_value: Any) -> None:
{"a": 10, "d": 40},
Bar(10, 200, 200, 40),
),
(
{"cls": "tests.test_utils.Bar", "params": {"b": 200, "c": "${params.b}"}},
None,
{"a": 10, "d": Foo(99)},
Bar(10, 200, 200, Foo(99)),
),
],
)
def test_class_instantiate(
Expand Down

0 comments on commit fe90b66

Please sign in to comment.