Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dict int key type #454

Merged
merged 34 commits into from
Dec 24, 2020
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
1eadbdf
preliminary DictConfig support for int key type
Jasha10 Dec 10, 2020
29941b5
DictConfig[int, ...]: comment out offending tests
Jasha10 Dec 11, 2020
1848190
black formatting
Jasha10 Dec 11, 2020
c4b12b9
black formatting
Jasha10 Dec 11, 2020
120ee32
Merge branch 'dict-int-key-type' of https://github.com/7iW/omegaconf …
Jasha10 Dec 11, 2020
df962eb
Change ErrorDictIntKey to ErrorDictObjectKey
Jasha10 Dec 13, 2020
cbfed03
tests: call OmegaConf.create on dictionary instead of string
Jasha10 Dec 13, 2020
fec336c
fix typo in test_basic_ops_dict
Jasha10 Dec 13, 2020
a3581c6
for invalid key test, use object instead of int
Jasha10 Dec 13, 2020
ecdef40
Use DictKeyType instead of Union[...]
Jasha10 Dec 13, 2020
56ea336
add "DictKeyType" to __init__.__all__
Jasha10 Dec 13, 2020
379b232
Test Dict[int, str] in structured config
Jasha10 Dec 13, 2020
457ec4d
black formatting
Jasha10 Dec 13, 2020
181eaa5
More test coverage for Dict[int, ...]
Jasha10 Dec 13, 2020
2ef70e2
Use DictKeyType for type annotation.
Jasha10 Dec 13, 2020
3b193b6
Change DictConfig MutableMapping parameter to Any.
Jasha10 Dec 14, 2020
a83120f
mypy: more flexible OmegaConf.create signature
Jasha10 Dec 14, 2020
4e056e6
modified: docs/source/structured_config.rst
Jasha10 Dec 16, 2020
b9ab0bf
test Structured config extending Dict[int,str]
Jasha10 Dec 16, 2020
35c796a
test_basic_ops_dict.py: parametrize by key type
Jasha10 Dec 16, 2020
c563e60
test OmegaConf.to_yaml for DictConfig[int, str]
Jasha10 Dec 16, 2020
2a700a5
update docs
Jasha10 Dec 17, 2020
819ad28
test_basic_ops_dict.py: add type annotations
Jasha10 Dec 17, 2020
d5cdc18
mypy test_basic_ops_dict: more specific annotation
Jasha10 Dec 17, 2020
d250253
isort
Jasha10 Dec 17, 2020
0649b39
docs: explicit note about supported dict key types
Jasha10 Dec 17, 2020
263d17e
Update docs/source/usage.rst
Jasha10 Dec 21, 2020
d330cb1
Update docs/source/usage.rst
Jasha10 Dec 21, 2020
d50f2f2
move DictKeyType defn basecontainer.py -> base.py
Jasha10 Dec 21, 2020
1c77264
test KeyValidationError dict[int,Any]:mistyped_key
Jasha10 Dec 21, 2020
d6170a6
docs usage.rst: fix doctest failure
Jasha10 Dec 21, 2020
9af1d0b
add news fragment
Jasha10 Dec 23, 2020
fe0b0ad
Revert "add news fragment"
Jasha10 Dec 23, 2020
1ecd152
add towncrier news fragment
Jasha10 Dec 23, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/source/structured_config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,9 @@ OmegaConf verifies at runtime that your Lists contains only values of the correc

Dictionaries
^^^^^^^^^^^^
Dictionaries are supported as well. Keys must be strings or enums, and values can be any of any type supported by OmegaConf
(Any, int, float, bool, str and Enums as well as arbitrary Structured configs)
Dictionaries are supported as well. Keys must be strings, ints or enums, and values can
be any of any type supported by OmegaConf (Any, int, float, bool, str and Enums as well
as arbitrary Structured configs)

Misc
----
Expand Down
12 changes: 8 additions & 4 deletions docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,26 +41,28 @@ From a dictionary

.. doctest::

>>> conf = OmegaConf.create({"k" : "v", "list" : [1, {"a": "1", "b": "2"}]})
>>> conf = OmegaConf.create({"k" : "v", "list" : [1, {"a": "1", "b": "2", 3: "c"}]})
>>> print(OmegaConf.to_yaml(conf))
k: v
list:
- 1
- a: '1'
b: '2'
3: c
<BLANKLINE>

From a list
^^^^^^^^^^^

.. doctest::

