From 8d6f162e9f01f74ec46a7058edf851df023d0f83 Mon Sep 17 00:00:00 2001 From: Assaf Bern Date: Thu, 6 May 2021 18:36:39 +0300 Subject: [PATCH 1/2] Support ConnectorExpression pushdown and introduce functions and LIKE pushdown --- .../main/java/io/trino/FeaturesConfig.java | 13 ++ .../io/trino/SystemSessionProperties.java | 11 ++ .../io/trino/metadata/MetadataManager.java | 5 +- .../ConnectorExpressionTranslator.java | 174 ++++++++++++++++-- .../trino/sql/planner/PartialTranslator.java | 10 +- .../rule/PushAggregationIntoTableScan.java | 2 +- .../rule/PushPredicateIntoTableScan.java | 42 ++++- .../rule/PushProjectionIntoTableScan.java | 7 +- .../sql/analyzer/TestFeaturesConfig.java | 3 + .../TestConnectorExpressionTranslator.java | 69 ++++++- .../sql/planner/TestPartialTranslator.java | 10 +- .../rule/TestPushProjectionIntoTableScan.java | 45 +++-- .../io/trino/spi/connector/Constraint.java | 54 +++++- .../ConstraintApplicationResult.java | 37 ++++ .../java/io/trino/spi/expression/Call.java | 86 +++++++++ .../io/trino/spi/expression/Constant.java | 4 + .../io/trino/spi/expression/FunctionName.java | 86 +++++++++ .../base/expression/ConnectorExpressions.java | 52 ++++++ 18 files changed, 655 insertions(+), 55 deletions(-) create mode 100644 core/trino-spi/src/main/java/io/trino/spi/expression/Call.java create mode 100644 core/trino-spi/src/main/java/io/trino/spi/expression/FunctionName.java create mode 100644 lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/expression/ConnectorExpressions.java diff --git a/core/trino-main/src/main/java/io/trino/FeaturesConfig.java b/core/trino-main/src/main/java/io/trino/FeaturesConfig.java index 52075e5b6385..32ec4277caa5 100644 --- a/core/trino-main/src/main/java/io/trino/FeaturesConfig.java +++ b/core/trino-main/src/main/java/io/trino/FeaturesConfig.java @@ -130,6 +130,7 @@ public class FeaturesConfig private boolean optimizeTopNRanking = true; private boolean lateMaterializationEnabled; private boolean skipRedundantSort = true; + private boolean complexExpressionPushdownEnabled = true; private boolean predicatePushdownUseTableProperties = true; private boolean ignoreDownstreamPreferences; private boolean rewriteFilteringSemiJoinToInnerJoin = true; @@ -948,6 +949,18 @@ public FeaturesConfig setSkipRedundantSort(boolean value) return this; } + public boolean isComplexExpressionPushdownEnabled() + { + return complexExpressionPushdownEnabled; + } + + @Config("optimizer.complex-expression-pushdown.enabled") + public FeaturesConfig setComplexExpressionPushdownEnabled(boolean complexExpressionPushdownEnabled) + { + this.complexExpressionPushdownEnabled = complexExpressionPushdownEnabled; + return this; + } + public boolean isPredicatePushdownUseTableProperties() { return predicatePushdownUseTableProperties; diff --git a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java index c80eadc2b53b..321b49ef8e3b 100644 --- a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java +++ b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java @@ -122,6 +122,7 @@ public final class SystemSessionProperties public static final String DEFAULT_FILTER_FACTOR_ENABLED = "default_filter_factor_enabled"; public static final String SKIP_REDUNDANT_SORT = "skip_redundant_sort"; public static final String ALLOW_PUSHDOWN_INTO_CONNECTORS = "allow_pushdown_into_connectors"; + public static final String COMPLEX_EXPRESSION_PUSHDOWN = "complex_expression_pushdown"; public static final String PREDICATE_PUSHDOWN_USE_TABLE_PROPERTIES = "predicate_pushdown_use_table_properties"; public static final String LATE_MATERIALIZATION = "late_materialization"; public static final String ENABLE_DYNAMIC_FILTERING = "enable_dynamic_filtering"; @@ -551,6 +552,11 @@ public SystemSessionProperties( // This is a diagnostic property true, true), + booleanProperty( + COMPLEX_EXPRESSION_PUSHDOWN, + "Allow complex expression pushdown into connectors", + featuresConfig.isComplexExpressionPushdownEnabled(), + true), booleanProperty( PREDICATE_PUSHDOWN_USE_TABLE_PROPERTIES, "Use table properties in predicate pushdown", @@ -1127,6 +1133,11 @@ public static boolean isAllowPushdownIntoConnectors(Session session) return session.getSystemProperty(ALLOW_PUSHDOWN_INTO_CONNECTORS, Boolean.class); } + public static boolean isComplexExpressionPushdown(Session session) + { + return session.getSystemProperty(COMPLEX_EXPRESSION_PUSHDOWN, Boolean.class); + } + public static boolean isPredicatePushdownUseTableProperties(Session session) { return session.getSystemProperty(PREDICATE_PUSHDOWN_USE_TABLE_PROPERTIES, Boolean.class); diff --git a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java index 021dd798eff9..d6c44b5c8e56 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java @@ -1722,10 +1722,7 @@ public Optional> applyFilter(Session se ConnectorSession connectorSession = session.toConnectorSession(catalogName); return metadata.applyFilter(connectorSession, table.getConnectorHandle(), constraint) - .map(result -> new ConstraintApplicationResult<>( - new TableHandle(catalogName, result.getHandle(), table.getTransaction()), - result.getRemainingFilter(), - result.isPrecalculateStatistics())); + .map(result -> result.transform(handle -> new TableHandle(catalogName, handle, table.getTransaction()))); } @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java index 1ab84420c680..b9a99e0ae7ee 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java @@ -13,14 +13,27 @@ */ package io.trino.sql.planner; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; import io.trino.Session; +import io.trino.metadata.LiteralFunction; +import io.trino.metadata.ResolvedFunction; +import io.trino.security.AllowAllAccessControl; +import io.trino.spi.expression.Call; import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.expression.Constant; import io.trino.spi.expression.FieldDereference; +import io.trino.spi.expression.FunctionName; import io.trino.spi.expression.Variable; import io.trino.spi.type.Decimals; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; +import io.trino.spi.type.TypeSignature; +import io.trino.spi.type.VarcharType; +import io.trino.sql.PlannerContext; +import io.trino.sql.analyzer.TypeSignatureProvider; import io.trino.sql.tree.AstVisitor; import io.trino.sql.tree.BinaryLiteral; import io.trino.sql.tree.BooleanLiteral; @@ -28,71 +41,135 @@ import io.trino.sql.tree.DecimalLiteral; import io.trino.sql.tree.DoubleLiteral; import io.trino.sql.tree.Expression; +import io.trino.sql.tree.FunctionCall; +import io.trino.sql.tree.LikePredicate; import io.trino.sql.tree.LongLiteral; import io.trino.sql.tree.NodeRef; import io.trino.sql.tree.NullLiteral; +import io.trino.sql.tree.QualifiedName; import io.trino.sql.tree.StringLiteral; import io.trino.sql.tree.SubscriptExpression; import io.trino.sql.tree.SymbolReference; +import io.trino.type.JoniRegexp; +import io.trino.type.Re2JRegexp; +import java.util.List; import java.util.Map; import java.util.Optional; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.airlift.slice.SliceUtf8.countCodePoints; +import static io.trino.SystemSessionProperties.isComplexExpressionPushdown; +import static io.trino.sql.planner.ExpressionInterpreter.evaluateConstantExpression; +import static io.trino.type.LikeFunctions.LIKE_PATTERN_FUNCTION_NAME; import static java.util.Objects.requireNonNull; public final class ConnectorExpressionTranslator { private ConnectorExpressionTranslator() {} - public static Expression translate(Session session, ConnectorExpression expression, Map variableMappings, LiteralEncoder literalEncoder) + public static Expression translate(Session session, ConnectorExpression expression, PlannerContext plannerContext, Map variableMappings, LiteralEncoder literalEncoder) { - return new ConnectorToSqlExpressionTranslator(variableMappings, literalEncoder).translate(session, expression); + return new ConnectorToSqlExpressionTranslator(session, plannerContext, literalEncoder, variableMappings) + .translate(session, expression) + .orElseThrow(() -> new UnsupportedOperationException("Expression is not supported: " + expression.toString())); } - public static Optional translate(Session session, Expression expression, TypeAnalyzer types, TypeProvider inputTypes) + public static Optional translate(Session session, Expression expression, TypeAnalyzer types, TypeProvider inputTypes, PlannerContext plannerContext) { - return new SqlToConnectorExpressionTranslator(types.getTypes(session, inputTypes, expression)) + return new SqlToConnectorExpressionTranslator(session, types.getTypes(session, inputTypes, expression), plannerContext) .process(expression); } private static class ConnectorToSqlExpressionTranslator { - private final Map variableMappings; + private final Session session; + private final PlannerContext plannerContext; private final LiteralEncoder literalEncoder; + private final Map variableMappings; - public ConnectorToSqlExpressionTranslator(Map variableMappings, LiteralEncoder literalEncoder) + public ConnectorToSqlExpressionTranslator(Session session, PlannerContext plannerContext, LiteralEncoder literalEncoder, Map variableMappings) { - this.variableMappings = requireNonNull(variableMappings, "variableMappings is null"); + this.session = requireNonNull(session, "session is null"); + this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); this.literalEncoder = requireNonNull(literalEncoder, "literalEncoder is null"); + this.variableMappings = requireNonNull(variableMappings, "variableMappings is null"); } - public Expression translate(Session session, ConnectorExpression expression) + public Optional translate(Session session, ConnectorExpression expression) { if (expression instanceof Variable) { - return variableMappings.get(((Variable) expression).getName()).toSymbolReference(); + String name = ((Variable) expression).getName(); + return Optional.of(variableMappings.get(name).toSymbolReference()); } if (expression instanceof Constant) { - return literalEncoder.toExpression(session, ((Constant) expression).getValue(), expression.getType()); + return Optional.of(literalEncoder.toExpression(session, ((Constant) expression).getValue(), expression.getType())); } if (expression instanceof FieldDereference) { FieldDereference dereference = (FieldDereference) expression; - return new SubscriptExpression(translate(session, dereference.getTarget()), new LongLiteral(Long.toString(dereference.getField() + 1))); + return translate(session, dereference.getTarget()) + .map(base -> new SubscriptExpression(base, new LongLiteral(Long.toString(dereference.getField() + 1)))); + } + + if (expression instanceof Call) { + return translateCall((Call) expression); } - throw new UnsupportedOperationException("Expression type not supported: " + expression.getClass().getName()); + return Optional.empty(); + } + + protected Optional translateCall(Call call) + { + if (call.getFunctionName().getCatalogSchema().isPresent()) { + return Optional.empty(); + } + QualifiedName name = QualifiedName.of(call.getFunctionName().getName()); + List argumentTypes = call.getArguments().stream() + .map(argument -> argument.getType().getTypeSignature()) + .collect(toImmutableList()); + ResolvedFunction resolved = plannerContext.getMetadata().resolveFunction(session, name, TypeSignatureProvider.fromTypeSignatures(argumentTypes)); + + // TODO Support ESCAPE character + if (LIKE_PATTERN_FUNCTION_NAME.equals(resolved.getSignature().getName()) && call.getArguments().size() == 2) { + return translateLike(call.getArguments().get(0), call.getArguments().get(1)); + } + + FunctionCallBuilder builder = FunctionCallBuilder.resolve(session, plannerContext.getMetadata()) + .setName(name); + for (int i = 0; i < call.getArguments().size(); i++) { + Type type = resolved.getSignature().getArgumentTypes().get(i); + Expression expression = ConnectorExpressionTranslator.translate(session, call.getArguments().get(i), plannerContext, variableMappings, literalEncoder); + builder.addArgument(type, expression); + } + return Optional.of(builder.build()); + } + + protected Optional translateLike(ConnectorExpression value, ConnectorExpression pattern) + { + Optional translatedValue = translate(session, value); + Optional translatedPattern = translate(session, pattern); + if (translatedValue.isPresent() && translatedPattern.isPresent()) { + return Optional.of(new LikePredicate(translatedValue.get(), translatedPattern.get(), Optional.empty())); + } + return Optional.empty(); } } - static class SqlToConnectorExpressionTranslator + public static class SqlToConnectorExpressionTranslator extends AstVisitor, Void> { + private final Session session; private final Map, Type> types; + private final PlannerContext plannerContext; - public SqlToConnectorExpressionTranslator(Map, Type> types) + public SqlToConnectorExpressionTranslator(Session session, Map, Type> types, PlannerContext plannerContext) { + this.session = requireNonNull(session, "session is null"); this.types = requireNonNull(types, "types is null"); + this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); } @Override @@ -149,6 +226,61 @@ protected Optional visitNullLiteral(NullLiteral node, Void return Optional.of(new Constant(null, typeOf(node))); } + @Override + protected Optional visitFunctionCall(FunctionCall node, Void context) + { + if (!isComplexExpressionPushdown(session)) { + return Optional.empty(); + } + + if (node.getFilter().isPresent() || node.getOrderBy().isPresent() || node.getWindow().isPresent() || node.getNullTreatment().isPresent() || node.isDistinct()) { + return Optional.empty(); + } + + String functionName = ResolvedFunction.extractFunctionName(node.getName()); + + if (LiteralFunction.LITERAL_FUNCTION_NAME.equalsIgnoreCase(functionName)) { + Object value = evaluateConstant(node); + if (value instanceof JoniRegexp) { + Slice pattern = ((JoniRegexp) value).pattern(); + return Optional.of(new Constant(pattern, VarcharType.createVarcharType(countCodePoints(pattern)))); + } + if (value instanceof Re2JRegexp) { + Slice pattern = Slices.utf8Slice(((Re2JRegexp) value).pattern()); + return Optional.of(new Constant(pattern, VarcharType.createVarcharType(countCodePoints(pattern)))); + } + return Optional.of(new Constant(value, types.get(NodeRef.of(node)))); + } + + ImmutableList.Builder arguments = ImmutableList.builder(); + for (Expression argumentExpression : node.getArguments()) { + Optional argument = process(argumentExpression); + if (argument.isEmpty()) { + return Optional.empty(); + } + arguments.add(argument.get()); + } + + // Currently, plugin-provided and runtime-added functions doesn't have a catalog/schema qualifier. + // TODO Translate catalog/schema qualifier when available. + FunctionName name = new FunctionName(functionName); + return Optional.of(new Call(typeOf(node), name, arguments.build())); + } + + @Override + protected Optional visitLikePredicate(LikePredicate node, Void context) + { + // TODO Support ESCAPE character + if (node.getEscape().isEmpty()) { + Optional value = process(node.getValue()); + Optional pattern = process(node.getPattern()); + if (value.isPresent() && pattern.isPresent()) { + return Optional.of(new Call(typeOf(node), new FunctionName(LIKE_PATTERN_FUNCTION_NAME), List.of(value.get(), pattern.get()))); + } + } + return Optional.empty(); + } + @Override protected Optional visitSubscriptExpression(SubscriptExpression node, Void context) { @@ -174,5 +306,19 @@ private Type typeOf(Expression node) { return types.get(NodeRef.of(node)); } + + private Object evaluateConstant(Expression node) + { + Type type = typeOf(node); + Object value = evaluateConstantExpression( + node, + type, + plannerContext, + session, + new AllowAllAccessControl(), + ImmutableMap.of()); + verify(!(value instanceof Expression), "Expression %s did not evaluate to constant: %s", node, value); + return value; + } } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PartialTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/PartialTranslator.java index 9deedf53b881..12ebe31e2ee0 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PartialTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PartialTranslator.java @@ -17,6 +17,7 @@ import io.trino.Session; import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.type.Type; +import io.trino.sql.PlannerContext; import io.trino.sql.tree.AstVisitor; import io.trino.sql.tree.Expression; import io.trino.sql.tree.LambdaExpression; @@ -40,7 +41,8 @@ public static Map, ConnectorExpression> extractPartialTransl Expression inputExpression, Session session, TypeAnalyzer typeAnalyzer, - TypeProvider typeProvider) + TypeProvider typeProvider, + PlannerContext plannerContext) { requireNonNull(inputExpression, "inputExpression is null"); requireNonNull(session, "session is null"); @@ -48,7 +50,7 @@ public static Map, ConnectorExpression> extractPartialTransl requireNonNull(typeProvider, "typeProvider is null"); Map, ConnectorExpression> partialTranslations = new HashMap<>(); - new Visitor(typeAnalyzer.getTypes(session, typeProvider, inputExpression), partialTranslations).process(inputExpression); + new Visitor(session, typeAnalyzer.getTypes(session, typeProvider, inputExpression), partialTranslations, plannerContext).process(inputExpression); return ImmutableMap.copyOf(partialTranslations); } @@ -58,11 +60,11 @@ private static class Visitor private final Map, ConnectorExpression> translatedSubExpressions; private final ConnectorExpressionTranslator.SqlToConnectorExpressionTranslator translator; - Visitor(Map, Type> types, Map, ConnectorExpression> translatedSubExpressions) + Visitor(Session session, Map, Type> types, Map, ConnectorExpression> translatedSubExpressions, PlannerContext plannerContext) { requireNonNull(types, "types is null"); this.translatedSubExpressions = requireNonNull(translatedSubExpressions, "translatedSubExpressions is null"); - this.translator = new ConnectorExpressionTranslator.SqlToConnectorExpressionTranslator(types); + this.translator = new ConnectorExpressionTranslator.SqlToConnectorExpressionTranslator(session, types, plannerContext); } @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java index 9dd6453bb1ce..158d6a443053 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java @@ -179,7 +179,7 @@ public static Optional pushAggregationIntoTableScan( } List newProjections = result.getProjections().stream() - .map(expression -> ConnectorExpressionTranslator.translate(context.getSession(), expression, variableMappings, new LiteralEncoder(plannerContext))) + .map(expression -> ConnectorExpressionTranslator.translate(context.getSession(), expression, plannerContext, variableMappings, new LiteralEncoder(plannerContext))) .collect(toImmutableList()); verify(aggregationOutputSymbols.size() == newProjections.size()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java index da28bd921c35..c9c89e6ef767 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java @@ -28,11 +28,16 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; +import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.type.Type; +import io.trino.sql.ExpressionUtils; import io.trino.sql.PlannerContext; +import io.trino.sql.planner.ConnectorExpressionTranslator; import io.trino.sql.planner.DomainTranslator; import io.trino.sql.planner.LayoutConstraintEvaluator; +import io.trino.sql.planner.LiteralEncoder; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolAllocator; import io.trino.sql.planner.TypeAnalyzer; @@ -42,15 +47,19 @@ import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.ValuesNode; import io.trino.sql.tree.Expression; +import io.trino.sql.tree.NodeRef; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.function.Function; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.trino.SystemSessionProperties.isAllowPushdownIntoConnectors; import static io.trino.matching.Capture.newCapture; +import static io.trino.spi.expression.Constant.TRUE; import static io.trino.sql.ExpressionUtils.combineConjuncts; import static io.trino.sql.ExpressionUtils.filterDeterministicConjuncts; import static io.trino.sql.ExpressionUtils.filterNonDeterministicConjuncts; @@ -169,6 +178,16 @@ public static Optional pushFilterIntoTableScan( .transformKeys(node.getAssignments()::get) .intersect(node.getEnforcedConstraint()); + Map, Type> remainingExpressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), decomposedPredicate.getRemainingExpression()); + Optional connectorExpression = new ConnectorExpressionTranslator.SqlToConnectorExpressionTranslator(session, remainingExpressionTypes, plannerContext) + .process(decomposedPredicate.getRemainingExpression()); + Map connectorExpressionAssignments = connectorExpression + .map(ignored -> + node.getAssignments() + .entrySet().stream() + .collect(toImmutableMap(entry -> entry.getKey().getName(), Map.Entry::getValue))) + .orElse(ImmutableMap.of()); + Map assignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse(); Constraint constraint; @@ -186,12 +205,12 @@ public static Optional pushFilterIntoTableScan( // Simplify the tuple domain to avoid creating an expression with too many nodes, // which would be expensive to evaluate in the call to isCandidate below. domainTranslator.toPredicate(session, newDomain.simplify().transformKeys(assignments::get)))); - constraint = new Constraint(newDomain, evaluator::isCandidate, evaluator.getArguments()); + constraint = new Constraint(newDomain, connectorExpression.orElse(TRUE), connectorExpressionAssignments, evaluator::isCandidate, evaluator.getArguments()); } else { // Currently, invoking the expression interpreter is very expensive. // TODO invoke the interpreter unconditionally when the interpreter becomes cheap enough. - constraint = new Constraint(newDomain); + constraint = new Constraint(newDomain, connectorExpression.orElse(TRUE), connectorExpressionAssignments); } // check if new domain is wider than domain already provided by table scan @@ -234,6 +253,7 @@ public static Optional pushFilterIntoTableScan( } TupleDomain remainingFilter = result.get().getRemainingFilter(); + Optional remainingConnectorExpression = result.get().getRemainingExpression(); boolean precalculateStatistics = result.get().isPrecalculateStatistics(); verifyTablePartitioning(session, plannerContext.getMetadata(), node, newTablePartitioning); @@ -249,6 +269,22 @@ public static Optional pushFilterIntoTableScan( node.isUpdateTarget(), node.getUseConnectorNodePartitioning()); + Expression remainingDecomposedPredicate; + if (remainingConnectorExpression.isEmpty() || remainingConnectorExpression.equals(connectorExpression)) { + remainingDecomposedPredicate = decomposedPredicate.getRemainingExpression(); + } + else { + Map variableMappings = assignments.values().stream() + .collect(toImmutableMap(Symbol::getName, Function.identity())); + Expression translatedExpression = ConnectorExpressionTranslator.translate(session, remainingConnectorExpression.get(), plannerContext, variableMappings, new LiteralEncoder(plannerContext)); + if (connectorExpression.isEmpty()) { + remainingDecomposedPredicate = ExpressionUtils.combineConjuncts(plannerContext.getMetadata(), translatedExpression, decomposedPredicate.getRemainingExpression()); + } + else { + remainingDecomposedPredicate = translatedExpression; + } + } + Expression resultingPredicate = createResultingPredicate( plannerContext, session, @@ -256,7 +292,7 @@ public static Optional pushFilterIntoTableScan( typeAnalyzer, domainTranslator.toPredicate(session, remainingFilter.transformKeys(assignments::get)), nonDeterministicPredicate, - decomposedPredicate.getRemainingExpression()); + remainingDecomposedPredicate); if (!TRUE_LITERAL.equals(resultingPredicate)) { return Optional.of(new FilterNode(filterNode.getId(), tableScan, resultingPredicate)); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionIntoTableScan.java index 80ccf484b8c3..b73a638049cd 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionIntoTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionIntoTableScan.java @@ -106,7 +106,8 @@ public Result apply(ProjectNode project, Captures captures, Context context) expression.getValue(), context.getSession(), typeAnalyzer, - context.getSymbolAllocator().getTypes()).entrySet().stream()) + context.getSymbolAllocator().getTypes(), + plannerContext).entrySet().stream()) // Avoid duplicates .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue, (first, ignore) -> first)); @@ -143,7 +144,7 @@ public Result apply(ProjectNode project, Captures captures, Context context) // Translate partial connector projections back to new partial projections List newPartialProjections = newConnectorPartialProjections.stream() - .map(expression -> ConnectorExpressionTranslator.translate(context.getSession(), expression, variableMappings, new LiteralEncoder(plannerContext))) + .map(expression -> ConnectorExpressionTranslator.translate(context.getSession(), expression, plannerContext, variableMappings, new LiteralEncoder(plannerContext))) .collect(toImmutableList()); // Map internal node references to new partial projections @@ -170,7 +171,7 @@ public Result apply(ProjectNode project, Captures captures, Context context) continue; } String resultVariableName = ((Variable) resultConnectorExpression).getName(); - Expression inputExpression = ConnectorExpressionTranslator.translate(context.getSession(), inputConnectorExpression, inputVariableMappings, new LiteralEncoder(plannerContext)); + Expression inputExpression = ConnectorExpressionTranslator.translate(context.getSession(), inputConnectorExpression, plannerContext, inputVariableMappings, new LiteralEncoder(plannerContext)); SymbolStatsEstimate symbolStatistics = scalarStatsCalculator.calculate(inputExpression, statistics, context.getSession(), context.getSymbolAllocator().getTypes()); builder.addSymbolStatistics(variableMappings.get(resultVariableName), symbolStatistics); } diff --git a/core/trino-main/src/test/java/io/trino/sql/analyzer/TestFeaturesConfig.java b/core/trino-main/src/test/java/io/trino/sql/analyzer/TestFeaturesConfig.java index 5eedcd802045..e4e3747a1049 100644 --- a/core/trino-main/src/test/java/io/trino/sql/analyzer/TestFeaturesConfig.java +++ b/core/trino-main/src/test/java/io/trino/sql/analyzer/TestFeaturesConfig.java @@ -102,6 +102,7 @@ public void testDefaults() .setMaxGroupingSets(2048) .setLateMaterializationEnabled(false) .setSkipRedundantSort(true) + .setComplexExpressionPushdownEnabled(true) .setPredicatePushdownUseTableProperties(true) .setIgnoreDownstreamPreferences(false) .setOmitDateTimeTypePrecision(false) @@ -180,6 +181,7 @@ public void testExplicitPropertyMappings() .put("analyzer.max-grouping-sets", "2047") .put("experimental.late-materialization.enabled", "true") .put("optimizer.skip-redundant-sort", "false") + .put("optimizer.complex-expression-pushdown.enabled", "false") .put("optimizer.predicate-pushdown-use-table-properties", "false") .put("optimizer.ignore-downstream-preferences", "true") .put("deprecated.omit-datetime-type-precision", "true") @@ -255,6 +257,7 @@ public void testExplicitPropertyMappings() .setDefaultFilterFactorEnabled(true) .setLateMaterializationEnabled(true) .setSkipRedundantSort(false) + .setComplexExpressionPushdownEnabled(false) .setPredicatePushdownUseTableProperties(false) .setIgnoreDownstreamPreferences(true) .setOmitDateTimeTypePrecision(true) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java index 34b360e2030f..89671b883156 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java @@ -14,22 +14,33 @@ package io.trino.sql.planner; import com.google.common.collect.ImmutableMap; +import io.airlift.slice.Slices; import io.trino.Session; +import io.trino.security.AllowAllAccessControl; +import io.trino.spi.expression.Call; import io.trino.spi.expression.ConnectorExpression; +import io.trino.spi.expression.Constant; import io.trino.spi.expression.FieldDereference; +import io.trino.spi.expression.FunctionName; import io.trino.spi.expression.Variable; import io.trino.spi.type.Type; import io.trino.sql.tree.Expression; +import io.trino.sql.tree.LikePredicate; import io.trino.sql.tree.LongLiteral; +import io.trino.sql.tree.QualifiedName; +import io.trino.sql.tree.StringLiteral; import io.trino.sql.tree.SubscriptExpression; import io.trino.sql.tree.SymbolReference; import io.trino.testing.TestingSession; +import io.trino.transaction.TestingTransactionManager; import org.testng.annotations.Test; +import java.util.List; import java.util.Map; import java.util.Optional; import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.RowType.field; @@ -38,6 +49,9 @@ import static io.trino.sql.planner.ConnectorExpressionTranslator.translate; import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT; import static io.trino.sql.planner.TypeAnalyzer.createTestingTypeAnalyzer; +import static io.trino.transaction.TransactionBuilder.transaction; +import static io.trino.type.LikeFunctions.LIKE_PATTERN_FUNCTION_NAME; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; public class TestConnectorExpressionTranslator @@ -45,11 +59,13 @@ public class TestConnectorExpressionTranslator private static final Session TEST_SESSION = TestingSession.testSessionBuilder().build(); private static final TypeAnalyzer TYPE_ANALYZER = createTestingTypeAnalyzer(PLANNER_CONTEXT); private static final Type ROW_TYPE = rowType(field("int_symbol_1", INTEGER), field("varchar_symbol_1", createVarcharType(5))); + private static final Type VARCHAR_TYPE = createVarcharType(25); private static final LiteralEncoder LITERAL_ENCODER = new LiteralEncoder(PLANNER_CONTEXT); private static final Map symbols = ImmutableMap.builder() .put(new Symbol("double_symbol_1"), DOUBLE) .put(new Symbol("row_symbol_1"), ROW_TYPE) + .put(new Symbol("varchar_symbol_1"), VARCHAR_TYPE) .buildOrThrow(); private static final TypeProvider TYPE_PROVIDER = TypeProvider.copyOf(symbols); @@ -59,9 +75,10 @@ public class TestConnectorExpressionTranslator @Test public void testTranslationToConnectorExpression() { - assertTranslationToConnectorExpression(new SymbolReference("double_symbol_1"), Optional.of(new Variable("double_symbol_1", DOUBLE))); + assertTranslationToConnectorExpression(TEST_SESSION, new SymbolReference("double_symbol_1"), Optional.of(new Variable("double_symbol_1", DOUBLE))); assertTranslationToConnectorExpression( + TEST_SESSION, new SubscriptExpression( new SymbolReference("row_symbol_1"), new LongLiteral("1")), @@ -70,6 +87,31 @@ public void testTranslationToConnectorExpression() INTEGER, new Variable("row_symbol_1", ROW_TYPE), 0))); + + String pattern = "%pattern%"; + assertTranslationToConnectorExpression( + TEST_SESSION, + new LikePredicate( + new SymbolReference("varchar_symbol_1"), + new StringLiteral(pattern), + Optional.empty()), + Optional.of(new Call(BOOLEAN, + new FunctionName(LIKE_PATTERN_FUNCTION_NAME), + List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE), + new Constant(Slices.wrappedBuffer(pattern.getBytes(UTF_8)), createVarcharType(pattern.length())))))); + + transaction(new TestingTransactionManager(), new AllowAllAccessControl()) + .readOnly() + .execute(TEST_SESSION, transactionSession -> { + assertTranslationToConnectorExpression(transactionSession, + FunctionCallBuilder.resolve(TEST_SESSION, PLANNER_CONTEXT.getMetadata()) + .setName(QualifiedName.of(("lower"))) + .addArgument(VARCHAR_TYPE, new SymbolReference("varchar_symbol_1")) + .build(), + Optional.of(new Call(VARCHAR_TYPE, + new FunctionName("lower"), + List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE))))); + }); } @Test @@ -85,18 +127,37 @@ public void testTranslationFromConnectorExpression() new SubscriptExpression( new SymbolReference("row_symbol_1"), new LongLiteral("1"))); + + String pattern = "%pattern%"; + assertTranslationFromConnectorExpression( + new Call(VARCHAR_TYPE, + new FunctionName(LIKE_PATTERN_FUNCTION_NAME), + List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE), + new Constant(Slices.wrappedBuffer(pattern.getBytes(UTF_8)), createVarcharType(pattern.length())))), + new LikePredicate(new SymbolReference("varchar_symbol_1"), + new StringLiteral(pattern), + Optional.empty())); + + assertTranslationFromConnectorExpression( + new Call(VARCHAR_TYPE, + new FunctionName("lower"), + List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE))), + FunctionCallBuilder.resolve(TEST_SESSION, PLANNER_CONTEXT.getMetadata()) + .setName(QualifiedName.of(("lower"))) + .addArgument(VARCHAR_TYPE, new SymbolReference("varchar_symbol_1")) + .build()); } - private void assertTranslationToConnectorExpression(Expression expression, Optional connectorExpression) + private void assertTranslationToConnectorExpression(Session session, Expression expression, Optional connectorExpression) { - Optional translation = translate(TEST_SESSION, expression, TYPE_ANALYZER, TYPE_PROVIDER); + Optional translation = translate(session, expression, TYPE_ANALYZER, TYPE_PROVIDER, PLANNER_CONTEXT); assertEquals(connectorExpression.isPresent(), translation.isPresent()); translation.ifPresent(value -> assertEquals(value, connectorExpression.get())); } private void assertTranslationFromConnectorExpression(ConnectorExpression connectorExpression, Expression expected) { - Expression translation = translate(TEST_SESSION, connectorExpression, variableMappings, LITERAL_ENCODER); + Expression translation = ConnectorExpressionTranslator.translate(TEST_SESSION, connectorExpression, PLANNER_CONTEXT, variableMappings, LITERAL_ENCODER); assertEquals(translation, expected); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java index 103a18ba7fe5..27106281e000 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java @@ -79,22 +79,22 @@ public void testPartialTranslator() List functionArguments = ImmutableList.of(stringLiteral, dereferenceExpression2); Expression functionCallExpression = new FunctionCall(QualifiedName.of("concat"), functionArguments); - assertPartialTranslation(functionCallExpression, functionArguments); + assertFullTranslation(functionCallExpression); } private void assertPartialTranslation(Expression expression, List subexpressions) { - Map, ConnectorExpression> translation = extractPartialTranslations(expression, TEST_SESSION, TYPE_ANALYZER, TYPE_PROVIDER); + Map, ConnectorExpression> translation = extractPartialTranslations(expression, TEST_SESSION, TYPE_ANALYZER, TYPE_PROVIDER, PLANNER_CONTEXT); assertEquals(subexpressions.size(), translation.size()); for (Expression subexpression : subexpressions) { - assertEquals(translation.get(NodeRef.of(subexpression)), translate(TEST_SESSION, subexpression, TYPE_ANALYZER, TYPE_PROVIDER).get()); + assertEquals(translation.get(NodeRef.of(subexpression)), translate(TEST_SESSION, subexpression, TYPE_ANALYZER, TYPE_PROVIDER, PLANNER_CONTEXT).get()); } } private void assertFullTranslation(Expression expression) { - Map, ConnectorExpression> translation = extractPartialTranslations(expression, TEST_SESSION, TYPE_ANALYZER, TYPE_PROVIDER); + Map, ConnectorExpression> translation = extractPartialTranslations(expression, TEST_SESSION, TYPE_ANALYZER, TYPE_PROVIDER, PLANNER_CONTEXT); assertEquals(getOnlyElement(translation.keySet()), NodeRef.of(expression)); - assertEquals(getOnlyElement(translation.values()), translate(TEST_SESSION, expression, TYPE_ANALYZER, TYPE_PROVIDER).get()); + assertEquals(getOnlyElement(translation.values()), translate(TEST_SESSION, expression, TYPE_ANALYZER, TYPE_PROVIDER, PLANNER_CONTEXT).get()); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java index f26cefda50a8..f59c0116418a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java @@ -32,9 +32,9 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTablePartitioning; import io.trino.spi.connector.ConnectorTableProperties; -import io.trino.spi.connector.ConnectorTransactionHandle; import io.trino.spi.connector.ProjectionApplicationResult; import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.expression.Call; import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.expression.Constant; import io.trino.spi.expression.FieldDereference; @@ -48,14 +48,21 @@ import io.trino.sql.planner.iterative.rule.test.RuleTester; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.tree.Expression; +import io.trino.sql.tree.FunctionCall; import io.trino.sql.tree.LongLiteral; +import io.trino.sql.tree.QualifiedName; +import io.trino.sql.tree.StringLiteral; import io.trino.sql.tree.SubscriptExpression; import io.trino.sql.tree.SymbolReference; +import io.trino.testing.TestingTransactionHandle; +import io.trino.transaction.TransactionId; import org.testng.annotations.Test; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.function.Function; +import java.util.stream.Collectors; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; @@ -72,6 +79,7 @@ import static io.trino.sql.planner.iterative.rule.test.RuleTester.defaultRuleTester; import static io.trino.testing.TestingSession.testSessionBuilder; import static java.util.Arrays.asList; +import static java.util.Locale.ENGLISH; import static org.assertj.core.api.Assertions.assertThatThrownBy; public class TestPushProjectionIntoTableScan @@ -133,25 +141,31 @@ public void testPushProjection() Symbol identity = new Symbol("symbol_identity"); Symbol dereference = new Symbol("symbol_dereference"); Symbol constant = new Symbol("symbol_constant"); + Symbol call = new Symbol("symbol_call"); ImmutableMap types = ImmutableMap.of( baseColumn, ROW_TYPE, identity, ROW_TYPE, dereference, BIGINT, - constant, BIGINT); + constant, BIGINT, + call, VARCHAR); // Prepare project node assignments ImmutableMap inputProjections = ImmutableMap.of( identity, baseColumn.toSymbolReference(), dereference, new SubscriptExpression(baseColumn.toSymbolReference(), new LongLiteral("1")), - constant, new LongLiteral("5")); + constant, new LongLiteral("5"), + call, new FunctionCall(QualifiedName.of("STARTS_WITH"), ImmutableList.of(new StringLiteral("abc"), new StringLiteral("ab")))); // Compute expected symbols after applyProjection + TransactionId transactionId = ruleTester.getQueryRunner().getTransactionManager().beginTransaction(false); + Session session = MOCK_SESSION.beginTransactionId(transactionId, ruleTester.getQueryRunner().getTransactionManager(), ruleTester.getQueryRunner().getAccessControl()); ImmutableMap connectorNames = inputProjections.entrySet().stream() - .collect(toImmutableMap(Map.Entry::getKey, e -> translate(MOCK_SESSION, e.getValue(), typeAnalyzer, viewOf(types)).get().toString())); + .collect(toImmutableMap(Map.Entry::getKey, e -> translate(session, e.getValue(), typeAnalyzer, viewOf(types), ruleTester.getPlannerContext()).get().toString())); ImmutableMap newNames = ImmutableMap.of( identity, "projected_variable_" + connectorNames.get(identity), dereference, "projected_dereference_" + connectorNames.get(dereference), - constant, "projected_constant_" + connectorNames.get(constant)); + constant, "projected_constant_" + connectorNames.get(constant), + call, "projected_call_" + connectorNames.get(call)); Map expectedColumns = newNames.entrySet().stream() .collect(toImmutableMap( Map.Entry::getValue, @@ -160,20 +174,18 @@ dereference, new SubscriptExpression(baseColumn.toSymbolReference(), new LongLit ruleTester.assertThat(createRule(ruleTester)) .on(p -> { // Register symbols - Symbol columnSymbol = p.symbol(columnName, columnType); - p.symbol(identity.getName(), types.get(identity)); - p.symbol(dereference.getName(), types.get(dereference)); - p.symbol(constant.getName(), types.get(constant)); + types.forEach((symbol, type) -> p.symbol(symbol.getName(), type)); return p.project( new Assignments(inputProjections), p.tableScan(tableScan -> tableScan .setTableHandle(TEST_TABLE_HANDLE) - .setSymbols(ImmutableList.of(columnSymbol)) - .setAssignments(ImmutableMap.of(columnSymbol, columnHandle)) + .setSymbols(ImmutableList.copyOf(types.keySet())) + .setAssignments(types.keySet().stream() + .collect(Collectors.toMap(Function.identity(), v -> columnHandle))) .setStatistics(Optional.of(PlanNodeStatsEstimate.builder() .setOutputRowCount(42) - .addSymbolStatistics(columnSymbol, SymbolStatsEstimate.builder().setNullsFraction(0).setDistinctValuesCount(33).build()) + .addSymbolStatistics(baseColumn, SymbolStatsEstimate.builder().setNullsFraction(0).setDistinctValuesCount(33).build()) .build())))); }) .withSession(MOCK_SESSION) @@ -198,6 +210,10 @@ dereference, new SubscriptExpression(baseColumn.toSymbolReference(), new LongLit .setLowValue(5) .setHighValue(5) .build()) + .addSymbolStatistics(new Symbol(newNames.get(call).toLowerCase(ENGLISH)), SymbolStatsEstimate.builder() + .setDistinctValuesCount(1) + .setNullsFraction(0) + .build()) .addSymbolStatistics(new Symbol(newNames.get(identity)), SymbolStatsEstimate.builder() .setDistinctValuesCount(33) .setNullsFraction(0) @@ -292,6 +308,9 @@ else if (projection instanceof FieldDereference) { else if (projection instanceof Constant) { variablePrefix = "projected_constant_"; } + else if (projection instanceof Call) { + variablePrefix = "projected_call_"; + } else { throw new UnsupportedOperationException(); } @@ -326,7 +345,7 @@ private static TableHandle createTableHandle(String schemaName, String tableName return new TableHandle( new CatalogName(MOCK_CATALOG), new MockConnectorTableHandle(new SchemaTableName(schemaName, tableName)), - new ConnectorTransactionHandle() {}); + TestingTransactionHandle.create()); } private static SymbolReference symbolReference(String name) diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/Constraint.java b/core/trino-spi/src/main/java/io/trino/spi/connector/Constraint.java index f346b38cd61a..d200cff7ff67 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/Constraint.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/Constraint.java @@ -13,6 +13,7 @@ */ package io.trino.spi.connector; +import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.predicate.NullableValue; import io.trino.spi.predicate.TupleDomain; @@ -21,11 +22,14 @@ import java.util.Set; import java.util.function.Predicate; +import static io.trino.spi.expression.Constant.TRUE; import static java.util.Objects.requireNonNull; public class Constraint { private final TupleDomain summary; + private final ConnectorExpression expression; + private final Map assignments; private final Optional>> predicate; private final Optional> predicateColumns; @@ -41,17 +45,39 @@ public static Constraint alwaysFalse() public Constraint(TupleDomain summary) { - this(summary, Optional.empty(), Optional.empty()); + this(summary, TRUE, Map.of(), Optional.empty(), Optional.empty()); } public Constraint(TupleDomain summary, Predicate> predicate, Set predicateColumns) { - this(summary, Optional.of(predicate), Optional.of(predicateColumns)); + this(summary, TRUE, Map.of(), Optional.of(predicate), Optional.of(predicateColumns)); } - private Constraint(TupleDomain summary, Optional>> predicate, Optional> predicateColumns) + public Constraint(TupleDomain summary, ConnectorExpression expression, Map assignments) + { + this(summary, expression, assignments, Optional.empty(), Optional.empty()); + } + + public Constraint( + TupleDomain summary, + ConnectorExpression expression, + Map assignments, + Predicate> predicate, + Set predicateColumns) + { + this(summary, expression, assignments, Optional.of(predicate), Optional.of(predicateColumns)); + } + + private Constraint( + TupleDomain summary, + ConnectorExpression expression, + Map assignments, + Optional>> predicate, + Optional> predicateColumns) { this.summary = requireNonNull(summary, "summary is null"); + this.expression = requireNonNull(expression, "expression is null"); + this.assignments = Map.copyOf(requireNonNull(assignments, "assignments is null")); this.predicate = requireNonNull(predicate, "predicate is null"); this.predicateColumns = requireNonNull(predicateColumns, "predicateColumns is null").map(Set::copyOf); @@ -63,13 +89,33 @@ private Constraint(TupleDomain summary, Optional getSummary() { return summary; } /** - * A predicate that can be used to filter data. If present, it is equivalent to, or stricter than, {@link #getSummary()}. + * @return an expression predicate which is different from, and should be AND-ed with, {@link #getSummary} or {@link #predicate} (if present). + */ + public ConnectorExpression getExpression() + { + return expression; + } + + /** + * @return mappings from variable names to table column handles + * It is guaranteed that all the required mappings for {@link #getExpression} will be provided but not necessarily *all* the column handles of the table + */ + public Map getAssignments() + { + return assignments; + } + + /** + * A predicate that can be used to filter data. If present, it is equivalent to, or stricter than, {@link #getSummary()} and different from, and should be AND-ed with, {@link #getExpression()}. *

* For Constraint provided in {@link ConnectorMetadata#applyFilter(ConnectorSession, ConnectorTableHandle, Constraint)}, * the predicate cannot be held on to after the call returns. diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConstraintApplicationResult.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConstraintApplicationResult.java index ad47a4521c19..04dd3122d5c1 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConstraintApplicationResult.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConstraintApplicationResult.java @@ -13,14 +13,19 @@ */ package io.trino.spi.connector; +import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.predicate.TupleDomain; +import java.util.Optional; +import java.util.function.Function; + import static java.util.Objects.requireNonNull; public class ConstraintApplicationResult { private final T handle; private final TupleDomain remainingFilter; + private final Optional remainingExpression; private final boolean precalculateStatistics; /** @@ -28,9 +33,31 @@ public class ConstraintApplicationResult * as the connector may be unable to provide good table statistics for {@code handle}. */ public ConstraintApplicationResult(T handle, TupleDomain remainingFilter, boolean precalculateStatistics) + { + this(handle, remainingFilter, Optional.empty(), precalculateStatistics); + } + + /** + * @param remainingExpression the remaining expression, which will be AND-ed with {@code remainingFilter}, + * @param precalculateStatistics Indicates whether engine should consider calculating statistics based on the plan before pushdown, + * as the connector may be unable to provide good table statistics for {@code handle}. + */ + public ConstraintApplicationResult(T handle, TupleDomain remainingFilter, ConnectorExpression remainingExpression, boolean precalculateStatistics) + { + this(handle, remainingFilter, Optional.of(remainingExpression), precalculateStatistics); + } + + /** + * @param remainingExpression the remaining expression, which will be AND-ed with {@code remainingFilter}, + * or {@link Optional#empty()} if the remaining expression is equal to the original expression. + * @param precalculateStatistics Indicates whether engine should consider calculating statistics based on the plan before pushdown, + * as the connector may be unable to provide good table statistics for {@code handle}. + */ + private ConstraintApplicationResult(T handle, TupleDomain remainingFilter, Optional remainingExpression, boolean precalculateStatistics) { this.handle = requireNonNull(handle, "handle is null"); this.remainingFilter = requireNonNull(remainingFilter, "remainingFilter is null"); + this.remainingExpression = requireNonNull(remainingExpression, "remainingExpression is null"); this.precalculateStatistics = precalculateStatistics; } @@ -44,8 +71,18 @@ public TupleDomain getRemainingFilter() return remainingFilter; } + public Optional getRemainingExpression() + { + return remainingExpression; + } + public boolean isPrecalculateStatistics() { return precalculateStatistics; } + + public ConstraintApplicationResult transform(Function transformHandle) + { + return new ConstraintApplicationResult<>(transformHandle.apply(handle), remainingFilter, remainingExpression, precalculateStatistics); + } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/expression/Call.java b/core/trino-spi/src/main/java/io/trino/spi/expression/Call.java new file mode 100644 index 000000000000..f1ed7aed7aed --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/expression/Call.java @@ -0,0 +1,86 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.expression; + +import io.trino.spi.type.Type; + +import java.util.List; +import java.util.Objects; +import java.util.StringJoiner; + +import static java.util.Objects.requireNonNull; + +public final class Call + extends ConnectorExpression +{ + private final FunctionName functionName; + private final List arguments; + + public Call( + Type type, + FunctionName functionName, + List arguments) + { + super(type); + this.functionName = requireNonNull(functionName, "functionName is null"); + this.arguments = List.copyOf(requireNonNull(arguments, "arguments is null")); + } + + public FunctionName getFunctionName() + { + return functionName; + } + + public List getArguments() + { + return arguments; + } + + @Override + public List getChildren() + { + return arguments; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Call call = (Call) o; + return Objects.equals(functionName, call.functionName) && + Objects.equals(arguments, call.arguments) && + Objects.equals(getType(), call.getType()); + } + + @Override + public int hashCode() + { + return Objects.hash(functionName, arguments, getType()); + } + + @Override + public String toString() + { + StringJoiner stringJoiner = new StringJoiner(", ", Call.class.getSimpleName() + "[", "]"); + return stringJoiner + .add("functionName=" + functionName) + .add("arguments=" + arguments) + .toString(); + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/expression/Constant.java b/core/trino-spi/src/main/java/io/trino/spi/expression/Constant.java index c8a4d9483dce..0b9f08293158 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/expression/Constant.java +++ b/core/trino-spi/src/main/java/io/trino/spi/expression/Constant.java @@ -13,6 +13,7 @@ */ package io.trino.spi.expression; +import io.trino.spi.type.BooleanType; import io.trino.spi.type.Type; import java.util.List; @@ -23,6 +24,9 @@ public class Constant extends ConnectorExpression { + public static final Constant TRUE = new Constant(true, BooleanType.BOOLEAN); + public static final Constant FALSE = new Constant(false, BooleanType.BOOLEAN); + private final Object value; /** diff --git a/core/trino-spi/src/main/java/io/trino/spi/expression/FunctionName.java b/core/trino-spi/src/main/java/io/trino/spi/expression/FunctionName.java new file mode 100644 index 000000000000..3e848d567697 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/expression/FunctionName.java @@ -0,0 +1,86 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.trino.spi.expression; + +import io.trino.spi.connector.CatalogSchemaName; + +import java.util.Objects; +import java.util.Optional; +import java.util.StringJoiner; + +import static java.util.Objects.requireNonNull; + +public class FunctionName +{ + private final Optional catalogSchema; + private final String name; + + public FunctionName(String name) + { + this(Optional.empty(), name); + } + + public FunctionName(Optional catalogSchema, String name) + { + this.catalogSchema = requireNonNull(catalogSchema, "catalogSchema is null"); + this.name = requireNonNull(name, "name is null"); + } + + /** + * @return the catalog and schema of this function, or {@link Optional#empty()} if this is a built-in function + */ + public Optional getCatalogSchema() + { + return catalogSchema; + } + + /** + * @return the function's name + */ + public String getName() + { + return name; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + FunctionName that = (FunctionName) o; + return Objects.equals(catalogSchema, that.catalogSchema) && + Objects.equals(name, that.name); + } + + @Override + public int hashCode() + { + return Objects.hash(catalogSchema, name); + } + + @Override + public String toString() + { + StringJoiner stringJoiner = new StringJoiner(", "); + catalogSchema.ifPresent(value -> stringJoiner.add("catalogSchema=" + value)); + return stringJoiner + .add("name='" + name + "'") + .toString(); + } +} diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/expression/ConnectorExpressions.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/expression/ConnectorExpressions.java new file mode 100644 index 000000000000..09267d28b590 --- /dev/null +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/expression/ConnectorExpressions.java @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.base.expression; + +import com.google.common.collect.ImmutableList; +import io.trino.spi.expression.ConnectorExpression; + +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +import static com.google.common.collect.Iterables.getOnlyElement; +import static io.trino.spi.expression.Constant.TRUE; + +public final class ConnectorExpressions +{ + private ConnectorExpressions() {} + + public static List extractConjuncts(ConnectorExpression expression) + { + // TODO: Implement after supporting the conversion of io.trino.sql.tree.LogicalExpression to io.trino.spi.expression.Call + return ImmutableList.of(expression); + } + + public static ConnectorExpression and(ConnectorExpression... expressions) + { + return and(Arrays.asList(expressions)); + } + + public static ConnectorExpression and(Collection expressions) + { + // TODO: Implement after supporting the conversion of io.trino.sql.tree.LogicalExpression to io.trino.spi.expression.Call + if (expressions.size() > 1) { + throw new RuntimeException("Only single expression is currently supported"); + } + if (expressions.isEmpty()) { + return TRUE; + } + return getOnlyElement(expressions); + } +} From a96f5eef5d15dadc74a1c02d0457384e64d26939 Mon Sep 17 00:00:00 2001 From: Assaf Bern Date: Tue, 5 Oct 2021 14:11:11 +0300 Subject: [PATCH 2/2] Use regexp_like function pushdown on Elasticsearch connector --- .../elasticsearch/CountQueryPageSource.java | 2 +- .../elasticsearch/ElasticsearchMetadata.java | 59 ++++++++++++++++++- .../ElasticsearchQueryBuilder.java | 6 +- .../ElasticsearchTableHandle.java | 23 +++++++- .../elasticsearch/ScanQueryPageSource.java | 2 +- .../BaseElasticsearchConnectorTest.java | 44 ++++++++++++++ .../TestElasticsearchMetadata.java | 31 ++++++++++ 7 files changed, 161 insertions(+), 6 deletions(-) create mode 100644 plugin/trino-elasticsearch/src/test/java/io/trino/plugin/elasticsearch/TestElasticsearchMetadata.java diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/CountQueryPageSource.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/CountQueryPageSource.java index 7c84bf29a160..f76b86a47658 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/CountQueryPageSource.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/CountQueryPageSource.java @@ -42,7 +42,7 @@ public CountQueryPageSource(ElasticsearchClient client, ElasticsearchTableHandle long count = client.count( split.getIndex(), split.getShard(), - buildSearchQuery(table.getConstraint().transformKeys(ElasticsearchColumnHandle.class::cast), table.getQuery())); + buildSearchQuery(table.getConstraint().transformKeys(ElasticsearchColumnHandle.class::cast), table.getQuery(), table.getRegexes())); readTimeNanos = System.nanoTime() - start; if (table.getLimit().isPresent()) { diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchMetadata.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchMetadata.java index 09fc08509ad9..78dbc0f37df7 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchMetadata.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchMetadata.java @@ -20,6 +20,8 @@ import com.google.common.collect.ImmutableSet; import com.google.common.io.BaseEncoding; import io.airlift.json.ObjectMapperProvider; +import io.airlift.slice.Slice; +import io.trino.plugin.base.expression.ConnectorExpressions; import io.trino.plugin.elasticsearch.client.ElasticsearchClient; import io.trino.plugin.elasticsearch.client.IndexMetadata; import io.trino.plugin.elasticsearch.client.IndexMetadata.DateTimeType; @@ -53,6 +55,11 @@ import io.trino.spi.connector.LimitApplicationResult; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; +import io.trino.spi.expression.Call; +import io.trino.spi.expression.ConnectorExpression; +import io.trino.spi.expression.Constant; +import io.trino.spi.expression.FunctionName; +import io.trino.spi.expression.Variable; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.ArrayType; @@ -64,12 +71,15 @@ import javax.inject.Inject; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.OptionalLong; import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import java.util.stream.Collectors; import static com.google.common.base.Preconditions.checkArgument; @@ -115,6 +125,9 @@ public class ElasticsearchMetadata new VarcharDecoder.Descriptor(PASSTHROUGH_QUERY_RESULT_COLUMN_NAME), false)); + // See https://www.elastic.co/guide/en/elasticsearch/reference/current/regexp-syntax.html + private static final char[] REGEXP_RESERVED_CHARACTERS = new char[] {'.', '?', '+', '*', '|', '{', '}', '[', ']', '(', ')', '"', '#', '@', '&', '<', '>', '~'}; + private final Type ipAddressType; private final ElasticsearchClient client; private final String schemaName; @@ -488,6 +501,7 @@ public Optional> applyLimit(Connect handle.getSchema(), handle.getIndex(), handle.getConstraint(), + handle.getRegexes(), handle.getQuery(), OptionalLong.of(limit)); @@ -521,7 +535,35 @@ public Optional> applyFilter(C TupleDomain oldDomain = handle.getConstraint(); TupleDomain newDomain = oldDomain.intersect(TupleDomain.withColumnDomains(supported)); - if (oldDomain.equals(newDomain)) { + + ConnectorExpression oldExpression = constraint.getExpression(); + Map newRegexes = new HashMap<>(handle.getRegexes()); + List expressions = ConnectorExpressions.extractConjuncts(constraint.getExpression()); + List notHandledExpressions = new ArrayList<>(); + for (ConnectorExpression expression : expressions) { + if (expression instanceof Call) { + Call call = (Call) expression; + // TODO Support ESCAPE character when it's pushed down by the engine + if (new FunctionName("$like_pattern").equals(call.getFunctionName()) && call.getArguments().size() == 2 && + call.getArguments().get(0) instanceof Variable && call.getArguments().get(1) instanceof Constant) { + String columnName = ((Variable) call.getArguments().get(0)).getName(); + Object pattern = ((Constant) call.getArguments().get(1)).getValue(); + if (!newRegexes.containsKey(columnName) && pattern instanceof Slice) { + IndexMetadata metadata = client.getIndexMetadata(handle.getIndex()); + if (metadata.getSchema() + .getFields().stream() + .anyMatch(field -> columnName.equals(field.getName()) && field.getType() instanceof PrimitiveType && "keyword".equals(((PrimitiveType) field.getType()).getName()))) { + newRegexes.put(columnName, likeToRegexp(((Slice) pattern).toStringUtf8())); + continue; + } + } + } + } + notHandledExpressions.add(expression); + } + + ConnectorExpression newExpression = ConnectorExpressions.and(notHandledExpressions); + if (oldDomain.equals(newDomain) && oldExpression.equals(newExpression)) { return Optional.empty(); } @@ -530,10 +572,23 @@ public Optional> applyFilter(C handle.getSchema(), handle.getIndex(), newDomain, + newRegexes, handle.getQuery(), handle.getLimit()); - return Optional.of(new ConstraintApplicationResult<>(handle, TupleDomain.withColumnDomains(unsupported), false)); + return Optional.of(new ConstraintApplicationResult<>(handle, TupleDomain.withColumnDomains(unsupported), newExpression, false)); + } + + protected static String likeToRegexp(String like) + { + // TODO: This can be done more efficiently by using a state machine and iterating over characters (See io.trino.type.LikeFunctions.likePattern(String, char, boolean)) + String regexp = like.replaceAll(Pattern.quote("\\"), Matcher.quoteReplacement("\\\\")); // first, escape regexp's escape character + for (char c : REGEXP_RESERVED_CHARACTERS) { + regexp = regexp.replaceAll(Pattern.quote(String.valueOf(c)), Matcher.quoteReplacement("\\" + c)); + } + return regexp + .replaceAll("%", ".*") + .replaceAll("_", "."); } private static boolean isPassthroughQuery(ElasticsearchTableHandle table) diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchQueryBuilder.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchQueryBuilder.java index 4e8cdef67c4a..c474c0a334b3 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchQueryBuilder.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchQueryBuilder.java @@ -24,6 +24,7 @@ import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryStringQueryBuilder; import org.elasticsearch.index.query.RangeQueryBuilder; +import org.elasticsearch.index.query.RegexpQueryBuilder; import org.elasticsearch.index.query.TermQueryBuilder; import java.time.Instant; @@ -54,7 +55,7 @@ public final class ElasticsearchQueryBuilder { private ElasticsearchQueryBuilder() {} - public static QueryBuilder buildSearchQuery(TupleDomain constraint, Optional query) + public static QueryBuilder buildSearchQuery(TupleDomain constraint, Optional query, Map regexes) { BoolQueryBuilder queryBuilder = new BoolQueryBuilder(); if (constraint.getDomains().isPresent()) { @@ -68,6 +69,9 @@ public static QueryBuilder buildSearchQuery(TupleDomain queryBuilder.filter(new BoolQueryBuilder().must(((new RegexpQueryBuilder(name, value)))))); + query.map(QueryStringQueryBuilder::new) .ifPresent(queryBuilder::must); diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchTableHandle.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchTableHandle.java index e7583ebb8dbf..a381ef99b34d 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchTableHandle.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ElasticsearchTableHandle.java @@ -15,13 +15,16 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableMap; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.predicate.TupleDomain; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.OptionalLong; +import java.util.stream.Collectors; import static java.util.Objects.requireNonNull; @@ -37,6 +40,7 @@ public enum Type private final String schema; private final String index; private final TupleDomain constraint; + private final Map regexes; private final Optional query; private final OptionalLong limit; @@ -48,6 +52,7 @@ public ElasticsearchTableHandle(Type type, String schema, String index, Optional this.query = requireNonNull(query, "query is null"); constraint = TupleDomain.all(); + regexes = ImmutableMap.of(); limit = OptionalLong.empty(); } @@ -57,6 +62,7 @@ public ElasticsearchTableHandle( @JsonProperty("schema") String schema, @JsonProperty("index") String index, @JsonProperty("constraint") TupleDomain constraint, + @JsonProperty("regexes") Map regexes, @JsonProperty("query") Optional query, @JsonProperty("limit") OptionalLong limit) { @@ -64,6 +70,7 @@ public ElasticsearchTableHandle( this.schema = requireNonNull(schema, "schema is null"); this.index = requireNonNull(index, "index is null"); this.constraint = requireNonNull(constraint, "constraint is null"); + this.regexes = ImmutableMap.copyOf(requireNonNull(regexes, "regexes is null")); this.query = requireNonNull(query, "query is null"); this.limit = requireNonNull(limit, "limit is null"); } @@ -92,6 +99,12 @@ public TupleDomain getConstraint() return constraint; } + @JsonProperty + public Map getRegexes() + { + return regexes; + } + @JsonProperty public OptionalLong getLimit() { @@ -118,6 +131,7 @@ public boolean equals(Object o) schema.equals(that.schema) && index.equals(that.index) && constraint.equals(that.constraint) && + regexes.equals(that.regexes) && query.equals(that.query) && limit.equals(that.limit); } @@ -125,7 +139,7 @@ public boolean equals(Object o) @Override public int hashCode() { - return Objects.hash(type, schema, index, constraint, query, limit); + return Objects.hash(type, schema, index, constraint, regexes, query, limit); } @Override @@ -135,6 +149,13 @@ public String toString() builder.append(type + ":" + index); StringBuilder attributes = new StringBuilder(); + if (!regexes.isEmpty()) { + attributes.append("regexes=["); + attributes.append(regexes.entrySet().stream() + .map(regex -> regex.getKey() + ":" + regex.getValue()) + .collect(Collectors.joining(", "))); + attributes.append("]"); + } limit.ifPresent(value -> attributes.append("limit=" + value)); query.ifPresent(value -> attributes.append("query" + value)); diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ScanQueryPageSource.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ScanQueryPageSource.java index 4baa34908f60..8bb5e60e9209 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ScanQueryPageSource.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ScanQueryPageSource.java @@ -111,7 +111,7 @@ public ScanQueryPageSource( SearchResponse searchResponse = client.beginSearch( split.getIndex(), split.getShard(), - buildSearchQuery(table.getConstraint().transformKeys(ElasticsearchColumnHandle.class::cast), table.getQuery()), + buildSearchQuery(table.getConstraint().transformKeys(ElasticsearchColumnHandle.class::cast), table.getQuery(), table.getRegexes()), needAllFields ? Optional.empty() : Optional.of(requiredFields), documentFields, sort, diff --git a/plugin/trino-elasticsearch/src/test/java/io/trino/plugin/elasticsearch/BaseElasticsearchConnectorTest.java b/plugin/trino-elasticsearch/src/test/java/io/trino/plugin/elasticsearch/BaseElasticsearchConnectorTest.java index 7b21d3f5455d..bfd991e56a2c 100644 --- a/plugin/trino-elasticsearch/src/test/java/io/trino/plugin/elasticsearch/BaseElasticsearchConnectorTest.java +++ b/plugin/trino-elasticsearch/src/test/java/io/trino/plugin/elasticsearch/BaseElasticsearchConnectorTest.java @@ -1041,6 +1041,50 @@ public void testNestedVariants() "VALUES 'value1', 'value2', 'value3', 'value4'"); } + @Test + public void testLike() + throws IOException + { + String indexName = "like_test"; + + @Language("JSON") + String mappings = "" + + "{" + + " \"properties\": { " + + " \"keyword_column\": { \"type\": \"keyword\" }," + + " \"text_column\": { \"type\": \"text\" }" + + " }" + + "}"; + + createIndex(indexName, mappings); + + index(indexName, ImmutableMap.builder() + .put("keyword_column", "so.me tex\\t") + .put("text_column", "so.me tex\\t") + .buildOrThrow()); + + // Add another document to make sure '.' is escaped and not treated as any character + index(indexName, ImmutableMap.builder() + .put("keyword_column", "soome tex\\t") + .put("text_column", "soome tex\\t") + .buildOrThrow()); + + assertThat(query("" + + "SELECT " + + "keyword_column " + + "FROM " + indexName + " " + + "WHERE keyword_column LIKE 's_.m%ex\\t'")) + .matches("VALUES VARCHAR 'so.me tex\\t'") + .isFullyPushedDown(); + + assertThat(query("" + + "SELECT " + + "text_column " + + "FROM " + indexName + " " + + "WHERE text_column LIKE 's_.m%ex\\t'")) + .matches("VALUES VARCHAR 'so.me tex\\t'"); + } + @Test public void testDataTypes() throws IOException diff --git a/plugin/trino-elasticsearch/src/test/java/io/trino/plugin/elasticsearch/TestElasticsearchMetadata.java b/plugin/trino-elasticsearch/src/test/java/io/trino/plugin/elasticsearch/TestElasticsearchMetadata.java new file mode 100644 index 000000000000..0d967f0001f5 --- /dev/null +++ b/plugin/trino-elasticsearch/src/test/java/io/trino/plugin/elasticsearch/TestElasticsearchMetadata.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.elasticsearch; + +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; + +public class TestElasticsearchMetadata +{ + @Test + public void testLikeToRegexp() + { + assertEquals(ElasticsearchMetadata.likeToRegexp("a_b_c"), "a.b.c"); + assertEquals(ElasticsearchMetadata.likeToRegexp("a%b%c"), "a.*b.*c"); + assertEquals(ElasticsearchMetadata.likeToRegexp("a%b_c"), "a.*b.c"); + assertEquals(ElasticsearchMetadata.likeToRegexp("a[b"), "a\\[b"); + assertEquals(ElasticsearchMetadata.likeToRegexp("a_\\b"), "a.\\\\b"); + } +}