Skip to content

Commit

Permalink
Translate LIKE predicate to Domain
Browse files Browse the repository at this point in the history
  • Loading branch information
findepi committed May 7, 2020
1 parent f9dd802 commit 6aea881
Show file tree
Hide file tree
Showing 4 changed files with 299 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.PeekingIterator;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.prestosql.Session;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.OperatorNotFoundException;
Expand All @@ -34,6 +36,7 @@
import io.prestosql.spi.type.DoubleType;
import io.prestosql.spi.type.RealType;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.VarcharType;
import io.prestosql.sql.ExpressionUtils;
import io.prestosql.sql.InterpretedFunctionInvoker;
import io.prestosql.sql.parser.SqlParser;
Expand All @@ -47,11 +50,14 @@
import io.prestosql.sql.tree.InPredicate;
import io.prestosql.sql.tree.IsNotNullPredicate;
import io.prestosql.sql.tree.IsNullPredicate;
import io.prestosql.sql.tree.LikePredicate;
import io.prestosql.sql.tree.LogicalBinaryExpression;
import io.prestosql.sql.tree.NodeRef;
import io.prestosql.sql.tree.NotExpression;
import io.prestosql.sql.tree.NullLiteral;
import io.prestosql.sql.tree.StringLiteral;
import io.prestosql.sql.tree.SymbolReference;
import io.prestosql.type.LikeFunctions;
import io.prestosql.type.TypeCoercion;

import javax.annotation.Nullable;
Expand All @@ -66,6 +72,10 @@
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.common.collect.Iterators.peekingIterator;
import static io.airlift.slice.SliceUtf8.countCodePoints;
import static io.airlift.slice.SliceUtf8.getCodePointAt;
import static io.airlift.slice.SliceUtf8.lengthOfCodePoint;
import static io.airlift.slice.SliceUtf8.setCodePointAt;
import static io.prestosql.spi.function.OperatorType.SATURATED_FLOOR_CAST;
import static io.prestosql.sql.ExpressionUtils.and;
import static io.prestosql.sql.ExpressionUtils.combineConjuncts;
Expand Down Expand Up @@ -857,6 +867,92 @@ protected ExtractionResult visitBetweenPredicate(BetweenPredicate node, Boolean
new ComparisonExpression(LESS_THAN_OR_EQUAL, node.getValue(), node.getMax())), complement);
}

@Override
protected ExtractionResult visitLikePredicate(LikePredicate node, Boolean complement)
{
Optional<ExtractionResult> result = tryVisitLikePredicate(node, complement);
if (result.isPresent()) {
return result.get();
}
return super.visitLikePredicate(node, complement);
}

private Optional<ExtractionResult> tryVisitLikePredicate(LikePredicate node, Boolean complement)
{
if (!(node.getValue() instanceof SymbolReference)) {
// LIKE not on a symbol
return Optional.empty();
}

if (!(node.getPattern() instanceof StringLiteral)) {
// dynamic pattern
return Optional.empty();
}

if (node.getEscape().isPresent() && !(node.getEscape().get() instanceof StringLiteral)) {
// dynamic escape
return Optional.empty();
}

Type type = typeAnalyzer.getType(session, types, node.getValue());
if (!(type instanceof VarcharType)) {
// TODO support CharType
return Optional.empty();
}
VarcharType varcharType = (VarcharType) type;

Symbol symbol = Symbol.from(node.getValue());
Slice pattern = ((StringLiteral) node.getPattern()).getSlice();
Optional<Slice> escape = node.getEscape()
.map(StringLiteral.class::cast)
.map(StringLiteral::getSlice);

int patternConstantPrefixBytes = LikeFunctions.patternConstantPrefixBytes(pattern, escape);
if (patternConstantPrefixBytes == pattern.length()) {
// This should not actually happen, constant LIKE pattern should be converted to equality predicate before DomainTranslator is invoked.

Slice literal = LikeFunctions.unescapeLiteralLikePattern(pattern, escape);
ValueSet valueSet;
if (varcharType.isUnbounded() || countCodePoints(literal) <= varcharType.getBoundedLength()) {
valueSet = ValueSet.of(type, literal);
}
else {
// impossible to satisfy
valueSet = ValueSet.none(type);
}
Domain domain = Domain.create(complementIfNecessary(valueSet, complement), false);
return Optional.of(new ExtractionResult(TupleDomain.withColumnDomains(ImmutableMap.of(symbol, domain)), TRUE_LITERAL));
}

if (complement || patternConstantPrefixBytes == 0) {
// TODO
return Optional.empty();
}

Slice constantPrefix = LikeFunctions.unescapeLiteralLikePattern(pattern.slice(0, patternConstantPrefixBytes), escape);

int lastIncrementable = -1;
for (int position = 0; position < constantPrefix.length(); position += lengthOfCodePoint(constantPrefix, position)) {
// Get last ASCII character to increment, so that character length in bytes does not change.
// Also prefer not to produce non-ASCII if input is all-ASCII, to be on the safe side with connectors.
// TODO remove those limitations
if (getCodePointAt(constantPrefix, position) < 127) {
lastIncrementable = position;
}
}

if (lastIncrementable == -1) {
return Optional.empty();
}

Slice lowerBound = constantPrefix;
Slice upperBound = Slices.copyOf(constantPrefix.slice(0, lastIncrementable + lengthOfCodePoint(constantPrefix, lastIncrementable)));
setCodePointAt(getCodePointAt(constantPrefix, lastIncrementable) + 1, upperBound, lastIncrementable);

Domain domain = Domain.create(ValueSet.ofRanges(Range.range(type, lowerBound, true, upperBound, false)), false);
return Optional.of(new ExtractionResult(TupleDomain.withColumnDomains(ImmutableMap.of(symbol, domain)), node));
}

