diff --git a/docs/source/structured_config.rst b/docs/source/structured_config.rst index 706a9a25a..6d66aacb5 100644 --- a/docs/source/structured_config.rst +++ b/docs/source/structured_config.rst @@ -309,7 +309,7 @@ Optional fields Interpolations ^^^^^^^^^^^^^^ -:ref:`interpolation` works normally with Structured configs but static type checkers may object to you assigning a string to an other types. +:ref:`interpolation` works normally with Structured configs but static type checkers may object to you assigning a string to another type. To work around it, use SI and II described below. .. doctest:: @@ -333,18 +333,27 @@ To work around it, use SI and II described below. >>> assert conf.c == 100 -Type validation is performed on assignment, but not on values returned by interpolation, e.g: +Type validation (and implicit conversion when possible) is performed both on assignment and on values returned by interpolations, e.g: .. doctest:: - >>> from omegaconf import SI + >>> from omegaconf import II >>> @dataclass ... class Interpolation: - ... int_key: int = II("str_key") ... str_key: str = "string" + ... int_key: int = II("str_key") >>> cfg = OmegaConf.structured(Interpolation) - >>> assert cfg.int_key == "string" + >>> cfg.int_key # fails due to type mismatch + Traceback (most recent call last): + ... + omegaconf.errors.ValidationError: Value 'string' could not be converted to Integer + full_key: int_key + object_type=Interpolation + >>> cfg.str_key = 1234 # convert int to str (assignment) + >>> assert cfg.str_key == "1234" + >>> assert cfg.int_key == 1234 # convert str to int (interpolation) + Frozen ^^^^^^ diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 5850c3c9c..4ee589327 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -472,6 +472,19 @@ simply use quotes to bypass character limitations in strings. 'Hello, World' +Custom resolvers can return lists or dictionaries, that are automatically converted into config objects: + +.. doctest:: + + >>> OmegaConf.register_new_resolver( + ... "min_max", lambda *a: {"min": min(a), "max": max(a)} + ... ) + >>> c = OmegaConf.create({'stats': '${min_max: -1, 3, 2, 5, -10}'}) + >>> assert isinstance(c.stats, DictConfig) + >>> c.stats.min, c.stats.max + (-10, 5) + + You can take advantage of nested interpolations to perform custom operations over variables: .. doctest:: diff --git a/news/488.api_change b/news/488.api_change new file mode 100644 index 000000000..a9d5540ba --- /dev/null +++ b/news/488.api_change @@ -0,0 +1 @@ +If the value of a typed node is obtained from an interpolation, it is now validated (and possibly converted) based on the node's type. diff --git a/news/540.api_change b/news/540.api_change new file mode 100644 index 000000000..b81ba0144 --- /dev/null +++ b/news/540.api_change @@ -0,0 +1 @@ +A custom resolver interpolation whose output is a list or dictionary is now automatically converted into a ListConfig or DictConfig. diff --git a/omegaconf/base.py b/omegaconf/base.py index f9c93f015..6ad7c1501 100644 --- a/omegaconf/base.py +++ b/omegaconf/base.py @@ -23,6 +23,7 @@ MissingMandatoryValue, OmegaConfBaseException, UnsupportedInterpolationType, + ValidationError, ) from .grammar.gen.OmegaConfGrammarParser import OmegaConfGrammarParser from .grammar_parser import parse @@ -337,8 +338,6 @@ def _select_impl( ) -> Tuple[Optional["Container"], Optional[str], Optional[Node]]: """ Select a value using dot separated key sequence - :param key: - :return: """ from .omegaconf import _select_one @@ -400,7 +399,9 @@ def _resolve_interpolation_from_parse_tree( parse_tree: OmegaConfGrammarParser.ConfigValueContext, throw_on_resolution_failure: bool, ) -> Optional["Node"]: - from .nodes import StringNode + from .basecontainer import BaseContainer + from .nodes import AnyNode, ValueNode + from .omegaconf import _node_wrap try: resolved = self.resolve_parse_tree( @@ -413,14 +414,36 @@ def _resolve_interpolation_from_parse_tree( raise return None - assert resolved is not None - if isinstance(resolved, str): - # Result is a string: create a new StringNode for it. - return StringNode( - value=resolved, - key=key, + # If the output is not a Node already (e.g., because it is the output of a + # custom resolver), then we will need to wrap it within a Node. + must_wrap = not isinstance(resolved, Node) + + # If the node is typed, validate (and possibly convert) the result. + if isinstance(value, ValueNode) and not isinstance(value, AnyNode): + res_value = _get_value(resolved) + try: + conv_value = value.validate_and_convert(res_value) + except ValidationError as e: + if throw_on_resolution_failure: + self._format_and_raise(key=key, value=res_value, cause=e) + return None + + # If the same object is returned, it means the value is already valid + # "as is", and we can thus use it directly. Otherwise, the converted + # value has to be wrapped into a node. + if conv_value is not res_value: + must_wrap = True + resolved = conv_value + + if must_wrap: + assert parent is None or isinstance(parent, BaseContainer) + return _node_wrap( + type_=value._metadata.ref_type, parent=parent, is_optional=value._metadata.optional, + value=resolved, + key=key, + ref_type=value._metadata.ref_type, ) else: assert isinstance(resolved, Node) @@ -467,19 +490,10 @@ def _evaluate_custom_resolver( ) -> Any: from omegaconf import OmegaConf - from .nodes import ValueNode - resolver = OmegaConf.get_resolver(inter_type) if resolver is not None: root_node = self._get_root() - value = resolver(root_node, inter_args, inter_args_str) - return ValueNode( - value=value, - parent=self, - metadata=Metadata( - ref_type=Any, object_type=Any, key=key, optional=True - ), - ) + return resolver(root_node, inter_args, inter_args_str) else: raise UnsupportedInterpolationType( f"Unsupported interpolation type {inter_type}" diff --git a/omegaconf/grammar_visitor.py b/omegaconf/grammar_visitor.py index cc0017294..e9c977947 100644 --- a/omegaconf/grammar_visitor.py +++ b/omegaconf/grammar_visitor.py @@ -39,7 +39,7 @@ class GrammarVisitor(OmegaConfGrammarParserVisitor): def __init__( self, node_interpolation_callback: Callable[[str], Optional["Node"]], - resolver_interpolation_callback: Callable[..., Optional["Node"]], + resolver_interpolation_callback: Callable[..., Any], quoted_string_callback: Callable[[str], str], **kw: Dict[Any, Any], ): @@ -96,9 +96,7 @@ def visitConfigKey(self, ctx: OmegaConfGrammarParser.ConfigKeyContext) -> str: ) return child.symbol.text - def visitConfigValue( - self, ctx: OmegaConfGrammarParser.ConfigValueContext - ) -> Union[str, Optional["Node"]]: + def visitConfigValue(self, ctx: OmegaConfGrammarParser.ConfigValueContext) -> Any: # (toplevelStr | (toplevelStr? (interpolation toplevelStr?)+)) EOF # Visit all children (except last one which is EOF) vals = [self.visit(c) for c in list(ctx.getChildren())[:-1]] @@ -106,12 +104,8 @@ def visitConfigValue( if len(vals) == 1 and isinstance( ctx.getChild(0), OmegaConfGrammarParser.InterpolationContext ): - from .base import Node # noqa F811 - - # Single interpolation: return the resulting node "as is". - ret = vals[0] - assert ret is None or isinstance(ret, Node), ret - return ret + # Single interpolation: return the result "as is". + return vals[0] # Concatenation of multiple components. return "".join(map(str, vals)) @@ -135,13 +129,9 @@ def visitElement(self, ctx: OmegaConfGrammarParser.ElementContext) -> Any: def visitInterpolation( self, ctx: OmegaConfGrammarParser.InterpolationContext - ) -> Optional["Node"]: - from .base import Node # noqa F811 - + ) -> Any: assert ctx.getChildCount() == 1 # interpolationNode | interpolationResolver - ret = self.visit(ctx.getChild(0)) - assert ret is None or isinstance(ret, Node) - return ret + return self.visit(ctx.getChild(0)) def visitInterpolationNode( self, ctx: OmegaConfGrammarParser.InterpolationNodeContext @@ -168,7 +158,7 @@ def visitInterpolationNode( def visitInterpolationResolver( self, ctx: OmegaConfGrammarParser.InterpolationResolverContext - ) -> Optional["Node"]: + ) -> Any: # INTER_OPEN resolverName COLON sequence? BRACE_CLOSE assert 4 <= ctx.getChildCount() <= 5 diff --git a/tests/test_base_config.py b/tests/test_base_config.py index 8d1d263b5..fe7ec02ab 100644 --- a/tests/test_base_config.py +++ b/tests/test_base_config.py @@ -5,6 +5,7 @@ from pytest import raises from omegaconf import ( + AnyNode, Container, DictConfig, IntegerNode, @@ -510,7 +511,7 @@ def test_resolve_str_interpolation(query: str, result: Any) -> None: cfg._maybe_resolve_interpolation( parent=None, key=None, - value=StringNode(value=query), + value=AnyNode(value=query), throw_on_resolution_failure=True, ) == result diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index 8a912b657..3f142ee09 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -8,9 +8,11 @@ from omegaconf import ( II, + SI, Container, DictConfig, IntegerNode, + ListConfig, Node, OmegaConf, Resolver, @@ -26,7 +28,7 @@ UnsupportedInterpolationType, ) -from . import StructuredWithMissing +from . import StructuredWithMissing, User # file deepcode ignore CopyPasteError: # The above comment is a statement to stop DeepCode from raising a warning on @@ -747,3 +749,115 @@ def fail_if_called(x: Any) -> None: x_node = cfg._get_node("x") assert isinstance(x_node, Node) assert x_node._dereference_node(throw_on_resolution_failure=False) is None + + +@pytest.mark.parametrize( + ("cfg", "key", "expected_value", "expected_node_type"), + [ + pytest.param( + User(name="Bond", age=SI("${cast:int,'7'}")), + "age", + 7, + IntegerNode, + id="expected_type", + ), + pytest.param( + # This example specifically test the case where intermediate resolver results + # are not of the same type as the key. + User(name="Bond", age=SI("${cast:int,${drop_last:${drop_last:7xx}}}")), + "age", + 7, + IntegerNode, + id="intermediate_type_mismatch_ok", + ), + pytest.param( + User(name="Bond", age=SI("${cast:str,'7'}")), + "age", + 7, + IntegerNode, + id="convert_str_to_int", + ), + ], +) +def test_interpolation_type_validated_ok( + cfg: Any, + key: str, + expected_value: Any, + expected_node_type: Any, + restore_resolvers: Any, +) -> Any: + def cast(t: Any, v: Any) -> Any: + return {"str": str, "int": int}[t](v) # cast `v` to type `t` + + def drop_last(s: str) -> str: + return s[0:-1] # drop last character from string `s` + + OmegaConf.register_new_resolver("cast", cast) + OmegaConf.register_new_resolver("drop_last", drop_last) + + cfg = OmegaConf.structured(cfg) + + val = cfg[key] + assert val == expected_value + + node = cfg._get_node(key) + assert isinstance(node, Node) + assert isinstance(node._dereference_node(), expected_node_type) + + +@pytest.mark.parametrize( + ("cfg", "key", "expected_error"), + [ + pytest.param( + User(name="Bond", age=SI("${cast:str,seven}")), + "age", + pytest.raises( + ValidationError, + match=re.escape( + "Value 'seven' could not be converted to Integer\n full_key: age" + ), + ), + id="type_mismatch_resolver", + ), + pytest.param( + User(name="Bond", age=SI("${name}")), + "age", + pytest.raises( + ValidationError, + match=re.escape( + "Value 'Bond' could not be converted to Integer\n full_key: age" + ), + ), + id="type_mismatch_node_interpolation", + ), + ], +) +def test_interpolation_type_validated_error( + cfg: Any, + key: str, + expected_error: Any, + restore_resolvers: Any, +) -> Any: + def cast(t: Any, v: Any) -> Any: + return {"str": str, "int": int}[t](v) # cast `v` to type `t` + + OmegaConf.register_new_resolver("cast", cast) + + cfg = OmegaConf.structured(cfg) + + with expected_error: + cfg[key] + + +def test_resolver_output_dictconfig(restore_resolvers: Any) -> None: + OmegaConf.register_new_resolver("dict", lambda: {"a": 0, "b": 1}) + cfg = OmegaConf.create({"x": "${dict:}"}) + assert isinstance(cfg.x, DictConfig) + assert cfg.x.a == 0 and cfg.x.b == 1 + + +def test_resolver_output_listconfig(restore_resolvers: Any) -> None: + OmegaConf.register_new_resolver("list", lambda: [0, 1]) + cfg = OmegaConf.create({"x": "${list:}"}) + assert isinstance(cfg.x, ListConfig) + assert cfg.x == [0, 1] diff --git a/tests/test_matrix.py b/tests/test_matrix.py index 21007b374..6ef0309a9 100644 --- a/tests/test_matrix.py +++ b/tests/test_matrix.py @@ -176,7 +176,7 @@ def test_none_construction(self, node_type: Any, values: Any) -> None: def test_interpolation( self, node_type: Any, values: Any, restore_resolvers: Any, register_func: Any ) -> None: - resolver_output = 9999 + resolver_output = "9999" register_func("func", lambda: resolver_output) values = copy.deepcopy(values) for value in values: