Skip to content

Commit

Permalink
Instantiate structured configs (#502)
Browse files Browse the repository at this point in the history
* to_container: instantiate_structured_configs flag

* add a failing test

* fix typo

* fix bug

* another bugfix

* solved an issue

* add type assert

* updates

* test instantiation of subclass of Dict[str, User]

* test instantiate_structured_configs-Str2UserWithField

* fix bug with allow_objects flag

* add comment

* remove dependence on ref_type; use object_type

* use keyword args in call to _instantiate_structured_config_impl

* refactor for clearer control flow

* move method _instantiate_structured_config_impl

* fix lint/mypy errors

* remove unnecessary allow_objects flag

* rename parameter 'instantiate_structured_configs' -> 'instantiate'

* create OmegaConf.to_object alias for OmegaConf.to_container

* One use case per test

* coverage: use to_object(cfg) instead of to_container(object, instantiate=True)

* rename tests: to_object instead of to_container

* tests: user str key instead of int key

* tests: change 'assert ... is MISSING' -> 'assert ... == MISSING'

* add tests for object nested inside object

* one use case per tests: dict subclass

* test_structured_config.py: consolidate instantiate=True tests

* finish rebase against master

* Move TestInstantiateStructuredConfigs to test_to_container.py

* Create get_structured_config_field_names function

* OmegaConf.to_object: resolve=True by default

* change _instantiate_structured_config_impl fn signature

* separate positive and negative test cases

* switch order of cases in _instantiate_structured_config_impl

* switch order of cases in _instantiate_structured_config_impl

* regroup tests for extracting structured config info

* Undo a stylistic change to tests/structured_conf/test_structured_config.py

This change would be best left for another PR.

* add failing tests for throw if MISSING

* Update omegaconf/_utils.py

Co-authored-by: Olivier Delalleau <[email protected]>

* fix mypy and flake8 issues

* implement MissingMandatoryValue in case of MISSING param to dataclass instance

* update a test to reflect new behavior r.e. MISSING

* use correct-typed value in test of KeyValidationError

* modify to_object docstring

* use a set for _instantiate_structured_config_impl field names

* remove redundant call to set()

* refactor TestInstantiateStructuredConfigs

- parametrize the `module` fixture directly
- reorder tests for increased consistency
  (test non-missing case before missing case)

* TestInstantiateStructuredConfigs: remove redundant isinstance assertions

Testing `isinstance(a, A)` is redundant if we are testing `type(a) is A`
on the next line.

* Use setattr(instance, k, v) when structured config has extra fields

* add news fragment

* refactoring: rename variables

* Test error message for MissingMandatoryValue

* Formatting: delete whitespace

* include $OBJECT_TYPE in MissingMandatoryValue err msg

* change _instantiate_structured_config_impl to an instance method

* simplify `retdict` & `retstruct` to `ret`

* rename `conf` -> `self` in _instantiate_structured_config_impl

* remove `resolve` arg from `to_object`

* Docs example for SCMode.INSTANTIATE

* docs: OmegaConf.to_object example

* Docs minor edit

* updates to to_object docs

* Revert test_structured_config.py (remove redundant test)

* dict subclass: DictConfig items become instance attributes

* docs: use `show` instead of `print`/`assert`

* minor doc fix

Co-authored-by: Olivier Delalleau <[email protected]>

* docs: Improve introduction to `to_object` method

* docs: Remove explanation r.e. equivalent OmegaConf.to_container calls

* docs: clarification on ducktyping

Co-authored-by: Omry Yadan <[email protected]>

* to_container docs: explicitly document the new SCMode.INSTANTIATE member

* update `to_object` docstring

* docs: fix typos

Co-authored-by: Olivier Delalleau <[email protected]>

* empty commit (to trigger CI workflow)

* refactor test_SCMode

* lowercase test fn name (test_SCMode -> test_scmode)

* StructuredConfigs have resolve=True and enum_to_str=False

* minor: revert whitespace addition

* Edit to news/472.feature

Co-authored-by: Omry Yadan <[email protected]>

* don't mention enum_to_str

* formatting and title for structured_config_mode docs

* remove TODO comment

* fix comment formatting

* move `import get_structured_config_field_names` to top of file

* one last formatting adjustment

Co-authored-by: Olivier Delalleau <[email protected]>
Co-authored-by: Omry Yadan <[email protected]>
  • Loading branch information
3 people authored Apr 7, 2021
1 parent a03f681 commit 0249455
Show file tree
Hide file tree
Showing 12 changed files with 456 additions and 53 deletions.
37 changes: 37 additions & 0 deletions docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -778,12 +778,25 @@ If resolve is set to True, interpolations will be resolved during conversion.
>>> show(resolved)
type: dict, value: {'foo': 'bar', 'foo2': 'bar'}


Using ``structured_config_mode``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
You can customize the treatment of ``OmegaConf.to_container()`` for
Structured Config nodes using the ``structured_config_mode`` option.
By default, Structured Config nodes are converted to plain dict.

Using ``structured_config_mode=SCMode.DICT_CONFIG`` causes such nodes to remain
as DictConfig, allowing attribute style access on the resulting node.

Using ``structured_config_mode=SCMode.INSTANTIATE``, Structured Config nodes
are converted to instances of the backing dataclass or attrs class. Note that
when ``structured_config_mode=SCMode.INSTANTIATE``, interpolations nested within
a structured config node will be resolved, even if ``OmegaConf.to_container`` is called
with the the keyword argument ``resolve=False``, so that interpolations are resolved before
being used to instantiate dataclass/attr class instances. Interpolations within
non-structured parent nodes will be resolved (or not) as usual, according to
the ``resolve`` keyword arg.

.. doctest::

>>> from omegaconf import SCMode
Expand All @@ -795,6 +808,30 @@ as DictConfig, allowing attribute style access on the resulting node.
>>> show(container["structured_config"])
type: DictConfig, value: {'port': 80, 'host': 'localhost'}

OmegaConf.to_object
^^^^^^^^^^^^^^^^^^^^^^
The ``OmegaConf.to_object`` method recursively converts DictConfig and ListConfig objects
into dicts and lists, with the exception that Structured Config objects are
converted into instances of the backing dataclass or attr class. All OmegaConf
interpolations are resolved before conversion to Python containers.

.. doctest::

>>> container = OmegaConf.to_object(conf)
>>> show(container)
type: dict, value: {'structured_config': MyConfig(port=80, host='localhost')}
>>> show(container["structured_config"])
type: MyConfig, value: MyConfig(port=80, host='localhost')

Note that here, ``container["structured_config"]`` is actually an instance of
``MyConfig``, whereas in the previous examples we had a ``dict`` or a
``DictConfig`` object that was duck-typed to look like an instance of
``MyConfig``.

The call ``OmegaConf.to_object(conf)`` is equivalent to
``OmegaConf.to_container(conf, resolve=True,
structured_config_mode=SCMode.INSTANTIATE)``.

OmegaConf.resolve
^^^^^^^^^^^^^^^^^
.. code-block:: python
Expand Down
1 change: 1 addition & 0 deletions news/472.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add the OmegaConf.to_object method, which converts Structured Configs to native instances of the underlying `@dataclass` or `@attr.s` class.
19 changes: 19 additions & 0 deletions omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,12 @@ def _resolve_forward(type_: Type[Any], module: str) -> Type[Any]:
return type_


def get_attr_class_field_names(obj: Any) -> List[str]:
is_type = isinstance(obj, type)
obj_type = obj if is_type else type(obj)
return list(attr.fields_dict(obj_type))


def get_attr_data(obj: Any, allow_objects: Optional[bool] = None) -> Dict[str, Any]:
from omegaconf.omegaconf import OmegaConf, _maybe_wrap

Expand Down Expand Up @@ -240,6 +246,10 @@ def get_attr_data(obj: Any, allow_objects: Optional[bool] = None) -> Dict[str, A
return d


def get_dataclass_field_names(obj: Any) -> List[str]:
return [field.name for field in dataclasses.fields(obj)]


def get_dataclass_data(
obj: Any, allow_objects: Optional[bool] = None
) -> Dict[str, Any]:
Expand Down Expand Up @@ -332,6 +342,15 @@ def is_structured_config_frozen(obj: Any) -> bool:
return False


def get_structured_config_field_names(obj: Any) -> List[str]:
if is_dataclass(obj):
return get_dataclass_field_names(obj)
elif is_attr_class(obj):
return get_attr_class_field_names(obj)
else:
raise ValueError(f"Unsupported type: {type(obj).__name__}")


def get_structured_config_data(
obj: Any, allow_objects: Optional[bool] = None
) -> Dict[str, Any]:
Expand Down
3 changes: 2 additions & 1 deletion omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,5 +728,6 @@ def _has_ref_type(self) -> bool:


class SCMode(Enum):
DICT = 1 # convert to plain dict
DICT = 1 # Convert to plain dict
DICT_CONFIG = 2 # Keep as OmegaConf DictConfig
INSTANTIATE = 3 # Create a dataclass or attrs class instance
4 changes: 4 additions & 0 deletions omegaconf/basecontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@ def convert(val: Node) -> Any:
and structured_config_mode == SCMode.DICT_CONFIG
):
return conf
if structured_config_mode == SCMode.INSTANTIATE and is_structured_config(
conf._metadata.object_type
):
return conf._to_object()

retdict: Dict[str, Any] = {}
for key in conf.keys():
Expand Down
49 changes: 48 additions & 1 deletion omegaconf/dictconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
ValueKind,
_get_value,
_is_interpolation,
_is_missing_literal,
_is_missing_value,
_is_none,
_valid_dict_key_annotation_type,
format_and_raise,
get_structured_config_data,
get_structured_config_field_names,
get_type_of,
get_value_kind,
is_container_annotation,
Expand All @@ -35,7 +37,7 @@
type_str,
valid_value_annotation_type,
)
from .base import Container, ContainerMetadata, DictKeyType, Node
from .base import Container, ContainerMetadata, DictKeyType, Node, SCMode
from .basecontainer import BaseContainer
from .errors import (
ConfigAttributeError,
Expand Down Expand Up @@ -682,3 +684,48 @@ def _dict_conf_eq(d1: "DictConfig", d2: "DictConfig") -> bool:
return False

return True

def _to_object(self) -> Any:
"""
Instantiate an instance of `self._metadata.object_type`.
This requires `self` to be a structured config.
Nested subconfigs are converted to_container with resolve=True.
"""
object_type = self._metadata.object_type
assert is_structured_config(object_type)
object_type_field_names = set(get_structured_config_field_names(object_type))

field_items: Dict[str, Any] = {}
nonfield_items: Dict[str, Any] = {}
for k in self.keys():
node = self._get_node(k)
assert isinstance(node, Node)
node = node._dereference_node(throw_on_resolution_failure=True)
assert node is not None
if isinstance(node, Container):
v = BaseContainer._to_content(
node,
resolve=True,
enum_to_str=False,
structured_config_mode=SCMode.INSTANTIATE,
)
else:
v = node._value()

if _is_missing_literal(v):
self._format_and_raise(
key=k,
value=None,
cause=MissingMandatoryValue(
"Structured config of type `$OBJECT_TYPE` has missing mandatory value: $KEY"
),
)
if k in object_type_field_names:
field_items[k] = v
else:
nonfield_items[k] = v

result = object_type(**field_items)
for k, v in nonfield_items.items():
setattr(result, k, v)
return result
27 changes: 27 additions & 0 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,9 @@ def to_container(
:param structured_config_mode: Specify how Structured Configs (DictConfigs backed by a dataclass) are handled.
By default (`structured_config_mode=SCMode.DICT`) structured configs are converted to plain dicts.
If `structured_config_mode=SCMode.DICT_CONFIG`, structured config nodes will remain as DictConfig.
If `structured_config_mode=SCMode.INSTANTIATE`, this function will instantiate structured configs
(DictConfigs backed by a dataclass), by creating an instance of the underlying dataclass.
See also OmegaConf.to_object.
:return: A dict or a list representing this config as a primitive container.
"""
if not OmegaConf.is_config(cfg):
Expand All @@ -593,6 +596,30 @@ def to_container(
structured_config_mode=structured_config_mode,
)

@staticmethod
def to_object(
cfg: Any,
*,
enum_to_str: bool = False,
) -> Union[Dict[DictKeyType, Any], List[Any], None, str, Any]:
"""
Resursively converts an OmegaConf config to a primitive container (dict or list).
Any DictConfig objects backed by dataclasses or attrs classes are instantiated
as instances of those backing classes.
This is an alias for OmegaConf.to_container(..., resolve=True, structured_config_mode=SCMode.INSTANTIATE)
:param cfg: the config to convert
:param enum_to_str: True to convert Enum values to strings
:return: A dict or a list or dataclass representing this config.
"""
return OmegaConf.to_container(
cfg=cfg,
resolve=True,
enum_to_str=enum_to_str,
structured_config_mode=SCMode.INSTANTIATE,
)

@staticmethod
def is_missing(cfg: Any, key: DictKeyType) -> bool:
assert isinstance(cfg, Container)
Expand Down
12 changes: 12 additions & 0 deletions tests/structured_conf/data/attr_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,14 @@ class Interpolation:
z2: str = SI("${x}_${y}")


@attr.s(auto_attribs=True)
class RelativeInterpolation:
x: int = 100
y: int = 200
z1: int = II(".x")
z2: str = SI("${.x}_${.y}")


@attr.s(auto_attribs=True)
class BoolOptional:
with_default: Optional[bool] = True
Expand Down Expand Up @@ -440,6 +448,10 @@ class Str2StrWithField(Dict[str, str]):
class Str2IntWithStrField(Dict[str, int]):
foo: int = 1

@attr.s(auto_attribs=True)
class Str2UserWithField(Dict[str, User]):
foo: User = User("Bond", 7)

class Error:
@attr.s(auto_attribs=True)
class User2Str(Dict[User, str]):
Expand Down
12 changes: 12 additions & 0 deletions tests/structured_conf/data/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,14 @@ class Interpolation:
z2: str = SI("${x}_${y}")


@dataclass
class RelativeInterpolation:
x: int = 100
y: int = 200
z1: int = II(".x")
z2: str = SI("${.x}_${.y}")


@dataclass
class BoolOptional:
with_default: Optional[bool] = True
Expand Down Expand Up @@ -461,6 +469,10 @@ class Str2StrWithField(Dict[str, str]):
class Str2IntWithStrField(Dict[str, int]):
foo: int = 1

@dataclass
class Str2UserWithField(Dict[str, User]):
foo: User = User("Bond", 7)

class Error:
@dataclass
class User2Str(Dict[User, str]):
Expand Down
12 changes: 12 additions & 0 deletions tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,6 +1234,18 @@ def finalize(self, cfg: Any) -> None:
),
id="list,readonly:del",
),
# to_object
param(
Expected(
create=lambda: OmegaConf.structured(User),
op=lambda cfg: OmegaConf.to_object(cfg),
exception_type=MissingMandatoryValue,
msg="Structured config of type `User` has missing mandatory value: name",
key="name",
child_node=lambda cfg: cfg._get_node("name"),
),
id="to_object:structured-missing-field",
),
]


Expand Down
Loading

0 comments on commit 0249455

Please sign in to comment.