>>> conf = OmegaConf.create([1, {"a":10, "b": {"a":10}}])
>>> conf = OmegaConf.create([1, {"a":10, "b": {"a":10, 123: 456}}])
Jasha10 marked this conversation as resolved.
Show resolved Hide resolved
>>> print(OmegaConf.to_yaml(conf))
- 1
- a: 10
b:
a: 10
123: 456
<BLANKLINE>

Tuples are supported as an valid option too.
Expand Down Expand Up @@ -95,6 +97,7 @@ From a YAML string
... list:
... - item1
... - item2
... 123: 456
... """
>>> conf = OmegaConf.create(s)
>>> print(OmegaConf.to_yaml(conf))
Expand All @@ -103,6 +106,7 @@ From a YAML string
list:
- item1
- item2
123: 456
<BLANKLINE>

From a dot-list
Expand Down Expand Up @@ -264,7 +268,7 @@ Save/Load YAML file

.. doctest:: loaded

>>> conf = OmegaConf.create({"foo": 10, "bar": 20})
>>> conf = OmegaConf.create({"foo": 10, "bar": 20, 123: 456})
>>> with tempfile.NamedTemporaryFile() as fp:
... OmegaConf.save(config=conf, f=fp.name)
... loaded = OmegaConf.load(fp.name)
Expand All @@ -279,7 +283,7 @@ Note that the saved file may be incompatible across different major versions of

.. doctest:: loaded

>>> conf = OmegaConf.create({"foo": 10, "bar": 20})
>>> conf = OmegaConf.create({"foo": 10, "bar": 20, 123: 456})
>>> with tempfile.TemporaryFile() as fp:
... pickle.dump(conf, fp)
... fp.flush()
Expand Down
2 changes: 2 additions & 0 deletions omegaconf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .base import Container, Node
from .basecontainer import DictKeyType
from .dictconfig import DictConfig
from .errors import (
KeyValidationError,
Expand Down Expand Up @@ -39,6 +40,7 @@
"Container",
"ListConfig",
"DictConfig",
"DictKeyType",
"OmegaConf",
"Resolver",
"flag_override",
Expand Down
4 changes: 3 additions & 1 deletion omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,9 @@ def valid_value_annotation_type(type_: Any) -> bool:


def _valid_dict_key_annotation_type(type_: Any) -> bool:
return type_ is None or type_ is Any or issubclass(type_, (str, Enum))
from omegaconf import DictKeyType

return type_ is None or type_ is Any or issubclass(type_, DictKeyType.__args__) # type: ignore


def is_primitive_type(type_: Any) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def __setitem__(self, key: Any, value: Any) -> None:
...

@abstractmethod
def __iter__(self) -> Iterator[str]:
def __iter__(self) -> Iterator[Any]:
...

@abstractmethod
Expand Down
9 changes: 7 additions & 2 deletions omegaconf/basecontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

DEFAULT_VALUE_MARKER: Any = str("__DEFAULT_VALUE_MARKER__")

DictKeyType = Union[str, int, Enum]

Jasha10 marked this conversation as resolved.
Show resolved Hide resolved

class BaseContainer(Container, ABC):
# static
Expand Down Expand Up @@ -187,7 +189,7 @@ def _to_content(
resolve: bool,
enum_to_str: bool = False,
exclude_structured_configs: bool = False,
) -> Union[None, Any, str, Dict[str, Any], List[Any]]:
) -> Union[None, Any, str, Dict[DictKeyType, Any], List[Any]]:
from .dictconfig import DictConfig
from .listconfig import ListConfig

Expand Down Expand Up @@ -528,7 +530,10 @@ def assign(value_key: Any, value_to_assign: Any) -> None:

@staticmethod
def _item_eq(
c1: Container, k1: Union[str, int], c2: Container, k2: Union[str, int]
c1: Container,
k1: Union[DictKeyType, int],
Jasha10 marked this conversation as resolved.
Show resolved Hide resolved
c2: Container,
k2: Union[DictKeyType, int],
) -> bool:
v1 = c1._get_node(k1)
v2 = c2._get_node(k2)
Expand Down
52 changes: 28 additions & 24 deletions omegaconf/dictconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
List,
MutableMapping,
Optional,
Sequence,
Tuple,
Type,
Union,
Expand All @@ -32,7 +33,7 @@
valid_value_annotation_type,
)
from .base import Container, ContainerMetadata, Node
from .basecontainer import DEFAULT_VALUE_MARKER, BaseContainer
from .basecontainer import DEFAULT_VALUE_MARKER, BaseContainer, DictKeyType
from .errors import (
ConfigAttributeError,
ConfigKeyError,
Expand All @@ -47,13 +48,13 @@
from .nodes import EnumNode, ValueNode


class DictConfig(BaseContainer, MutableMapping[str, Any]):
class DictConfig(BaseContainer, MutableMapping[Any, Any]):

_metadata: ContainerMetadata

def __init__(
self,
content: Union[Dict[str, Any], Any],
content: Union[Dict[DictKeyType, Any], Any],
key: Any = None,
parent: Optional[Container] = None,
ref_type: Union[Any, Type[Any]] = Any,
Expand Down Expand Up @@ -245,14 +246,12 @@ def _raise_invalid_value(
)
raise ValidationError(msg)

def _validate_and_normalize_key(self, key: Any) -> Union[str, Enum]:
def _validate_and_normalize_key(self, key: Any) -> DictKeyType:
return self._s_validate_and_normalize_key(self._metadata.key_type, key)

def _s_validate_and_normalize_key(
self, key_type: Any, key: Any
) -> Union[str, Enum]:
def _s_validate_and_normalize_key(self, key_type: Any, key: Any) -> DictKeyType:
if key_type is Any:
for t in (str, Enum):
for t in DictKeyType.__args__: # type: ignore
try:
return self._s_validate_and_normalize_key(key_type=t, key=key)
except KeyValidationError:
Expand All @@ -264,6 +263,13 @@ def _s_validate_and_normalize_key(
f"Key $KEY ($KEY_TYPE) is incompatible with ({key_type.__name__})"
)

return key
elif key_type == int:
if not isinstance(key, int):
raise KeyValidationError(
f"Key $KEY ($KEY_TYPE) is incompatible with ({key_type.__name__})"
)

Jasha10 marked this conversation as resolved.
Show resolved Hide resolved
return key
elif issubclass(key_type, Enum):
try:
Expand All @@ -278,7 +284,7 @@ def _s_validate_and_normalize_key(
else:
assert False, f"Unsupported key type {key_type}"

def __setitem__(self, key: Union[str, Enum], value: Any) -> None:
def __setitem__(self, key: DictKeyType, value: Any) -> None:
try:
self.__set_impl(key=key, value=value)
except AttributeError as e:
Expand All @@ -288,7 +294,7 @@ def __setitem__(self, key: Union[str, Enum], value: Any) -> None:
except Exception as e:
self._format_and_raise(key=key, value=value, cause=e)

def __set_impl(self, key: Union[str, Enum], value: Any) -> None:
def __set_impl(self, key: DictKeyType, value: Any) -> None:
key = self._validate_and_normalize_key(key)
self._set_item_impl(key, value)

Expand Down Expand Up @@ -331,7 +337,7 @@ def __getattr__(self, key: str) -> Any:
except Exception as e:
self._format_and_raise(key=key, value=None, cause=e)

def __getitem__(self, key: Union[str, Enum]) -> Any:
def __getitem__(self, key: DictKeyType) -> Any:
"""
Allow map style access
:param key:
Expand All @@ -347,7 +353,7 @@ def __getitem__(self, key: Union[str, Enum]) -> Any:
except Exception as e:
self._format_and_raise(key=key, value=None, cause=e)

