Skip to content

Commit

Permalink
Move the detection of \ to escape from lexer to visitor
Browse files Browse the repository at this point in the history
  • Loading branch information
odelalleau committed May 6, 2021
1 parent 1eb0dcc commit c5d3a07
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 40 deletions.
35 changes: 8 additions & 27 deletions omegaconf/grammar/OmegaConfGrammarLexer.g4
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ ANY_STR: ~[$]* ~[\\$];
// Escaped interpolation: '\${', optionally preceded by an even number of \
ESC_INTER: ESC_BACKSLASH* '\\${';

// Escaped backslashes: even number of \ followed by an interpolation.
// The interpolation must *not* be matched by this rule (this is why we use a predicate lookahead).
TOP_ESC: ESC_BACKSLASH+ { self._input.LA(1) == ord("$") and self._input.LA(2) == ord("{") }? -> type(ESC);
// Backslashes that *may* be escaped (even number).
TOP_ESC: ESC_BACKSLASH+;

// Other backslashes that will not need escaping.
// Other backslashes that will not need escaping (odd number due to not matching the previous rule).
BACKSLASHES: '\\'+ -> type(ANY_STR);

// The dollar sign must be singled out so that we can recognize interpolations.
Expand Down Expand Up @@ -111,19 +110,10 @@ QSINGLE_STR: ~['$]* ~['\\$] -> type(ANY_STR);

QSINGLE_ESC_INTER: ESC_INTER -> type(ESC_INTER);

// In a quoted string we also have the following escape sequences:
// - \', optionally preceded by an even number of \ (escaped quote)
// - an even number of \ followed by either the closing quote (trailing backslashes) or by
// an interpolation (as in `DEFAULT_MODE`) -- which must not be matched by this rule
QSINGLE_ESC:
(
ESC_BACKSLASH* '\\\'' |
ESC_BACKSLASH+ {(
self._input.LA(1) == ord("'") or
(self._input.LA(1) == ord("$") and self._input.LA(2) == ord("{"))
)}?
) -> type(ESC);
// Escaped quote (optionally preceded by an even number of backslashes).
QSINGLE_ESC_QUOTE: ESC_BACKSLASH* '\\\'' -> type(ESC);

QUOTED_ESC: ESC_BACKSLASH+;
QSINGLE_BACKSLASHES: '\\'+ -> type(ANY_STR);
QSINGLE_DOLLAR: '$' -> type(ANY_STR);

Expand All @@ -140,17 +130,8 @@ QDOUBLE_INTER_OPEN: INTER_OPEN -> type(INTER_OPEN), pushMode(INTERPOLATION_MODE)
QDOUBLE_CLOSE: '"' -> type(MATCHING_QUOTE_CLOSE), popMode;

QDOUBLE_STR: ~["$]* ~["\\$] -> type(ANY_STR);

QDOUBLE_ESC_INTER: ESC_INTER -> type(ESC_INTER);

QDOUBLE_ESC:
(
ESC_BACKSLASH* '\\"' |
ESC_BACKSLASH+ {(
self._input.LA(1) == ord('"') or
(self._input.LA(1) == ord("$") and self._input.LA(2) == ord("{"))
)}?
) -> type(ESC);

QDOUBLE_ESC_QUOTE: ESC_BACKSLASH* '\\"' -> type(ESC);
QDOUBLE_ESC: ESC_BACKSLASH+ -> type(QUOTED_ESC);
QDOUBLE_BACKSLASHES: '\\'+ -> type(ANY_STR);
QDOUBLE_DOLLAR: '$' -> type(ANY_STR);
2 changes: 1 addition & 1 deletion omegaconf/grammar/OmegaConfGrammarParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ singleElement: element EOF;

// Composite text expression (may contain interpolations).

text: (interpolation | ESC | ESC_INTER | ANY_STR)+;
text: (interpolation | ANY_STR | ESC | ESC_INTER | TOP_ESC | QUOTED_ESC)+;


// Elements.
Expand Down
48 changes: 36 additions & 12 deletions omegaconf/grammar_visitor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import sys
import warnings
from itertools import zip_longest
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Set,
Expand Down Expand Up @@ -289,7 +289,7 @@ def visitSingleElement(
return self.visit(ctx.getChild(0))

def visitText(self, ctx: OmegaConfGrammarParser.TextContext) -> Any:
# (interpolation | ESC | ESC_INTER | ANY_STR)+
# (interpolation | ANY_STR | ESC | ESC_INTER | TOP_ESC | QUOTED_ESC)+

# Single interpolation? If yes, return its resolved value "as is".
if ctx.getChildCount() == 1:
Expand All @@ -298,7 +298,7 @@ def visitText(self, ctx: OmegaConfGrammarParser.TextContext) -> Any:
return self.visitInterpolation(c)

# Otherwise, concatenate string representations together.
return self._unescape(ctx.getChildren())
return self._unescape(list(ctx.getChildren()))

def _createPrimitive(
self,
Expand Down Expand Up @@ -336,11 +336,11 @@ def _createPrimitive(
raise AssertionError("WS should never be reached")
assert False, symbol.type
# Concatenation of multiple items ==> un-escape the concatenation.
return self._unescape(ctx.getChildren())
return self._unescape(list(ctx.getChildren()))

def _unescape(
self,
seq: Iterable[Union[TerminalNode, OmegaConfGrammarParser.InterpolationContext]],
seq: List[Union[TerminalNode, OmegaConfGrammarParser.InterpolationContext]],
) -> str:
"""
Concatenate all symbols / interpolations in `seq`, unescaping symbols as needed.
Expand All @@ -350,19 +350,43 @@ def _unescape(
resolving of the interpolation).
"""
chrs = []
for node in seq:
for node, next_node in zip_longest(seq, seq[1:]):
if isinstance(node, TerminalNode):
s = node.symbol
if s.type == OmegaConfGrammarLexer.ESC:
chrs.append(s.text[1::2])
elif s.type == OmegaConfGrammarLexer.ESC_INTER:
if s.type == OmegaConfGrammarLexer.ESC_INTER:
# `ESC_INTER` is of the form `\\...\${`: the formula below computes
# the number of characters to keep at the end of the string to remove
# the correct number of backslashes.
chrs.append(s.text[-(len(s.text) // 2 + 1) :])
text = s.text[-(len(s.text) // 2 + 1) :]
elif (
# Character sequence identified as requiring un-escaping.
s.type == OmegaConfGrammarLexer.ESC
or (
# At top level, we need to un-escape backslashes that precede
# an interpolation.
s.type == OmegaConfGrammarLexer.TOP_ESC
and isinstance(
next_node, OmegaConfGrammarParser.InterpolationContext
)
)
or (
# In a quoted sring, we need to un-escape backslashes that
# either end the string, or are followed by an interpolation.
s.type == OmegaConfGrammarLexer.QUOTED_ESC
and (
next_node is None
or isinstance(
next_node, OmegaConfGrammarParser.InterpolationContext
)
)
)
):
text = s.text[1::2] # un-escape the sequence
else:
chrs.append(s.text)
text = s.text # keep the original text
else:
assert isinstance(node, OmegaConfGrammarParser.InterpolationContext)
chrs.append(str(self.visitInterpolation(node)))
text = str(self.visitInterpolation(node))
chrs.append(text)

return "".join(chrs)

0 comments on commit c5d3a07

Please sign in to comment.