Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate subclassing dict #707

Merged
merged 14 commits into from
May 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions docs/source/structured_config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,6 @@ The resulting object and will also rejects attempts to access or set fields that
>>> with raises(AttributeError):
... conf.does_not_exist

You can create a config with specified fields that can also accept arbitrary values by extending Dict:

.. doctest::

>>> @dataclass
... class DictWithFields(Dict[str, Any]):
... num: int = 10
>>>
>>> conf = OmegaConf.structured(DictWithFields)
>>> assert conf.num == 10
>>>
>>> conf.foo = "bar"
>>> assert conf.foo == "bar"



Static type checker support
^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions news/663.api_change
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support for Structured Configs that subclass `typing.Dict` is now deprecated.
24 changes: 18 additions & 6 deletions omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
import string
import sys
import warnings
from contextlib import contextmanager
from enum import Enum
from textwrap import dedent
Expand Down Expand Up @@ -234,11 +235,22 @@ def extract_dict_subclass_data(obj: Any, parent: Any) -> Optional[Dict[str, Any]
"""Check if obj is an instance of a subclass of Dict. If so, extract the Dict keys/values."""
from omegaconf.omegaconf import _maybe_wrap

if isinstance(obj, type):
return None
is_type = isinstance(obj, type)
obj_type = obj if is_type else type(obj)
subclasses_dict = is_dict_subclass(obj_type)

if subclasses_dict:
warnings.warn(
f"Class `{obj_type.__name__}` subclasses `Dict`."
+ " Subclassing `Dict` in Structured Config classes is deprecated,"
+ " see github.com/omry/omegaconf/issues/663",
UserWarning,
stacklevel=9,
)

obj_type = type(obj)
if is_dict_subclass(obj_type):
if is_type:
return None
elif subclasses_dict:
dict_subclass_data = {}
key_type, element_type = get_dict_key_value_types(obj_type)
for name, value in obj.items():
Expand All @@ -257,8 +269,8 @@ def extract_dict_subclass_data(obj: Any, parent: Any) -> Optional[Dict[str, Any]
node=None, key=name, value=value, cause=ex, msg=str(ex)
)
return dict_subclass_data

return None
else:
return None


def get_attr_class_field_names(obj: Any) -> List[str]:
Expand Down
13 changes: 13 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import re
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Union

import attr
from pytest import warns

from omegaconf import II, MISSING

Expand Down Expand Up @@ -220,3 +222,14 @@ class InterpolationDict:
@dataclass
class Str2Int(Dict[str, int]):
pass


def warns_dict_subclass_deprecated(dict_subclass: Any) -> Any:
return warns(
UserWarning,
match=re.escape(
f"Class `{dict_subclass.__name__}` subclasses `Dict`."
+ " Subclassing `Dict` in Structured Config classes is deprecated,"
+ " see github.com/omry/omegaconf/issues/663"
),
)
39 changes: 27 additions & 12 deletions tests/structured_conf/test_structured_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
_utils,
)
from omegaconf.errors import ConfigKeyError
from tests import Color, User
from tests import Color, User, warns_dict_subclass_deprecated