def __delitem__(self, key: Union[str, int, Enum]) -> None:
def __delitem__(self, key: DictKeyType) -> None:
if self._get_flag("readonly"):
self._format_and_raise(
key=key,
Expand Down Expand Up @@ -375,15 +381,13 @@ def __delitem__(self, key: Union[str, int, Enum]) -> None:

del self.__dict__["_content"][key]

def get(
self, key: Union[str, Enum], default_value: Any = DEFAULT_VALUE_MARKER
) -> Any:
def get(self, key: DictKeyType, default_value: Any = DEFAULT_VALUE_MARKER) -> Any:
try:
return self._get_impl(key=key, default_value=default_value)
except Exception as e:
self._format_and_raise(key=key, value=None, cause=e)

def _get_impl(self, key: Union[str, Enum], default_value: Any) -> Any:
def _get_impl(self, key: DictKeyType, default_value: Any) -> Any:
try:
node = self._get_node(key=key)
except ConfigAttributeError:
Expand All @@ -396,7 +400,7 @@ def _get_impl(self, key: Union[str, Enum], default_value: Any) -> Any:
)

def _get_node(
self, key: Union[str, Enum], validate_access: bool = True
self, key: DictKeyType, validate_access: bool = True
) -> Optional[Node]:
try:
key = self._validate_and_normalize_key(key)
Expand All @@ -413,7 +417,7 @@ def _get_node(

return value

def pop(self, key: Union[str, Enum], default: Any = DEFAULT_VALUE_MARKER) -> Any:
def pop(self, key: DictKeyType, default: Any = DEFAULT_VALUE_MARKER) -> Any:
try:
if self._get_flag("readonly"):
raise ReadonlyConfigError("Cannot pop from read-only node")
Expand Down Expand Up @@ -479,13 +483,13 @@ def __contains__(self, key: object) -> bool:
except (MissingMandatoryValue, KeyError):
return False

def __iter__(self) -> Iterator[str]:
def __iter__(self) -> Iterator[DictKeyType]:
return iter(self.keys())

def items(self) -> AbstractSet[Tuple[str, Any]]:
def items(self) -> AbstractSet[Tuple[DictKeyType, Any]]:
return self.items_ex(resolve=True, keys=None)

def setdefault(self, key: Union[str, Enum], default: Any = None) -> Any:
def setdefault(self, key: DictKeyType, default: Any = None) -> Any:
if key in self:
ret = self.__getitem__(key)
else:
Expand All @@ -494,9 +498,9 @@ def setdefault(self, key: Union[str, Enum], default: Any = None) -> Any:
return ret

def items_ex(
self, resolve: bool = True, keys: Optional[List[str]] = None
) -> AbstractSet[Tuple[str, Any]]:
items: List[Tuple[str, Any]] = []
self, resolve: bool = True, keys: Optional[Sequence[DictKeyType]] = None
) -> AbstractSet[Tuple[DictKeyType, Any]]:
items: List[Tuple[DictKeyType, Any]] = []
odelalleau marked this conversation as resolved.
Show resolved Hide resolved
for key in self.keys():
if resolve:
value = self.get(key)
Expand Down
8 changes: 4 additions & 4 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import yaml
from typing_extensions import Protocol

