From 7e09985f73ce3d8d61b99c81b5ba02cf6dc1e657 Mon Sep 17 00:00:00 2001 From: Marios Trivyzas Date: Fri, 21 Sep 2018 14:53:32 +0200 Subject: [PATCH] Full implementation of tree depth Circuit Breaker --- .../xpack/sql/parser/SqlParser.java | 27 ++++++--- .../xpack/sql/parser/SqlParserTests.java | 55 +++++++++++++++++-- 2 files changed, 67 insertions(+), 15 deletions(-) diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/parser/SqlParser.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/parser/SqlParser.java index 4f9538fe919aa..2752b06d2a239 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/parser/SqlParser.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/parser/SqlParser.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.sql.parser; +import com.carrotsearch.hppc.ObjectIntHashMap; import org.antlr.v4.runtime.BaseErrorListener; import org.antlr.v4.runtime.CharStream; import org.antlr.v4.runtime.CommonToken; @@ -99,7 +100,7 @@ private T invokeParser(String sql, CommonTokenStream tokenStream = new CommonTokenStream(tokenSource); SqlBaseParser parser = new SqlBaseParser(tokenStream); - parser.addParseListener(new CircuitBreakerProcessor()); + parser.addParseListener(new CircuitBreakerListener()); parser.addParseListener(new PostProcessor(Arrays.asList(parser.getRuleNames()))); parser.removeErrorListeners(); @@ -212,18 +213,26 @@ public void exitNonReserved(SqlBaseParser.NonReservedContext context) { /** * Used to catch large expressions that can lead to stack overflows */ - private class CircuitBreakerProcessor extends SqlBaseBaseListener { + private class CircuitBreakerListener extends SqlBaseBaseListener { - private static final short MAX_BOOLEAN_ELEMENTS = 1000; - private short countElementsInBooleanExpressions = 0; + private static final short MAX_DEPTH = 100; + + // Keep current depth for every rule visited + ObjectIntHashMap depthCounts = new ObjectIntHashMap<>(100); @Override - public void enterLogicalBinary(SqlBaseParser.LogicalBinaryContext ctx) { - if (++countElementsInBooleanExpressions == MAX_BOOLEAN_ELEMENTS) { - throw new ParsingException("boolean expression is too large to parse, (exceeds {} elements)", - MAX_BOOLEAN_ELEMENTS); + public void enterEveryRule(ParserRuleContext ctx) { + int currentDepth = depthCounts.putOrAdd(ctx.getClass().getSimpleName(), 1, 1); + if (currentDepth > MAX_DEPTH) { + throw new ParsingException("expression is too large to parse, (tree's depth exceeds {})", MAX_DEPTH); } - super.enterLogicalBinary(ctx); + super.enterEveryRule(ctx); + } + + @Override + public void exitEveryRule(ParserRuleContext ctx) { + depthCounts.putOrAdd(ctx.getClass().getSimpleName(), 0, -1); + super.exitEveryRule(ctx); } } diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/parser/SqlParserTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/parser/SqlParserTests.java index 1099e3b304d97..5dbb89725f37c 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/parser/SqlParserTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/parser/SqlParserTests.java @@ -138,14 +138,57 @@ public void testMultiMatchQuery() { assertThat(mmqp.optionMap(), hasEntry("fuzzy_rewrite", "scoring_boolean")); } - public void testLimitToPreventStackOverflowFromLargeBooleanExpression() { - // 1000 elements is ok - new SqlParser().createExpression(Joiner.on(" OR ").join(nCopies(1000, "a = b"))); + public void testLimitToPreventStackOverflowFromLargeUnaryBooleanExpression() { + // 100 elements is ok + new SqlParser().createExpression( + Joiner.on("NOT(").join(nCopies(100, "true")).concat(Joiner.on("").join(nCopies(99, ")")))); - // 1001 elements parser's "circuit breaker" is triggered + // 500 elements parser's "circuit breaker" is triggered + ParsingException e = expectThrows(ParsingException.class, () -> new SqlParser().createExpression( + Joiner.on("NOT(").join(nCopies(101, "true")).concat(Joiner.on("").join(nCopies(100, ")"))))); + assertEquals("expression is too large to parse, (tree's depth exceeds 100)", e.getErrorMessage()); + } + + public void testLimitToPreventStackOverflowFromLargeBinaryBooleanExpression() { + // 100 elements is ok + new SqlParser().createExpression(Joiner.on(" OR ").join(nCopies(100, "true"))); + + // 101 elements parser's "circuit breaker" is triggered ParsingException e = expectThrows(ParsingException.class, () -> - new SqlParser().createExpression(Joiner.on(" OR ").join(nCopies(1001, "a = b")))); - assertEquals("boolean expression is too large to parse, (exceeds 1000 elements)", e.getErrorMessage()); + new SqlParser().createExpression(Joiner.on(" OR ").join(nCopies(101, "a = b")))); + assertEquals("expression is too large to parse, (tree's depth exceeds 100)", e.getErrorMessage()); + } + + public void testLimitToPreventStackOverflowFromLargeUnaryArithmeticExpression() { + // 100 elements is ok + new SqlParser().createExpression( + Joiner.on("abs(").join(nCopies(100, "i")).concat(Joiner.on("").join(nCopies(99, ")")))); + + // 101 elements parser's "circuit breaker" is triggered + ParsingException e = expectThrows(ParsingException.class, () -> new SqlParser().createExpression( + Joiner.on("abs(").join(nCopies(101, "i")).concat(Joiner.on("").join(nCopies(100, ")"))))); + assertEquals("expression is too large to parse, (tree's depth exceeds 100)", e.getErrorMessage()); + } + + public void testLimitToPreventStackOverflowFromLargeBinaryArithmeticExpression() { + // 100 elements is ok + new SqlParser().createExpression(Joiner.on(" + ").join(nCopies(100, "a"))); + + // 101 elements parser's "circuit breaker" is triggered + ParsingException e = expectThrows(ParsingException.class, () -> + new SqlParser().createExpression(Joiner.on(" + ").join(nCopies(101, "a")))); + assertEquals("expression is too large to parse, (tree's depth exceeds 100)", e.getErrorMessage()); + } + + public void testLimitToPreventStackOverflowFromLargeSubselectTree() { + // 100 elements is ok + new SqlParser().createStatement( + Joiner.on(" (").join(nCopies(100, "SELECT * FROM")).concat("t").concat(Joiner.on("").join(nCopies(99, ")")))); + + // 101 elements parser's "circuit breaker" is triggered + ParsingException e = expectThrows(ParsingException.class, () -> new SqlParser().createStatement( + Joiner.on(" (").join(nCopies(101, "SELECT * FROM")).concat("t").concat(Joiner.on("").join(nCopies(100, ")"))))); + assertEquals("expression is too large to parse, (tree's depth exceeds 100)", e.getErrorMessage()); } private LogicalPlan parseStatement(String sql) {