@fixture(
Expand Down Expand Up @@ -933,7 +933,8 @@ def test_dataclass_frozen() -> None:

class TestDictSubclass:
def test_str2str(self, module: Any) -> None:
cfg = OmegaConf.structured(module.DictSubclass.Str2Str())
with warns_dict_subclass_deprecated(module.DictSubclass.Str2Str):
cfg = OmegaConf.structured(module.DictSubclass.Str2Str())
cfg.hello = "world"
assert cfg.hello == "world"

Expand All @@ -943,18 +944,21 @@ def test_str2str(self, module: Any) -> None:
def test_dict_subclass_data_preserved_upon_node_creation(self, module: Any) -> None:
src = module.DictSubclass.Str2StrWithField()
src["baz"] = "qux"
cfg = OmegaConf.structured(src)
with warns_dict_subclass_deprecated(module.DictSubclass.Str2StrWithField):
cfg = OmegaConf.structured(src)
assert cfg.foo == "bar"
assert cfg.baz == "qux"

def test_create_dict_subclass_with_bad_value_type(self, module: Any) -> None:
src = module.DictSubclass.Str2Int()
src["baz"] = "qux"
with raises(ValidationError):
OmegaConf.structured(src)
with warns_dict_subclass_deprecated(module.DictSubclass.Str2Int):
OmegaConf.structured(src)

def test_str2str_as_sub_node(self, module: Any) -> None:
cfg = OmegaConf.create({"foo": module.DictSubclass.Str2Str})
with warns_dict_subclass_deprecated(module.DictSubclass.Str2Str):
cfg = OmegaConf.create({"foo": module.DictSubclass.Str2Str})
assert OmegaConf.get_type(cfg.foo) == module.DictSubclass.Str2Str
assert _utils.get_ref_type(cfg.foo) == Any

Expand All @@ -968,7 +972,8 @@ def test_str2str_as_sub_node(self, module: Any) -> None:
cfg.foo[123] = "fail"

def test_int2str(self, module: Any) -> None:
cfg = OmegaConf.structured(module.DictSubclass.Int2Str())
with warns_dict_subclass_deprecated(module.DictSubclass.Int2Str):
cfg = OmegaConf.structured(module.DictSubclass.Int2Str())

cfg[10] = "ten" # okay
assert cfg[10] == "ten"
Expand All @@ -986,7 +991,8 @@ def test_int2str(self, module: Any) -> None:
cfg[Color.RED] = "fail"

def test_int2str_as_sub_node(self, module: Any) -> None:
cfg = OmegaConf.create({"foo": module.DictSubclass.Int2Str})
with warns_dict_subclass_deprecated(module.DictSubclass.Int2Str):
cfg = OmegaConf.create({"foo": module.DictSubclass.Int2Str})
assert OmegaConf.get_type(cfg.foo) == module.DictSubclass.Int2Str
assert _utils.get_ref_type(cfg.foo) == Any

Expand All @@ -1006,7 +1012,8 @@ def test_int2str_as_sub_node(self, module: Any) -> None:
cfg.foo[Color.RED] = "fail"

def test_color2str(self, module: Any) -> None:
cfg = OmegaConf.structured(module.DictSubclass.Color2Str())
with warns_dict_subclass_deprecated(module.DictSubclass.Color2Str):
cfg = OmegaConf.structured(module.DictSubclass.Color2Str())
cfg[Color.RED] = "red"

with raises(KeyValidationError):
Expand All @@ -1016,7 +1023,8 @@ def test_color2str(self, module: Any) -> None:
cfg[123] = "nope"

def test_color2color(self, module: Any) -> None:
cfg = OmegaConf.structured(module.DictSubclass.Color2Color())
with warns_dict_subclass_deprecated(module.DictSubclass.Color2Color):
cfg = OmegaConf.structured(module.DictSubclass.Color2Color())

# add key
cfg[Color.RED] = "GREEN"
Expand Down Expand Up @@ -1045,7 +1053,8 @@ def test_color2color(self, module: Any) -> None:
cfg.greeen = "nope"

def test_str2user(self, module: Any) -> None:
cfg = OmegaConf.structured(module.DictSubclass.Str2User())
with warns_dict_subclass_deprecated(module.DictSubclass.Str2User):
cfg = OmegaConf.structured(module.DictSubclass.Str2User())

cfg.bond = module.User(name="James Bond", age=7)
assert cfg.bond.name == "James Bond"
Expand All @@ -1060,7 +1069,8 @@ def test_str2user(self, module: Any) -> None:
cfg[Color.BLUE] = "nope"

def test_str2str_with_field(self, module: Any) -> None:
cfg = OmegaConf.structured(module.DictSubclass.Str2StrWithField())
with warns_dict_subclass_deprecated(module.DictSubclass.Str2StrWithField):
cfg = OmegaConf.structured(module.DictSubclass.Str2StrWithField())
assert cfg.foo == "bar"
cfg.hello = "world"
assert cfg.hello == "world"
Expand All @@ -1074,10 +1084,15 @@ def test_usr2str(self, module: Any) -> None:
OmegaConf.structured(module.DictSubclass.Error.User2Str())

def test_str2int_with_field_of_different_type(self, module: Any) -> None:
cfg = OmegaConf.structured(module.DictSubclass.Str2IntWithStrField())
with warns_dict_subclass_deprecated(
module.DictSubclass.Str2IntWithStrField
):
cfg = OmegaConf.structured(module.DictSubclass.Str2IntWithStrField())
with raises(ValidationError):
cfg.foo = "str"


class TestConfigs2:
def test_construct_from_another_retain_node_types(self, module: Any) -> None:
cfg1 = OmegaConf.create(module.User(name="James Bond", age=7))
with raises(ValidationError):
Expand Down
40 changes: 27 additions & 13 deletions tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
SubscriptedDict,
UnionError,
User,
warns_dict_subclass_deprecated,
)


