Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Standardized exception handling
Browse files Browse the repository at this point in the history
omry committed Apr 4, 2020
1 parent 9edcb9d commit 9ddbc0c
Showing 22 changed files with 1,237 additions and 415 deletions.
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
@@ -61,7 +61,7 @@
source_suffix = ".rst"

# The master toctree document.
master_doc = "index"
master_doc = "key"

# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
1 change: 1 addition & 0 deletions news/186.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Standardize exception messages
1 change: 1 addition & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
@@ -50,6 +50,7 @@ def lint(session):
session.run("mypy", ".", "--strict", silent=True)
session.run("isort", ".", "--check", silent=True)
session.run("black", "--check", ".", silent=True)
session.run("flake8")


@nox.session(python=PYTHON_VERSIONS)
79 changes: 78 additions & 1 deletion omegaconf/_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import re
import string
import sys
import warnings
from enum import Enum
from typing import Any, Dict, List, Match, Optional, Tuple, Type, Union

import yaml

from .errors import KeyValidationError, ValidationError
from .errors import KeyValidationError, ValidationError, _OmegaConfException

try:
import dataclasses
@@ -369,6 +372,14 @@ def is_primitive_type(type_: Any) -> bool:
return issubclass(type_, Enum) or type_ in (int, float, bool, str, type(None))


def _is_primitive_type(type_: Any) -> bool:
warnings.warn(
"use omegaconf._utils.is_primitive_type", DeprecationWarning, stacklevel=2,
)

return is_primitive_type(type_)