from . import DictConfig, ListConfig
from . import DictConfig, DictKeyType, ListConfig
from ._utils import (
_ensure_container,
_get_value,
Expand Down Expand Up @@ -183,7 +183,7 @@ def create(
@staticmethod
@overload
def create(
obj: Union[Dict[str, Any], None] = None,
obj: Optional[Dict[Any, Any]] = None,
odelalleau marked this conversation as resolved.
Show resolved Hide resolved
parent: Optional[BaseContainer] = None,
flags: Optional[Dict[str, bool]] = None,
) -> DictConfig:
Expand Down Expand Up @@ -467,7 +467,7 @@ def to_container(
resolve: bool = False,
enum_to_str: bool = False,
exclude_structured_configs: bool = False,
) -> Union[Dict[str, Any], List[Any], None, str]:
) -> Union[Dict[DictKeyType, Any], List[Any], None, str]:
"""
Resursively converts an OmegaConf config to a primitive container (dict or list).
:param cfg: the config to convert
Expand Down Expand Up @@ -508,7 +508,7 @@ def is_optional(obj: Any, key: Optional[Union[int, str]] = None) -> bool:
return True

@staticmethod
def is_none(obj: Any, key: Optional[Union[int, str]] = None) -> bool:
def is_none(obj: Any, key: Optional[Union[int, DictKeyType]] = None) -> bool:
if key is not None:
assert isinstance(obj, Container)
obj = obj._get_node(key)
Expand Down
9 changes: 7 additions & 2 deletions tests/structured_conf/data/attr_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,9 @@ class WithTypedDict:


@attr.s(auto_attribs=True)
class ErrorDictIntKey:
class ErrorDictObjectKey:
# invalid dict key, must be str
dict: Dict[int, str] = {10: "foo", 20: "bar"}
dict: Dict[object, str] = {object(): "foo", object(): "bar"}


class RegularClass:
Expand Down Expand Up @@ -350,6 +350,7 @@ class DictExamples:
"green": Color.GREEN,
"blue": Color.BLUE,
}
int_keys: Dict[int, str] = {1: "one", 2: "two"}


@attr.s(auto_attribs=True)
Expand All @@ -372,6 +373,10 @@ class DictSubclass:
class Str2Str(Dict[str, str]):
pass

@attr.s(auto_attribs=True)
class Int2Str(Dict[int, str]):
pass

@attr.s(auto_attribs=True)
class Color2Str(Dict[Color, str]):
pass
Expand Down
Loading