Expand Down Expand Up @@ -793,19 +794,6 @@ def finalize(self, cfg: Any) -> None:
id="dict,readonly:delattr",
),
# creating structured config
param(
Expected(
create=lambda: Str2Int(),
op=lambda src: (src.__setitem__("bar", "qux"), OmegaConf.structured(src)),
exception_type=ValidationError,
msg="Value 'qux' could not be converted to Integer",
object_type=None,
key="bar",
full_key="bar",
parent_node=lambda cfg: None,
),
id="structured,Dict_subclass:bad_value_type",
),
param(
Jasha10 marked this conversation as resolved.
Show resolved Hide resolved
Expected(
create=lambda: None,
Expand Down Expand Up @@ -1445,3 +1433,29 @@ def test_get_full_key_failure_in_format_and_raise() -> None:

with raises(RecursionError, match=match):
c.x


def test_dict_subclass_error() -> None:
"""
Test calling OmegaConf.structured(malformed_dict_subclass).
We expect a ValueError and a UserWarning (deprecation) to be raised simultaneously.
We are using a separate function instead of adding
warning support to the giant `test_errors` function above,
"""
src = Str2Int()
src["bar"] = "qux" # type: ignore
with raises(
ValidationError,
match=re.escape("Value 'qux' could not be converted to Integer"),
) as einfo:
with warns_dict_subclass_deprecated(Str2Int):
OmegaConf.structured(src)
ex = einfo.value
assert isinstance(ex, OmegaConfBaseException)

assert ex.key == "bar"
assert ex.full_key == "bar"
assert ex.ref_type is None
assert ex.object_type is None
assert ex.parent_node is None
assert ex.child_node is None
11 changes: 7 additions & 4 deletions tests/test_to_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
open_dict,
)
from omegaconf.errors import InterpolationResolutionError
from tests import Color, User
from tests import Color, User, warns_dict_subclass_deprecated


@mark.parametrize(
Expand Down Expand Up @@ -377,7 +377,8 @@ def test_nested_object_with_Any_ref_type(self, module: Any) -> None:
assert nested.var.interpolation == 456

def test_str2user_instantiate(self, module: Any) -> None:
cfg = OmegaConf.structured(module.DictSubclass.Str2User())
with warns_dict_subclass_deprecated(module.DictSubclass.Str2User):
cfg = OmegaConf.structured(module.DictSubclass.Str2User())
cfg.bond = module.User(name="James Bond", age=7)
data = self.round_trip_to_object(cfg)

Expand All @@ -386,7 +387,8 @@ def test_str2user_instantiate(self, module: Any) -> None:
assert data.bond == module.User("James Bond", 7)

def test_str2user_with_field_instantiate(self, module: Any) -> None:
cfg = OmegaConf.structured(module.DictSubclass.Str2UserWithField())
with warns_dict_subclass_deprecated(module.DictSubclass.Str2UserWithField):
cfg = OmegaConf.structured(module.DictSubclass.Str2UserWithField())
cfg.mp = module.User(name="Moneypenny", age=11)
data = self.round_trip_to_object(cfg)

Expand All @@ -397,7 +399,8 @@ def test_str2user_with_field_instantiate(self, module: Any) -> None:
assert data.mp == module.User("Moneypenny", 11)

def test_str2str_with_field_instantiate(self, module: Any) -> None:
cfg = OmegaConf.structured(module.DictSubclass.Str2StrWithField())
with warns_dict_subclass_deprecated(module.DictSubclass.Str2StrWithField):
cfg = OmegaConf.structured(module.DictSubclass.Str2StrWithField())
cfg.hello = "world"
data = self.round_trip_to_object(cfg)

Expand Down