def _is_interpolation(v: Any) -> bool:
if isinstance(v, str):
ret = get_value_kind(v) in (
@@ -389,3 +400,69 @@ def _get_value(value: Any) -> Any:
if isinstance(value, ValueNode):
value = value._value()
return value


def format_and_raise(oce: _OmegaConfException) -> None:
def type_str(t: Any) -> str:
if isinstance(t, type):
return t.__name__
else:
return str(node._metadata.object_type)

key = oce.key
if key is None:
node = oce.node
else:
if oce.node is None or oce.node._is_none() or oce.node._is_missing():
node = None
else:
if isinstance(key, slice):
node = oce.node
else:
node = oce.node.get_node_ex(key, validate_access=False)
if node is None:
node = oce.node
else:
key = None

if node is None:
full_key = None
object_type = None
rt = None
else:
full_key = node._get_full_key(key=key)

ref_type: Any = node._metadata.ref_type
if ref_type is None:
ref_type = Any
object_type = type_str(node._metadata.object_type)
rt = type_str(ref_type)
if node._metadata.optional:
if ref_type is Any:
rt = "Any"
else:
rt = f"Optional[{rt}]"

msg = string.Template(oce.msg).substitute(
REF_TYPE=rt,
OBJECT_TYPE=object_type,
KEY=key,
FULL_KEY=full_key,
VALUE=oce.value,
VALUE_TYPE=f"{type(oce.value).__name__}",
KEY_TYPE=f"{type(key).__name__}",
)

s = string.Template(
"""$MSG
\tfull_key: $FULL_KEY
\treference_type=$REF_TYPE
\tobject_type=$OBJECT_TYPE
"""
)
message = s.substitute(
REF_TYPE=rt, OBJECT_TYPE=object_type, MSG=msg, FULL_KEY=full_key,
)
raise oce.exception_type(f"{message}").with_traceback(
sys.exc_info()[2]
) from oce.cause
85 changes: 28 additions & 57 deletions omegaconf/base.py
Original file line number Diff line number Diff line change
@@ -4,8 +4,12 @@
from enum import Enum
from typing import Any, Dict, Iterator, Optional, Tuple, Type, Union

from ._utils import ValueKind, _get_value, get_value_kind
from .errors import MissingMandatoryValue, UnsupportedInterpolationType
from ._utils import ValueKind, _get_value, format_and_raise, get_value_kind
from .errors import (
MissingMandatoryValue,
UnsupportedInterpolationType,
_OmegaConfException,
)


@dataclass
@@ -84,58 +88,25 @@ def _get_flag(self, flag: str) -> Optional[bool]:
# noinspection PyProtectedMember
return parent._get_flag(flag)

def _get_full_key(self, key: Union[str, Enum, int, None]) -> str:
from .listconfig import ListConfig
from .omegaconf import _select_one

def prepand(full_key: str, parent_type: Any, cur_type: Any, key: Any) -> str:
if issubclass(parent_type, ListConfig):
if full_key != "":
if issubclass(cur_type, ListConfig):
full_key = f"[{key}]{full_key}"
else:
full_key = f"[{key}].{full_key}"
else:
full_key = f"[{key}]"
else:
if full_key == "":
full_key = key
else:
if issubclass(cur_type, ListConfig):
full_key = f"{key}{full_key}"
else:
full_key = f"{key}.{full_key}"
return full_key

if key is not None and key != "":
assert isinstance(self, Container)
cur, _ = _select_one(c=self, key=str(key), throw_on_missing=False)
if cur is None:
cur = self
full_key = prepand("", type(cur), None, key)
if cur._key() is not None:
full_key = prepand(
full_key, type(cur._get_parent()), type(cur), cur._key()
)
else:
full_key = prepand("", type(cur._get_parent()), type(cur), cur._key())
else:
cur = self
if cur._key() is None:
return ""
full_key = self._key()

assert cur is not None
while cur._get_parent() is not None:
cur = cur._get_parent()
assert cur is not None
key = cur._key()
if key is not None:
full_key = prepand(
full_key, type(cur._get_parent()), type(cur), cur._key()
)
# TODO: move to utils
def _translate_exception(
self, e: Exception, key: Any, value: Any, type_override: Any = None
) -> None:
format_and_raise(
_OmegaConfException(
exception_type=type(e) if type_override is None else type_override,
node=self,
key=key,
value=value,
msg=str(e),
cause=e,
)
)
assert False # pragma: no cover

return full_key
@abstractmethod
def _get_full_key(self, key: Union[str, Enum, int, None]) -> str:
...

def _dereference_node(self, throw_on_missing: bool = False) -> "Node":
from .nodes import StringNode
@@ -172,7 +143,7 @@ def _dereference_node(self, throw_on_missing: bool = False) -> "Node":
if throw_on_missing:
value = self._value()
if value == "???":
raise MissingMandatoryValue(self._get_full_key(""))
raise MissingMandatoryValue("Missing mandatory value")
return self

@abstractmethod
@@ -237,7 +208,7 @@ def update_node(self, key: str, value: Any = None) -> None:
def select(self, key: str, throw_on_missing: bool = False) -> Any:
... # pragma: no cover

def get_node(self, key: Any) -> Node:
def get_node(self, key: Any) -> Optional[Node]:
... # pragma: no cover

@abstractmethod
@@ -318,7 +289,7 @@ def _resolve_interpolation(

if parent is None or (value is None and last_key not in parent): # type: ignore
raise KeyError(
"{} interpolation key '{}' not found".format(inter_type, inter_key)
f"{inter_type} interpolation key '{inter_key}' not found"
)
assert isinstance(value, Node)
return value
@@ -335,7 +306,7 @@ def _resolve_interpolation(
)
else:
raise UnsupportedInterpolationType(
"Unsupported interpolation type {}".format(inter_type)
f"Unsupported interpolation type {inter_type}"
)

def _resolve_str_interpolation(
151 changes: 129 additions & 22 deletions omegaconf/basecontainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import string
import sys
import warnings
from abc import ABC, abstractmethod
@@ -64,15 +65,14 @@ def is_mandatory_missing(val: Any) -> bool:
):
value = default_value

value = self._resolve_str_interpolation(
resolved = self._resolve_str_interpolation(
key=key, value=value, throw_on_missing=True
)
if is_mandatory_missing(value):
raise MissingMandatoryValue(self._get_full_key(str(key)))
if is_mandatory_missing(resolved):
raise MissingMandatoryValue("Missing mandatory value: $FULL_KEY")
resolved2 = _get_value(resolved)

value = _get_value(value)

return value
return resolved2

def __str__(self) -> str:
return self.__repr__()
@@ -157,14 +157,19 @@ def update_node(self, key: str, value: Any = None) -> None:
root[idx] = value

def select(self, key: str, throw_on_missing: bool = False) -> Any:
_root, _last_key, value = self._select_impl(key, throw_on_missing=False)
if value is not None and value._is_missing():
if throw_on_missing:
raise MissingMandatoryValue(value._get_full_key(""))
else:
return None
try:
_root, _last_key, value = self._select_impl(key, throw_on_missing=False)
if value is not None and value._is_missing():
if throw_on_missing:
raise MissingMandatoryValue(
f"Missing mandatory value : {self._get_full_key('')}"
)
else:
return None

return _get_value(value)
return _get_value(value)
except Exception as e:
self._translate_exception(e=e, key=key, value=None)

def is_empty(self) -> bool:
"""return true if config is empty"""
@@ -220,15 +225,15 @@ def to_container(self, resolve: bool = False) -> Union[Dict[str, Any], List[Any]
return BaseContainer._to_content(self, resolve)

def pretty(self, resolve: bool = False, sort_keys: bool = False) -> str:
from omegaconf import OmegaConf

"""
returns a yaml dump of this config object.
:param resolve: if True, will return a string with the interpolations resolved, otherwise
interpolations are preserved
:param sort_keys: If True, will print dict keys in sorted order. default False.
:return: A string containing the yaml representation.
"""
from omegaconf import OmegaConf

container = OmegaConf.to_container(self, resolve=resolve, enum_to_str=True)
return yaml.dump( # type: ignore
container, default_flow_style=False, allow_unicode=True, sort_keys=sort_keys
@@ -251,6 +256,9 @@ def _map_merge(dest: "BaseContainer", src: "BaseContainer") -> None:
type_backup = dest._metadata.object_type
dest._metadata.object_type = None

if (dest._get_flag("readonly")) or dest._get_flag("readonly"):
raise ReadonlyConfigError("Cannot merge into read-only node")
dest._validate_set_merge_impl(key=None, value=src, is_assign=False)
for key, src_value in src.items_ex(resolve=False):
if dest._is_missing():
if dest._metadata.ref_type is not None and not is_dict_annotation(
@@ -316,7 +324,7 @@ def merge_with(
for item in other:
self.append(item)
else:
raise TypeError("Merging DictConfig with ListConfig is not supported")
raise TypeError("Cannot merge DictConfig with ListConfig")

# recursively correct the parent hierarchy after the merge
self._re_parent()
@@ -327,7 +335,6 @@ def _set_item_impl(self, key: Any, value: Any) -> None:

from .nodes import ValueNode

self._validate_get(key)
if isinstance(value, Node):
try:
old = value._key()
@@ -404,11 +411,56 @@ def assign(value_key: Any, value_to_assign: Any) -> None:
else:
self.__dict__["_content"][key] = wrap(key, value)
except ValidationError as ve:
import sys
self._format_and_raise(
exception_type=ValidationError, key=key, msg=f"{ve}",
)

# TODO: cleanup
def _format_and_raise(
self, exception_type: Any, msg: str, key: Optional[str]
) -> None:
def type_str(t: Any) -> str:
if isinstance(t, type):
return t.__name__
else:
return str(node._metadata.object_type)

full_key = self._get_full_key(key=key)
if key is None:
node = self
else:
node = (
self.__dict__["_content"][key]
if key in self.__dict__["_content"]
else None
)
if node is None:
node = self

ref_type: Any = node._metadata.ref_type
if ref_type is None:
ref_type = Any
object_type = type_str(node._metadata.object_type)

rt = type_str(ref_type)
if node._metadata.optional:
if ref_type is Any:
rt = "Any"
else:
rt = f"Optional[{rt}]"

s = string.Template(
"""$MSG
\tfull_key: $FULL_KEY
\treference_type=$REF_TYPE
\tobject_type=$OBJECT_TYPE
"""
)
message = s.substitute(
REF_TYPE=rt, OBJECT_TYPE=object_type, MSG=msg, FULL_KEY=full_key,
)

raise type(ve)(
f"Error setting '{self._get_full_key(str(key))} = {value}' : {ve}"
).with_traceback(sys.exc_info()[2]) from None
raise exception_type(f"{message}").with_traceback(sys.exc_info()[2]) from None

@staticmethod
def _item_eq(
@@ -500,7 +552,7 @@ def _is_interpolation(self) -> bool:
return _is_interpolation(self.__dict__["_content"])

@abstractmethod
def _validate_get(self, key: Any) -> None:
def _validate_get(self, key: Any, value: Any = None) -> None:
... # pragma: no cover

@abstractmethod
@@ -509,3 +561,58 @@ def _validate_set(self, key: Any, value: Any) -> None:

def _value(self) -> Any:
return self.__dict__["_content"]

def _get_full_key(self, key: Union[str, Enum, int, None]) -> str:
from .listconfig import ListConfig
from .omegaconf import _select_one

def prepand(full_key: str, parent_type: Any, cur_type: Any, key: Any) -> str:
if issubclass(parent_type, ListConfig):
if full_key != "":
if issubclass(cur_type, ListConfig):
full_key = f"[{key}]{full_key}"
else:
full_key = f"[{key}].{full_key}"
else:
full_key = f"[{key}]"
else:
if full_key == "":
full_key = key
else:
if issubclass(cur_type, ListConfig):
full_key = f"{key}{full_key}"
else:
full_key = f"{key}.{full_key}"
return full_key

if key is not None and key != "":
assert isinstance(self, Container)
cur, _ = _select_one(
c=self, key=str(key), throw_on_missing=False, throw_on_type_error=False,
)
if cur is None:
cur = self
full_key = prepand("", type(cur), None, key)
if cur._key() is not None:
full_key = prepand(
full_key, type(cur._get_parent()), type(cur), cur._key()
)
else:
full_key = prepand("", type(cur._get_parent()), type(cur), cur._key())
else:
cur = self
if cur._key() is None:
return ""
full_key = self._key()

assert cur is not None
while cur._get_parent() is not None:
cur = cur._get_parent()
assert cur is not None
key = cur._key()
if key is not None:
full_key = prepand(
full_key, type(cur._get_parent()), type(cur), cur._key()
)

return full_key
211 changes: 107 additions & 104 deletions omegaconf/dictconfig.py
Original file line number Diff line number Diff line change
@@ -33,7 +33,6 @@
MissingMandatoryValue,
ReadonlyConfigError,
UnsupportedInterpolationType,
UnsupportedValueType,
ValidationError,
)
from .nodes import EnumNode, ValueNode
@@ -98,7 +97,7 @@ def __copy__(self) -> "DictConfig":
def copy(self) -> "DictConfig":
return copy.copy(self)

def _validate_get(self, key: Union[int, str, Enum]) -> None:
def _validate_get(self, key: Any, value: Any = None) -> None:
is_typed = self._metadata.object_type not in (Any, None,) and not is_dict(
self._metadata.object_type
)
@@ -112,11 +111,9 @@ def _validate_get(self, key: Union[int, str, Enum]) -> None:
if is_typed or is_struct:
if is_typed:
assert self._metadata.object_type is not None
msg = f"Accessing unknown key in {self._metadata.object_type.__name__} : {self._get_full_key(key)}"
msg = f"Key '{key}' not in ({self._metadata.object_type.__name__})"
else:
msg = "Accessing unknown key in a struct : {}".format(
self._get_full_key(key)
)
msg = f"Key '{key}' in not in struct"
raise AttributeError(msg)

def _validate_merge(self, key: Any, value: Any) -> None:
@@ -125,27 +122,6 @@ def _validate_merge(self, key: Any, value: Any) -> None:
def _validate_set(self, key: Any, value: Any) -> None:
self._validate_set_merge_impl(key, value, is_assign=True)

def _raise(self, exception_type: Any, msg: str) -> None:
ref_type: Any = self._metadata.ref_type
if ref_type is None:
ref_type = Any
object_type = self._metadata.object_type

rt = ref_type.__name__ if ref_type is not Any else "Any"
if self._metadata.optional:
rt = f"Optional[{rt}]"

if ref_type is not Any and object_type is not None:
ot = object_type.__name__
raise exception_type(f"{rt} = {ot}() :: {msg}")
elif ref_type is Any and object_type is not None:
ot = object_type.__name__
raise exception_type(f"{ot}() :: {msg}")
elif ref_type is not Any and object_type is None:
raise exception_type(f"{rt} :: {msg}")
else:
raise exception_type(f"{msg}")

def _validate_set_merge_impl(self, key: Any, value: Any, is_assign: bool) -> None:
from omegaconf import OmegaConf

@@ -155,36 +131,41 @@ def _validate_set_merge_impl(self, key: Any, value: Any, is_assign: bool) -> Non
if isinstance(value, (str, ValueNode)) and vk == ValueKind.STR_INTERPOLATION:
return

target: Node
if key is None:
target = self
else:
target = self.get_node(key)

if OmegaConf.is_none(value):
if key is not None:
node = self.get_node(key)
if node is not None and not node._is_optional():
self._raise(
exception_type=ValidationError,
msg=f"field '{self._get_full_key(key)}' is not Optional",
)
raise ValidationError("field '$FULL_KEY' is not Optional")
else:
if not self._is_optional():
self._raise(
exception_type=ValidationError,
msg=f"field '{self._get_full_key(None)}' is not Optional",
)
raise ValidationError("field '$FULL_KEY' is not Optional")

if value == "???":
return

if target is not None:
if target._get_flag("readonly"):
raise ReadonlyConfigError(self._get_full_key(key))
# validate get
if key is not None:
try:
self._validate_get(key, value)
except AttributeError as e:
raise AttributeError(f"Error setting $KEY=$VALUE : {e}")

target: Optional[Node]
if key is None:
target = self
else:
if self._get_flag("readonly"):
raise ReadonlyConfigError(self._get_full_key(key))
target = self.get_node(key)

if (target is not None and target._get_flag("readonly")) or self._get_flag(
"readonly"
):
if is_assign:
msg = f"Cannot assign to read-only node : {value}"
else:
msg = f"Cannot merge into read-only node : {value}"
self._format_and_raise(
exception_type=ReadonlyConfigError, msg=msg, key=key,
)

if target is None:
return
@@ -221,63 +202,59 @@ def is_typed(c: Any) -> bool:
)

if validation_error:
assert isinstance(value_type, type)
assert isinstance(target_type, type)
raise ValidationError(
f"Invalid type assigned : {value_type.__name__} "
f"is not a subclass of {target_type.__name__}. value: {value}"
assert value_type is not None
assert target_type is not None
msg = (
f"Invalid type assigned : {value_type.__name__} is not a "
f"subclass of {target_type.__name__}. value: {value}"
)
self._format_and_raise(exception_type=ValidationError, key=key, msg=msg)

def _validate_and_normalize_key(self, key: Any) -> Union[str, Enum]:
return self._s_validate_and_normalize_key(self._metadata.key_type, key)

@staticmethod
def _s_validate_and_normalize_key(key_type: Any, key: Any) -> Union[str, Enum]:
# TODO: this function is a mess.
def _s_validate_and_normalize_key(
self, key_type: Any, key: Any
) -> Union[str, Enum]:
if key_type is None:
for t in (str, Enum):
try:
return DictConfig._s_validate_and_normalize_key(key_type=t, key=key)
return self._s_validate_and_normalize_key(key_type=t, key=key)
except KeyValidationError:
pass
raise KeyValidationError(
f"Unsupported key type {type(key).__name__} : {key}"
)

if key_type == str:
if not isinstance(key, str):
raise KeyValidationError(
f"Key {key} is incompatible with {key_type.__name__}"
)
return key

try:
ret = EnumNode.validate_and_convert_to_enum(key_type, key)
assert ret is not None
return ret
except ValidationError as e:
raise KeyValidationError(
f"Key {key} is incompatible with {key_type.__name__} : {e}"
)
raise KeyValidationError("Incompatible key type '$KEY_TYPE'")
else:
if key_type == str:
if not isinstance(key, str):
raise KeyValidationError(
f"Key $KEY ($KEY_TYPE) is incompatible with ({key_type.__name__})"
)
return key
elif issubclass(key_type, Enum):
try:
ret = EnumNode.validate_and_convert_to_enum(self, key_type, key)
assert ret is not None
return ret
except ValidationError as e:
raise KeyValidationError(
f"Key '$KEY' is incompatible with ({key_type.__name__}) : {e}"
)
else:
assert False # pragma: no cover

def __setitem__(self, key: Union[str, Enum], value: Any) -> None:

try:
self.__set_impl(key, value)
self.__set_impl(key=key, value=value)
except AttributeError as e:
import sys

raise KeyError(
f"Error setting '{self._get_full_key(str(key))} = {value}' : {e}"
).with_traceback(sys.exc_info()[2]) from None
self._translate_exception(e=e, key=key, value=value, type_override=KeyError)
except Exception as e:
self._translate_exception(e=e, key=key, value=value)

def __set_impl(self, key: Union[str, Enum], value: Any) -> None:
key = self._validate_and_normalize_key(key)

try:
self._set_item_impl(key, value)
except UnsupportedValueType as ex:
raise UnsupportedValueType(
f"'{type(value).__name__}' is not a supported type (key: {self._get_full_key(key)}) : {ex}"
)
self._set_item_impl(key, value)

# hide content while inspecting in debugger
def __dir__(self) -> Iterable[str]:
@@ -292,7 +269,11 @@ def __setattr__(self, key: str, value: Any) -> None:
:param value:
:return:
"""
self.__set_impl(key, value)
try:
self.__set_impl(key, value)
except Exception as e:
self._translate_exception(e=e, key=key, value=value)
assert False # pragma: no cover

def __getattr__(self, key: str) -> Any:
"""
@@ -304,37 +285,52 @@ def __getattr__(self, key: str) -> Any:
if key == "__members__":
raise AttributeError()

return self.get(key=key)
if key == "__name__":
raise AttributeError()

try:
return self._get_impl(key=key, default_value=DEFAULT_VALUE_MARKER)
except Exception as e:
self._translate_exception(e=e, key=key, value=None)

def __getitem__(self, key: Union[str, Enum]) -> Any:
"""
Allow map style access
:param key:
:return:
"""

try:
return self.get(key=key)
return self._get_impl(key=key, default_value=DEFAULT_VALUE_MARKER)
except AttributeError as e:
raise KeyError(str(e))
raise KeyError(f"Error getting '{key}' : {e}")
except Exception as e:
self._translate_exception(e=e, key=key, value=None)

def get(
self, key: Union[str, Enum], default_value: Any = DEFAULT_VALUE_MARKER,
) -> Any:
try:
return self._get_impl(key=key, default_value=default_value)
except Exception as e:
self._translate_exception(e=e, key=key, value=None)

def _get_impl(self, key: Union[str, Enum], default_value: Any,) -> Any:
key = self._validate_and_normalize_key(key)
node = self.get_node_ex(key=key, default_value=default_value)
return self._resolve_with_default(
key=key, value=node, default_value=default_value,
)

def get_node(self, key: Union[str, Enum]) -> Node:
def get_node(self, key: Union[str, Enum]) -> Optional[Node]:
return self.get_node_ex(key, default_value=DEFAULT_VALUE_MARKER)

def get_node_ex(
self,
key: Union[str, Enum],
default_value: Any = DEFAULT_VALUE_MARKER,
validate_access: bool = True,
) -> Node:
) -> Optional[Node]:
value: Node = self.__dict__["_content"].get(key)
if validate_access:
try:
@@ -349,19 +345,30 @@ def get_node_ex(
value = default_value
return value

__pop_marker = object()
__pop_marker = str("__POP_MARKER__")

def pop(self, key: Union[str, Enum], default: Any = __pop_marker) -> Any:
key = self._validate_and_normalize_key(key)
if self._get_flag("readonly"):
raise ReadonlyConfigError(self._get_full_key(key))
self._translate_exception(
e=ReadonlyConfigError("Cannot pop from read-only node"),
key=key,
value=None,
)

value = self._resolve_with_default(
key=key,
value=self.__dict__["_content"].pop(key, default),
default_value=default,
)

if value is self.__pop_marker:
raise KeyError(key)
full_key = self._get_full_key(key)
msg = f"Cannot pop key '{key}'"
if key != full_key:
msg += f", path='{full_key}'"

raise KeyError(msg)
return value

def keys(self) -> Any:
@@ -446,7 +453,11 @@ def _promote(self, type_or_prototype: Optional[Type[Any]]) -> None:
if type_or_prototype is None:
return
if not is_structured_config(type_or_prototype):
raise ValueError("Expected structured config class")
self._format_and_raise(
exception_type=ValueError,
key=None,
msg=f"Expected structured config class : {type_or_prototype}",
)

from omegaconf import OmegaConf

@@ -485,14 +496,6 @@ def _set_value(self, value: Any) -> None:
self._metadata.object_type = dict
for k, v in value.items_ex(resolve=False):
self.__setitem__(k, v)
# TODO: validate that it's needed
# try:
# backup = self._metadata.object_type
# self._metadata.object_type = dict
# for k, v in value.items_ex(resolve=False):
# self.__setitem__(k, v)
# finally:
# self._metadata.object_type = backup
self._metadata.object_type = OmegaConf.get_type(value)
elif isinstance(value, dict):
self._metadata.object_type = self._metadata.ref_type
23 changes: 23 additions & 0 deletions omegaconf/errors.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Any, Optional


class MissingMandatoryValue(Exception):
"""Thrown when a variable flagged with '???' value is accessed to
indicate that the value was not set"""
@@ -31,3 +34,23 @@ class UnsupportedInterpolationType(ValueError):
"""
Thrown when an attempt to use an unregistered interpolation is made
"""


class _OmegaConfException(Exception):
def __init__(
self,
node: Any,
key: Any,
value: Any,
exception_type: Any,
msg: str,
cause: Optional[Exception] = None,
):
super().__init__(msg)
assert exception_type != _OmegaConfException
self.node = node
self.key = key
self.value = value
self.exception_type = exception_type
self.msg = msg
self.cause = cause
289 changes: 189 additions & 100 deletions omegaconf/listconfig.py
Original file line number Diff line number Diff line change
@@ -14,19 +14,22 @@
Union,
)

from ._utils import ValueKind, get_value_kind, is_primitive_list, isint
from ._utils import ValueKind, get_value_kind, is_primitive_list
from .base import Container, ContainerMetadata, Node
from .basecontainer import BaseContainer
from .errors import (
KeyValidationError,
MissingMandatoryValue,
ReadonlyConfigError,
UnsupportedValueType,
ValidationError,
)
from .nodes import AnyNode, ValueNode


class ListConfig(BaseContainer, MutableSequence[Any]):

_content: Union[List[Optional[Node]], None, str]

def __init__(
self,
content: Union[List[Any], Tuple[Any, ...], str, None],
@@ -47,22 +50,26 @@ def __init__(
key_type=int,
),
)
self._content = None
self._set_value(value=content)

def _validate_get(self, index: Any) -> None:
if not isinstance(index, (int, slice)):
raise KeyValidationError(f"Key type {type(index).__name__} is invalid")
def _validate_get(self, key: Any, value: Any = None) -> None:
if not isinstance(key, (int, slice)):
raise KeyValidationError("Invalid key type '$KEY_TYPE'")

def _validate_set(self, key: Any, value: Any) -> None:

self._validate_get(key, value)

if self._get_flag("readonly"):
raise ReadonlyConfigError(self._get_full_key(f"{key}"))
raise ReadonlyConfigError("ListConfig is read-only")

if 0 <= key < self.__len__():
target = self.get_node(key)
if isinstance(target, Container):
if target is not None:
if value is None and not target._is_optional():
raise ValidationError(
"Non optional ListConfig node cannot be assigned None"
"$FULL_KEY is not optional and cannot be assigned None"
)

def __deepcopy__(self, memo: Dict[int, Any] = {}) -> "ListConfig":
@@ -72,84 +79,106 @@ def __deepcopy__(self, memo: Dict[int, Any] = {}) -> "ListConfig":
res._re_parent()
return res

def __getattr__(self, key: str) -> Any:
if isinstance(key, str) and isint(key):
return self.__getitem__(int(key))
else:
raise AttributeError()

# hide content while inspecting in debugger
def __dir__(self) -> Iterable[str]:
return [str(x) for x in range(0, len(self))]

def __len__(self) -> int:
if self._is_none():
return 0
if self._is_missing():
return 0
assert isinstance(self._content, list)
return len(self._content)

def __getitem__(self, index: Union[int, slice]) -> Any:
assert isinstance(index, (int, slice))
self._validate_get(index)

if isinstance(index, slice):
result = []
for slice_idx in itertools.islice(
range(0, len(self)), index.start, index.stop, index.step
):
val = self._resolve_with_default(
key=slice_idx, value=self._content[slice_idx], default_value=None
try:
if self._is_missing():
raise MissingMandatoryValue("ListConfig is missing")
self._validate_get(index, None)
if self._is_none():
raise TypeError(
"ListConfig object representing None is not subscriptable"
)
result.append(val)
return result
else:
return self._resolve_with_default(
key=index, value=self._content[index], default_value=None
)

def _set_at_index(self, index: Union[int, slice], value: Any) -> None:
try:
self._set_item_impl(index, value)
except UnsupportedValueType:
full_key = self._get_full_key(str(index))
assert isinstance(self._content, list)
if isinstance(index, slice):
result = []
for slice_idx in itertools.islice(
range(0, len(self)), index.start, index.stop, index.step
):
val = self._resolve_with_default(
key=slice_idx, value=self._content[slice_idx]
)
result.append(val)
return result
else:
return self._resolve_with_default(key=index, value=self._content[index])
except Exception as e:
self._translate_exception(e=e, key=index, value=None)

raise UnsupportedValueType(
f"{type(value).__name__} is not a supported type (key: {full_key})"
)
def _set_at_index(self, index: Union[int, slice], value: Any) -> None:
self._set_item_impl(index, value)

def __setitem__(self, index: Union[int, slice], value: Any) -> None:
self._set_at_index(index, value)
try:
self._set_at_index(index, value)
except Exception as e:
self._translate_exception(e=e, key=index, value=value)

def append(self, item: Any) -> None:
index = len(self)
self._validate_set(key=index, value=item)

try:
from omegaconf.omegaconf import OmegaConf, _maybe_wrap

self.__dict__["_content"].append(
_maybe_wrap(
ref_type=self._metadata.element_type,
key=index,
value=item,
is_optional=OmegaConf.is_optional(item),
parent=self,
)
)
except UnsupportedValueType:
full_key = self._get_full_key(f"{len(self)}")
raise UnsupportedValueType(
f"{type(item).__name__} is not a supported type (key: {full_key})"
index = len(self)
self._validate_set(key=index, value=item)

node = _maybe_wrap(
ref_type=self._metadata.element_type,
key=index,
value=item,
is_optional=OmegaConf.is_optional(item),
parent=self,
)
self.__dict__["_content"].append(node)
except Exception as e:
self._translate_exception(e=e, key=index, value=item)
assert False # pragma: no cover

def insert(self, index: int, item: Any) -> None:
if self._get_flag("readonly"):
raise ReadonlyConfigError(self._get_full_key(str(index)))
try:
self._content.insert(index, AnyNode(None))
self._set_at_index(index, item)
except Exception:
del self.__dict__["_content"][index]
raise
if self._get_flag("readonly"):
raise ReadonlyConfigError("Cannot insert into a read-only ListConfig")
if self._is_none():
raise TypeError(
"Cannot insert into ListConfig object representing None"
)
if self._is_missing():
raise MissingMandatoryValue("Cannot insert into missing ListConfig")
try:
# TODO: not respecting element type
# TODO: this and other list ops like delete are not updating key in list element nodes
# from omegaconf.omegaconf import OmegaConf, _maybe_wrap
#
# index = len(self)
# self._validate_set(key=index, value=item)
#
# node = _maybe_wrap(
# ref_type=self._metadata.element_type,
# key=index,
# value=item,
# is_optional=OmegaConf.is_optional(item),
# parent=self,
# )
assert isinstance(self._content, list)
self._content.insert(index, AnyNode(None))
self._set_at_index(index, item)
except Exception:
del self.__dict__["_content"][index]
raise
except Exception as e:
self._translate_exception(e=e, key=index, value=item)
assert False # pragma: no cover

def extend(self, lst: Iterable[Any]) -> None:
assert isinstance(lst, (tuple, list, ListConfig))
@@ -180,7 +209,10 @@ def index(
if found_idx != -1:
return found_idx
else:
raise ValueError("Item not found in ListConfig")
self._translate_exception(
e=ValueError("Item not found in ListConfig"), key=None, value=None
)
assert False # pragma: no cover

def count(self, x: Any) -> int:
c = 0
@@ -192,42 +224,87 @@ def count(self, x: Any) -> int:
def copy(self) -> "ListConfig":
return copy.copy(self)

def get_node(self, index: int) -> Node:
assert type(index) == int
return self.__dict__["_content"][index] # type: ignore
def get_node(self, key: Any) -> Optional[Node]:
return self.get_node_ex(key)

def get_node_ex(self, key: Any, validate_access: bool = True) -> Optional[Node]:
if self._is_none():
raise TypeError(
"Cannot get_node from a ListConfig object representing None"
)
if self._is_missing():
raise MissingMandatoryValue("Cannot get_node from a missing ListConfig")

try:
assert isinstance(self._content, list)
return self._content[key] # type: ignore
except (IndexError, TypeError) as e:
if validate_access:
self._translate_exception(e=e, key=key, value=None)
assert False # pragma: no cover
else:
return None

def get(self, index: int, default_value: Any = None) -> Any:
assert type(index) == int
return self._resolve_with_default(
key=index, value=self._content[index], default_value=default_value
)
try:
if self._is_none():
raise TypeError("Cannot get from a ListConfig object representing None")
if self._is_missing():
raise MissingMandatoryValue("Cannot get from a missing ListConfig")
self._validate_get(index, None)
assert isinstance(self._content, list)
return self._resolve_with_default(
key=index, value=self._content[index], default_value=default_value
)
except Exception as e:
self._translate_exception(e=e, key=index, value=None)
assert False # pragma: no cover

def pop(self, index: int = -1) -> Any:
if self._get_flag("readonly"):
raise ReadonlyConfigError(
self._get_full_key(str(index if index != -1 else ""))
try:
if self._get_flag("readonly"):
raise ReadonlyConfigError("Cannot pop from read-only ListConfig")
if self._is_none():
raise TypeError("Cannot pop from a ListConfig object representing None")
if self._is_missing():
raise MissingMandatoryValue("Cannot pop from a missing ListConfig")

assert isinstance(self._content, list)

return self._resolve_with_default(
key=index, value=self._content.pop(index), default_value=None
)
return self._resolve_with_default(
key=index, value=self._content.pop(index), default_value=None
)
except (ReadonlyConfigError, IndexError) as e:
self._translate_exception(e=e, key=index, value=None)
assert False # pragma: no cover

def sort(
self, key: Optional[Callable[[Any], Any]] = None, reverse: bool = False
) -> None:
if self._get_flag("readonly"):
raise ReadonlyConfigError()
try:
if self._get_flag("readonly"):
raise ReadonlyConfigError("Cannot sort a read-only ListConfig")
if self._is_none():
raise TypeError("Cannot sort a ListConfig object representing None")
if self._is_missing():
raise MissingMandatoryValue("Cannot sort a missing ListConfig")

if key is None:
if key is None:

def key1(x: Any) -> Any:
return x._value()
def key1(x: Any) -> Any:
return x._value()

else:
else:

def key1(x: Any) -> Any:
return key(x._value()) # type: ignore
def key1(x: Any) -> Any:
return key(x._value()) # type: ignore

self._content.sort(key=key1, reverse=reverse)
assert isinstance(self._content, list)
self._content.sort(key=key1, reverse=reverse)

except (ReadonlyConfigError, IndexError) as e:
self._translate_exception(e=e, key=None, value=None)
assert False # pragma: no cover

def __eq__(self, other: Any) -> bool:
if isinstance(other, (list, tuple)) or other is None:
@@ -246,21 +323,31 @@ def __hash__(self) -> int:
return hash(str(self))

def __iter__(self) -> Iterator[Any]:
class MyItems(Iterator[Any]):
def __init__(self, lst: List[Any]) -> None:
self.lst = lst
self.iterator = iter(lst)

def __next__(self) -> Any:
return self.next()

def next(self) -> Any:
v = next(self.iterator)
if isinstance(v, ValueNode):
v = v._value()
return v

return MyItems(self._content)
try:
if self._is_none():
raise TypeError("Cannot iterate on ListConfig object representing None")
if self._is_missing():
raise MissingMandatoryValue("Cannot iterate on a missing ListConfig")

class MyItems(Iterator[Any]):
def __init__(self, lst: List[Any]) -> None:
self.lst = lst
self.iterator = iter(lst)

def __next__(self) -> Any:
return self.next()

def next(self) -> Any:
v = next(self.iterator)
if isinstance(v, ValueNode):
v = v._value()
return v

assert isinstance(self._content, list)
return MyItems(self._content)
except (ReadonlyConfigError, TypeError, MissingMandatoryValue) as e:
self._translate_exception(e=e, key=None, value=None)
assert False # pragma: no cover

def __add__(self, other: Union[List[Any], "ListConfig"]) -> "ListConfig":
# res is sharing this list's parent to allow interpolation to work as expected
@@ -297,6 +384,8 @@ def _set_value(self, value: Any) -> None:
self.__dict__["_content"] = value
else:
assert is_primitive_list(value) or isinstance(value, ListConfig)
if isinstance(value, ListConfig):
self._metadata = copy.deepcopy(value._metadata)
self.__dict__["_content"] = []
for item in value:
self.append(item)
49 changes: 28 additions & 21 deletions omegaconf/nodes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import copy
import math
import sys
from enum import Enum
from typing import Any, Dict, Optional, Type, Union

from omegaconf._utils import _is_interpolation
from omegaconf._utils import _is_interpolation, get_type_of
from omegaconf.base import Container, Metadata, Node
from omegaconf.errors import (
MissingMandatoryValue,
@@ -91,6 +92,14 @@ def _is_missing(self) -> bool:
def _is_interpolation(self) -> bool:
return _is_interpolation(self._value())

def _get_full_key(self, key: Union[str, Enum, int, None]) -> str:
# TODO: add testing for get_full_key on value nodes, including value nodes without a parent.
parent = self._get_parent()
if parent is None:
return str(self._metadata.key)
else:
return parent._get_full_key(self._metadata.key)


class AnyNode(ValueNode):
def __init__(
@@ -112,8 +121,9 @@ def validate_and_convert(self, value: Any) -> Any:
from ._utils import is_primitive_type

if not is_primitive_type(value):
t = get_type_of(value)
raise UnsupportedValueType(
f"Unsupported value type, type={type(value)}, value={value}"
f"Value '{t.__name__}' is not a supported primitive type"
)
return value

@@ -173,9 +183,7 @@ def validate_and_convert(self, value: Any) -> Optional[int]:
else:
raise ValueError()
except ValueError:
raise ValidationError(
f"Value '{value}' could not be converted to Integer"
) from None
raise ValidationError("Value '$VALUE' could not be converted to Integer")
return val

def __deepcopy__(self, memo: Dict[int, Any] = {}) -> "IntegerNode":
@@ -209,9 +217,7 @@ def validate_and_convert(self, value: Any) -> Optional[float]:
else:
raise ValueError()
except ValueError:
raise ValidationError(
f"Value '{value}' could not be converted to float"
) from None
raise ValidationError("Value '$VALUE' could not be converted to Float")

def __eq__(self, other: Any) -> bool:
if isinstance(other, ValueNode):
@@ -263,18 +269,18 @@ def validate_and_convert(self, value: Any) -> Optional[bool]:
elif isinstance(value, str):
try:
return self.validate_and_convert(int(value))
except ValueError:
except ValueError as e:
if value.lower() in ("yes", "y", "on", "true"):
return True
elif value.lower() in ("no", "n", "off", "false"):
return False
else:
raise ValidationError(
"Value '{}' is not a valid bool".format(value)
) from None
"Value '$VALUE' is not a valid bool (type $VALUE_TYPE)"
).with_traceback(sys.exc_info()[2]) from e
else:
raise ValidationError(
f"Value '{value}' is not a valid bool (type {type(value).__name__})"
"Value '$VALUE' is not a valid bool (type $VALUE_TYPE)"
)

def __deepcopy__(self, memo: Dict[int, Any] = {}) -> "BooleanNode":
@@ -316,19 +322,20 @@ def __init__(
)

def validate_and_convert(self, value: Any) -> Optional[Enum]:
return self.validate_and_convert_to_enum(enum_type=self.enum_type, value=value)
return self.validate_and_convert_to_enum(
self, enum_type=self.enum_type, value=value
)

@staticmethod
def validate_and_convert_to_enum(
enum_type: Type[Enum], value: Any
node: Node, enum_type: Type[Enum], value: Any
) -> Optional[Enum]:
if value is None:
return None

if not isinstance(value, (str, int)) and not isinstance(value, enum_type):
# if type(value) not in (str, int) and not isinstance(value, enum_type):
raise ValidationError(
f"Value {value} ({type(value).__name__}) is not a valid input for {enum_type}"
f"Value $VALUE ($VALUE_TYPE) is not a valid input for {enum_type}"
)

if isinstance(value, enum_type):
@@ -339,21 +346,21 @@ def validate_and_convert_to_enum(
raise ValueError

if isinstance(value, int):
return enum_type(value)
return enum_type(value) # TODO: does this ever work??

if isinstance(value, str):
prefix = "{}.".format(enum_type.__name__)
prefix = f"{enum_type.__name__}."
if value.startswith(prefix):
value = value[len(prefix) :]
return enum_type[value]

assert False # pragma: no cover

except (ValueError, KeyError):
except (ValueError, KeyError) as e:
valid = "\n".join([f"\t{x}" for x in enum_type.__members__.keys()])
raise ValidationError(
f"Invalid value '{value}', expected one of:\n{valid}"
) from None
f"Invalid value '$VALUE', expected one of:\n{valid}"
).with_traceback(sys.exc_info()[2]) from e

def __deepcopy__(self, memo: Dict[int, Any] = {}) -> "EnumNode":
res = EnumNode(enum_type=self.enum_type)
40 changes: 29 additions & 11 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
@@ -109,6 +109,7 @@ def env(key: str, default: Optional[str] = None) -> Any:
if default is not None:
return decode_primitive(default)
else:
# TODO: validate and test
raise KeyError("Environment variable '{}' not found".format(key))

OmegaConf.register_resolver("env", env)
@@ -148,6 +149,12 @@ def create(
@staticmethod
def create( # noqa F811
obj: Any = _EMPTY_MARKER_, parent: Optional[BaseContainer] = None
) -> Union[DictConfig, ListConfig]:
return OmegaConf._create_impl(obj=obj, parent=parent)

@staticmethod
def _create_impl( # noqa F811
obj: Any = _EMPTY_MARKER_, parent: Optional[BaseContainer] = None
) -> Union[DictConfig, ListConfig]:
from ._utils import get_yaml_loader
from .dictconfig import DictConfig
@@ -185,7 +192,7 @@ def create( # noqa F811
)
else:
raise ValidationError(
"Unsupported type {}".format(type(obj).__name__)
f"Object of unsupported type: '{type(obj).__name__}'"
)

@staticmethod
@@ -376,6 +383,8 @@ def to_container(
def is_missing(cfg: BaseContainer, key: Union[int, str]) -> bool:
try:
node = cfg.get_node(key)
if node is None:
return False
return node._is_missing()
except (UnsupportedInterpolationType, KeyError, AttributeError):
return False
@@ -563,6 +572,8 @@ def _maybe_wrap(
from . import DictConfig, ListConfig

if isinstance(value, ValueNode):
value._set_key(key)
value._set_parent(parent)
return value
ret: Node # pragma: no cover
origin_ = getattr(ref_type, "__origin__", None)
@@ -635,7 +646,7 @@ def _maybe_wrap(


def _select_one(
c: Container, key: str, throw_on_missing: bool
c: Container, key: str, throw_on_missing: bool, throw_on_type_error: bool = True
) -> Tuple[Optional[Node], Union[str, int]]:
from .dictconfig import DictConfig
from .listconfig import ListConfig
@@ -644,25 +655,32 @@ def _select_one(
assert isinstance(c, (DictConfig, ListConfig)), f"Unexpected type : {c}"
if isinstance(c, DictConfig):
assert isinstance(ret_key, str)
val: Optional[Node]
if c.get_node_ex(ret_key, validate_access=False) is not None:
val = c.get_node(ret_key)
val: Optional[Node] = c.get_node_ex(ret_key, validate_access=False)
if val is not None:
if val._is_missing():
if throw_on_missing:
raise MissingMandatoryValue(c._get_full_key(ret_key))
raise MissingMandatoryValue(
f"Missing mandatory value : {c._get_full_key(ret_key)}"
)
else:
return val, ret_key
else:
val = None
elif isinstance(c, ListConfig):
assert isinstance(ret_key, str)
if not isint(ret_key):
raise TypeError("Index {} is not an int".format(ret_key))
ret_key = int(ret_key)
if ret_key < 0 or ret_key + 1 > len(c):
val = None
if throw_on_type_error:
raise TypeError(
f"Index '{ret_key}' ({type(ret_key).__name__}) is not an int"
)
else:
val = None
else:
val = c.get_node(ret_key)
ret_key = int(ret_key)
if ret_key < 0 or ret_key + 1 > len(c):
val = None
else:
val = c.get_node(ret_key)
else:
assert False # pragma: no cover

6 changes: 4 additions & 2 deletions tests/examples/test_dataclass_example.py
Original file line number Diff line number Diff line change
@@ -277,7 +277,9 @@ def test_enum_key() -> None:
def test_dict_of_objects() -> None:
conf: WebServer = OmegaConf.structured(WebServer)
conf.domains["blog"] = Domain(name="blog.example.com", path="/www/blog.example.com")
with pytest.raises(ValidationError):
with pytest.raises(
ValidationError
): # TODO: improve exception, error makes no sense.
conf.domains.foo = 10 # type: ignore

assert conf.domains["blog"].name == "blog.example.com"
@@ -337,7 +339,7 @@ class Config:

schema: Config = OmegaConf.structured(Config)
cfg = OmegaConf.create(yaml)
merged = OmegaConf.merge(schema, cfg)
merged: Any = OmegaConf.merge(schema, cfg)
assert merged == {
"num": 10,
"user": {"name": "Omry", "height": "???"},
2 changes: 1 addition & 1 deletion tests/structured_conf/test_structured_basic.py
Original file line number Diff line number Diff line change
@@ -49,7 +49,7 @@ def test_merge(self, class_type: str) -> None:
cfg1 = OmegaConf.create({"plugin": module.Plugin})
cfg2 = OmegaConf.create({"plugin": module.ConcretePlugin})
assert cfg2.plugin == module.ConcretePlugin
res = OmegaConf.merge(cfg1, cfg2)
res: Any = OmegaConf.merge(cfg1, cfg2)
assert OmegaConf.get_type(res.plugin) == module.ConcretePlugin
assert (
OmegaConf.get_type(res.plugin.params)
28 changes: 17 additions & 11 deletions tests/test_base_config.py
Original file line number Diff line number Diff line change
@@ -269,35 +269,41 @@ def test_deepcopy_preserves_container_type(cfg: Container) -> None:


@pytest.mark.parametrize( # type: ignore
"src, flag_name, flag_value, func, expectation",
"src, flag_name, func, expectation",
[
({}, "struct", False, lambda c: c.__setitem__("foo", 1), raises(KeyError),),
(
pytest.param(
{},
"struct",
lambda c: c.__setitem__("foo", 1),
raises(KeyError),
id="struct_setiitem",
),
pytest.param(
{},
"struct",
False,
lambda c: c.__setattr__("foo", 1),
raises(AttributeError),
id="struct_setattr",
),
(
pytest.param(
{},
"readonly",
False,
lambda c: c.__setitem__("foo", 1),
raises(ReadonlyConfigError),
id="readonly",
),
],
)
def test_flag_override(
src: Dict[str, Any], flag_name: str, flag_value: bool, func: Any, expectation: Any
src: Dict[str, Any], flag_name: str, func: Any, expectation: Any
) -> None:
c = OmegaConf.create(src)
c._set_flag(flag_name, True)
with expectation:
func(c)

with does_not_raise():
with flag_override(c, flag_name, flag_value):
with flag_override(c, flag_name, False):
func(c)


@@ -446,7 +452,7 @@ def test_omegaconf_create() -> None:


@pytest.mark.parametrize( # type: ignore
"parent, index, value, expected",
"parent, key, value, expected",
[
([10, 11], 0, ["a", "b"], [["a", "b"], 11]),
([None], 0, {"foo": "bar"}, [{"foo": "bar"}]),
@@ -456,7 +462,7 @@ def test_omegaconf_create() -> None:
({}, "foo", OmegaConf.create({"foo": "bar"}), {"foo": {"foo": "bar"}}),
],
)
def test_assign(parent: Any, index: Union[str, int], value: Any, expected: Any) -> None:
def test_assign(parent: Any, key: Union[str, int], value: Any, expected: Any) -> None:
c = OmegaConf.create(parent)
c[index] = value
c[key] = value
assert c == expected
26 changes: 15 additions & 11 deletions tests/test_basic_ops_dict.py
Original file line number Diff line number Diff line change
@@ -7,13 +7,13 @@

from omegaconf import (
DictConfig,
KeyValidationError,
MissingMandatoryValue,
OmegaConf,
UnsupportedValueType,
ValidationError,
)
from omegaconf.basecontainer import BaseContainer
from omegaconf.errors import KeyValidationError

from . import (
ConcretePlugin,
@@ -52,11 +52,13 @@ def test_setattr_deep_map() -> None:

def test_getattr() -> None:
c = OmegaConf.create("a: b")
assert isinstance(c, DictConfig)
assert "b" == c.a


def test_getattr_dict() -> None:
c = OmegaConf.create("a: {b: 1}")
assert isinstance(c, DictConfig)
assert {"b": 1} == c.a


@@ -131,6 +133,7 @@ def test_get_default_value() -> None:

def test_scientific_notation_float() -> None:
c = OmegaConf.create("a: 10e-3")
assert isinstance(c, DictConfig)
assert 10e-3 == c.a


@@ -187,6 +190,7 @@ def test_items_with_interpolation() -> None:

def test_dict_keys() -> None:
c = OmegaConf.create("{a: 2, b: 10}")
assert isinstance(c, DictConfig)
assert {"a": 2, "b": 10}.keys() == c.keys()


@@ -261,6 +265,9 @@ def test_dict_pop(
assert type(val) == type(expected)


# TODO: test that a failed pop does not mutate the dict


@pytest.mark.parametrize( # type: ignore
"conf,key,expected",
[
@@ -293,6 +300,7 @@ def test_get_root_of_merged() -> None:

c2 = OmegaConf.create(dict(b=dict(b1="???", b2=4, bb=dict(bb1=3, bb2=4))))
c3 = OmegaConf.merge(c1, c2)
assert isinstance(c3, DictConfig)

assert c3._get_root() == c3
assert c3.a._get_root() == c3
@@ -531,11 +539,11 @@ def test_is_missing() -> None:
"missing_node_inter": "${missing_node}",
}
)
assert cfg.get_node("foo")._is_missing()
assert cfg.get_node("inter")._is_missing()
assert cfg.get_node("str_inter")._is_missing()
assert cfg.get_node("missing_node")._is_missing()
assert cfg.get_node("missing_node_inter")._is_missing()
assert cfg.get_node("foo")._is_missing() # type:ignore
assert cfg.get_node("inter")._is_missing() # type:ignore
assert cfg.get_node("str_inter")._is_missing() # type:ignore
assert cfg.get_node("missing_node")._is_missing() # type:ignore
assert cfg.get_node("missing_node_inter")._is_missing() # type:ignore


@pytest.mark.parametrize("ref_type", [None, Any]) # type: ignore
@@ -600,9 +608,5 @@ def test_assign_to_reftype_plugin(
)
with expectation():
cfg2 = OmegaConf.merge(cfg, {"foo": assign})
assert isinstance(cfg2, DictConfig)
assert cfg2.foo == assign


# TODO:
# define behavior when interpolation would cause an incompatibility between ref type and object type.
# Test that assignment changes object type but not ref type
87 changes: 70 additions & 17 deletions tests/test_basic_ops_list.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,11 @@
import pytest

from omegaconf import AnyNode, ListConfig, OmegaConf
from omegaconf.errors import KeyValidationError, UnsupportedValueType
from omegaconf.errors import (
KeyValidationError,
MissingMandatoryValue,
UnsupportedValueType,
)
from omegaconf.nodes import IntegerNode, StringNode

from . import IllegalType, does_not_raise
@@ -119,7 +123,7 @@ def test_list_config_with_tuple() -> None:
def test_items_on_list() -> None:
c = OmegaConf.create([1, 2])
with pytest.raises(AttributeError):
c.items()
c.items() # type: ignore


def test_list_enumerate() -> None:
@@ -143,14 +147,26 @@ def test_list_delitem() -> None:
del c[100]


def test_list_len() -> None:
c = OmegaConf.create([1, 2])
assert len(c) == 2
@pytest.mark.parametrize( # type: ignore
"lst,expected",
[
(OmegaConf.create([1, 2]), 2),
(ListConfig(content=None), 0),
(ListConfig(content="???"), 0),
],
)
def test_list_len(lst: Any, expected: Any) -> None:
assert len(lst) == expected


def test_nested_list_assign_illegal_value() -> None:
c = OmegaConf.create(dict(a=[None]))
with pytest.raises(UnsupportedValueType, match=re.escape("key: a[0]")):
c = OmegaConf.create({"a": [None]})
with pytest.raises(
UnsupportedValueType,
match=re.escape(
"Value 'IllegalType' is not a supported primitive type\n\tfull_key: a[0]"
),
):
c.a[0] = IllegalType()


@@ -194,15 +210,6 @@ def test_list_dir() -> None:
assert ["0", "1", "2"] == dir(c)


def test_getattr() -> None:
c = OmegaConf.create(["a", "b", "c"])
assert getattr(c, "0") == "a"
assert getattr(c, "1") == "b"
assert getattr(c, "2") == "c"
with pytest.raises(AttributeError):
getattr(c, "anything")


@pytest.mark.parametrize( # type: ignore
"input_, index, value, expected, expected_node_type",
[
@@ -221,6 +228,18 @@ def test_insert(
assert type(c.get_node(index)) == expected_node_type


@pytest.mark.parametrize( # type: ignore
"lst,idx,value,expectation",
[
(ListConfig(content=None), 0, 10, pytest.raises(TypeError)),
(ListConfig(content="???"), 0, 10, pytest.raises(MissingMandatoryValue)),
],
)
def test_insert_special_list(lst: Any, idx: Any, value: Any, expectation: Any) -> None:
with expectation:
lst.insert(idx, value)


@pytest.mark.parametrize( # type: ignore
"src, append, result",
[
@@ -320,8 +339,9 @@ def test_insert_throws_not_changing_list() -> None:

def test_append_throws_not_changing_list() -> None:
c = OmegaConf.create([])
v = IllegalType()
with pytest.raises(ValueError):
c.append(IllegalType())
c.append(v)
assert len(c) == 0
assert c == []

@@ -372,3 +392,36 @@ def test_set_with_invalid_key() -> None:
cfg = OmegaConf.create([1, 2, 3])
with pytest.raises(KeyValidationError):
cfg["foo"] = 4 # type: ignore


@pytest.mark.parametrize( # type: ignore
"lst,idx,expected",
[
(OmegaConf.create([1, 2]), 0, 1),
(ListConfig(content=None), 0, TypeError),
(ListConfig(content="???"), 0, MissingMandatoryValue),
],
)
def test_getitem(lst: Any, idx: Any, expected: Any) -> None:
if isinstance(expected, type):
with pytest.raises(expected):
lst.__getitem__(idx)
else:
lst.__getitem__(idx) == expected


@pytest.mark.parametrize( # type: ignore
"lst,idx,expected",
[
(OmegaConf.create([1, 2]), 0, 1),
(OmegaConf.create([1, 2]), "foo", KeyValidationError),
(ListConfig(content=None), 0, TypeError),
(ListConfig(content="???"), 0, MissingMandatoryValue),
],
)
def test_get(lst: Any, idx: Any, expected: Any) -> None:
if isinstance(expected, type):
with pytest.raises(expected):
lst.get(idx)
else:
lst.__getitem__(idx) == expected
7 changes: 6 additions & 1 deletion tests/test_create.py
Original file line number Diff line number Diff line change
@@ -86,8 +86,9 @@ def test_create_list_with_illegal_value_idx0() -> None:


def test_create_list_with_illegal_value_idx1() -> None:
lst = [1, IllegalType(), 3]
with pytest.raises(UnsupportedValueType, match=re.escape("key: [1]")):
OmegaConf.create([1, IllegalType(), 3])
OmegaConf.create(lst)


def test_create_dict_with_illegal_value() -> None:
@@ -113,3 +114,7 @@ def test_create_from_oc_with_flags() -> None:
c2 = OmegaConf.create(c1)
assert c1 == c2
assert c1._metadata.flags == c2._metadata.flags


# TODO: test that DictConfig created with OmegaConf.create(DictConfig(...))
# is preserving the metadata of the inner dict config, same for listconfig
449 changes: 449 additions & 0 deletions tests/test_errors.py

Large diffs are not rendered by default.

13 changes: 5 additions & 8 deletions tests/test_matrix.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import re
from typing import Any

import pytest
@@ -67,19 +68,14 @@ def verify(
lambda value, is_optional, key=None: DictConfig(
is_optional=is_optional, content=value, key=key
),
[
{},
{"foo": "bar"},
OmegaConf.create({}),
OmegaConf.create({"foo": "bar"}),
],
[{}, {"foo": "bar"}],
),
# ListConfig
(
lambda value, is_optional, key=None: ListConfig(
is_optional=is_optional, content=value, key=key
),
[[], [1, 2, 3], OmegaConf.create([]), OmegaConf.create([1, 2, 3])],
[[], [1, 2, 3]],
),
# dataclass
(
@@ -110,7 +106,8 @@ def test_none_assignment_and_merging_in_dict(
data = {"node": node}
cfg = OmegaConf.create(obj=data)
verify(cfg, "node", none=False, opt=False, missing=False, inter=False)
with pytest.raises(ValidationError):
msg = "field 'node' is not Optional"
with pytest.raises(ValidationError, match=re.escape(msg)):
cfg.node = None

with pytest.raises(ValidationError):
2 changes: 1 addition & 1 deletion tests/test_nodes.py
Original file line number Diff line number Diff line change
@@ -419,7 +419,7 @@ class UserClass:

from omegaconf.omegaconf import _node_wrap

with pytest.raises(ValueError):
with pytest.raises(ValidationError):
_node_wrap(
type_=UserClass, value=UserClass(), is_optional=False, parent=None, key=None
)
2 changes: 1 addition & 1 deletion tests/test_readonly.py
Original file line number Diff line number Diff line change
@@ -151,5 +151,5 @@ def test_readonly_from_cli() -> None:
assert isinstance(c, DictConfig)
OmegaConf.set_readonly(c, True)
cli = OmegaConf.from_dotlist(["foo.bar=[2]"])
with raises(ReadonlyConfigError, match="foo"):
with raises(ReadonlyConfigError):
OmegaConf.merge(c, cli)
99 changes: 54 additions & 45 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -9,66 +9,75 @@
from omegaconf.errors import KeyValidationError, ValidationError
from omegaconf.nodes import StringNode

from . import Color, does_not_raise
from . import Color, IllegalType, does_not_raise


@pytest.mark.parametrize( # type: ignore
"target_type, value, expectation",
[
# Any
(Any, "foo", does_not_raise()),
(Any, True, does_not_raise()),
(Any, 1, does_not_raise()),
(Any, 1.0, does_not_raise()),
(Any, Color.RED, does_not_raise()),
(Any, "foo", None),
(Any, True, None),
(Any, 1, None),
(Any, 1.0, None),
(Any, Color.RED, None),
# int
(int, "foo", pytest.raises(ValidationError)),
(int, True, pytest.raises(ValidationError)),
(int, 1, does_not_raise()),
(int, 1.0, pytest.raises(ValidationError)),
(int, Color.RED, pytest.raises(ValidationError)),
(int, "foo", ValidationError),
(int, True, ValidationError),
(int, 1, None),
(int, 1.0, ValidationError),
(int, Color.RED, ValidationError),
# float
(float, "foo", pytest.raises(ValidationError)),
(float, True, pytest.raises(ValidationError)),
(float, 1, does_not_raise()),
(float, 1.0, does_not_raise()),
(float, Color.RED, pytest.raises(ValidationError)),
(float, "foo", ValidationError),
(float, True, ValidationError),
(float, 1, None),
(float, 1.0, None),
(float, Color.RED, ValidationError),
# bool
(bool, "foo", pytest.raises(ValidationError)),
(bool, True, does_not_raise()),
(bool, 1, does_not_raise()),
(bool, 0, does_not_raise()),
(bool, 1.0, pytest.raises(ValidationError)),
(bool, Color.RED, pytest.raises(ValidationError)),
(bool, "true", does_not_raise()),
(bool, "false", does_not_raise()),
(bool, "on", does_not_raise()),
(bool, "off", does_not_raise()),
(bool, "foo", ValidationError),
(bool, True, None),
(bool, 1, None),
(bool, 0, None),
(bool, 1.0, ValidationError),
(bool, Color.RED, ValidationError),
(bool, "true", None),
(bool, "false", None),
(bool, "on", None),
(bool, "off", None),
# str
(str, "foo", does_not_raise()),
(str, True, does_not_raise()),
(str, 1, does_not_raise()),
(str, 1.0, does_not_raise()),
(str, Color.RED, does_not_raise()),
(str, "foo", None),
(str, True, None),
(str, 1, None),
(str, 1.0, None),
(str, Color.RED, None),
# Color
(Color, "foo", pytest.raises(ValidationError)),
(Color, True, pytest.raises(ValidationError)),
(Color, 1, does_not_raise()),
(Color, 1.0, pytest.raises(ValidationError)),
(Color, Color.RED, does_not_raise()),
(Color, "RED", does_not_raise()),
(Color, "Color.RED", does_not_raise()),
(Color, "foo", ValidationError),
(Color, True, ValidationError),
(Color, 1, None),
(Color, 1.0, ValidationError),
(Color, Color.RED, None),
(Color, "RED", None),
(Color, "Color.RED", None),
# bad type
(Exception, "nope", pytest.raises(ValueError)),
(IllegalType, "nope", ValidationError),
],
)
def test_maybe_wrap(target_type: type, value: Any, expectation: Any) -> None:
with expectation:
from omegaconf.omegaconf import _maybe_wrap
from omegaconf.omegaconf import _maybe_wrap

if expectation is None:
_maybe_wrap(
ref_type=target_type, key=None, value=value, is_optional=False, parent=None,
)
else:
with pytest.raises(expectation):
_maybe_wrap(
ref_type=target_type,
key=None,
value=value,
is_optional=False,
parent=None,
)


class _TestEnum(Enum):
@@ -202,8 +211,8 @@ def test_value_kind(value: Any, kind: _utils.ValueKind) -> None:
def test_re_parent() -> None:
def validate(cfg1: DictConfig) -> None:
assert cfg1._get_parent() is None
assert cfg1.get_node("str")._get_parent() == cfg1
assert cfg1.get_node("list")._get_parent() == cfg1
assert cfg1.get_node("str")._get_parent() == cfg1 # type:ignore
assert cfg1.get_node("list")._get_parent() == cfg1 # type:ignore
assert cfg1.list.get_node(0)._get_parent() == cfg1.list

cfg = OmegaConf.create({})
@@ -213,8 +222,8 @@ def validate(cfg1: DictConfig) -> None:

validate(cfg)

cfg.get_node("str")._set_parent(None)
cfg.get_node("list")._set_parent(None)
cfg.get_node("str")._set_parent(None) # type:ignore
cfg.get_node("list")._set_parent(None) # type:ignore
cfg.list.get_node(0)._set_parent(None) # type: ignore
# noinspection PyProtectedMember
cfg._re_parent()

0 comments on commit 9ddbc0c

Please sign in to comment.