@Override
protected ExtractionResult visitIsNullPredicate(IsNullPredicate node, Boolean complement)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import io.prestosql.sql.tree.InListExpression;
import io.prestosql.sql.tree.InPredicate;
import io.prestosql.sql.tree.IsNullPredicate;
import io.prestosql.sql.tree.LikePredicate;
import io.prestosql.sql.tree.Literal;
import io.prestosql.sql.tree.LongLiteral;
import io.prestosql.sql.tree.NotExpression;
Expand All @@ -56,6 +57,7 @@
import java.math.BigDecimal;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.TimeUnit;

import static io.airlift.slice.Slices.utf8Slice;
Expand All @@ -77,6 +79,7 @@
import static io.prestosql.spi.type.TinyintType.TINYINT;
import static io.prestosql.spi.type.VarbinaryType.VARBINARY;
import static io.prestosql.spi.type.VarcharType.VARCHAR;
import static io.prestosql.spi.type.VarcharType.createUnboundedVarcharType;
import static io.prestosql.sql.ExpressionUtils.and;
import static io.prestosql.sql.ExpressionUtils.or;
import static io.prestosql.sql.analyzer.TypeSignatureTranslator.toSqlType;
Expand Down Expand Up @@ -1457,6 +1460,128 @@ private void testNumericTypeTranslation(NumericValues<?> columnValues, NumericVa
}
}

