Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate subclassing dict #707

Merged
merged 14 commits into from
May 11, 2021
Merged
2 changes: 2 additions & 0 deletions docs/source/structured_config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ You can create a config with specified fields that can also accept arbitrary val
>>> conf.foo = "bar"
>>> assert conf.foo == "bar"

This feature is deprecated; OmegaConf's ability to handle structured configs
that subclass ``Dict`` is planned to be removed in a future release.
Jasha10 marked this conversation as resolved.
Show resolved Hide resolved


Static type checker support
Expand Down
10 changes: 9 additions & 1 deletion omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
import string
import sys
import warnings
from contextlib import contextmanager
from enum import Enum
from textwrap import dedent
Expand Down Expand Up @@ -234,10 +235,17 @@ def extract_dict_subclass_data(obj: Any, parent: Any) -> Optional[Dict[str, Any]
"""Check if obj is an instance of a subclass of Dict. If so, extract the Dict keys/values."""
from omegaconf.omegaconf import _maybe_wrap

obj_type = type(obj)
if is_dict_subclass(obj) or is_dict_subclass(obj_type):
warnings.warn(
"Subclassing of `Dict` by Structured Config classes is deprecated",
UserWarning,
stacklevel=2,
)

if isinstance(obj, type):
return None

obj_type = type(obj)
if is_dict_subclass(obj_type):
dict_subclass_data = {}
key_type, element_type = get_dict_key_value_types(obj_type)
Expand Down
43 changes: 32 additions & 11 deletions tests/structured_conf/test_structured_config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import re
import sys
from importlib import import_module
from typing import Any, Dict, List, Optional

from pytest import fixture, mark, param, raises
from pytest import fixture, mark, param, raises, warns

from omegaconf import (
MISSING,
Expand Down Expand Up @@ -932,8 +933,17 @@ def test_dataclass_frozen() -> None:


class TestDictSubclass:
def warns_deprecated(self) -> Any:
return warns(
UserWarning,
match=re.escape(
"Subclassing of `Dict` by Structured Config classes is deprecated"
),
)

def test_str2str(self, module: Any) -> None:
cfg = OmegaConf.structured(module.DictSubclass.Str2Str())
with self.warns_deprecated():
cfg = OmegaConf.structured(module.DictSubclass.Str2Str())
cfg.hello = "world"
assert cfg.hello == "world"

Expand All @@ -943,7 +953,8 @@ def test_str2str(self, module: Any) -> None:
def test_dict_subclass_data_preserved_upon_node_creation(self, module: Any) -> None:
src = module.DictSubclass.Str2StrWithField()
src["baz"] = "qux"
cfg = OmegaConf.structured(src)
with self.warns_deprecated():
cfg = OmegaConf.structured(src)
Jasha10 marked this conversation as resolved.
Show resolved Hide resolved
assert cfg.foo == "bar"
assert cfg.baz == "qux"

Expand All @@ -954,7 +965,8 @@ def test_create_dict_subclass_with_bad_value_type(self, module: Any) -> None:
OmegaConf.structured(src)

def test_str2str_as_sub_node(self, module: Any) -> None:
cfg = OmegaConf.create({"foo": module.DictSubclass.Str2Str})
with self.warns_deprecated():
cfg = OmegaConf.create({"foo": module.DictSubclass.Str2Str})
assert OmegaConf.get_type(cfg.foo) == module.DictSubclass.Str2Str
assert _utils.get_ref_type(cfg.foo) == Any

Expand All @@ -968,7 +980,8 @@ def test_str2str_as_sub_node(self, module: Any) -> None:
cfg.foo[123] = "fail"

def test_int2str(self, module: Any) -> None:
cfg = OmegaConf.structured(module.DictSubclass.Int2Str())
with self.warns_deprecated():
cfg = OmegaConf.structured(module.DictSubclass.Int2Str())

cfg[10] = "ten" # okay
assert cfg[10] == "ten"
Expand All @@ -986,7 +999,8 @@ def test_int2str(self, module: Any) -> None:
cfg[Color.RED] = "fail"

def test_int2str_as_sub_node(self, module: Any) -> None:
cfg = OmegaConf.create({"foo": module.DictSubclass.Int2Str})
with self.warns_deprecated():
cfg = OmegaConf.create({"foo": module.DictSubclass.Int2Str})
assert OmegaConf.get_type(cfg.foo) == module.DictSubclass.Int2Str
assert _utils.get_ref_type(cfg.foo) == Any

Expand All @@ -1006,7 +1020,8 @@ def test_int2str_as_sub_node(self, module: Any) -> None:
cfg.foo[Color.RED] = "fail"

def test_color2str(self, module: Any) -> None:
cfg = OmegaConf.structured(module.DictSubclass.Color2Str())
with self.warns_deprecated():
cfg = OmegaConf.structured(module.DictSubclass.Color2Str())
cfg[Color.RED] = "red"

with raises(KeyValidationError):
Expand All @@ -1016,7 +1031,8 @@ def test_color2str(self, module: Any) -> None:
cfg[123] = "nope"

def test_color2color(self, module: Any) -> None:
cfg = OmegaConf.structured(module.DictSubclass.Color2Color())
with self.warns_deprecated():
cfg = OmegaConf.structured(module.DictSubclass.Color2Color())

# add key
cfg[Color.RED] = "GREEN"
Expand Down Expand Up @@ -1045,7 +1061,8 @@ def test_color2color(self, module: Any) -> None:
cfg.greeen = "nope"

def test_str2user(self, module: Any) -> None:
cfg = OmegaConf.structured(module.DictSubclass.Str2User())
with self.warns_deprecated():
cfg = OmegaConf.structured(module.DictSubclass.Str2User())

cfg.bond = module.User(name="James Bond", age=7)
assert cfg.bond.name == "James Bond"
Expand All @@ -1060,7 +1077,8 @@ def test_str2user(self, module: Any) -> None:
cfg[Color.BLUE] = "nope"

def test_str2str_with_field(self, module: Any) -> None:
cfg = OmegaConf.structured(module.DictSubclass.Str2StrWithField())
with self.warns_deprecated():
cfg = OmegaConf.structured(module.DictSubclass.Str2StrWithField())
assert cfg.foo == "bar"
cfg.hello = "world"
assert cfg.hello == "world"
Expand All @@ -1074,10 +1092,13 @@ def test_usr2str(self, module: Any) -> None:
OmegaConf.structured(module.DictSubclass.Error.User2Str())

def test_str2int_with_field_of_different_type(self, module: Any) -> None:
cfg = OmegaConf.structured(module.DictSubclass.Str2IntWithStrField())
with warns(UserWarning):
cfg = OmegaConf.structured(module.DictSubclass.Str2IntWithStrField())
with raises(ValidationError):
cfg.foo = "str"


class TestConfigs2:
def test_construct_from_another_retain_node_types(self, module: Any) -> None:
cfg1 = OmegaConf.create(module.User(name="James Bond", age=7))
with raises(ValidationError):
Expand Down