From 984014ce77a8dd9a1b8a0eb1e399141262c0f557 Mon Sep 17 00:00:00 2001 From: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Date: Mon, 3 May 2021 13:59:49 -0400 Subject: [PATCH] Fix escaping in quoted values Fixes #1600 --- .../core/override_parser/overrides_visitor.py | 51 +++++-- hydra/core/override_parser/types.py | 9 +- hydra/grammar/OverrideLexer.g4 | 41 +++++- hydra/grammar/OverrideParser.g4 | 8 +- tests/test_overrides_parser.py | 133 +++++++++++++++++- 5 files changed, 215 insertions(+), 27 deletions(-) diff --git a/hydra/core/override_parser/overrides_visitor.py b/hydra/core/override_parser/overrides_visitor.py index bd1ea75617e..80bc7daeadf 100644 --- a/hydra/core/override_parser/overrides_visitor.py +++ b/hydra/core/override_parser/overrides_visitor.py @@ -240,6 +240,42 @@ def visitFunction(self, ctx: OverrideParser.FunctionContext) -> Any: f"{type(e).__name__} while evaluating '{ctx.getText()}': {e}" ) from e + def visitQuotedValue(self, ctx: OverrideParser.QuotedValueContext) -> QuotedString: + children = list(ctx.getChildren()) + assert len(children) >= 2 + + # Identity quote type. + first_quote = children[0].getText() + if first_quote == "'": + quote = Quote.single + else: + assert first_quote == '"' + quote = Quote.double + + # Inspect string content. + tokens = [] + is_interpolation = False + for child in children[1:-1]: + assert isinstance(child, TerminalNode) + symbol = child.symbol + text = symbol.text + if symbol.type == OverrideLexer.ESC_QUOTE: + # Always un-escape quotes. + text = text[1] + elif symbol.type == OverrideLexer.INTERPOLATION: + is_interpolation = True + tokens.append(text) + + # Contactenate string fragments. + ret = "".join(tokens) + + # If it is an interpolation, then OmegaConf will take care of un-escaping + # the `\\`. But if it is not, then we need to do it here. + if not is_interpolation: + ret = ret.replace("\\\\", "\\") + + return QuotedString(text=ret, quote=quote, esc_backslash=not is_interpolation) + def _createPrimitive( self, ctx: ParserRuleContext ) -> Optional[Union[QuotedString, int, bool, float, str]]: @@ -274,19 +310,8 @@ def _createPrimitive( ret = "".join(tokens) else: node = ctx.getChild(first_idx) - if node.symbol.type == OverrideLexer.QUOTED_VALUE: - text = node.getText() - qc = text[0] - text = text[1:-1] - if qc == "'": - quote = Quote.single - text = text.replace("\\'", "'") - elif qc == '"': - quote = Quote.double - text = text.replace('\\"', '"') - else: - assert False - return QuotedString(text=text, quote=quote) + if isinstance(node, OverrideParser.QuotedValueContext): + return self.visitQuotedValue(node) elif node.symbol.type in (OverrideLexer.ID, OverrideLexer.INTERPOLATION): ret = node.symbol.text elif node.symbol.type == OverrideLexer.INT: diff --git a/hydra/core/override_parser/types.py b/hydra/core/override_parser/types.py index e9e3480838c..d1c6ff54a3c 100644 --- a/hydra/core/override_parser/types.py +++ b/hydra/core/override_parser/types.py @@ -26,16 +26,19 @@ class Quote(Enum): @dataclass(frozen=True) class QuotedString: text: str - quote: Quote + esc_backslash: bool = True def with_quotes(self) -> str: + text = self.text + if self.esc_backslash: + text = text.replace("\\", "\\\\") if self.quote == Quote.single: q = "'" - text = self.text.replace("'", "\\'") + text = text.replace("'", "\\'") elif self.quote == Quote.double: q = '"' - text = self.text.replace('"', '\\"') + text = text.replace('"', '\\"') else: assert False return f"{q}{text}{q}" diff --git a/hydra/grammar/OverrideLexer.g4 b/hydra/grammar/OverrideLexer.g4 index 93dff89326a..3d44211e9ad 100644 --- a/hydra/grammar/OverrideLexer.g4 +++ b/hydra/grammar/OverrideLexer.g4 @@ -33,6 +33,9 @@ DOT_PATH: (ID | INT_UNSIGNED) ('.' (ID | INT_UNSIGNED))+; mode VALUE_MODE; +QUOTE_OPEN_SINGLE: '\'' -> pushMode(QUOTED_SINGLE_MODE); +QUOTE_OPEN_DOUBLE: '"' -> pushMode(QUOTED_DOUBLE_MODE); + POPEN: WS? '(' WS?; // whitespaces before to allow `func (x)` COMMA: WS? ',' WS?; PCLOSE: WS? ')'; @@ -66,8 +69,38 @@ ESC: (ESC_BACKSLASH | '\\(' | '\\)' | '\\[' | '\\]' | '\\{' | '\\}' | '\\:' | '\\=' | '\\,' | '\\ ' | '\\\t')+; WS: [ \t]+; -QUOTED_VALUE: - '\'' ('\\\''|.)*? '\'' // Single quotes, can contain escaped single quote : /' - | '"' ('\\"'|.)*? '"' ; // Double quotes, can contain escaped double quote : /" - INTERPOLATION: '${' ~('}')+ '}'; + + +//////////////////////// +// QUOTED_SINGLE_MODE // +//////////////////////// + +mode QUOTED_SINGLE_MODE; + +MATCHING_QUOTE_CLOSE: '\'' -> popMode; + +ESC_QUOTE: '\\\''; +QSINGLE_ESC_BACKSLASH: ESC_BACKSLASH -> type(ESC); + +QSINGLE_INTERPOLATION: INTERPOLATION -> type(INTERPOLATION); +SPECIAL_CHAR: [\\$]; +ANY_STR: ~['\\$]+; + + +//////////////////////// +// QUOTED_DOUBLE_MODE // +//////////////////////// + +mode QUOTED_DOUBLE_MODE; + +// Same as `QUOTED_SINGLE_MODE` but for double quotes. + +QDOUBLE_CLOSE: '"' -> type(MATCHING_QUOTE_CLOSE), popMode; + +QDOUBLE_ESC_QUOTE: '\\"' -> type(ESC_QUOTE); +QDOUBLE_ESC_BACKSLASH: ESC_BACKSLASH -> type(ESC); + +QDOUBLE_INTERPOLATION: INTERPOLATION -> type(INTERPOLATION); +QDOUBLE_SPECIAL_CHAR: SPECIAL_CHAR -> type(SPECIAL_CHAR); +QDOUBLE_STR: ~["\\$]+ -> type(ANY_STR); diff --git a/hydra/grammar/OverrideParser.g4 b/hydra/grammar/OverrideParser.g4 index bd9006010f2..5503042cf91 100644 --- a/hydra/grammar/OverrideParser.g4 +++ b/hydra/grammar/OverrideParser.g4 @@ -51,8 +51,14 @@ dictKeyValuePair: dictKey COLON element; // Primitive types. +// Ex: "hello world", 'hello ${world}' +quotedValue: + (QUOTE_OPEN_SINGLE | QUOTE_OPEN_DOUBLE) + (INTERPOLATION | ESC | ESC_QUOTE | SPECIAL_CHAR | ANY_STR)* + MATCHING_QUOTE_CLOSE; + primitive: - QUOTED_VALUE // 'hello world', "hello world" + quotedValue // 'hello world', "hello world" | ( ID // foo_10 | NULL // null, NULL | INT // 0, 10, -20, 1_000_000 diff --git a/tests/test_overrides_parser.py b/tests/test_overrides_parser.py index f06f4cb8243..ea10cb742b3 100644 --- a/tests/test_overrides_parser.py +++ b/tests/test_overrides_parser.py @@ -200,6 +200,22 @@ def test_value(value: str, expected: Any) -> None: param("[[a]]", [["a"]], id="list:nested_list"), param("[[[a]]]", [[["a"]]], id="list:double_nested_list"), param("[1,[a]]", [1, ["a"]], id="list:simple_and_list_elements"), + param( + r"['a\\', 'b\\']", + [ + QuotedString(text="a\\", quote=Quote.single), + QuotedString(text="b\\", quote=Quote.single), + ], + id="list:str_trailing_backslash_single", + ), + param( + r'["a\\", "b\\"]', + [ + QuotedString(text="a\\", quote=Quote.double), + QuotedString(text="b\\", quote=Quote.double), + ], + id="list:str_trailing_backslash_double", + ), ], ) def test_list_container(value: str, expected: Any) -> None: @@ -295,6 +311,22 @@ def test_shuffle_sequence(value: str, expected: Any) -> None: }, id="dict_mixed_keys", ), + param( + r"{a: 'a\\', b: 'b\\'}", + { + "a": QuotedString(text="a\\", quote=Quote.single), + "b": QuotedString(text="b\\", quote=Quote.single), + }, + id="dict_str_trailing_backslash_single", + ), + param( + r'{a: "a\\", b: "b\\"}', + { + "a": QuotedString(text="a\\", quote=Quote.double), + "b": QuotedString(text="b\\", quote=Quote.double), + }, + id="dict_str_trailing_backslash_double", + ), ], ) def test_dict_container(value: str, expected: Any) -> None: @@ -426,13 +458,15 @@ def test_interval_sweep(value: str, expected: Any) -> None: param( "override", "key=[1,2,3]'", - raises(HydraException, match=re.escape("token recognition error at: '''")), + raises( + HydraException, match=re.escape("extraneous input ''' expecting ") + ), id="error:left_overs", ), param( "dictContainer", "{'0a': 0, \"1b\": 1}", - raises(HydraException, match=re.escape("mismatched input ''0a''")), + raises(HydraException, match=re.escape("mismatched input '''")), id="error:dict_quoted_key_dictContainer", ), param( @@ -440,7 +474,7 @@ def test_interval_sweep(value: str, expected: Any) -> None: "key={' abc ': 0}", raises( HydraException, - match=re.escape("no viable alternative at input '{' abc ''"), + match=re.escape("no viable alternative at input '{''"), ), id="error:dict_quoted_key_override_single", ), @@ -449,7 +483,7 @@ def test_interval_sweep(value: str, expected: Any) -> None: 'key={" abc ": 0}', raises( HydraException, - match=re.escape("""no viable alternative at input '{" abc "'"""), + match=re.escape("""no viable alternative at input '{"'"""), ), id="error:dict_quoted_key_override_double", ), @@ -561,15 +595,41 @@ def test_key(value: str, expected: Any) -> None: param("false", False, id="primitive:bool"), # quoted string param( - "'foo \\'bar'", + r"'foo \'bar'", QuotedString(text="foo 'bar", quote=Quote.single), id="value:escape_single_quote", ), param( - '"foo \\"bar"', + r'"foo \"bar"', QuotedString(text='foo "bar', quote=Quote.double), id="value:escape_double_quote", ), + param( + r"'foo \\\'bar'", + QuotedString(text=r"foo \'bar", quote=Quote.single), + id="value:escape_single_quote_x3", + ), + param( + r'"foo \\\"bar"', + QuotedString(text=r"foo \"bar", quote=Quote.double), + id="value:escape_double_quote_x3", + ), + param( + r"'foo\\bar'", + QuotedString(text=r"foo\bar", quote=Quote.single), + id="value:escape_backslash", + ), + param( + r"'foo\\\\bar'", + QuotedString(text=r"foo\\bar", quote=Quote.single), + id="value:escape_backslash_x4", + ), + param( + r"'foo bar\\'", + # Note: raw strings do not allow trailing \, adding a space and stripping it. + QuotedString(text=r" foo bar\ ".strip(), quote=Quote.single), + id="value:escape_backslash_trailing", + ), param( "'\t []{},=+~'", QuotedString(text="\t []{},=+~", quote=Quote.single), @@ -643,6 +703,41 @@ def test_key(value: str, expected: Any) -> None: QuotedString(text="false", quote=Quote.single), id="value:bool:quoted", ), + param( + "'a ${b}'", + QuotedString(text="a ${b}", quote=Quote.single, esc_backslash=False), + id="value:interpolation:quoted", + ), + param( + r"'a \${b}'", + QuotedString(text=r"a \${b}", quote=Quote.single, esc_backslash=False), + id="value:esc_interpolation:quoted", + ), + param( + r"'a \\${b}'", + QuotedString(text=r"a \\${b}", quote=Quote.single, esc_backslash=False), + id="value:backslash_and_interpolation:quoted", + ), + param( + r"'a \'${b}\''", + QuotedString(text=r"a '${b}'", quote=Quote.single, esc_backslash=False), + id="value:quotes_and_interpolation:quoted", + ), + param( + r"'a \'\${b}\''", + QuotedString(text=r"a '\${b}'", quote=Quote.single, esc_backslash=False), + id="value:quotes_and_esc_interpolation:quoted", + ), + param( + r"'a \'\\${b}\''", + QuotedString(text=r"a '\\${b}'", quote=Quote.single, esc_backslash=False), + id="value:quotes_backslash_and_interpolation:quoted", + ), + param( + r"'a \\\'${b}\\\''", + QuotedString(text=r"a \\'${b}\\'", quote=Quote.single, esc_backslash=False), + id="value:backaslash_quotes_and_interpolation:quoted", + ), # interpolations: param("${a}", "${a}", id="primitive:interpolation"), param("${a.b.c}", "${a.b.c}", id="primitive:interpolation"), @@ -659,6 +754,32 @@ def test_primitive(value: str, expected: Any) -> None: assert eq(ret, expected) +@mark.parametrize( + ("value", "expected", "with_quotes"), + [ + param( + r"'foo\bar'", + QuotedString(text=r"foo\bar", quote=Quote.single), + r"'foo\\bar'", + id="value:one_backslash_single", + ), + param( + r'"foo\bar"', + QuotedString(text=r"foo\bar", quote=Quote.double), + r'"foo\\bar"', + id="value:one_backslash_double", + ), + ], +) +def test_with_quotes_one_backslash(value: str, expected: Any, with_quotes: str) -> None: + # This test's objective is to test the case where a quoted string contains a single + # (i.e., non-escaped) backslash. This case can't be included in `test_primitive()` + # because the backslash is escaped by `with_quotes()` => value != ret.with_quotes() + ret = parse_rule(value, "primitive") + assert eq(ret, expected) + assert ret.with_quotes() == with_quotes + + @mark.parametrize( "prefix,override_type", [