diff --git a/news/600.feature b/news/600.feature new file mode 100644 index 000000000..1fcf30f56 --- /dev/null +++ b/news/600.feature @@ -0,0 +1 @@ +The dollar character '$' is now allowed in interpolated key names, e.g. ${$var} diff --git a/omegaconf/grammar/OmegaConfGrammarLexer.g4 b/omegaconf/grammar/OmegaConfGrammarLexer.g4 index 4025c3b5e..7a81073f2 100644 --- a/omegaconf/grammar/OmegaConfGrammarLexer.g4 +++ b/omegaconf/grammar/OmegaConfGrammarLexer.g4 @@ -76,4 +76,9 @@ INTER_CLOSE: WS? '}' -> popMode; DOT: '.'; INTER_ID: ID -> type(ID); -INTER_KEY: ~[\\${}()[\]:. \t'"]+; // interpolation key, may contain any non special character + +// Interpolation key, may contain any non special character. +// Note that we can allow '$' because the parser does not support interpolations that +// are only part of a key name, i.e., "${foo${bar}}" is not allowed. As a result, it +// is ok to "consume" all '$' characters within the `INTER_KEY` token. +INTER_KEY: ~[\\{}()[\]:. \t'"]+; diff --git a/omegaconf/grammar_parser.py b/omegaconf/grammar_parser.py index 457252e88..181e702c9 100644 --- a/omegaconf/grammar_parser.py +++ b/omegaconf/grammar_parser.py @@ -19,12 +19,14 @@ # Build regex pattern to efficiently identify typical interpolations. # See test `test_match_simple_interpolation_pattern` for examples. _id = "[a-zA-Z_]\\w*" # foo, foo_bar, abc123 -_dot_path = f"(\\.)*({_id}(\\.{_id})*)?" # foo, foo.bar3, foo_.b4r.b0z -_inter_node = f"\\${{\\s*{_dot_path}\\s*}}" # node interpolation +_config_key = f"({_id}|\\$)+" # foo, $bar, $foo$bar$ +_node_path = f"(\\.)*({_config_key}(\\.{_config_key})*)?" # foo, .foo.$bar +_node_inter = f"\\${{\\s*{_node_path}\\s*}}" # node interpolation ${foo.bar} +_resolver_name = f"({_id}(\\.{_id})*)?" # foo, ns.bar3, ns_1.ns_2.b0z _arg = "[a-zA-Z_0-9/\\-\\+.$%*@]+" # string representing a resolver argument _args = f"{_arg}(\\s*,\\s*{_arg})*" # list of resolver arguments -_inter_res = f"\\${{\\s*{_dot_path}\\s*:\\s*{_args}?\\s*}}" # resolver interpolation -_inter = f"({_inter_node}|{_inter_res})" # any kind of interpolation +_resolver_inter = f"\\${{\\s*{_resolver_name}\\s*:\\s*{_args}?\\s*}}" # ${foo:bar} +_inter = f"({_node_inter}|{_resolver_inter})" # any kind of interpolation _outer = "([^$]|\\$(?!{))+" # any character except $ (unless not followed by {) SIMPLE_INTERPOLATION_PATTERN = re.compile( f"({_outer})?({_inter}({_outer})?)+$", flags=re.ASCII diff --git a/tests/test_grammar.py b/tests/test_grammar.py index 85ec68d25..963b9b11d 100644 --- a/tests/test_grammar.py +++ b/tests/test_grammar.py @@ -21,6 +21,10 @@ UnsupportedInterpolationType, ) +# Characters that are not allowed by the grammar in config key names. +INVALID_CHARS_IN_KEY_NAMES = "\\{}()[].: '\"" + + # A fixed config that may be used (but not modified!) by tests. BASE_TEST_CFG = OmegaConf.create( { @@ -32,11 +36,12 @@ "list": [x - 1 for x in range(11)], "null": None, # Special cases. - "x@y": 123, # to test keys with @ in name - "0": 0, # to test keys with int names - "1": {"2": 12}, # to test dot-path with int keys - "FalsE": {"TruE": True}, # to test keys with bool names - "None": {"null": 1}, # to test keys with null-like names + "x@y": 123, # @ in name + "$x$y$z$": 456, # $ in name (beginning, middle and end) + "0": 0, # integer name + "FalsE": {"TruE": True}, # bool name + "None": {"null": 1}, # null-like name + "1": {"2": 12}, # dot-path with int keys # Used in nested interpolations. "str_test": "test", "ref_str": "str", @@ -200,6 +205,7 @@ ("null_like_key_quoted_2", "${'None.null'}", GrammarParseError), ("dotpath_bad_type", "${dict.${float}}", (None, InterpolationResolutionError)), ("at_in_key", "${x@y}", 123), + ("dollar_in_key", "${$x$y$z$}", 456), # Interpolations in dictionaries. ("dict_interpolation_value", "{hi: ${str}, int: ${int}}", {"hi": "hi", "int": 123}), ("dict_interpolation_key", "{${str}: 0, ${null}: 1", GrammarParseError), @@ -524,6 +530,7 @@ def visit() -> Any: "${foo:bar,0,a-b+c*d/$.%@}", "\\${foo}", "${foo.bar:boz}", + "${$foo.bar$.x$y}", # relative interpolations "${.}", "${..}", @@ -549,6 +556,8 @@ def test_match_simple_interpolation_pattern(expression: str) -> None: "\\${foo", "${foo . bar}", "${ns . f:var}", + "${$foo:bar}", + "${.foo:bar}", ], ) def test_do_not_match_simple_interpolation_pattern(expression: str) -> None: @@ -634,3 +643,46 @@ def callback(inter_key: Any) -> Any: ) ret = visitor.visit(tree) assert ret == expected + + +def test_custom_resolver_param_supported_chars() -> None: + supported_chars = "abc123_/:-\\+.$%*@" + c = OmegaConf.create({"dir1": "${copy:" + supported_chars + "}"}) + + OmegaConf.register_new_resolver("copy", lambda x: x) + assert c.dir1 == supported_chars + + +def test_valid_chars_in_interpolation() -> None: + valid_chars = "".join( + chr(i) for i in range(33, 128) if chr(i) not in INVALID_CHARS_IN_KEY_NAMES + ) + cfg_dict = {valid_chars: 123, "inter": f"${{{valid_chars}}}"} + cfg = OmegaConf.create(cfg_dict) + # Test that we can access the node made of all valid characters, both + # directly and through interpolations. + assert cfg[valid_chars] == 123 + assert cfg.inter == 123 + + +@mark.parametrize("c", list(INVALID_CHARS_IN_KEY_NAMES)) +def test_invalid_chars_in_interpolation(c: str) -> None: + def create() -> DictConfig: + return OmegaConf.create({"invalid": f"${{ab{c}de}}"}) + + # Test that all invalid characters trigger errors in interpolations. + if c in [".", "}"]: + # With '.', we try to access `${ab.de}`. + # With '}', we try to access `${ab}`. + cfg = create() + with raises(InterpolationKeyError): + cfg.invalid + elif c == ":": + # With ':', we try to run a resolver `${ab:de}` + cfg = create() + with raises(UnsupportedInterpolationType): + cfg.invalid + else: + # Other invalid characters should be detected at creation time. + with raises(GrammarParseError): + create() diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index 287e3282e..bbf00e4d5 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -21,11 +21,9 @@ ) from omegaconf._utils import _ensure_container from omegaconf.errors import ( - GrammarParseError, InterpolationKeyError, InterpolationResolutionError, InterpolationValidationError, - UnsupportedInterpolationType, ) from . import MissingDict, MissingList, StructuredWithMissing, SubscriptedList, User @@ -35,9 +33,6 @@ # lines that do equality checks of the form # c.k == c.k -# Characters that are not allowed by the grammar in config key names. -INVALID_CHARS_IN_KEY_NAMES = "\\${}()[].: '\"" - def dereference(cfg: Container, key: Any) -> Node: node = cfg._get_node(key) @@ -605,49 +600,6 @@ def test_clear_cache(restore_resolvers: Any) -> None: assert old != c.k -def test_supported_chars() -> None: - supported_chars = "abc123_/:-\\+.$%*@" - c = OmegaConf.create({"dir1": "${copy:" + supported_chars + "}"}) - - OmegaConf.register_new_resolver("copy", lambda x: x) - assert c.dir1 == supported_chars - - -def test_valid_chars_in_key_names() -> None: - valid_chars = "".join( - chr(i) for i in range(33, 128) if chr(i) not in INVALID_CHARS_IN_KEY_NAMES - ) - cfg_dict = {valid_chars: 123, "inter": f"${{{valid_chars}}}"} - cfg = OmegaConf.create(cfg_dict) - # Test that we can access the node made of all valid characters, both - # directly and through interpolations. - assert cfg[valid_chars] == 123 - assert cfg.inter == 123 - - -@pytest.mark.parametrize("c", list(INVALID_CHARS_IN_KEY_NAMES)) -def test_invalid_chars_in_key_names(c: str) -> None: - def create() -> DictConfig: - return OmegaConf.create({"invalid": f"${{ab{c}de}}"}) - - # Test that all invalid characters trigger errors in interpolations. - if c in [".", "}"]: - # With '.', we try to access `${ab.de}`. - # With '}', we try to access `${ab}`. - cfg = create() - with pytest.raises(InterpolationKeyError): - cfg.invalid - elif c == ":": - # With ':', we try to run a resolver `${ab:de}` - cfg = create() - with pytest.raises(UnsupportedInterpolationType): - cfg.invalid - else: - # Other invalid characters should be detected at creation time. - with pytest.raises(GrammarParseError): - create() - - def test_interpolation_in_list_key_error() -> None: # Test that a KeyError is thrown if an str_interpolation key is not available c = OmegaConf.create(["${10}"])