Skip to content

Commit

Permalink
cleanup and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
omry committed Apr 5, 2020
1 parent a236556 commit 2c6ea6a
Show file tree
Hide file tree
Showing 20 changed files with 895 additions and 722 deletions.
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
source_suffix = ".rst"

# The master toctree document.
master_doc = "key"
master_doc = "index"

# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
Expand Down
26 changes: 20 additions & 6 deletions omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,17 +323,30 @@ def is_primitive_container(obj: Any) -> bool:
return is_primitive_list(obj) or is_primitive_dict(obj)


def get_key_value_types(annotated_type: Any) -> Tuple[Any, Any]:
def get_list_element_type(ref_type: Optional[Type[Any]]) -> Optional[Type[Any]]:
args = getattr(ref_type, "__args__", None)
if ref_type is not List and args is not None and args[0] is not Any:
element_type = args[0]
else:
element_type = None

if not (valid_value_annotation_type(element_type)):
raise ValidationError(f"Unsupported value type : {element_type}")
assert element_type is None or isinstance(element_type, type)
return element_type


def get_dict_key_value_types(ref_type: Any) -> Tuple[Any, Any]:

args = getattr(annotated_type, "__args__", None)
args = getattr(ref_type, "__args__", None)
if args is None:
bases = getattr(annotated_type, "__orig_bases__", None)
bases = getattr(ref_type, "__orig_bases__", None)
if bases is not None and len(bases) > 0:
args = getattr(bases[0], "__args__", None)

key_type: Any
element_type: Any
if annotated_type is None:
if ref_type is None:
key_type = None
element_type = None
else:
Expand Down Expand Up @@ -406,9 +419,9 @@ def format_and_raise(
node: Any,
key: Any,
value: Any,
exception_type: Any,
msg: str,
cause: Optional[Exception] = None,
cause: Exception,
type_override: Any = None,
) -> None:
def type_str(t: Any) -> str:
if isinstance(t, type):
Expand Down Expand Up @@ -469,4 +482,5 @@ def type_str(t: Any) -> str:
message = s.substitute(
REF_TYPE=rt, OBJECT_TYPE=object_type, MSG=msg, FULL_KEY=full_key,
)
exception_type = type(cause) if type_override is None else type_override
raise exception_type(f"{message}").with_traceback(sys.exc_info()[2]) from cause
43 changes: 25 additions & 18 deletions omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,19 +84,22 @@ def _get_flag(self, flag: str) -> Optional[bool]:
# noinspection PyProtectedMember
return parent._get_flag(flag)

# TODO: move to utils
def _translate_exception(
self, e: Exception, key: Any, value: Any, type_override: Any = None
def _format_and_raise(
self, key: Any, value: Any, cause: Exception, type_override: Any = None,
) -> None:
etype = type(e) if type_override is None else type_override
format_and_raise(
exception_type=etype, node=self, key=key, value=value, msg=str(e), cause=e,
node=self,
key=key,
value=value,
msg=str(cause),
cause=cause,
type_override=type_override,
)
assert False # pragma: no cover
assert False

@abstractmethod
def _get_full_key(self, key: Union[str, Enum, int, None]) -> str:
...
... # pragma: no cover

def _dereference_node(self, throw_on_missing: bool = False) -> "Node":
from .nodes import StringNode
Expand Down Expand Up @@ -127,7 +130,7 @@ def _dereference_node(self, throw_on_missing: bool = False) -> "Node":
parent=parent,
is_optional=self._metadata.optional,
)
assert False # pragma: no cover
assert False
else:
# not interpolation, compare directly
if throw_on_missing:
Expand Down Expand Up @@ -202,7 +205,7 @@ def get_node(self, key: Any) -> Optional[Node]:
... # pragma: no cover

@abstractmethod
def __delitem__(self, key: Union[str, int, slice]) -> None:
def __delitem__(self, key: Any) -> None:
... # pragma: no cover

