From 1eadbdf8ff2dd11948902f35cc42cbd486167801 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 9 Dec 2020 18:47:48 -0600 Subject: [PATCH 01/33] preliminary DictConfig support for int key type --- omegaconf/_utils.py | 2 +- omegaconf/dictconfig.py | 29 ++++++++++++++++++----------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/omegaconf/_utils.py b/omegaconf/_utils.py index 3b893ebd0..cdc8d3484 100644 --- a/omegaconf/_utils.py +++ b/omegaconf/_utils.py @@ -491,7 +491,7 @@ def valid_value_annotation_type(type_: Any) -> bool: def _valid_dict_key_annotation_type(type_: Any) -> bool: - return type_ is None or type_ is Any or issubclass(type_, (str, Enum)) + return type_ is None or type_ is Any or issubclass(type_, (str, int, Enum)) def is_primitive_type(type_: Any) -> bool: diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index 3fafe8455..bc5392127 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -245,14 +245,14 @@ def _raise_invalid_value( ) raise ValidationError(msg) - def _validate_and_normalize_key(self, key: Any) -> Union[str, Enum]: + def _validate_and_normalize_key(self, key: Any) -> Union[str, int, Enum]: return self._s_validate_and_normalize_key(self._metadata.key_type, key) def _s_validate_and_normalize_key( self, key_type: Any, key: Any - ) -> Union[str, Enum]: + ) -> Union[str, int, Enum]: if key_type is Any: - for t in (str, Enum): + for t in (str, int, Enum): try: return self._s_validate_and_normalize_key(key_type=t, key=key) except KeyValidationError: @@ -264,6 +264,13 @@ def _s_validate_and_normalize_key( f"Key $KEY ($KEY_TYPE) is incompatible with ({key_type.__name__})" ) + return key + elif key_type == int: + if not isinstance(key, int): + raise KeyValidationError( + f"Key $KEY ($KEY_TYPE) is incompatible with ({key_type.__name__})" + ) + return key elif issubclass(key_type, Enum): try: @@ -278,7 +285,7 @@ def _s_validate_and_normalize_key( else: assert False, f"Unsupported key type {key_type}" - def __setitem__(self, key: Union[str, Enum], value: Any) -> None: + def __setitem__(self, key: Union[str, int, Enum], value: Any) -> None: try: self.__set_impl(key=key, value=value) except AttributeError as e: @@ -288,7 +295,7 @@ def __setitem__(self, key: Union[str, Enum], value: Any) -> None: except Exception as e: self._format_and_raise(key=key, value=value, cause=e) - def __set_impl(self, key: Union[str, Enum], value: Any) -> None: + def __set_impl(self, key: Union[str, int, Enum], value: Any) -> None: key = self._validate_and_normalize_key(key) self._set_item_impl(key, value) @@ -331,7 +338,7 @@ def __getattr__(self, key: str) -> Any: except Exception as e: self._format_and_raise(key=key, value=None, cause=e) - def __getitem__(self, key: Union[str, Enum]) -> Any: + def __getitem__(self, key: Union[str, int, Enum]) -> Any: """ Allow map style access :param key: @@ -376,14 +383,14 @@ def __delitem__(self, key: Union[str, int, Enum]) -> None: del self.__dict__["_content"][key] def get( - self, key: Union[str, Enum], default_value: Any = DEFAULT_VALUE_MARKER + self, key: Union[str, int, Enum], default_value: Any = DEFAULT_VALUE_MARKER ) -> Any: try: return self._get_impl(key=key, default_value=default_value) except Exception as e: self._format_and_raise(key=key, value=None, cause=e) - def _get_impl(self, key: Union[str, Enum], default_value: Any) -> Any: + def _get_impl(self, key: Union[str, int, Enum], default_value: Any) -> Any: try: node = self._get_node(key=key) except ConfigAttributeError: @@ -396,7 +403,7 @@ def _get_impl(self, key: Union[str, Enum], default_value: Any) -> Any: ) def _get_node( - self, key: Union[str, Enum], validate_access: bool = True + self, key: Union[str, int, Enum], validate_access: bool = True ) -> Optional[Node]: try: key = self._validate_and_normalize_key(key) @@ -413,7 +420,7 @@ def _get_node( return value - def pop(self, key: Union[str, Enum], default: Any = DEFAULT_VALUE_MARKER) -> Any: + def pop(self, key: Union[str, int, Enum], default: Any = DEFAULT_VALUE_MARKER) -> Any: try: if self._get_flag("readonly"): raise ReadonlyConfigError("Cannot pop from read-only node") @@ -485,7 +492,7 @@ def __iter__(self) -> Iterator[str]: def items(self) -> AbstractSet[Tuple[str, Any]]: return self.items_ex(resolve=True, keys=None) - def setdefault(self, key: Union[str, Enum], default: Any = None) -> Any: + def setdefault(self, key: Union[str, int, Enum], default: Any = None) -> Any: if key in self: ret = self.__getitem__(key) else: From 29941b5c68f42adda40b3963bbf5592cf41b14ad Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Thu, 10 Dec 2020 20:09:20 -0600 Subject: [PATCH 02/33] DictConfig[int, ...]: comment out offending tests --- tests/structured_conf/data/attr_classes.py | 8 +++--- .../structured_conf/test_structured_config.py | 10 +++---- tests/test_basic_ops_dict.py | 27 +++++++++++-------- 3 files changed, 25 insertions(+), 20 deletions(-) diff --git a/tests/structured_conf/data/attr_classes.py b/tests/structured_conf/data/attr_classes.py index d2deac274..11ac6d7e1 100644 --- a/tests/structured_conf/data/attr_classes.py +++ b/tests/structured_conf/data/attr_classes.py @@ -299,10 +299,10 @@ class WithTypedDict: dict: Dict[str, int] = {"foo": 10, "bar": 20} -@attr.s(auto_attribs=True) -class ErrorDictIntKey: - # invalid dict key, must be str - dict: Dict[int, str] = {10: "foo", 20: "bar"} +# @attr.s(auto_attribs=True) +# class ErrorDictIntKey: +# # invalid dict key, must be str +# dict: Dict[int, str] = {10: "foo", 20: "bar"} class RegularClass: diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index d97c00020..143da6a16 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -546,11 +546,11 @@ def test_merge_dict_with_correct_type(self, class_type: str) -> None: res = OmegaConf.merge(cfg, {"dict": {"foo": user}}) assert res.dict == {"foo": user} - def test_typed_dict_key_error(self, class_type: str) -> None: - module: Any = import_module(class_type) - input_ = module.ErrorDictIntKey - with pytest.raises(KeyValidationError): - OmegaConf.structured(input_) + # def test_typed_dict_key_error(self, class_type: str) -> None: + # module: Any = import_module(class_type) + # input_ = module.ErrorDictIntKey + # with pytest.raises(KeyValidationError): + # OmegaConf.structured(input_) def test_typed_dict_value_error(self, class_type: str) -> None: module: Any = import_module(class_type) diff --git a/tests/test_basic_ops_dict.py b/tests/test_basic_ops_dict.py index 294169d8d..6378917a2 100644 --- a/tests/test_basic_ops_dict.py +++ b/tests/test_basic_ops_dict.py @@ -83,6 +83,11 @@ def test_subscript_get() -> None: assert isinstance(c, DictConfig) assert "b" == c["a"] +def test_subscript_get_int_key() -> None: + c = OmegaConf.create("1: b") + assert isinstance(c, DictConfig) + assert "b" == c[1] + def test_subscript_set() -> None: c = OmegaConf.create() @@ -542,21 +547,21 @@ def test_masked_copy_is_deep() -> None: OmegaConf.masked_copy("fail", []) # type: ignore -def test_creation_with_invalid_key() -> None: - with pytest.raises(KeyValidationError): - OmegaConf.create({1: "a"}) # type: ignore +# def test_creation_with_invalid_key() -> None: +# with pytest.raises(KeyValidationError): +# OmegaConf.create({1: "a"}) # type: ignore -def test_set_with_invalid_key() -> None: - cfg = OmegaConf.create() - with pytest.raises(KeyValidationError): - cfg[1] = "a" # type: ignore +# def test_set_with_invalid_key() -> None: +# cfg = OmegaConf.create() +# with pytest.raises(KeyValidationError): +# cfg[1] = "a" # type: ignore -def test_get_with_invalid_key() -> None: - cfg = OmegaConf.create() - with pytest.raises(KeyValidationError): - cfg[1] # type: ignore +# def test_get_with_invalid_key() -> None: +# cfg = OmegaConf.create() +# with pytest.raises(KeyValidationError): +# cfg[1] # type: ignore def test_hasattr() -> None: From 1848190cc55e94ffbdb11a5fd3558a2d940029ab Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Fri, 11 Dec 2020 12:34:32 -0600 Subject: [PATCH 03/33] black formatting --- omegaconf/dictconfig.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index bc5392127..e041836c8 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -420,7 +420,9 @@ def _get_node( return value - def pop(self, key: Union[str, int, Enum], default: Any = DEFAULT_VALUE_MARKER) -> Any: + def pop( + self, key: Union[str, int, Enum], default: Any = DEFAULT_VALUE_MARKER + ) -> Any: try: if self._get_flag("readonly"): raise ReadonlyConfigError("Cannot pop from read-only node") From c4b12b905e6d9f480174927654af7768d290acea Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Fri, 11 Dec 2020 12:34:32 -0600 Subject: [PATCH 04/33] black formatting --- omegaconf/dictconfig.py | 4 +++- tests/test_basic_ops_dict.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index bc5392127..e041836c8 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -420,7 +420,9 @@ def _get_node( return value - def pop(self, key: Union[str, int, Enum], default: Any = DEFAULT_VALUE_MARKER) -> Any: + def pop( + self, key: Union[str, int, Enum], default: Any = DEFAULT_VALUE_MARKER + ) -> Any: try: if self._get_flag("readonly"): raise ReadonlyConfigError("Cannot pop from read-only node") diff --git a/tests/test_basic_ops_dict.py b/tests/test_basic_ops_dict.py index 6378917a2..d503ca492 100644 --- a/tests/test_basic_ops_dict.py +++ b/tests/test_basic_ops_dict.py @@ -83,6 +83,7 @@ def test_subscript_get() -> None: assert isinstance(c, DictConfig) assert "b" == c["a"] + def test_subscript_get_int_key() -> None: c = OmegaConf.create("1: b") assert isinstance(c, DictConfig) From df962ebbdb65d5753dd068ef07fa16e79551aa55 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Sun, 13 Dec 2020 12:42:11 -0600 Subject: [PATCH 05/33] Change ErrorDictIntKey to ErrorDictObjectKey We want to test that DictConfig throws KeyValidationError if invalid key type is used. But `int` type is no longer invalid, so we test on Dict[object, str] instead of on Dict[int, str]. --- tests/structured_conf/data/attr_classes.py | 8 ++++---- tests/structured_conf/data/dataclasses.py | 6 ++++-- tests/structured_conf/test_structured_config.py | 10 +++++----- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/tests/structured_conf/data/attr_classes.py b/tests/structured_conf/data/attr_classes.py index 11ac6d7e1..76368f544 100644 --- a/tests/structured_conf/data/attr_classes.py +++ b/tests/structured_conf/data/attr_classes.py @@ -299,10 +299,10 @@ class WithTypedDict: dict: Dict[str, int] = {"foo": 10, "bar": 20} -# @attr.s(auto_attribs=True) -# class ErrorDictIntKey: -# # invalid dict key, must be str -# dict: Dict[int, str] = {10: "foo", 20: "bar"} +@attr.s(auto_attribs=True) +class ErrorDictObjectKey: + # invalid dict key, must be str + dict: Dict[object, str] = {object(): "foo", object(): "bar"} class RegularClass: diff --git a/tests/structured_conf/data/dataclasses.py b/tests/structured_conf/data/dataclasses.py index 117dcfc58..349503007 100644 --- a/tests/structured_conf/data/dataclasses.py +++ b/tests/structured_conf/data/dataclasses.py @@ -303,9 +303,11 @@ class WithTypedDict: @dataclass -class ErrorDictIntKey: +class ErrorDictObjectKey: # invalid dict key, must be str - dict: Dict[int, str] = field(default_factory=lambda: {10: "foo", 20: "bar"}) + dict: Dict[object, str] = field( + default_factory=lambda: {object(): "foo", object(): "bar"} + ) class RegularClass: diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index 143da6a16..4f91f4eb3 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -546,11 +546,11 @@ def test_merge_dict_with_correct_type(self, class_type: str) -> None: res = OmegaConf.merge(cfg, {"dict": {"foo": user}}) assert res.dict == {"foo": user} - # def test_typed_dict_key_error(self, class_type: str) -> None: - # module: Any = import_module(class_type) - # input_ = module.ErrorDictIntKey - # with pytest.raises(KeyValidationError): - # OmegaConf.structured(input_) + def test_typed_dict_key_error(self, class_type: str) -> None: + module: Any = import_module(class_type) + input_ = module.ErrorDictObjectKey + with pytest.raises(KeyValidationError): + OmegaConf.structured(input_) def test_typed_dict_value_error(self, class_type: str) -> None: module: Any = import_module(class_type) From cbfed039bb5916fffd7056943f2832ef0ae30395 Mon Sep 17 00:00:00 2001 From: Jasha10 <8935917+Jasha10@users.noreply.github.com> Date: Sun, 13 Dec 2020 12:46:52 -0600 Subject: [PATCH 06/33] tests: call OmegaConf.create on dictionary instead of string Based on suggested changes https://github.com/omry/omegaconf/pull/454#discussion_r541522788 Co-authored-by: Omry Yadan --- tests/test_basic_ops_dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_basic_ops_dict.py b/tests/test_basic_ops_dict.py index d503ca492..58046fb0b 100644 --- a/tests/test_basic_ops_dict.py +++ b/tests/test_basic_ops_dict.py @@ -85,7 +85,7 @@ def test_subscript_get() -> None: def test_subscript_get_int_key() -> None: - c = OmegaConf.create("1: b") + c = OmegaConf.create({1: b}) assert isinstance(c, DictConfig) assert "b" == c[1] From fec336c7f57abae0d3df8017f22851eaaea94a09 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Sun, 13 Dec 2020 12:58:13 -0600 Subject: [PATCH 07/33] fix typo in test_basic_ops_dict --- tests/test_basic_ops_dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_basic_ops_dict.py b/tests/test_basic_ops_dict.py index 58046fb0b..a746317fa 100644 --- a/tests/test_basic_ops_dict.py +++ b/tests/test_basic_ops_dict.py @@ -85,7 +85,7 @@ def test_subscript_get() -> None: def test_subscript_get_int_key() -> None: - c = OmegaConf.create({1: b}) + c = OmegaConf.create({1: "b"}) assert isinstance(c, DictConfig) assert "b" == c[1] From a3581c66e05e6931138ca4f54c994addef76b6e6 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Sun, 13 Dec 2020 13:00:57 -0600 Subject: [PATCH 08/33] for invalid key test, use object instead of int --- tests/test_basic_ops_dict.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/test_basic_ops_dict.py b/tests/test_basic_ops_dict.py index a746317fa..c2edd1338 100644 --- a/tests/test_basic_ops_dict.py +++ b/tests/test_basic_ops_dict.py @@ -548,21 +548,21 @@ def test_masked_copy_is_deep() -> None: OmegaConf.masked_copy("fail", []) # type: ignore -# def test_creation_with_invalid_key() -> None: -# with pytest.raises(KeyValidationError): -# OmegaConf.create({1: "a"}) # type: ignore +def test_creation_with_invalid_key() -> None: + with pytest.raises(KeyValidationError): + OmegaConf.create({object(): "a"}) # type: ignore -# def test_set_with_invalid_key() -> None: -# cfg = OmegaConf.create() -# with pytest.raises(KeyValidationError): -# cfg[1] = "a" # type: ignore +def test_set_with_invalid_key() -> None: + cfg = OmegaConf.create() + with pytest.raises(KeyValidationError): + cfg[object()] = "a" # type: ignore -# def test_get_with_invalid_key() -> None: -# cfg = OmegaConf.create() -# with pytest.raises(KeyValidationError): -# cfg[1] # type: ignore +def test_get_with_invalid_key() -> None: + cfg = OmegaConf.create() + with pytest.raises(KeyValidationError): + cfg[object()] # type: ignore def test_hasattr() -> None: From ecdef40d00b6e19cc66428e9dcc9fa0ee97688c6 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Sun, 13 Dec 2020 13:45:47 -0600 Subject: [PATCH 09/33] Use DictKeyType instead of Union[...] --- omegaconf/__init__.py | 2 +- omegaconf/_utils.py | 4 +++- omegaconf/dictconfig.py | 26 ++++++++++++++------------ 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/omegaconf/__init__.py b/omegaconf/__init__.py index 10f26285c..5461faff4 100644 --- a/omegaconf/__init__.py +++ b/omegaconf/__init__.py @@ -1,5 +1,5 @@ from .base import Container, Node -from .dictconfig import DictConfig +from .dictconfig import DictConfig, DictKeyType from .errors import ( KeyValidationError, MissingMandatoryValue, diff --git a/omegaconf/_utils.py b/omegaconf/_utils.py index cdc8d3484..16b0cbf8d 100644 --- a/omegaconf/_utils.py +++ b/omegaconf/_utils.py @@ -491,7 +491,9 @@ def valid_value_annotation_type(type_: Any) -> bool: def _valid_dict_key_annotation_type(type_: Any) -> bool: - return type_ is None or type_ is Any or issubclass(type_, (str, int, Enum)) + from omegaconf import DictKeyType + + return type_ is None or type_ is Any or issubclass(type_, DictKeyType.__args__) # type: ignore def is_primitive_type(type_: Any) -> bool: diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index e041836c8..a40dd2a26 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -46,6 +46,8 @@ ) from .nodes import EnumNode, ValueNode +DictKeyType = Union[str, int, Enum] + class DictConfig(BaseContainer, MutableMapping[str, Any]): @@ -245,14 +247,14 @@ def _raise_invalid_value( ) raise ValidationError(msg) - def _validate_and_normalize_key(self, key: Any) -> Union[str, int, Enum]: + def _validate_and_normalize_key(self, key: Any) -> DictKeyType: return self._s_validate_and_normalize_key(self._metadata.key_type, key) def _s_validate_and_normalize_key( self, key_type: Any, key: Any - ) -> Union[str, int, Enum]: + ) -> DictKeyType: if key_type is Any: - for t in (str, int, Enum): + for t in DictKeyType.__args__: # type: ignore try: return self._s_validate_and_normalize_key(key_type=t, key=key) except KeyValidationError: @@ -285,7 +287,7 @@ def _s_validate_and_normalize_key( else: assert False, f"Unsupported key type {key_type}" - def __setitem__(self, key: Union[str, int, Enum], value: Any) -> None: + def __setitem__(self, key: DictKeyType, value: Any) -> None: try: self.__set_impl(key=key, value=value) except AttributeError as e: @@ -295,7 +297,7 @@ def __setitem__(self, key: Union[str, int, Enum], value: Any) -> None: except Exception as e: self._format_and_raise(key=key, value=value, cause=e) - def __set_impl(self, key: Union[str, int, Enum], value: Any) -> None: + def __set_impl(self, key: DictKeyType, value: Any) -> None: key = self._validate_and_normalize_key(key) self._set_item_impl(key, value) @@ -338,7 +340,7 @@ def __getattr__(self, key: str) -> Any: except Exception as e: self._format_and_raise(key=key, value=None, cause=e) - def __getitem__(self, key: Union[str, int, Enum]) -> Any: + def __getitem__(self, key: DictKeyType) -> Any: """ Allow map style access :param key: @@ -354,7 +356,7 @@ def __getitem__(self, key: Union[str, int, Enum]) -> Any: except Exception as e: self._format_and_raise(key=key, value=None, cause=e) - def __delitem__(self, key: Union[str, int, Enum]) -> None: + def __delitem__(self, key: DictKeyType) -> None: if self._get_flag("readonly"): self._format_and_raise( key=key, @@ -383,14 +385,14 @@ def __delitem__(self, key: Union[str, int, Enum]) -> None: del self.__dict__["_content"][key] def get( - self, key: Union[str, int, Enum], default_value: Any = DEFAULT_VALUE_MARKER + self, key: DictKeyType, default_value: Any = DEFAULT_VALUE_MARKER ) -> Any: try: return self._get_impl(key=key, default_value=default_value) except Exception as e: self._format_and_raise(key=key, value=None, cause=e) - def _get_impl(self, key: Union[str, int, Enum], default_value: Any) -> Any: + def _get_impl(self, key: DictKeyType, default_value: Any) -> Any: try: node = self._get_node(key=key) except ConfigAttributeError: @@ -403,7 +405,7 @@ def _get_impl(self, key: Union[str, int, Enum], default_value: Any) -> Any: ) def _get_node( - self, key: Union[str, int, Enum], validate_access: bool = True + self, key: DictKeyType, validate_access: bool = True ) -> Optional[Node]: try: key = self._validate_and_normalize_key(key) @@ -421,7 +423,7 @@ def _get_node( return value def pop( - self, key: Union[str, int, Enum], default: Any = DEFAULT_VALUE_MARKER + self, key: DictKeyType, default: Any = DEFAULT_VALUE_MARKER ) -> Any: try: if self._get_flag("readonly"): @@ -494,7 +496,7 @@ def __iter__(self) -> Iterator[str]: def items(self) -> AbstractSet[Tuple[str, Any]]: return self.items_ex(resolve=True, keys=None) - def setdefault(self, key: Union[str, int, Enum], default: Any = None) -> Any: + def setdefault(self, key: DictKeyType, default: Any = None) -> Any: if key in self: ret = self.__getitem__(key) else: From 56ea33692a652b653512485c2262bd808d15c14b Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Sun, 13 Dec 2020 13:57:06 -0600 Subject: [PATCH 10/33] add "DictKeyType" to __init__.__all__ --- omegaconf/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/omegaconf/__init__.py b/omegaconf/__init__.py index 5461faff4..53a382115 100644 --- a/omegaconf/__init__.py +++ b/omegaconf/__init__.py @@ -39,6 +39,7 @@ "Container", "ListConfig", "DictConfig", + "DictKeyType", "OmegaConf", "Resolver", "flag_override", From 379b232c6265a317c8d82f2ac7262452b710c148 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Sun, 13 Dec 2020 14:25:22 -0600 Subject: [PATCH 11/33] Test Dict[int, str] in structured config --- tests/structured_conf/data/attr_classes.py | 1 + tests/structured_conf/data/dataclasses.py | 1 + tests/structured_conf/test_structured_config.py | 10 ++++++++++ 3 files changed, 12 insertions(+) diff --git a/tests/structured_conf/data/attr_classes.py b/tests/structured_conf/data/attr_classes.py index 76368f544..b10cf4028 100644 --- a/tests/structured_conf/data/attr_classes.py +++ b/tests/structured_conf/data/attr_classes.py @@ -350,6 +350,7 @@ class DictExamples: "green": Color.GREEN, "blue": Color.BLUE, } + int_keys: Dict[int, str] = {1: "one", 2: "two"} @attr.s(auto_attribs=True) diff --git a/tests/structured_conf/data/dataclasses.py b/tests/structured_conf/data/dataclasses.py index 349503007..193731477 100644 --- a/tests/structured_conf/data/dataclasses.py +++ b/tests/structured_conf/data/dataclasses.py @@ -365,6 +365,7 @@ class DictExamples: "blue": Color.BLUE, } ) + int_keys: Dict[int, str] = field(default_factory=lambda: {1: "one", 2: "two"}) @dataclass diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index 4f91f4eb3..c592ba489 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -675,6 +675,16 @@ def test_any(name: str) -> None: "f": Color.BLUE, } + # test int_keys + with pytest.raises(KeyValidationError): + conf.int_keys.foo_key = "foo_value" + conf.int_keys[3] = "three" + assert conf.int_keys == { + 1: "one", + 2: "two", + 3: "three", + } + def test_enum_key(self, class_type: str) -> None: module: Any = import_module(class_type) conf = OmegaConf.structured(module.DictWithEnumKeys) From 457ec4d9ced96ff108a1d511aaee79b6907842b8 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Sun, 13 Dec 2020 14:26:30 -0600 Subject: [PATCH 12/33] black formatting --- omegaconf/dictconfig.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index a40dd2a26..1f508978d 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -250,9 +250,7 @@ def _raise_invalid_value( def _validate_and_normalize_key(self, key: Any) -> DictKeyType: return self._s_validate_and_normalize_key(self._metadata.key_type, key) - def _s_validate_and_normalize_key( - self, key_type: Any, key: Any - ) -> DictKeyType: + def _s_validate_and_normalize_key(self, key_type: Any, key: Any) -> DictKeyType: if key_type is Any: for t in DictKeyType.__args__: # type: ignore try: @@ -384,9 +382,7 @@ def __delitem__(self, key: DictKeyType) -> None: del self.__dict__["_content"][key] - def get( - self, key: DictKeyType, default_value: Any = DEFAULT_VALUE_MARKER - ) -> Any: + def get(self, key: DictKeyType, default_value: Any = DEFAULT_VALUE_MARKER) -> Any: try: return self._get_impl(key=key, default_value=default_value) except Exception as e: @@ -422,9 +418,7 @@ def _get_node( return value - def pop( - self, key: DictKeyType, default: Any = DEFAULT_VALUE_MARKER - ) -> Any: + def pop(self, key: DictKeyType, default: Any = DEFAULT_VALUE_MARKER) -> Any: try: if self._get_flag("readonly"): raise ReadonlyConfigError("Cannot pop from read-only node") From 181eaa5ab3268aee10da3742068f924ccf62537a Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Sun, 13 Dec 2020 14:50:52 -0600 Subject: [PATCH 13/33] More test coverage for Dict[int, ...] --- tests/test_basic_ops_dict.py | 48 ++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/tests/test_basic_ops_dict.py b/tests/test_basic_ops_dict.py index c2edd1338..1f0f3b3f2 100644 --- a/tests/test_basic_ops_dict.py +++ b/tests/test_basic_ops_dict.py @@ -72,12 +72,27 @@ def test_mandatory_value() -> None: c.a +def test_mandatory_value_int_key() -> None: + c = OmegaConf.create({1: "???"}) + with pytest.raises(MissingMandatoryValue, match="1"): + c[1] + + def test_nested_dict_mandatory_value() -> None: c = OmegaConf.create(dict(a=dict(b="???"))) with pytest.raises(MissingMandatoryValue): c.a.b +def test_nested_dict_mandatory_value_int_key() -> None: + c = OmegaConf.create({1: dict(b="???")}) + with pytest.raises(MissingMandatoryValue): + c[1].b + c2 = OmegaConf.create(dict(a={2: "???"})) + with pytest.raises(MissingMandatoryValue): + c2.a[2] + + def test_subscript_get() -> None: c = OmegaConf.create("a: b") assert isinstance(c, DictConfig) @@ -96,6 +111,12 @@ def test_subscript_set() -> None: assert {"a": "b"} == c +def test_subscript_set_int_key() -> None: + c = OmegaConf.create() + c[1] = "b" + assert {1: "b"} == c + + def test_default_value() -> None: c = OmegaConf.create() assert c.missing_key or "a default value" == "a default value" @@ -359,6 +380,10 @@ def test_dict_pop_error(cfg: Dict[Any, Any], key: Any, expectation: Any) -> None "incompatible_key_type", False, ), + ({1: "a", 2: {}}, 1, True), + ({1: "a", 2: {}}, 2, True), + ({1: "a", 2: {}}, 3, False), + ({1: "a", 2: "???"}, 2, False), ], ) def test_in_dict(conf: Any, key: str, expected: Any) -> None: @@ -400,6 +425,18 @@ def test_dict_delitem() -> None: del c["not_found"] +def test_dict_delitem_int_key() -> None: + src = {1: "a", 2: "b"} + c = OmegaConf.create(src) + assert c == src + del c[1] + assert c == {2: "b"} + with pytest.raises(KeyError): + del c["not_found"] + with pytest.raises(KeyError): + del c[3] + + def test_dict_struct_delitem() -> None: src = {"a": 10, "b": 11} c = OmegaConf.create(src) @@ -411,6 +448,17 @@ def test_dict_struct_delitem() -> None: assert "a" not in c +def test_dict_struct_delitem_int_key() -> None: + src = {1: "a", 2: "b"} + c = OmegaConf.create(src) + OmegaConf.set_struct(c, True) + with pytest.raises(ConfigTypeError): + del c[1] + with open_dict(c): + del c[1] + assert 1 not in c + + def test_dict_structured_delitem() -> None: c = OmegaConf.structured(User(name="Bond")) with pytest.raises(ConfigTypeError): From 2ef70e2d5e77c1ca188423a69eaf0cb6ae86ff97 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Sun, 13 Dec 2020 15:34:07 -0600 Subject: [PATCH 14/33] Use DictKeyType for type annotation. The changes in this commit were primarly driven by feedback from mypy. Note that the definition of `DictKeyType` has been moved from dictconfig.py to basecontainer.py due to a circular import issue. --- omegaconf/__init__.py | 3 ++- omegaconf/base.py | 2 +- omegaconf/basecontainer.py | 9 +++++++-- omegaconf/dictconfig.py | 19 +++++++++---------- omegaconf/omegaconf.py | 8 ++++---- 5 files changed, 23 insertions(+), 18 deletions(-) diff --git a/omegaconf/__init__.py b/omegaconf/__init__.py index 53a382115..efbdcef0f 100644 --- a/omegaconf/__init__.py +++ b/omegaconf/__init__.py @@ -1,5 +1,6 @@ from .base import Container, Node -from .dictconfig import DictConfig, DictKeyType +from .basecontainer import DictKeyType +from .dictconfig import DictConfig from .errors import ( KeyValidationError, MissingMandatoryValue, diff --git a/omegaconf/base.py b/omegaconf/base.py index 85b075209..32b79edc0 100644 --- a/omegaconf/base.py +++ b/omegaconf/base.py @@ -250,7 +250,7 @@ def __setitem__(self, key: Any, value: Any) -> None: ... @abstractmethod - def __iter__(self) -> Iterator[str]: + def __iter__(self) -> Iterator[Any]: ... @abstractmethod diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 8084f0882..4131fae8c 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -29,6 +29,8 @@ DEFAULT_VALUE_MARKER: Any = str("__DEFAULT_VALUE_MARKER__") +DictKeyType = Union[str, int, Enum] + class BaseContainer(Container, ABC): # static @@ -187,7 +189,7 @@ def _to_content( resolve: bool, enum_to_str: bool = False, exclude_structured_configs: bool = False, - ) -> Union[None, Any, str, Dict[str, Any], List[Any]]: + ) -> Union[None, Any, str, Dict[DictKeyType, Any], List[Any]]: from .dictconfig import DictConfig from .listconfig import ListConfig @@ -528,7 +530,10 @@ def assign(value_key: Any, value_to_assign: Any) -> None: @staticmethod def _item_eq( - c1: Container, k1: Union[str, int], c2: Container, k2: Union[str, int] + c1: Container, + k1: Union[DictKeyType, int], + c2: Container, + k2: Union[DictKeyType, int], ) -> bool: v1 = c1._get_node(k1) v2 = c2._get_node(k2) diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index 1f508978d..61839f01f 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -9,6 +9,7 @@ List, MutableMapping, Optional, + Sequence, Tuple, Type, Union, @@ -32,7 +33,7 @@ valid_value_annotation_type, ) from .base import Container, ContainerMetadata, Node -from .basecontainer import DEFAULT_VALUE_MARKER, BaseContainer +from .basecontainer import DEFAULT_VALUE_MARKER, BaseContainer, DictKeyType from .errors import ( ConfigAttributeError, ConfigKeyError, @@ -46,16 +47,14 @@ ) from .nodes import EnumNode, ValueNode -DictKeyType = Union[str, int, Enum] - -class DictConfig(BaseContainer, MutableMapping[str, Any]): +class DictConfig(BaseContainer, MutableMapping[DictKeyType, Any]): _metadata: ContainerMetadata def __init__( self, - content: Union[Dict[str, Any], Any], + content: Union[Dict[DictKeyType, Any], Any], key: Any = None, parent: Optional[Container] = None, ref_type: Union[Any, Type[Any]] = Any, @@ -484,10 +483,10 @@ def __contains__(self, key: object) -> bool: except (MissingMandatoryValue, KeyError): return False - def __iter__(self) -> Iterator[str]: + def __iter__(self) -> Iterator[DictKeyType]: return iter(self.keys()) - def items(self) -> AbstractSet[Tuple[str, Any]]: + def items(self) -> AbstractSet[Tuple[DictKeyType, Any]]: return self.items_ex(resolve=True, keys=None) def setdefault(self, key: DictKeyType, default: Any = None) -> Any: @@ -499,9 +498,9 @@ def setdefault(self, key: DictKeyType, default: Any = None) -> Any: return ret def items_ex( - self, resolve: bool = True, keys: Optional[List[str]] = None - ) -> AbstractSet[Tuple[str, Any]]: - items: List[Tuple[str, Any]] = [] + self, resolve: bool = True, keys: Optional[Sequence[DictKeyType]] = None + ) -> AbstractSet[Tuple[DictKeyType, Any]]: + items: List[Tuple[DictKeyType, Any]] = [] for key in self.keys(): if resolve: value = self.get(key) diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index 22ce8b277..078c563ab 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -28,7 +28,7 @@ import yaml from typing_extensions import Protocol -from . import DictConfig, ListConfig +from . import DictConfig, DictKeyType, ListConfig from ._utils import ( _ensure_container, _get_value, @@ -183,7 +183,7 @@ def create( @staticmethod @overload def create( - obj: Union[Dict[str, Any], None] = None, + obj: Optional[Dict[DictKeyType, Any]] = None, parent: Optional[BaseContainer] = None, flags: Optional[Dict[str, bool]] = None, ) -> DictConfig: @@ -467,7 +467,7 @@ def to_container( resolve: bool = False, enum_to_str: bool = False, exclude_structured_configs: bool = False, - ) -> Union[Dict[str, Any], List[Any], None, str]: + ) -> Union[Dict[DictKeyType, Any], List[Any], None, str]: """ Resursively converts an OmegaConf config to a primitive container (dict or list). :param cfg: the config to convert @@ -508,7 +508,7 @@ def is_optional(obj: Any, key: Optional[Union[int, str]] = None) -> bool: return True @staticmethod - def is_none(obj: Any, key: Optional[Union[int, str]] = None) -> bool: + def is_none(obj: Any, key: Optional[Union[int, DictKeyType]] = None) -> bool: if key is not None: assert isinstance(obj, Container) obj = obj._get_node(key) From 3b193b601a43302a2c5d6107ef2df0afa744992f Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Mon, 14 Dec 2020 12:47:23 -0600 Subject: [PATCH 15/33] Change DictConfig MutableMapping parameter to Any. modified: omegaconf/dictconfig.py Changing the declaration of DictConfig from class DictConfig(BaseContainer, MutableMapping[DictKeyType, Any]): to class DictConfig(BaseContainer, MutableMapping[Any, Any]): This eliminates some mypy errors. For example, the following code was giving mypy errors when using the old declaration: from omegaconf import OmegaConf cfg = OmegaConf.create() upd = dict(a=1) cfg.update(upd) The mypy error was: error: Argument 1 to "update" of "MutableMapping" has incompatible type "Dict[str, int]"; expected "Mapping[Union[str, int, Enum], Any]" An alternative to using `MutableMapping[Any, Any]` would have been something like ``` KT = TypeVar("KT", bound=DictKeyType) class DictConfig(BaseContainer, MutableMapping[KT, Any]): ... ``` but this alternative would have introduced additional complexity, such as requiring parametrization of DictConfig instances, e.g. `DictConfig[int]` or `DictConfig[str]`. --- omegaconf/dictconfig.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index 61839f01f..ef5cb6d87 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -48,7 +48,7 @@ from .nodes import EnumNode, ValueNode -class DictConfig(BaseContainer, MutableMapping[DictKeyType, Any]): +class DictConfig(BaseContainer, MutableMapping[Any, Any]): _metadata: ContainerMetadata From a83120f44fa73336e2e55428a9ca41f1fa484416 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Mon, 14 Dec 2020 15:36:16 -0600 Subject: [PATCH 16/33] mypy: more flexible OmegaConf.create signature --- omegaconf/omegaconf.py | 2 +- tests/test_basic_ops_dict.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index 078c563ab..4f3e1f1f5 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -183,7 +183,7 @@ def create( @staticmethod @overload def create( - obj: Optional[Dict[DictKeyType, Any]] = None, + obj: Optional[Dict[Any, Any]] = None, parent: Optional[BaseContainer] = None, flags: Optional[Dict[str, bool]] = None, ) -> DictConfig: diff --git a/tests/test_basic_ops_dict.py b/tests/test_basic_ops_dict.py index 1f0f3b3f2..eaf61c34e 100644 --- a/tests/test_basic_ops_dict.py +++ b/tests/test_basic_ops_dict.py @@ -598,7 +598,7 @@ def test_masked_copy_is_deep() -> None: def test_creation_with_invalid_key() -> None: with pytest.raises(KeyValidationError): - OmegaConf.create({object(): "a"}) # type: ignore + OmegaConf.create({object(): "a"}) def test_set_with_invalid_key() -> None: From 4e056e6c9976601a705ba544e2da7a028d02208e Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 16 Dec 2020 14:02:00 -0600 Subject: [PATCH 17/33] modified: docs/source/structured_config.rst --- docs/source/structured_config.rst | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/structured_config.rst b/docs/source/structured_config.rst index 91bfaf0c6..cba3c8bfe 100644 --- a/docs/source/structured_config.rst +++ b/docs/source/structured_config.rst @@ -257,8 +257,9 @@ OmegaConf verifies at runtime that your Lists contains only values of the correc Dictionaries ^^^^^^^^^^^^ -Dictionaries are supported as well. Keys must be strings or enums, and values can be any of any type supported by OmegaConf -(Any, int, float, bool, str and Enums as well as arbitrary Structured configs) +Dictionaries are supported as well. Keys must be strings, ints or enums, and values can +be any of any type supported by OmegaConf (Any, int, float, bool, str and Enums as well +as arbitrary Structured configs) Misc ---- From b9ab0bfc50334f05e133de43d8709b5d52f2315f Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 16 Dec 2020 16:52:46 -0600 Subject: [PATCH 18/33] test Structured config extending Dict[int,str] --- tests/structured_conf/data/attr_classes.py | 4 ++ tests/structured_conf/data/dataclasses.py | 4 ++ .../structured_conf/test_structured_config.py | 50 +++++++++++++++++++ 3 files changed, 58 insertions(+) diff --git a/tests/structured_conf/data/attr_classes.py b/tests/structured_conf/data/attr_classes.py index b10cf4028..a8710d43a 100644 --- a/tests/structured_conf/data/attr_classes.py +++ b/tests/structured_conf/data/attr_classes.py @@ -373,6 +373,10 @@ class DictSubclass: class Str2Str(Dict[str, str]): pass + @attr.s(auto_attribs=True) + class Int2Str(Dict[int, str]): + pass + @attr.s(auto_attribs=True) class Color2Str(Dict[Color, str]): pass diff --git a/tests/structured_conf/data/dataclasses.py b/tests/structured_conf/data/dataclasses.py index 193731477..de4611463 100644 --- a/tests/structured_conf/data/dataclasses.py +++ b/tests/structured_conf/data/dataclasses.py @@ -392,6 +392,10 @@ class DictSubclass: class Str2Str(Dict[str, str]): pass + @dataclass + class Int2Str(Dict[int, str]): + pass + @dataclass class Color2Str(Dict[Color, str]): pass diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index c592ba489..b8282ce07 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -904,6 +904,49 @@ def test_str2str_as_sub_node(self, class_type: str) -> None: with pytest.raises(KeyValidationError): cfg.foo[Color.RED] = "fail" + with pytest.raises(KeyValidationError): + cfg.foo[123] = "fail" + + def test_int2str(self, class_type: str) -> None: + module: Any = import_module(class_type) + cfg = OmegaConf.structured(module.DictSubclass.Int2Str()) + + cfg[10] = "ten" # okay + assert cfg[10] == "ten" + + with pytest.raises(KeyValidationError): + cfg[10.0] = "float" # fail + + with pytest.raises(KeyValidationError): + cfg["10"] = "string" # fail + + with pytest.raises(KeyValidationError): + cfg.hello = "fail" + + with pytest.raises(KeyValidationError): + cfg[Color.RED] = "fail" + + def test_int2str_as_sub_node(self, class_type: str) -> None: + module: Any = import_module(class_type) + cfg = OmegaConf.create({"foo": module.DictSubclass.Int2Str}) + assert OmegaConf.get_type(cfg.foo) == module.DictSubclass.Int2Str + assert _utils.get_ref_type(cfg.foo) == Optional[module.DictSubclass.Int2Str] + + cfg.foo[10] = "ten" + assert cfg.foo[10] == "ten" + + with pytest.raises(KeyValidationError): + cfg.foo[10.0] = "float" # fail + + with pytest.raises(KeyValidationError): + cfg.foo["10"] = "string" # fail + + with pytest.raises(KeyValidationError): + cfg.foo.hello = "fail" + + with pytest.raises(KeyValidationError): + cfg.foo[Color.RED] = "fail" + def test_color2str(self, class_type: str) -> None: module: Any = import_module(class_type) cfg = OmegaConf.structured(module.DictSubclass.Color2Str()) @@ -912,6 +955,9 @@ def test_color2str(self, class_type: str) -> None: with pytest.raises(KeyValidationError): cfg.greeen = "nope" + with pytest.raises(KeyValidationError): + cfg[123] = "nope" + def test_color2color(self, class_type: str) -> None: module: Any = import_module(class_type) cfg = OmegaConf.structured(module.DictSubclass.Color2Color()) @@ -934,6 +980,10 @@ def test_color2color(self, class_type: str) -> None: # bad value cfg[Color.GREEN] = 10 + with pytest.raises(ValidationError): + # bad value + cfg[Color.GREEN] = "this string is not a color" + with pytest.raises(KeyValidationError): # bad key cfg.greeen = "nope" From 35c796a68ec4e8580fa2914ca2ca59f077184941 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 16 Dec 2020 17:42:56 -0600 Subject: [PATCH 19/33] test_basic_ops_dict.py: parametrize by key type --- tests/test_basic_ops_dict.py | 156 ++++++++++++++--------------------- 1 file changed, 63 insertions(+), 93 deletions(-) diff --git a/tests/test_basic_ops_dict.py b/tests/test_basic_ops_dict.py index eaf61c34e..5a85a44c2 100644 --- a/tests/test_basic_ops_dict.py +++ b/tests/test_basic_ops_dict.py @@ -66,55 +66,69 @@ def test_getattr_dict() -> None: assert {"b": 1} == c.a -def test_mandatory_value() -> None: - c = OmegaConf.create({"a": "???"}) - with pytest.raises(MissingMandatoryValue, match="a"): - c.a - - -def test_mandatory_value_int_key() -> None: - c = OmegaConf.create({1: "???"}) - with pytest.raises(MissingMandatoryValue, match="1"): - c[1] - - -def test_nested_dict_mandatory_value() -> None: - c = OmegaConf.create(dict(a=dict(b="???"))) - with pytest.raises(MissingMandatoryValue): - c.a.b - - -def test_nested_dict_mandatory_value_int_key() -> None: - c = OmegaConf.create({1: dict(b="???")}) - with pytest.raises(MissingMandatoryValue): - c[1].b - c2 = OmegaConf.create(dict(a={2: "???"})) - with pytest.raises(MissingMandatoryValue): - c2.a[2] - - -def test_subscript_get() -> None: - c = OmegaConf.create("a: b") - assert isinstance(c, DictConfig) - assert "b" == c["a"] - - -def test_subscript_get_int_key() -> None: - c = OmegaConf.create({1: "b"}) - assert isinstance(c, DictConfig) - assert "b" == c[1] - - -def test_subscript_set() -> None: - c = OmegaConf.create() - c["a"] = "b" - assert {"a": "b"} == c - - -def test_subscript_set_int_key() -> None: - c = OmegaConf.create() - c[1] = "b" - assert {1: "b"} == c +@pytest.mark.parametrize( + "key", + ["a", 1], +) +class TestDictKeyTypes: + def test_mandatory_value(self, key) -> None: + c = OmegaConf.create({key: "???"}) + with pytest.raises(MissingMandatoryValue, match=str(key)): + c[key] + if isinstance(key, str): + with pytest.raises(MissingMandatoryValue, match=key): + getattr(c, key) + + def test_nested_dict_mandatory_value(self, key) -> None: + c = OmegaConf.create({"b": {key: "???"}}) + with pytest.raises(MissingMandatoryValue): + c.b[key] + if isinstance(key, str): + with pytest.raises(MissingMandatoryValue): + getattr(c.b, key) + + c = OmegaConf.create({key: {"b": "???"}}) + with pytest.raises(MissingMandatoryValue): + c[key].b + if isinstance(key, str): + with pytest.raises(MissingMandatoryValue): + getattr(c, key).b + + def test_subscript_get(self, key) -> None: + c = OmegaConf.create({key: "b"}) + assert isinstance(c, DictConfig) + assert "b" == c[key] + + def test_subscript_set(self, key) -> None: + c = OmegaConf.create() + c[key] = "b" + assert {key: "b"} == c + + +@pytest.mark.parametrize( + "src,key,expected", + [ + ({"a": 10, "b": 11}, "a", {"b": 11}), + ({1: "a", 2: "b"}, 1, {2: "b"}), + ], +) +class TestDelitemKeyTypes: + def test_dict_delitem(self, src, key, expected) -> None: + c = OmegaConf.create(src) + assert c == src + del c[key] + assert c == expected + with pytest.raises(KeyError): + del c["not_found"] + + def test_dict_struct_delitem(self, src, key, expected) -> None: + c = OmegaConf.create(src) + OmegaConf.set_struct(c, True) + with pytest.raises(ConfigTypeError): + del c[key] + with open_dict(c): + del c[key] + assert key not in c def test_default_value() -> None: @@ -415,50 +429,6 @@ def test_dict_config() -> None: assert isinstance(c, DictConfig) -def test_dict_delitem() -> None: - src = {"a": 10, "b": 11} - c = OmegaConf.create(src) - assert c == src - del c["a"] - assert c == {"b": 11} - with pytest.raises(KeyError): - del c["not_found"] - - -def test_dict_delitem_int_key() -> None: - src = {1: "a", 2: "b"} - c = OmegaConf.create(src) - assert c == src - del c[1] - assert c == {2: "b"} - with pytest.raises(KeyError): - del c["not_found"] - with pytest.raises(KeyError): - del c[3] - - -def test_dict_struct_delitem() -> None: - src = {"a": 10, "b": 11} - c = OmegaConf.create(src) - OmegaConf.set_struct(c, True) - with pytest.raises(ConfigTypeError): - del c["a"] - with open_dict(c): - del c["a"] - assert "a" not in c - - -def test_dict_struct_delitem_int_key() -> None: - src = {1: "a", 2: "b"} - c = OmegaConf.create(src) - OmegaConf.set_struct(c, True) - with pytest.raises(ConfigTypeError): - del c[1] - with open_dict(c): - del c[1] - assert 1 not in c - - def test_dict_structured_delitem() -> None: c = OmegaConf.structured(User(name="Bond")) with pytest.raises(ConfigTypeError): From c563e60fb58edd968b41b7abf48664380be023f4 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 16 Dec 2020 17:53:33 -0600 Subject: [PATCH 20/33] test OmegaConf.to_yaml for DictConfig[int, str] --- tests/test_to_yaml.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_to_yaml.py b/tests/test_to_yaml.py index 55b70d283..3d9d7faac 100644 --- a/tests/test_to_yaml.py +++ b/tests/test_to_yaml.py @@ -14,6 +14,8 @@ [ (["item1", "item2", {"key3": "value3"}], "- item1\n- item2\n- key3: value3\n"), ({"hello": "world", "list": [1, 2]}, "hello: world\nlist:\n- 1\n- 2\n"), + ({"abc": "str key"}, "abc: str key\n"), + ({123: "int key"}, "123: int key\n"), ], ) def test_to_yaml(input_: Any, expected: str) -> None: From 2a700a55a7456ce62bf23ee51f3d8b6995955275 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 16 Dec 2020 19:20:39 -0600 Subject: [PATCH 21/33] update docs --- docs/source/usage.rst | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 363ed1830..cc89c57e2 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -41,13 +41,14 @@ From a dictionary .. doctest:: - >>> conf = OmegaConf.create({"k" : "v", "list" : [1, {"a": "1", "b": "2"}]}) + >>> conf = OmegaConf.create({"k" : "v", "list" : [1, {"a": "1", "b": "2", 3: "c"}]}) >>> print(OmegaConf.to_yaml(conf)) k: v list: - 1 - a: '1' b: '2' + 3: c From a list @@ -55,12 +56,13 @@ From a list .. doctest:: - >>> conf = OmegaConf.create([1, {"a":10, "b": {"a":10}}]) + >>> conf = OmegaConf.create([1, {"a":10, "b": {"a":10, 123: 456}}]) >>> print(OmegaConf.to_yaml(conf)) - 1 - a: 10 b: a: 10 + 123: 456 Tuples are supported as an valid option too. @@ -95,6 +97,7 @@ From a YAML string ... list: ... - item1 ... - item2 + ... 123: 456 ... """ >>> conf = OmegaConf.create(s) >>> print(OmegaConf.to_yaml(conf)) @@ -103,6 +106,7 @@ From a YAML string list: - item1 - item2 + 123: 456 From a dot-list @@ -264,7 +268,7 @@ Save/Load YAML file .. doctest:: loaded - >>> conf = OmegaConf.create({"foo": 10, "bar": 20}) + >>> conf = OmegaConf.create({"foo": 10, "bar": 20, 123: 456}) >>> with tempfile.NamedTemporaryFile() as fp: ... OmegaConf.save(config=conf, f=fp.name) ... loaded = OmegaConf.load(fp.name) @@ -279,7 +283,7 @@ Note that the saved file may be incompatible across different major versions of .. doctest:: loaded - >>> conf = OmegaConf.create({"foo": 10, "bar": 20}) + >>> conf = OmegaConf.create({"foo": 10, "bar": 20, 123: 456}) >>> with tempfile.TemporaryFile() as fp: ... pickle.dump(conf, fp) ... fp.flush() From 819ad2869ac461b43a15193c3b19abbac1718526 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 16 Dec 2020 19:30:35 -0600 Subject: [PATCH 22/33] test_basic_ops_dict.py: add type annotations --- tests/test_basic_ops_dict.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_basic_ops_dict.py b/tests/test_basic_ops_dict.py index 5a85a44c2..146c8edc9 100644 --- a/tests/test_basic_ops_dict.py +++ b/tests/test_basic_ops_dict.py @@ -71,7 +71,7 @@ def test_getattr_dict() -> None: ["a", 1], ) class TestDictKeyTypes: - def test_mandatory_value(self, key) -> None: + def test_mandatory_value(self, key: Any) -> None: c = OmegaConf.create({key: "???"}) with pytest.raises(MissingMandatoryValue, match=str(key)): c[key] @@ -79,7 +79,7 @@ def test_mandatory_value(self, key) -> None: with pytest.raises(MissingMandatoryValue, match=key): getattr(c, key) - def test_nested_dict_mandatory_value(self, key) -> None: + def test_nested_dict_mandatory_value(self, key: Any) -> None: c = OmegaConf.create({"b": {key: "???"}}) with pytest.raises(MissingMandatoryValue): c.b[key] @@ -94,12 +94,12 @@ def test_nested_dict_mandatory_value(self, key) -> None: with pytest.raises(MissingMandatoryValue): getattr(c, key).b - def test_subscript_get(self, key) -> None: + def test_subscript_get(self, key: Any) -> None: c = OmegaConf.create({key: "b"}) assert isinstance(c, DictConfig) assert "b" == c[key] - def test_subscript_set(self, key) -> None: + def test_subscript_set(self, key: Any) -> None: c = OmegaConf.create() c[key] = "b" assert {key: "b"} == c @@ -113,7 +113,7 @@ def test_subscript_set(self, key) -> None: ], ) class TestDelitemKeyTypes: - def test_dict_delitem(self, src, key, expected) -> None: + def test_dict_delitem(self, src: Any, key: Any, expected: Any) -> None: c = OmegaConf.create(src) assert c == src del c[key] @@ -121,7 +121,7 @@ def test_dict_delitem(self, src, key, expected) -> None: with pytest.raises(KeyError): del c["not_found"] - def test_dict_struct_delitem(self, src, key, expected) -> None: + def test_dict_struct_delitem(self, src: Any, key: Any, expected: Any) -> None: c = OmegaConf.create(src) OmegaConf.set_struct(c, True) with pytest.raises(ConfigTypeError): From d5cdc18eb1635f482ffcd73965c4cf95223577e5 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 16 Dec 2020 21:07:49 -0600 Subject: [PATCH 23/33] mypy test_basic_ops_dict: more specific annotation --- tests/test_basic_ops_dict.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/test_basic_ops_dict.py b/tests/test_basic_ops_dict.py index 146c8edc9..fbe3d7fc1 100644 --- a/tests/test_basic_ops_dict.py +++ b/tests/test_basic_ops_dict.py @@ -8,6 +8,7 @@ from omegaconf import ( DictConfig, ListConfig, + DictKeyType, MissingMandatoryValue, OmegaConf, UnsupportedValueType, @@ -71,7 +72,7 @@ def test_getattr_dict() -> None: ["a", 1], ) class TestDictKeyTypes: - def test_mandatory_value(self, key: Any) -> None: + def test_mandatory_value(self, key: DictKeyType) -> None: c = OmegaConf.create({key: "???"}) with pytest.raises(MissingMandatoryValue, match=str(key)): c[key] @@ -79,7 +80,7 @@ def test_mandatory_value(self, key: Any) -> None: with pytest.raises(MissingMandatoryValue, match=key): getattr(c, key) - def test_nested_dict_mandatory_value(self, key: Any) -> None: + def test_nested_dict_mandatory_value(self, key: DictKeyType) -> None: c = OmegaConf.create({"b": {key: "???"}}) with pytest.raises(MissingMandatoryValue): c.b[key] @@ -94,12 +95,12 @@ def test_nested_dict_mandatory_value(self, key: Any) -> None: with pytest.raises(MissingMandatoryValue): getattr(c, key).b - def test_subscript_get(self, key: Any) -> None: + def test_subscript_get(self, key: DictKeyType) -> None: c = OmegaConf.create({key: "b"}) assert isinstance(c, DictConfig) assert "b" == c[key] - def test_subscript_set(self, key: Any) -> None: + def test_subscript_set(self, key: DictKeyType) -> None: c = OmegaConf.create() c[key] = "b" assert {key: "b"} == c @@ -113,7 +114,7 @@ def test_subscript_set(self, key: Any) -> None: ], ) class TestDelitemKeyTypes: - def test_dict_delitem(self, src: Any, key: Any, expected: Any) -> None: + def test_dict_delitem(self, src: Any, key: DictKeyType, expected: Any) -> None: c = OmegaConf.create(src) assert c == src del c[key] @@ -121,7 +122,9 @@ def test_dict_delitem(self, src: Any, key: Any, expected: Any) -> None: with pytest.raises(KeyError): del c["not_found"] - def test_dict_struct_delitem(self, src: Any, key: Any, expected: Any) -> None: + def test_dict_struct_delitem( + self, src: Any, key: DictKeyType, expected: Any + ) -> None: c = OmegaConf.create(src) OmegaConf.set_struct(c, True) with pytest.raises(ConfigTypeError): From d2502530d60225f8e6b180c25ffc1455c404893d Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Thu, 17 Dec 2020 03:53:46 -0600 Subject: [PATCH 24/33] isort --- tests/test_basic_ops_dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_basic_ops_dict.py b/tests/test_basic_ops_dict.py index fbe3d7fc1..2636f9d9f 100644 --- a/tests/test_basic_ops_dict.py +++ b/tests/test_basic_ops_dict.py @@ -7,8 +7,8 @@ from omegaconf import ( DictConfig, - ListConfig, DictKeyType, + ListConfig, MissingMandatoryValue, OmegaConf, UnsupportedValueType, From 0649b399f5d8cd2a0be9fdaf4d1ac4e5636fc319 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Thu, 17 Dec 2020 11:41:09 -0600 Subject: [PATCH 25/33] docs: explicit note about supported dict key types --- docs/source/usage.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index cc89c57e2..83684dabe 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -51,6 +51,9 @@ From a dictionary 3: c +Currently, OmegaConf supports the following primitive key types for dictionaries: `str`, +`int`, and subclasses of `Enum`. + From a list ^^^^^^^^^^^ From 263d17e5fea6d1c2ebf169086b62ce074d86bcdc Mon Sep 17 00:00:00 2001 From: Jasha10 <8935917+Jasha10@users.noreply.github.com> Date: Mon, 21 Dec 2020 16:44:51 -0600 Subject: [PATCH 26/33] Update docs/source/usage.rst Co-authored-by: Omry Yadan --- docs/source/usage.rst | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 83684dabe..8f2e3a5b2 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -51,8 +51,7 @@ From a dictionary 3: c -Currently, OmegaConf supports the following primitive key types for dictionaries: `str`, -`int`, and subclasses of `Enum`. +OmegaConf supports `str`, `int` and Enums as dictionary key types. From a list ^^^^^^^^^^^ From d330cb1a71fb161944b595ced8eddb0d080437b6 Mon Sep 17 00:00:00 2001 From: Jasha10 <8935917+Jasha10@users.noreply.github.com> Date: Mon, 21 Dec 2020 16:45:15 -0600 Subject: [PATCH 27/33] Update docs/source/usage.rst Co-authored-by: Omry Yadan --- docs/source/usage.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 8f2e3a5b2..1db4cde61 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -58,7 +58,7 @@ From a list .. doctest:: - >>> conf = OmegaConf.create([1, {"a":10, "b": {"a":10, 123: 456}}]) + >>> conf = OmegaConf.create([1, {"a":10, "b": {"a":10, 123: "int_key"}}]) >>> print(OmegaConf.to_yaml(conf)) - 1 - a: 10 From d50f2f2ba49416e971962ed41ab6e64cfd193256 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Mon, 21 Dec 2020 17:04:09 -0600 Subject: [PATCH 28/33] move DictKeyType defn basecontainer.py -> base.py --- omegaconf/__init__.py | 3 +-- omegaconf/base.py | 2 ++ omegaconf/basecontainer.py | 4 +--- omegaconf/dictconfig.py | 4 ++-- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/omegaconf/__init__.py b/omegaconf/__init__.py index efbdcef0f..2a1e82872 100644 --- a/omegaconf/__init__.py +++ b/omegaconf/__init__.py @@ -1,5 +1,4 @@ -from .base import Container, Node -from .basecontainer import DictKeyType +from .base import Container, DictKeyType, Node from .dictconfig import DictConfig from .errors import ( KeyValidationError, diff --git a/omegaconf/base.py b/omegaconf/base.py index 32b79edc0..fe1a290a5 100644 --- a/omegaconf/base.py +++ b/omegaconf/base.py @@ -7,6 +7,8 @@ from ._utils import ValueKind, _get_value, format_and_raise, get_value_kind from .errors import ConfigKeyError, MissingMandatoryValue, UnsupportedInterpolationType +DictKeyType = Union[str, int, Enum] + @dataclass class Metadata: diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 4131fae8c..649d5831d 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -24,13 +24,11 @@ is_primitive_type, is_structured_config, ) -from .base import Container, ContainerMetadata, Node +from .base import Container, ContainerMetadata, DictKeyType, Node from .errors import MissingMandatoryValue, ReadonlyConfigError, ValidationError DEFAULT_VALUE_MARKER: Any = str("__DEFAULT_VALUE_MARKER__") -DictKeyType = Union[str, int, Enum] - class BaseContainer(Container, ABC): # static diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index ef5cb6d87..8c2ec97e4 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -32,8 +32,8 @@ type_str, valid_value_annotation_type, ) -from .base import Container, ContainerMetadata, Node -from .basecontainer import DEFAULT_VALUE_MARKER, BaseContainer, DictKeyType +from .base import Container, ContainerMetadata, DictKeyType, Node +from .basecontainer import DEFAULT_VALUE_MARKER, BaseContainer from .errors import ( ConfigAttributeError, ConfigKeyError, From 1c772640f7b56a22d34777785a7bd40525f7bb65 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Mon, 21 Dec 2020 17:21:31 -0600 Subject: [PATCH 29/33] test KeyValidationError dict[int,Any]:mistyped_key --- tests/test_errors.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/test_errors.py b/tests/test_errors.py index 551c952fd..cf88a4d6f 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -1,6 +1,7 @@ import re from dataclasses import dataclass from enum import Enum +from textwrap import dedent from typing import Any, Dict, List, Optional, Type, Union import pytest @@ -435,6 +436,23 @@ def finalize(self, cfg: Any) -> None: ), id="dict:get_object_of_illegal_type", ), + pytest.param( + Expected( + create=lambda: DictConfig({}, key_type=int), + op=lambda cfg: cfg.get("foo"), + exception_type=KeyValidationError, + msg=dedent( + """\ + Key foo (str) is incompatible with (int) + \tfull_key: foo + \treference_type=Optional[Dict[int, Any]] + \tobject_type=dict""" + ), + key="foo", + full_key="foo", + ), + id="dict[int,Any]:mistyped_key", + ), # dict:create pytest.param( Expected( From d6170a69e7ee21dfc5d7c354e2bbe82f34bc01e8 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Mon, 21 Dec 2020 17:24:59 -0600 Subject: [PATCH 30/33] docs usage.rst: fix doctest failure --- docs/source/usage.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 1db4cde61..0cc256358 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -64,7 +64,7 @@ From a list - a: 10 b: a: 10 - 123: 456 + 123: int_key Tuples are supported as an valid option too. From 9af1d0b08ede4b680be67663435c3ab523cd3a43 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Tue, 22 Dec 2020 22:28:57 -0600 Subject: [PATCH 31/33] add news fragment --- NEWS.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/NEWS.md b/NEWS.md index 63a50c243..10263ab39 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,9 @@ +## 2.0.6 (2020-12-22) + +### Features + +- OmegaConf now supports `int` for dictionary key types ([#149](https://github.com/omry/omegaconf/issues/149)) + ## 2.0.5 (2020-11-11) From fe0b0ad9db71ee8160cdb29567133e95a57d5e40 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 23 Dec 2020 13:26:59 -0600 Subject: [PATCH 32/33] Revert "add news fragment" This reverts commit 9af1d0b08ede4b680be67663435c3ab523cd3a43. --- NEWS.md | 6 ------ 1 file changed, 6 deletions(-) diff --git a/NEWS.md b/NEWS.md index 10263ab39..63a50c243 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,9 +1,3 @@ -## 2.0.6 (2020-12-22) - -### Features - -- OmegaConf now supports `int` for dictionary key types ([#149](https://github.com/omry/omegaconf/issues/149)) - ## 2.0.5 (2020-11-11) From 1ecd152a3324f35cfabb42eb4ea0e30dd9709d4c Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 23 Dec 2020 13:33:17 -0600 Subject: [PATCH 33/33] add towncrier news fragment --- news/149.feature | 1 + 1 file changed, 1 insertion(+) create mode 100644 news/149.feature diff --git a/news/149.feature b/news/149.feature new file mode 100644 index 000000000..c85412de5 --- /dev/null +++ b/news/149.feature @@ -0,0 +1 @@ +Add support for `int` key type in OmegaConf dictionaries