Skip to content

Commit

Permalink
Adding support for bool and float as DictConfig key types (#484)
Browse files Browse the repository at this point in the history
  • Loading branch information
omry authored Feb 22, 2021
2 parents 42e8702 + baa918c commit 1dbe70f
Show file tree
Hide file tree
Showing 16 changed files with 1,058 additions and 136 deletions.
18 changes: 17 additions & 1 deletion docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,23 @@ From a dictionary
3: c
<BLANKLINE>

OmegaConf supports `str`, `int` and Enums as dictionary key types.
Here is an example of various supported key types:

.. doctest::

>>> from enum import Enum
>>> class Color(Enum):
... RED = 1
... BLUE = 2
>>>
>>> conf = OmegaConf.create(
... {"key": "str", 123: "int", True: "bool", 3.14: "float", Color.RED: "Color"}
... )
>>>
>>> print(conf)
{'key': 'str', 123: 'int', True: 'bool', 3.14: 'float', <Color.RED: 1>: 'Color'}

OmegaConf supports `str`, `int`, `bool`, `float` and Enums as dictionary key types.

From a list
^^^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions news/483.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add DictConfig support for keys of type float and bool
1 change: 1 addition & 0 deletions news/554.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
When a dictconfig has enum-typed keys, __delitem__ can now be called with a string naming the enum member to be deleted.
4 changes: 2 additions & 2 deletions omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .grammar_parser import parse
from .grammar_visitor import GrammarVisitor

DictKeyType = Union[str, int, Enum]
DictKeyType = Union[str, int, Enum, float, bool]

_MARKER_ = object()

Expand Down Expand Up @@ -179,7 +179,7 @@ def _format_and_raise(
assert False

@abstractmethod
def _get_full_key(self, key: Union[str, Enum, int, None]) -> str:
def _get_full_key(self, key: Optional[Union[DictKeyType, int]]) -> str:
...

def _dereference_node(
Expand Down
8 changes: 5 additions & 3 deletions omegaconf/basecontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, parent: Optional["Container"], metadata: ContainerMetadata):

def _resolve_with_default(
self,
key: Union[str, int, Enum],
key: Union[DictKeyType, int],
value: Any,
default_value: Any = DEFAULT_VALUE_MARKER,
) -> Any:
Expand Down Expand Up @@ -697,11 +697,11 @@ def _validate_set(self, key: Any, value: Any) -> None:
def _value(self) -> Any:
return self.__dict__["_content"]

def _get_full_key(self, key: Union[str, Enum, int, slice, None]) -> str:
def _get_full_key(self, key: Union[DictKeyType, int, slice, None]) -> str:
from .listconfig import ListConfig
from .omegaconf import _select_one

if not isinstance(key, (int, str, Enum, slice, type(None))):
if not isinstance(key, (int, str, Enum, float, bool, slice, type(None))):
return ""

def _slice_to_str(x: slice) -> str:
Expand All @@ -715,6 +715,8 @@ def prepand(full_key: str, parent_type: Any, cur_type: Any, key: Any) -> str:
key = _slice_to_str(key)
elif isinstance(key, Enum):
key = key.name
elif isinstance(key, (int, float, bool)):
key = str(key)

if issubclass(parent_type, ListConfig):
if full_key != "":
Expand Down
35 changes: 19 additions & 16 deletions omegaconf/dictconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,28 +275,26 @@ def _validate_and_normalize_key(self, key: Any) -> DictKeyType:
def _s_validate_and_normalize_key(self, key_type: Any, key: Any) -> DictKeyType:
if key_type is Any:
for t in DictKeyType.__args__: # type: ignore
try:
return self._s_validate_and_normalize_key(key_type=t, key=key)
except KeyValidationError:
pass
if isinstance(key, t):
return key # type: ignore
raise KeyValidationError("Incompatible key type '$KEY_TYPE'")
elif key_type == str:
if not isinstance(key, str):
raise KeyValidationError(
f"Key $KEY ($KEY_TYPE) is incompatible with ({key_type.__name__})"
)

return key
elif key_type == int:
if not isinstance(key, int):
elif key_type is bool and key in [0, 1]:
# Python treats True as 1 and False as 0 when used as dict keys
# assert hash(0) == hash(False)
# assert hash(1) == hash(True)
return bool(key)
elif key_type in (str, int, float, bool): # primitive type
if not isinstance(key, key_type):
raise KeyValidationError(
f"Key $KEY ($KEY_TYPE) is incompatible with ({key_type.__name__})"
)

return key
return key # type: ignore
elif issubclass(key_type, Enum):
try:
ret = EnumNode.validate_and_convert_to_enum(key_type, key)
ret = EnumNode.validate_and_convert_to_enum(
key_type, key, allow_none=False
)
assert ret is not None
return ret
except ValidationError:
Expand Down Expand Up @@ -377,6 +375,7 @@ def __getitem__(self, key: DictKeyType) -> Any:
self._format_and_raise(key=key, value=None, cause=e)

def __delitem__(self, key: DictKeyType) -> None:
key = self._validate_and_normalize_key(key)
if self._get_flag("readonly"):
self._format_and_raise(
key=key,
Expand All @@ -402,7 +401,11 @@ def __delitem__(self, key: DictKeyType) -> None:
),
)

del self.__dict__["_content"][key]
try:
del self.__dict__["_content"][key]
except KeyError:
msg = "Key not found: '$KEY'"
self._format_and_raise(key=key, value=None, cause=ConfigKeyError(msg))

def get(self, key: DictKeyType, default_value: Any = None) -> Any:
"""Return the value for `key` if `key` is in the dictionary, else
Expand Down
8 changes: 4 additions & 4 deletions omegaconf/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
get_value_kind,
is_primitive_container,
)
from omegaconf.base import Container, Metadata, Node
from omegaconf.base import Container, DictKeyType, Metadata, Node
from omegaconf.errors import (
ConfigKeyError,
ReadonlyConfigError,
Expand Down Expand Up @@ -122,7 +122,7 @@ def _is_missing(self) -> bool:
def _is_interpolation(self) -> bool:
return _is_interpolation(self._value())

def _get_full_key(self, key: Union[str, Enum, int, None]) -> str:
def _get_full_key(self, key: Optional[Union[DictKeyType, int]]) -> str:
parent = self._get_parent()
if parent is None:
if self._metadata.key is None:
Expand Down Expand Up @@ -366,9 +366,9 @@ def validate_and_convert(self, value: Any) -> Optional[Enum]:

@staticmethod
def validate_and_convert_to_enum(
enum_type: Type[Enum], value: Any
enum_type: Type[Enum], value: Any, allow_none: bool = True
) -> Optional[Enum]:
if value is None:
if allow_none and value is None:
return None

if not isinstance(value, (str, int)) and not isinstance(value, enum_type):
Expand Down
6 changes: 5 additions & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,11 @@ class UntypedDict:

@dataclass
class SubscriptedDict:
dict: Dict[str, int] = field(default_factory=lambda: {"foo": 4})
dict_str: Dict[str, int] = field(default_factory=lambda: {"foo": 4})
dict_enum: Dict[Color, int] = field(default_factory=lambda: {Color.RED: 4})
dict_int: Dict[int, int] = field(default_factory=lambda: {123: 4})
dict_float: Dict[float, int] = field(default_factory=lambda: {123.45: 4})
dict_bool: Dict[bool, int] = field(default_factory=lambda: {True: 4, False: 5})


@dataclass
Expand Down
14 changes: 10 additions & 4 deletions tests/structured_conf/data/attr_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,8 @@ class DictExamples:
"blue": Color.BLUE,
}
int_keys: Dict[int, str] = {1: "one", 2: "two"}


@attr.s(auto_attribs=True)
class DictWithEnumKeys:
float_keys: Dict[float, str] = {1.1: "one", 2.2: "two"}
bool_keys: Dict[bool, str] = {True: "T", False: "F"}
enum_key: Dict[Color, str] = {Color.RED: "red", Color.GREEN: "green"}


Expand All @@ -414,6 +412,14 @@ class Str2Str(Dict[str, str]):
class Int2Str(Dict[int, str]):
pass

@attr.s(auto_attribs=True)
class Float2Str(Dict[float, str]):
pass

@attr.s(auto_attribs=True)
class Bool2Str(Dict[bool, str]):
pass

@attr.s(auto_attribs=True)
class Color2Str(Dict[Color, str]):
pass
Expand Down
16 changes: 12 additions & 4 deletions tests/structured_conf/data/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,10 +403,10 @@ class DictExamples:
}
)
int_keys: Dict[int, str] = field(default_factory=lambda: {1: "one", 2: "two"})


@dataclass
class DictWithEnumKeys:
float_keys: Dict[float, str] = field(
default_factory=lambda: {1.1: "one", 2.2: "two"}
)
bool_keys: Dict[bool, str] = field(default_factory=lambda: {True: "T", False: "F"})
enum_key: Dict[Color, str] = field(
default_factory=lambda: {Color.RED: "red", Color.GREEN: "green"}
)
Expand All @@ -433,6 +433,14 @@ class Str2Str(Dict[str, str]):
class Int2Str(Dict[int, str]):
pass

@dataclass
class Float2Str(Dict[float, str]):
pass

@dataclass
class Bool2Str(Dict[bool, str]):
pass

@dataclass
class Color2Str(Dict[Color, str]):
pass
Expand Down
Loading

0 comments on commit 1dbe70f

Please sign in to comment.