Skip to content

Commit

Permalink
SQL: Remove CircuitBreaker from parser (#41835)
Browse files Browse the repository at this point in the history
The CircuitBreaker was introduced as means of preventing a
`StackOverflowException` during the build of the AST by the parser.

The ANTLR4 grammar causes a weird behaviour for a Parser Listener.
The `enterEveryRule()` method is often called with a different parsing
context than the respective `exitEveryRule()`. This makes it difficult
to keep track of the tree's depth, and a custom Map was used as an
attempt of matching the contextes as they are encounter during `enter`
and during `exit` of the rules.

This approach had 2 important drawbacks:
1. It's hard to maintain this custom Map as the grammar changes.
2. The CircuitBreaker could often lead to false positives which caused
valid queries to return an Exception and prevent them from executing.

So, this removes completely the CircuitBreaker which is replaced be
a simple handling of the `StackOverflowException`

Fixes: #41471
(cherry picked from commit 1559a8e)
  • Loading branch information
matriv committed May 7, 2019
1 parent a22120d commit 04116c9
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 278 deletions.
8 changes: 8 additions & 0 deletions docs/reference/sql/limitations.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@
[[sql-limitations]]
== SQL Limitations

[float]
[[large-parsing-trees]]
=== Large queries may throw `ParsingExpection`

Extremely large queries can consume too much memory during the parsing phase, in which case the {es-sql} engine will
abort parsing and throw an error. In such cases, consider reducing the query to a smaller size by potentially
simplifying it or splitting it into smaller queries.

[float]
[[sys-columns-describe-table-nested-fields]]
=== Nested fields in `SYS COLUMNS` and `DESCRIBE TABLE`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
*/
package org.elasticsearch.xpack.sql.parser;

import com.carrotsearch.hppc.ObjectShortHashMap;
import org.antlr.v4.runtime.BaseErrorListener;
import org.antlr.v4.runtime.CharStream;
import org.antlr.v4.runtime.CommonToken;
Expand All @@ -26,16 +25,6 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.BackQuotedIdentifierContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.BooleanDefaultContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.BooleanExpressionContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.PrimaryExpressionContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.QueryPrimaryDefaultContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.QueryTermContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.QuoteIdentifierContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.StatementContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.StatementDefaultContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.UnquoteIdentifierContext;
import org.elasticsearch.xpack.sql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.sql.proto.SqlTypedParamValue;

Expand All @@ -50,7 +39,6 @@
import java.util.function.Function;

import static java.lang.String.format;
import static org.elasticsearch.xpack.sql.parser.AbstractBuilder.source;

public class SqlParser {

Expand Down Expand Up @@ -100,45 +88,49 @@ private <T> T invokeParser(String sql,
List<SqlTypedParamValue> params, Function<SqlBaseParser,
ParserRuleContext> parseFunction,
BiFunction<AstBuilder, ParserRuleContext, T> visitor) {
SqlBaseLexer lexer = new SqlBaseLexer(new CaseInsensitiveStream(sql));
try {
SqlBaseLexer lexer = new SqlBaseLexer(new CaseInsensitiveStream(sql));

lexer.removeErrorListeners();
lexer.addErrorListener(ERROR_LISTENER);
lexer.removeErrorListeners();
lexer.addErrorListener(ERROR_LISTENER);

Map<Token, SqlTypedParamValue> paramTokens = new HashMap<>();
TokenSource tokenSource = new ParametrizedTokenSource(lexer, paramTokens, params);
Map<Token, SqlTypedParamValue> paramTokens = new HashMap<>();
TokenSource tokenSource = new ParametrizedTokenSource(lexer, paramTokens, params);

CommonTokenStream tokenStream = new CommonTokenStream(tokenSource);
SqlBaseParser parser = new SqlBaseParser(tokenStream);
CommonTokenStream tokenStream = new CommonTokenStream(tokenSource);
SqlBaseParser parser = new SqlBaseParser(tokenStream);

parser.addParseListener(new CircuitBreakerListener());
parser.addParseListener(new PostProcessor(Arrays.asList(parser.getRuleNames())));
parser.addParseListener(new PostProcessor(Arrays.asList(parser.getRuleNames())));

parser.removeErrorListeners();
parser.addErrorListener(ERROR_LISTENER);
parser.removeErrorListeners();
parser.addErrorListener(ERROR_LISTENER);

parser.getInterpreter().setPredictionMode(PredictionMode.SLL);
parser.getInterpreter().setPredictionMode(PredictionMode.SLL);

if (DEBUG) {
debug(parser);
tokenStream.fill();
if (DEBUG) {
debug(parser);
tokenStream.fill();

for (Token t : tokenStream.getTokens()) {
String symbolicName = SqlBaseLexer.VOCABULARY.getSymbolicName(t.getType());
String literalName = SqlBaseLexer.VOCABULARY.getLiteralName(t.getType());
log.info(format(Locale.ROOT, " %-15s '%s'",
for (Token t : tokenStream.getTokens()) {
String symbolicName = SqlBaseLexer.VOCABULARY.getSymbolicName(t.getType());
String literalName = SqlBaseLexer.VOCABULARY.getLiteralName(t.getType());
log.info(format(Locale.ROOT, " %-15s '%s'",
symbolicName == null ? literalName : symbolicName,
t.getText()));
}
}
}

ParserRuleContext tree = parseFunction.apply(parser);
ParserRuleContext tree = parseFunction.apply(parser);

if (DEBUG) {
log.info("Parse tree {} " + tree.toStringTree());
}
if (DEBUG) {
log.info("Parse tree {} " + tree.toStringTree());
}

return visitor.apply(new AstBuilder(paramTokens), tree);
return visitor.apply(new AstBuilder(paramTokens), tree);
} catch (StackOverflowError e) {
throw new ParsingException("SQL statement is too large, " +
"causing stack overflow when generating the parsing tree: [{}]", sql);
}
}

private static void debug(SqlBaseParser parser) {
Expand Down Expand Up @@ -221,93 +213,6 @@ public void exitNonReserved(SqlBaseParser.NonReservedContext context) {
}
}

/**
* Used to catch large expressions that can lead to stack overflows
*/
static class CircuitBreakerListener extends SqlBaseBaseListener {

private static final short MAX_RULE_DEPTH = 200;

/**
* Due to the structure of the grammar and our custom handling in {@link ExpressionBuilder}
* some expressions can exit with a different class than they entered:
* e.g.: ValueExpressionContext can exit as ValueExpressionDefaultContext
*/
private static final Map<String, String> ENTER_EXIT_RULE_MAPPING = new HashMap<>();

static {
ENTER_EXIT_RULE_MAPPING.put(StatementDefaultContext.class.getSimpleName(), StatementContext.class.getSimpleName());
ENTER_EXIT_RULE_MAPPING.put(QueryPrimaryDefaultContext.class.getSimpleName(), QueryTermContext.class.getSimpleName());
ENTER_EXIT_RULE_MAPPING.put(BooleanDefaultContext.class.getSimpleName(), BooleanExpressionContext.class.getSimpleName());
}

private boolean insideIn = false;

// Keep current depth for every rule visited.
// The totalDepth alone cannot be used as expressions like: e1 OR e2 OR e3 OR ...
// are processed as e1 OR (e2 OR (e3 OR (... and this results in the totalDepth not growing
// while the stack call depth is, leading to a StackOverflowError.
private ObjectShortHashMap<String> depthCounts = new ObjectShortHashMap<>();

@Override
public void enterEveryRule(ParserRuleContext ctx) {
if (inDetected(ctx)) {
insideIn = true;
}

// Skip PrimaryExpressionContext for IN as it's not visited on exit due to
// the grammar's peculiarity rule with "predicated" and "predicate".
// Also skip the Identifiers as they are "cheap".
if (ctx.getClass() != UnquoteIdentifierContext.class &&
ctx.getClass() != QuoteIdentifierContext.class &&
ctx.getClass() != BackQuotedIdentifierContext.class &&
ctx.getClass() != SqlBaseParser.ConstantContext.class &&
ctx.getClass() != SqlBaseParser.NumberContext.class &&
ctx.getClass() != SqlBaseParser.ValueExpressionContext.class &&
(insideIn == false || ctx.getClass() != PrimaryExpressionContext.class)) {

int currentDepth = depthCounts.putOrAdd(ctx.getClass().getSimpleName(), (short) 1, (short) 1);
if (currentDepth > MAX_RULE_DEPTH) {
throw new ParsingException(source(ctx), "SQL statement too large; " +
"halt parsing to prevent memory errors (stopped at depth {})", MAX_RULE_DEPTH);
}
}
super.enterEveryRule(ctx);
}

@Override
public void exitEveryRule(ParserRuleContext ctx) {
if (inDetected(ctx)) {
insideIn = false;
}

decrementCounter(ctx);
super.exitEveryRule(ctx);
}

ObjectShortHashMap<String> depthCounts() {
return depthCounts;
}

private void decrementCounter(ParserRuleContext ctx) {
String className = ctx.getClass().getSimpleName();
String classNameToDecrement = ENTER_EXIT_RULE_MAPPING.getOrDefault(className, className);

// Avoid having negative numbers
if (depthCounts.containsKey(classNameToDecrement)) {
depthCounts.putOrAdd(classNameToDecrement, (short) 0, (short) -1);
}
}

private boolean inDetected(ParserRuleContext ctx) {
if (ctx.getParent() != null && ctx.getParent().getClass() == SqlBaseParser.PredicateContext.class) {
SqlBaseParser.PredicateContext pc = (SqlBaseParser.PredicateContext) ctx.getParent();
return pc.kind != null && pc.kind.getType() == SqlBaseParser.IN;
}
return false;
}
}

private static final BaseErrorListener ERROR_LISTENER = new BaseErrorListener() {
@Override
public void syntaxError(Recognizer<?, ?> recognizer, Object offendingSymbol, int line,
Expand Down
Loading

0 comments on commit 04116c9

Please sign in to comment.