diff --git a/hydra/_internal/grammar/utils.py b/hydra/_internal/grammar/utils.py index 765bf9365fe..459e8413d91 100644 --- a/hydra/_internal/grammar/utils.py +++ b/hydra/_internal/grammar/utils.py @@ -1,9 +1,37 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import inspect +import re from typing import Any, Union from omegaconf._utils import is_dict_annotation, is_list_annotation +# All characters that must be escaped (must match the ESC grammar lexer token). +_ESC = "\\()[]{}:=, \t" + +# Regular expression that matches any sequence of characters in `_ESC`. +_ESC_REGEX = re.compile(f"[{re.escape(_ESC)}]+") + + +def escape_special_characters(s: str) -> str: + """Escape special characters in `s`""" + matches = _ESC_REGEX.findall(s) + if not matches: + return s + # Replace all special characters found in `s`. Performance should not be critical + # so we do one pass per special character. + all_special = set("".join(matches)) + # '\' is even more special: it needs to be replaced first, otherwise we will + # mess up the other escaped characters. + try: + all_special.remove("\\") + except KeyError: + pass # no '\' in the string + else: + s = s.replace("\\", "\\\\") + for special_char in all_special: + s = s.replace(special_char, f"\\{special_char}") + return s + def is_type_matching(value: Any, type_: Any) -> bool: # Union diff --git a/hydra/core/override_parser/overrides_visitor.py b/hydra/core/override_parser/overrides_visitor.py index 36b5a86d82b..eedf239e484 100644 --- a/hydra/core/override_parser/overrides_visitor.py +++ b/hydra/core/override_parser/overrides_visitor.py @@ -3,7 +3,7 @@ import warnings from typing import Any, Dict, List, Optional, Tuple, Union -from antlr4 import TerminalNode, Token +from antlr4 import ParserRuleContext, TerminalNode, Token from antlr4.error.ErrorListener import ErrorListener from antlr4.tree.Tree import TerminalNodeImpl @@ -78,74 +78,10 @@ def is_ws(self, c: Any) -> bool: def visitPrimitive( self, ctx: OverrideParser.PrimitiveContext ) -> Optional[Union[QuotedString, int, bool, float, str]]: - ret: Optional[Union[int, bool, float, str]] - first_idx = 0 - last_idx = ctx.getChildCount() - # skip first if whitespace - if self.is_ws(ctx.getChild(0)): - if last_idx == 1: - # Only whitespaces => this is not allowed. - raise HydraException( - "Trying to parse a primitive that is all whitespaces" - ) - first_idx = 1 - if self.is_ws(ctx.getChild(-1)): - last_idx = last_idx - 1 - num = last_idx - first_idx - if num > 1: - # Concatenate, while un-escaping as needed. - tokens = [] - for i, n in enumerate(ctx.getChildren()): - if n.symbol.type == OverrideLexer.WS and ( - i < first_idx or i >= last_idx - ): - # Skip leading / trailing whitespaces. - continue - tokens.append( - n.symbol.text[1::2] # un-escape by skipping every other char - if n.symbol.type == OverrideLexer.ESC - else n.symbol.text - ) - ret = "".join(tokens) - else: - node = ctx.getChild(first_idx) - if node.symbol.type == OverrideLexer.QUOTED_VALUE: - text = node.getText() - qc = text[0] - text = text[1:-1] - if qc == "'": - quote = Quote.single - text = text.replace("\\'", "'") - elif qc == '"': - quote = Quote.double - text = text.replace('\\"', '"') - else: - assert False - return QuotedString(text=text, quote=quote) - elif node.symbol.type in (OverrideLexer.ID, OverrideLexer.INTERPOLATION): - ret = node.symbol.text - elif node.symbol.type == OverrideLexer.INT: - ret = int(node.symbol.text) - elif node.symbol.type == OverrideLexer.FLOAT: - ret = float(node.symbol.text) - elif node.symbol.type == OverrideLexer.NULL: - ret = None - elif node.symbol.type == OverrideLexer.BOOL: - text = node.getText().lower() - if text == "true": - ret = True - elif text == "false": - ret = False - else: - assert False - elif node.symbol.type == OverrideLexer.ESC: - ret = node.symbol.text[1::2] - else: - return node.getText() # type: ignore - return ret + return self._createPrimitive(ctx) - def visitListValue( - self, ctx: OverrideParser.ListValueContext + def visitListContainer( + self, ctx: OverrideParser.ListContainerContext ) -> List[ParsedElementType]: ret: List[ParsedElementType] = [] @@ -159,8 +95,8 @@ def visitListValue( ret.append(self.visitElement(element)) return ret - def visitDictValue( - self, ctx: OverrideParser.DictValueContext + def visitDictContainer( + self, ctx: OverrideParser.DictContainerContext ) -> Dict[str, ParsedElementType]: assert self.is_matching_terminal(ctx.getChild(0), OverrideLexer.BRACE_OPEN) return dict( @@ -168,13 +104,16 @@ def visitDictValue( for i in range(1, ctx.getChildCount() - 1, 2) ) + def visitDictKey(self, ctx: OverrideParser.DictKeyContext) -> Any: + return self._createPrimitive(ctx) + def visitDictKeyValuePair( self, ctx: OverrideParser.DictKeyValuePairContext ) -> Tuple[str, ParsedElementType]: children = ctx.getChildren() item = next(children) - assert self.is_matching_terminal(item, OverrideLexer.ID) - pkey = item.getText() + assert isinstance(item, OverrideParser.DictKeyContext) + pkey = self.visitDictKey(item) assert self.is_matching_terminal(next(children), OverrideLexer.COLON) value = next(children) assert isinstance(value, OverrideParser.ElementContext) @@ -186,10 +125,10 @@ def visitElement(self, ctx: OverrideParser.ElementContext) -> ParsedElementType: return self.visitFunction(ctx.function()) # type: ignore elif ctx.primitive(): return self.visitPrimitive(ctx.primitive()) - elif ctx.listValue(): - return self.visitListValue(ctx.listValue()) - elif ctx.dictValue(): - return self.visitDictValue(ctx.dictValue()) + elif ctx.listContainer(): + return self.visitListContainer(ctx.listContainer()) + elif ctx.dictContainer(): + return self.visitDictContainer(ctx.dictContainer()) else: assert False @@ -305,6 +244,75 @@ def visitFunction(self, ctx: OverrideParser.FunctionContext) -> Any: f"{type(e).__name__} while evaluating '{ctx.getText()}': {e}" ) from e + def _createPrimitive( + self, ctx: ParserRuleContext + ) -> Optional[Union[QuotedString, int, bool, float, str]]: + ret: Optional[Union[int, bool, float, str]] + first_idx = 0 + last_idx = ctx.getChildCount() + # skip first if whitespace + if self.is_ws(ctx.getChild(0)): + if last_idx == 1: + # Only whitespaces => this is not allowed. + raise HydraException( + "Trying to parse a primitive that is all whitespaces" + ) + first_idx = 1 + if self.is_ws(ctx.getChild(-1)): + last_idx = last_idx - 1 + num = last_idx - first_idx + if num > 1: + # Concatenate, while un-escaping as needed. + tokens = [] + for i, n in enumerate(ctx.getChildren()): + if n.symbol.type == OverrideLexer.WS and ( + i < first_idx or i >= last_idx + ): + # Skip leading / trailing whitespaces. + continue + tokens.append( + n.symbol.text[1::2] # un-escape by skipping every other char + if n.symbol.type == OverrideLexer.ESC + else n.symbol.text + ) + ret = "".join(tokens) + else: + node = ctx.getChild(first_idx) + if node.symbol.type == OverrideLexer.QUOTED_VALUE: + text = node.getText() + qc = text[0] + text = text[1:-1] + if qc == "'": + quote = Quote.single + text = text.replace("\\'", "'") + elif qc == '"': + quote = Quote.double + text = text.replace('\\"', '"') + else: + assert False + return QuotedString(text=text, quote=quote) + elif node.symbol.type in (OverrideLexer.ID, OverrideLexer.INTERPOLATION): + ret = node.symbol.text + elif node.symbol.type == OverrideLexer.INT: + ret = int(node.symbol.text) + elif node.symbol.type == OverrideLexer.FLOAT: + ret = float(node.symbol.text) + elif node.symbol.type == OverrideLexer.NULL: + ret = None + elif node.symbol.type == OverrideLexer.BOOL: + text = node.getText().lower() + if text == "true": + ret = True + elif text == "false": + ret = False + else: + assert False + elif node.symbol.type == OverrideLexer.ESC: + ret = node.symbol.text[1::2] + else: + return node.getText() # type: ignore + return ret + class HydraErrorListener(ErrorListener): # type: ignore def syntaxError( diff --git a/hydra/core/override_parser/types.py b/hydra/core/override_parser/types.py index dfef6798b14..6d32db687d5 100644 --- a/hydra/core/override_parser/types.py +++ b/hydra/core/override_parser/types.py @@ -10,6 +10,7 @@ from omegaconf import OmegaConf from omegaconf._utils import is_structured_config +from hydra._internal.grammar.utils import escape_special_characters from hydra.core.config_loader import ConfigLoader from hydra.core.object_type import ObjectType from hydra.errors import HydraException @@ -20,7 +21,7 @@ class Quote(Enum): double = 1 -@dataclass +@dataclass(frozen=True) class QuotedString: text: str @@ -258,7 +259,13 @@ def _convert_value(value: ParsedElementType) -> Optional[ElementType]: if isinstance(value, list): return [Override._convert_value(x) for x in value] elif isinstance(value, dict): - return {k: Override._convert_value(v) for k, v in value.items()} + + return { + # We ignore potential type mismatch here so as to let OmegaConf + # raise an explicit error in case of invalid type. + Override._convert_value(k): Override._convert_value(v) # type: ignore + for k, v in value.items() + } elif isinstance(value, QuotedString): return value.text else: @@ -411,17 +418,19 @@ def _get_value_element_as_str( ) return "[" + s + "]" elif isinstance(value, dict): - s = comma.join( - [ - f"{k}{colon}{Override._get_value_element_as_str(v, space_after_sep=space_after_sep)}" - for k, v in value.items() - ] - ) - return "{" + s + "}" - elif isinstance(value, (str, int, bool, float)): + str_items = [] + for k, v in value.items(): + str_key = Override._get_value_element_as_str(k) + str_value = Override._get_value_element_as_str( + v, space_after_sep=space_after_sep + ) + str_items.append(f"{str_key}{colon}{str_value}") + return "{" + comma.join(str_items) + "}" + elif isinstance(value, str): + return escape_special_characters(value) + elif isinstance(value, (int, bool, float)): return str(value) elif is_structured_config(value): - print(value) return Override._get_value_element_as_str( OmegaConf.to_container(OmegaConf.structured(value)) ) diff --git a/hydra/grammar/OverrideLexer.g4 b/hydra/grammar/OverrideLexer.g4 index 41fb9b606a2..8e25ff4d777 100644 --- a/hydra/grammar/OverrideLexer.g4 +++ b/hydra/grammar/OverrideLexer.g4 @@ -61,6 +61,8 @@ NULL: [Nn][Uu][Ll][Ll]; UNQUOTED_CHAR: [/\-\\+.$%*@]; // other characters allowed in unquoted strings ID: (CHAR|'_') (CHAR|DIGIT|'_')*; +// Note: when adding more characters to the ESC rule below, also add them to +// the `_ESC` string in `_internal/grammar/utils.py`. ESC: (ESC_BACKSLASH | '\\(' | '\\)' | '\\[' | '\\]' | '\\{' | '\\}' | '\\:' | '\\=' | '\\,' | '\\ ' | '\\\t')+; WS: [ \t]+; diff --git a/hydra/grammar/OverrideParser.g4 b/hydra/grammar/OverrideParser.g4 index 8645a4147da..953c736cecc 100644 --- a/hydra/grammar/OverrideParser.g4 +++ b/hydra/grammar/OverrideParser.g4 @@ -31,8 +31,8 @@ value: element | simpleChoiceSweep; element: primitive - | listValue - | dictValue + | listContainer + | dictContainer | function ; @@ -47,12 +47,12 @@ function: ID POPEN (argName? element (COMMA argName? element )* )? PCLOSE; // Data structures. -listValue: BRACKET_OPEN // [], [1,2,3], [a,b,[1,2]] +listContainer: BRACKET_OPEN // [], [1,2,3], [a,b,[1,2]] (element(COMMA element)*)? BRACKET_CLOSE; -dictValue: BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE; // {}, {a:10,b:20} -dictKeyValuePair: ID COLON element; +dictContainer: BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE; // {}, {a:10,b:20} +dictKeyValuePair: dictKey COLON element; // Primitive types. @@ -69,3 +69,16 @@ primitive: | ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \, | WS // whitespaces )+; + +// Same as `primitive` except that `COLON` and `INTERPOLATION` are not allowed. +dictKey: + QUOTED_VALUE // 'hello world', "hello world" + | ( ID // foo_10 + | NULL // null, NULL + | INT // 0, 10, -20, 1_000_000 + | FLOAT // 3.14, -20.0, 1e-1, -10e3 + | BOOL // true, TrUe, false, False + | UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @ + | ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \, + | WS // whitespaces + )+; diff --git a/tests/test_config_loader.py b/tests/test_config_loader.py index 9d6a91d756b..d26893d80b6 100644 --- a/tests/test_config_loader.py +++ b/tests/test_config_loader.py @@ -1152,14 +1152,12 @@ def test_apply_overrides_to_config( id="default_change", ), pytest.param( - # need to unset optimizer config group first, otherwise they get merged "config", ["optimizer={type:nesterov2,lr:1}"], {"optimizer": {"type": "nesterov2", "lr": 1}}, id="dict_merge", ), pytest.param( - # need to unset optimizer config group first, otherwise they get merged "config", ["+optimizer={foo:10}"], {"optimizer": {"type": "nesterov", "lr": 0.001, "foo": 10}}, diff --git a/tests/test_overrides_parser.py b/tests/test_overrides_parser.py index 5e7af0e2572..d6e5f2e9b46 100644 --- a/tests/test_overrides_parser.py +++ b/tests/test_overrides_parser.py @@ -9,6 +9,7 @@ from _pytest.python_api import RaisesContext from hydra._internal.grammar.functions import Functions +from hydra._internal.grammar.utils import escape_special_characters from hydra.core.override_parser.overrides_parser import ( OverridesParser, create_functions, @@ -206,8 +207,8 @@ def test_value(value: str, expected: Any) -> None: pytest.param("[1,[a]]", [1, ["a"]], id="list:simple_and_list_elements"), ], ) -def test_list_value(value: str, expected: Any) -> None: - ret = parse_rule(value, "listValue") +def test_list_container(value: str, expected: Any) -> None: + ret = parse_rule(value, "listContainer") assert ret == expected @@ -277,10 +278,40 @@ def test_shuffle_sequence(value: str, expected: Any) -> None: pytest.param("{a:10,b:20}", {"a": 10, "b": 20}, id="dict"), pytest.param("{a:10,b:{}}", {"a": 10, "b": {}}, id="dict"), pytest.param("{a:10,b:{c:[1,2]}}", {"a": 10, "b": {"c": [1, 2]}}, id="dict"), + pytest.param( + "{'0a': 0, \"1b\": 1}", + { + QuotedString(text="0a", quote=Quote.single): 0, + QuotedString(text="1b", quote=Quote.double): 1, + }, + id="dict_quoted_key", + ), + pytest.param("{null: 1}", {None: 1}, id="dict_null_key"), + pytest.param("{123: 1, 0: 2, -1: 3}", {123: 1, 0: 2, -1: 3}, id="dict_int_key"), + pytest.param("{3.14: 0, 1e3: 1}", {3.14: 0, 1000.0: 1}, id="dict_float_key"), + pytest.param("{true: 1, fAlSe: 0}", {True: 1, False: 0}, id="dict_bool_key"), + pytest.param("{/-\\+.$%*@: 1}", {"/-\\+.$%*@": 1}, id="dict_unquoted_char_key"), + pytest.param( + "{\\\\\\(\\)\\[\\]\\{\\}\\:\\=\\ \\\t\\,: 1}", + {"\\()[]{}:= \t,": 1}, + id="dict_esc_key", + ), + pytest.param("{white spaces: 1}", {"white spaces": 1}, id="dict_ws_key"), + pytest.param( + "{'a:b': 1, ab 123.5 True: 2, null false: 3, 1: 4, null: 5}", + { + QuotedString(text="a:b", quote=Quote.single): 1, + "ab 123.5 True": 2, + "null false": 3, + 1: 4, + None: 5, + }, + id="dict_mixed_keys", + ), ], ) -def test_dict_value(value: str, expected: Any) -> None: - ret = parse_rule(value, "dictValue") +def test_dict_container(value: str, expected: Any) -> None: + ret = parse_rule(value, "dictContainer") assert ret == expected @@ -929,6 +960,12 @@ def test_get_key_element(override: str, expected: str) -> None: pytest.param("key='value'", "'value'", False, id="single_quoted"), pytest.param('key="value"', '"value"', False, id="double_quoted"), pytest.param("key='שלום'", "'שלום'", False, id="quoted_unicode"), + pytest.param( + "key=\\\\\\(\\)\\[\\]\\{\\}\\:\\=\\ \\\t\\,", + "\\\\\\(\\)\\[\\]\\{\\}\\:\\=\\ \\\t\\,", + False, + id="escaped_chars", + ), pytest.param("key=10", "10", False, id="int"), pytest.param("key=3.1415", "3.1415", False, id="float"), pytest.param("key=[]", "[]", False, id="list"), @@ -942,6 +979,30 @@ def test_get_key_element(override: str, expected: str) -> None: pytest.param("key={a:10,b:20}", "{a:10,b:20}", False, id="dict"), pytest.param("key={a:10,b:20}", "{a: 10, b: 20}", True, id="dict"), pytest.param("key={a:10,b:[1,2,3]}", "{a: 10, b: [1, 2, 3]}", True, id="dict"), + pytest.param( + "key={'null':1, \"a:b\": 0}", + "{'null': 1, \"a:b\": 0}", + True, + id="dict_quoted_key", + ), + pytest.param( + "key={/-\\+.$%*@: 1}", + "{/-\\\\+.$%*@: 1}", # note that \ gets escaped + True, + id="dict_unquoted_key_special", + ), + pytest.param( + "key={ white space\t: 2}", + "{white\\ \\ space: 2}", + True, + id="dict_ws_in_key", + ), + pytest.param( + "key={\\\\\\(\\)\\[\\]\\{\\}\\:\\=\\ \\\t\\,: 2}", + "{\\\\\\(\\)\\[\\]\\{\\}\\:\\=\\ \\\t\\,: 2}", + True, + id="dict_esc_key", + ), ], ) def test_override_get_value_element_method( @@ -970,6 +1031,16 @@ def test_override_get_value_element_method( pytest.param("key={a:10,b:20}", {"a": 10, "b": 20}, id="dict"), pytest.param("key={a:10,b:20}", {"a": 10, "b": 20}, id="dict"), pytest.param("key={a:10,b:[1,2,3]}", {"a": 10, "b": [1, 2, 3]}, id="dict"), + pytest.param("key={123id: 0}", {"123id": 0}, id="dict_key_int_plus_id"), + pytest.param("key={' abc ': 0}", {" abc ": 0}, id="dict_key_quoted_single"), + pytest.param('key={" abc ": 0}', {" abc ": 0}, id="dict_key_quoted_double"), + pytest.param("key={a/-\\+.$%*@: 0}", {"a/-\\+.$%*@": 0}, id="dict_key_noquote"), + pytest.param("key={w s: 0}", {"w s": 0}, id="dict_key_ws"), + pytest.param( + "key={\\\\\\(\\)\\[\\]\\{\\}\\:\\=\\ \\\t\\,: 0}", + {"\\()[]{}:= \t,": 0}, + id="dict_key_esc", + ), ], ) def test_override_value_method(override: str, expected: str) -> None: @@ -1954,3 +2025,26 @@ def test_sweep_iterators( ] assert actual_sweep_string_list == expected_sweep_string_list assert actual_sweep_encoded_list == expected_sweep_encoded_list + + +@pytest.mark.parametrize( # type: ignore + ("s", "expected"), + [ + pytest.param("abc", "abc", id="no_esc"), + pytest.param("\\", "\\\\", id="esc_backslash"), + pytest.param("\\\\\\", "\\\\\\\\\\\\", id="esc_backslash_x3"), + pytest.param("()", "\\(\\)", id="esc_parentheses"), + pytest.param("[]", "\\[\\]", id="esc_brackets"), + pytest.param("{}", "\\{\\}", id="esc_braces"), + pytest.param(":=,", "\\:\\=\\,", id="esc_symbols"), + pytest.param(" \t", "\\ \\ \\\t", id="esc_ws"), + pytest.param( + "ab\\(cd{ef}[gh]): ij,kl\t", + "ab\\\\\\(cd\\{ef\\}\\[gh\\]\\)\\:\\ ij\\,kl\\\t", + id="esc_mixed", + ), + ], +) +def test_escape_special_characters(s: str, expected: str) -> None: + escaped = escape_special_characters(s) + assert escaped == expected diff --git a/website/docs/advanced/override_grammar/basic.md b/website/docs/advanced/override_grammar/basic.md index b6ef50a9b57..912b47f95ad 100644 --- a/website/docs/advanced/override_grammar/basic.md +++ b/website/docs/advanced/override_grammar/basic.md @@ -57,8 +57,8 @@ value: element | simpleChoiceSweep; element: primitive - | listValue - | dictValue + | listContainer + | dictContainer | function ; @@ -73,12 +73,12 @@ function: ID POPEN (argName? element (COMMA argName? element )* )? PCLOSE; // Data structures. -listValue: BRACKET_OPEN // [], [1,2,3], [a,b,[1,2]] +listContainer: BRACKET_OPEN // [], [1,2,3], [a,b,[1,2]] (element(COMMA element)*)? BRACKET_CLOSE; -dictValue: BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE; // {}, {a:10,b:20} -dictKeyValuePair: ID COLON element; +dictContainer: BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE; // {}, {a:10,b:20} +dictKeyValuePair: dictKey COLON element; // Primitive types. @@ -95,6 +95,19 @@ primitive: | ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \, | WS // whitespaces )+; + +// Same as `primitive` except that `COLON` and `INTERPOLATION` are not allowed. +dictKey: + QUOTED_VALUE // 'hello world', "hello world" + | ( ID // foo_10 + | NULL // null, NULL + | INT // 0, 10, -20, 1_000_000 + | FLOAT // 3.14, -20.0, 1e-1, -10e3 + | BOOL // true, TrUe, false, False + | UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @ + | ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \, + | WS // whitespaces + )+; ``` ## Elements