@abstractmethod
Expand Down Expand Up @@ -286,14 +289,18 @@ def _resolve_interpolation(
else:
resolver = OmegaConf.get_resolver(inter_type)
if resolver is not None:
value = resolver(root_node, inter_key)
return ValueNode(
value=value,
parent=self,
metadata=Metadata(
ref_type=None, object_type=None, key=key, optional=True
),
)
try:
value = resolver(root_node, inter_key)
return ValueNode(
value=value,
parent=self,
metadata=Metadata(
ref_type=None, object_type=None, key=key, optional=True
),
)
except Exception as e:
self._format_and_raise(key=inter_key, value=None, cause=e)
assert False
else:
raise UnsupportedInterpolationType(
f"Unsupported interpolation type {inter_type}"
Expand Down Expand Up @@ -336,4 +343,4 @@ def _resolve_str_interpolation(
new += orig[last_index:]
return StringNode(value=new, key=key)
else:
assert False # pragma: no cover
assert False
115 changes: 26 additions & 89 deletions omegaconf/basecontainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import copy
import string
import sys
import warnings
from abc import ABC, abstractmethod
Expand All @@ -21,9 +20,9 @@
is_structured_config,
)
from .base import Container, ContainerMetadata, Node
from .errors import MissingMandatoryValue, ReadonlyConfigError, ValidationError
from .errors import MissingMandatoryValue, ReadonlyConfigError

DEFAULT_VALUE_MARKER: Any = object()
DEFAULT_VALUE_MARKER: Any = str("__DEFAULT_VALUE_MARKER__")


class BaseContainer(Container, ABC):
Expand Down Expand Up @@ -94,10 +93,9 @@ def __getstate__(self) -> Dict[str, Any]:
def __setstate__(self, d: Dict[str, Any]) -> None:
self.__dict__.update(d)

def __delitem__(self, key: Union[str, int, slice]) -> None:
if self._get_flag("readonly"):
raise ReadonlyConfigError(self._get_full_key(str(key)))
del self.__dict__["_content"][key]
@abstractmethod
def __delitem__(self, key: Any) -> None:
... # pragma: no cover

def __len__(self) -> int:
return self.__dict__["_content"].__len__() # type: ignore
Expand Down Expand Up @@ -169,7 +167,7 @@ def select(self, key: str, throw_on_missing: bool = False) -> Any:

return _get_value(value)
except Exception as e:
self._translate_exception(e=e, key=key, value=None)
self._format_and_raise(key=key, value=None, cause=e)

def is_empty(self) -> bool:
"""return true if config is empty"""
Expand Down Expand Up @@ -213,7 +211,7 @@ def convert(val: Any) -> Any:
retlist.append(item)
return retlist

assert False # pragma: no cover
assert False

def to_container(self, resolve: bool = False) -> Union[Dict[str, Any], List[Any]]:
warnings.warn(
Expand Down Expand Up @@ -256,22 +254,13 @@ def _map_merge(dest: "BaseContainer", src: "BaseContainer") -> None:
type_backup = dest._metadata.object_type
dest._metadata.object_type = None

if (dest._get_flag("readonly")) or dest._get_flag("readonly"):
raise ReadonlyConfigError("Cannot merge into read-only node")
dest._validate_set_merge_impl(key=None, value=src, is_assign=False)
for key, src_value in src.items_ex(resolve=False):
if dest._is_missing():
if dest._metadata.ref_type is not None and not is_dict_annotation(
dest._metadata.ref_type
):
dest._set_value(dest._metadata.ref_type())
else:
dest._set_value({})
dest_element_type = dest._metadata.element_type
element_typed = dest_element_type not in (None, Any)
if OmegaConf.is_missing(dest, key):
if isinstance(src_value, DictConfig):
if key not in dest:
if OmegaConf.is_missing(dest, key):
dest[key] = src_value

if (dest.get_node(key) is not None) or element_typed:
Expand Down Expand Up @@ -388,79 +377,27 @@ def assign(value_key: Any, value_to_assign: Any) -> None:
value_to_assign._set_key(value_key)
self.__dict__["_content"][value_key] = value_to_assign

try:
if is_primitive_container(value):
self.__dict__["_content"][key] = wrap(key, value)
elif input_node and target_node:
# both nodes, replace existing node with new one
assign(key, value)
elif not input_node and target_node:
# input is not node, can be primitive or config
if input_config:
assign(key, value)
else:
self.__dict__["_content"][key]._set_value(value)
elif input_node and not target_node:
# target must be config, replace target with input node
if is_primitive_container(value):
self.__dict__["_content"][key] = wrap(key, value)
elif input_node and target_node:
# both nodes, replace existing node with new one
assign(key, value)
elif not input_node and target_node:
# input is not node, can be primitive or config
if input_config:
assign(key, value)
elif not input_node and not target_node:
if should_set_value:
self.__dict__["_content"][key]._set_value(value)
elif input_config:
assign(key, value)
else:
self.__dict__["_content"][key] = wrap(key, value)
except ValidationError as ve:
self._format_and_raise(
exception_type=ValidationError, key=key, msg=f"{ve}",
)

# TODO: cleanup
def _format_and_raise(
self, exception_type: Any, msg: str, key: Optional[str]
) -> None:
def type_str(t: Any) -> str:
if isinstance(t, type):
return t.__name__
else:
return str(node._metadata.object_type)

full_key = self._get_full_key(key=key)
if key is None:
node = self
else:
node = (
self.__dict__["_content"][key]
if key in self.__dict__["_content"]
else None
)
if node is None:
node = self

ref_type: Any = node._metadata.ref_type
if ref_type is None:
ref_type = Any
object_type = type_str(node._metadata.object_type)

rt = type_str(ref_type)
if node._metadata.optional:
if ref_type is Any:
rt = "Any"
self.__dict__["_content"][key]._set_value(value)
elif input_node and not target_node:
# target must be config, replace target with input node
assign(key, value)
elif not input_node and not target_node:
if should_set_value:
self.__dict__["_content"][key]._set_value(value)
elif input_config:
assign(key, value)
else:
rt = f"Optional[{rt}]"

s = string.Template(
"""$MSG
\tfull_key: $FULL_KEY
\treference_type=$REF_TYPE
\tobject_type=$OBJECT_TYPE
"""
)
message = s.substitute(
REF_TYPE=rt, OBJECT_TYPE=object_type, MSG=msg, FULL_KEY=full_key,
)

raise exception_type(f"{message}").with_traceback(sys.exc_info()[2]) from None
self.__dict__["_content"][key] = wrap(key, value)

@staticmethod
def _item_eq(
Expand Down
Loading

0 comments on commit 2c6ea6a

Please sign in to comment.