diff --git a/omegaconf/grammar/OmegaConfGrammarLexer.g4 b/omegaconf/grammar/OmegaConfGrammarLexer.g4 index e7325f8ad..73f616b7c 100644 --- a/omegaconf/grammar/OmegaConfGrammarLexer.g4 +++ b/omegaconf/grammar/OmegaConfGrammarLexer.g4 @@ -23,11 +23,10 @@ ANY_STR: ~[$]* ~[\\$]; // Escaped interpolation: '\${', optionally preceded by an even number of \ ESC_INTER: ESC_BACKSLASH* '\\${'; -// Escaped backslashes: even number of \ followed by an interpolation. -// The interpolation must *not* be matched by this rule (this is why we use a predicate lookahead). -TOP_ESC: ESC_BACKSLASH+ { self._input.LA(1) == ord("$") and self._input.LA(2) == ord("{") }? -> type(ESC); +// Backslashes that *may* be escaped (even number). +TOP_ESC: ESC_BACKSLASH+; -// Other backslashes that will not need escaping. +// Other backslashes that will not need escaping (odd number due to not matching the previous rule). BACKSLASHES: '\\'+ -> type(ANY_STR); // The dollar sign must be singled out so that we can recognize interpolations. @@ -111,19 +110,10 @@ QSINGLE_STR: ~['$]* ~['\\$] -> type(ANY_STR); QSINGLE_ESC_INTER: ESC_INTER -> type(ESC_INTER); -// In a quoted string we also have the following escape sequences: -// - \', optionally preceded by an even number of \ (escaped quote) -// - an even number of \ followed by either the closing quote (trailing backslashes) or by -// an interpolation (as in `DEFAULT_MODE`) -- which must not be matched by this rule -QSINGLE_ESC: - ( - ESC_BACKSLASH* '\\\'' | - ESC_BACKSLASH+ {( - self._input.LA(1) == ord("'") or - (self._input.LA(1) == ord("$") and self._input.LA(2) == ord("{")) - )}? - ) -> type(ESC); +// Escaped quote (optionally preceded by an even number of backslashes). +QSINGLE_ESC_QUOTE: ESC_BACKSLASH* '\\\'' -> type(ESC); +QUOTED_ESC: ESC_BACKSLASH+; QSINGLE_BACKSLASHES: '\\'+ -> type(ANY_STR); QSINGLE_DOLLAR: '$' -> type(ANY_STR); @@ -140,17 +130,8 @@ QDOUBLE_INTER_OPEN: INTER_OPEN -> type(INTER_OPEN), pushMode(INTERPOLATION_MODE) QDOUBLE_CLOSE: '"' -> type(MATCHING_QUOTE_CLOSE), popMode; QDOUBLE_STR: ~["$]* ~["\\$] -> type(ANY_STR); - QDOUBLE_ESC_INTER: ESC_INTER -> type(ESC_INTER); - -QDOUBLE_ESC: - ( - ESC_BACKSLASH* '\\"' | - ESC_BACKSLASH+ {( - self._input.LA(1) == ord('"') or - (self._input.LA(1) == ord("$") and self._input.LA(2) == ord("{")) - )}? - ) -> type(ESC); - +QDOUBLE_ESC_QUOTE: ESC_BACKSLASH* '\\"' -> type(ESC); +QDOUBLE_ESC: ESC_BACKSLASH+ -> type(QUOTED_ESC); QDOUBLE_BACKSLASHES: '\\'+ -> type(ANY_STR); QDOUBLE_DOLLAR: '$' -> type(ANY_STR); diff --git a/omegaconf/grammar/OmegaConfGrammarParser.g4 b/omegaconf/grammar/OmegaConfGrammarParser.g4 index 5ec70cf07..f0f256c9a 100644 --- a/omegaconf/grammar/OmegaConfGrammarParser.g4 +++ b/omegaconf/grammar/OmegaConfGrammarParser.g4 @@ -23,7 +23,7 @@ singleElement: element EOF; // Composite text expression (may contain interpolations). -text: (interpolation | ESC | ESC_INTER | ANY_STR)+; +text: (interpolation | ANY_STR | ESC | ESC_INTER | TOP_ESC | QUOTED_ESC)+; // Elements. diff --git a/omegaconf/grammar_visitor.py b/omegaconf/grammar_visitor.py index 8df6bfa70..1771516dc 100644 --- a/omegaconf/grammar_visitor.py +++ b/omegaconf/grammar_visitor.py @@ -1,12 +1,12 @@ import sys import warnings +from itertools import zip_longest from typing import ( TYPE_CHECKING, Any, Callable, Dict, Generator, - Iterable, List, Optional, Set, @@ -289,7 +289,7 @@ def visitSingleElement( return self.visit(ctx.getChild(0)) def visitText(self, ctx: OmegaConfGrammarParser.TextContext) -> Any: - # (interpolation | ESC | ESC_INTER | ANY_STR)+ + # (interpolation | ANY_STR | ESC | ESC_INTER | TOP_ESC | QUOTED_ESC)+ # Single interpolation? If yes, return its resolved value "as is". if ctx.getChildCount() == 1: @@ -298,7 +298,7 @@ def visitText(self, ctx: OmegaConfGrammarParser.TextContext) -> Any: return self.visitInterpolation(c) # Otherwise, concatenate string representations together. - return self._unescape(ctx.getChildren()) + return self._unescape(list(ctx.getChildren())) def _createPrimitive( self, @@ -336,11 +336,11 @@ def _createPrimitive( raise AssertionError("WS should never be reached") assert False, symbol.type # Concatenation of multiple items ==> un-escape the concatenation. - return self._unescape(ctx.getChildren()) + return self._unescape(list(ctx.getChildren())) def _unescape( self, - seq: Iterable[Union[TerminalNode, OmegaConfGrammarParser.InterpolationContext]], + seq: List[Union[TerminalNode, OmegaConfGrammarParser.InterpolationContext]], ) -> str: """ Concatenate all symbols / interpolations in `seq`, unescaping symbols as needed. @@ -350,19 +350,43 @@ def _unescape( resolving of the interpolation). """ chrs = [] - for node in seq: + for node, next_node in zip_longest(seq, seq[1:]): if isinstance(node, TerminalNode): s = node.symbol - if s.type == OmegaConfGrammarLexer.ESC: - chrs.append(s.text[1::2]) - elif s.type == OmegaConfGrammarLexer.ESC_INTER: + if s.type == OmegaConfGrammarLexer.ESC_INTER: # `ESC_INTER` is of the form `\\...\${`: the formula below computes # the number of characters to keep at the end of the string to remove # the correct number of backslashes. - chrs.append(s.text[-(len(s.text) // 2 + 1) :]) + text = s.text[-(len(s.text) // 2 + 1) :] + elif ( + # Character sequence identified as requiring un-escaping. + s.type == OmegaConfGrammarLexer.ESC + or ( + # At top level, we need to un-escape backslashes that precede + # an interpolation. + s.type == OmegaConfGrammarLexer.TOP_ESC + and isinstance( + next_node, OmegaConfGrammarParser.InterpolationContext + ) + ) + or ( + # In a quoted sring, we need to un-escape backslashes that + # either end the string, or are followed by an interpolation. + s.type == OmegaConfGrammarLexer.QUOTED_ESC + and ( + next_node is None + or isinstance( + next_node, OmegaConfGrammarParser.InterpolationContext + ) + ) + ) + ): + text = s.text[1::2] # un-escape the sequence else: - chrs.append(s.text) + text = s.text # keep the original text else: assert isinstance(node, OmegaConfGrammarParser.InterpolationContext) - chrs.append(str(self.visitInterpolation(node))) + text = str(self.visitInterpolation(node)) + chrs.append(text) + return "".join(chrs)