@Test
public void testLikePredicate()
{
Type varcharType = createUnboundedVarcharType();

// constant
testSimpleComparison(
like(C_VARCHAR, stringLiteral("abc")),
C_VARCHAR,
Domain.multipleValues(varcharType, ImmutableList.of(utf8Slice("abc"))));

// starts with pattern
assertUnsupportedPredicate(like(C_VARCHAR, stringLiteral("_def")));
assertUnsupportedPredicate(like(C_VARCHAR, stringLiteral("%def")));

// _ pattern (unless escaped)
testSimpleComparison(
like(C_VARCHAR, stringLiteral("abc_def")),
C_VARCHAR,
like(C_VARCHAR, stringLiteral("abc_def")),
Domain.create(ValueSet.ofRanges(Range.range(varcharType, utf8Slice("abc"), true, utf8Slice("abd"), false)), false));

testSimpleComparison(
like(C_VARCHAR, stringLiteral("abc\\_def")),
C_VARCHAR,
like(C_VARCHAR, stringLiteral("abc\\_def")),
Domain.create(ValueSet.ofRanges(Range.range(varcharType, utf8Slice("abc\\"), true, utf8Slice("abc]"), false)), false));

testSimpleComparison(
like(C_VARCHAR, stringLiteral("abc\\_def"), stringLiteral("\\")),
C_VARCHAR,
Domain.multipleValues(varcharType, ImmutableList.of(utf8Slice("abc_def"))));

testSimpleComparison(
like(C_VARCHAR, stringLiteral("abc\\_def_"), stringLiteral("\\")),
C_VARCHAR,
like(C_VARCHAR, stringLiteral("abc\\_def_"), stringLiteral("\\")),
Domain.create(ValueSet.ofRanges(Range.range(varcharType, utf8Slice("abc_def"), true, utf8Slice("abc_deg"), false)), false));

testSimpleComparison(
like(C_VARCHAR, stringLiteral("abc^_def_"), stringLiteral("^")),
C_VARCHAR,
like(C_VARCHAR, stringLiteral("abc^_def_"), stringLiteral("^")),
Domain.create(ValueSet.ofRanges(Range.range(varcharType, utf8Slice("abc_def"), true, utf8Slice("abc_deg"), false)), false));

// % pattern (unless escaped)
testSimpleComparison(
like(C_VARCHAR, stringLiteral("abc%")),
C_VARCHAR,
like(C_VARCHAR, stringLiteral("abc%")),
Domain.create(ValueSet.ofRanges(Range.range(varcharType, utf8Slice("abc"), true, utf8Slice("abd"), false)), false));

testSimpleComparison(
like(C_VARCHAR, stringLiteral("abc%def")),
C_VARCHAR,
like(C_VARCHAR, stringLiteral("abc%def")),
Domain.create(ValueSet.ofRanges(Range.range(varcharType, utf8Slice("abc"), true, utf8Slice("abd"), false)), false));

testSimpleComparison(
like(C_VARCHAR, stringLiteral("abc\\%def")),
C_VARCHAR,
like(C_VARCHAR, stringLiteral("abc\\%def")),
Domain.create(ValueSet.ofRanges(Range.range(varcharType, utf8Slice("abc\\"), true, utf8Slice("abc]"), false)), false));

testSimpleComparison(
like(C_VARCHAR, stringLiteral("abc\\%def"), stringLiteral("\\")),
C_VARCHAR,
Domain.multipleValues(varcharType, ImmutableList.of(utf8Slice("abc%def"))));

testSimpleComparison(
like(C_VARCHAR, stringLiteral("abc\\%def_"), stringLiteral("\\")),
C_VARCHAR,
like(C_VARCHAR, stringLiteral("abc\\%def_"), stringLiteral("\\")),
Domain.create(ValueSet.ofRanges(Range.range(varcharType, utf8Slice("abc%def"), true, utf8Slice("abc%deg"), false)), false));

testSimpleComparison(
like(C_VARCHAR, stringLiteral("abc^%def_"), stringLiteral("^")),
C_VARCHAR,
like(C_VARCHAR, stringLiteral("abc^%def_"), stringLiteral("^")),
Domain.create(ValueSet.ofRanges(Range.range(varcharType, utf8Slice("abc%def"), true, utf8Slice("abc%deg"), false)), false));

// non-ASCII literal
testSimpleComparison(
like(C_VARCHAR, stringLiteral("abc\u007f\u0123\udbfe")),
C_VARCHAR,
Domain.multipleValues(varcharType, ImmutableList.of(utf8Slice("abc\u007f\u0123\udbfe"))));

// non-ASCII prefix
testSimpleComparison(
like(C_VARCHAR, stringLiteral("abc\u0123\ud83d\ude80def\u007e\u007f\u00ff\u0123\uccf0%")),
C_VARCHAR,
like(C_VARCHAR, stringLiteral("abc\u0123\ud83d\ude80def\u007e\u007f\u00ff\u0123\uccf0%")),
Domain.create(
ValueSet.ofRanges(Range.range(varcharType,
utf8Slice("abc\u0123\ud83d\ude80def\u007e\u007f\u00ff\u0123\uccf0"), true,
utf8Slice("abc\u0123\ud83d\ude80def\u007f"), false)),
false));

// dynamic escape
assertUnsupportedPredicate(like(C_VARCHAR, stringLiteral("abc\\_def"), C_VARCHAR_1.toSymbolReference()));

// negation with literal
testSimpleComparison(
not(like(C_VARCHAR, stringLiteral("abcdef"))),
C_VARCHAR,
Domain.create(ValueSet.ofRanges(
Range.lessThan(varcharType, utf8Slice("abcdef")),
Range.greaterThan(varcharType, utf8Slice("abcdef"))),
false));

testSimpleComparison(
not(like(C_VARCHAR, stringLiteral("abc\\_def"), stringLiteral("\\"))),
C_VARCHAR,
Domain.create(ValueSet.ofRanges(
Range.lessThan(varcharType, utf8Slice("abc_def")),
Range.greaterThan(varcharType, utf8Slice("abc_def"))),
false));

// negation with pattern
assertUnsupportedPredicate(not(like(C_VARCHAR, stringLiteral("abc\\_def"))));
}

