From 6aea881752ef9a83521a4a20bccf510bf9a31d2c Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Thu, 7 May 2020 00:11:07 +0200 Subject: [PATCH] Translate LIKE predicate to Domain --- .../sql/planner/DomainTranslator.java | 96 ++++++++++++ .../sql/planner/TestDomainTranslator.java | 148 +++++++++++++++++- .../sql/planner/TestLogicalPlanner.java | 43 +++++ .../AbstractTestIntegrationSmokeTest.java | 16 ++ 4 files changed, 299 insertions(+), 4 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/DomainTranslator.java b/presto-main/src/main/java/io/prestosql/sql/planner/DomainTranslator.java index 7f0d44b84890..9f8199c04151 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/DomainTranslator.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/DomainTranslator.java @@ -16,6 +16,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.PeekingIterator; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; import io.prestosql.Session; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.OperatorNotFoundException; @@ -34,6 +36,7 @@ import io.prestosql.spi.type.DoubleType; import io.prestosql.spi.type.RealType; import io.prestosql.spi.type.Type; +import io.prestosql.spi.type.VarcharType; import io.prestosql.sql.ExpressionUtils; import io.prestosql.sql.InterpretedFunctionInvoker; import io.prestosql.sql.parser.SqlParser; @@ -47,11 +50,14 @@ import io.prestosql.sql.tree.InPredicate; import io.prestosql.sql.tree.IsNotNullPredicate; import io.prestosql.sql.tree.IsNullPredicate; +import io.prestosql.sql.tree.LikePredicate; import io.prestosql.sql.tree.LogicalBinaryExpression; import io.prestosql.sql.tree.NodeRef; import io.prestosql.sql.tree.NotExpression; import io.prestosql.sql.tree.NullLiteral; +import io.prestosql.sql.tree.StringLiteral; import io.prestosql.sql.tree.SymbolReference; +import io.prestosql.type.LikeFunctions; import io.prestosql.type.TypeCoercion; import javax.annotation.Nullable; @@ -66,6 +72,10 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.collect.Iterators.peekingIterator; +import static io.airlift.slice.SliceUtf8.countCodePoints; +import static io.airlift.slice.SliceUtf8.getCodePointAt; +import static io.airlift.slice.SliceUtf8.lengthOfCodePoint; +import static io.airlift.slice.SliceUtf8.setCodePointAt; import static io.prestosql.spi.function.OperatorType.SATURATED_FLOOR_CAST; import static io.prestosql.sql.ExpressionUtils.and; import static io.prestosql.sql.ExpressionUtils.combineConjuncts; @@ -857,6 +867,92 @@ protected ExtractionResult visitBetweenPredicate(BetweenPredicate node, Boolean new ComparisonExpression(LESS_THAN_OR_EQUAL, node.getValue(), node.getMax())), complement); } + @Override + protected ExtractionResult visitLikePredicate(LikePredicate node, Boolean complement) + { + Optional result = tryVisitLikePredicate(node, complement); + if (result.isPresent()) { + return result.get(); + } + return super.visitLikePredicate(node, complement); + } + + private Optional tryVisitLikePredicate(LikePredicate node, Boolean complement) + { + if (!(node.getValue() instanceof SymbolReference)) { + // LIKE not on a symbol + return Optional.empty(); + } + + if (!(node.getPattern() instanceof StringLiteral)) { + // dynamic pattern + return Optional.empty(); + } + + if (node.getEscape().isPresent() && !(node.getEscape().get() instanceof StringLiteral)) { + // dynamic escape + return Optional.empty(); + } + + Type type = typeAnalyzer.getType(session, types, node.getValue()); + if (!(type instanceof VarcharType)) { + // TODO support CharType + return Optional.empty(); + } + VarcharType varcharType = (VarcharType) type; + + Symbol symbol = Symbol.from(node.getValue()); + Slice pattern = ((StringLiteral) node.getPattern()).getSlice(); + Optional escape = node.getEscape() + .map(StringLiteral.class::cast) + .map(StringLiteral::getSlice); + + int patternConstantPrefixBytes = LikeFunctions.patternConstantPrefixBytes(pattern, escape); + if (patternConstantPrefixBytes == pattern.length()) { + // This should not actually happen, constant LIKE pattern should be converted to equality predicate before DomainTranslator is invoked. + + Slice literal = LikeFunctions.unescapeLiteralLikePattern(pattern, escape); + ValueSet valueSet; + if (varcharType.isUnbounded() || countCodePoints(literal) <= varcharType.getBoundedLength()) { + valueSet = ValueSet.of(type, literal); + } + else { + // impossible to satisfy + valueSet = ValueSet.none(type); + } + Domain domain = Domain.create(complementIfNecessary(valueSet, complement), false); + return Optional.of(new ExtractionResult(TupleDomain.withColumnDomains(ImmutableMap.of(symbol, domain)), TRUE_LITERAL)); + } + + if (complement || patternConstantPrefixBytes == 0) { + // TODO + return Optional.empty(); + } + + Slice constantPrefix = LikeFunctions.unescapeLiteralLikePattern(pattern.slice(0, patternConstantPrefixBytes), escape); + + int lastIncrementable = -1; + for (int position = 0; position < constantPrefix.length(); position += lengthOfCodePoint(constantPrefix, position)) { + // Get last ASCII character to increment, so that character length in bytes does not change. + // Also prefer not to produce non-ASCII if input is all-ASCII, to be on the safe side with connectors. + // TODO remove those limitations + if (getCodePointAt(constantPrefix, position) < 127) { + lastIncrementable = position; + } + } + + if (lastIncrementable == -1) { + return Optional.empty(); + } + + Slice lowerBound = constantPrefix; + Slice upperBound = Slices.copyOf(constantPrefix.slice(0, lastIncrementable + lengthOfCodePoint(constantPrefix, lastIncrementable))); + setCodePointAt(getCodePointAt(constantPrefix, lastIncrementable) + 1, upperBound, lastIncrementable); + + Domain domain = Domain.create(ValueSet.ofRanges(Range.range(type, lowerBound, true, upperBound, false)), false); + return Optional.of(new ExtractionResult(TupleDomain.withColumnDomains(ImmutableMap.of(symbol, domain)), node)); + } + @Override protected ExtractionResult visitIsNullPredicate(IsNullPredicate node, Boolean complement) { diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestDomainTranslator.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestDomainTranslator.java index 9c56361cf4f6..fffca4020a64 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/TestDomainTranslator.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/TestDomainTranslator.java @@ -39,6 +39,7 @@ import io.prestosql.sql.tree.InListExpression; import io.prestosql.sql.tree.InPredicate; import io.prestosql.sql.tree.IsNullPredicate; +import io.prestosql.sql.tree.LikePredicate; import io.prestosql.sql.tree.Literal; import io.prestosql.sql.tree.LongLiteral; import io.prestosql.sql.tree.NotExpression; @@ -56,6 +57,7 @@ import java.math.BigDecimal; import java.util.Arrays; import java.util.List; +import java.util.Optional; import java.util.concurrent.TimeUnit; import static io.airlift.slice.Slices.utf8Slice; @@ -77,6 +79,7 @@ import static io.prestosql.spi.type.TinyintType.TINYINT; import static io.prestosql.spi.type.VarbinaryType.VARBINARY; import static io.prestosql.spi.type.VarcharType.VARCHAR; +import static io.prestosql.spi.type.VarcharType.createUnboundedVarcharType; import static io.prestosql.sql.ExpressionUtils.and; import static io.prestosql.sql.ExpressionUtils.or; import static io.prestosql.sql.analyzer.TypeSignatureTranslator.toSqlType; @@ -1457,6 +1460,128 @@ private void testNumericTypeTranslation(NumericValues columnValues, NumericVa } } + @Test + public void testLikePredicate() + { + Type varcharType = createUnboundedVarcharType(); + + // constant + testSimpleComparison( + like(C_VARCHAR, stringLiteral("abc")), + C_VARCHAR, + Domain.multipleValues(varcharType, ImmutableList.of(utf8Slice("abc")))); + + // starts with pattern + assertUnsupportedPredicate(like(C_VARCHAR, stringLiteral("_def"))); + assertUnsupportedPredicate(like(C_VARCHAR, stringLiteral("%def"))); + + // _ pattern (unless escaped) + testSimpleComparison( + like(C_VARCHAR, stringLiteral("abc_def")), + C_VARCHAR, + like(C_VARCHAR, stringLiteral("abc_def")), + Domain.create(ValueSet.ofRanges(Range.range(varcharType, utf8Slice("abc"), true, utf8Slice("abd"), false)), false)); + + testSimpleComparison( + like(C_VARCHAR, stringLiteral("abc\\_def")), + C_VARCHAR, + like(C_VARCHAR, stringLiteral("abc\\_def")), + Domain.create(ValueSet.ofRanges(Range.range(varcharType, utf8Slice("abc\\"), true, utf8Slice("abc]"), false)), false)); + + testSimpleComparison( + like(C_VARCHAR, stringLiteral("abc\\_def"), stringLiteral("\\")), + C_VARCHAR, + Domain.multipleValues(varcharType, ImmutableList.of(utf8Slice("abc_def")))); + + testSimpleComparison( + like(C_VARCHAR, stringLiteral("abc\\_def_"), stringLiteral("\\")), + C_VARCHAR, + like(C_VARCHAR, stringLiteral("abc\\_def_"), stringLiteral("\\")), + Domain.create(ValueSet.ofRanges(Range.range(varcharType, utf8Slice("abc_def"), true, utf8Slice("abc_deg"), false)), false)); + + testSimpleComparison( + like(C_VARCHAR, stringLiteral("abc^_def_"), stringLiteral("^")), + C_VARCHAR, + like(C_VARCHAR, stringLiteral("abc^_def_"), stringLiteral("^")), + Domain.create(ValueSet.ofRanges(Range.range(varcharType, utf8Slice("abc_def"), true, utf8Slice("abc_deg"), false)), false)); + + // % pattern (unless escaped) + testSimpleComparison( + like(C_VARCHAR, stringLiteral("abc%")), + C_VARCHAR, + like(C_VARCHAR, stringLiteral("abc%")), + Domain.create(ValueSet.ofRanges(Range.range(varcharType, utf8Slice("abc"), true, utf8Slice("abd"), false)), false)); + + testSimpleComparison( + like(C_VARCHAR, stringLiteral("abc%def")), + C_VARCHAR, + like(C_VARCHAR, stringLiteral("abc%def")), + Domain.create(ValueSet.ofRanges(Range.range(varcharType, utf8Slice("abc"), true, utf8Slice("abd"), false)), false)); + + testSimpleComparison( + like(C_VARCHAR, stringLiteral("abc\\%def")), + C_VARCHAR, + like(C_VARCHAR, stringLiteral("abc\\%def")), + Domain.create(ValueSet.ofRanges(Range.range(varcharType, utf8Slice("abc\\"), true, utf8Slice("abc]"), false)), false)); + + testSimpleComparison( + like(C_VARCHAR, stringLiteral("abc\\%def"), stringLiteral("\\")), + C_VARCHAR, + Domain.multipleValues(varcharType, ImmutableList.of(utf8Slice("abc%def")))); + + testSimpleComparison( + like(C_VARCHAR, stringLiteral("abc\\%def_"), stringLiteral("\\")), + C_VARCHAR, + like(C_VARCHAR, stringLiteral("abc\\%def_"), stringLiteral("\\")), + Domain.create(ValueSet.ofRanges(Range.range(varcharType, utf8Slice("abc%def"), true, utf8Slice("abc%deg"), false)), false)); + + testSimpleComparison( + like(C_VARCHAR, stringLiteral("abc^%def_"), stringLiteral("^")), + C_VARCHAR, + like(C_VARCHAR, stringLiteral("abc^%def_"), stringLiteral("^")), + Domain.create(ValueSet.ofRanges(Range.range(varcharType, utf8Slice("abc%def"), true, utf8Slice("abc%deg"), false)), false)); + + // non-ASCII literal + testSimpleComparison( + like(C_VARCHAR, stringLiteral("abc\u007f\u0123\udbfe")), + C_VARCHAR, + Domain.multipleValues(varcharType, ImmutableList.of(utf8Slice("abc\u007f\u0123\udbfe")))); + + // non-ASCII prefix + testSimpleComparison( + like(C_VARCHAR, stringLiteral("abc\u0123\ud83d\ude80def\u007e\u007f\u00ff\u0123\uccf0%")), + C_VARCHAR, + like(C_VARCHAR, stringLiteral("abc\u0123\ud83d\ude80def\u007e\u007f\u00ff\u0123\uccf0%")), + Domain.create( + ValueSet.ofRanges(Range.range(varcharType, + utf8Slice("abc\u0123\ud83d\ude80def\u007e\u007f\u00ff\u0123\uccf0"), true, + utf8Slice("abc\u0123\ud83d\ude80def\u007f"), false)), + false)); + + // dynamic escape + assertUnsupportedPredicate(like(C_VARCHAR, stringLiteral("abc\\_def"), C_VARCHAR_1.toSymbolReference())); + + // negation with literal + testSimpleComparison( + not(like(C_VARCHAR, stringLiteral("abcdef"))), + C_VARCHAR, + Domain.create(ValueSet.ofRanges( + Range.lessThan(varcharType, utf8Slice("abcdef")), + Range.greaterThan(varcharType, utf8Slice("abcdef"))), + false)); + + testSimpleComparison( + not(like(C_VARCHAR, stringLiteral("abc\\_def"), stringLiteral("\\"))), + C_VARCHAR, + Domain.create(ValueSet.ofRanges( + Range.lessThan(varcharType, utf8Slice("abc_def")), + Range.greaterThan(varcharType, utf8Slice("abc_def"))), + false)); + + // negation with pattern + assertUnsupportedPredicate(not(like(C_VARCHAR, stringLiteral("abc\\_def")))); + } + @Test public void testCharComparedToVarcharExpression() { @@ -1568,6 +1693,16 @@ private static ComparisonExpression isDistinctFrom(Symbol symbol, Expression exp return isDistinctFrom(symbol.toSymbolReference(), expression); } + private static LikePredicate like(Symbol symbol, Expression expression) + { + return new LikePredicate(symbol.toSymbolReference(), expression, Optional.empty()); + } + + private static LikePredicate like(Symbol symbol, Expression expression, Expression escape) + { + return new LikePredicate(symbol.toSymbolReference(), expression, Optional.of(escape)); + } + private static Expression isNotNull(Symbol symbol) { return isNotNull(symbol.toSymbolReference()); @@ -1733,14 +1868,19 @@ private void testSimpleComparison(Expression expression, Symbol symbol, Range ex testSimpleComparison(expression, symbol, Domain.create(ValueSet.ofRanges(expectedDomainRange), false)); } - private void testSimpleComparison(Expression expression, Symbol symbol, Domain domain) + private void testSimpleComparison(Expression expression, Symbol symbol, Domain expectedDomain) + { + testSimpleComparison(expression, symbol, TRUE_LITERAL, expectedDomain); + } + + private void testSimpleComparison(Expression expression, Symbol symbol, Expression expectedRemainingExpression, Domain expectedDomain) { ExtractionResult result = fromPredicate(expression); - assertEquals(result.getRemainingExpression(), TRUE_LITERAL); + assertEquals(result.getRemainingExpression(), expectedRemainingExpression); TupleDomain actual = result.getTupleDomain(); - TupleDomain expected = withColumnDomains(ImmutableMap.of(symbol, domain)); + TupleDomain expected = withColumnDomains(ImmutableMap.of(symbol, expectedDomain)); if (!actual.equals(expected)) { - fail(format("for comparison [%s] expected %s but found %s", expression.toString(), expected.toString(SESSION), actual.toString(SESSION))); + fail(format("for comparison [%s] expected [%s] but found [%s]", expression.toString(), expected.toString(SESSION), actual.toString(SESSION))); } } diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestLogicalPlanner.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestLogicalPlanner.java index c683b5c4673c..a9687a0e1707 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/TestLogicalPlanner.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/TestLogicalPlanner.java @@ -15,8 +15,16 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.airlift.slice.Slices; import io.prestosql.Session; +import io.prestosql.plugin.tpch.TpchColumnHandle; +import io.prestosql.plugin.tpch.TpchTableHandle; import io.prestosql.spi.block.SortOrder; +import io.prestosql.spi.connector.ColumnHandle; +import io.prestosql.spi.predicate.Domain; +import io.prestosql.spi.predicate.Range; +import io.prestosql.spi.predicate.TupleDomain; +import io.prestosql.spi.predicate.ValueSet; import io.prestosql.sql.analyzer.FeaturesConfig.JoinDistributionType; import io.prestosql.sql.analyzer.FeaturesConfig.JoinReorderingStrategy; import io.prestosql.sql.planner.assertions.BasePlanTest; @@ -51,10 +59,14 @@ import org.testng.annotations.Test; import java.util.List; +import java.util.Map; +import java.util.Map.Entry; import java.util.Optional; import java.util.function.Consumer; import java.util.function.Predicate; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.MoreCollectors.toOptional; import static io.airlift.slice.Slices.utf8Slice; import static io.prestosql.SystemSessionProperties.DISTRIBUTED_SORT; import static io.prestosql.SystemSessionProperties.FORCE_SINGLE_NODE_OUTPUT; @@ -138,6 +150,37 @@ public void testAnalyze() tableScan("orders", ImmutableMap.of())))))))))); } + @Test + public void testLikePredicate() + { + assertPlan("SELECT type FROM part WHERE type LIKE 'LARGE PLATED %'", + anyTree( + tableScan( + tableHandle -> { + Map domains = ((TpchTableHandle) tableHandle).getConstraint().getDomains() + .orElseThrow(() -> new AssertionError("Unexpected none TupleDomain")); + + Domain domain = domains.entrySet().stream() + .filter(entry -> ((TpchColumnHandle) entry.getKey()).getColumnName().equals("type")) + .map(Entry::getValue) + .collect(toOptional()) + .orElseThrow(() -> new AssertionError("No domain for 'type'")); + + assertEquals(domain, Domain.multipleValues( + createVarcharType(25), + ImmutableList.of("LARGE PLATED BRASS", "LARGE PLATED COPPER", "LARGE PLATED NICKEL", "LARGE PLATED STEEL", "LARGE PLATED TIN").stream() + .map(Slices::utf8Slice) + .collect(toImmutableList()))); + return true; + }, + TupleDomain.withColumnDomains(ImmutableMap.of( + tableHandle -> ((TpchColumnHandle) tableHandle).getColumnName().equals("type"), + Domain.create( + ValueSet.ofRanges(Range.range(createVarcharType(25), utf8Slice("LARGE PLATED "), true, utf8Slice("LARGE PLATED!"), false)), + false))), + ImmutableMap.of()))); + } + @Test public void testAggregation() { diff --git a/presto-testing/src/main/java/io/prestosql/testing/AbstractTestIntegrationSmokeTest.java b/presto-testing/src/main/java/io/prestosql/testing/AbstractTestIntegrationSmokeTest.java index e5d8c9a075e3..39062d17faf6 100644 --- a/presto-testing/src/main/java/io/prestosql/testing/AbstractTestIntegrationSmokeTest.java +++ b/presto-testing/src/main/java/io/prestosql/testing/AbstractTestIntegrationSmokeTest.java @@ -101,6 +101,22 @@ public void testIsNullPredicate() assertQuery("SELECT custkey FROM orders WHERE orderkey = 32 OR orderkey IS NULL", "VALUES (1301)"); } + @Test + public void testLikePredicate() + { + // filtered column is not selected + assertQuery("SELECT orderkey FROM orders WHERE orderpriority LIKE '5-L%'"); + + // filtered column is selected + assertQuery("SELECT orderkey, orderpriority FROM orders WHERE orderpriority LIKE '5-L%'"); + + // filtered column is not selected + assertQuery("SELECT orderkey FROM orders WHERE orderpriority LIKE '5-L__'"); + + // filtered column is selected + assertQuery("SELECT orderkey, orderpriority FROM orders WHERE orderpriority LIKE '5-L__'"); + } + @Test public void testLimit() {