Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type validation of interpolations #578

Merged
merged 22 commits into from
Mar 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
21224f2
Validate and convert interpolation results to their intended type
odelalleau Feb 12, 2021
9923827
Make `ValueNode` an abstract class
odelalleau Mar 10, 2021
1749581
Check if conversion is needed based on type rather than object identity
odelalleau Mar 5, 2021
a38eb8a
Consistent exceptions when resolving interpolations
odelalleau Mar 9, 2021
69a32d2
Set wrapped nodes as read-only
odelalleau Mar 9, 2021
dca7e01
Update news item according to suggestion
odelalleau Mar 10, 2021
36167ba
Add news item to describe the new feature from #540
odelalleau Mar 10, 2021
6159f3e
InterpolationResolutionError also inherits from ValidationError
odelalleau Mar 10, 2021
3d350db
Add docstring for `_resolve_interpolation_from_parse_tree()`
odelalleau Mar 10, 2021
d09277d
Refactor: split `_resolve_interpolation_from_parse_tree()`
odelalleau Mar 10, 2021
a99f776
More accurate comment
odelalleau Mar 11, 2021
0d63bf9
Update tests/test_interpolation.py
odelalleau Mar 11, 2021
c2c1c03
Use a fixture instead of duplicating the identity resolver
odelalleau Mar 11, 2021
4af8eb5
Add comment to explain test
odelalleau Mar 11, 2021
804845f
Clearer test names
odelalleau Mar 11, 2021
96e55f3
Move `cast` resolver to fixture in interpolation tests
odelalleau Mar 12, 2021
b9a3197
Add comment to explain test
odelalleau Mar 12, 2021
fefffc1
Check readonly flag directly
odelalleau Mar 12, 2021
5080322
Update tests/test_interpolation.py
odelalleau Mar 12, 2021
572e125
Check readonly flag for dict/list resolver outputs
odelalleau Mar 12, 2021
ce92f60
Fix syntax error
odelalleau Mar 12, 2021
57ee921
Remove now redundant definition of `cast()`
odelalleau Mar 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions docs/source/structured_config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ Optional fields
Interpolations
^^^^^^^^^^^^^^

:ref:`interpolation` works normally with Structured configs but static type checkers may object to you assigning a string to an other types.
:ref:`interpolation` works normally with Structured configs but static type checkers may object to you assigning a string to another type.
To work around it, use SI and II described below.

.. doctest::
Expand All @@ -333,18 +333,38 @@ To work around it, use SI and II described below.
>>> assert conf.c == 100


Type validation is performed on assignment, but not on values returned by interpolation, e.g:
Interpolated values are validated, and converted when possible, to the annotated type when the interpolation is accessed, e.g:

.. doctest::

>>> from omegaconf import SI
>>> from omegaconf import II
>>> @dataclass
... class Interpolation:
... int_key: int = II("str_key")
... str_key: str = "string"
... int_key: int = II("str_key")

>>> cfg = OmegaConf.structured(Interpolation)
>>> assert cfg.int_key == "string"
>>> cfg.int_key # fails due to type mismatch
Traceback (most recent call last):
...
omegaconf.errors.InterpolationValidationError: Value 'string' could not be converted to Integer
full_key: int_key
object_type=Interpolation
>>> cfg.str_key = "1234" # string value
>>> assert cfg.int_key == 1234 # automatically convert str to int

Note however that this validation step is currently skipped for container node interpolations:

.. doctest::

>>> @dataclass
... class NotValidated:
... some_int: int = 0
... some_dict: Dict[str, str] = II("some_int")

>>> cfg = OmegaConf.structured(NotValidated)
>>> assert cfg.some_dict == 0 # type mismatch, but no error


Frozen
^^^^^^
Expand Down
13 changes: 13 additions & 0 deletions docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,19 @@ simply use quotes to bypass character limitations in strings.
'Hello, World'


Custom resolvers can return lists or dictionaries, that are automatically converted into DictConfig and ListConfig:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am starting to think that we need a dedicated page for interpolations (not as a part of this PR).
Interpoaltion docs are now about 300 lines.

We can keep the really minimal use cases in usage and mention more advanced use cases and point to the interpolations doc page.

Copy link
Collaborator Author

@odelalleau odelalleau Mar 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am starting to think that we need a dedicated page for interpolations (not as a part of this PR).

Added to #535 => resolving

Edit 2021-11-22: actually unresolving as otherwise linking to this discussion doesn't work


