diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 4aae23a20..20bdf4678 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -51,7 +51,23 @@ From a dictionary 3: c -OmegaConf supports `str`, `int` and Enums as dictionary key types. +Here is an example of various supported key types: + +.. doctest:: + + >>> from enum import Enum + >>> class Color(Enum): + ... RED = 1 + ... BLUE = 2 + >>> + >>> conf = OmegaConf.create( + ... {"key": "str", 123: "int", True: "bool", 3.14: "float", Color.RED: "Color"} + ... ) + >>> + >>> print(conf) + {'key': 'str', 123: 'int', True: 'bool', 3.14: 'float', : 'Color'} + +OmegaConf supports `str`, `int`, `bool`, `float` and Enums as dictionary key types. From a list ^^^^^^^^^^^ diff --git a/news/483.feature b/news/483.feature new file mode 100644 index 000000000..c72e40405 --- /dev/null +++ b/news/483.feature @@ -0,0 +1 @@ +Add DictConfig support for keys of type float and bool diff --git a/news/554.bugfix b/news/554.bugfix new file mode 100644 index 000000000..5959ad7a5 --- /dev/null +++ b/news/554.bugfix @@ -0,0 +1 @@ +When a dictconfig has enum-typed keys, __delitem__ can now be called with a string naming the enum member to be deleted. diff --git a/omegaconf/base.py b/omegaconf/base.py index d0da37fc7..4303c152c 100644 --- a/omegaconf/base.py +++ b/omegaconf/base.py @@ -19,7 +19,7 @@ from .grammar_parser import parse from .grammar_visitor import GrammarVisitor -DictKeyType = Union[str, int, Enum] +DictKeyType = Union[str, int, Enum, float, bool] _MARKER_ = object() @@ -179,7 +179,7 @@ def _format_and_raise( assert False @abstractmethod - def _get_full_key(self, key: Union[str, Enum, int, None]) -> str: + def _get_full_key(self, key: Optional[Union[DictKeyType, int]]) -> str: ... def _dereference_node( diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 40779037a..cab237abc 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -44,7 +44,7 @@ def __init__(self, parent: Optional["Container"], metadata: ContainerMetadata): def _resolve_with_default( self, - key: Union[str, int, Enum], + key: Union[DictKeyType, int], value: Any, default_value: Any = DEFAULT_VALUE_MARKER, ) -> Any: @@ -697,11 +697,11 @@ def _validate_set(self, key: Any, value: Any) -> None: def _value(self) -> Any: return self.__dict__["_content"] - def _get_full_key(self, key: Union[str, Enum, int, slice, None]) -> str: + def _get_full_key(self, key: Union[DictKeyType, int, slice, None]) -> str: from .listconfig import ListConfig from .omegaconf import _select_one - if not isinstance(key, (int, str, Enum, slice, type(None))): + if not isinstance(key, (int, str, Enum, float, bool, slice, type(None))): return "" def _slice_to_str(x: slice) -> str: @@ -715,6 +715,8 @@ def prepand(full_key: str, parent_type: Any, cur_type: Any, key: Any) -> str: key = _slice_to_str(key) elif isinstance(key, Enum): key = key.name + elif isinstance(key, (int, float, bool)): + key = str(key) if issubclass(parent_type, ListConfig): if full_key != "": diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index 6141fdc5d..0b9912d3f 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -275,28 +275,26 @@ def _validate_and_normalize_key(self, key: Any) -> DictKeyType: def _s_validate_and_normalize_key(self, key_type: Any, key: Any) -> DictKeyType: if key_type is Any: for t in DictKeyType.__args__: # type: ignore - try: - return self._s_validate_and_normalize_key(key_type=t, key=key) - except KeyValidationError: - pass + if isinstance(key, t): + return key # type: ignore raise KeyValidationError("Incompatible key type '$KEY_TYPE'") - elif key_type == str: - if not isinstance(key, str): - raise KeyValidationError( - f"Key $KEY ($KEY_TYPE) is incompatible with ({key_type.__name__})" - ) - - return key - elif key_type == int: - if not isinstance(key, int): + elif key_type is bool and key in [0, 1]: + # Python treats True as 1 and False as 0 when used as dict keys + # assert hash(0) == hash(False) + # assert hash(1) == hash(True) + return bool(key) + elif key_type in (str, int, float, bool): # primitive type + if not isinstance(key, key_type): raise KeyValidationError( f"Key $KEY ($KEY_TYPE) is incompatible with ({key_type.__name__})" ) - return key + return key # type: ignore elif issubclass(key_type, Enum): try: - ret = EnumNode.validate_and_convert_to_enum(key_type, key) + ret = EnumNode.validate_and_convert_to_enum( + key_type, key, allow_none=False + ) assert ret is not None return ret except ValidationError: @@ -377,6 +375,7 @@ def __getitem__(self, key: DictKeyType) -> Any: self._format_and_raise(key=key, value=None, cause=e) def __delitem__(self, key: DictKeyType) -> None: + key = self._validate_and_normalize_key(key) if self._get_flag("readonly"): self._format_and_raise( key=key, @@ -402,7 +401,11 @@ def __delitem__(self, key: DictKeyType) -> None: ), ) - del self.__dict__["_content"][key] + try: + del self.__dict__["_content"][key] + except KeyError: + msg = "Key not found: '$KEY'" + self._format_and_raise(key=key, value=None, cause=ConfigKeyError(msg)) def get(self, key: DictKeyType, default_value: Any = None) -> Any: """Return the value for `key` if `key` is in the dictionary, else diff --git a/omegaconf/nodes.py b/omegaconf/nodes.py index b1e318b9f..8927605d3 100644 --- a/omegaconf/nodes.py +++ b/omegaconf/nodes.py @@ -11,7 +11,7 @@ get_value_kind, is_primitive_container, ) -from omegaconf.base import Container, Metadata, Node +from omegaconf.base import Container, DictKeyType, Metadata, Node from omegaconf.errors import ( ConfigKeyError, ReadonlyConfigError, @@ -122,7 +122,7 @@ def _is_missing(self) -> bool: def _is_interpolation(self) -> bool: return _is_interpolation(self._value()) - def _get_full_key(self, key: Union[str, Enum, int, None]) -> str: + def _get_full_key(self, key: Optional[Union[DictKeyType, int]]) -> str: parent = self._get_parent() if parent is None: if self._metadata.key is None: @@ -366,9 +366,9 @@ def validate_and_convert(self, value: Any) -> Optional[Enum]: @staticmethod def validate_and_convert_to_enum( - enum_type: Type[Enum], value: Any + enum_type: Type[Enum], value: Any, allow_none: bool = True ) -> Optional[Enum]: - if value is None: + if allow_none and value is None: return None if not isinstance(value, (str, int)) and not isinstance(value, enum_type): diff --git a/tests/__init__.py b/tests/__init__.py index d720430b2..fee650f71 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -181,7 +181,11 @@ class UntypedDict: @dataclass class SubscriptedDict: - dict: Dict[str, int] = field(default_factory=lambda: {"foo": 4}) + dict_str: Dict[str, int] = field(default_factory=lambda: {"foo": 4}) + dict_enum: Dict[Color, int] = field(default_factory=lambda: {Color.RED: 4}) + dict_int: Dict[int, int] = field(default_factory=lambda: {123: 4}) + dict_float: Dict[float, int] = field(default_factory=lambda: {123.45: 4}) + dict_bool: Dict[bool, int] = field(default_factory=lambda: {True: 4, False: 5}) @dataclass diff --git a/tests/structured_conf/data/attr_classes.py b/tests/structured_conf/data/attr_classes.py index d7a8d5626..e9d23fd2a 100644 --- a/tests/structured_conf/data/attr_classes.py +++ b/tests/structured_conf/data/attr_classes.py @@ -388,10 +388,8 @@ class DictExamples: "blue": Color.BLUE, } int_keys: Dict[int, str] = {1: "one", 2: "two"} - - -@attr.s(auto_attribs=True) -class DictWithEnumKeys: + float_keys: Dict[float, str] = {1.1: "one", 2.2: "two"} + bool_keys: Dict[bool, str] = {True: "T", False: "F"} enum_key: Dict[Color, str] = {Color.RED: "red", Color.GREEN: "green"} @@ -414,6 +412,14 @@ class Str2Str(Dict[str, str]): class Int2Str(Dict[int, str]): pass + @attr.s(auto_attribs=True) + class Float2Str(Dict[float, str]): + pass + + @attr.s(auto_attribs=True) + class Bool2Str(Dict[bool, str]): + pass + @attr.s(auto_attribs=True) class Color2Str(Dict[Color, str]): pass diff --git a/tests/structured_conf/data/dataclasses.py b/tests/structured_conf/data/dataclasses.py index 0e8f99b64..10548c3ea 100644 --- a/tests/structured_conf/data/dataclasses.py +++ b/tests/structured_conf/data/dataclasses.py @@ -403,10 +403,10 @@ class DictExamples: } ) int_keys: Dict[int, str] = field(default_factory=lambda: {1: "one", 2: "two"}) - - -@dataclass -class DictWithEnumKeys: + float_keys: Dict[float, str] = field( + default_factory=lambda: {1.1: "one", 2.2: "two"} + ) + bool_keys: Dict[bool, str] = field(default_factory=lambda: {True: "T", False: "F"}) enum_key: Dict[Color, str] = field( default_factory=lambda: {Color.RED: "red", Color.GREEN: "green"} ) @@ -433,6 +433,14 @@ class Str2Str(Dict[str, str]): class Int2Str(Dict[int, str]): pass + @dataclass + class Float2Str(Dict[float, str]): + pass + + @dataclass + class Bool2Str(Dict[bool, str]): + pass + @dataclass class Color2Str(Dict[Color, str]): pass diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index 6c0109b39..100ffb6d9 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -639,42 +639,47 @@ def test_any(name: str) -> None: Color.BLUE, ] - def test_dict_examples(self, class_type: str) -> None: + def test_dict_examples_any(self, class_type: str) -> None: module: Any = import_module(class_type) conf = OmegaConf.structured(module.DictExamples) - def test_any(name: str) -> None: - conf[name].c = True - conf[name].d = Color.RED - conf[name].e = 3.1415 - assert conf[name] == { - "a": 1, - "b": "foo", - "c": True, - "d": Color.RED, - "e": 3.1415, - } + dct = conf.any + dct.c = True + dct.d = Color.RED + dct.e = 3.1415 + assert dct == {"a": 1, "b": "foo", "c": True, "d": Color.RED, "e": 3.1415} - # any and untyped - test_any("any") + def test_dict_examples_int(self, class_type: str) -> None: + module: Any = import_module(class_type) + conf = OmegaConf.structured(module.DictExamples) + dct = conf.ints # test ints with pytest.raises(ValidationError): - conf.ints.a = "foo" - conf.ints.c = 10 - assert conf.ints == {"a": 10, "b": 20, "c": 10} + dct.a = "foo" + dct.c = 10 + assert dct == {"a": 10, "b": 20, "c": 10} + + def test_dict_examples_strings(self, class_type: str) -> None: + module: Any = import_module(class_type) + conf = OmegaConf.structured(module.DictExamples) # test strings conf.strings.c = Color.BLUE assert conf.strings == {"a": "foo", "b": "bar", "c": "Color.BLUE"} - # tests booleans + def test_dict_examples_bool(self, class_type: str) -> None: + module: Any = import_module(class_type) + conf = OmegaConf.structured(module.DictExamples) + dct = conf.booleans + + # test bool with pytest.raises(ValidationError): - conf.booleans.a = "foo" - conf.booleans.c = True - conf.booleans.d = "off" - conf.booleans.e = 1 - assert conf.booleans == { + dct.a = "foo" + dct.c = True + dct.d = "off" + dct.e = 1 + assert dct == { "a": True, "b": False, "c": True, @@ -682,60 +687,110 @@ def test_any(name: str) -> None: "e": True, } - # test colors - with pytest.raises(ValidationError): - conf.colors.foo = "foo" - conf.colors.c = Color.BLUE - conf.colors.d = "RED" - conf.colors.e = "Color.GREEN" - conf.colors.f = 3 - assert conf.colors == { - "red": Color.RED, - "green": Color.GREEN, - "blue": Color.BLUE, - "c": Color.BLUE, - "d": Color.RED, - "e": Color.GREEN, - "f": Color.BLUE, - } + class TestDictExamples: + @pytest.fixture + def conf(self, class_type: str) -> DictConfig: + module: Any = import_module(class_type) + conf: DictConfig = OmegaConf.structured(module.DictExamples) + return conf - # test int_keys - with pytest.raises(KeyValidationError): - conf.int_keys.foo_key = "foo_value" - conf.int_keys[3] = "three" - assert conf.int_keys == { - 1: "one", - 2: "two", - 3: "three", - } + def test_dict_examples_colors(self, conf: DictConfig) -> None: + dct = conf.colors - def test_enum_key(self, class_type: str) -> None: - module: Any = import_module(class_type) - conf = OmegaConf.structured(module.DictWithEnumKeys) + # test colors + with pytest.raises(ValidationError): + dct.foo = "foo" + dct.c = Color.BLUE + dct.d = "RED" + dct.e = "Color.GREEN" + dct.f = 3 + assert dct == { + "red": Color.RED, + "green": Color.GREEN, + "blue": Color.BLUE, + "c": Color.BLUE, + "d": Color.RED, + "e": Color.GREEN, + "f": Color.BLUE, + } - # When an Enum is a dictionary key the name of the Enum is actually used - # as the key - assert conf.enum_key.RED == "red" - assert conf.enum_key["GREEN"] == "green" - assert conf.enum_key[Color.GREEN] == "green" + def test_dict_examples_str_keys(self, conf: DictConfig) -> None: + dct = conf.any - conf.enum_key["BLUE"] = "Blue too" - assert conf.enum_key[Color.BLUE] == "Blue too" - with pytest.raises(KeyValidationError): - conf.enum_key["error"] = "error" + with pytest.raises(KeyValidationError): + dct[123] = "bad key type" + dct["c"] = "three" + assert dct == { + "a": 1, + "b": "foo", + "c": "three", + } + + def test_dict_examples_int_keys(self, conf: DictConfig) -> None: + dct = conf.int_keys + + # test int keys + with pytest.raises(KeyValidationError): + dct.foo_key = "foo_value" + dct[3] = "three" + assert dct == { + 1: "one", + 2: "two", + 3: "three", + } + + def test_dict_examples_float_keys(self, conf: DictConfig) -> None: + dct = conf.float_keys + + # test float keys + with pytest.raises(KeyValidationError): + dct.foo_key = "foo_value" + dct[3.3] = "three" + assert dct == { + 1.1: "one", + 2.2: "two", + 3.3: "three", + } + + def test_dict_examples_bool_keys(self, conf: DictConfig) -> None: + dct = conf.bool_keys + + # test bool_keys + with pytest.raises(KeyValidationError): + dct.foo_key = "foo_value" + dct[True] = "new value" + assert dct == { + True: "new value", + False: "F", + } + + def test_dict_examples_enum_key(self, conf: DictConfig) -> None: + dct = conf.enum_key + + # When an Enum is a dictionary key the name of the Enum is actually used + # as the key + assert dct.RED == "red" + assert dct["GREEN"] == "green" + assert dct[Color.GREEN] == "green" + + dct["BLUE"] = "Blue too" + assert dct[Color.BLUE] == "Blue too" + with pytest.raises(KeyValidationError): + dct["error"] = "error" def test_dict_of_objects(self, class_type: str) -> None: module: Any = import_module(class_type) conf = OmegaConf.structured(module.DictOfObjects) - assert conf.users.joe.age == 18 - assert conf.users.joe.name == "Joe" + dct = conf.users + assert dct.joe.age == 18 + assert dct.joe.name == "Joe" - conf.users.bond = module.User(name="James Bond", age=7) - assert conf.users.bond.name == "James Bond" - assert conf.users.bond.age == 7 + dct.bond = module.User(name="James Bond", age=7) + assert dct.bond.name == "James Bond" + assert dct.bond.age == 7 with pytest.raises(ValidationError): - conf.users.fail = "fail" + dct.fail = "fail" def test_list_of_objects(self, class_type: str) -> None: module: Any = import_module(class_type) diff --git a/tests/test_basic_ops_dict.py b/tests/test_basic_ops_dict.py index aab594948..88ecc4331 100644 --- a/tests/test_basic_ops_dict.py +++ b/tests/test_basic_ops_dict.py @@ -73,39 +73,51 @@ def test_getattr_dict() -> None: @pytest.mark.parametrize( - "key", - ["a", 1], + "key,match", + [ + pytest.param("a", "a", id="str"), + pytest.param(1, "1", id="int"), + pytest.param(123.45, "123.45", id="float"), + pytest.param(True, "True", id="bool-T"), + pytest.param(False, "False", id="bool-F"), + pytest.param(Enum1.FOO, "FOO", id="enum"), + ], ) class TestDictKeyTypes: - def test_mandatory_value(self, key: DictKeyType) -> None: + def test_mandatory_value(self, key: DictKeyType, match: str) -> None: c = OmegaConf.create({key: "???"}) - with pytest.raises(MissingMandatoryValue, match=str(key)): + with pytest.raises(MissingMandatoryValue, match=match): c[key] if isinstance(key, str): - with pytest.raises(MissingMandatoryValue, match=key): + with pytest.raises(MissingMandatoryValue, match=match): getattr(c, key) - def test_nested_dict_mandatory_value(self, key: DictKeyType) -> None: + def test_nested_dict_mandatory_value_inner( + self, key: DictKeyType, match: str + ) -> None: c = OmegaConf.create({"b": {key: "???"}}) - with pytest.raises(MissingMandatoryValue): + with pytest.raises(MissingMandatoryValue, match=match): c.b[key] if isinstance(key, str): - with pytest.raises(MissingMandatoryValue): + with pytest.raises(MissingMandatoryValue, match=match): getattr(c.b, key) + def test_nested_dict_mandatory_value_outer( + self, key: DictKeyType, match: str + ) -> None: c = OmegaConf.create({key: {"b": "???"}}) - with pytest.raises(MissingMandatoryValue): + with pytest.raises(MissingMandatoryValue, match=match): c[key].b if isinstance(key, str): - with pytest.raises(MissingMandatoryValue): + with pytest.raises(MissingMandatoryValue, match=match): getattr(c, key).b - def test_subscript_get(self, key: DictKeyType) -> None: + def test_subscript_get(self, key: DictKeyType, match: str) -> None: c = OmegaConf.create({key: "b"}) assert isinstance(c, DictConfig) assert "b" == c[key] - def test_subscript_set(self, key: DictKeyType) -> None: + def test_subscript_set(self, key: DictKeyType, match: str) -> None: c = OmegaConf.create() c[key] = "b" assert {key: "b"} == c @@ -116,6 +128,9 @@ def test_subscript_set(self, key: DictKeyType) -> None: [ ({"a": 10, "b": 11}, "a", {"b": 11}), ({1: "a", 2: "b"}, 1, {2: "b"}), + ({123.45: "a", 67.89: "b"}, 67.89, {123.45: "a"}), + ({True: "a", False: "b"}, False, {True: "a"}), + ({Enum1.FOO: "foo", Enum1.BAR: "bar"}, Enum1.FOO, {Enum1.BAR: "bar"}), ], ) class TestDelitemKeyTypes: @@ -124,8 +139,14 @@ def test_dict_delitem(self, src: Any, key: DictKeyType, expected: Any) -> None: assert c == src del c[key] assert c == expected + + def test_dict_delitem_KeyError( + self, src: Any, key: DictKeyType, expected: Any + ) -> None: + c = OmegaConf.create(expected) + assert c == expected with pytest.raises(KeyError): - del c["not_found"] + del c[key] def test_dict_struct_delitem( self, src: Any, key: DictKeyType, expected: Any @@ -282,6 +303,7 @@ def test_iterate_dict_with_interpolation() -> None: @pytest.mark.parametrize( "cfg, key, default_, expected", [ + # string key pytest.param({"a": 1, "b": 2}, "a", "__NO_DEFAULT__", 1, id="no_default"), pytest.param({"a": 1, "b": 2}, "not_found", None, None, id="none_default"), pytest.param( @@ -312,6 +334,31 @@ def test_iterate_dict_with_interpolation() -> None: "default", id="enum_key_with_default", ), + # other key types + pytest.param( + {123.45: "a", 67.89: "b"}, + 67.89, + "__NO_DEFAULT__", + "b", + id="float_key_no_default", + ), + pytest.param( + {123.45: "a", 67.89: "b"}, + "not found", + None, + None, + id="float_key_with_default", + ), + pytest.param( + {True: "a", False: "b"}, + False, + "__NO_DEFAULT__", + "b", + id="bool_key_no_default", + ), + pytest.param( + {True: "a", False: "b"}, "not found", None, None, id="bool_key_with_default" + ), ], ) def test_dict_pop(cfg: Dict[Any, Any], key: Any, default_: Any, expected: Any) -> None: @@ -364,12 +411,23 @@ def test_dict_structured_mode_pop() -> None: @pytest.mark.parametrize( "cfg, key, expectation", [ + # key not found ({"a": 1, "b": 2}, "not_found", pytest.raises(KeyError)), + ({1: "a", 2: "b"}, 3, pytest.raises(KeyError)), + ({123.45: "a", 67.89: "b"}, 10.11, pytest.raises(KeyError)), + ({True: "a"}, False, pytest.raises(KeyError)), + ({Enum1.FOO: "bar"}, Enum1.BAR, pytest.raises(KeyError)), # Interpolations ({"a": "???", "b": 2}, "a", pytest.raises(MissingMandatoryValue)), + ({1: "???", 2: "b"}, 1, pytest.raises(MissingMandatoryValue)), + ({123.45: "???", 67.89: "b"}, 123.45, pytest.raises(MissingMandatoryValue)), + ({True: "???", False: "b"}, True, pytest.raises(MissingMandatoryValue)), + ( + {Enum1.FOO: "???", Enum1.BAR: "bar"}, + Enum1.FOO, + pytest.raises(MissingMandatoryValue), + ), ({"a": "${b}", "b": "???"}, "a", pytest.raises(MissingMandatoryValue)), - # enum key - ({Enum1.FOO: "bar"}, Enum1.BAR, pytest.raises(KeyError)), ], ) def test_dict_pop_error(cfg: Dict[Any, Any], key: Any, expectation: Any) -> None: @@ -382,6 +440,7 @@ def test_dict_pop_error(cfg: Dict[Any, Any], key: Any, expectation: Any) -> None @pytest.mark.parametrize( "conf,key,expected", [ + # str key type ({"a": 1, "b": {}}, "a", True), ({"a": 1, "b": {}}, "b", True), ({"a": 1, "b": {}}, "c", False), @@ -392,8 +451,10 @@ def test_dict_pop_error(cfg: Dict[Any, Any], key: Any, expectation: Any) -> None ({"a": "${unknown_resolver:bar}"}, "a", True), ({"a": None, "b": "${a}"}, "b", True), ({"a": "cat", "b": "${a}"}, "b", True), + # Enum key type ({Enum1.FOO: 1, "b": {}}, Enum1.FOO, True), ({Enum1.FOO: 1, "b": {}}, "aaa", False), + ({Enum1.FOO: 1, "b": {}}, "FOO", False), ( DictConfig(content={Enum1.FOO: "foo"}, key_type=Enum1, element_type=str), Enum1.FOO, @@ -404,10 +465,39 @@ def test_dict_pop_error(cfg: Dict[Any, Any], key: Any, expectation: Any) -> None "incompatible_key_type", False, ), + ( + DictConfig(content={Enum1.FOO: "foo"}, key_type=Enum1, element_type=str), + "FOO", + True, + ), + ( + DictConfig(content={Enum1.FOO: "foo"}, key_type=Enum1, element_type=str), + None, + False, + ), + # int key type ({1: "a", 2: {}}, 1, True), ({1: "a", 2: {}}, 2, True), ({1: "a", 2: {}}, 3, False), ({1: "a", 2: "???"}, 2, False), + ({1: "a", 2: "???"}, None, False), + ({1: "a", 2: "???"}, "1", False), + (DictConfig({1: "a", 2: "???"}, key_type=int), "1", False), + # float key type + ({1.1: "a", 2.2: {}}, 1.1, True), + ({1.1: "a", 2.2: {}}, "1.1", False), + (DictConfig({1.1: "a", 2.2: {}}, key_type=float), "1.1", False), + ({1.1: "a", 2.2: {}}, 2.2, True), + ({1.1: "a", 2.2: {}}, 3.3, False), + ({1.1: "a", 2.2: "???"}, 2.2, False), + ({1.1: "a", 2.2: "???"}, None, False), + # bool key type + ({True: "a", False: {}}, True, True), + ({True: "a", False: {}}, False, True), + ({True: "a", False: {}}, "no", False), + ({True: "a", False: {}}, 1, True), + ({True: "a", False: {}}, None, False), + ({True: "a", False: "???"}, False, False), ], ) def test_in_dict(conf: Any, key: str, expected: Any) -> None: diff --git a/tests/test_compare_dictconfig_vs_dict.py b/tests/test_compare_dictconfig_vs_dict.py new file mode 100644 index 000000000..33905cbd9 --- /dev/null +++ b/tests/test_compare_dictconfig_vs_dict.py @@ -0,0 +1,647 @@ +""" +This file compares DictConfig methods with the corresponding +methods of standard python's dict. +The following methods are compared: + __contains__ + __delitem__ + __eq__ + __getitem__ + __setitem__ + get + pop + keys + values + items + +We have separate test classes for the following cases: + TestUntypedDictConfig: for DictConfig without a set key_type + TestPrimitiveTypeDunderMethods: for DictConfig where key_type is primitive + TestEnumTypeDunderMethods: for DictConfig where key_type is Enum +""" +from copy import deepcopy +from enum import Enum +from typing import Any, Dict, Optional + +from pytest import fixture, mark, param, raises + +from omegaconf import DictConfig, OmegaConf +from omegaconf.errors import ConfigKeyError, ConfigTypeError, KeyValidationError +from tests import Enum1 + + +@fixture( + params=[ + "str", + 1, + 3.1415, + True, + Enum1.FOO, + ] +) +def key(request: Any) -> Any: + """A key to test indexing into DictConfig.""" + return request.param + + +@fixture +def python_dict(data: Dict[Any, Any]) -> Dict[Any, Any]: + """Just a standard python dictionary, to be used in comparison with DictConfig.""" + return deepcopy(data) + + +@fixture(params=[None, False, True]) +def struct_mode(request: Any) -> Optional[bool]: + struct_mode: Optional[bool] = request.param + return struct_mode + + +@mark.parametrize( + "data", + [ + param({"a": 10}, id="str"), + param({1: "a"}, id="int"), + param({123.45: "a"}, id="float"), + param({True: "a"}, id="bool"), + param({Enum1.FOO: "foo"}, id="Enum1"), + ], +) +class TestUntypedDictConfig: + """Compare DictConfig with python dict in the case where key_type is not set.""" + + @fixture + def cfg(self, python_dict: Any, struct_mode: Optional[bool]) -> DictConfig: + """Create a DictConfig instance from the given data""" + cfg: DictConfig = DictConfig(content=python_dict) + OmegaConf.set_struct(cfg, struct_mode) + return cfg + + def test__setitem__( + self, python_dict: Any, cfg: DictConfig, key: Any, struct_mode: Optional[bool] + ) -> None: + """Ensure that __setitem__ has same effect on python dict and on DictConfig.""" + if struct_mode and key not in cfg: + with raises(ConfigKeyError): + cfg[key] = "sentinel" + else: + python_dict[key] = "sentinel" + cfg[key] = "sentinel" + assert python_dict == cfg + + def test__getitem__(self, python_dict: Any, cfg: DictConfig, key: Any) -> None: + """Ensure that __getitem__ has same result with python dict as with DictConfig.""" + try: + result = python_dict[key] + except KeyError: + with raises(ConfigKeyError): + cfg[key] + else: + assert result == cfg[key] + + @mark.parametrize("struct_mode", [False, None]) + def test__delitem__(self, python_dict: Any, cfg: DictConfig, key: Any) -> None: + """Ensure that __delitem__ has same result with python dict as with DictConfig.""" + try: + del python_dict[key] + assert key not in python_dict + except KeyError: + with raises(ConfigKeyError): + del cfg[key] + else: + del cfg[key] + assert key not in cfg + + @mark.parametrize("struct_mode", [True]) + def test__delitem__struct_mode( + self, python_dict: Any, cfg: DictConfig, key: Any + ) -> None: + """Ensure that __delitem__ fails in struct_mode""" + with raises(ConfigTypeError): + del cfg[key] + + def test__contains__(self, python_dict: Any, cfg: Any, key: Any) -> None: + """Ensure that __contains__ has same result with python dict as with DictConfig.""" + assert (key in python_dict) == (key in cfg) + + def test__eq__(self, python_dict: Any, cfg: Any, key: Any) -> None: + assert python_dict == cfg + + def test_get(self, python_dict: Any, cfg: DictConfig, key: Any) -> None: + """Ensure that __getitem__ has same result with python dict as with DictConfig.""" + assert python_dict.get(key) == cfg.get(key) + + def test_get_with_default( + self, python_dict: Any, cfg: DictConfig, key: Any + ) -> None: + """Ensure that __getitem__ has same result with python dict as with DictConfig.""" + assert python_dict.get(key, "DEFAULT") == cfg.get(key, "DEFAULT") + + @mark.parametrize("struct_mode", [False, None]) + def test_pop( + self, + python_dict: Any, + cfg: DictConfig, + key: Any, + ) -> None: + """Ensure that pop has same result with python dict as with DictConfig.""" + try: + result = python_dict.pop(key) + except KeyError: + with raises(ConfigKeyError): + cfg.pop(key) + else: + assert result == cfg.pop(key) + assert python_dict.keys() == cfg.keys() + + @mark.parametrize("struct_mode", [True]) + def test_pop_struct_mode( + self, + python_dict: Any, + cfg: DictConfig, + key: Any, + ) -> None: + """Ensure that pop fails in struct mode.""" + with raises(ConfigTypeError): + cfg.pop(key) + + @mark.parametrize("struct_mode", [False, None]) + def test_pop_with_default( + self, + python_dict: Any, + cfg: DictConfig, + key: Any, + ) -> None: + """Ensure that pop(..., DEFAULT) has same result with python dict as with DictConfig.""" + assert python_dict.pop(key, "DEFAULT") == cfg.pop(key, "DEFAULT") + assert python_dict.keys() == cfg.keys() + + @mark.parametrize("struct_mode", [True]) + def test_pop_with_default_struct_mode( + self, + python_dict: Any, + cfg: DictConfig, + key: Any, + ) -> None: + """Ensure that pop(..., DEFAULT) fails in struct mode.""" + with raises(ConfigTypeError): + cfg.pop(key, "DEFAULT") + + def test_keys(self, python_dict: Any, cfg: Any) -> None: + assert python_dict.keys() == cfg.keys() + + def test_values(self, python_dict: Any, cfg: Any) -> None: + assert list(python_dict.values()) == list(cfg.values()) + + def test_items(self, python_dict: Any, cfg: Any) -> None: + assert list(python_dict.items()) == list(cfg.items()) + + +@fixture +def cfg_typed( + python_dict: Any, cfg_key_type: Any, struct_mode: Optional[bool] +) -> DictConfig: + """Create a DictConfig instance that has strongly-typed keys""" + cfg_typed: DictConfig = DictConfig(content=python_dict, key_type=cfg_key_type) + OmegaConf.set_struct(cfg_typed, struct_mode) + return cfg_typed + + +@mark.parametrize( + "cfg_key_type,data", + [(str, {"a": 10}), (int, {1: "a"}), (float, {123.45: "a"}), (bool, {True: "a"})], +) +class TestPrimitiveTypeDunderMethods: + """Compare DictConfig with python dict in the case where key_type is a primitive type.""" + + def test__setitem__primitive_typed( + self, + python_dict: Any, + cfg_typed: DictConfig, + key: Any, + cfg_key_type: Any, + struct_mode: Optional[bool], + ) -> None: + """When DictConfig keys are strongly typed, + ensure that __setitem__ has same effect on python dict and on DictConfig.""" + if struct_mode and key not in cfg_typed: + if isinstance(key, cfg_key_type) or ( + cfg_key_type == bool and key in (0, 1) + ): + with raises(ConfigKeyError): + cfg_typed[key] = "sentinel" + else: + with raises(KeyValidationError): + cfg_typed[key] = "sentinel" + else: + python_dict[key] = "sentinel" + if isinstance(key, cfg_key_type) or ( + cfg_key_type == bool and key in (0, 1) + ): + cfg_typed[key] = "sentinel" + assert python_dict == cfg_typed + else: + with raises(KeyValidationError): + cfg_typed[key] = "sentinel" + + def test__getitem__primitive_typed( + self, + python_dict: Any, + cfg_typed: DictConfig, + key: Any, + cfg_key_type: Any, + ) -> None: + """When Dictconfig keys are strongly typed, + ensure that __getitem__ has same result with python dict as with DictConfig.""" + try: + result = python_dict[key] + except KeyError: + if isinstance(key, cfg_key_type) or ( + cfg_key_type == bool and key in (0, 1) + ): + with raises(ConfigKeyError): + cfg_typed[key] + else: + with raises(KeyValidationError): + cfg_typed[key] + else: + assert result == cfg_typed[key] + + @mark.parametrize("struct_mode", [False, None]) + def test__delitem__primitive_typed( + self, + python_dict: Any, + cfg_typed: DictConfig, + key: Any, + cfg_key_type: Any, + ) -> None: + """When Dictconfig keys are strongly typed, + ensure that __delitem__ has same result with python dict as with DictConfig.""" + try: + del python_dict[key] + assert key not in python_dict + except KeyError: + if isinstance(key, cfg_key_type) or ( + cfg_key_type == bool and key in (0, 1) + ): + with raises(ConfigKeyError): + del cfg_typed[key] + else: + with raises(KeyValidationError): + del cfg_typed[key] + else: + del cfg_typed[key] + assert key not in cfg_typed + + @mark.parametrize("struct_mode", [True]) + def test__delitem__primitive_typed_struct_mode( + self, + python_dict: Any, + cfg_typed: DictConfig, + key: Any, + cfg_key_type: Any, + ) -> None: + """Ensure ensure that struct-mode __delitem__ raises ConfigTypeError or KeyValidationError""" + if isinstance(key, cfg_key_type) or (cfg_key_type == bool and key in (0, 1)): + with raises(ConfigTypeError): + del cfg_typed[key] + else: + with raises(KeyValidationError): + del cfg_typed[key] + + def test__contains__primitive_typed( + self, python_dict: Any, cfg_typed: Any, key: Any + ) -> None: + """Ensure that __contains__ has same result with python dict as with DictConfig.""" + assert (key in python_dict) == (key in cfg_typed) + + def test__eq__primitive_typed( + self, python_dict: Any, cfg_typed: Any, key: Any + ) -> None: + assert python_dict == cfg_typed + + def test_get_primitive_typed( + self, + python_dict: Any, + cfg_typed: DictConfig, + key: Any, + cfg_key_type: Any, + ) -> None: + """Ensure that __getitem__ has same result with python dict as with DictConfig.""" + if isinstance(key, cfg_key_type) or (cfg_key_type == bool and key in (0, 1)): + assert python_dict.get(key) == cfg_typed.get(key) + else: + with raises(KeyValidationError): + cfg_typed.get(key) + + def test_get_with_default_primitive_typed( + self, + python_dict: Any, + cfg_typed: DictConfig, + key: Any, + cfg_key_type: Any, + ) -> None: + """Ensure that __getitem__ has same result with python dict as with DictConfig.""" + if isinstance(key, cfg_key_type) or (cfg_key_type == bool and key in (0, 1)): + assert python_dict.get(key, "DEFAULT") == cfg_typed.get(key, "DEFAULT") + else: + with raises(KeyValidationError): + cfg_typed.get(key, "DEFAULT") + + @mark.parametrize("struct_mode", [False, None]) + def test_pop_primitive_typed( + self, + python_dict: Any, + cfg_typed: DictConfig, + key: Any, + cfg_key_type: Any, + ) -> None: + """Ensure that pop has same result with python dict as with DictConfig.""" + if isinstance(key, cfg_key_type) or (cfg_key_type == bool and key in (0, 1)): + try: + result = python_dict.pop(key) + except KeyError: + with raises(ConfigKeyError): + cfg_typed.pop(key) + else: + assert result == cfg_typed.pop(key) + assert python_dict.keys() == cfg_typed.keys() + else: + with raises(KeyValidationError): + cfg_typed.pop(key) + + @mark.parametrize("struct_mode", [True]) + def test_pop_primitive_typed_struct_mode( + self, + python_dict: Any, + cfg_typed: DictConfig, + key: Any, + cfg_key_type: Any, + ) -> None: + """Ensure that pop fails in struct mode.""" + with raises(ConfigTypeError): + cfg_typed.pop(key) + + @mark.parametrize("struct_mode", [False, None]) + def test_pop_with_default_primitive_typed( + self, + python_dict: Any, + cfg_typed: DictConfig, + key: Any, + cfg_key_type: Any, + ) -> None: + """Ensure that pop(..., DEFAULT) has same result with python dict as with DictConfig.""" + if isinstance(key, cfg_key_type) or (cfg_key_type == bool and key in (0, 1)): + assert python_dict.pop(key, "DEFAULT") == cfg_typed.pop(key, "DEFAULT") + assert python_dict.keys() == cfg_typed.keys() + else: + with raises(KeyValidationError): + cfg_typed.pop(key, "DEFAULT") + + @mark.parametrize("struct_mode", [True]) + def test_pop_with_default_primitive_typed_struct_mode( + self, + python_dict: Any, + cfg_typed: DictConfig, + key: Any, + cfg_key_type: Any, + ) -> None: + """Ensure that pop(..., DEFAULT) fails in struct mode""" + with raises(ConfigTypeError): + cfg_typed.pop(key) + + def test_keys_primitive_typed(self, python_dict: Any, cfg_typed: Any) -> None: + assert python_dict.keys() == cfg_typed.keys() + + def test_values_primitive_typed(self, python_dict: Any, cfg_typed: Any) -> None: + assert list(python_dict.values()) == list(cfg_typed.values()) + + def test_items_primitive_typed(self, python_dict: Any, cfg_typed: Any) -> None: + assert list(python_dict.items()) == list(cfg_typed.items()) + + +@mark.parametrize("cfg_key_type,data", [(Enum1, {Enum1.FOO: "foo"})]) +class TestEnumTypeDunderMethods: + """Compare DictConfig with python dict in the case where key_type is an Enum type.""" + + @fixture + def key_coerced(self, key: Any, cfg_key_type: Any) -> Any: + """ + This handles key coersion in the special case where DictConfig key_type + is a subclass of Enum: keys of type `str` or `int` are coerced to `key_type`. + See https://github.com/omry/omegaconf/pull/484#issuecomment-765772019 + """ + assert issubclass(cfg_key_type, Enum) + if type(key) == str and key in [e.name for e in cfg_key_type]: + return cfg_key_type[key] + elif type(key) == int and key in [e.value for e in cfg_key_type]: + return cfg_key_type(key) + else: + return key + + def test__setitem__enum_typed( + self, + python_dict: Any, + cfg_typed: DictConfig, + key: Any, + key_coerced: Any, + cfg_key_type: Any, + struct_mode: Optional[bool], + ) -> None: + """When DictConfig keys are strongly typed, + ensure that __setitem__ has same effect on python dict and on DictConfig.""" + if struct_mode and key_coerced not in cfg_typed: + if isinstance(key_coerced, cfg_key_type): + with raises(ConfigKeyError): + cfg_typed[key] = "sentinel" + else: + with raises(KeyValidationError): + cfg_typed[key] = "sentinel" + else: + python_dict[key_coerced] = "sentinel" + if isinstance(key_coerced, cfg_key_type): + cfg_typed[key] = "sentinel" + assert python_dict == cfg_typed + else: + with raises(KeyValidationError): + cfg_typed[key] = "sentinel" + + def test__getitem__enum_typed( + self, + python_dict: Any, + cfg_typed: DictConfig, + key: Any, + key_coerced: Any, + cfg_key_type: Any, + ) -> None: + """When Dictconfig keys are strongly typed, + ensure that __getitem__ has same result with python dict as with DictConfig.""" + try: + result = python_dict[key_coerced] + except KeyError: + if isinstance(key_coerced, cfg_key_type): + with raises(ConfigKeyError): + cfg_typed[key] + else: + with raises(KeyValidationError): + cfg_typed[key] + else: + assert result == cfg_typed[key] + + @mark.parametrize("struct_mode", [False, None]) + def test__delitem__enum_typed( + self, + python_dict: Any, + cfg_typed: DictConfig, + key: Any, + key_coerced: Any, + cfg_key_type: Any, + ) -> None: + """When Dictconfig keys are strongly typed, + ensure that __delitem__ has same result with python dict as with DictConfig.""" + try: + del python_dict[key_coerced] + assert key_coerced not in python_dict + except KeyError: + if isinstance(key_coerced, cfg_key_type): + with raises(ConfigKeyError): + del cfg_typed[key] + else: + with raises(KeyValidationError): + del cfg_typed[key] + else: + del cfg_typed[key] + assert key not in cfg_typed + + @mark.parametrize("struct_mode", [True]) + def test__delitem__enum_typed_struct_mode( + self, + python_dict: Any, + cfg_typed: DictConfig, + key: Any, + key_coerced: Any, + cfg_key_type: Any, + ) -> None: + """Ensure that __delitem__ errors in struct mode""" + if isinstance(key_coerced, cfg_key_type): + with raises(ConfigTypeError): + del cfg_typed[key] + else: + with raises(KeyValidationError): + del cfg_typed[key] + + def test__contains__enum_typed( + self, python_dict: Any, cfg_typed: Any, key: Any, key_coerced: Any + ) -> None: + """Ensure that __contains__ has same result with python dict as with DictConfig.""" + assert (key_coerced in python_dict) == (key in cfg_typed) + + def test__eq__enum_typed(self, python_dict: Any, cfg_typed: Any, key: Any) -> None: + assert python_dict == cfg_typed + + def test_get_enum_typed( + self, + python_dict: Any, + cfg_typed: DictConfig, + key: Any, + key_coerced: Any, + cfg_key_type: Any, + ) -> None: + """Ensure that __getitem__ has same result with python dict as with DictConfig.""" + if isinstance(key_coerced, cfg_key_type): + assert python_dict.get(key_coerced) == cfg_typed.get(key) + else: + with raises(KeyValidationError): + cfg_typed.get(key) + + def test_get_with_default_enum_typed( + self, + python_dict: Any, + cfg_typed: DictConfig, + key: Any, + key_coerced: Any, + cfg_key_type: Any, + ) -> None: + """Ensure that __getitem__ has same result with python dict as with DictConfig.""" + if isinstance(key_coerced, cfg_key_type): + assert python_dict.get(key_coerced, "DEFAULT") == cfg_typed.get( + key, "DEFAULT" + ) + else: + with raises(KeyValidationError): + cfg_typed.get(key, "DEFAULT") + + @mark.parametrize("struct_mode", [False, None]) + def test_pop_enum_typed( + self, + python_dict: Any, + cfg_typed: DictConfig, + key: Any, + key_coerced: Any, + cfg_key_type: Any, + ) -> None: + """Ensure that pop has same result with python dict as with DictConfig.""" + if isinstance(key_coerced, cfg_key_type): + try: + result = python_dict.pop(key_coerced) + except KeyError: + with raises(ConfigKeyError): + cfg_typed.pop(key) + else: + assert result == cfg_typed.pop(key) + assert python_dict.keys() == cfg_typed.keys() + else: + with raises(KeyValidationError): + cfg_typed.pop(key) + + @mark.parametrize("struct_mode", [True]) + def test_pop_enum_typed_struct_mode( + self, + python_dict: Any, + cfg_typed: DictConfig, + key: Any, + key_coerced: Any, + cfg_key_type: Any, + ) -> None: + """Ensure that pop fails in struct mode""" + with raises(ConfigTypeError): + cfg_typed.pop(key) + + @mark.parametrize("struct_mode", [False, None]) + def test_pop_with_default_enum_typed( + self, + python_dict: Any, + cfg_typed: DictConfig, + key: Any, + key_coerced: Any, + cfg_key_type: Any, + ) -> None: + """Ensure that pop(..., DEFAULT) has same result with python dict as with DictConfig.""" + if isinstance(key_coerced, cfg_key_type): + assert python_dict.pop(key_coerced, "DEFAULT") == cfg_typed.pop( + key, "DEFAULT" + ) + assert python_dict.keys() == cfg_typed.keys() + else: + with raises(KeyValidationError): + cfg_typed.pop(key, "DEFAULT") + + @mark.parametrize("struct_mode", [True]) + def test_pop_with_default_enum_typed_struct_mode( + self, + python_dict: Any, + cfg_typed: DictConfig, + key: Any, + key_coerced: Any, + cfg_key_type: Any, + ) -> None: + """Ensure that pop(..., DEFAULT) errors in struct mode""" + with raises(ConfigTypeError): + cfg_typed.pop(key) + + def test_keys_enum_typed(self, python_dict: Any, cfg_typed: Any) -> None: + assert python_dict.keys() == cfg_typed.keys() + + def test_values_enum_typed(self, python_dict: Any, cfg_typed: Any) -> None: + assert list(python_dict.values()) == list(cfg_typed.values()) + + def test_items_enum_typed(self, python_dict: Any, cfg_typed: Any) -> None: + assert list(python_dict.items()) == list(cfg_typed.items()) diff --git a/tests/test_errors.py b/tests/test_errors.py index 2d0c4d9bb..1ccc811bd 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -1,7 +1,6 @@ import re from dataclasses import dataclass from enum import Enum -from textwrap import dedent from typing import Any, Dict, List, Optional, Type import pytest @@ -91,7 +90,7 @@ def finalize(self, cfg: Any) -> None: if self.key is None: self.full_key = "" else: - if isinstance(self.key, (str, int, Enum, slice)): + if isinstance(self.key, (str, int, Enum, float, bool, slice)): self.full_key = self.key else: self.full_key = "" @@ -153,25 +152,47 @@ def finalize(self, cfg: Any) -> None: pytest.param( Expected( create=lambda: OmegaConf.create({"foo": "bar"}), - op=lambda cfg: cfg.pop("nevermind"), - key="nevermind", + op=lambda cfg: cfg.pop("not_found"), + key="not_found", exception_type=ConfigKeyError, - msg="Key not found: 'nevermind'", + msg="Key not found: 'not_found'", ), id="dict:pop_invalid", ), pytest.param( Expected( create=lambda: OmegaConf.create({"foo": {}}), - op=lambda cfg: cfg.foo.pop("nevermind"), - key="nevermind", - full_key="foo.nevermind", + op=lambda cfg: cfg.foo.pop("not_found"), + key="not_found", + full_key="foo.not_found", parent_node=lambda cfg: cfg.foo, exception_type=ConfigKeyError, - msg="Key not found: 'nevermind' (path: 'foo.nevermind')", + msg="Key not found: 'not_found' (path: 'foo.not_found')", ), id="dict:pop_invalid_nested", ), + pytest.param( + Expected( + create=lambda: OmegaConf.create({"foo": "bar"}), + op=lambda cfg: cfg.__delitem__("not_found"), + key="not_found", + exception_type=ConfigKeyError, + msg="Key not found: 'not_found'", + ), + id="dict:del_invalid", + ), + pytest.param( + Expected( + create=lambda: OmegaConf.create({"foo": {}}), + op=lambda cfg: cfg.foo.__delitem__("not_found"), + key="not_found", + full_key="foo.not_found", + parent_node=lambda cfg: cfg.foo, + exception_type=ConfigKeyError, + msg="Key not found: 'not_found'", + ), + id="dict:del_invalid_nested", + ), pytest.param( Expected( create=lambda: OmegaConf.structured(ConcretePlugin), @@ -314,6 +335,16 @@ def finalize(self, cfg: Any) -> None: ), id="DictConfig[Color,str]:setitem_bad_key", ), + pytest.param( + Expected( + create=lambda: DictConfig(key_type=Color, element_type=str, content={}), + op=lambda cfg: cfg.__setitem__(None, "bar"), + exception_type=KeyValidationError, + msg="Key 'None' is incompatible with the enum type 'Color', valid: [RED, GREEN, BLUE]", + key=None, + ), + id="DictConfig[Color,str]:setitem_bad_key", + ), pytest.param( Expected( create=lambda: DictConfig(key_type=str, element_type=Color, content={}), @@ -358,6 +389,16 @@ def finalize(self, cfg: Any) -> None: ), id="DictConfig[Color,str]:getitem_str_key", ), + pytest.param( + Expected( + create=lambda: DictConfig(key_type=Color, element_type=str, content={}), + op=lambda cfg: cfg.__getitem__(None), + exception_type=KeyValidationError, + msg="Key 'None' is incompatible with the enum type 'Color', valid: [RED, GREEN, BLUE]", + key=None, + ), + id="DictConfig[Color,str]:getitem_str_key_None", + ), pytest.param( Expected( create=lambda: DictConfig(key_type=str, element_type=str, content={}), @@ -459,17 +500,34 @@ def finalize(self, cfg: Any) -> None: create=lambda: DictConfig({}, key_type=int), op=lambda cfg: cfg.get("foo"), exception_type=KeyValidationError, - msg=dedent( - """\ - Key foo (str) is incompatible with (int) - full_key: foo - object_type=dict""" - ), + msg="Key foo (str) is incompatible with (int)", key="foo", full_key="foo", ), id="dict[int,Any]:mistyped_key", ), + pytest.param( + Expected( + create=lambda: DictConfig({}, key_type=float), + op=lambda cfg: cfg.get("foo"), + exception_type=KeyValidationError, + msg="Key foo (str) is incompatible with (float)", + key="foo", + full_key="foo", + ), + id="dict[float,Any]:mistyped_key", + ), + pytest.param( + Expected( + create=lambda: DictConfig({}, key_type=bool), + op=lambda cfg: cfg.get("foo"), + exception_type=KeyValidationError, + msg="Key foo (str) is incompatible with (bool)", + key="foo", + full_key="foo", + ), + id="dict[bool,Any]:mistyped_key", + ), # dict:create pytest.param( Expected( @@ -567,10 +625,10 @@ def finalize(self, cfg: Any) -> None: pytest.param( Expected( create=lambda: OmegaConf.structured(SubscriptedDict), - op=lambda cfg: cfg.__setitem__("dict", 1), + op=lambda cfg: cfg.__setitem__("dict_str", 1), exception_type=ValidationError, msg="Cannot assign int to Dict[str, int]", - key="dict", + key="dict_str", ref_type=Optional[Dict[str, int]], low_level=True, ), @@ -579,15 +637,39 @@ def finalize(self, cfg: Any) -> None: pytest.param( Expected( create=lambda: OmegaConf.structured(SubscriptedDict), - op=lambda cfg: cfg.__setitem__("dict", User(age=2, name="bar")), + op=lambda cfg: cfg.__setitem__("dict_str", User(age=2, name="bar")), exception_type=ValidationError, msg="Cannot assign User to Dict[str, int]", - key="dict", + key="dict_str", ref_type=Optional[Dict[str, int]], low_level=True, ), id="DictConfig[str,int]:assigned_structured_config", ), + pytest.param( + Expected( + create=lambda: OmegaConf.structured(SubscriptedDict), + op=lambda cfg: cfg.__setitem__("dict_int", "fail"), + exception_type=ValidationError, + msg="Cannot assign str to Dict[int, int]", + key="dict_int", + ref_type=Optional[Dict[int, int]], + low_level=True, + ), + id="DictConfig[int,int]:assigned_primitive_type", + ), + pytest.param( + Expected( + create=lambda: OmegaConf.structured(SubscriptedDict), + op=lambda cfg: cfg.__setitem__("dict_int", User(age=2, name="bar")), + exception_type=ValidationError, + msg="Cannot assign User to Dict[int, int]", + key="dict_int", + ref_type=Optional[Dict[int, int]], + low_level=True, + ), + id="DictConfig[int,int]:assigned_structured_config", + ), # delete pytest.param( Expected( diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 2299b124f..249cf35fa 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -12,6 +12,7 @@ from omegaconf import MISSING, DictConfig, ListConfig, OmegaConf from omegaconf._utils import get_ref_type from tests import ( + Color, PersonA, PersonD, SubscriptedDict, @@ -165,7 +166,11 @@ def test_load_empty_file(tmpdir: str) -> None: True, Optional[Dict[Any, Any]], ), - (SubscriptedDict, "dict", int, str, False, Dict[str, int]), + (SubscriptedDict, "dict_str", int, str, False, Dict[str, int]), + (SubscriptedDict, "dict_int", int, int, False, Dict[int, int]), + (SubscriptedDict, "dict_bool", int, bool, False, Dict[bool, int]), + (SubscriptedDict, "dict_float", int, float, False, Dict[float, int]), + (SubscriptedDict, "dict_enum", int, Color, False, Dict[Color, int]), (SubscriptedList, "list", int, Any, False, List[int]), ( DictConfig( diff --git a/tests/test_to_yaml.py b/tests/test_to_yaml.py index 49ed1807c..636e1d98b 100644 --- a/tests/test_to_yaml.py +++ b/tests/test_to_yaml.py @@ -15,6 +15,8 @@ ({"hello": "world", "list": [1, 2]}, "hello: world\nlist:\n- 1\n- 2\n"), ({"abc": "str key"}, "abc: str key\n"), ({123: "int key"}, "123: int key\n"), + ({123.45: "float key"}, "123.45: float key\n"), + ({True: "bool key", False: "another"}, "true: bool key\nfalse: another\n"), ], ) def test_to_yaml(input_: Any, expected: str) -> None: