Skip to content

Commit

Permalink
in get_structured_config_data: check for dict subclass data (#653)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasha10 authored Apr 11, 2021
1 parent 72ae5aa commit 9b037f9
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 1 deletion.
1 change: 1 addition & 0 deletions news/584.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix creation of structured config from a dict subclass: data from the dict is no longer thrown away.
40 changes: 39 additions & 1 deletion omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,37 @@ def _resolve_forward(type_: Type[Any], module: str) -> Type[Any]:
return type_


def extract_dict_subclass_data(obj: Any, parent: Any) -> Optional[Dict[str, Any]]:
"""Check if obj is an instance of a subclass of Dict. If so, extract the Dict keys/values."""
from omegaconf.omegaconf import _maybe_wrap

if isinstance(obj, type):
return None

obj_type = type(obj)
if is_dict_subclass(obj_type):
dict_subclass_data = {}
key_type, element_type = get_dict_key_value_types(obj_type)
for name, value in obj.items():
is_optional, type_ = _resolve_optional(element_type)
type_ = _resolve_forward(type_, obj.__module__)
try:
dict_subclass_data[name] = _maybe_wrap(
ref_type=type_,
is_optional=is_optional,
key=name,
value=value,
parent=parent,
)
except ValidationError as ex:
format_and_raise(
node=None, key=name, value=value, cause=ex, msg=str(ex)
)
return dict_subclass_data

return None


def get_attr_class_field_names(obj: Any) -> List[str]:
is_type = isinstance(obj, type)
obj_type = obj if is_type else type(obj)
Expand Down Expand Up @@ -243,6 +274,9 @@ def get_attr_data(obj: Any, allow_objects: Optional[bool] = None) -> Dict[str, A
except ValidationError as ex:
format_and_raise(node=None, key=name, value=value, cause=ex, msg=str(ex))
d[name]._set_parent(None)
dict_subclass_data = extract_dict_subclass_data(obj=obj, parent=dummy_parent)
if dict_subclass_data is not None:
d.update(dict_subclass_data)
return d


Expand All @@ -258,7 +292,8 @@ def get_dataclass_data(
flags = {"allow_objects": allow_objects} if allow_objects is not None else {}
dummy_parent = OmegaConf.create({}, flags=flags)
d = {}
resolved_hints = get_type_hints(get_type_of(obj))
obj_type = get_type_of(obj)
resolved_hints = get_type_hints(obj_type)
for field in dataclasses.fields(obj):
name = field.name
is_optional, type_ = _resolve_optional(resolved_hints[field.name])
Expand Down Expand Up @@ -290,6 +325,9 @@ def get_dataclass_data(
except ValidationError as ex:
format_and_raise(node=None, key=name, value=value, cause=ex, msg=str(ex))
d[name]._set_parent(None)
dict_subclass_data = extract_dict_subclass_data(obj=obj, parent=dummy_parent)
if dict_subclass_data is not None:
d.update(dict_subclass_data)
return d


Expand Down
5 changes: 5 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,8 @@ class InterpolationList:
@dataclass
class InterpolationDict:
dict: Dict[str, int] = II("optimization.lr")


@dataclass
class Str2Int(Dict[str, int]):
pass
4 changes: 4 additions & 0 deletions tests/structured_conf/data/attr_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,10 @@ class DictSubclass:
class Str2Str(Dict[str, str]):
pass

@attr.s(auto_attribs=True)
class Str2Int(Dict[str, int]):
pass

@attr.s(auto_attribs=True)
class Int2Str(Dict[int, str]):
pass
Expand Down
4 changes: 4 additions & 0 deletions tests/structured_conf/data/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,10 @@ class DictSubclass:
class Str2Str(Dict[str, str]):
pass

@dataclass
class Str2Int(Dict[str, int]):
pass

@dataclass
class Int2Str(Dict[int, str]):
pass
Expand Down
13 changes: 13 additions & 0 deletions tests/structured_conf/test_structured_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,19 @@ def test_str2str(self, module: Any) -> None:
with raises(KeyValidationError):
cfg[Color.RED]

def test_dict_subclass_data_preserved_upon_node_creation(self, module: Any) -> None:
src = module.DictSubclass.Str2StrWithField()
src["baz"] = "qux"
cfg = OmegaConf.structured(src)
assert cfg.foo == "bar"
assert cfg.baz == "qux"

def test_create_dict_subclass_with_bad_value_type(self, module: Any) -> None:
src = module.DictSubclass.Str2Int()
src["baz"] = "qux"
with raises(ValidationError):
OmegaConf.structured(src)

def test_str2str_as_sub_node(self, module: Any) -> None:
cfg = OmegaConf.create({"foo": module.DictSubclass.Str2Str})
assert OmegaConf.get_type(cfg.foo) == module.DictSubclass.Str2Str
Expand Down
15 changes: 15 additions & 0 deletions tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
Module,
Package,
Plugin,
Str2Int,
StructuredWithMissing,
SubscriptedDict,
UnionError,
Expand Down Expand Up @@ -772,6 +773,20 @@ def finalize(self, cfg: Any) -> None:
),
id="dict,structured:del",
),
# creating structured config
param(
Expected(
create=lambda: Str2Int(),
op=lambda src: (src.__setitem__("bar", "qux"), OmegaConf.structured(src)),
exception_type=ValidationError,
msg="Value 'qux' could not be converted to Integer",
object_type=None,
key="bar",
full_key="",
parent_node=lambda cfg: None,
),
id="structured,Dict_subclass:bad_value_type",
),
##############
# ListConfig #
##############
Expand Down

0 comments on commit 9b037f9

Please sign in to comment.