Skip to content

Commit

Permalink
Validate and convert interpolation results to their intended type
Browse files Browse the repository at this point in the history
This commit fixes several issues:

* For nested resolvers (ex: `${f:${g:x}}`), intermediate resolver
  outputs (of `g` in this example) were wrapped in a ValueNode just to
  be unwrapped immediately, which was wasteful. This commit pushes the
  node wrapping to the very last step of the interpolation resolution.

* There was no type checking to make sure that the result of an
  interpolation had a type consistent with the node's type (when
  specified). Now a check is made and the interpolation result may be
  converted into the desired type (see omry#488).

* If a resolver interpolation returns a dict / list, it is now wrapped
  into a DictConfig / ListConfig, instead of a ValueNode. This makes it
  possible to generate configs from resolvers (see omry#540)

Fixes omry#488
Fixes omry#540
  • Loading branch information
odelalleau committed Mar 3, 2021
1 parent cb5e556 commit 55ed2af
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 44 deletions.
19 changes: 14 additions & 5 deletions docs/source/structured_config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ Optional fields
Interpolations
^^^^^^^^^^^^^^

:ref:`interpolation` works normally with Structured configs but static type checkers may object to you assigning a string to an other types.
:ref:`interpolation` works normally with Structured configs but static type checkers may object to you assigning a string to another type.
To work around it, use SI and II described below.

.. doctest::
Expand All @@ -333,18 +333,27 @@ To work around it, use SI and II described below.
>>> assert conf.c == 100


Type validation is performed on assignment, but not on values returned by interpolation, e.g:
Type validation (and implicit conversion when possible) is performed both on assignment and on values returned by interpolations, e.g:

.. doctest::

>>> from omegaconf import SI
>>> from omegaconf import II
>>> @dataclass
... class Interpolation:
... int_key: int = II("str_key")
... str_key: str = "string"
... int_key: int = II("str_key")

>>> cfg = OmegaConf.structured(Interpolation)
>>> assert cfg.int_key == "string"
>>> cfg.int_key # fails due to type mismatch
Traceback (most recent call last):
...
omegaconf.errors.ValidationError: Value 'string' could not be converted to Integer
full_key: int_key
object_type=Interpolation
>>> cfg.str_key = 1234 # convert int to str (assignment)
>>> assert cfg.str_key == "1234"
>>> assert cfg.int_key == 1234 # convert str to int (interpolation)


Frozen
^^^^^^
Expand Down
13 changes: 13 additions & 0 deletions docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,19 @@ simply use quotes to bypass character limitations in strings.
'Hello, World'


Custom resolvers can return lists or dictionaries, that are automatically converted into config objects:

.. doctest::

>>> OmegaConf.register_new_resolver(
... "min_max", lambda *a: {"min": min(a), "max": max(a)}
... )
>>> c = OmegaConf.create({'stats': '${min_max: -1, 3, 2, 5, -10}'})
>>> assert isinstance(c.stats, DictConfig)
>>> c.stats.min, c.stats.max
(-10, 5)


You can take advantage of nested interpolations to perform custom operations over variables:

.. doctest::
Expand Down
1 change: 1 addition & 0 deletions news/488.api_change
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
If the value of a typed node is obtained from an interpolation, it is now validated (and possibly converted) based on the node's type.
1 change: 1 addition & 0 deletions news/540.api_change
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
A custom resolver interpolation whose output is a list or dictionary is now automatically converted into a ListConfig or DictConfig.
52 changes: 33 additions & 19 deletions omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
MissingMandatoryValue,
OmegaConfBaseException,
UnsupportedInterpolationType,
ValidationError,
)
from .grammar.gen.OmegaConfGrammarParser import OmegaConfGrammarParser
from .grammar_parser import parse
Expand Down Expand Up @@ -337,8 +338,6 @@ def _select_impl(
) -> Tuple[Optional["Container"], Optional[str], Optional[Node]]:
"""
Select a value using dot separated key sequence
:param key:
:return:
"""
from .omegaconf import _select_one

Expand Down Expand Up @@ -400,7 +399,9 @@ def _resolve_interpolation_from_parse_tree(
parse_tree: OmegaConfGrammarParser.ConfigValueContext,
throw_on_resolution_failure: bool,
) -> Optional["Node"]:
from .nodes import StringNode
from .basecontainer import BaseContainer
from .nodes import AnyNode, ValueNode
from .omegaconf import _node_wrap

try:
resolved = self.resolve_parse_tree(
Expand All @@ -413,14 +414,36 @@ def _resolve_interpolation_from_parse_tree(
raise
return None

assert resolved is not None
if isinstance(resolved, str):
# Result is a string: create a new StringNode for it.
return StringNode(
value=resolved,
key=key,
# If the output is not a Node already (e.g., because it is the output of a
# custom resolver), then we will need to wrap it within a Node.
must_wrap = not isinstance(resolved, Node)

# If the node is typed, validate (and possibly convert) the result.
if isinstance(value, ValueNode) and not isinstance(value, AnyNode):
res_value = _get_value(resolved)
try:
conv_value = value.validate_and_convert(res_value)
except ValidationError as e:
if throw_on_resolution_failure:
self._format_and_raise(key=key, value=res_value, cause=e)
return None

# If the same object is returned, it means the value is already valid
# "as is", and we can thus use it directly. Otherwise, the converted
# value has to be wrapped into a node.
if conv_value is not res_value:
must_wrap = True
resolved = conv_value

if must_wrap:
assert parent is None or isinstance(parent, BaseContainer)
return _node_wrap(
type_=value._metadata.ref_type,
parent=parent,
is_optional=value._metadata.optional,
value=resolved,
key=key,
ref_type=value._metadata.ref_type,
)
else:
assert isinstance(resolved, Node)
Expand Down Expand Up @@ -467,19 +490,10 @@ def _evaluate_custom_resolver(
) -> Any:
from omegaconf import OmegaConf

from .nodes import ValueNode

resolver = OmegaConf.get_resolver(inter_type)
if resolver is not None:
root_node = self._get_root()
value = resolver(root_node, inter_args, inter_args_str)
return ValueNode(
value=value,
parent=self,
metadata=Metadata(
ref_type=Any, object_type=Any, key=key, optional=True
),
)
return resolver(root_node, inter_args, inter_args_str)
else:
raise UnsupportedInterpolationType(
f"Unsupported interpolation type {inter_type}"
Expand Down
24 changes: 7 additions & 17 deletions omegaconf/grammar_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class GrammarVisitor(OmegaConfGrammarParserVisitor):
def __init__(
self,
node_interpolation_callback: Callable[[str], Optional["Node"]],
resolver_interpolation_callback: Callable[..., Optional["Node"]],
resolver_interpolation_callback: Callable[..., Any],
quoted_string_callback: Callable[[str], str],
**kw: Dict[Any, Any],
):
Expand Down Expand Up @@ -96,22 +96,16 @@ def visitConfigKey(self, ctx: OmegaConfGrammarParser.ConfigKeyContext) -> str:
)
return child.symbol.text

def visitConfigValue(
self, ctx: OmegaConfGrammarParser.ConfigValueContext
) -> Union[str, Optional["Node"]]:
def visitConfigValue(self, ctx: OmegaConfGrammarParser.ConfigValueContext) -> Any:
# (toplevelStr | (toplevelStr? (interpolation toplevelStr?)+)) EOF
# Visit all children (except last one which is EOF)
vals = [self.visit(c) for c in list(ctx.getChildren())[:-1]]
assert vals
if len(vals) == 1 and isinstance(
ctx.getChild(0), OmegaConfGrammarParser.InterpolationContext
):
from .base import Node # noqa F811

# Single interpolation: return the resulting node "as is".
ret = vals[0]
assert ret is None or isinstance(ret, Node), ret
return ret
# Single interpolation: return the result "as is".
return vals[0]
# Concatenation of multiple components.
return "".join(map(str, vals))

Expand All @@ -135,13 +129,9 @@ def visitElement(self, ctx: OmegaConfGrammarParser.ElementContext) -> Any:

def visitInterpolation(
self, ctx: OmegaConfGrammarParser.InterpolationContext
) -> Optional["Node"]:
from .base import Node # noqa F811

) -> Any:
assert ctx.getChildCount() == 1 # interpolationNode | interpolationResolver
ret = self.visit(ctx.getChild(0))
assert ret is None or isinstance(ret, Node)
return ret
return self.visit(ctx.getChild(0))

def visitInterpolationNode(
self, ctx: OmegaConfGrammarParser.InterpolationNodeContext
Expand All @@ -168,7 +158,7 @@ def visitInterpolationNode(

def visitInterpolationResolver(
self, ctx: OmegaConfGrammarParser.InterpolationResolverContext
) -> Optional["Node"]:
) -> Any:

# INTER_OPEN resolverName COLON sequence? BRACE_CLOSE
assert 4 <= ctx.getChildCount() <= 5
Expand Down
3 changes: 2 additions & 1 deletion tests/test_base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pytest import raises

from omegaconf import (
AnyNode,
Container,
DictConfig,
IntegerNode,
Expand Down Expand Up @@ -510,7 +511,7 @@ def test_resolve_str_interpolation(query: str, result: Any) -> None:
cfg._maybe_resolve_interpolation(
parent=None,
key=None,
value=StringNode(value=query),
value=AnyNode(value=query),
throw_on_resolution_failure=True,
)
== result
Expand Down
116 changes: 115 additions & 1 deletion tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@

from omegaconf import (
II,
SI,
Container,
DictConfig,
IntegerNode,
ListConfig,
Node,
OmegaConf,
Resolver,
Expand All @@ -26,7 +28,7 @@
UnsupportedInterpolationType,
)

from . import StructuredWithMissing
from . import StructuredWithMissing, User

# file deepcode ignore CopyPasteError:
# The above comment is a statement to stop DeepCode from raising a warning on
Expand Down Expand Up @@ -747,3 +749,115 @@ def fail_if_called(x: Any) -> None:
x_node = cfg._get_node("x")
assert isinstance(x_node, Node)
assert x_node._dereference_node(throw_on_resolution_failure=False) is None


@pytest.mark.parametrize(
("cfg", "key", "expected_value", "expected_node_type"),
[
pytest.param(
User(name="Bond", age=SI("${cast:int,'7'}")),
"age",
7,
IntegerNode,
id="expected_type",
),
pytest.param(
# This example specifically test the case where intermediate resolver results
# are not of the same type as the key.
User(name="Bond", age=SI("${cast:int,${drop_last:${drop_last:7xx}}}")),
"age",
7,
IntegerNode,
id="intermediate_type_mismatch_ok",
),
pytest.param(
User(name="Bond", age=SI("${cast:str,'7'}")),
"age",
7,
IntegerNode,
id="convert_str_to_int",
),
],
)
def test_interpolation_type_validated_ok(
cfg: Any,
key: str,
expected_value: Any,
expected_node_type: Any,
restore_resolvers: Any,
) -> Any:
def cast(t: Any, v: Any) -> Any:
return {"str": str, "int": int}[t](v) # cast `v` to type `t`

def drop_last(s: str) -> str:
return s[0:-1] # drop last character from string `s`

OmegaConf.register_new_resolver("cast", cast)
OmegaConf.register_new_resolver("drop_last", drop_last)

cfg = OmegaConf.structured(cfg)

val = cfg[key]
assert val == expected_value

node = cfg._get_node(key)
assert isinstance(node, Node)
assert isinstance(node._dereference_node(), expected_node_type)


@pytest.mark.parametrize(
("cfg", "key", "expected_error"),
[
pytest.param(
User(name="Bond", age=SI("${cast:str,seven}")),
"age",
pytest.raises(
ValidationError,
match=re.escape(
"Value 'seven' could not be converted to Integer\n full_key: age"
),
),
id="type_mismatch_resolver",
),
pytest.param(
User(name="Bond", age=SI("${name}")),
"age",
pytest.raises(
ValidationError,
match=re.escape(
"Value 'Bond' could not be converted to Integer\n full_key: age"
),
),
id="type_mismatch_node_interpolation",
),
],
)
def test_interpolation_type_validated_error(
cfg: Any,
key: str,
expected_error: Any,
restore_resolvers: Any,
) -> Any:
def cast(t: Any, v: Any) -> Any:
return {"str": str, "int": int}[t](v) # cast `v` to type `t`

OmegaConf.register_new_resolver("cast", cast)

cfg = OmegaConf.structured(cfg)

with expected_error:
cfg[key]


def test_resolver_output_dictconfig(restore_resolvers: Any) -> None:
OmegaConf.register_new_resolver("dict", lambda: {"a": 0, "b": 1})
cfg = OmegaConf.create({"x": "${dict:}"})
assert isinstance(cfg.x, DictConfig)
assert cfg.x.a == 0 and cfg.x.b == 1


def test_resolver_output_listconfig(restore_resolvers: Any) -> None:
OmegaConf.register_new_resolver("list", lambda: [0, 1])
cfg = OmegaConf.create({"x": "${list:}"})
assert isinstance(cfg.x, ListConfig)
assert cfg.x == [0, 1]
2 changes: 1 addition & 1 deletion tests/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def test_none_construction(self, node_type: Any, values: Any) -> None:
def test_interpolation(
self, node_type: Any, values: Any, restore_resolvers: Any, register_func: Any
) -> None:
resolver_output = 9999
resolver_output = "9999"
register_func("func", lambda: resolver_output)
values = copy.deepcopy(values)
for value in values:
Expand Down

0 comments on commit 55ed2af

Please sign in to comment.