diff --git a/docs/source/structured_config.rst b/docs/source/structured_config.rst index 91bfaf0c6..cba3c8bfe 100644 --- a/docs/source/structured_config.rst +++ b/docs/source/structured_config.rst @@ -257,8 +257,9 @@ OmegaConf verifies at runtime that your Lists contains only values of the correc Dictionaries ^^^^^^^^^^^^ -Dictionaries are supported as well. Keys must be strings or enums, and values can be any of any type supported by OmegaConf -(Any, int, float, bool, str and Enums as well as arbitrary Structured configs) +Dictionaries are supported as well. Keys must be strings, ints or enums, and values can +be any of any type supported by OmegaConf (Any, int, float, bool, str and Enums as well +as arbitrary Structured configs) Misc ---- diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 363ed1830..0cc256358 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -41,26 +41,30 @@ From a dictionary .. doctest:: - >>> conf = OmegaConf.create({"k" : "v", "list" : [1, {"a": "1", "b": "2"}]}) + >>> conf = OmegaConf.create({"k" : "v", "list" : [1, {"a": "1", "b": "2", 3: "c"}]}) >>> print(OmegaConf.to_yaml(conf)) k: v list: - 1 - a: '1' b: '2' + 3: c +OmegaConf supports `str`, `int` and Enums as dictionary key types. + From a list ^^^^^^^^^^^ .. doctest:: - >>> conf = OmegaConf.create([1, {"a":10, "b": {"a":10}}]) + >>> conf = OmegaConf.create([1, {"a":10, "b": {"a":10, 123: "int_key"}}]) >>> print(OmegaConf.to_yaml(conf)) - 1 - a: 10 b: a: 10 + 123: int_key Tuples are supported as an valid option too. @@ -95,6 +99,7 @@ From a YAML string ... list: ... - item1 ... - item2 + ... 123: 456 ... """ >>> conf = OmegaConf.create(s) >>> print(OmegaConf.to_yaml(conf)) @@ -103,6 +108,7 @@ From a YAML string list: - item1 - item2 + 123: 456 From a dot-list @@ -264,7 +270,7 @@ Save/Load YAML file .. doctest:: loaded - >>> conf = OmegaConf.create({"foo": 10, "bar": 20}) + >>> conf = OmegaConf.create({"foo": 10, "bar": 20, 123: 456}) >>> with tempfile.NamedTemporaryFile() as fp: ... OmegaConf.save(config=conf, f=fp.name) ... loaded = OmegaConf.load(fp.name) @@ -279,7 +285,7 @@ Note that the saved file may be incompatible across different major versions of .. doctest:: loaded - >>> conf = OmegaConf.create({"foo": 10, "bar": 20}) + >>> conf = OmegaConf.create({"foo": 10, "bar": 20, 123: 456}) >>> with tempfile.TemporaryFile() as fp: ... pickle.dump(conf, fp) ... fp.flush() diff --git a/news/149.feature b/news/149.feature new file mode 100644 index 000000000..c85412de5 --- /dev/null +++ b/news/149.feature @@ -0,0 +1 @@ +Add support for `int` key type in OmegaConf dictionaries diff --git a/omegaconf/__init__.py b/omegaconf/__init__.py index 10f26285c..2a1e82872 100644 --- a/omegaconf/__init__.py +++ b/omegaconf/__init__.py @@ -1,4 +1,4 @@ -from .base import Container, Node +from .base import Container, DictKeyType, Node from .dictconfig import DictConfig from .errors import ( KeyValidationError, @@ -39,6 +39,7 @@ "Container", "ListConfig", "DictConfig", + "DictKeyType", "OmegaConf", "Resolver", "flag_override", diff --git a/omegaconf/_utils.py b/omegaconf/_utils.py index 3b893ebd0..16b0cbf8d 100644 --- a/omegaconf/_utils.py +++ b/omegaconf/_utils.py @@ -491,7 +491,9 @@ def valid_value_annotation_type(type_: Any) -> bool: def _valid_dict_key_annotation_type(type_: Any) -> bool: - return type_ is None or type_ is Any or issubclass(type_, (str, Enum)) + from omegaconf import DictKeyType + + return type_ is None or type_ is Any or issubclass(type_, DictKeyType.__args__) # type: ignore def is_primitive_type(type_: Any) -> bool: diff --git a/omegaconf/base.py b/omegaconf/base.py index 85b075209..fe1a290a5 100644 --- a/omegaconf/base.py +++ b/omegaconf/base.py @@ -7,6 +7,8 @@ from ._utils import ValueKind, _get_value, format_and_raise, get_value_kind from .errors import ConfigKeyError, MissingMandatoryValue, UnsupportedInterpolationType +DictKeyType = Union[str, int, Enum] + @dataclass class Metadata: @@ -250,7 +252,7 @@ def __setitem__(self, key: Any, value: Any) -> None: ... @abstractmethod - def __iter__(self) -> Iterator[str]: + def __iter__(self) -> Iterator[Any]: ... @abstractmethod diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 8084f0882..649d5831d 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -24,7 +24,7 @@ is_primitive_type, is_structured_config, ) -from .base import Container, ContainerMetadata, Node +from .base import Container, ContainerMetadata, DictKeyType, Node from .errors import MissingMandatoryValue, ReadonlyConfigError, ValidationError DEFAULT_VALUE_MARKER: Any = str("__DEFAULT_VALUE_MARKER__") @@ -187,7 +187,7 @@ def _to_content( resolve: bool, enum_to_str: bool = False, exclude_structured_configs: bool = False, - ) -> Union[None, Any, str, Dict[str, Any], List[Any]]: + ) -> Union[None, Any, str, Dict[DictKeyType, Any], List[Any]]: from .dictconfig import DictConfig from .listconfig import ListConfig @@ -528,7 +528,10 @@ def assign(value_key: Any, value_to_assign: Any) -> None: @staticmethod def _item_eq( - c1: Container, k1: Union[str, int], c2: Container, k2: Union[str, int] + c1: Container, + k1: Union[DictKeyType, int], + c2: Container, + k2: Union[DictKeyType, int], ) -> bool: v1 = c1._get_node(k1) v2 = c2._get_node(k2) diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index 3fafe8455..8c2ec97e4 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -9,6 +9,7 @@ List, MutableMapping, Optional, + Sequence, Tuple, Type, Union, @@ -31,7 +32,7 @@ type_str, valid_value_annotation_type, ) -from .base import Container, ContainerMetadata, Node +from .base import Container, ContainerMetadata, DictKeyType, Node from .basecontainer import DEFAULT_VALUE_MARKER, BaseContainer from .errors import ( ConfigAttributeError, @@ -47,13 +48,13 @@ from .nodes import EnumNode, ValueNode -class DictConfig(BaseContainer, MutableMapping[str, Any]): +class DictConfig(BaseContainer, MutableMapping[Any, Any]): _metadata: ContainerMetadata def __init__( self, - content: Union[Dict[str, Any], Any], + content: Union[Dict[DictKeyType, Any], Any], key: Any = None, parent: Optional[Container] = None, ref_type: Union[Any, Type[Any]] = Any, @@ -245,14 +246,12 @@ def _raise_invalid_value( ) raise ValidationError(msg) - def _validate_and_normalize_key(self, key: Any) -> Union[str, Enum]: + def _validate_and_normalize_key(self, key: Any) -> DictKeyType: return self._s_validate_and_normalize_key(self._metadata.key_type, key) - def _s_validate_and_normalize_key( - self, key_type: Any, key: Any - ) -> Union[str, Enum]: + def _s_validate_and_normalize_key(self, key_type: Any, key: Any) -> DictKeyType: if key_type is Any: - for t in (str, Enum): + for t in DictKeyType.__args__: # type: ignore try: return self._s_validate_and_normalize_key(key_type=t, key=key) except KeyValidationError: @@ -264,6 +263,13 @@ def _s_validate_and_normalize_key( f"Key $KEY ($KEY_TYPE) is incompatible with ({key_type.__name__})" ) + return key + elif key_type == int: + if not isinstance(key, int): + raise KeyValidationError( + f"Key $KEY ($KEY_TYPE) is incompatible with ({key_type.__name__})" + ) + return key elif issubclass(key_type, Enum): try: @@ -278,7 +284,7 @@ def _s_validate_and_normalize_key( else: assert False, f"Unsupported key type {key_type}" - def __setitem__(self, key: Union[str, Enum], value: Any) -> None: + def __setitem__(self, key: DictKeyType, value: Any) -> None: try: self.__set_impl(key=key, value=value) except AttributeError as e: @@ -288,7 +294,7 @@ def __setitem__(self, key: Union[str, Enum], value: Any) -> None: except Exception as e: self._format_and_raise(key=key, value=value, cause=e) - def __set_impl(self, key: Union[str, Enum], value: Any) -> None: + def __set_impl(self, key: DictKeyType, value: Any) -> None: key = self._validate_and_normalize_key(key) self._set_item_impl(key, value) @@ -331,7 +337,7 @@ def __getattr__(self, key: str) -> Any: except Exception as e: self._format_and_raise(key=key, value=None, cause=e) - def __getitem__(self, key: Union[str, Enum]) -> Any: + def __getitem__(self, key: DictKeyType) -> Any: """ Allow map style access :param key: @@ -347,7 +353,7 @@ def __getitem__(self, key: Union[str, Enum]) -> Any: except Exception as e: self._format_and_raise(key=key, value=None, cause=e) - def __delitem__(self, key: Union[str, int, Enum]) -> None: + def __delitem__(self, key: DictKeyType) -> None: if self._get_flag("readonly"): self._format_and_raise( key=key, @@ -375,15 +381,13 @@ def __delitem__(self, key: Union[str, int, Enum]) -> None: del self.__dict__["_content"][key] - def get( - self, key: Union[str, Enum], default_value: Any = DEFAULT_VALUE_MARKER - ) -> Any: + def get(self, key: DictKeyType, default_value: Any = DEFAULT_VALUE_MARKER) -> Any: try: return self._get_impl(key=key, default_value=default_value) except Exception as e: self._format_and_raise(key=key, value=None, cause=e) - def _get_impl(self, key: Union[str, Enum], default_value: Any) -> Any: + def _get_impl(self, key: DictKeyType, default_value: Any) -> Any: try: node = self._get_node(key=key) except ConfigAttributeError: @@ -396,7 +400,7 @@ def _get_impl(self, key: Union[str, Enum], default_value: Any) -> Any: ) def _get_node( - self, key: Union[str, Enum], validate_access: bool = True + self, key: DictKeyType, validate_access: bool = True ) -> Optional[Node]: try: key = self._validate_and_normalize_key(key) @@ -413,7 +417,7 @@ def _get_node( return value - def pop(self, key: Union[str, Enum], default: Any = DEFAULT_VALUE_MARKER) -> Any: + def pop(self, key: DictKeyType, default: Any = DEFAULT_VALUE_MARKER) -> Any: try: if self._get_flag("readonly"): raise ReadonlyConfigError("Cannot pop from read-only node") @@ -479,13 +483,13 @@ def __contains__(self, key: object) -> bool: except (MissingMandatoryValue, KeyError): return False - def __iter__(self) -> Iterator[str]: + def __iter__(self) -> Iterator[DictKeyType]: return iter(self.keys()) - def items(self) -> AbstractSet[Tuple[str, Any]]: + def items(self) -> AbstractSet[Tuple[DictKeyType, Any]]: return self.items_ex(resolve=True, keys=None) - def setdefault(self, key: Union[str, Enum], default: Any = None) -> Any: + def setdefault(self, key: DictKeyType, default: Any = None) -> Any: if key in self: ret = self.__getitem__(key) else: @@ -494,9 +498,9 @@ def setdefault(self, key: Union[str, Enum], default: Any = None) -> Any: return ret def items_ex( - self, resolve: bool = True, keys: Optional[List[str]] = None - ) -> AbstractSet[Tuple[str, Any]]: - items: List[Tuple[str, Any]] = [] + self, resolve: bool = True, keys: Optional[Sequence[DictKeyType]] = None + ) -> AbstractSet[Tuple[DictKeyType, Any]]: + items: List[Tuple[DictKeyType, Any]] = [] for key in self.keys(): if resolve: value = self.get(key) diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index 22ce8b277..4f3e1f1f5 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -28,7 +28,7 @@ import yaml from typing_extensions import Protocol -from . import DictConfig, ListConfig +from . import DictConfig, DictKeyType, ListConfig from ._utils import ( _ensure_container, _get_value, @@ -183,7 +183,7 @@ def create( @staticmethod @overload def create( - obj: Union[Dict[str, Any], None] = None, + obj: Optional[Dict[Any, Any]] = None, parent: Optional[BaseContainer] = None, flags: Optional[Dict[str, bool]] = None, ) -> DictConfig: @@ -467,7 +467,7 @@ def to_container( resolve: bool = False, enum_to_str: bool = False, exclude_structured_configs: bool = False, - ) -> Union[Dict[str, Any], List[Any], None, str]: + ) -> Union[Dict[DictKeyType, Any], List[Any], None, str]: """ Resursively converts an OmegaConf config to a primitive container (dict or list). :param cfg: the config to convert @@ -508,7 +508,7 @@ def is_optional(obj: Any, key: Optional[Union[int, str]] = None) -> bool: return True @staticmethod - def is_none(obj: Any, key: Optional[Union[int, str]] = None) -> bool: + def is_none(obj: Any, key: Optional[Union[int, DictKeyType]] = None) -> bool: if key is not None: assert isinstance(obj, Container) obj = obj._get_node(key) diff --git a/tests/structured_conf/data/attr_classes.py b/tests/structured_conf/data/attr_classes.py index d2deac274..a8710d43a 100644 --- a/tests/structured_conf/data/attr_classes.py +++ b/tests/structured_conf/data/attr_classes.py @@ -300,9 +300,9 @@ class WithTypedDict: @attr.s(auto_attribs=True) -class ErrorDictIntKey: +class ErrorDictObjectKey: # invalid dict key, must be str - dict: Dict[int, str] = {10: "foo", 20: "bar"} + dict: Dict[object, str] = {object(): "foo", object(): "bar"} class RegularClass: @@ -350,6 +350,7 @@ class DictExamples: "green": Color.GREEN, "blue": Color.BLUE, } + int_keys: Dict[int, str] = {1: "one", 2: "two"} @attr.s(auto_attribs=True) @@ -372,6 +373,10 @@ class DictSubclass: class Str2Str(Dict[str, str]): pass + @attr.s(auto_attribs=True) + class Int2Str(Dict[int, str]): + pass + @attr.s(auto_attribs=True) class Color2Str(Dict[Color, str]): pass diff --git a/tests/structured_conf/data/dataclasses.py b/tests/structured_conf/data/dataclasses.py index 117dcfc58..de4611463 100644 --- a/tests/structured_conf/data/dataclasses.py +++ b/tests/structured_conf/data/dataclasses.py @@ -303,9 +303,11 @@ class WithTypedDict: @dataclass -class ErrorDictIntKey: +class ErrorDictObjectKey: # invalid dict key, must be str - dict: Dict[int, str] = field(default_factory=lambda: {10: "foo", 20: "bar"}) + dict: Dict[object, str] = field( + default_factory=lambda: {object(): "foo", object(): "bar"} + ) class RegularClass: @@ -363,6 +365,7 @@ class DictExamples: "blue": Color.BLUE, } ) + int_keys: Dict[int, str] = field(default_factory=lambda: {1: "one", 2: "two"}) @dataclass @@ -389,6 +392,10 @@ class DictSubclass: class Str2Str(Dict[str, str]): pass + @dataclass + class Int2Str(Dict[int, str]): + pass + @dataclass class Color2Str(Dict[Color, str]): pass diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index d97c00020..b8282ce07 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -548,7 +548,7 @@ def test_merge_dict_with_correct_type(self, class_type: str) -> None: def test_typed_dict_key_error(self, class_type: str) -> None: module: Any = import_module(class_type) - input_ = module.ErrorDictIntKey + input_ = module.ErrorDictObjectKey with pytest.raises(KeyValidationError): OmegaConf.structured(input_) @@ -675,6 +675,16 @@ def test_any(name: str) -> None: "f": Color.BLUE, } + # test int_keys + with pytest.raises(KeyValidationError): + conf.int_keys.foo_key = "foo_value" + conf.int_keys[3] = "three" + assert conf.int_keys == { + 1: "one", + 2: "two", + 3: "three", + } + def test_enum_key(self, class_type: str) -> None: module: Any = import_module(class_type) conf = OmegaConf.structured(module.DictWithEnumKeys) @@ -894,6 +904,49 @@ def test_str2str_as_sub_node(self, class_type: str) -> None: with pytest.raises(KeyValidationError): cfg.foo[Color.RED] = "fail" + with pytest.raises(KeyValidationError): + cfg.foo[123] = "fail" + + def test_int2str(self, class_type: str) -> None: + module: Any = import_module(class_type) + cfg = OmegaConf.structured(module.DictSubclass.Int2Str()) + + cfg[10] = "ten" # okay + assert cfg[10] == "ten" + + with pytest.raises(KeyValidationError): + cfg[10.0] = "float" # fail + + with pytest.raises(KeyValidationError): + cfg["10"] = "string" # fail + + with pytest.raises(KeyValidationError): + cfg.hello = "fail" + + with pytest.raises(KeyValidationError): + cfg[Color.RED] = "fail" + + def test_int2str_as_sub_node(self, class_type: str) -> None: + module: Any = import_module(class_type) + cfg = OmegaConf.create({"foo": module.DictSubclass.Int2Str}) + assert OmegaConf.get_type(cfg.foo) == module.DictSubclass.Int2Str + assert _utils.get_ref_type(cfg.foo) == Optional[module.DictSubclass.Int2Str] + + cfg.foo[10] = "ten" + assert cfg.foo[10] == "ten" + + with pytest.raises(KeyValidationError): + cfg.foo[10.0] = "float" # fail + + with pytest.raises(KeyValidationError): + cfg.foo["10"] = "string" # fail + + with pytest.raises(KeyValidationError): + cfg.foo.hello = "fail" + + with pytest.raises(KeyValidationError): + cfg.foo[Color.RED] = "fail" + def test_color2str(self, class_type: str) -> None: module: Any = import_module(class_type) cfg = OmegaConf.structured(module.DictSubclass.Color2Str()) @@ -902,6 +955,9 @@ def test_color2str(self, class_type: str) -> None: with pytest.raises(KeyValidationError): cfg.greeen = "nope" + with pytest.raises(KeyValidationError): + cfg[123] = "nope" + def test_color2color(self, class_type: str) -> None: module: Any = import_module(class_type) cfg = OmegaConf.structured(module.DictSubclass.Color2Color()) @@ -924,6 +980,10 @@ def test_color2color(self, class_type: str) -> None: # bad value cfg[Color.GREEN] = 10 + with pytest.raises(ValidationError): + # bad value + cfg[Color.GREEN] = "this string is not a color" + with pytest.raises(KeyValidationError): # bad key cfg.greeen = "nope" diff --git a/tests/test_basic_ops_dict.py b/tests/test_basic_ops_dict.py index 294169d8d..2636f9d9f 100644 --- a/tests/test_basic_ops_dict.py +++ b/tests/test_basic_ops_dict.py @@ -7,6 +7,7 @@ from omegaconf import ( DictConfig, + DictKeyType, ListConfig, MissingMandatoryValue, OmegaConf, @@ -66,28 +67,71 @@ def test_getattr_dict() -> None: assert {"b": 1} == c.a -def test_mandatory_value() -> None: - c = OmegaConf.create({"a": "???"}) - with pytest.raises(MissingMandatoryValue, match="a"): - c.a - - -def test_nested_dict_mandatory_value() -> None: - c = OmegaConf.create(dict(a=dict(b="???"))) - with pytest.raises(MissingMandatoryValue): - c.a.b - - -def test_subscript_get() -> None: - c = OmegaConf.create("a: b") - assert isinstance(c, DictConfig) - assert "b" == c["a"] - - -def test_subscript_set() -> None: - c = OmegaConf.create() - c["a"] = "b" - assert {"a": "b"} == c +@pytest.mark.parametrize( + "key", + ["a", 1], +) +class TestDictKeyTypes: + def test_mandatory_value(self, key: DictKeyType) -> None: + c = OmegaConf.create({key: "???"}) + with pytest.raises(MissingMandatoryValue, match=str(key)): + c[key] + if isinstance(key, str): + with pytest.raises(MissingMandatoryValue, match=key): + getattr(c, key) + + def test_nested_dict_mandatory_value(self, key: DictKeyType) -> None: + c = OmegaConf.create({"b": {key: "???"}}) + with pytest.raises(MissingMandatoryValue): + c.b[key] + if isinstance(key, str): + with pytest.raises(MissingMandatoryValue): + getattr(c.b, key) + + c = OmegaConf.create({key: {"b": "???"}}) + with pytest.raises(MissingMandatoryValue): + c[key].b + if isinstance(key, str): + with pytest.raises(MissingMandatoryValue): + getattr(c, key).b + + def test_subscript_get(self, key: DictKeyType) -> None: + c = OmegaConf.create({key: "b"}) + assert isinstance(c, DictConfig) + assert "b" == c[key] + + def test_subscript_set(self, key: DictKeyType) -> None: + c = OmegaConf.create() + c[key] = "b" + assert {key: "b"} == c + + +@pytest.mark.parametrize( + "src,key,expected", + [ + ({"a": 10, "b": 11}, "a", {"b": 11}), + ({1: "a", 2: "b"}, 1, {2: "b"}), + ], +) +class TestDelitemKeyTypes: + def test_dict_delitem(self, src: Any, key: DictKeyType, expected: Any) -> None: + c = OmegaConf.create(src) + assert c == src + del c[key] + assert c == expected + with pytest.raises(KeyError): + del c["not_found"] + + def test_dict_struct_delitem( + self, src: Any, key: DictKeyType, expected: Any + ) -> None: + c = OmegaConf.create(src) + OmegaConf.set_struct(c, True) + with pytest.raises(ConfigTypeError): + del c[key] + with open_dict(c): + del c[key] + assert key not in c def test_default_value() -> None: @@ -353,6 +397,10 @@ def test_dict_pop_error(cfg: Dict[Any, Any], key: Any, expectation: Any) -> None "incompatible_key_type", False, ), + ({1: "a", 2: {}}, 1, True), + ({1: "a", 2: {}}, 2, True), + ({1: "a", 2: {}}, 3, False), + ({1: "a", 2: "???"}, 2, False), ], ) def test_in_dict(conf: Any, key: str, expected: Any) -> None: @@ -384,27 +432,6 @@ def test_dict_config() -> None: assert isinstance(c, DictConfig) -def test_dict_delitem() -> None: - src = {"a": 10, "b": 11} - c = OmegaConf.create(src) - assert c == src - del c["a"] - assert c == {"b": 11} - with pytest.raises(KeyError): - del c["not_found"] - - -def test_dict_struct_delitem() -> None: - src = {"a": 10, "b": 11} - c = OmegaConf.create(src) - OmegaConf.set_struct(c, True) - with pytest.raises(ConfigTypeError): - del c["a"] - with open_dict(c): - del c["a"] - assert "a" not in c - - def test_dict_structured_delitem() -> None: c = OmegaConf.structured(User(name="Bond")) with pytest.raises(ConfigTypeError): @@ -544,19 +571,19 @@ def test_masked_copy_is_deep() -> None: def test_creation_with_invalid_key() -> None: with pytest.raises(KeyValidationError): - OmegaConf.create({1: "a"}) # type: ignore + OmegaConf.create({object(): "a"}) def test_set_with_invalid_key() -> None: cfg = OmegaConf.create() with pytest.raises(KeyValidationError): - cfg[1] = "a" # type: ignore + cfg[object()] = "a" # type: ignore def test_get_with_invalid_key() -> None: cfg = OmegaConf.create() with pytest.raises(KeyValidationError): - cfg[1] # type: ignore + cfg[object()] # type: ignore def test_hasattr() -> None: diff --git a/tests/test_errors.py b/tests/test_errors.py index 551c952fd..cf88a4d6f 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -1,6 +1,7 @@ import re from dataclasses import dataclass from enum import Enum +from textwrap import dedent from typing import Any, Dict, List, Optional, Type, Union import pytest @@ -435,6 +436,23 @@ def finalize(self, cfg: Any) -> None: ), id="dict:get_object_of_illegal_type", ), + pytest.param( + Expected( + create=lambda: DictConfig({}, key_type=int), + op=lambda cfg: cfg.get("foo"), + exception_type=KeyValidationError, + msg=dedent( + """\ + Key foo (str) is incompatible with (int) + \tfull_key: foo + \treference_type=Optional[Dict[int, Any]] + \tobject_type=dict""" + ), + key="foo", + full_key="foo", + ), + id="dict[int,Any]:mistyped_key", + ), # dict:create pytest.param( Expected( diff --git a/tests/test_to_yaml.py b/tests/test_to_yaml.py index 55b70d283..3d9d7faac 100644 --- a/tests/test_to_yaml.py +++ b/tests/test_to_yaml.py @@ -14,6 +14,8 @@ [ (["item1", "item2", {"key3": "value3"}], "- item1\n- item2\n- key3: value3\n"), ({"hello": "world", "list": [1, 2]}, "hello: world\nlist:\n- 1\n- 2\n"), + ({"abc": "str key"}, "abc: str key\n"), + ({123: "int key"}, "123: int key\n"), ], ) def test_to_yaml(input_: Any, expected: str) -> None: