diff --git a/news/584.bugfix b/news/584.bugfix new file mode 100644 index 000000000..efe509887 --- /dev/null +++ b/news/584.bugfix @@ -0,0 +1 @@ +Fix creation of structured config from a dict subclass: data from the dict is no longer thrown away. diff --git a/omegaconf/_utils.py b/omegaconf/_utils.py index acd79af91..4752b29cf 100644 --- a/omegaconf/_utils.py +++ b/omegaconf/_utils.py @@ -201,6 +201,37 @@ def _resolve_forward(type_: Type[Any], module: str) -> Type[Any]: return type_ +def extract_dict_subclass_data(obj: Any, parent: Any) -> Optional[Dict[str, Any]]: + """Check if obj is an instance of a subclass of Dict. If so, extract the Dict keys/values.""" + from omegaconf.omegaconf import _maybe_wrap + + if isinstance(obj, type): + return None + + obj_type = type(obj) + if is_dict_subclass(obj_type): + dict_subclass_data = {} + key_type, element_type = get_dict_key_value_types(obj_type) + for name, value in obj.items(): + is_optional, type_ = _resolve_optional(element_type) + type_ = _resolve_forward(type_, obj.__module__) + try: + dict_subclass_data[name] = _maybe_wrap( + ref_type=type_, + is_optional=is_optional, + key=name, + value=value, + parent=parent, + ) + except ValidationError as ex: + format_and_raise( + node=None, key=name, value=value, cause=ex, msg=str(ex) + ) + return dict_subclass_data + + return None + + def get_attr_class_field_names(obj: Any) -> List[str]: is_type = isinstance(obj, type) obj_type = obj if is_type else type(obj) @@ -243,6 +274,9 @@ def get_attr_data(obj: Any, allow_objects: Optional[bool] = None) -> Dict[str, A except ValidationError as ex: format_and_raise(node=None, key=name, value=value, cause=ex, msg=str(ex)) d[name]._set_parent(None) + dict_subclass_data = extract_dict_subclass_data(obj=obj, parent=dummy_parent) + if dict_subclass_data is not None: + d.update(dict_subclass_data) return d @@ -258,7 +292,8 @@ def get_dataclass_data( flags = {"allow_objects": allow_objects} if allow_objects is not None else {} dummy_parent = OmegaConf.create({}, flags=flags) d = {} - resolved_hints = get_type_hints(get_type_of(obj)) + obj_type = get_type_of(obj) + resolved_hints = get_type_hints(obj_type) for field in dataclasses.fields(obj): name = field.name is_optional, type_ = _resolve_optional(resolved_hints[field.name]) @@ -290,6 +325,9 @@ def get_dataclass_data( except ValidationError as ex: format_and_raise(node=None, key=name, value=value, cause=ex, msg=str(ex)) d[name]._set_parent(None) + dict_subclass_data = extract_dict_subclass_data(obj=obj, parent=dummy_parent) + if dict_subclass_data is not None: + d.update(dict_subclass_data) return d diff --git a/tests/__init__.py b/tests/__init__.py index b6904dd9f..8a98583f6 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -211,3 +211,8 @@ class InterpolationList: @dataclass class InterpolationDict: dict: Dict[str, int] = II("optimization.lr") + + +@dataclass +class Str2Int(Dict[str, int]): + pass diff --git a/tests/structured_conf/data/attr_classes.py b/tests/structured_conf/data/attr_classes.py index fc7cf3d5c..dbd3ab933 100644 --- a/tests/structured_conf/data/attr_classes.py +++ b/tests/structured_conf/data/attr_classes.py @@ -416,6 +416,10 @@ class DictSubclass: class Str2Str(Dict[str, str]): pass + @attr.s(auto_attribs=True) + class Str2Int(Dict[str, int]): + pass + @attr.s(auto_attribs=True) class Int2Str(Dict[int, str]): pass diff --git a/tests/structured_conf/data/dataclasses.py b/tests/structured_conf/data/dataclasses.py index ef5ccf69d..a5121fa34 100644 --- a/tests/structured_conf/data/dataclasses.py +++ b/tests/structured_conf/data/dataclasses.py @@ -437,6 +437,10 @@ class DictSubclass: class Str2Str(Dict[str, str]): pass + @dataclass + class Str2Int(Dict[str, int]): + pass + @dataclass class Int2Str(Dict[int, str]): pass diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index 4e08921b9..c5f639357 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -904,6 +904,19 @@ def test_str2str(self, module: Any) -> None: with raises(KeyValidationError): cfg[Color.RED] + def test_dict_subclass_data_preserved_upon_node_creation(self, module: Any) -> None: + src = module.DictSubclass.Str2StrWithField() + src["baz"] = "qux" + cfg = OmegaConf.structured(src) + assert cfg.foo == "bar" + assert cfg.baz == "qux" + + def test_create_dict_subclass_with_bad_value_type(self, module: Any) -> None: + src = module.DictSubclass.Str2Int() + src["baz"] = "qux" + with raises(ValidationError): + OmegaConf.structured(src) + def test_str2str_as_sub_node(self, module: Any) -> None: cfg = OmegaConf.create({"foo": module.DictSubclass.Str2Str}) assert OmegaConf.get_type(cfg.foo) == module.DictSubclass.Str2Str diff --git a/tests/test_errors.py b/tests/test_errors.py index 541c85eec..a09cf2b10 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -38,6 +38,7 @@ Module, Package, Plugin, + Str2Int, StructuredWithMissing, SubscriptedDict, UnionError, @@ -772,6 +773,20 @@ def finalize(self, cfg: Any) -> None: ), id="dict,structured:del", ), + # creating structured config + param( + Expected( + create=lambda: Str2Int(), + op=lambda src: (src.__setitem__("bar", "qux"), OmegaConf.structured(src)), + exception_type=ValidationError, + msg="Value 'qux' could not be converted to Integer", + object_type=None, + key="bar", + full_key="", + parent_node=lambda cfg: None, + ), + id="structured,Dict_subclass:bad_value_type", + ), ############## # ListConfig # ##############