.. doctest::

>>> OmegaConf.register_new_resolver(
... "min_max", lambda *a: {"min": min(a), "max": max(a)}
... )
>>> c = OmegaConf.create({'stats': '${min_max: -1, 3, 2, 5, -10}'})
>>> assert isinstance(c.stats, DictConfig)
>>> c.stats.min, c.stats.max
(-10, 5)


You can take advantage of nested interpolations to perform custom operations over variables:

.. doctest::
Expand Down
1 change: 1 addition & 0 deletions news/488.api_change
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
When resolving an interpolation of a typed config value, the interpolated value is validated and possibly converted based on the node's type.
1 change: 1 addition & 0 deletions news/540.api_change
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
A custom resolver interpolation whose output is a list or dictionary is now automatically converted into a ListConfig or DictConfig.
omry marked this conversation as resolved.
Show resolved Hide resolved
odelalleau marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions news/540.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Custom resolvers can now generate transient config nodes dynamically.
141 changes: 119 additions & 22 deletions omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@
InterpolationKeyError,
InterpolationResolutionError,
InterpolationToMissingValueError,
InterpolationValidationError,
KeyValidationError,
MissingMandatoryValue,
OmegaConfBaseException,
UnsupportedInterpolationType,
ValidationError,
)
from .grammar.gen.OmegaConfGrammarParser import OmegaConfGrammarParser
from .grammar_parser import parse
Expand Down Expand Up @@ -353,8 +356,6 @@ def _select_impl(
) -> Tuple[Optional["Container"], Optional[str], Optional[Node]]:
"""
Select a value using dot separated key sequence
:param key:
:return:
"""
from .omegaconf import _select_one

Expand Down Expand Up @@ -416,8 +417,34 @@ def _resolve_interpolation_from_parse_tree(
parse_tree: OmegaConfGrammarParser.ConfigValueContext,
throw_on_resolution_failure: bool,
) -> Optional["Node"]:
omry marked this conversation as resolved.
Show resolved Hide resolved
from .nodes import StringNode

