Skip to content

Commit

Permalink
Deprecate subclassing dict (#707)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasha10 authored May 11, 2021
1 parent 0b99ead commit 1349008
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 50 deletions.
15 changes: 0 additions & 15 deletions docs/source/structured_config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,6 @@ The resulting object and will also rejects attempts to access or set fields that
>>> with raises(AttributeError):
... conf.does_not_exist

You can create a config with specified fields that can also accept arbitrary values by extending Dict:

.. doctest::

>>> @dataclass
... class DictWithFields(Dict[str, Any]):
... num: int = 10
>>>
>>> conf = OmegaConf.structured(DictWithFields)
>>> assert conf.num == 10
>>>
>>> conf.foo = "bar"
>>> assert conf.foo == "bar"



Static type checker support
^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions news/663.api_change
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support for Structured Configs that subclass `typing.Dict` is now deprecated.
24 changes: 18 additions & 6 deletions omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
import string
import sys
import warnings
from contextlib import contextmanager
from enum import Enum
from textwrap import dedent
Expand Down Expand Up @@ -234,11 +235,22 @@ def extract_dict_subclass_data(obj: Any, parent: Any) -> Optional[Dict[str, Any]
"""Check if obj is an instance of a subclass of Dict. If so, extract the Dict keys/values."""
from omegaconf.omegaconf import _maybe_wrap

if isinstance(obj, type):
return None
is_type = isinstance(obj, type)
obj_type = obj if is_type else type(obj)
subclasses_dict = is_dict_subclass(obj_type)

if subclasses_dict:
warnings.warn(
f"Class `{obj_type.__name__}` subclasses `Dict`."
+ " Subclassing `Dict` in Structured Config classes is deprecated,"
+ " see github.com/omry/omegaconf/issues/663",
UserWarning,
stacklevel=9,
)

obj_type = type(obj)
if is_dict_subclass(obj_type):
if is_type:
return None
elif subclasses_dict:
dict_subclass_data = {}
key_type, element_type = get_dict_key_value_types(obj_type)
for name, value in obj.items():
Expand All @@ -257,8 +269,8 @@ def extract_dict_subclass_data(obj: Any, parent: Any) -> Optional[Dict[str, Any]
node=None, key=name, value=value, cause=ex, msg=str(ex)
)
return dict_subclass_data

return None
else:
return None


def get_attr_class_field_names(obj: Any) -> List[str]:
Expand Down
13 changes: 13 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import re
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Union

import attr
from pytest import warns

from omegaconf import II, MISSING

Expand Down Expand Up @@ -220,3 +222,14 @@ class InterpolationDict:
@dataclass
class Str2Int(Dict[str, int]):
pass


def warns_dict_subclass_deprecated(dict_subclass: Any) -> Any:
return warns(
UserWarning,
match=re.escape(
f"Class `{dict_subclass.__name__}` subclasses `Dict`."
+ " Subclassing `Dict` in Structured Config classes is deprecated,"
+ " see github.com/omry/omegaconf/issues/663"
),
)
39 changes: 27 additions & 12 deletions tests/structured_conf/test_structured_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
_utils,
)
from omegaconf.errors import ConfigKeyError
from tests import Color, User
from tests import Color, User, warns_dict_subclass_deprecated


@fixture(
Expand Down Expand Up @@ -933,7 +933,8 @@ def test_dataclass_frozen() -> None:

class TestDictSubclass:
def test_str2str(self, module: Any) -> None:
cfg = OmegaConf.structured(module.DictSubclass.Str2Str())
with warns_dict_subclass_deprecated(module.DictSubclass.Str2Str):
cfg = OmegaConf.structured(module.DictSubclass.Str2Str())
cfg.hello = "world"
assert cfg.hello == "world"

Expand All @@ -943,18 +944,21 @@ def test_str2str(self, module: Any) -> None:
def test_dict_subclass_data_preserved_upon_node_creation(self, module: Any) -> None:
src = module.DictSubclass.Str2StrWithField()
src["baz"] = "qux"
cfg = OmegaConf.structured(src)
with warns_dict_subclass_deprecated(module.DictSubclass.Str2StrWithField):
cfg = OmegaConf.structured(src)
assert cfg.foo == "bar"
assert cfg.baz == "qux"

def test_create_dict_subclass_with_bad_value_type(self, module: Any) -> None:
src = module.DictSubclass.Str2Int()
src["baz"] = "qux"
with raises(ValidationError):
OmegaConf.structured(src)
with warns_dict_subclass_deprecated(module.DictSubclass.Str2Int):
OmegaConf.structured(src)

def test_str2str_as_sub_node(self, module: Any) -> None:
cfg = OmegaConf.create({"foo": module.DictSubclass.Str2Str})
with warns_dict_subclass_deprecated(module.DictSubclass.Str2Str):
cfg = OmegaConf.create({"foo": module.DictSubclass.Str2Str})
assert OmegaConf.get_type(cfg.foo) == module.DictSubclass.Str2Str
assert _utils.get_ref_type(cfg.foo) == Any

Expand All @@ -968,7 +972,8 @@ def test_str2str_as_sub_node(self, module: Any) -> None:
cfg.foo[123] = "fail"

def test_int2str(self, module: Any) -> None:
cfg = OmegaConf.structured(module.DictSubclass.Int2Str())
with warns_dict_subclass_deprecated(module.DictSubclass.Int2Str):
cfg = OmegaConf.structured(module.DictSubclass.Int2Str())

cfg[10] = "ten" # okay
assert cfg[10] == "ten"
Expand All @@ -986,7 +991,8 @@ def test_int2str(self, module: Any) -> None:
cfg[Color.RED] = "fail"

def test_int2str_as_sub_node(self, module: Any) -> None:
cfg = OmegaConf.create({"foo": module.DictSubclass.Int2Str})
with warns_dict_subclass_deprecated(module.DictSubclass.Int2Str):
cfg = OmegaConf.create({"foo": module.DictSubclass.Int2Str})
assert OmegaConf.get_type(cfg.foo) == module.DictSubclass.Int2Str
assert _utils.get_ref_type(cfg.foo) == Any

Expand All @@ -1006,7 +1012,8 @@ def test_int2str_as_sub_node(self, module: Any) -> None:
cfg.foo[Color.RED] = "fail"

def test_color2str(self, module: Any) -> None:
cfg = OmegaConf.structured(module.DictSubclass.Color2Str())
with warns_dict_subclass_deprecated(module.DictSubclass.Color2Str):
cfg = OmegaConf.structured(module.DictSubclass.Color2Str())
cfg[Color.RED] = "red"

with raises(KeyValidationError):
Expand All @@ -1016,7 +1023,8 @@ def test_color2str(self, module: Any) -> None:
cfg[123] = "nope"

def test_color2color(self, module: Any) -> None:
cfg = OmegaConf.structured(module.DictSubclass.Color2Color())
with warns_dict_subclass_deprecated(module.DictSubclass.Color2Color):
cfg = OmegaConf.structured(module.DictSubclass.Color2Color())

# add key
cfg[Color.RED] = "GREEN"
Expand Down Expand Up @@ -1045,7 +1053,8 @@ def test_color2color(self, module: Any) -> None:
cfg.greeen = "nope"

def test_str2user(self, module: Any) -> None:
cfg = OmegaConf.structured(module.DictSubclass.Str2User())
with warns_dict_subclass_deprecated(module.DictSubclass.Str2User):
cfg = OmegaConf.structured(module.DictSubclass.Str2User())

cfg.bond = module.User(name="James Bond", age=7)
assert cfg.bond.name == "James Bond"
Expand All @@ -1060,7 +1069,8 @@ def test_str2user(self, module: Any) -> None:
cfg[Color.BLUE] = "nope"

def test_str2str_with_field(self, module: Any) -> None:
cfg = OmegaConf.structured(module.DictSubclass.Str2StrWithField())
with warns_dict_subclass_deprecated(module.DictSubclass.Str2StrWithField):
cfg = OmegaConf.structured(module.DictSubclass.Str2StrWithField())
assert cfg.foo == "bar"
cfg.hello = "world"
assert cfg.hello == "world"
Expand All @@ -1074,10 +1084,15 @@ def test_usr2str(self, module: Any) -> None:
OmegaConf.structured(module.DictSubclass.Error.User2Str())

def test_str2int_with_field_of_different_type(self, module: Any) -> None:
cfg = OmegaConf.structured(module.DictSubclass.Str2IntWithStrField())
with warns_dict_subclass_deprecated(
module.DictSubclass.Str2IntWithStrField
):
cfg = OmegaConf.structured(module.DictSubclass.Str2IntWithStrField())
with raises(ValidationError):
cfg.foo = "str"


class TestConfigs2:
def test_construct_from_another_retain_node_types(self, module: Any) -> None:
cfg1 = OmegaConf.create(module.User(name="James Bond", age=7))
with raises(ValidationError):
Expand Down
40 changes: 27 additions & 13 deletions tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
SubscriptedDict,
UnionError,
User,
warns_dict_subclass_deprecated,
)


Expand Down Expand Up @@ -793,19 +794,6 @@ def finalize(self, cfg: Any) -> None:
id="dict,readonly:delattr",
),
# creating structured config
param(
Expected(
create=lambda: Str2Int(),
op=lambda src: (src.__setitem__("bar", "qux"), OmegaConf.structured(src)),
exception_type=ValidationError,
msg="Value 'qux' could not be converted to Integer",
object_type=None,
key="bar",
full_key="bar",
parent_node=lambda cfg: None,
),
id="structured,Dict_subclass:bad_value_type",
),
param(
Expected(
create=lambda: None,
Expand Down Expand Up @@ -1445,3 +1433,29 @@ def test_get_full_key_failure_in_format_and_raise() -> None:

with raises(RecursionError, match=match):
c.x


def test_dict_subclass_error() -> None:
"""
Test calling OmegaConf.structured(malformed_dict_subclass).
We expect a ValueError and a UserWarning (deprecation) to be raised simultaneously.
We are using a separate function instead of adding
warning support to the giant `test_errors` function above,
"""
src = Str2Int()
src["bar"] = "qux" # type: ignore
with raises(
ValidationError,
match=re.escape("Value 'qux' could not be converted to Integer"),
) as einfo:
with warns_dict_subclass_deprecated(Str2Int):
OmegaConf.structured(src)
ex = einfo.value
assert isinstance(ex, OmegaConfBaseException)

assert ex.key == "bar"
assert ex.full_key == "bar"
assert ex.ref_type is None
assert ex.object_type is None
assert ex.parent_node is None
assert ex.child_node is None
11 changes: 7 additions & 4 deletions tests/test_to_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
open_dict,
)
from omegaconf.errors import InterpolationResolutionError
from tests import Color, User
from tests import Color, User, warns_dict_subclass_deprecated


@mark.parametrize(
Expand Down Expand Up @@ -377,7 +377,8 @@ def test_nested_object_with_Any_ref_type(self, module: Any) -> None:
assert nested.var.interpolation == 456

def test_str2user_instantiate(self, module: Any) -> None:
cfg = OmegaConf.structured(module.DictSubclass.Str2User())
with warns_dict_subclass_deprecated(module.DictSubclass.Str2User):
cfg = OmegaConf.structured(module.DictSubclass.Str2User())
cfg.bond = module.User(name="James Bond", age=7)
data = self.round_trip_to_object(cfg)

Expand All @@ -386,7 +387,8 @@ def test_str2user_instantiate(self, module: Any) -> None:
assert data.bond == module.User("James Bond", 7)

def test_str2user_with_field_instantiate(self, module: Any) -> None:
cfg = OmegaConf.structured(module.DictSubclass.Str2UserWithField())
with warns_dict_subclass_deprecated(module.DictSubclass.Str2UserWithField):
cfg = OmegaConf.structured(module.DictSubclass.Str2UserWithField())
cfg.mp = module.User(name="Moneypenny", age=11)
data = self.round_trip_to_object(cfg)

Expand All @@ -397,7 +399,8 @@ def test_str2user_with_field_instantiate(self, module: Any) -> None:
assert data.mp == module.User("Moneypenny", 11)

def test_str2str_with_field_instantiate(self, module: Any) -> None:
cfg = OmegaConf.structured(module.DictSubclass.Str2StrWithField())
with warns_dict_subclass_deprecated(module.DictSubclass.Str2StrWithField):
cfg = OmegaConf.structured(module.DictSubclass.Str2StrWithField())
cfg.hello = "world"
data = self.round_trip_to_object(cfg)

Expand Down

0 comments on commit 1349008

Please sign in to comment.