diff --git a/docs/source/conf.py b/docs/source/conf.py index f0157ca56..c570079bb 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -61,7 +61,7 @@ source_suffix = ".rst" # The master toctree document. -master_doc = "key" +master_doc = "index" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/omegaconf/_utils.py b/omegaconf/_utils.py index 57f5f7261..436c0565a 100644 --- a/omegaconf/_utils.py +++ b/omegaconf/_utils.py @@ -323,17 +323,30 @@ def is_primitive_container(obj: Any) -> bool: return is_primitive_list(obj) or is_primitive_dict(obj) -def get_key_value_types(annotated_type: Any) -> Tuple[Any, Any]: +def get_list_element_type(ref_type: Optional[Type[Any]]) -> Optional[Type[Any]]: + args = getattr(ref_type, "__args__", None) + if ref_type is not List and args is not None and args[0] is not Any: + element_type = args[0] + else: + element_type = None + + if not (valid_value_annotation_type(element_type)): + raise ValidationError(f"Unsupported value type : {element_type}") + assert element_type is None or isinstance(element_type, type) + return element_type + + +def get_dict_key_value_types(ref_type: Any) -> Tuple[Any, Any]: - args = getattr(annotated_type, "__args__", None) + args = getattr(ref_type, "__args__", None) if args is None: - bases = getattr(annotated_type, "__orig_bases__", None) + bases = getattr(ref_type, "__orig_bases__", None) if bases is not None and len(bases) > 0: args = getattr(bases[0], "__args__", None) key_type: Any element_type: Any - if annotated_type is None: + if ref_type is None: key_type = None element_type = None else: @@ -406,9 +419,9 @@ def format_and_raise( node: Any, key: Any, value: Any, - exception_type: Any, msg: str, - cause: Optional[Exception] = None, + cause: Exception, + type_override: Any = None, ) -> None: def type_str(t: Any) -> str: if isinstance(t, type): @@ -469,4 +482,5 @@ def type_str(t: Any) -> str: message = s.substitute( REF_TYPE=rt, OBJECT_TYPE=object_type, MSG=msg, FULL_KEY=full_key, ) + exception_type = type(cause) if type_override is None else type_override raise exception_type(f"{message}").with_traceback(sys.exc_info()[2]) from cause diff --git a/omegaconf/base.py b/omegaconf/base.py index b8aa14503..6249b5197 100644 --- a/omegaconf/base.py +++ b/omegaconf/base.py @@ -84,19 +84,22 @@ def _get_flag(self, flag: str) -> Optional[bool]: # noinspection PyProtectedMember return parent._get_flag(flag) - # TODO: move to utils - def _translate_exception( - self, e: Exception, key: Any, value: Any, type_override: Any = None + def _format_and_raise( + self, key: Any, value: Any, cause: Exception, type_override: Any = None, ) -> None: - etype = type(e) if type_override is None else type_override format_and_raise( - exception_type=etype, node=self, key=key, value=value, msg=str(e), cause=e, + node=self, + key=key, + value=value, + msg=str(cause), + cause=cause, + type_override=type_override, ) - assert False # pragma: no cover + assert False @abstractmethod def _get_full_key(self, key: Union[str, Enum, int, None]) -> str: - ... + ... # pragma: no cover def _dereference_node(self, throw_on_missing: bool = False) -> "Node": from .nodes import StringNode @@ -127,7 +130,7 @@ def _dereference_node(self, throw_on_missing: bool = False) -> "Node": parent=parent, is_optional=self._metadata.optional, ) - assert False # pragma: no cover + assert False else: # not interpolation, compare directly if throw_on_missing: @@ -202,7 +205,7 @@ def get_node(self, key: Any) -> Optional[Node]: ... # pragma: no cover @abstractmethod - def __delitem__(self, key: Union[str, int, slice]) -> None: + def __delitem__(self, key: Any) -> None: ... # pragma: no cover @abstractmethod @@ -286,14 +289,18 @@ def _resolve_interpolation( else: resolver = OmegaConf.get_resolver(inter_type) if resolver is not None: - value = resolver(root_node, inter_key) - return ValueNode( - value=value, - parent=self, - metadata=Metadata( - ref_type=None, object_type=None, key=key, optional=True - ), - ) + try: + value = resolver(root_node, inter_key) + return ValueNode( + value=value, + parent=self, + metadata=Metadata( + ref_type=None, object_type=None, key=key, optional=True + ), + ) + except Exception as e: + self._format_and_raise(key=inter_key, value=None, cause=e) + assert False else: raise UnsupportedInterpolationType( f"Unsupported interpolation type {inter_type}" @@ -336,4 +343,4 @@ def _resolve_str_interpolation( new += orig[last_index:] return StringNode(value=new, key=key) else: - assert False # pragma: no cover + assert False diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 54b9f60bf..85be36341 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -1,5 +1,4 @@ import copy -import string import sys import warnings from abc import ABC, abstractmethod @@ -21,9 +20,9 @@ is_structured_config, ) from .base import Container, ContainerMetadata, Node -from .errors import MissingMandatoryValue, ReadonlyConfigError, ValidationError +from .errors import MissingMandatoryValue, ReadonlyConfigError -DEFAULT_VALUE_MARKER: Any = object() +DEFAULT_VALUE_MARKER: Any = str("__DEFAULT_VALUE_MARKER__") class BaseContainer(Container, ABC): @@ -94,10 +93,9 @@ def __getstate__(self) -> Dict[str, Any]: def __setstate__(self, d: Dict[str, Any]) -> None: self.__dict__.update(d) - def __delitem__(self, key: Union[str, int, slice]) -> None: - if self._get_flag("readonly"): - raise ReadonlyConfigError(self._get_full_key(str(key))) - del self.__dict__["_content"][key] + @abstractmethod + def __delitem__(self, key: Any) -> None: + ... # pragma: no cover def __len__(self) -> int: return self.__dict__["_content"].__len__() # type: ignore @@ -169,7 +167,7 @@ def select(self, key: str, throw_on_missing: bool = False) -> Any: return _get_value(value) except Exception as e: - self._translate_exception(e=e, key=key, value=None) + self._format_and_raise(key=key, value=None, cause=e) def is_empty(self) -> bool: """return true if config is empty""" @@ -213,7 +211,7 @@ def convert(val: Any) -> Any: retlist.append(item) return retlist - assert False # pragma: no cover + assert False def to_container(self, resolve: bool = False) -> Union[Dict[str, Any], List[Any]]: warnings.warn( @@ -256,22 +254,13 @@ def _map_merge(dest: "BaseContainer", src: "BaseContainer") -> None: type_backup = dest._metadata.object_type dest._metadata.object_type = None - if (dest._get_flag("readonly")) or dest._get_flag("readonly"): - raise ReadonlyConfigError("Cannot merge into read-only node") dest._validate_set_merge_impl(key=None, value=src, is_assign=False) for key, src_value in src.items_ex(resolve=False): - if dest._is_missing(): - if dest._metadata.ref_type is not None and not is_dict_annotation( - dest._metadata.ref_type - ): - dest._set_value(dest._metadata.ref_type()) - else: - dest._set_value({}) dest_element_type = dest._metadata.element_type element_typed = dest_element_type not in (None, Any) if OmegaConf.is_missing(dest, key): if isinstance(src_value, DictConfig): - if key not in dest: + if OmegaConf.is_missing(dest, key): dest[key] = src_value if (dest.get_node(key) is not None) or element_typed: @@ -388,79 +377,27 @@ def assign(value_key: Any, value_to_assign: Any) -> None: value_to_assign._set_key(value_key) self.__dict__["_content"][value_key] = value_to_assign - try: - if is_primitive_container(value): - self.__dict__["_content"][key] = wrap(key, value) - elif input_node and target_node: - # both nodes, replace existing node with new one - assign(key, value) - elif not input_node and target_node: - # input is not node, can be primitive or config - if input_config: - assign(key, value) - else: - self.__dict__["_content"][key]._set_value(value) - elif input_node and not target_node: - # target must be config, replace target with input node + if is_primitive_container(value): + self.__dict__["_content"][key] = wrap(key, value) + elif input_node and target_node: + # both nodes, replace existing node with new one + assign(key, value) + elif not input_node and target_node: + # input is not node, can be primitive or config + if input_config: assign(key, value) - elif not input_node and not target_node: - if should_set_value: - self.__dict__["_content"][key]._set_value(value) - elif input_config: - assign(key, value) - else: - self.__dict__["_content"][key] = wrap(key, value) - except ValidationError as ve: - self._format_and_raise( - exception_type=ValidationError, key=key, msg=f"{ve}", - ) - - # TODO: cleanup - def _format_and_raise( - self, exception_type: Any, msg: str, key: Optional[str] - ) -> None: - def type_str(t: Any) -> str: - if isinstance(t, type): - return t.__name__ else: - return str(node._metadata.object_type) - - full_key = self._get_full_key(key=key) - if key is None: - node = self - else: - node = ( - self.__dict__["_content"][key] - if key in self.__dict__["_content"] - else None - ) - if node is None: - node = self - - ref_type: Any = node._metadata.ref_type - if ref_type is None: - ref_type = Any - object_type = type_str(node._metadata.object_type) - - rt = type_str(ref_type) - if node._metadata.optional: - if ref_type is Any: - rt = "Any" + self.__dict__["_content"][key]._set_value(value) + elif input_node and not target_node: + # target must be config, replace target with input node + assign(key, value) + elif not input_node and not target_node: + if should_set_value: + self.__dict__["_content"][key]._set_value(value) + elif input_config: + assign(key, value) else: - rt = f"Optional[{rt}]" - - s = string.Template( - """$MSG -\tfull_key: $FULL_KEY -\treference_type=$REF_TYPE -\tobject_type=$OBJECT_TYPE -""" - ) - message = s.substitute( - REF_TYPE=rt, OBJECT_TYPE=object_type, MSG=msg, FULL_KEY=full_key, - ) - - raise exception_type(f"{message}").with_traceback(sys.exc_info()[2]) from None + self.__dict__["_content"][key] = wrap(key, value) @staticmethod def _item_eq( diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index 9e73033b6..c14df8bcc 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -17,7 +17,7 @@ from ._utils import ( ValueKind, _is_interpolation, - get_key_value_types, + get_dict_key_value_types, get_structured_config_data, get_type_of, get_value_kind, @@ -50,7 +50,7 @@ def __init__( ref_type: Optional[Type[Any]] = None, is_optional: bool = True, ) -> None: - key_type, element_type = get_key_value_types(ref_type) + key_type, element_type = get_dict_key_value_types(ref_type) super().__init__( parent=parent, metadata=ContainerMetadata( @@ -111,7 +111,7 @@ def _validate_get(self, key: Any, value: Any = None) -> None: if is_typed or is_struct: if is_typed: assert self._metadata.object_type is not None - msg = f"Key '{key}' not in ({self._metadata.object_type.__name__})" + msg = f"Key '{key}' not in '{self._metadata.object_type.__name__}'" else: msg = f"Key '{key}' in not in struct" raise AttributeError(msg) @@ -163,9 +163,7 @@ def _validate_set_merge_impl(self, key: Any, value: Any, is_assign: bool) -> Non msg = f"Cannot assign to read-only node : {value}" else: msg = f"Cannot merge into read-only node : {value}" - self._format_and_raise( - exception_type=ReadonlyConfigError, msg=msg, key=key, - ) + raise ReadonlyConfigError(msg) if target is None: return @@ -205,15 +203,14 @@ def is_typed(c: Any) -> bool: assert value_type is not None assert target_type is not None msg = ( - f"Invalid type assigned : {value_type.__name__} is not a " + f"Invalid type assigned : {value_type.__name__} is not a " f"subclass of {target_type.__name__}. value: {value}" ) - self._format_and_raise(exception_type=ValidationError, key=key, msg=msg) + raise ValidationError(msg) def _validate_and_normalize_key(self, key: Any) -> Union[str, Enum]: return self._s_validate_and_normalize_key(self._metadata.key_type, key) - # TODO: this function is a mess. def _s_validate_and_normalize_key( self, key_type: Any, key: Any ) -> Union[str, Enum]: @@ -233,7 +230,7 @@ def _s_validate_and_normalize_key( return key elif issubclass(key_type, Enum): try: - ret = EnumNode.validate_and_convert_to_enum(self, key_type, key) + ret = EnumNode.validate_and_convert_to_enum(key_type, key) assert ret is not None return ret except ValidationError as e: @@ -241,16 +238,18 @@ def _s_validate_and_normalize_key( f"Key '$KEY' is incompatible with ({key_type.__name__}) : {e}" ) else: - assert False # pragma: no cover + assert False def __setitem__(self, key: Union[str, Enum], value: Any) -> None: try: self.__set_impl(key=key, value=value) except AttributeError as e: - self._translate_exception(e=e, key=key, value=value, type_override=KeyError) + self._format_and_raise( + key=key, value=value, type_override=KeyError, cause=e + ) except Exception as e: - self._translate_exception(e=e, key=key, value=value) + self._format_and_raise(key=key, value=value, cause=e) def __set_impl(self, key: Union[str, Enum], value: Any) -> None: key = self._validate_and_normalize_key(key) @@ -272,8 +271,8 @@ def __setattr__(self, key: str, value: Any) -> None: try: self.__set_impl(key, value) except Exception as e: - self._translate_exception(e=e, key=key, value=value) - assert False # pragma: no cover + self._format_and_raise(key=key, value=value, cause=e) + assert False def __getattr__(self, key: str) -> Any: """ @@ -291,7 +290,7 @@ def __getattr__(self, key: str) -> Any: try: return self._get_impl(key=key, default_value=DEFAULT_VALUE_MARKER) except Exception as e: - self._translate_exception(e=e, key=key, value=None) + self._format_and_raise(key=key, value=None, cause=e) def __getitem__(self, key: Union[str, Enum]) -> Any: """ @@ -305,7 +304,18 @@ def __getitem__(self, key: Union[str, Enum]) -> Any: except AttributeError as e: raise KeyError(f"Error getting '{key}' : {e}") except Exception as e: - self._translate_exception(e=e, key=key, value=None) + self._format_and_raise(key=key, value=None, cause=e) + + def __delitem__(self, key: Union[str, int, Enum]) -> None: + if self._get_flag("readonly"): + self._format_and_raise( + key=key, + value=None, + cause=ReadonlyConfigError( + "Cannot delete item from read-only DictConfig" + ), + ) + del self.__dict__["_content"][key] def get( self, key: Union[str, Enum], default_value: Any = DEFAULT_VALUE_MARKER, @@ -313,9 +323,9 @@ def get( try: return self._get_impl(key=key, default_value=default_value) except Exception as e: - self._translate_exception(e=e, key=key, value=None) + self._format_and_raise(key=key, value=None, cause=e) - def _get_impl(self, key: Union[str, Enum], default_value: Any,) -> Any: + def _get_impl(self, key: Union[str, Enum], default_value: Any) -> Any: key = self._validate_and_normalize_key(key) node = self.get_node_ex(key=key, default_value=default_value) return self._resolve_with_default( @@ -340,37 +350,44 @@ def get_node_ex( value = default_value else: raise - else: - if default_value is not DEFAULT_VALUE_MARKER: - value = default_value - return value - - __pop_marker = str("__POP_MARKER__") - - def pop(self, key: Union[str, Enum], default: Any = __pop_marker) -> Any: - key = self._validate_and_normalize_key(key) - if self._get_flag("readonly"): - self._translate_exception( - e=ReadonlyConfigError("Cannot pop from read-only node"), - key=key, - value=None, - ) - value = self._resolve_with_default( - key=key, - value=self.__dict__["_content"].pop(key, default), - default_value=default, - ) - - if value is self.__pop_marker: - full_key = self._get_full_key(key) - msg = f"Cannot pop key '{key}'" - if key != full_key: - msg += f", path='{full_key}'" + if default_value is not DEFAULT_VALUE_MARKER: + value = default_value - raise KeyError(msg) return value + def pop(self, key: Union[str, Enum], default: Any = DEFAULT_VALUE_MARKER) -> Any: + try: + key = self._validate_and_normalize_key(key) + if self._get_flag("readonly"): + self._format_and_raise( + key=key, + value=None, + cause=ReadonlyConfigError("Cannot pop from read-only node"), + ) + + node = self.get_node(key=key) + if node is not None: + value = self._resolve_with_default( + key=key, value=node, default_value=default + ) + del self[key] + return value + else: + if default is not DEFAULT_VALUE_MARKER: + return default + else: + full = self._get_full_key(key) + if full != key: + raise KeyError(f"Key not found: '{key}' (path: '{full}')") + else: + raise KeyError(f"Key not found: '{key}'") + except Exception as e: + if isinstance(e, KeyError): + raise + else: + self._format_and_raise(key=key, value=None, cause=e) + def keys(self) -> Any: if self._is_missing() or self._is_interpolation() or self._is_none(): return list() @@ -453,11 +470,7 @@ def _promote(self, type_or_prototype: Optional[Type[Any]]) -> None: if type_or_prototype is None: return if not is_structured_config(type_or_prototype): - self._format_and_raise( - exception_type=ValueError, - key=None, - msg=f"Expected structured config class : {type_or_prototype}", - ) + raise ValueError(f"Expected structured config class : {type_or_prototype}") from omegaconf import OmegaConf @@ -496,7 +509,8 @@ def _set_value(self, value: Any) -> None: self._metadata.object_type = dict for k, v in value.items_ex(resolve=False): self.__setitem__(k, v) - self._metadata.object_type = OmegaConf.get_type(value) + self.__dict__["_metadata"] = copy.deepcopy(value._metadata) + elif isinstance(value, dict): self._metadata.object_type = self._metadata.ref_type for k, v in value.items(): diff --git a/omegaconf/listconfig.py b/omegaconf/listconfig.py index a4eb67fb9..6426e318b 100644 --- a/omegaconf/listconfig.py +++ b/omegaconf/listconfig.py @@ -14,7 +14,7 @@ Union, ) -from ._utils import ValueKind, get_value_kind, is_primitive_list +from ._utils import ValueKind, get_list_element_type, get_value_kind, is_primitive_list from .base import Container, ContainerMetadata, Node from .basecontainer import BaseContainer from .errors import ( @@ -23,7 +23,7 @@ ReadonlyConfigError, ValidationError, ) -from .nodes import AnyNode, ValueNode +from .nodes import ValueNode class ListConfig(BaseContainer, MutableSequence[Any]): @@ -37,8 +37,8 @@ def __init__( parent: Optional[Container] = None, ref_type: Optional[Type[Any]] = None, is_optional: bool = True, - element_type: Optional[Type[Any]] = None, ) -> None: + element_type = get_list_element_type(ref_type) super().__init__( parent=parent, metadata=ContainerMetadata( @@ -115,7 +115,7 @@ def __getitem__(self, index: Union[int, slice]) -> Any: else: return self._resolve_with_default(key=index, value=self._content[index]) except Exception as e: - self._translate_exception(e=e, key=index, value=None) + self._format_and_raise(key=index, value=None, cause=e) def _set_at_index(self, index: Union[int, slice], value: Any) -> None: self._set_item_impl(index, value) @@ -124,7 +124,7 @@ def __setitem__(self, index: Union[int, slice], value: Any) -> None: try: self._set_at_index(index, value) except Exception as e: - self._translate_exception(e=e, key=index, value=value) + self._format_and_raise(key=index, value=value, cause=e) def append(self, item: Any) -> None: try: @@ -142,10 +142,18 @@ def append(self, item: Any) -> None: ) self.__dict__["_content"].append(node) except Exception as e: - self._translate_exception(e=e, key=index, value=item) - assert False # pragma: no cover + self._format_and_raise(key=index, value=item, cause=e) + assert False + + def _update_keys(self) -> None: + for i in range(len(self)): + node = self.get_node(i) + if node is not None: + node._metadata.key = i def insert(self, index: int, item: Any) -> None: + from omegaconf.omegaconf import OmegaConf, _maybe_wrap + try: if self._get_flag("readonly"): raise ReadonlyConfigError("Cannot insert into a read-only ListConfig") @@ -155,30 +163,28 @@ def insert(self, index: int, item: Any) -> None: ) if self._is_missing(): raise MissingMandatoryValue("Cannot insert into missing ListConfig") + try: - # TODO: not respecting element type - # TODO: this and other list ops like delete are not updating key in list element nodes - # from omegaconf.omegaconf import OmegaConf, _maybe_wrap - # - # index = len(self) - # self._validate_set(key=index, value=item) - # - # node = _maybe_wrap( - # ref_type=self._metadata.element_type, - # key=index, - # value=item, - # is_optional=OmegaConf.is_optional(item), - # parent=self, - # ) assert isinstance(self._content, list) - self._content.insert(index, AnyNode(None)) - self._set_at_index(index, item) + # insert place holder + self._content.insert(index, None) + node = _maybe_wrap( + ref_type=self._metadata.element_type, + key=index, + value=item, + is_optional=OmegaConf.is_optional(item), + parent=self, + ) + self._validate_set(key=index, value=node) + self._set_at_index(index, node) + self._update_keys() except Exception: del self.__dict__["_content"][index] + self._update_keys() raise except Exception as e: - self._translate_exception(e=e, key=index, value=item) - assert False # pragma: no cover + self._format_and_raise(key=index, value=item, cause=e) + assert False def extend(self, lst: Iterable[Any]) -> None: assert isinstance(lst, (tuple, list, ListConfig)) @@ -188,6 +194,18 @@ def extend(self, lst: Iterable[Any]) -> None: def remove(self, x: Any) -> None: del self[self.index(x)] + def __delitem__(self, key: Union[str, int, slice]) -> None: + if self._get_flag("readonly"): + self._format_and_raise( + key=key, + value=None, + cause=ReadonlyConfigError( + "Cannot delete item from read-only ListConfig" + ), + ) + del self.__dict__["_content"][key] + self._update_keys() + def clear(self) -> None: del self[:] @@ -209,10 +227,10 @@ def index( if found_idx != -1: return found_idx else: - self._translate_exception( - e=ValueError("Item not found in ListConfig"), key=None, value=None + self._format_and_raise( + key=None, value=None, cause=ValueError("Item not found in ListConfig") ) - assert False # pragma: no cover + assert False def count(self, x: Any) -> int: c = 0 @@ -240,8 +258,8 @@ def get_node_ex(self, key: Any, validate_access: bool = True) -> Optional[Node]: return self._content[key] # type: ignore except (IndexError, TypeError) as e: if validate_access: - self._translate_exception(e=e, key=key, value=None) - assert False # pragma: no cover + self._format_and_raise(key=key, value=None, cause=e) + assert False else: return None @@ -257,8 +275,8 @@ def get(self, index: int, default_value: Any = None) -> Any: key=index, value=self._content[index], default_value=default_value ) except Exception as e: - self._translate_exception(e=e, key=index, value=None) - assert False # pragma: no cover + self._format_and_raise(key=index, value=None, cause=e) + assert False def pop(self, index: int = -1) -> Any: try: @@ -270,13 +288,15 @@ def pop(self, index: int = -1) -> Any: raise MissingMandatoryValue("Cannot pop from a missing ListConfig") assert isinstance(self._content, list) - - return self._resolve_with_default( - key=index, value=self._content.pop(index), default_value=None + ret = self._resolve_with_default( + key=index, value=self.get_node(index), default_value=None ) + del self._content[index] + self._update_keys() + return ret except (ReadonlyConfigError, IndexError) as e: - self._translate_exception(e=e, key=index, value=None) - assert False # pragma: no cover + self._format_and_raise(key=index, value=None, cause=e) + assert False def sort( self, key: Optional[Callable[[Any], Any]] = None, reverse: bool = False @@ -303,8 +323,8 @@ def key1(x: Any) -> Any: self._content.sort(key=key1, reverse=reverse) except (ReadonlyConfigError, IndexError) as e: - self._translate_exception(e=e, key=None, value=None) - assert False # pragma: no cover + self._format_and_raise(key=None, value=None, cause=e) + assert False def __eq__(self, other: Any) -> bool: if isinstance(other, (list, tuple)) or other is None: @@ -346,8 +366,8 @@ def next(self) -> Any: assert isinstance(self._content, list) return MyItems(self._content) except (ReadonlyConfigError, TypeError, MissingMandatoryValue) as e: - self._translate_exception(e=e, key=None, value=None) - assert False # pragma: no cover + self._format_and_raise(key=None, value=None, cause=e) + assert False def __add__(self, other: Union[List[Any], "ListConfig"]) -> "ListConfig": # res is sharing this list's parent to allow interpolation to work as expected @@ -384,12 +404,16 @@ def _set_value(self, value: Any) -> None: self.__dict__["_content"] = value else: assert is_primitive_list(value) or isinstance(value, ListConfig) + self.__dict__["_content"] = [] if isinstance(value, ListConfig): self._metadata = copy.deepcopy(value._metadata) - self.__dict__["_content"] = [] + self._metadata.flags = {} for item in value: self.append(item) + if isinstance(value, ListConfig): + self._metadata.flags = value._metadata.flags + @staticmethod def _list_eq(l1: Optional["ListConfig"], l2: Optional["ListConfig"]) -> bool: diff --git a/omegaconf/nodes.py b/omegaconf/nodes.py index b9f6be62d..fde6d653d 100644 --- a/omegaconf/nodes.py +++ b/omegaconf/nodes.py @@ -93,10 +93,12 @@ def _is_interpolation(self) -> bool: return _is_interpolation(self._value()) def _get_full_key(self, key: Union[str, Enum, int, None]) -> str: - # TODO: add testing for get_full_key on value nodes, including value nodes without a parent. parent = self._get_parent() if parent is None: - return str(self._metadata.key) + if self._metadata.key is None: + return "" + else: + return str(self._metadata.key) else: return parent._get_full_key(self._metadata.key) @@ -322,13 +324,11 @@ def __init__( ) def validate_and_convert(self, value: Any) -> Optional[Enum]: - return self.validate_and_convert_to_enum( - self, enum_type=self.enum_type, value=value - ) + return self.validate_and_convert_to_enum(enum_type=self.enum_type, value=value) @staticmethod def validate_and_convert_to_enum( - node: Node, enum_type: Type[Enum], value: Any + enum_type: Type[Enum], value: Any ) -> Optional[Enum]: if value is None: return None @@ -346,7 +346,7 @@ def validate_and_convert_to_enum( raise ValueError if isinstance(value, int): - return enum_type(value) # TODO: does this ever work?? + return enum_type(value) if isinstance(value, str): prefix = f"{enum_type.__name__}." @@ -354,7 +354,7 @@ def validate_and_convert_to_enum( value = value[len(prefix) :] return enum_type[value] - assert False # pragma: no cover + assert False except (ValueError, KeyError) as e: valid = "\n".join([f"\t{x}" for x in enum_type.__members__.keys()]) diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index 8013e3f59..e609aebdf 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -109,8 +109,7 @@ def env(key: str, default: Optional[str] = None) -> Any: if default is not None: return decode_primitive(default) else: - # TODO: validate and test - raise KeyError("Environment variable '{}' not found".format(key)) + raise ValidationError(f"Environment variable '{key}' not found") OmegaConf.register_resolver("env", env) @@ -183,7 +182,7 @@ def _create_impl( # noqa F811 return DictConfig(content=obj, parent=parent, ref_type=ref_type) elif is_primitive_list(obj) or OmegaConf.is_list(obj): ref_type = OmegaConf.get_type(obj) - return ListConfig(content=obj, parent=parent, ref_type=ref_type) + return ListConfig(ref_type=ref_type, content=obj, parent=parent) else: if isinstance(obj, type): raise ValidationError( @@ -261,7 +260,8 @@ def merge( if is_primitive_container(target) or is_structured_config(target): target = OmegaConf.create(target) assert isinstance(target, (DictConfig, ListConfig)) - target.merge_with(*others[1:]) + with flag_override(target, "readonly", False): + target.merge_with(*others[1:]) return target @staticmethod @@ -472,7 +472,7 @@ def get_type(obj: Any, key: Optional[str] = None) -> Optional[Type[Any]]: elif isinstance(c, (list, tuple)): return list else: - assert False # pragma: no cover + assert False @staticmethod def get_ref_type(obj: Any, key: Optional[str] = None) -> Optional[Type[Any]]: @@ -558,7 +558,9 @@ def _node_wrap( elif type_ == str: node = StringNode(value=value, key=key, parent=parent, is_optional=is_optional) else: - raise ValidationError(f"Unexpected object type : {type_.__name__}") + raise ValidationError( + f"Unexpected object type : {type_.__name__}" + ) # pragma: no cover return node @@ -575,7 +577,7 @@ def _maybe_wrap( value._set_key(key) value._set_parent(parent) return value - ret: Node # pragma: no cover + ret: Node origin_ = getattr(ref_type, "__origin__", None) is_dict = ( type(value) in (dict, DictConfig) @@ -598,23 +600,14 @@ def _maybe_wrap( is_optional=is_optional, ) elif is_list: - args = getattr(ref_type, "__args__", None) - if ref_type is not List and args is not None: - element_type = args[0] - else: - element_type = None - - if not (valid_value_annotation_type(element_type)): - raise ValidationError(f"Unsupported value type : {element_type}") ret = ListConfig( content=value, key=key, parent=parent, is_optional=is_optional, - element_type=element_type, + ref_type=ref_type, ) - elif ( is_structured_config(ref_type) and ( @@ -623,7 +616,7 @@ def _maybe_wrap( or value_kind == ValueKind.INTERPOLATION or value is None ) - ) or is_structured_config(value): + ) or is_structured_config(ref_type): from . import DictConfig ret = DictConfig( @@ -682,6 +675,6 @@ def _select_one( else: val = c.get_node(ret_key) else: - assert False # pragma: no cover + assert False return val, ret_key diff --git a/tests/examples/test_dataclass_example.py b/tests/examples/test_dataclass_example.py index 902b1902a..4e3021496 100644 --- a/tests/examples/test_dataclass_example.py +++ b/tests/examples/test_dataclass_example.py @@ -277,9 +277,7 @@ def test_enum_key() -> None: def test_dict_of_objects() -> None: conf: WebServer = OmegaConf.structured(WebServer) conf.domains["blog"] = Domain(name="blog.example.com", path="/www/blog.example.com") - with pytest.raises( - ValidationError - ): # TODO: improve exception, error makes no sense. + with pytest.raises(ValidationError): conf.domains.foo = 10 # type: ignore assert conf.domains["blog"].name == "blog.example.com" diff --git a/tests/test_base_config.py b/tests/test_base_config.py index 4d51650fd..fd936a0d1 100644 --- a/tests/test_base_config.py +++ b/tests/test_base_config.py @@ -261,7 +261,7 @@ def test_deepcopy_and_merge_and_flags() -> None: @pytest.mark.parametrize( # type: ignore - "cfg", [ListConfig(content=[], element_type=int), DictConfig(content={})], + "cfg", [ListConfig(ref_type=List[int], content=[]), DictConfig(content={})], ) def test_deepcopy_preserves_container_type(cfg: Container) -> None: cp: Container = copy.deepcopy(cfg) diff --git a/tests/test_basic_ops_dict.py b/tests/test_basic_ops_dict.py index 5cf9e0f18..d937168ea 100644 --- a/tests/test_basic_ops_dict.py +++ b/tests/test_basic_ops_dict.py @@ -230,42 +230,44 @@ def test_iterate_dictionary() -> None: @pytest.mark.parametrize( # type: ignore - "cfg, key, default_, expected, expectation", + "cfg, key, default_, expected", [ - (dict(a=1, b=2), "a", None, 1, does_not_raise()), - (dict(a=1, b=2), "not_found", "default", "default", does_not_raise()), - (dict(a=1, b=2), "not_found", None, None, pytest.raises(KeyError)), + ({"a": 1, "b": 2}, "a", None, 1), + ({"a": 1, "b": 2}, "not_found", "default", "default"), # Interpolations - (dict(a="${b}", b=2), "a", None, 2, does_not_raise()), - (dict(a="???", b=2), "a", None, None, pytest.raises(KeyError)), - ( - dict(a="${b}", b="???"), - "a", - None, - None, - pytest.raises(MissingMandatoryValue), - ), + ({"a": "${b}", "b": 2}, "a", None, 2), # enum key - ({Enum1.FOO: "bar"}, Enum1.FOO, None, "bar", does_not_raise()), - ({Enum1.FOO: "bar"}, Enum1.BAR, "default", "default", does_not_raise()), - ({Enum1.FOO: "bar"}, Enum1.BAR, None, None, pytest.raises(KeyError)), + ({Enum1.FOO: "bar"}, Enum1.FOO, None, "bar"), + ({Enum1.FOO: "bar"}, Enum1.BAR, "default", "default"), ], ) -def test_dict_pop( - cfg: Dict[Any, Any], key: Any, default_: Any, expected: Any, expectation: Any -) -> None: +def test_dict_pop(cfg: Dict[Any, Any], key: Any, default_: Any, expected: Any) -> None: c = OmegaConf.create(cfg) - with expectation: - if default_ is not None: - val = c.pop(key, default_) - else: - val = c.pop(key) + if default_ is not None: + val = c.pop(key, default_) + else: + val = c.pop(key) - assert val == expected - assert type(val) == type(expected) + assert val == expected + assert type(val) == type(expected) -# TODO: test that a failed pop does not mutate the dict +@pytest.mark.parametrize( # type: ignore + "cfg, key, expectation", + [ + ({"a": 1, "b": 2}, "not_found", pytest.raises(KeyError)), + # Interpolations + ({"a": "???", "b": 2}, "a", pytest.raises(MissingMandatoryValue)), + ({"a": "${b}", "b": "???"}, "a", pytest.raises(MissingMandatoryValue),), + # enum key + ({Enum1.FOO: "bar"}, Enum1.BAR, pytest.raises(KeyError)), + ], +) +def test_dict_pop_error(cfg: Dict[Any, Any], key: Any, expectation: Any) -> None: + c = OmegaConf.create(cfg) + with expectation: + c.pop(key) + assert c == cfg @pytest.mark.parametrize( # type: ignore diff --git a/tests/test_basic_ops_list.py b/tests/test_basic_ops_list.py index 92de1a4e2..611fa518f 100644 --- a/tests/test_basic_ops_list.py +++ b/tests/test_basic_ops_list.py @@ -9,6 +9,7 @@ KeyValidationError, MissingMandatoryValue, UnsupportedValueType, + ValidationError, ) from omegaconf.nodes import IntegerNode, StringNode @@ -100,6 +101,23 @@ def test_list_pop() -> None: assert c == [2, 3] with pytest.raises(IndexError): c.pop(100) + validate_list_keys(c) + + +def test_list_pop_on_unexpected_exception_not_modifying(mocker: Any) -> None: + src = [1, 2, 3, 4] + c = OmegaConf.create(src) + + def resolve_with_default_throws( + key: Any, value: Any, default_value: Any = None, + ) -> None: + raise Exception("mocked_exception") + + mocker.patch.object(c, "_resolve_with_default", resolve_with_default_throws) + + with pytest.raises(Exception, match="mocked_exception"): + c.pop(0) + assert c == src def test_in_list() -> None: @@ -146,6 +164,8 @@ def test_list_delitem() -> None: with pytest.raises(IndexError): del c[100] + validate_list_keys(c) + @pytest.mark.parametrize( # type: ignore "lst,expected", @@ -178,6 +198,8 @@ def test_list_append() -> None: c.append([]) assert c == [1, 2, {}, []] + validate_list_keys(c) + def test_pretty_without_resolve() -> None: c = OmegaConf.create([100, "${0}"]) @@ -210,22 +232,60 @@ def test_list_dir() -> None: assert ["0", "1", "2"] == dir(c) +def validate_list_keys(c: Any) -> None: + # validate keys are maintained + for i in range(len(c)): + assert c.get_node(i)._metadata.key == i + + @pytest.mark.parametrize( # type: ignore - "input_, index, value, expected, expected_node_type", + "input_, index, value, expected, expected_node_type, expectation", [ - (["a", "b", "c"], 1, 100, ["a", 100, "b", "c"], AnyNode), - (["a", "b", "c"], 1, IntegerNode(100), ["a", 100, "b", "c"], IntegerNode), - (["a", "b", "c"], 1, "foo", ["a", "foo", "b", "c"], AnyNode), - (["a", "b", "c"], 1, StringNode("foo"), ["a", "foo", "b", "c"], StringNode), + (["a", "b", "c"], 1, 100, ["a", 100, "b", "c"], AnyNode, None), + ( + ["a", "b", "c"], + 1, + IntegerNode(100), + ["a", 100, "b", "c"], + IntegerNode, + None, + ), + (["a", "b", "c"], 1, "foo", ["a", "foo", "b", "c"], AnyNode, None), + ( + ["a", "b", "c"], + 1, + StringNode("foo"), + ["a", "foo", "b", "c"], + StringNode, + None, + ), + ( + ListConfig(ref_type=List[int], content=[]), + 0, + "foo", + None, + None, + ValidationError, + ), ], ) def test_insert( - input_: List[str], index: int, value: Any, expected: Any, expected_node_type: type + input_: List[str], + index: int, + value: Any, + expected: Any, + expected_node_type: type, + expectation: Any, ) -> None: c = OmegaConf.create(input_) - c.insert(index, value) - assert c == expected - assert type(c.get_node(index)) == expected_node_type + if expectation is None: + c.insert(index, value) + assert c == expected + assert type(c.get_node(index)) == expected_node_type + else: + with pytest.raises(expectation): + c.insert(index, value) + validate_list_keys(c) @pytest.mark.parametrize( # type: ignore @@ -344,6 +404,7 @@ def test_append_throws_not_changing_list() -> None: c.append(v) assert len(c) == 0 assert c == [] + validate_list_keys(c) def test_hash() -> None: diff --git a/tests/test_create.py b/tests/test_create.py index 261813dfb..d413c0e6e 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -5,10 +5,10 @@ import pytest -from omegaconf import OmegaConf +from omegaconf import DictConfig, ListConfig, OmegaConf from omegaconf.errors import UnsupportedValueType -from . import IllegalType +from . import ConcretePlugin, IllegalType, Plugin @pytest.mark.parametrize( # type: ignore @@ -116,5 +116,19 @@ def test_create_from_oc_with_flags() -> None: assert c1._metadata.flags == c2._metadata.flags -# TODO: test that DictConfig created with OmegaConf.create(DictConfig(...)) -# is preserving the metadata of the inner dict config, same for listconfig +def test_create_from_dictconfig_preserves_metadata() -> None: + cfg1 = DictConfig(ref_type=Plugin, is_optional=False, content=ConcretePlugin) + OmegaConf.set_struct(cfg1, True) + OmegaConf.set_readonly(cfg1, True) + cfg2 = OmegaConf.create(cfg1) + assert cfg1 == cfg2 + assert cfg1._metadata == cfg2._metadata + + +def test_create_from_listconfig_preserves_metadata() -> None: + cfg1 = ListConfig(ref_type=List[int], is_optional=False, content=[1, 2, 3]) + OmegaConf.set_struct(cfg1, True) + OmegaConf.set_readonly(cfg1, True) + cfg2 = OmegaConf.create(cfg1) + assert cfg1 == cfg2 + assert cfg1._metadata == cfg2._metadata diff --git a/tests/test_errors.py b/tests/test_errors.py index 59213aea5..0471de3c6 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -1,6 +1,6 @@ import re from dataclasses import dataclass -from typing import Any, Dict +from typing import Any, Dict, List import pytest @@ -29,415 +29,464 @@ def create_readonly(cfg: Any) -> Any: return cfg -def create_struct(cfg: Any) -> Any: - cfg = OmegaConf.create(cfg) - OmegaConf.set_struct(cfg, True) - return cfg +params = [ + ############## + # DictConfig # + ############## + # update_node + pytest.param( + lambda: OmegaConf.structured(StructuredWithMissing), + lambda cfg: cfg.update_node("num", "hello"), + ValidationError, + "Value 'hello' could not be converted to Integer", + id="structured:update_node_with_invalid_value", + ), + pytest.param( + lambda: OmegaConf.structured(StructuredWithMissing), + lambda cfg: cfg.update_node("num", None), + ValidationError, + "field 'num' is not Optional", + id="structured:update_node:none_to_non_optional", + ), + pytest.param( + lambda: OmegaConf.create({}), + lambda cfg: cfg.update_node("a", IllegalType()), + UnsupportedValueType, + "Value 'IllegalType' is not a supported primitive type", + id="dict:update_node:object_of_illegal_type", + ), + # pop + pytest.param( + lambda: create_readonly({"foo": "bar"}), + lambda cfg: cfg.pop("foo"), + ReadonlyConfigError, + "Cannot pop from read-only node", + id="dict,readonly:pop", + ), + pytest.param( + lambda: OmegaConf.create({"foo": "bar"}), + lambda cfg: cfg.pop("nevermind"), + KeyError, + "Key not found: 'nevermind'", + id="dict:pop_invalid", + ), + pytest.param( + lambda: OmegaConf.create({"foo": {}}), + lambda cfg: cfg.foo.pop("nevermind"), + KeyError, + "Key not found: 'nevermind' (path: 'foo.nevermind')", + id="dict:pop_invalid", + ), + pytest.param( + lambda: OmegaConf.structured(ConcretePlugin), + lambda cfg: getattr(cfg, "fail"), + AttributeError, + "Key 'fail' not in 'ConcretePlugin'", + id="structured:access_invalid_attribute", + ), + # getattr + pytest.param( + lambda: create_struct({"foo": "bar"}), + lambda cfg: getattr(cfg, "fail"), + AttributeError, + "Key 'fail' in not in struct", + id="dict,struct:access_invalid_attribute", + ), + # setattr + pytest.param( + lambda: create_struct({"foo": "bar"}), + lambda cfg: setattr(cfg, "zlonk", "zlank"), + AttributeError, + "Key 'zlonk' in not in struct", + id="dict,struct:set_invalid_attribute", + ), + pytest.param( + lambda: OmegaConf.structured(ConcretePlugin), + lambda cfg: setattr(cfg, "params", 10), + ValidationError, + "Invalid type assigned : int is not a subclass of FoobarParams. value: 10", + id="structured:setattr,invalid_type", + ), + pytest.param( + lambda: create_readonly({"foo": "bar"}), + lambda cfg: setattr(cfg, "foo", 20), + ReadonlyConfigError, + "Cannot assign to read-only node : 20", + id="dict,readonly:set_attribute", + ), + pytest.param( + lambda: OmegaConf.create({"foo": DictConfig(is_optional=False, content={})}), + lambda cfg: setattr(cfg, "foo", None), + ValidationError, + "field 'foo' is not Optional", + id="dict,none_optional:set_none", + ), + # setitem + pytest.param( + lambda: create_struct({"foo": "bar"}), + lambda cfg: cfg.__setitem__("zoo", "zonk"), + KeyError, + "Error setting zoo=zonk : Key 'zoo' in not in struct", + id="dict,struct:setitem_on_none_existing_key", + ), + # getitem + pytest.param( + lambda: create_struct({"foo": "bar"}), + lambda cfg: cfg.__getitem__("zoo"), + KeyError, + "Error getting 'zoo' : Key 'zoo' in not in struct", + id="dict,struct:getitem_key_not_in_struct", + ), + pytest.param( + lambda: DictConfig(ref_type=Dict[Color, str], content={}), + lambda cfg: cfg.__getitem__("foo"), + KeyValidationError, + "Key 'foo' is incompatible with (Color)", + id="dict,reftype=Dict[Color,str]:,getitem_str_key", + ), + pytest.param( + lambda: DictConfig(ref_type=Dict[str, str], content={}), + lambda cfg: cfg.__getitem__(Color.RED), + KeyValidationError, + "Key Color.RED (Color) is incompatible with (str)", + id="dict,reftype=Dict[str,str]:,getitem_color_key", + ), + # merge + pytest.param( + lambda: create_readonly({"foo": "bar"}), + lambda cfg: cfg.merge_with(OmegaConf.create()), + ReadonlyConfigError, + "Cannot merge into read-only node", + id="dict,readonly:merge_with", + ), + # merge_with + pytest.param( + lambda: OmegaConf.structured(ConcretePlugin), + lambda cfg: cfg.merge_with(Plugin), + ValidationError, + "Plugin is not a subclass of ConcretePlugin. value: {'name': '???', 'params': '???'}", + id="structured:merge_invalid_dataclass", + ), + # get + pytest.param( + lambda: OmegaConf.create(), + lambda cfg: cfg.get(IllegalType), + KeyValidationError, + "Incompatible key type 'type'", + id="dict:get_illegal_type", + ), + pytest.param( + lambda: OmegaConf.create(), + lambda cfg: cfg.get(IllegalType()), + KeyValidationError, + "Incompatible key type 'IllegalType'", + id="dict:get_object_of_illegal_type", + ), + # create + pytest.param( + lambda: None, + lambda cfg: OmegaConf.structured(NonOptionalAssignedNone), + ValidationError, + "Non optional field cannot be assigned None", + id="dict_create_none_optional_with_none", + ), + pytest.param( + lambda: None, + lambda cfg: OmegaConf.structured(IllegalType), + ValidationError, + "Input class 'IllegalType' is not a structured config. did you forget to decorate it as a dataclass?", + id="dict_create_from_illegal_type", + ), + pytest.param( + lambda: None, + lambda cfg: OmegaConf.structured(IllegalType()), + ValidationError, + "Object of unsupported type: 'IllegalType'", + id="structured:create_from_unsupported_object", + ), + # assign + pytest.param( + lambda: DictConfig(ref_type=ConcretePlugin, content="???"), + lambda cfg: cfg._set_value(1), + ValidationError, + "Invalid type assigned : int is not a subclass of ConcretePlugin. value: 1", + id="dict:set_value:reftype_mismatch", + ), + pytest.param( + lambda: DictConfig(ref_type=Dict[str, int], content={"foo": 10, "bar": 20}), + lambda cfg: cfg.__setitem__("baz", "fail"), + ValidationError, + "Value 'fail' could not be converted to Integer", + id="dict,int_element_type:assigned_str_value", + ), + # delete + pytest.param( + lambda: create_readonly({"foo": "bar"}), + lambda cfg: cfg.__delitem__("foo"), + ReadonlyConfigError, + "Cannot delete item from read-only DictConfig", + id="dict,readonly:del", + ), + ############## + # ListConfig # + ############## + # get node + pytest.param( + lambda: OmegaConf.create([1, 2, 3]), + lambda cfg: cfg.get_node_ex("foo"), + TypeError, + "list indices must be integers or slices, not str", + id="list:get_nox_ex:invalid_index_type", + ), + pytest.param( + lambda: OmegaConf.create([1, 2, 3]), + lambda cfg: cfg.get_node_ex(20), + IndexError, + "list index out of range", + id="list:get_node_ex:index_out_of_range", + ), + pytest.param( + lambda: ListConfig(content=None), + lambda cfg: cfg.get_node_ex(20), + TypeError, + "Cannot get_node from a ListConfig object representing None", + id="list:get_node_none", + ), + pytest.param( + lambda: ListConfig(content="???"), + lambda cfg: cfg.get_node_ex(20), + MissingMandatoryValue, + "Cannot get_node from a missing ListConfig", + id="list:get_node_missing", + ), + # create + pytest.param( + lambda: None, + lambda cfg: ListConfig(is_optional=False, content=None), + ValidationError, + "Non optional ListConfig cannot be constructed from None", + id="list:create_not_optional_with_none", + ), + # append + pytest.param( + lambda: OmegaConf.create([]), + lambda cfg: cfg.append(IllegalType()), + UnsupportedValueType, + "Value 'IllegalType' is not a supported primitive type", + id="list:append_value_of_illegal_type", + ), + # pop + pytest.param( + lambda: create_readonly([1, 2, 3]), + lambda cfg: cfg.pop(0), + ReadonlyConfigError, + "Cannot pop from read-only ListConfig", + id="dict:readonly:pop", + ), + pytest.param( + lambda: create_readonly([1, 2, 3]), + lambda cfg: cfg.pop("Invalid key type"), + ReadonlyConfigError, + "Cannot pop from read-only ListConfig", + id="dict:readonly:pop", + ), + pytest.param( + lambda: ListConfig(content=None), + lambda cfg: cfg.pop(0), + TypeError, + "Cannot pop from a ListConfig object representing None", + id="list:pop_from_none", + ), + pytest.param( + lambda: ListConfig(content="???"), + lambda cfg: cfg.pop(0), + MissingMandatoryValue, + "Cannot pop from a missing ListConfig", + id="list:pop_from_missing", + ), + # getitem + pytest.param( + lambda: OmegaConf.create(["???"]), + lambda cfg: cfg.__getitem__(slice(0, 1)), + MissingMandatoryValue, + "Missing mandatory value: [slice(0, 1, None)]", + id="list:subscript_slice_with_missing", + ), + pytest.param( + lambda: OmegaConf.create([10, "???"]), + lambda cfg: cfg.__getitem__(1), + MissingMandatoryValue, + "Missing mandatory value: [1]", + id="list:subscript_index_with_missing", + ), + pytest.param( + lambda: OmegaConf.create([1, 2, 3]), + lambda cfg: cfg.__getitem__(20), + IndexError, + "list index out of range", + id="list:subscript:index_out_of_range", + ), + pytest.param( + lambda: OmegaConf.create([1, 2, 3]), + lambda cfg: cfg.__getitem__("foo"), + KeyValidationError, + "Invalid key type 'str'", + id="list:getitem,illegal_key_type", + ), + pytest.param( + lambda: ListConfig(content=None), + lambda cfg: cfg.__getitem__(0), + TypeError, + "ListConfig object representing None is not subscriptable", + id="list:getitem,illegal_key_type", + ), + # setitem + pytest.param( + lambda: OmegaConf.create([None]), + lambda cfg: cfg.__setitem__(0, IllegalType()), + UnsupportedValueType, + "Value 'IllegalType' is not a supported primitive type", + id="list:setitem,illegal_value_type", + ), + pytest.param( + lambda: OmegaConf.create([1, 2, 3]), + lambda cfg: cfg.__setitem__("foo", 4), + KeyValidationError, + "Invalid key type 'str'", + id="list:setitem,illegal_key_type", + ), + pytest.param( + lambda: create_readonly([1, 2, 3]), + lambda cfg: cfg.__setitem__(0, 4), + ReadonlyConfigError, + "ListConfig is read-only", + id="list,readonly:setitem", + ), + # assign + pytest.param( + lambda: ListConfig(ref_type=List[int], content=[1, 2, 3]), + lambda cfg: cfg.__setitem__(0, "foo"), + ValidationError, + "Value 'foo' could not be converted to Integer", + id="list,int_elements:assigned_str_element", + ), + pytest.param( + # make sure OmegaConf.create is not losing critical metadata. + lambda: OmegaConf.create(ListConfig(ref_type=List[int], content=[1, 2, 3])), + lambda cfg: cfg.__setitem__(0, "foo"), + ValidationError, + "Value 'foo' could not be converted to Integer", + id="list,int_elements:assigned_str_element", + ), + pytest.param( + lambda: OmegaConf.create([IntegerNode(is_optional=False, value=0), 2, 3]), + lambda cfg: cfg.__setitem__(0, None), + ValidationError, + "[0] is not optional and cannot be assigned None", + id="list,not_optional:assigned_none", + ), + # index + pytest.param( + lambda: create_readonly([1, 2, 3]), + lambda cfg: cfg.index(99), + ValueError, + "Item not found in ListConfig", + id="list,readonly:index_not_found", + ), + # insert + pytest.param( + lambda: create_readonly([1, 2, 3]), + lambda cfg: cfg.insert(1, 99), + ReadonlyConfigError, + "Cannot insert into a read-only ListConfig", + id="list,readonly:insert", + ), + pytest.param( + lambda: ListConfig(content=None), + lambda cfg: cfg.insert(1, 99), + TypeError, + "Cannot insert into ListConfig object representing None", + id="list:insert_into_none", + ), + pytest.param( + lambda: ListConfig(content="???"), + lambda cfg: cfg.insert(1, 99), + MissingMandatoryValue, + "Cannot insert into missing ListConfig", + id="list:insert_into_missing", + ), + # get + pytest.param( + lambda: ListConfig(content=None), + lambda cfg: cfg.get(0), + TypeError, + "Cannot get from a ListConfig object representing None", + id="list:get_from_none", + ), + pytest.param( + lambda: ListConfig(content="???"), + lambda cfg: cfg.get(0), + MissingMandatoryValue, + "Cannot get from a missing ListConfig", + id="list:get_from_missing", + ), + # sort + pytest.param( + lambda: create_readonly([1, 2, 3]), + lambda cfg: cfg.sort(), + ReadonlyConfigError, + "Cannot sort a read-only ListConfig", + id="list:readonly:sort", + ), + pytest.param( + lambda: ListConfig(content=None), + lambda cfg: cfg.sort(), + TypeError, + "Cannot sort a ListConfig object representing None", + id="list:sort_from_none", + ), + pytest.param( + lambda: ListConfig(content="???"), + lambda cfg: cfg.sort(), + MissingMandatoryValue, + "Cannot sort a missing ListConfig", + id="list:sort_from_missing", + ), + # iter + pytest.param( + lambda: create_readonly([1, 2, 3]), + lambda cfg: cfg.sort(), + ReadonlyConfigError, + "Cannot sort a read-only ListConfig", + id="list:readonly:sort", + ), + pytest.param( + lambda: ListConfig(content=None), + lambda cfg: iter(cfg), + TypeError, + "Cannot iterate on ListConfig object representing None", + id="list:iter_none", + ), + pytest.param( + lambda: ListConfig(content="???"), + lambda cfg: iter(cfg), + MissingMandatoryValue, + "Cannot iterate on a missing ListConfig", + id="list:iter_missing", + ), + # delete + pytest.param( + lambda: create_readonly([1, 2, 3]), + lambda cfg: cfg.__delitem__(0), + ReadonlyConfigError, + "Cannot delete item from read-only ListConfig", + id="list,readonly:del", + ), +] -# TODO: change all ids to be normalized @pytest.mark.parametrize( # type:ignore - "create, op, exception_type, msg", - [ - pytest.param( - lambda: OmegaConf.structured(StructuredWithMissing), - lambda cfg: cfg.update_node("num", "hello"), - ValidationError, - "Value 'hello' could not be converted to Integer", - id="structured:update_node_with_invalid_value", - ), - pytest.param( - lambda: OmegaConf.structured(StructuredWithMissing), - lambda cfg: cfg.update_node("num", None), - ValidationError, - "field 'num' is not Optional", - id="structured:update_node:none_to_non_optional", - ), - pytest.param( - lambda: OmegaConf.create({}), - lambda cfg: cfg.update_node("a", IllegalType()), - UnsupportedValueType, - "Value 'IllegalType' is not a supported primitive type", - id="dict:update_node:object_of_illegal_type", - ), - pytest.param( - lambda: create_readonly({"foo": "bar"}), - lambda cfg: cfg.pop("foo"), - ReadonlyConfigError, - "Cannot pop from read-only node", - id="dict,readonly:pop", - ), - pytest.param( - lambda: OmegaConf.create({"foo": "bar"}), - lambda cfg: cfg.pop("nevermind"), - KeyError, - "Cannot pop key 'nevermind'", - id="dict:pop_invalid", - ), - pytest.param( - lambda: OmegaConf.structured(ConcretePlugin), - lambda cfg: getattr(cfg, "fail"), - AttributeError, - "Key 'fail' not in (ConcretePlugin)", - id="Structured: access invalid attribute", - ), - pytest.param( - lambda: create_struct({"foo": "bar"}), - lambda cfg: getattr(cfg, "fail"), - AttributeError, - "Key 'fail' in not in struct", - id="Dict, Struct: access invalid attribute", - ), - pytest.param( - lambda: create_struct({"foo": "bar"}), - lambda cfg: setattr(cfg, "zlonk", "zlank"), - AttributeError, - "Key 'zlonk' in not in struct", - id="Dict,Struct: set invalid attribute", - ), - pytest.param( - lambda: create_struct({"foo": "bar"}), - lambda cfg: cfg.__setitem__("zoo", "zonk"), - KeyError, - "Error setting zoo=zonk : Key 'zoo' in not in struct", - id="Dict, Struct: setitem on a key that does not exist", - ), - pytest.param( - lambda: create_struct({"foo": "bar"}), - lambda cfg: cfg.__getitem__("zoo"), - KeyError, - "Error getting 'zoo' : Key 'zoo' in not in struct", - ), - pytest.param( - lambda: create_readonly({"foo": "bar"}), - lambda cfg: OmegaConf.merge(cfg, OmegaConf.create()), - ReadonlyConfigError, - "Cannot merge into read-only node", - id="Dict, Readonly: merge into", # TODO: this should actually not throw an error - ), - pytest.param( - lambda: create_readonly({"foo": "bar"}), - lambda cfg: setattr(cfg, "foo", 20), - ReadonlyConfigError, - "Cannot assign to read-only node : 20", - id="Dict, Readonly: set attribute", - ), - pytest.param( - lambda: OmegaConf.structured(ConcretePlugin), - lambda cfg: cfg.merge_with(Plugin), - ValidationError, - "Plugin is not a subclass of ConcretePlugin. value: {'name': '???', 'params': '???'}", - id="Structured, Merge invalid dataclass", - ), - pytest.param( - lambda: OmegaConf.create(), - lambda cfg: cfg.get(IllegalType), - KeyValidationError, - "Incompatible key type 'type'", - id="dict:get_illegal_type", - ), - pytest.param( - lambda: OmegaConf.create(), - lambda cfg: cfg.get(IllegalType()), - KeyValidationError, - "Incompatible key type 'IllegalType'", - id="dict:get_object_of_illegal_type", - ), - pytest.param( - lambda: DictConfig(ref_type=Dict[Color, str], content={}), - lambda cfg: cfg["foo"], - KeyValidationError, - "Key 'foo' is incompatible with (Color)", - id="dict,reftype=Dict[Color,str]:,getitem_str_key", - ), - pytest.param( - lambda: DictConfig(ref_type=Dict[str, str], content={}), - lambda cfg: cfg[Color.RED], - KeyValidationError, - "Key Color.RED (Color) is incompatible with (str)", - id="dict,reftype=Dict[str,str]:,getitem_color_key", - ), - pytest.param( - lambda: OmegaConf.create( - {"foo": DictConfig(is_optional=False, content={})} - ), - lambda cfg: setattr(cfg, "foo", None), - ValidationError, - "field 'foo' is not Optional", - id="Dict: Assigning None to a non optional Dict", - ), - pytest.param( - lambda: None, - lambda cfg: OmegaConf.structured(NonOptionalAssignedNone), - ValidationError, - "Non optional field cannot be assigned None", - id="dict_create_none_optional_with_none", - ), - pytest.param( - lambda: None, - lambda cfg: OmegaConf.structured(IllegalType), - ValidationError, - "Input class 'IllegalType' is not a structured config. did you forget to decorate it as a dataclass?", - id="dict_create_from_illegal_type", - ), - pytest.param( - lambda: None, - lambda cfg: OmegaConf.structured(IllegalType()), - ValidationError, - "Object of unsupported type: 'IllegalType'", - id="structured:create_from_unsupported_object", - ), - ############### - # List - ############### - # get node - pytest.param( - lambda: OmegaConf.create([1, 2, 3]), - lambda cfg: cfg.get_node_ex("foo"), - TypeError, - "list indices must be integers or slices, not str", - id="list:get_nox_ex:invalid_index_type", - ), - pytest.param( - lambda: OmegaConf.create([1, 2, 3]), - lambda cfg: cfg.get_node_ex(20), - IndexError, - "list index out of range", - id="list:get_node_ex:index_out_of_range", - ), - pytest.param( - lambda: ListConfig(content=None), - lambda cfg: cfg.get_node_ex(20), - TypeError, - "Cannot get_node from a ListConfig object representing None", - id="list:get_node_none", - ), - pytest.param( - lambda: ListConfig(content="???"), - lambda cfg: cfg.get_node_ex(20), - MissingMandatoryValue, - "Cannot get_node from a missing ListConfig", - id="list:get_node_missing", - ), - # create - pytest.param( - lambda: None, - lambda cfg: ListConfig(is_optional=False, content=None), - ValidationError, - "Non optional ListConfig cannot be constructed from None", - id="list:create_not_optional_with_none", - ), - # append - pytest.param( - lambda: OmegaConf.create([]), - lambda cfg: cfg.append(IllegalType()), - UnsupportedValueType, - "Value 'IllegalType' is not a supported primitive type", - id="list:append_value_of_illegal_type", - ), - # pop - pytest.param( - lambda: create_readonly([1, 2, 3]), - lambda cfg: cfg.pop(0), - ReadonlyConfigError, - "Cannot pop from read-only ListConfig", - id="dict:readonly:pop", - ), - pytest.param( - lambda: create_readonly([1, 2, 3]), - lambda cfg: cfg.pop("Invalid key type"), - ReadonlyConfigError, - "Cannot pop from read-only ListConfig", - id="dict:readonly:pop", - ), - pytest.param( - lambda: ListConfig(content=None), - lambda cfg: cfg.pop(0), - TypeError, - "Cannot pop from a ListConfig object representing None", - id="list:pop_from_none", - ), - pytest.param( - lambda: ListConfig(content="???"), - lambda cfg: cfg.pop(0), - MissingMandatoryValue, - "Cannot pop from a missing ListConfig", - id="list:pop_from_missing", - ), - # getitem - pytest.param( - lambda: OmegaConf.create(["???"]), - lambda cfg: cfg.__getitem__(slice(0, 1)), - MissingMandatoryValue, - "Missing mandatory value: [slice(0, 1, None)]", - id="list:subscript_slice_with_missing", - ), - pytest.param( - lambda: OmegaConf.create([10, "???"]), - lambda cfg: cfg.__getitem__(1), - MissingMandatoryValue, - "Missing mandatory value: [1]", - id="list:subscript_index_with_missing", - ), - pytest.param( - lambda: OmegaConf.create([1, 2, 3]), - lambda cfg: cfg.__getitem__(20), - IndexError, - "list index out of range", - id="list:subscript:index_out_of_range", - ), - pytest.param( - lambda: OmegaConf.create([1, 2, 3]), - lambda cfg: cfg.__getitem__("foo"), - KeyValidationError, - "Invalid key type 'str'", - id="list:getitem,illegal_key_type", - ), - pytest.param( - lambda: ListConfig(content=None), - lambda cfg: cfg.__getitem__(0), - TypeError, - "ListConfig object representing None is not subscriptable", - id="list:getitem,illegal_key_type", - ), - # setitem - pytest.param( - lambda: OmegaConf.create([None]), - lambda cfg: cfg.__setitem__(0, IllegalType()), - UnsupportedValueType, - "Value 'IllegalType' is not a supported primitive type", - id="list:setitem,illegal_value_type", - ), - pytest.param( - lambda: OmegaConf.create([1, 2, 3]), - lambda cfg: cfg.__setitem__("foo", 4), - KeyValidationError, - "Invalid key type 'str'", - id="list:setitem,illegal_key_type", - ), - pytest.param( - lambda: create_readonly([1, 2, 3]), - lambda cfg: cfg.__setitem__(0, 4), - ReadonlyConfigError, - "ListConfig is read-only", - id="list,readonly:setitem", - ), - # assign - pytest.param( - lambda: ListConfig(ref_type=list, element_type=int, content=[1, 2, 3]), - lambda cfg: cfg.__setitem__(0, "foo"), - ValidationError, - "Value 'foo' could not be converted to Integer", - id="list,int_elements:assigned_str_element", - ), - pytest.param( - # make sure OmegaConf.create is not losing critical metadata. - lambda: OmegaConf.create( - ListConfig(ref_type=list, element_type=int, content=[1, 2, 3]) - ), - lambda cfg: cfg.__setitem__(0, "foo"), - ValidationError, - "Value 'foo' could not be converted to Integer", - id="list,int_elements:assigned_str_element", - ), - pytest.param( - lambda: OmegaConf.create([IntegerNode(is_optional=False, value=0), 2, 3]), - lambda cfg: cfg.__setitem__(0, None), - ValidationError, - "[0] is not optional and cannot be assigned None", - id="list,not_optional:assigned_none", - ), - # index - pytest.param( - lambda: create_readonly([1, 2, 3]), - lambda cfg: cfg.index(99), - ValueError, - "Item not found in ListConfig", - id="list,readonly:index_not_found", - ), - # insert - pytest.param( - lambda: create_readonly([1, 2, 3]), - lambda cfg: cfg.insert(1, 99), - ReadonlyConfigError, - "Cannot insert into a read-only ListConfig", - id="list,readonly:insert", - ), - pytest.param( - lambda: ListConfig(content=None), - lambda cfg: cfg.insert(1, 99), - TypeError, - "Cannot insert into ListConfig object representing None", - id="list:insert_into_none", - ), - pytest.param( - lambda: ListConfig(content="???"), - lambda cfg: cfg.insert(1, 99), - MissingMandatoryValue, - "Cannot insert into missing ListConfig", - id="list:insert_into_missing", - ), - # get - pytest.param( - lambda: ListConfig(content=None), - lambda cfg: cfg.get(0), - TypeError, - "Cannot get from a ListConfig object representing None", - id="list:get_from_none", - ), - pytest.param( - lambda: ListConfig(content="???"), - lambda cfg: cfg.get(0), - MissingMandatoryValue, - "Cannot get from a missing ListConfig", - id="list:get_from_missing", - ), - # sort - pytest.param( - lambda: create_readonly([1, 2, 3]), - lambda cfg: cfg.sort(), - ReadonlyConfigError, - "Cannot sort a read-only ListConfig", - id="list:readonly:sort", - ), - pytest.param( - lambda: ListConfig(content=None), - lambda cfg: cfg.sort(), - TypeError, - "Cannot sort a ListConfig object representing None", - id="list:sort_from_none", - ), - pytest.param( - lambda: ListConfig(content="???"), - lambda cfg: cfg.sort(), - MissingMandatoryValue, - "Cannot sort a missing ListConfig", - id="list:sort_from_missing", - ), - # iter - pytest.param( - lambda: create_readonly([1, 2, 3]), - lambda cfg: cfg.sort(), - ReadonlyConfigError, - "Cannot sort a read-only ListConfig", - id="list:readonly:sort", - ), - pytest.param( - lambda: ListConfig(content=None), - lambda cfg: iter(cfg), - TypeError, - "Cannot iterate on ListConfig object representing None", - id="list:iter_none", - ), - pytest.param( - lambda: ListConfig(content="???"), - lambda cfg: iter(cfg), - MissingMandatoryValue, - "Cannot iterate on a missing ListConfig", - id="list:iter_missing", - ), - # delete - ], + "create, op, exception_type, msg", params ) def test_errors(create: Any, op: Any, exception_type: Any, msg: str) -> None: cfg = create() @@ -447,3 +496,9 @@ def test_errors(create: Any, op: Any, exception_type: Any, msg: str) -> None: except Exception as e: # helps in debugging raise e + + +def create_struct(cfg: Any) -> Any: + cfg = OmegaConf.create(cfg) + OmegaConf.set_struct(cfg, True) + return cfg diff --git a/tests/test_get_full_key.py b/tests/test_get_full_key.py index 6c0acf768..eef4e7157 100644 --- a/tests/test_get_full_key.py +++ b/tests/test_get_full_key.py @@ -2,10 +2,10 @@ import pytest -from omegaconf import OmegaConf +from omegaconf import IntegerNode, OmegaConf -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "cfg, select, key, expected", [ ({}, "", "a", "a"), @@ -67,10 +67,23 @@ # special cases # parent_with_missing_item ({"x": "???", "a": 1, "b": {"c": 1}}, "b", "c", "b.c"), + ({"foo": IntegerNode(value=10)}, "", "foo", "foo"), + ({"foo": {"bar": IntegerNode(value=10)}}, "foo", "bar", "foo.bar"), ], ) -class TestGetFullKeyMatrix: - def test(self, cfg: Any, select: str, key: Any, expected: Any) -> None: - c = OmegaConf.create(cfg) - node = c.select(select) - assert node._get_full_key(key) == expected +def test_get_full_key_from_config( + cfg: Any, select: str, key: Any, expected: Any +) -> None: + c = OmegaConf.create(cfg) + node = c.select(select) + assert node._get_full_key(key) == expected + + +def test_value_node_get_full_key() -> None: + cfg = OmegaConf.create({"foo": IntegerNode(value=10)}) + assert cfg.get_node("foo")._get_full_key(None) == "foo" # type: ignore + + node = IntegerNode(value=10) + assert node._get_full_key(None) == "" + node = IntegerNode(key="foo", value=10) + assert node._get_full_key(None) == "foo" diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index dc5be45d5..d4bf56b48 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -1,5 +1,6 @@ import os import random +import re from typing import Any, Dict import pytest @@ -140,7 +141,9 @@ def test_env_interpolation1() -> None: def test_env_interpolation_not_found() -> None: c = OmegaConf.create({"path": "/test/${env:foobar}"}) - with pytest.raises(KeyError): + with pytest.raises( + ValidationError, match=re.escape("Environment variable 'foobar' not found") + ): c.path diff --git a/tests/test_merge.py b/tests/test_merge.py index e6c8138b9..5d6eaeedf 100644 --- a/tests/test_merge.py +++ b/tests/test_merge.py @@ -2,8 +2,15 @@ import pytest -from omegaconf import MISSING, DictConfig, OmegaConf, ValidationError, _utils, nodes -from omegaconf.errors import ReadonlyConfigError +from omegaconf import ( + MISSING, + DictConfig, + OmegaConf, + ReadonlyConfigError, + ValidationError, + nodes, +) +from omegaconf._utils import is_structured_config from . import ConcretePlugin, ConfWithMissingDict, Group, Plugin, User, Users @@ -72,12 +79,17 @@ pytest.raises(ValidationError), ), ([Plugin, ConcretePlugin], ConcretePlugin), + pytest.param( + [{"user": "???"}, {"user": Group}], + {"user": Group}, + id="merge_into_missing_node", + ), ], ) def test_merge(inputs: Any, expected: Any) -> None: configs = [OmegaConf.create(c) for c in inputs] - if isinstance(expected, (dict, list)) or _utils.is_structured_config(expected): + if isinstance(expected, (dict, list)) or is_structured_config(expected): merged = OmegaConf.merge(*configs) assert merged == expected # test input configs are not changed. @@ -114,7 +126,7 @@ def test_merge_no_eq_verify( def test_merge_with_1() -> None: a = OmegaConf.create() - b = OmegaConf.create(dict(a=1, b=2)) + b = OmegaConf.create({"a": 1, "b": 2}) a.merge_with(b) assert a == b @@ -123,12 +135,7 @@ def test_merge_with_2() -> None: a = OmegaConf.create() assert isinstance(a, DictConfig) a.inner = {} - b = OmegaConf.create( - """ - a : 1 - b : 2 - """ - ) + b = OmegaConf.create({"a": 1, "b": 2}) a.inner.merge_with(b) # type: ignore assert a.inner == b @@ -154,7 +161,7 @@ def test_merge_list_list() -> None: ({}, [], TypeError), ([], {}, TypeError), ([1, 2, 3], None, ValueError), - (dict(a=10), None, ValueError), + ({"a": 10}, None, ValueError), ], ) def test_merge_error(base: Any, merge: Any, exception: Any) -> None: @@ -164,23 +171,29 @@ def test_merge_error(base: Any, merge: Any, exception: Any) -> None: OmegaConf.merge(base, merge) -def test_into_readonly_dict() -> None: - cfg = OmegaConf.create({"foo": "bar"}) +@pytest.mark.parametrize( # type: ignore + "c1, c2", [({"foo": "bar"}, {"zoo": "foo"}), ([1, 2, 3], [4, 5, 6])] +) +def test_with_readonly(c1: Any, c2: Any) -> None: + cfg = OmegaConf.create(c1) OmegaConf.set_readonly(cfg, True) - with pytest.raises(ReadonlyConfigError): - OmegaConf.merge(cfg, {"zoo": "foo"}) + cfg2 = OmegaConf.merge(cfg, c2) + assert OmegaConf.is_readonly(cfg2) -def test_into_readonly_list() -> None: - cfg = OmegaConf.create([1, 2, 3]) +@pytest.mark.parametrize( # type: ignore + "c1, c2", [({"foo": "bar"}, {"zoo": "foo"}), ([1, 2, 3], [4, 5, 6])] +) +def test_into_readonly(c1: Any, c2: Any) -> None: + cfg = OmegaConf.create(c1) OmegaConf.set_readonly(cfg, True) with pytest.raises(ReadonlyConfigError): - OmegaConf.merge(cfg, [4, 5, 6]) + cfg.merge_with(c2) def test_parent_maintained() -> None: - c1 = OmegaConf.create(dict(a=dict(b=10))) - c2 = OmegaConf.create(dict(aa=dict(bb=100))) + c1 = OmegaConf.create({"a": {"b": 10}}) + c2 = OmegaConf.create({"aa": {"bb": 100}}) c3 = OmegaConf.merge(c1, c2) assert isinstance(c1, DictConfig) assert isinstance(c2, DictConfig) diff --git a/tests/test_omegaconf.py b/tests/test_omegaconf.py index c9966e82c..e0813dec9 100644 --- a/tests/test_omegaconf.py +++ b/tests/test_omegaconf.py @@ -346,12 +346,12 @@ def test_is_interpolation(fac): ({"foo": 10.0}, float), ({"foo": True}, bool), ({"foo": "bar"}, str), - ({"foo": None}, type(None)), # TODO: can this be None instead? + ({"foo": None}, type(None)), ({"foo": ConcretePlugin()}, ConcretePlugin), ({"foo": ConcretePlugin}, ConcretePlugin), - # ({"foo": {}}, dict), + ({"foo": {}}, dict), ({"foo": OmegaConf.create()}, dict), - # ({"foo": []}, list), + ({"foo": []}, list), ({"foo": OmegaConf.create([])}, list), ], ) diff --git a/tests/test_readonly.py b/tests/test_readonly.py index f903d23d5..29b93deff 100644 --- a/tests/test_readonly.py +++ b/tests/test_readonly.py @@ -151,5 +151,6 @@ def test_readonly_from_cli() -> None: assert isinstance(c, DictConfig) OmegaConf.set_readonly(c, True) cli = OmegaConf.from_dotlist(["foo.bar=[2]"]) - with raises(ReadonlyConfigError): - OmegaConf.merge(c, cli) + cfg2 = OmegaConf.merge(c, cli) + assert OmegaConf.is_readonly(c) + assert OmegaConf.is_readonly(cfg2) diff --git a/tests/test_utils.py b/tests/test_utils.py index 84e14696f..fd77ea1ae 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,7 +5,7 @@ import attr import pytest -from omegaconf import DictConfig, OmegaConf, _utils +from omegaconf import DictConfig, ListConfig, OmegaConf, _utils from omegaconf.errors import KeyValidationError, ValidationError from omegaconf.nodes import StringNode @@ -263,10 +263,34 @@ def test_get_key_value_types( dt = Dict[key_type, value_type] # type: ignore if expected_key_type is not None and issubclass(expected_key_type, Exception): with pytest.raises(expected_key_type): - _utils.get_key_value_types(dt) + _utils.get_dict_key_value_types(dt) else: - assert _utils.get_key_value_types(dt) == ( + assert _utils.get_dict_key_value_types(dt) == ( expected_key_type, expected_value_type, ) + + +@pytest.mark.parametrize( # type: ignore + "type_, is_primitive", + [ + (int, True), + (float, True), + (bool, True), + (str, True), + (type(None), True), + (Color, True), + (list, False), + (ListConfig, False), + (dict, False), + (DictConfig, False), + ], +) +def test_is_primitive_type(type_: Any, is_primitive: bool) -> None: + assert _utils.is_primitive_type(type_) == is_primitive + + +def test_deprectated_is_primitive_type() -> None: + with pytest.deprecated_call(): + _utils._is_primitive_type(int)