"""
Resolve an interpolation.
odelalleau marked this conversation as resolved.
Show resolved Hide resolved

This happens in two steps:
1. The parse tree is visited, which outputs either a `Node` (e.g.,
for node interpolations "${foo}"), a string (e.g., for string
interpolations "hello ${name}", or any other arbitrary value
(e.g., or custom interpolations "${foo:bar}").
2. This output is potentially validated and converted when the node
being resolved (`value`) is typed.

If an error occurs in one of the above steps, an `InterpolationResolutionError`
(or a subclass of it) is raised, *unless* `throw_on_resolution_failure` is set
to `False` (in which case the return value is `None`).

:param parent: Parent of the node being resolved.
:param value: Node being resolved.
:param key: The associated key in the parent.
:param parse_tree: The parse tree as obtained from `grammar_parser.parse()`.
:param throw_on_resolution_failure: If `False`, then exceptions raised during
the resolution of the interpolation are silenced, and instead `None` is
returned.

:return: A `Node` that contains the interpolation result. This may be an existing
node in the config (in the case of a node interpolation "${foo}"), or a new
node that is created to wrap the interpolated value. It is `None` if and only if
`throw_on_resolution_failure` is `False` and an error occurs during resolution.
"""
try:
resolved = self.resolve_parse_tree(
parse_tree=parse_tree,
Expand All @@ -429,19 +456,98 @@ def _resolve_interpolation_from_parse_tree(
raise
return None

assert resolved is not None
if isinstance(resolved, str):
# Result is a string: create a new StringNode for it.
return StringNode(
value=resolved,
key=key,
return self._validate_and_convert_interpolation_result(
parent=parent,
value=value,
key=key,
resolved=resolved,
throw_on_resolution_failure=throw_on_resolution_failure,
)

def _validate_and_convert_interpolation_result(
self,
parent: Optional["Container"],
value: "Node",
key: Any,
resolved: Any,
throw_on_resolution_failure: bool,
) -> Optional["Node"]:
from .nodes import AnyNode, ValueNode

# If the output is not a Node already (e.g., because it is the output of a
# custom resolver), then we will need to wrap it within a Node.
must_wrap = not isinstance(resolved, Node)
omry marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is now a kilometer long.
Refactor it to reduce the size?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in d09277d


# If the node is typed, validate (and possibly convert) the result.
if isinstance(value, ValueNode) and not isinstance(value, AnyNode):
Comment on lines +481 to +482
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I am okay with not validating config nodes at this time (could be expensive), I think the docs should reflect that.

The use case is for custom resolvers returning a config:

@dataclass
class Client:
  ports: List[int] = II("oc.decode:${oc.env:PORTS}")

# env.PORTS="[80,8080]" => ListConfig([80, 8080], element_type=int)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a doc on the lack of validation for container nodes in 1a77110

I tried to keep it simple, so I didn't go into the details of how exactly validation can still occur for resolver interpolations.

res_value = _get_value(resolved)
try:
conv_value = value.validate_and_convert(res_value)
except ValidationError as e:
if throw_on_resolution_failure:
self._format_and_raise(
key=key,
value=res_value,
cause=e,
type_override=InterpolationValidationError,
)
return None

# If the converted value is of the same type, it means that no conversion
# was actually needed. As a result, we can keep the original `resolved`
# (and otherwise, the converted value must be wrapped into a new node).
if type(conv_value) != type(res_value):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deducing it from type of the result type seems fragile.
Maybe validate_and_convert should return a second boolean flag indicating if it converted or not?
Alternatively, we can break it into two function:

def is_valid_value(self, value) -> bool:
  ...

def convert_value(self, value) -> Any:
  ...

def validate_and_convert(self, value) -> Any:
  if self.is_valid_value(value):
    return value
  else:
    return self.convert_value(value)

This can be a sizeable change. If we do it I suggest we do it as a standalone diff independently of this diff. Once it's in - this diff can utilize the more fine-grained API instead of calling validate_and_convert().

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deducing it from type of the result type seems fragile.
Maybe validate_and_convert should return a second boolean flag indicating if it converted or not?

Or, we impose the contract that it should return the same object if it is not converted -- back to my first version :)

Alternatively, we can break it into two function

A potential concern with breaking into two functions is that it may do the same checks twice.

Let me know what you think makes most sense for this PR, and I'll go with it.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having the same type checks twice during assignment is minor.

Breaking the function into two will also play better when we try to deprecate the automatic conversion on assignment of primitive types. (as opposed to automatic conversion on merge, which I think we want to keep).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok but how should I proceed for this PR? I think there are 4 options:

  1. Keep current check based on type
  2. Go back to previous version where the check was based on the object returned being the same
  3. Have validate_and_convert() return a boolean flag indicating whether a change was made
  4. Split validate_and_convert() in two functions

must_wrap = True
resolved = conv_value

if must_wrap:
return self._wrap_interpolation_result(
parent=parent,
is_optional=value._metadata.optional,
value=value,
key=key,
resolved=resolved,
throw_on_resolution_failure=throw_on_resolution_failure,
)
else:
assert isinstance(resolved, Node)
return resolved

def _wrap_interpolation_result(
self,
parent: Optional["Container"],
value: "Node",
key: Any,
resolved: Any,
throw_on_resolution_failure: bool,
) -> Optional["Node"]:
from .basecontainer import BaseContainer
from .omegaconf import _node_wrap

assert parent is None or isinstance(parent, BaseContainer)
try:
wrapped = _node_wrap(
type_=value._metadata.ref_type,
parent=parent,
is_optional=value._metadata.optional,
value=resolved,
key=key,
ref_type=value._metadata.ref_type,
)
except (KeyValidationError, ValidationError) as e:
if throw_on_resolution_failure:
self._format_and_raise(
key=key,
value=resolved,
cause=e,
type_override=InterpolationValidationError,
)
return None
# Since we created a new node on the fly, future changes to this node are
# likely to be lost. We thus set the "readonly" flag to `True` to reduce
# the risk of accidental modifications.
wrapped._set_flag("readonly", True)
return wrapped

def _resolve_node_interpolation(
self,
inter_key: str,
Expand Down Expand Up @@ -483,19 +589,10 @@ def _evaluate_custom_resolver(
) -> Any:
from omegaconf import OmegaConf

from .nodes import ValueNode

resolver = OmegaConf.get_resolver(inter_type)
if resolver is not None:
root_node = self._get_root()
value = resolver(root_node, inter_args, inter_args_str)
return ValueNode(
value=value,
parent=self,
metadata=Metadata(
ref_type=Any, object_type=Any, key=key, optional=True
),
)
return resolver(root_node, inter_args, inter_args_str)
else:
raise UnsupportedInterpolationType(
f"Unsupported interpolation type {inter_type}"
Expand Down Expand Up @@ -556,7 +653,7 @@ def quoted_string_callback(quoted_str: str) -> str:
value=quoted_str,
key=key,
parent=parent,
is_optional=False,
is_optional=True,
odelalleau marked this conversation as resolved.
Show resolved Hide resolved
),
throw_on_resolution_failure=True,
)
Expand Down
4 changes: 1 addition & 3 deletions omegaconf/dictconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,7 @@ def _s_validate_and_normalize_key(self, key_type: Any, key: Any) -> DictKeyType:
return key # type: ignore
elif issubclass(key_type, Enum):
try:
ret = EnumNode.validate_and_convert_to_enum(key_type, key)
assert ret is not None
return ret
return EnumNode.validate_and_convert_to_enum(key_type, key)
except ValidationError:
valid = ", ".join([x for x in key_type.__members__.keys()])
raise KeyValidationError(
Expand Down
6 changes: 6 additions & 0 deletions omegaconf/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ class InterpolationToMissingValueError(InterpolationResolutionError):
"""


class InterpolationValidationError(InterpolationResolutionError, ValidationError):
"""
Thrown when the result of an interpolation fails the validation step.
"""


class ConfigKeyError(OmegaConfBaseException, KeyError):
"""
Thrown from DictConfig when a regular dict access would have caused a KeyError.
Expand Down
24 changes: 7 additions & 17 deletions omegaconf/grammar_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class GrammarVisitor(OmegaConfGrammarParserVisitor):
def __init__(
self,
node_interpolation_callback: Callable[[str], Optional["Node"]],
resolver_interpolation_callback: Callable[..., Optional["Node"]],
resolver_interpolation_callback: Callable[..., Any],
quoted_string_callback: Callable[[str], str],
**kw: Dict[Any, Any],
):
Expand Down Expand Up @@ -96,22 +96,16 @@ def visitConfigKey(self, ctx: OmegaConfGrammarParser.ConfigKeyContext) -> str:
)
return child.symbol.text

def visitConfigValue(
self, ctx: OmegaConfGrammarParser.ConfigValueContext
) -> Union[str, Optional["Node"]]:
def visitConfigValue(self, ctx: OmegaConfGrammarParser.ConfigValueContext) -> Any:
# (toplevelStr | (toplevelStr? (interpolation toplevelStr?)+)) EOF
# Visit all children (except last one which is EOF)
vals = [self.visit(c) for c in list(ctx.getChildren())[:-1]]
assert vals
if len(vals) == 1 and isinstance(
ctx.getChild(0), OmegaConfGrammarParser.InterpolationContext
):
from .base import Node # noqa F811

# Single interpolation: return the resulting node "as is".
ret = vals[0]
assert ret is None or isinstance(ret, Node), ret
return ret
# Single interpolation: return the result "as is".
return vals[0]
# Concatenation of multiple components.
return "".join(map(str, vals))

Expand All @@ -135,13 +129,9 @@ def visitElement(self, ctx: OmegaConfGrammarParser.ElementContext) -> Any:

def visitInterpolation(
self, ctx: OmegaConfGrammarParser.InterpolationContext
) -> Optional["Node"]:
from .base import Node # noqa F811

) -> Any:
assert ctx.getChildCount() == 1 # interpolationNode | interpolationResolver
ret = self.visit(ctx.getChild(0))
assert ret is None or isinstance(ret, Node)
return ret
return self.visit(ctx.getChild(0))

def visitInterpolationNode(
self, ctx: OmegaConfGrammarParser.InterpolationNodeContext
Expand All @@ -168,7 +158,7 @@ def visitInterpolationNode(

def visitInterpolationResolver(
self, ctx: OmegaConfGrammarParser.InterpolationResolverContext
) -> Optional["Node"]:
) -> Any:

# INTER_OPEN resolverName COLON sequence? BRACE_CLOSE
assert 4 <= ctx.getChildCount() <= 5
Expand Down
4 changes: 3 additions & 1 deletion omegaconf/nodes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import math
import sys
from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, Optional, Type, Union

Expand Down Expand Up @@ -55,8 +56,9 @@ def validate_and_convert(self, value: Any) -> Any:
# Subclasses can assume that `value` is not None in `_validate_and_convert_impl()`.
return self._validate_and_convert_impl(value)

@abstractmethod
def _validate_and_convert_impl(self, value: Any) -> Any:
return value
...

def __str__(self) -> str:
return str(self._val)
Expand Down
Loading