Skip to content

Commit

Permalink
Only support the select syntax in oc.dict.{keys,values}
Browse files Browse the repository at this point in the history
  • Loading branch information
odelalleau committed Apr 7, 2021
1 parent 8dfb06d commit 27e59f8
Showing 1 changed file with 42 additions and 73 deletions.
115 changes: 42 additions & 73 deletions omegaconf/built_in_resolvers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import os
import warnings

# from collections.abc import Mapping, MutableMapping
from typing import Any, List, Mapping, Optional, Union
from typing import Any, List, Optional

from ._utils import _DEFAULT_MARKER_, Marker, _get_value, decode_primitive
from .base import Container
Expand Down Expand Up @@ -40,72 +38,40 @@ def decode(expr: Optional[str], _parent_: Container) -> Any:


def dict_keys(
in_dict: Union[str, Mapping[Any, Any]],
_root_: BaseContainer,
key: str,
_parent_: Container,
) -> ListConfig:
assert isinstance(_parent_, BaseContainer)

in_dict = _get_and_validate_dict_input(
in_dict, root=_root_, resolver_name="oc.dict.keys"
key, parent=_parent_, resolver_name="oc.dict.keys"
)
assert isinstance(_parent_, BaseContainer)

ret = OmegaConf.create(list(in_dict.keys()), parent=_parent_)
assert isinstance(ret, ListConfig)
return ret


def dict_values(
in_dict: Union[str, Mapping[Any, Any]], _root_: BaseContainer, _parent_: Container
) -> ListConfig:
def dict_values(key: str, _root_: BaseContainer, _parent_: Container) -> ListConfig:
assert isinstance(_parent_, BaseContainer)
in_dict = _get_and_validate_dict_input(
in_dict, root=_root_, resolver_name="oc.dict.values"
key, parent=_parent_, resolver_name="oc.dict.values"
)

if isinstance(in_dict, DictConfig):
# DictConfig objects are handled in a special way: the goal is to make the
# returned ListConfig point to the DictConfig nodes through interpolations.

dict_key: Optional[str] = None
if in_dict._get_root() is _root_:
# Try to obtain the full key through which we can access `in_dict`.
if in_dict is _root_:
dict_key = ""
else:
dict_key = in_dict._get_full_key(None)
if dict_key:
dict_key += "." # append dot for future concatenation
else:
# This can happen e.g. if `in_dict` is a transient node.
dict_key = None

if dict_key is None:
# No path to `in_dict` in the existing config.
raise NotImplementedError(
"`oc.dict.values` only supports input config nodes that "
"are accessible through the root config. See "
"https://github.com/omry/omegaconf/issues/650 for details."
)

ret = ListConfig([])
content = in_dict._content
assert isinstance(content, dict)

for key, node in content.items():
ref_node = AnyNode(f"${{{dict_key}{key}}}")
ret.append(ref_node)

# Finalize result by setting proper type and parent.
element_type: Any = in_dict._metadata.element_type
ret._metadata.element_type = element_type
ret._metadata.ref_type = List[element_type]
ret._set_parent(_parent_)

return ret

# Other dict-like object: simply create a ListConfig from its values.
assert isinstance(_parent_, BaseContainer)
ret = OmegaConf.create(list(in_dict.values()), parent=_parent_)
assert isinstance(ret, ListConfig)
ret = ListConfig([])
content = in_dict._content
assert isinstance(content, dict)

for k, node in content.items():
ref_node = AnyNode(f"${{{key}.{k}}}")
ret.append(ref_node)

# Finalize result by setting proper type and parent.
element_type: Any = in_dict._metadata.element_type
ret._metadata.element_type = element_type
ret._metadata.ref_type = List[element_type]
ret._set_parent(_parent_)

return ret


Expand Down Expand Up @@ -144,26 +110,29 @@ def legacy_env(key: str, default: Optional[str] = None) -> Any:


def _get_and_validate_dict_input(
in_dict: Union[str, Mapping[Any, Any]],
root: BaseContainer,
key: str,
parent: BaseContainer,
resolver_name: str,
) -> Mapping[Any, Any]:
if isinstance(in_dict, str):
# Path to an existing key in the config: use `select()`.
key = in_dict
if key.startswith("."):
raise NotImplementedError(
f"To use relative interpolations with `{resolver_name}`, please use "
f"the explicit interpolation syntax: ${{{resolver_name}:${{{key}}}}}"
)
in_dict = OmegaConf.select(
root, key, throw_on_missing=True, default=_DEFAULT_SELECT_MARKER_
) -> DictConfig:
if not isinstance(key, str):
raise TypeError(
f"`{resolver_name}` requires a string as input, but obtained `{key}` "
f"of type: {type(key).__name__}"
)
if in_dict is _DEFAULT_SELECT_MARKER_:
raise ConfigKeyError(f"Key not found: '{key}'")
assert in_dict is not None

if not isinstance(in_dict, Mapping):
in_dict = OmegaConf.select(
parent,
key,
throw_on_missing=True,
absolute_key=True,
default=_DEFAULT_SELECT_MARKER_,
)

if in_dict is _DEFAULT_SELECT_MARKER_:
raise ConfigKeyError(f"Key not found: '{key}'")
assert in_dict is not None

if not isinstance(in_dict, DictConfig):
raise TypeError(
f"`{resolver_name}` cannot be applied to objects of type: "
f"{type(in_dict).__name__}"
Expand Down

0 comments on commit 27e59f8

Please sign in to comment.