@Test
public void testCharComparedToVarcharExpression()
{
Expand Down Expand Up @@ -1568,6 +1693,16 @@ private static ComparisonExpression isDistinctFrom(Symbol symbol, Expression exp
return isDistinctFrom(symbol.toSymbolReference(), expression);
}

private static LikePredicate like(Symbol symbol, Expression expression)
{
return new LikePredicate(symbol.toSymbolReference(), expression, Optional.empty());
}

private static LikePredicate like(Symbol symbol, Expression expression, Expression escape)
{
return new LikePredicate(symbol.toSymbolReference(), expression, Optional.of(escape));
}

private static Expression isNotNull(Symbol symbol)
{
return isNotNull(symbol.toSymbolReference());
Expand Down Expand Up @@ -1733,14 +1868,19 @@ private void testSimpleComparison(Expression expression, Symbol symbol, Range ex
testSimpleComparison(expression, symbol, Domain.create(ValueSet.ofRanges(expectedDomainRange), false));
}

private void testSimpleComparison(Expression expression, Symbol symbol, Domain domain)
private void testSimpleComparison(Expression expression, Symbol symbol, Domain expectedDomain)
{
testSimpleComparison(expression, symbol, TRUE_LITERAL, expectedDomain);
}

private void testSimpleComparison(Expression expression, Symbol symbol, Expression expectedRemainingExpression, Domain expectedDomain)
{
ExtractionResult result = fromPredicate(expression);
assertEquals(result.getRemainingExpression(), TRUE_LITERAL);
assertEquals(result.getRemainingExpression(), expectedRemainingExpression);
TupleDomain<Symbol> actual = result.getTupleDomain();
TupleDomain<Symbol> expected = withColumnDomains(ImmutableMap.of(symbol, domain));
TupleDomain<Symbol> expected = withColumnDomains(ImmutableMap.of(symbol, expectedDomain));
if (!actual.equals(expected)) {
fail(format("for comparison [%s] expected %s but found %s", expression.toString(), expected.toString(SESSION), actual.toString(SESSION)));
fail(format("for comparison [%s] expected [%s] but found [%s]", expression.toString(), expected.toString(SESSION), actual.toString(SESSION)));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,16 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.slice.Slices;
import io.prestosql.Session;
import io.prestosql.plugin.tpch.TpchColumnHandle;
import io.prestosql.plugin.tpch.TpchTableHandle;
import io.prestosql.spi.block.SortOrder;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.predicate.Domain;
import io.prestosql.spi.predicate.Range;
import io.prestosql.spi.predicate.TupleDomain;
import io.prestosql.spi.predicate.ValueSet;
import io.prestosql.sql.analyzer.FeaturesConfig.JoinDistributionType;
import io.prestosql.sql.analyzer.FeaturesConfig.JoinReorderingStrategy;
import io.prestosql.sql.planner.assertions.BasePlanTest;
Expand Down Expand Up @@ -51,10 +59,14 @@
import org.testng.annotations.Test;

import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.function.Predicate;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.MoreCollectors.toOptional;
import static io.airlift.slice.Slices.utf8Slice;
import static io.prestosql.SystemSessionProperties.DISTRIBUTED_SORT;
import static io.prestosql.SystemSessionProperties.FORCE_SINGLE_NODE_OUTPUT;
Expand Down Expand Up @@ -138,6 +150,37 @@ public void testAnalyze()
tableScan("orders", ImmutableMap.of()))))))))));
}

@Test
public void testLikePredicate()
{
assertPlan("SELECT type FROM part WHERE type LIKE 'LARGE PLATED %'",
anyTree(
tableScan(
tableHandle -> {
Map<ColumnHandle, Domain> domains = ((TpchTableHandle) tableHandle).getConstraint().getDomains()
.orElseThrow(() -> new AssertionError("Unexpected none TupleDomain"));

Domain domain = domains.entrySet().stream()
.filter(entry -> ((TpchColumnHandle) entry.getKey()).getColumnName().equals("type"))
.map(Entry::getValue)
.collect(toOptional())
.orElseThrow(() -> new AssertionError("No domain for 'type'"));

assertEquals(domain, Domain.multipleValues(
createVarcharType(25),
ImmutableList.of("LARGE PLATED BRASS", "LARGE PLATED COPPER", "LARGE PLATED NICKEL", "LARGE PLATED STEEL", "LARGE PLATED TIN").stream()
.map(Slices::utf8Slice)
.collect(toImmutableList())));
return true;
},
TupleDomain.withColumnDomains(ImmutableMap.of(
tableHandle -> ((TpchColumnHandle) tableHandle).getColumnName().equals("type"),
Domain.create(
ValueSet.ofRanges(Range.range(createVarcharType(25), utf8Slice("LARGE PLATED "), true, utf8Slice("LARGE PLATED!"), false)),
false))),
ImmutableMap.of())));
}

@Test
public void testAggregation()
{
Expand Down
Loading

0 comments on commit 6aea881

Please sign in to comment.