diff --git a/core/trino-main/src/main/java/io/trino/FeaturesConfig.java b/core/trino-main/src/main/java/io/trino/FeaturesConfig.java index d2c5363961dc..16922f455ac9 100644 --- a/core/trino-main/src/main/java/io/trino/FeaturesConfig.java +++ b/core/trino-main/src/main/java/io/trino/FeaturesConfig.java @@ -122,6 +122,8 @@ public class FeaturesConfig private boolean faultTolerantExecutionExchangeEncryptionEnabled = true; + private boolean pushFieldDereferenceLambdaIntoScanEnabled; + public enum DataIntegrityVerification { NONE, @@ -517,4 +519,17 @@ public void applyFaultTolerantExecutionDefaults() { exchangeCompressionCodec = LZ4; } + + public boolean isPushFieldDereferenceLambdaIntoScanEnabled() + { + return pushFieldDereferenceLambdaIntoScanEnabled; + } + + @Config("experimental.enable-push-field-dereference-lambda-into-scan.enabled") + @ConfigDescription("Enables pushing field dereferences in lambda into table scan") + public FeaturesConfig setPushFieldDereferenceLambdaIntoScanEnabled(boolean pushFieldDereferenceLambdaIntoScanEnabled) + { + this.pushFieldDereferenceLambdaIntoScanEnabled = pushFieldDereferenceLambdaIntoScanEnabled; + return this; + } } diff --git a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java index 762ef0526a77..241d528f09eb 100644 --- a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java +++ b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java @@ -216,6 +216,7 @@ public final class SystemSessionProperties public static final String IDLE_WRITER_MIN_DATA_SIZE_THRESHOLD = "idle_writer_min_data_size_threshold"; public static final String CLOSE_IDLE_WRITERS_TRIGGER_DURATION = "close_idle_writers_trigger_duration"; public static final String COLUMNAR_FILTER_EVALUATION_ENABLED = "columnar_filter_evaluation_enabled"; + public static final String ENABLE_PUSH_FIELD_DEREFERENCE_LAMBDA_INTO_SCAN = "enable_push_field_dereference_lambda_into_scan"; private final List> sessionProperties; @@ -1103,7 +1104,12 @@ public SystemSessionProperties( ALLOW_UNSAFE_PUSHDOWN, "Allow pushing down expressions that may fail for some inputs", optimizerConfig.isUnsafePushdownAllowed(), - true)); + true), + booleanProperty( + ENABLE_PUSH_FIELD_DEREFERENCE_LAMBDA_INTO_SCAN, + "Enable pushing field dereferences in lambda into scan", + featuresConfig.isPushFieldDereferenceLambdaIntoScanEnabled(), + false)); } @Override @@ -1982,4 +1988,9 @@ public static boolean isUnsafePushdownAllowed(Session session) { return session.getSystemProperty(ALLOW_UNSAFE_PUSHDOWN, Boolean.class); } + + public static boolean isPushFieldDereferenceLambdaIntoScanEnabled(Session session) + { + return session.getSystemProperty(ENABLE_PUSH_FIELD_DEREFERENCE_LAMBDA_INTO_SCAN, Boolean.class); + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java index 73453ec31a5e..ac199192222d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java @@ -24,6 +24,7 @@ import io.trino.plugin.base.expression.ConnectorExpressions; import io.trino.security.AllowAllAccessControl; import io.trino.spi.connector.CatalogSchemaName; +import io.trino.spi.expression.ArrayFieldDereference; import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.expression.FieldDereference; import io.trino.spi.expression.FunctionName; @@ -46,9 +47,11 @@ import io.trino.sql.ir.In; import io.trino.sql.ir.IrVisitor; import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Lambda; import io.trino.sql.ir.Logical; import io.trino.sql.ir.NullIf; import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Row; import io.trino.sql.tree.QualifiedName; import io.trino.type.JoniRegexp; import io.trino.type.JsonPathType; @@ -69,6 +72,7 @@ import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; import static io.trino.metadata.GlobalFunctionCatalog.isBuiltinFunctionName; import static io.trino.metadata.LanguageFunctionManager.isInlineFunction; +import static io.trino.operator.scalar.ArrayTransformFunction.ARRAY_TRANSFORM_NAME; import static io.trino.operator.scalar.JsonStringToArrayCast.JSON_STRING_TO_ARRAY_NAME; import static io.trino.operator.scalar.JsonStringToMapCast.JSON_STRING_TO_MAP_NAME; import static io.trino.operator.scalar.JsonStringToRowCast.JSON_STRING_TO_ROW_NAME; @@ -127,7 +131,13 @@ public static Expression translate(Session session, ConnectorExpression expressi public static Optional translate(Session session, Expression expression) { - return new SqlToConnectorExpressionTranslator(session) + return new SqlToConnectorExpressionTranslator(session, false) + .process(expression); + } + + public static Optional translate(Session session, Expression expression, boolean translateArrayFieldReference) + { + return new SqlToConnectorExpressionTranslator(session, translateArrayFieldReference) .process(expression); } @@ -135,7 +145,7 @@ public static ConnectorExpressionTranslation translateConjuncts( Session session, Expression expression) { - SqlToConnectorExpressionTranslator translator = new SqlToConnectorExpressionTranslator(session); + SqlToConnectorExpressionTranslator translator = new SqlToConnectorExpressionTranslator(session, false); List conjuncts = extractConjuncts(expression); List remaining = new ArrayList<>(); @@ -562,10 +572,12 @@ public static class SqlToConnectorExpressionTranslator extends IrVisitor, Void> { private final Session session; + private final boolean translateArrayFieldReference; - public SqlToConnectorExpressionTranslator(Session session) + public SqlToConnectorExpressionTranslator(Session session, boolean translateArrayFieldReference) { this.session = requireNonNull(session, "session is null"); + this.translateArrayFieldReference = translateArrayFieldReference; } @Override @@ -694,6 +706,37 @@ else if (functionName.equals(builtinFunctionName(MODULUS))) { new io.trino.spi.expression.Call(node.type(), MODULUS_FUNCTION_NAME, ImmutableList.of(left, right)))); } + // Very narrow case that only tries to extract a particular type of lambda expression + // TODO: Expand the scope + if (translateArrayFieldReference && functionName.equals(builtinFunctionName(ARRAY_TRANSFORM_NAME))) { + List allNodeArgument = node.arguments(); + // at this point, SubscriptExpression should already been pushed down by PushProjectionIntoTableScan, + // if not, it means its referenced by other expressions. we only care about SymbolReference at this moment + List inputExpressions = allNodeArgument.stream().filter(Reference.class::isInstance) + .collect(toImmutableList()); + List lambdaExpressions = allNodeArgument.stream().filter(e -> e instanceof Lambda lambda + && lambda.arguments().size() == 1) + .map(Lambda.class::cast) + .collect(toImmutableList()); + if (inputExpressions.size() == 1 && lambdaExpressions.size() == 1) { + Optional inputVariable = process(inputExpressions.get(0)); + if (lambdaExpressions.get(0).body() instanceof Row row) { + List rowFields = row.items(); + List translatedRowFields = + rowFields.stream().map(e -> process(e)).filter(Optional::isPresent).map(Optional::get).collect(toImmutableList()); + if (inputVariable.isPresent() && translatedRowFields.size() == rowFields.size()) { + return Optional.of(new ArrayFieldDereference(node.type(), inputVariable.get(), translatedRowFields)); + } + } + if (lambdaExpressions.get(0).body() instanceof FieldReference fieldReference) { + Optional fieldReferenceConnectorExpr = process(fieldReference); + if (inputVariable.isPresent() && fieldReferenceConnectorExpr.isPresent() && fieldReferenceConnectorExpr.get() instanceof FieldDereference expr) { + return Optional.of(new ArrayFieldDereference(node.type(), inputVariable.get(), ImmutableList.of(expr))); + } + } + } + } + ImmutableList.Builder arguments = ImmutableList.builder(); for (Expression argumentExpression : node.arguments()) { Optional argument = process(argumentExpression); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PartialTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/PartialTranslator.java index 32ed719e033a..cd334116c0bd 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PartialTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PartialTranslator.java @@ -37,13 +37,14 @@ private PartialTranslator() {} */ public static Map, ConnectorExpression> extractPartialTranslations( Expression inputExpression, - Session session) + Session session, + boolean translateArrayFieldReference) { requireNonNull(inputExpression, "inputExpression is null"); requireNonNull(session, "session is null"); Map, ConnectorExpression> partialTranslations = new HashMap<>(); - new Visitor(session, partialTranslations).process(inputExpression); + new Visitor(session, partialTranslations, translateArrayFieldReference).process(inputExpression); return ImmutableMap.copyOf(partialTranslations); } @@ -53,10 +54,10 @@ private static class Visitor private final Map, ConnectorExpression> translatedSubExpressions; private final ConnectorExpressionTranslator.SqlToConnectorExpressionTranslator translator; - Visitor(Session session, Map, ConnectorExpression> translatedSubExpressions) + Visitor(Session session, Map, ConnectorExpression> translatedSubExpressions, boolean translateArrayFieldReference) { this.translatedSubExpressions = requireNonNull(translatedSubExpressions, "translatedSubExpressions is null"); - this.translator = new ConnectorExpressionTranslator.SqlToConnectorExpressionTranslator(session); + this.translator = new ConnectorExpressionTranslator.SqlToConnectorExpressionTranslator(session, translateArrayFieldReference); } @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java index a6232ca83919..1fc95acedef7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java @@ -151,7 +151,11 @@ import io.trino.sql.planner.iterative.rule.PushDownDereferencesThroughTopN; import io.trino.sql.planner.iterative.rule.PushDownDereferencesThroughTopNRanking; import io.trino.sql.planner.iterative.rule.PushDownDereferencesThroughWindow; +import io.trino.sql.planner.iterative.rule.PushDownFieldReferenceLambdaThroughFilter; +import io.trino.sql.planner.iterative.rule.PushDownFieldReferenceLambdaThroughProject; import io.trino.sql.planner.iterative.rule.PushDownProjectionsFromPatternRecognition; +import io.trino.sql.planner.iterative.rule.PushFieldReferenceLambdaIntoTableScan; +import io.trino.sql.planner.iterative.rule.PushFieldReferenceLambdaThroughFilterIntoTableScan; import io.trino.sql.planner.iterative.rule.PushFilterIntoValues; import io.trino.sql.planner.iterative.rule.PushFilterThroughBoolOrAggregation; import io.trino.sql.planner.iterative.rule.PushFilterThroughCountAggregation; @@ -346,7 +350,9 @@ public PlanOptimizers( new PushDownDereferencesThroughWindow(), new PushDownDereferencesThroughTopN(), new PushDownDereferencesThroughRowNumber(), - new PushDownDereferencesThroughTopNRanking()); + new PushDownDereferencesThroughTopNRanking(), + new PushDownFieldReferenceLambdaThroughProject(), + new PushDownFieldReferenceLambdaThroughFilter()); Set> limitPushdownRules = ImmutableSet.of( new PushLimitThroughOffset(), @@ -999,6 +1005,15 @@ public PlanOptimizers( new PushPartialAggregationThroughExchange(plannerContext), new PruneJoinColumns(), new PruneJoinChildrenColumns()))); + // This rule does not touch query plans, but only add subfields if necessary. Trigger at the near end + // Keeping this as iterative as it could be combined with PushProjectionIntoTableScan in the future + builder.add(new IterativeOptimizer( + plannerContext, + ruleStats, + statsCalculator, + costCalculator, + ImmutableSet.of(new PushFieldReferenceLambdaIntoTableScan(plannerContext), + new PushFieldReferenceLambdaThroughFilterIntoTableScan(plannerContext)))); builder.add(new IterativeOptimizer( plannerContext, ruleStats, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DereferencePushdown.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DereferencePushdown.java index 80d73f883dcd..df50c38855ed 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DereferencePushdown.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DereferencePushdown.java @@ -14,21 +14,32 @@ package io.trino.sql.planner.iterative.rule; import com.google.common.collect.ImmutableList; +import io.trino.spi.function.CatalogSchemaFunctionName; import io.trino.spi.type.RowType; +import io.trino.sql.ir.Call; import io.trino.sql.ir.DefaultTraversalVisitor; import io.trino.sql.ir.Expression; import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.Lambda; import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Row; import io.trino.sql.planner.Symbol; +import io.trino.sql.tree.FunctionCall; import java.util.Collection; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.getOnlyElement; +import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; +import static io.trino.operator.scalar.ArrayTransformFunction.ARRAY_TRANSFORM_NAME; import static io.trino.sql.planner.SymbolsExtractor.extractAll; /** @@ -132,4 +143,113 @@ private static boolean prefixExists(Expression expression, Set expre verify(current instanceof Reference); return false; } + + // Common methods for subscript lambda pushdown + /** + * Extract the sub-expressions of type subscript lambda {@link FunctionCall} from the {@param expression} + */ + public static Map extractSubscriptLambdas(Collection expressions) + { + List> referencesAndFieldDereferenceLambdas = + expressions.stream() + .map(expression -> getSymbolReferencesAndSubscriptLambdas(expression)) + .collect(toImmutableList()); + + Set symbolReferences = + referencesAndFieldDereferenceLambdas.stream() + .flatMap(m -> m.keySet().stream()) + .filter(Reference.class::isInstance) + .map(Reference.class::cast) + .collect(Collectors.toSet()); + + // Returns the subscript expression and its target input expression + Map subscriptLambdas = + referencesAndFieldDereferenceLambdas.stream() + .flatMap(m -> m.entrySet().stream()) + .filter(e -> e.getKey() instanceof Call && !symbolReferences.contains(e.getValue())) + .collect(Collectors.toMap(e -> (Call) e.getKey(), e -> e.getValue())); + + return subscriptLambdas; + } + + /** + * Extract the sub-expressions of type {@link Reference} and subscript lambda {@link FunctionCall} from the {@param expression} + */ + private static Map getSymbolReferencesAndSubscriptLambdas(Expression expression) + { + Map symbolMappings = new HashMap<>(); + + new DefaultTraversalVisitor>() + { + @Override + protected Void visitReference(Reference node, Map context) + { + context.put(node, node); + return null; + } + + @Override + protected Void visitCall(Call node, Map context) + { + Optional inputExpression = getSubscriptLambdaInputExpression(node); + if (inputExpression.isPresent()) { + context.put(node, inputExpression.get()); + } + + return null; + } + }.process(expression, symbolMappings); + + return symbolMappings; + } + + /** + * Extract the sub-expressions of type {@link Reference} from the {@param expression} + */ + public static List getReferences(Expression expression) + { + ImmutableList.Builder builder = ImmutableList.builder(); + + new DefaultTraversalVisitor>() + { + @Override + protected Void visitReference(Reference node, ImmutableList.Builder context) + { + context.add(node); + return null; + } + }.process(expression, builder); + + return builder.build(); + } + + /** + * Common pattern matching util function to look for subscript lambda function + */ + public static Optional getSubscriptLambdaInputExpression(Expression expression) + { + if (expression instanceof Call functionCall) { + CatalogSchemaFunctionName functionName = functionCall.function().name(); + + if (functionName.equals(builtinFunctionName(ARRAY_TRANSFORM_NAME))) { + List allNodeArgument = functionCall.arguments(); + // at this point, FieldDereference expression should already been replaced with reference expression, + // if not, it means its referenced by other expressions. we only care about FieldReference at this moment + List inputExpressions = allNodeArgument.stream() + .filter(Reference.class::isInstance) + .map(Reference.class::cast) + .collect(toImmutableList()); + List lambdaExpressions = allNodeArgument.stream().filter(e -> e instanceof Lambda lambda + && lambda.arguments().size() == 1) + .map(Lambda.class::cast) + .collect(toImmutableList()); + if (inputExpressions.size() == 1 && lambdaExpressions.size() == 1 && + ((lambdaExpressions.get(0).body() instanceof Row row && + row.items().stream().allMatch(FieldReference.class::isInstance)) || (lambdaExpressions.get(0).body() instanceof FieldReference))) { + return Optional.of(inputExpressions.get(0)); + } + } + } + return Optional.empty(); + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/InlineProjections.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/InlineProjections.java index fa1898a485cf..be6a813883e2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/InlineProjections.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/InlineProjections.java @@ -39,6 +39,7 @@ import static io.trino.matching.Capture.newCapture; import static io.trino.sql.planner.ExpressionSymbolInliner.inlineSymbols; +import static io.trino.sql.planner.iterative.rule.DereferencePushdown.getSubscriptLambdaInputExpression; import static io.trino.sql.planner.plan.Patterns.project; import static io.trino.sql.planner.plan.Patterns.source; import static java.util.stream.Collectors.toSet; @@ -187,6 +188,10 @@ private static Set extractInliningTargets(ProjectNode parent, ProjectNod return true; }) + .filter(entry -> { + // skip subscript lambdas, otherwise, inlining can cause conflicts with PushdownDereferences + return getSubscriptLambdaInputExpression(child.getAssignments().get(entry.getKey())).isEmpty(); + }) .map(Map.Entry::getKey) .collect(toSet()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownFieldReferenceLambdaThroughFilter.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownFieldReferenceLambdaThroughFilter.java new file mode 100644 index 000000000000..5f7d3220e6a1 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownFieldReferenceLambdaThroughFilter.java @@ -0,0 +1,132 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.iterative.rule; + +import com.google.common.collect.HashBiMap; +import io.trino.Session; +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Expression; +import io.trino.sql.ir.Reference; +import io.trino.sql.planner.iterative.Rule; +import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.FilterNode; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.ProjectNode; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.trino.SystemSessionProperties.isAllowPushdownIntoConnectors; +import static io.trino.SystemSessionProperties.isPushFieldDereferenceLambdaIntoScanEnabled; +import static io.trino.matching.Capture.newCapture; +import static io.trino.sql.planner.ExpressionNodeInliner.replaceExpression; +import static io.trino.sql.planner.iterative.rule.DereferencePushdown.extractSubscriptLambdas; +import static io.trino.sql.planner.iterative.rule.DereferencePushdown.getReferences; +import static io.trino.sql.planner.plan.Patterns.filter; +import static io.trino.sql.planner.plan.Patterns.project; +import static io.trino.sql.planner.plan.Patterns.source; + +/** + * This rule is to push field reference lambdas into below projection through filter. This rule increases the + * possibilities of subscript lambdas reaching near table scans + */ + +/** + * Transforms: + *
+ *  Project(c := f(a, x -> x[1]), d := g(b))
+ *    Filter(b = 3)
+ *    Project(a, b)
+ *  
+ * to: + *
+ *  Project(c := expr, d := g(b))
+ *    Filter(b = 3)
+ *    Project(expr := f(a, x -> x[1]), b)
+ * 
+ */ +public class PushDownFieldReferenceLambdaThroughFilter + implements Rule +{ + private static final Capture CHILD = newCapture(); + + @Override + public Pattern getPattern() + { + return project() + .with(source().matching(filter().capturedAs(CHILD))); + } + + @Override + public boolean isEnabled(Session session) + { + return isAllowPushdownIntoConnectors(session) + && isPushFieldDereferenceLambdaIntoScanEnabled(session); + } + + @Override + public Result apply(ProjectNode node, Captures captures, Rule.Context context) + { + FilterNode filterNode = captures.get(CHILD); + + // Extract subscript lambdas from project node assignments for pushdown + Map subscriptLambdas = extractSubscriptLambdas(node.getAssignments().getExpressions()); + + if (subscriptLambdas.isEmpty()) { + return Result.empty(); + } + + // If filter has same references as subscript inputs, skip to be safe, extending the scope later + List filterSymbolReferences = getReferences(filterNode.getPredicate()); + subscriptLambdas = subscriptLambdas.entrySet().stream() + .filter(e -> !filterSymbolReferences.contains(e.getValue())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + if (subscriptLambdas.isEmpty()) { + return Result.empty(); + } + + // Create new symbols for subscript lambda expressions + Assignments subscriptLambdaAssignments = Assignments.of(subscriptLambdas.keySet(), context.getSymbolAllocator()); + + // Rewrite project node assignments using new symbols for subscript lambda expressions + Map mappings = HashBiMap.create(subscriptLambdaAssignments.getMap()) + .inverse() + .entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); + Assignments assignments = node.getAssignments().rewrite(expression -> replaceExpression(expression, mappings)); + + PlanNode source = filterNode.getSource(); + + return Result.ofPlanNode( + new ProjectNode( + context.getIdAllocator().getNextId(), + new FilterNode( + context.getIdAllocator().getNextId(), + new ProjectNode( + context.getIdAllocator().getNextId(), + source, + Assignments.builder() + .putIdentities(source.getOutputSymbols()) + .putAll(subscriptLambdaAssignments) + .build()), + filterNode.getPredicate()), + assignments)); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownFieldReferenceLambdaThroughProject.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownFieldReferenceLambdaThroughProject.java new file mode 100644 index 000000000000..4b630aab73b5 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownFieldReferenceLambdaThroughProject.java @@ -0,0 +1,116 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.iterative.rule; + +import com.google.common.collect.HashBiMap; +import io.trino.Session; +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Expression; +import io.trino.sql.ir.Reference; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.iterative.Rule; +import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.ProjectNode; + +import java.util.Map; +import java.util.stream.Collectors; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.trino.SystemSessionProperties.isAllowPushdownIntoConnectors; +import static io.trino.SystemSessionProperties.isPushFieldDereferenceLambdaIntoScanEnabled; +import static io.trino.matching.Capture.newCapture; +import static io.trino.sql.planner.ExpressionNodeInliner.replaceExpression; +import static io.trino.sql.planner.iterative.rule.DereferencePushdown.extractSubscriptLambdas; +import static io.trino.sql.planner.plan.Patterns.project; +import static io.trino.sql.planner.plan.Patterns.source; + +/** + * This rule is to push field reference lambdas into below projection. This rule increases the + * possibilities of subscript lambdas reaching near table scans + */ + +/** + * Transforms: + *
+ *  Project(c := f(a, x -> x[1]), d := g(b))
+ *    Project(a, b)
+ *  
+ * to: + *
+ *  Project(c := expr d := g(b))
+ *    Project(a, b, expr := f(a, x -> x[1]))
+ * 
+ */ +public class PushDownFieldReferenceLambdaThroughProject + implements Rule +{ + private static final Capture CHILD = newCapture(); + + @Override + public Pattern getPattern() + { + return project() + .with(source().matching(project().capturedAs(CHILD))); + } + + @Override + public boolean isEnabled(Session session) + { + return isAllowPushdownIntoConnectors(session) + && isPushFieldDereferenceLambdaIntoScanEnabled(session); + } + + @Override + public Result apply(ProjectNode node, Captures captures, Context context) + { + ProjectNode child = captures.get(CHILD); + + // Extract subscript lambdas from project node assignments for pushdown + Map subscriptLambdas = extractSubscriptLambdas(node.getAssignments().getExpressions()); + + // Exclude subscript lambdas on symbols being synthesized within child + subscriptLambdas = subscriptLambdas.entrySet().stream() + .filter(e -> child.getSource().getOutputSymbols().contains(Symbol.from(e.getValue()))) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + if (subscriptLambdas.isEmpty()) { + return Result.empty(); + } + + // Create new symbols for subscript lambda expressions + Assignments subscriptLambdaAssignments = Assignments.of(subscriptLambdas.keySet(), context.getSymbolAllocator()); + + // Rewrite project node assignments using new symbols for subscript lambda expressions + Map mappings = HashBiMap.create(subscriptLambdaAssignments.getMap()) + .inverse() + .entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); + Assignments assignments = node.getAssignments().rewrite(expression -> replaceExpression(expression, mappings)); + + return Result.ofPlanNode( + new ProjectNode( + context.getIdAllocator().getNextId(), + new ProjectNode( + context.getIdAllocator().getNextId(), + child.getSource(), + Assignments.builder() + .putAll(child.getAssignments()) + .putAll(subscriptLambdaAssignments) + .build()), + assignments)); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFieldReferenceLambdaIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFieldReferenceLambdaIntoTableScan.java new file mode 100644 index 000000000000..58a8febdbe54 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFieldReferenceLambdaIntoTableScan.java @@ -0,0 +1,178 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import io.airlift.log.Logger; +import io.trino.Session; +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.metadata.TableHandle; +import io.trino.spi.connector.Assignment; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ProjectionApplicationResult; +import io.trino.spi.expression.ArrayFieldDereference; +import io.trino.spi.expression.ConnectorExpression; +import io.trino.spi.expression.Variable; +import io.trino.sql.PlannerContext; +import io.trino.sql.ir.Expression; +import io.trino.sql.ir.NodeRef; +import io.trino.sql.ir.Reference; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.iterative.Rule; +import io.trino.sql.planner.plan.ProjectNode; +import io.trino.sql.planner.plan.TableScanNode; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.trino.SystemSessionProperties.isAllowPushdownIntoConnectors; +import static io.trino.SystemSessionProperties.isPushFieldDereferenceLambdaIntoScanEnabled; +import static io.trino.matching.Capture.newCapture; +import static io.trino.sql.planner.PartialTranslator.extractPartialTranslations; +import static io.trino.sql.planner.iterative.rule.DereferencePushdown.getReferences; +import static io.trino.sql.planner.plan.Patterns.project; +import static io.trino.sql.planner.plan.Patterns.source; +import static io.trino.sql.planner.plan.Patterns.tableScan; +import static java.util.function.Function.identity; + +/** + * This rule will try to retrieve field reference expressions within lambda function and generate subfields into table scan + * The rule is purposely being very narrow for a few reasons: + * 1. Waiting on decision to accept Subfield to replace list of dereference names that currently being used + * 2. This serves as a starting point to push lambda expression into table scan, and push lambda expression through other operators in the future + * 3. The PruneUnnestMappings has NOT been accepted yet, which this rule is relying on, there is risk that this need to be rewritten + * + * TODO: Remove lambda expression after subfields are pushed down + */ +public class PushFieldReferenceLambdaIntoTableScan + implements Rule +{ + private static final Logger LOG = Logger.get(PushFieldReferenceLambdaIntoTableScan.class); + private static final Capture TABLE_SCAN = newCapture(); + private static final Pattern PATTERN = project().with(source().matching( + tableScan().capturedAs(TABLE_SCAN))); + + private final PlannerContext plannerContext; + + public PushFieldReferenceLambdaIntoTableScan(PlannerContext plannerContext) + { + this.plannerContext = plannerContext; + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public boolean isEnabled(Session session) + { + return isAllowPushdownIntoConnectors(session) + && isPushFieldDereferenceLambdaIntoScanEnabled(session); + } + + @Override + public Result apply(ProjectNode project, Captures captures, Context context) + { + TableScanNode tableScan = captures.get(TABLE_SCAN); + + Session session = context.getSession(); + + // Extract only ArrayFieldDereference expressions from projection expressions, other expressions have been applied + Map, ConnectorExpression> partialTranslations = project.getAssignments().getMap().entrySet().stream() + .flatMap(expression -> + extractPartialTranslations( + expression.getValue(), + session, + true + ).entrySet().stream().filter(entry -> (entry.getValue() instanceof ArrayFieldDereference))) + // Avoid duplicates + .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue, (first, ignore) -> first)); + + if (partialTranslations.isEmpty()) { + return Result.empty(); + } + + Map inputVariableMappings = tableScan.getAssignments().keySet().stream() + .collect(toImmutableMap(Symbol::name, identity())); + Map assignments = inputVariableMappings.entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> tableScan.getAssignments().get(entry.getValue()))); + + // Because we will not replace any symbol references but prune the data, we want to make sure same table scan symbol + // is not used anywhere else just to be safe, we will revisit this once we need to expand the scope of this optimization. + // As a result, only support limited cases now which symbol reference has to be uniquely referenced + List expressions = ImmutableList.copyOf(project.getAssignments().getExpressions()); + Map referenceNamesCount = expressions.stream() + .flatMap(expression -> getReferences(expression).stream()) + .map(Reference::name) + .collect(Collectors.groupingBy(Function.identity(), Collectors.counting())); + + partialTranslations = partialTranslations.entrySet().stream().filter(entry -> { + ArrayFieldDereference arrayFieldDereference = (ArrayFieldDereference) entry.getValue(); + return arrayFieldDereference.getTarget() instanceof Variable variable + && referenceNamesCount.get(variable.getName()) == 1; + }).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + if (partialTranslations.isEmpty()) { + return Result.empty(); + } + + // At this point, only Hive connector understands how to deal with ArrayFieldDereference expression + Optional> result = + plannerContext.getMetadata().applyProjection(session, + tableScan.getTable(), + ImmutableList.copyOf(partialTranslations.values()), + assignments); + + if (result.isEmpty()) { + return Result.empty(); + } + + Map newTableAssignments = new HashMap<>(); + for (Assignment assignment : result.get().getAssignments()) { + newTableAssignments.put(inputVariableMappings.get(assignment.getVariable()), assignment.getColumn()); + } + + verify(assignments.size() == newTableAssignments.size(), + "Assignments size mis-match after PushSubscriptLambdaIntoTableScan: %d instead of %d", + newTableAssignments.size(), + assignments.size()); + + LOG.info("PushSubscriptLambdaIntoTableScan is effectively triggered on %d expressions", partialTranslations.size()); + + // Only update tableHandle and TableScan assignments which have new columnHandles + return Result.ofPlanNode( + new ProjectNode( + context.getIdAllocator().getNextId(), + new TableScanNode( + tableScan.getId(), + result.get().getHandle(), + tableScan.getOutputSymbols(), + newTableAssignments, + tableScan.getEnforcedConstraint(), + tableScan.getStatistics(), + tableScan.isUpdateTarget(), + tableScan.getUseConnectorNodePartitioning()), + project.getAssignments())); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFieldReferenceLambdaThroughFilterIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFieldReferenceLambdaThroughFilterIntoTableScan.java new file mode 100644 index 000000000000..cecf80f69c5e --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFieldReferenceLambdaThroughFilterIntoTableScan.java @@ -0,0 +1,179 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import io.airlift.log.Logger; +import io.trino.Session; +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.metadata.TableHandle; +import io.trino.spi.connector.Assignment; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ProjectionApplicationResult; +import io.trino.spi.expression.ArrayFieldDereference; +import io.trino.spi.expression.ConnectorExpression; +import io.trino.sql.PlannerContext; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Expression; +import io.trino.sql.ir.NodeRef; +import io.trino.sql.ir.Reference; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.iterative.Rule; +import io.trino.sql.planner.plan.FilterNode; +import io.trino.sql.planner.plan.ProjectNode; +import io.trino.sql.planner.plan.TableScanNode; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.trino.SystemSessionProperties.isAllowPushdownIntoConnectors; +import static io.trino.SystemSessionProperties.isPushFieldDereferenceLambdaIntoScanEnabled; +import static io.trino.matching.Capture.newCapture; +import static io.trino.sql.planner.PartialTranslator.extractPartialTranslations; +import static io.trino.sql.planner.iterative.rule.DereferencePushdown.extractSubscriptLambdas; +import static io.trino.sql.planner.iterative.rule.DereferencePushdown.getReferences; +import static io.trino.sql.planner.plan.Patterns.filter; +import static io.trino.sql.planner.plan.Patterns.project; +import static io.trino.sql.planner.plan.Patterns.source; +import static io.trino.sql.planner.plan.Patterns.tableScan; +import static java.util.function.Function.identity; + +/** + * This rule is similar as PushSubscriptLambdaIntoTableScan, but handles the case where filter node + * is above table scan after predicate pushdown rules + * + * TODO: Remove lambda expression after subfields are pushed down + */ +public class PushFieldReferenceLambdaThroughFilterIntoTableScan + implements Rule +{ + private static final Logger LOG = Logger.get(PushFieldReferenceLambdaThroughFilterIntoTableScan.class); + private static final Capture filter = newCapture(); + private static final Capture tablescan = newCapture(); + + private final PlannerContext plannerContext; + + public PushFieldReferenceLambdaThroughFilterIntoTableScan(PlannerContext plannerContext) + { + this.plannerContext = plannerContext; + } + + @Override + public Pattern getPattern() + { + return project() + .with(source().matching(filter().capturedAs(filter) + .with(source().matching((tableScan().capturedAs(tablescan)))))); + } + + @Override + public boolean isEnabled(Session session) + { + return isAllowPushdownIntoConnectors(session) + && isPushFieldDereferenceLambdaIntoScanEnabled(session); + } + + @Override + public Result apply(ProjectNode project, Captures captures, Context context) + { + FilterNode filterNode = captures.get(filter); + TableScanNode tableScanNode = captures.get(tablescan); + + Map subscriptLambdas = extractSubscriptLambdas(project.getAssignments().getExpressions()); + + if (subscriptLambdas.isEmpty()) { + return Result.empty(); + } + + // If filter has same reference as subscript input, skip for safe for now + List filterSymbolReferences = getReferences(filterNode.getPredicate()); + subscriptLambdas = subscriptLambdas.entrySet().stream() + .filter(e -> !filterSymbolReferences.contains(e.getValue())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + if (subscriptLambdas.isEmpty()) { + return Result.empty(); + } + + Session session = context.getSession(); + // Extract only ArrayFieldDereference expressions from projection expressions, other expressions have been applied + Map, ConnectorExpression> partialTranslations = subscriptLambdas.entrySet().stream() + .flatMap(expression -> + extractPartialTranslations( + expression.getKey(), + session, + true + ).entrySet().stream().filter(entry -> (entry.getValue() instanceof ArrayFieldDereference))) + .filter(entry -> !(entry.getValue() instanceof io.trino.spi.expression.Constant)) + // Avoid duplicates + .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue, (first, ignore) -> first)); + + if (partialTranslations.isEmpty()) { + return Result.empty(); + } + + Map inputVariableMappings = tableScanNode.getAssignments().keySet().stream() + .collect(toImmutableMap(Symbol::name, identity())); + Map assignments = inputVariableMappings.entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> tableScanNode.getAssignments().get(entry.getValue()))); + + // Apply projections handled by connectors + Optional> result = + plannerContext.getMetadata().applyProjection(session, + tableScanNode.getTable(), + ImmutableList.copyOf(partialTranslations.values()), + assignments); + + if (result.isEmpty()) { + return Result.empty(); + } + + Map newTableAssignments = new HashMap<>(); + for (Assignment assignment : result.get().getAssignments()) { + newTableAssignments.put(inputVariableMappings.get(assignment.getVariable()), assignment.getColumn()); + } + + verify(assignments.size() == newTableAssignments.size(), + "Assignments size mis-match after PushSubscriptLambdaThroughFilterIntoTableScan: %d instead of %d", + newTableAssignments.size(), + assignments.size()); + + LOG.info("PushSubscriptLambdaThroughFilterIntoTableScan is effectively triggered on %d expressions", partialTranslations.size()); + + // Only update tableHandle and TableScan assignments which have new columnHandles + return Result.ofPlanNode( + new ProjectNode( + context.getIdAllocator().getNextId(), + new FilterNode( + context.getIdAllocator().getNextId(), + new TableScanNode( + tableScanNode.getId(), + result.get().getHandle(), + tableScanNode.getOutputSymbols(), + newTableAssignments, + tableScanNode.getEnforcedConstraint(), + tableScanNode.getStatistics(), + tableScanNode.isUpdateTarget(), + tableScanNode.getUseConnectorNodePartitioning()), + filterNode.getPredicate()), + project.getAssignments())); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionIntoTableScan.java index f3c9879b12ef..62c7870e2232 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionIntoTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionIntoTableScan.java @@ -103,7 +103,8 @@ public Result apply(ProjectNode project, Captures captures, Context context) .flatMap(expression -> extractPartialTranslations( expression.getValue(), - session + session, + false // In the future, we want to rewrite this class to translate ArrayFieldReference here as well ).entrySet().stream()) // Filter out constant expressions. Constant expressions should not be pushed to the connector. .filter(entry -> !(entry.getValue() instanceof io.trino.spi.expression.Constant)) diff --git a/core/trino-main/src/test/java/io/trino/sql/analyzer/TestFeaturesConfig.java b/core/trino-main/src/test/java/io/trino/sql/analyzer/TestFeaturesConfig.java index bd7a6d271d95..d942f70b3604 100644 --- a/core/trino-main/src/test/java/io/trino/sql/analyzer/TestFeaturesConfig.java +++ b/core/trino-main/src/test/java/io/trino/sql/analyzer/TestFeaturesConfig.java @@ -66,6 +66,7 @@ public void testDefaults() .setHideInaccessibleColumns(false) .setForceSpillingJoin(false) .setColumnarFilterEvaluationEnabled(true) + .setPushFieldDereferenceLambdaIntoScanEnabled(false) .setFaultTolerantExecutionExchangeEncryptionEnabled(true)); } @@ -100,6 +101,7 @@ public void testExplicitPropertyMappings() .put("hide-inaccessible-columns", "true") .put("force-spilling-join-operator", "true") .put("experimental.columnar-filter-evaluation.enabled", "false") + .put("experimental.enable-push-field-dereference-lambda-into-scan.enabled", "true") .put("fault-tolerant-execution.exchange-encryption-enabled", "false") .buildOrThrow(); @@ -131,6 +133,7 @@ public void testExplicitPropertyMappings() .setHideInaccessibleColumns(true) .setForceSpillingJoin(true) .setColumnarFilterEvaluationEnabled(false) + .setPushFieldDereferenceLambdaIntoScanEnabled(true) .setFaultTolerantExecutionExchangeEncryptionEnabled(false); assertFullMapping(properties, expected); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java index 054f504d9a8e..8d440f58e829 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java @@ -23,6 +23,7 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.operator.scalar.JsonPath; import io.trino.security.AllowAllAccessControl; +import io.trino.spi.expression.ArrayFieldDereference; import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.expression.FieldDereference; import io.trino.spi.expression.FunctionName; @@ -43,12 +44,15 @@ import io.trino.sql.ir.FieldReference; import io.trino.sql.ir.In; import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Lambda; import io.trino.sql.ir.Logical; import io.trino.sql.ir.NullIf; import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Row; import io.trino.testing.TestingSession; import io.trino.transaction.TestingTransactionManager; import io.trino.transaction.TransactionManager; +import io.trino.type.FunctionType; import io.trino.type.LikeFunctions; import org.junit.jupiter.api.Test; @@ -60,6 +64,7 @@ import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.airlift.slice.Slices.utf8Slice; import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; +import static io.trino.operator.scalar.ArrayTransformFunction.ARRAY_TRANSFORM_NAME; import static io.trino.operator.scalar.JoniRegexpCasts.joniRegexp; import static io.trino.operator.scalar.JsonStringToArrayCast.JSON_STRING_TO_ARRAY_NAME; import static io.trino.operator.scalar.JsonStringToMapCast.JSON_STRING_TO_MAP_NAME; @@ -96,6 +101,7 @@ import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; import static io.trino.spi.type.VarcharType.createVarcharType; +import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.ir.IrExpressions.not; import static io.trino.sql.planner.ConnectorExpressionTranslator.translate; import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT; @@ -114,6 +120,10 @@ public class TestConnectorExpressionTranslator private static final Type ROW_TYPE = rowType(field("int_symbol_1", INTEGER), field("varchar_symbol_1", createVarcharType(5))); private static final VarcharType VARCHAR_TYPE = createUnboundedVarcharType(); private static final ArrayType VARCHAR_ARRAY_TYPE = new ArrayType(VARCHAR_TYPE); + private static final Type ANONYMOUS_ROW_TYPE = RowType.anonymous(ImmutableList.of(INTEGER, createVarcharType(5))); + private static final ArrayType ARRAY_ROW_TYPE = new ArrayType(ANONYMOUS_ROW_TYPE); + private static final Type PRUNED_ROW_TYPE = RowType.anonymous(ImmutableList.of(INTEGER)); + private static final ArrayType PRUNED_ARRAY_ROW_TYPE = new ArrayType(PRUNED_ROW_TYPE); private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); private static final ResolvedFunction NEGATION_DOUBLE = FUNCTIONS.resolveOperator(OperatorType.NEGATION, ImmutableList.of(DOUBLE)); @@ -556,6 +566,69 @@ public void testTranslateCastPlusJsonParse() }); } + @Test + public void testTranslateFieldDereferenceInLambda() + { + TransactionManager transactionManager = new TestingTransactionManager(); + Metadata metadata = MetadataManager.testMetadataManagerBuilder().withTransactionManager(transactionManager).build(); + transaction(transactionManager, metadata, new AllowAllAccessControl()) + .readOnly() + .execute(TEST_SESSION, transactionSession -> { + // Base case, returns proper ArrayFieldDereference expression + ArrayFieldDereference translated = new ArrayFieldDereference(PRUNED_ARRAY_ROW_TYPE, + new Variable("array_of_struct", ARRAY_ROW_TYPE), + List.of(new FieldDereference( + INTEGER, + new Variable("transformarray$element", ANONYMOUS_ROW_TYPE), + 0))); + + // Test wrap with row + assertTranslationToConnectorExpressionWithArrayDereference( + transactionSession, + new Call(FUNCTIONS.resolveFunction(ARRAY_TRANSFORM_NAME, fromTypes(ARRAY_ROW_TYPE, new FunctionType(ImmutableList.of(ANONYMOUS_ROW_TYPE), PRUNED_ROW_TYPE))), ImmutableList.of(new Reference(ARRAY_ROW_TYPE, "array_of_struct"), + new Lambda(ImmutableList.of(new Symbol(ANONYMOUS_ROW_TYPE, "transformarray$element")), + new Row(ImmutableList.of(new FieldReference( + new Reference(ANONYMOUS_ROW_TYPE, "transformarray$element"), + 0)))))), + Optional.of(translated)); + + // Test FieldReference only + assertTranslationToConnectorExpressionWithArrayDereference( + transactionSession, + new Call(FUNCTIONS.resolveFunction(ARRAY_TRANSFORM_NAME, fromTypes(ARRAY_ROW_TYPE, new FunctionType(ImmutableList.of(ANONYMOUS_ROW_TYPE), INTEGER))), ImmutableList.of(new Reference(ARRAY_ROW_TYPE, "array_of_struct"), + new Lambda(ImmutableList.of(new Symbol(ANONYMOUS_ROW_TYPE, "transformarray$element")), + new FieldReference( + new Reference(ANONYMOUS_ROW_TYPE, "transformarray$element"), + 0)))), + Optional.of(new ArrayFieldDereference(new ArrayType(INTEGER), + new Variable("array_of_struct", ARRAY_ROW_TYPE), + List.of(new FieldDereference( + INTEGER, + new Variable("transformarray$element", ANONYMOUS_ROW_TYPE), + 0))))); + + // Multiple subscript expressions available, multiple FieldReferences within ArrayFieldDereference translated + assertTranslationToConnectorExpressionWithArrayDereference( + transactionSession, + new Call(FUNCTIONS.resolveFunction(ARRAY_TRANSFORM_NAME, fromTypes(ARRAY_ROW_TYPE, new FunctionType(ImmutableList.of(ANONYMOUS_ROW_TYPE), ANONYMOUS_ROW_TYPE))), ImmutableList.of(new Reference(ARRAY_ROW_TYPE, "array_of_struct"), + new Lambda(ImmutableList.of(new Symbol(ANONYMOUS_ROW_TYPE, "transformarray$element")), + new Row(ImmutableList.of(new FieldReference( + new Reference(ANONYMOUS_ROW_TYPE, "transformarray$element"), + 0), + new FieldReference( + new Reference(ANONYMOUS_ROW_TYPE, "transformarray$element"), + 1)))))), + Optional.of(new ArrayFieldDereference(ARRAY_ROW_TYPE, + new Variable("array_of_struct", ARRAY_ROW_TYPE), + List.of(new FieldDereference( + INTEGER, + new Variable("transformarray$element", ANONYMOUS_ROW_TYPE), 0), + new FieldDereference( + createVarcharType(5), + new Variable("transformarray$element", ANONYMOUS_ROW_TYPE), 1))))); + }); + } + private void assertTranslationRoundTrips(Expression expression, ConnectorExpression connectorExpression) { assertTranslationRoundTrips(TEST_SESSION, expression, connectorExpression); @@ -584,4 +657,11 @@ private void assertTranslationFromConnectorExpression(Session session, Connector Expression translation = ConnectorExpressionTranslator.translate(session, connectorExpression, PLANNER_CONTEXT, variableMappings); assertThat(translation).isEqualTo(expected); } + + private void assertTranslationToConnectorExpressionWithArrayDereference(Session session, Expression expression, Optional connectorExpression) + { + Optional translation = translate(session, expression, true); + assertThat(connectorExpression.isPresent()).isEqualTo(translation.isPresent()); + translation.ifPresent(value -> assertThat(value).isEqualTo(connectorExpression.get())); + } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java index 9cdd6fc748fe..7e3f22cb826f 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java @@ -74,7 +74,7 @@ public void testPartialTranslator() private void assertFullTranslation(Expression expression) { - Map, ConnectorExpression> translation = extractPartialTranslations(expression, TEST_SESSION); + Map, ConnectorExpression> translation = extractPartialTranslations(expression, TEST_SESSION, false); assertThat(getOnlyElement(translation.keySet())).isEqualTo(NodeRef.of(expression)); assertThat(getOnlyElement(translation.values())).isEqualTo(translate(TEST_SESSION, expression).get()); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionMatcher.java index 5ee85817ae78..16cd1518cc6b 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionMatcher.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionMatcher.java @@ -34,11 +34,13 @@ public class ExpressionMatcher { private final String sql; private final Expression expression; + private final Optional extraSymbolAliases; - ExpressionMatcher(Expression expression) + ExpressionMatcher(Expression expression, Optional extraSymbolAliases) { this.expression = requireNonNull(expression, "expression is null"); this.sql = ExpressionFormatter.formatExpression(expression); + this.extraSymbolAliases = requireNonNull(extraSymbolAliases, "extraSymbolAliases is null"); } @Override @@ -52,7 +54,8 @@ public Optional getAssignedSymbol(PlanNode node, Session session, Metada return result; } - ExpressionVerifier verifier = new ExpressionVerifier(symbolAliases); + // Temporary solution, there is a separate PR to support lambda expression verifier + ExpressionVerifier verifier = new ExpressionVerifier(extraSymbolAliases.isPresent() ? SymbolAliases.builder().putAll(symbolAliases).putAll(extraSymbolAliases.get()).build() : symbolAliases); for (Map.Entry assignment : assignments.entrySet()) { if (verifier.process(assignment.getValue(), expression)) { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java index 18026e6f1b5b..c4459799527f 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java @@ -1028,7 +1028,12 @@ public static SetExpressionMatcher setExpression(ApplyNode.SetExpression express public static ExpressionMatcher expression(Expression expression) { - return new ExpressionMatcher(expression); + return expression(expression, Optional.empty()); + } + + public static ExpressionMatcher expression(Expression expression, Optional extraSymbolAliases) + { + return new ExpressionMatcher(expression, extraSymbolAliases); } public PlanMatchPattern withOutputs(List aliases) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownSubscriptLambdaRules.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownSubscriptLambdaRules.java new file mode 100644 index 000000000000..6940efb6f552 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownSubscriptLambdaRules.java @@ -0,0 +1,193 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TestingFunctionResolution; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Comparison; +import io.trino.sql.ir.Constant; +import io.trino.sql.ir.FieldReference; +import io.trino.sql.ir.Lambda; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Row; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.assertions.SymbolAliases; +import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; +import io.trino.sql.planner.iterative.rule.test.RuleTester; +import io.trino.sql.planner.plan.Assignments; +import io.trino.type.FunctionType; +import org.junit.jupiter.api.Test; + +import java.util.Optional; + +import static io.trino.SystemSessionProperties.ENABLE_PUSH_FIELD_DEREFERENCE_LAMBDA_INTO_SCAN; +import static io.trino.operator.scalar.ArrayTransformFunction.ARRAY_TRANSFORM_NAME; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.NOT_EQUAL; +import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; +import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; +import static io.trino.sql.planner.assertions.PlanMatchPattern.project; +import static io.trino.sql.planner.assertions.PlanMatchPattern.values; + +public class TestPushDownSubscriptLambdaRules + extends BaseRuleTest +{ + private static final Type ROW_TYPE = RowType.anonymous(ImmutableList.of(BIGINT, BIGINT)); + private static final Type PRUNED_ROW_TYPE = RowType.anonymous(ImmutableList.of(BIGINT)); + private static final Type ARRAY_ROW_TYPE = new ArrayType(ROW_TYPE); + private static final Type PRUNED_ARRAY_ROW_TYPE = new ArrayType(PRUNED_ROW_TYPE); + private static final Reference LAMBDA_ELEMENT_REFERENCE = new Reference(ROW_TYPE, "transformarray$element"); + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction TRANSFORM = FUNCTIONS.resolveFunction(ARRAY_TRANSFORM_NAME, fromTypes(ARRAY_ROW_TYPE, new FunctionType(ImmutableList.of(ROW_TYPE), PRUNED_ROW_TYPE))); + private static final Call dereferenceFunctionCall = new Call(TRANSFORM, ImmutableList.of(new Reference(ARRAY_ROW_TYPE, "array_of_struct"), + new Lambda(ImmutableList.of(new Symbol(ROW_TYPE, "transformarray$element")), + new Row(ImmutableList.of(new FieldReference( + LAMBDA_ELEMENT_REFERENCE, + 0)))))); + + @Test + public void testPushDownSubscriptLambdaThroughProject() + { + PushDownFieldReferenceLambdaThroughProject pushDownFieldReferenceLambdaThroughProject = new PushDownFieldReferenceLambdaThroughProject(); + + try (RuleTester ruleTester = RuleTester.builder().addSessionProperty(ENABLE_PUSH_FIELD_DEREFERENCE_LAMBDA_INTO_SCAN, "true").build()) { + // Base symbol referenced by other assignments, skip the optimization + ruleTester.assertThat(pushDownFieldReferenceLambdaThroughProject) + .on(p -> { + Symbol arrayOfRow = p.symbol("array_of_struct", ARRAY_ROW_TYPE); + return p.project( + Assignments.of(p.symbol("pruned_nested_array", PRUNED_ARRAY_ROW_TYPE), + dereferenceFunctionCall, arrayOfRow, new Reference(ARRAY_ROW_TYPE, "array_of_struct")), + p.project( + Assignments.identity(arrayOfRow), + p.values(arrayOfRow))); + }).doesNotFire(); + + // Dereference Lambda being pushed down to the lower projection + ruleTester.assertThat(pushDownFieldReferenceLambdaThroughProject) + .on(p -> { + Symbol arrayOfRow = p.symbol("array_of_struct", ARRAY_ROW_TYPE); + return p.project( + Assignments.of(p.symbol("pruned_nested_array", PRUNED_ARRAY_ROW_TYPE), + dereferenceFunctionCall), + p.project( + Assignments.identity(arrayOfRow), + p.values(arrayOfRow))); + }) + .matches( + project( + ImmutableMap.of("pruned_nested_array", expression(new Reference(PRUNED_ARRAY_ROW_TYPE, "expr"))), + project( + ImmutableMap.of("expr", expression(dereferenceFunctionCall, Optional.of(SymbolAliases.builder().put("transformarray$element", LAMBDA_ELEMENT_REFERENCE).build()))), + values("array_of_struct")))); + + // Dereference Lambda being pushed down to the lower projection, with other symbols kept in projections + ruleTester.assertThat(pushDownFieldReferenceLambdaThroughProject) + .on(p -> { + Symbol arrayOfRow = p.symbol("array_of_struct", ARRAY_ROW_TYPE); + Symbol e = p.symbol("e", ARRAY_ROW_TYPE); + return p.project( + Assignments.of(p.symbol("pruned_nested_array", PRUNED_ARRAY_ROW_TYPE), + dereferenceFunctionCall, e, new Reference(ARRAY_ROW_TYPE, "e")), + p.project( + Assignments.identity(arrayOfRow, e), + p.values(arrayOfRow, e))); + }) + .matches( + project( + ImmutableMap.of("pruned_nested_array", expression(new Reference(PRUNED_ARRAY_ROW_TYPE, "expr")), "e", expression(new Reference(ARRAY_ROW_TYPE, "e"))), + project( + ImmutableMap.of("expr", expression(dereferenceFunctionCall, Optional.of(SymbolAliases.builder().put("transformarray$element", LAMBDA_ELEMENT_REFERENCE).build())), "e", expression(new Reference(ARRAY_ROW_TYPE, "e"))), + values("array_of_struct", "e")))); + } + } + + @Test + public void testPushDownSubscriptLambdaThroughFilter() + { + PushDownFieldReferenceLambdaThroughFilter pushDownFieldReferenceLambdaThroughFilter = new PushDownFieldReferenceLambdaThroughFilter(); + + try (RuleTester ruleTester = RuleTester.builder().addSessionProperty(ENABLE_PUSH_FIELD_DEREFERENCE_LAMBDA_INTO_SCAN, "true").build()) { + // Base symbol referenced by other assignments, skip the optimization + ruleTester.assertThat(pushDownFieldReferenceLambdaThroughFilter) + .on(p -> { + Symbol arrayOfRow = p.symbol("array_of_struct", ARRAY_ROW_TYPE); + Symbol e = p.symbol("e", BIGINT); + return p.project( + Assignments.of(p.symbol("pruned_nested_array", PRUNED_ARRAY_ROW_TYPE), + dereferenceFunctionCall, arrayOfRow, new Reference(ARRAY_ROW_TYPE, "array_of_struct")), + p.filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "e"), new Constant(BIGINT, 1L)), + p.project( + Assignments.identity(arrayOfRow, e), + p.values(arrayOfRow, e)))); + }).doesNotFire(); + + // No filter node, skip the rule + ruleTester.assertThat(pushDownFieldReferenceLambdaThroughFilter) + .on(p -> { + Symbol arrayOfRow = p.symbol("array_of_struct", ARRAY_ROW_TYPE); + return p.project( + Assignments.of(p.symbol("pruned_nested_array", PRUNED_ARRAY_ROW_TYPE), + dereferenceFunctionCall), + p.project( + Assignments.identity(arrayOfRow), + p.values(arrayOfRow))); + }).doesNotFire(); + + // Base symbol referenced by predicate in filter node, skip the optimization + ruleTester.assertThat(pushDownFieldReferenceLambdaThroughFilter) + .on(p -> { + Symbol arrayOfRow = p.symbol("array_of_struct", ARRAY_ROW_TYPE); + return p.project( + Assignments.of(p.symbol("pruned_nested_array", PRUNED_ARRAY_ROW_TYPE), + dereferenceFunctionCall), + p.filter(new Comparison(NOT_EQUAL, new Reference(ARRAY_ROW_TYPE, "array_of_struct"), new Constant(ARRAY_ROW_TYPE, null)), + p.project( + Assignments.identity(arrayOfRow), + p.values(arrayOfRow)))); + }).doesNotFire(); + + // FieldDereference Lambda being pushed down to the lower projection through filter, with other symbols kept in projections + ruleTester.assertThat(pushDownFieldReferenceLambdaThroughFilter) + .on(p -> { + Symbol arrayOfRow = p.symbol("array_of_struct", ARRAY_ROW_TYPE); + Symbol e = p.symbol("e", BIGINT); + return p.project( + Assignments.of(p.symbol("pruned_nested_array", PRUNED_ARRAY_ROW_TYPE), + dereferenceFunctionCall, e, new Reference(BIGINT, "e")), + p.filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "e"), new Constant(BIGINT, 1L)), + p.project( + Assignments.identity(arrayOfRow, e), + p.values(arrayOfRow, e)))); + }) + .matches( + project( + ImmutableMap.of("pruned_nested_array", expression(new Reference(ARRAY_ROW_TYPE, "expr")), "e", expression(new Reference(BIGINT, "e"))), + filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "e"), new Constant(BIGINT, 1L)), + project( + ImmutableMap.of("expr", expression(dereferenceFunctionCall, Optional.of(SymbolAliases.builder().put("transformarray$element", LAMBDA_ELEMENT_REFERENCE).build())), "e", expression(new Reference(BIGINT, "e"))), + project( + ImmutableMap.of("array_of_struct", expression(new Reference(ARRAY_ROW_TYPE, "array_of_struct")), "e", expression(new Reference(BIGINT, "e"))), + values("array_of_struct", "e")))))); + } + } +} diff --git a/core/trino-spi/pom.xml b/core/trino-spi/pom.xml index ebf18ba7e603..be91cbdd480e 100644 --- a/core/trino-spi/pom.xml +++ b/core/trino-spi/pom.xml @@ -46,6 +46,12 @@ true + + com.google.guava + guava + provided + + io.opentelemetry opentelemetry-context @@ -70,12 +76,6 @@ test - - com.google.guava - guava - test - - com.google.inject guice diff --git a/core/trino-spi/src/main/java/io/trino/spi/expression/ArrayFieldDereference.java b/core/trino-spi/src/main/java/io/trino/spi/expression/ArrayFieldDereference.java new file mode 100644 index 000000000000..f4d79c3d9197 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/expression/ArrayFieldDereference.java @@ -0,0 +1,90 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.expression; + +import com.google.common.collect.ImmutableList; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.Type; + +import java.util.List; +import java.util.Objects; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.lang.String.format; +import static java.util.Collections.singletonList; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.joining; + +// This class is used to represent expression with dereferences into Array +// Target is the actual reference to the array. elementFieldDereferences are the field dereferences +public class ArrayFieldDereference + extends ConnectorExpression +{ + private final ConnectorExpression target; + private final List elementFieldDereferences; + + public ArrayFieldDereference(Type type, ConnectorExpression target, List elementFieldDereference) + { + super(type); + checkArgument(type instanceof ArrayType, "wrong input type for ArrayFieldDereference"); + this.target = requireNonNull(target, "target is null"); + this.elementFieldDereferences = ImmutableList.copyOf(requireNonNull(elementFieldDereference, "elementFieldDereference is null")); + } + + public ConnectorExpression getTarget() + { + return target; + } + + public List getElementFieldDereferences() + { + return elementFieldDereferences; + } + + @Override + public List getChildren() + { + return singletonList(target); + } + + @Override + public int hashCode() + { + return Objects.hash(target, elementFieldDereferences, getType()); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + ArrayFieldDereference that = (ArrayFieldDereference) o; + return Objects.equals(target, that.target) + && Objects.equals(elementFieldDereferences, that.elementFieldDereferences) + && Objects.equals(getType(), that.getType()); + } + + @Override + public String toString() + { + return format("(%s).#[%s]", target, elementFieldDereferences.stream() + .map(item -> "(" + item + ")") + .collect(joining(" "))); + } +} diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/projection/ApplyProjectionUtil.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/projection/ApplyProjectionUtil.java index c5f6ef60ef17..6778b812b323 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/projection/ApplyProjectionUtil.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/projection/ApplyProjectionUtil.java @@ -15,6 +15,8 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; +import io.trino.plugin.base.subfield.Subfield; +import io.trino.spi.expression.ArrayFieldDereference; import io.trino.spi.expression.Call; import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.expression.Constant; @@ -24,9 +26,12 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; +import java.util.Set; import java.util.function.Predicate; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.util.Objects.requireNonNull; public final class ApplyProjectionUtil @@ -123,6 +128,29 @@ public static ConnectorExpression replaceWithNewVariables(ConnectorExpression ex throw new UnsupportedOperationException("Unsupported expression: " + expression); } + public static Optional> validArrayFieldDereferences(List projections) + { + Set arrayFieldDereferences = projections.stream() + .filter(ArrayFieldDereference.class::isInstance) + .map(ArrayFieldDereference.class::cast) + .filter(e -> e.getTarget() instanceof Variable) + .collect(toImmutableSet()); + if (arrayFieldDereferences.size() != projections.size()) { + // Sanity check to prevent any unhandled edge cases + // Currently very narrow cases can generate ArrayFieldDereference and can be passed down here + // If ArrayFieldDereference at this point can not be translated to new projections, then we need to dig on the reasons + return Optional.empty(); + } + return Optional.of(arrayFieldDereferences); + } + + public static void generatePathElementWithNames(ImmutableList.Builder pathElements, List pathNamesWithinArray) + { + for (String pathName : pathNamesWithinArray) { + pathElements.add(new Subfield.NestedField(pathName)); + } + } + public static class ProjectedColumnRepresentation { private final Variable variable; diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/subfield/Subfield.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/subfield/Subfield.java new file mode 100644 index 000000000000..900f4cce9138 --- /dev/null +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/subfield/Subfield.java @@ -0,0 +1,337 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.base.subfield; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +import static java.util.Objects.requireNonNull; + +// Class to represent subfield. Direct referenced from Presto +public class Subfield +{ + public sealed interface PathElement permits AllSubscripts, NoSubfield, NestedField, LongSubscript, StringSubscript + { + boolean isSubscript(); + } + + public static final class AllSubscripts + implements PathElement + { + private static final AllSubscripts ALL_SUBSCRIPTS = new AllSubscripts(); + + private AllSubscripts() {} + + public static AllSubscripts getInstance() + { + return ALL_SUBSCRIPTS; + } + + @Override + public boolean isSubscript() + { + return true; + } + + @Override + public String toString() + { + return "[*]"; + } + } + + public static final class NoSubfield + implements PathElement + { + private static final NoSubfield NO_SUBFIELD = new NoSubfield(); + + public static NoSubfield getInstance() + { + return NO_SUBFIELD; + } + + @Override + public boolean isSubscript() + { + return false; + } + + @Override + public String toString() + { + return ".$"; + } + } + + public static final class NestedField + implements PathElement + { + private final String name; + + public NestedField(String name) + { + this.name = requireNonNull(name, "name is null"); + } + + public String getName() + { + return name; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + NestedField that = (NestedField) o; + return Objects.equals(name, that.name); + } + + @Override + public int hashCode() + { + return Objects.hash(name); + } + + @Override + public String toString() + { + return "." + name; + } + + @Override + public boolean isSubscript() + { + return false; + } + } + + public static final class LongSubscript + implements PathElement + { + private final long index; + + public LongSubscript(long index) + { + this.index = index; + } + + public long getIndex() + { + return index; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + LongSubscript that = (LongSubscript) o; + return index == that.index; + } + + @Override + public int hashCode() + { + return Objects.hash(index); + } + + @Override + public String toString() + { + return "[" + index + "]"; + } + + @Override + public boolean isSubscript() + { + return true; + } + } + + public static final class StringSubscript + implements PathElement + { + private final String index; + + public StringSubscript(String index) + { + this.index = requireNonNull(index, "index is null"); + } + + public String getIndex() + { + return index; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + StringSubscript that = (StringSubscript) o; + return Objects.equals(index, that.index); + } + + @Override + public int hashCode() + { + return Objects.hash(index); + } + + @Override + public String toString() + { + return "[\"" + index.replace("\"", "\\\"") + "\"]"; + } + + @Override + public boolean isSubscript() + { + return true; + } + } + + private final String name; + private final List path; + + public static PathElement allSubscripts() + { + return AllSubscripts.getInstance(); + } + + public static PathElement noSubfield() + { + return NoSubfield.getInstance(); + } + + @JsonCreator + public Subfield(String path) + { + requireNonNull(path, "path is null"); + + SubfieldTokenizer tokenizer = new SubfieldTokenizer(path); + checkArgument(tokenizer.hasNext(), "Column name is missing: " + path); + + PathElement firstElement = tokenizer.next(); + checkArgument(firstElement instanceof NestedField, "Subfield path must start with a name: " + path); + + this.name = ((NestedField) firstElement).getName(); + + List pathElements = new ArrayList<>(); + tokenizer.forEachRemaining(pathElements::add); + this.path = Collections.unmodifiableList(pathElements); + } + + private static void checkArgument(boolean expression, String errorMessage) + { + if (!expression) { + throw new IllegalArgumentException(errorMessage); + } + } + + public Subfield(String name, List path) + { + this.name = requireNonNull(name, "name is null"); + this.path = requireNonNull(path, "path is null"); + } + + public String getRootName() + { + return name; + } + + public List getPath() + { + return path; + } + + public boolean isPrefix(Subfield other) + { + if (!other.name.equals(name)) { + return false; + } + + if (path.size() < other.path.size()) { + return Objects.equals(path, other.path.subList(0, path.size())); + } + + return false; + } + + public Subfield tail(String name) + { + if (path.isEmpty()) { + throw new IllegalStateException("path is empty"); + } + return new Subfield(name, path.subList(1, path.size())); + } + + @JsonValue + public String serialize() + { + return name + path.stream() + .map(PathElement::toString) + .collect(Collectors.joining()); + } + + @Override + public String toString() + { + return serialize(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + Subfield other = (Subfield) o; + return Objects.equals(name, other.name) && + Objects.equals(path, other.path); + } + + @Override + public int hashCode() + { + return Objects.hash(name, path); + } +} diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/subfield/SubfieldTokenizer.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/subfield/SubfieldTokenizer.java new file mode 100644 index 000000000000..6b3419337bfc --- /dev/null +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/subfield/SubfieldTokenizer.java @@ -0,0 +1,293 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.base.subfield; + +import io.trino.spi.TrinoException; + +import java.util.Iterator; +import java.util.NoSuchElementException; + +import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static java.lang.Character.isLetterOrDigit; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +// Class to translate subfield. Direct referenced from Presto +public class SubfieldTokenizer + implements Iterator +{ + private static final char QUOTE = '\"'; + private static final char BACKSLASH = '\\'; + private static final char DOT = '.'; + private static final char OPEN_BRACKET = '['; + private static final char CLOSE_BRACKET = ']'; + private static final char UNICODE_CARET = '\u2038'; + private static final char WILDCARD = '*'; + private static final char DOLLAR = '$'; + + private final String path; + private State state = State.NOT_READY; + private int index; + private boolean firstSegment = true; + private Subfield.PathElement next; + + public SubfieldTokenizer(String path) + { + this.path = requireNonNull(path, "path is null"); + + if (path.isEmpty()) { + throw invalidSubfieldPath(); + } + } + + @Override + public final boolean hasNext() + { + if (state == State.FAILED) { + throw new IllegalStateException(); + } + switch (state) { + case DONE: + return false; + case READY: + return true; + default: + } + return tryToComputeNext(); + } + + private boolean tryToComputeNext() + { + state = State.FAILED; // temporary pessimism + next = computeNext(); + if (state != State.DONE) { + state = State.READY; + return true; + } + return false; + } + + @Override + public final Subfield.PathElement next() + { + if (!hasNext()) { + throw new NoSuchElementException(); + } + state = State.NOT_READY; + Subfield.PathElement result = next; + next = null; + return result; + } + + @Override + public final void remove() + { + throw new UnsupportedOperationException(); + } + + private Subfield.PathElement computeNext() + { + if (!hasNextCharacter()) { + state = State.DONE; + return null; + } + + if (tryMatch(DOT)) { + Subfield.PathElement token = tryMatch(DOLLAR) ? matchDollarPathElement() : matchPathSegment(); + firstSegment = false; + return token; + } + + if (tryMatch(OPEN_BRACKET)) { + Subfield.PathElement token = tryMatch(QUOTE) ? matchQuotedSubscript() : tryMatch(WILDCARD) ? matchWildcardSubscript() : matchUnquotedSubscript(); + + match(CLOSE_BRACKET); + firstSegment = false; + return token; + } + + if (firstSegment) { + Subfield.PathElement token = matchPathSegment(); + firstSegment = false; + return token; + } + + throw invalidSubfieldPath(); + } + + private Subfield.PathElement matchPathSegment() + { + // seek until we see a special character or whitespace + int start = index; + while (hasNextCharacter() && isUnquotedPathCharacter(peekCharacter())) { + nextCharacter(); + } + int end = index; + + String token = path.substring(start, end); + + // an empty unquoted token is not allowed + if (token.isEmpty()) { + throw invalidSubfieldPath(); + } + + return new Subfield.NestedField(token); + } + + private Subfield.PathElement matchWildcardSubscript() + { + return Subfield.allSubscripts(); + } + + private Subfield.PathElement matchDollarPathElement() + { + return Subfield.noSubfield(); + } + + private static boolean isUnquotedPathCharacter(char c) + { + return c == ':' || c == '$' || c == '-' || c == '/' || c == '@' || c == '|' || c == '#' || c == ' ' || isUnquotedSubscriptCharacter(c); + } + + private Subfield.PathElement matchUnquotedSubscript() + { + // seek until we see a special character or whitespace + int start = index; + while (hasNextCharacter() && isUnquotedSubscriptCharacter(peekCharacter())) { + nextCharacter(); + } + int end = index; + + String token = path.substring(start, end); + + // an empty unquoted token is not allowed + if (token.isEmpty()) { + throw invalidSubfieldPath(); + } + + long index; + try { + index = Long.valueOf(token); + } + catch (NumberFormatException e) { + throw invalidSubfieldPath(); + } + + return new Subfield.LongSubscript(index); + } + + private static boolean isUnquotedSubscriptCharacter(char c) + { + return c == '-' || c == '_' || isLetterOrDigit(c); + } + + private Subfield.PathElement matchQuotedSubscript() + { + // quote has already been matched + + // seek until we see the close quote + StringBuilder token = new StringBuilder(); + boolean escaped = false; + + while (hasNextCharacter() && (escaped || peekCharacter() != QUOTE)) { + if (escaped) { + switch (peekCharacter()) { + case QUOTE: + case BACKSLASH: + token.append(peekCharacter()); + break; + default: + throw invalidSubfieldPath(); + } + escaped = false; + } + else { + if (peekCharacter() == BACKSLASH) { + escaped = true; + } + else { + token.append(peekCharacter()); + } + } + nextCharacter(); + } + if (escaped) { + throw invalidSubfieldPath(); + } + + match(QUOTE); + + String index = token.toString(); + if (index.equals(String.valueOf(WILDCARD))) { + return Subfield.allSubscripts(); + } + return new Subfield.StringSubscript(index); + } + + private boolean hasNextCharacter() + { + return index < path.length(); + } + + private void match(char expected) + { + if (!tryMatch(expected)) { + throw invalidSubfieldPath(); + } + } + + private boolean tryMatch(char expected) + { + if (!hasNextCharacter() || peekCharacter() != expected) { + return false; + } + index++; + return true; + } + + private void nextCharacter() + { + index++; + } + + private char peekCharacter() + { + return path.charAt(index); + } + + private TrinoException invalidSubfieldPath() + { + return new TrinoException(INVALID_FUNCTION_ARGUMENT, format("Invalid subfield path: '%s'", this)); + } + + @Override + public String toString() + { + return path.substring(0, index) + UNICODE_CARET + path.substring(index); + } + + private enum State { + /** We have computed the next element and haven't returned it yet. */ + READY, + + /** We haven't yet computed or have already returned the element. */ + NOT_READY, + + /** We have reached the end of the data and are finished. */ + DONE, + + /** We've suffered an exception and are kaput. */ + FAILED, + } +} diff --git a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/subfield/TestSubfieldTokenizer.java b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/subfield/TestSubfieldTokenizer.java new file mode 100644 index 000000000000..81a0ecb178e8 --- /dev/null +++ b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/subfield/TestSubfieldTokenizer.java @@ -0,0 +1,106 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.base.subfield; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Streams; +import io.trino.spi.TrinoException; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatException; + +// Class to test SubfieldTokenizer. Direct referenced from Presto +public class TestSubfieldTokenizer +{ + @Test + public void test() + { + List elements = ImmutableList.of( + new Subfield.NestedField("b"), + new Subfield.LongSubscript(2), + new Subfield.LongSubscript(-1), + new Subfield.StringSubscript("z"), + Subfield.allSubscripts(), + new Subfield.StringSubscript("34"), + new Subfield.StringSubscript("b \"test\""), + new Subfield.StringSubscript("\"abc"), + new Subfield.StringSubscript("abc\""), + new Subfield.StringSubscript("ab\"cde"), + new Subfield.StringSubscript("a.b[\"hello\uDBFF\"]")); + + for (Subfield.PathElement element : elements) { + assertPath(new Subfield("a", ImmutableList.of(element))); + } + + for (Subfield.PathElement element : elements) { + for (Subfield.PathElement secondElement : elements) { + assertPath(new Subfield("a", ImmutableList.of(element, secondElement))); + } + } + + for (Subfield.PathElement element : elements) { + for (Subfield.PathElement secondElement : elements) { + for (Subfield.PathElement thirdElement : elements) { + assertPath(new Subfield("a", ImmutableList.of(element, secondElement, thirdElement))); + } + } + } + } + + private static void assertPath(Subfield path) + { + SubfieldTokenizer tokenizer = new SubfieldTokenizer(path.serialize()); + assertThat(tokenizer.hasNext()); + assertThat(new Subfield(((Subfield.NestedField) tokenizer.next()).getName(), Streams.stream(tokenizer).collect(toImmutableList())).equals(path)); + } + + @Test + public void testColumnNames() + { + assertPath(new Subfield("#bucket", ImmutableList.of())); + assertPath(new Subfield("$bucket", ImmutableList.of())); + assertPath(new Subfield("apollo-11", ImmutableList.of())); + assertPath(new Subfield("a/b/c:12", ImmutableList.of())); + assertPath(new Subfield("@basis", ImmutableList.of())); + assertPath(new Subfield("@basis|city_id", ImmutableList.of())); + assertPath(new Subfield("a and b", ImmutableList.of())); + } + + @Test + public void testInvalidPaths() + { + assertInvalidPath("a[b]"); + assertInvalidPath("a[2"); + assertInvalidPath("a.*"); + assertInvalidPath("a[2].[3]."); + } + + private void assertInvalidPath(String path) + { + SubfieldTokenizer tokenizer = new SubfieldTokenizer(path); + + try { + Streams.stream(tokenizer).collect(toImmutableList()); + assertThatException(); + } + catch (TrinoException trinoException) { + assertThat(trinoException.getErrorCode() == INVALID_FUNCTION_ARGUMENT.toErrorCode()); + } + } +} diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeColumnProjectionInfo.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeColumnProjectionInfo.java index f271638fdf74..96e451f3b563 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeColumnProjectionInfo.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeColumnProjectionInfo.java @@ -89,7 +89,9 @@ public long getRetainedSizeInBytes() public HiveColumnProjectionInfo toHiveColumnProjectionInfo() { - return new HiveColumnProjectionInfo(dereferenceIndices, dereferencePhysicalNames, toHiveType(type), type); + // Currently not supporting DeltaLake subfield + // TODO: Once Subfield is accepted, will extend it to DeltaLake, or even Iceberg which currently not even supporting ColumnProjectionInfo + return new HiveColumnProjectionInfo(dereferenceIndices, dereferencePhysicalNames, toHiveType(type), type, ImmutableList.of()); } @Override diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java index 0201735b36f3..21829dda462a 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java @@ -371,7 +371,8 @@ private TupleDomain buildTupleDomainColumnHandle(EntryType ent ImmutableList.of(0), // hiveColumnIndex; we provide fake value because we always find columns by name ImmutableList.of(field), toHiveType(type), - type)), + type, + ImmutableList.of())), ColumnType.REGULAR, column.getComment()); @@ -395,7 +396,8 @@ private static HiveColumnHandle toPartitionValuesParsedField(HiveColumnHandle ad ImmutableList.of(0, 0), // hiveColumnIndex; we provide fake value because we always find columns by name ImmutableList.of("partitionvalues_parsed", partitionColumn.columnName()), DeltaHiveTypeTranslator.toHiveType(partitionColumn.type()), - partitionColumn.type())), + partitionColumn.type(), + ImmutableList.of())), HiveColumnHandle.ColumnType.REGULAR, addColumn.getComment()); } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveColumnProjectionInfo.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveColumnProjectionInfo.java index 85be14f7ad1b..3493464ee536 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveColumnProjectionInfo.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveColumnProjectionInfo.java @@ -17,6 +17,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import io.airlift.slice.SizeOf; import io.trino.metastore.HiveType; +import io.trino.plugin.base.subfield.Subfield; import io.trino.spi.type.Type; import java.util.List; @@ -37,23 +38,26 @@ public class HiveColumnProjectionInfo private final HiveType hiveType; private final Type type; private final String partialName; + private List subfields; @JsonCreator public HiveColumnProjectionInfo( @JsonProperty("dereferenceIndices") List dereferenceIndices, @JsonProperty("dereferenceNames") List dereferenceNames, @JsonProperty("hiveType") HiveType hiveType, - @JsonProperty("type") Type type) + @JsonProperty("type") Type type, + @JsonProperty("subfields") List subfields) { this.dereferenceIndices = requireNonNull(dereferenceIndices, "dereferenceIndices is null"); this.dereferenceNames = requireNonNull(dereferenceNames, "dereferenceNames is null"); - checkArgument(dereferenceIndices.size() > 0, "dereferenceIndices should not be empty"); + // checkArgument(dereferenceIndices.size() > 0, "dereferenceIndices should not be empty"); checkArgument(dereferenceIndices.size() == dereferenceNames.size(), "dereferenceIndices and dereferenceNames should have the same sizes"); this.hiveType = requireNonNull(hiveType, "hiveType is null"); this.type = requireNonNull(type, "type is null"); this.partialName = generatePartialName(dereferenceNames); + this.subfields = subfields; } public String getPartialName() @@ -85,10 +89,16 @@ public Type getType() return type; } + @JsonProperty + public List getSubfields() + { + return subfields; + } + @Override public int hashCode() { - return Objects.hash(dereferenceIndices, dereferenceNames, hiveType, type); + return Objects.hash(dereferenceIndices, dereferenceNames, hiveType, type, subfields); } @Override @@ -105,7 +115,8 @@ public boolean equals(Object obj) return Objects.equals(this.dereferenceIndices, other.dereferenceIndices) && Objects.equals(this.dereferenceNames, other.dereferenceNames) && Objects.equals(this.hiveType, other.hiveType) && - Objects.equals(this.type, other.type); + Objects.equals(this.type, other.type) && + Objects.equals(this.subfields, other.subfields); } @Override diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadata.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadata.java index 9b8ac1488af2..ec1fb469c848 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadata.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadata.java @@ -48,8 +48,10 @@ import io.trino.metastore.StorageFormat; import io.trino.metastore.Table; import io.trino.metastore.TableInfo; +import io.trino.metastore.type.ListTypeInfo; import io.trino.plugin.base.projection.ApplyProjectionUtil; import io.trino.plugin.base.projection.ApplyProjectionUtil.ProjectedColumnRepresentation; +import io.trino.plugin.base.subfield.Subfield; import io.trino.plugin.hive.HiveSessionProperties.InsertExistingPartitionsBehavior; import io.trino.plugin.hive.HiveWritableTableHandle.BucketInfo; import io.trino.plugin.hive.LocationService.WriteInfo; @@ -106,7 +108,9 @@ import io.trino.spi.connector.TableScanRedirectApplicationResult; import io.trino.spi.connector.ViewNotFoundException; import io.trino.spi.connector.WriterScalingOptions; +import io.trino.spi.expression.ArrayFieldDereference; import io.trino.spi.expression.ConnectorExpression; +import io.trino.spi.expression.FieldDereference; import io.trino.spi.expression.Variable; import io.trino.spi.function.LanguageFunction; import io.trino.spi.function.SchemaFunctionName; @@ -173,8 +177,12 @@ import static io.trino.metastore.StorageFormat.VIEW_STORAGE_FORMAT; import static io.trino.metastore.type.Category.PRIMITIVE; import static io.trino.parquet.writer.ParquetWriter.SUPPORTED_BLOOM_FILTER_TYPES; +import static io.trino.plugin.base.projection.ApplyProjectionUtil.createProjectedColumnRepresentation; import static io.trino.plugin.base.projection.ApplyProjectionUtil.extractSupportedProjectedColumns; +import static io.trino.plugin.base.projection.ApplyProjectionUtil.generatePathElementWithNames; import static io.trino.plugin.base.projection.ApplyProjectionUtil.replaceWithNewVariables; +import static io.trino.plugin.base.projection.ApplyProjectionUtil.validArrayFieldDereferences; +import static io.trino.plugin.base.subfield.Subfield.allSubscripts; import static io.trino.plugin.hive.HiveAnalyzeProperties.getColumnNames; import static io.trino.plugin.hive.HiveAnalyzeProperties.getPartitionList; import static io.trino.plugin.hive.HiveApplyProjectionUtil.find; @@ -294,6 +302,7 @@ import static io.trino.plugin.hive.util.HiveBucketing.isSupportedBucketing; import static io.trino.plugin.hive.util.HiveTypeTranslator.toHiveType; import static io.trino.plugin.hive.util.HiveTypeUtil.getHiveDereferenceNames; +import static io.trino.plugin.hive.util.HiveTypeUtil.getHiveDereferenceNamesWithinArray; import static io.trino.plugin.hive.util.HiveTypeUtil.getHiveTypeForDereferences; import static io.trino.plugin.hive.util.HiveTypeUtil.getType; import static io.trino.plugin.hive.util.HiveTypeUtil.getTypeSignature; @@ -3124,6 +3133,17 @@ public Optional> applyProjecti return Optional.empty(); } + // Reuse the applyProjection interface while handling ArrayFieldDereference separately. + // ArrayFieldDereference can only be passed in from PushSubscriptLambdaIntoTableScan rule at this moment + // TODO: fully integrate PushSubscriptLambdaIntoTableScan with PushProjectionIntoTableScan rule + if (projections.stream().anyMatch(ArrayFieldDereference.class::isInstance)) { + Optional> arrayFieldDereferences = + validArrayFieldDereferences(projections); + return arrayFieldDereferences.isPresent() ? + applyArrayFieldDereferences((HiveTableHandle) handle, arrayFieldDereferences.get(), assignments) : + Optional.empty(); + } + // Create projected column representations for supported sub expressions. Simple column references and chain of // dereferences on a variable are supported right now. Set projectedExpressions = projections.stream() @@ -3202,6 +3222,131 @@ public Optional> applyProjecti false)); } + private Optional> applyArrayFieldDereferences( + HiveTableHandle handle, + Set arrayFieldDereferences, + Map assignments) + { + ImmutableMap.Builder newAssignmentsMapBuilder = ImmutableMap.builder(); + for (ArrayFieldDereference arrayFieldDereference : arrayFieldDereferences) { + Variable inputVariable = (Variable) arrayFieldDereference.getTarget(); + String inputSymbolName = inputVariable.getName(); + HiveColumnHandle previousColumnHandle = (HiveColumnHandle) assignments.get(inputSymbolName); + List subscriptExpressions = arrayFieldDereference.getElementFieldDereferences(); + if (!subscriptExpressions.stream().allMatch(FieldDereference.class::isInstance) || previousColumnHandle == null) { + // Do not create subfields from non FieldDereference expression + continue; + } + + Optional previousProjectionInfo = previousColumnHandle.getHiveColumnProjectionInfo(); + List existingPathNames = + previousProjectionInfo.isPresent() ? previousProjectionInfo.get().getDereferenceNames() : ImmutableList.of(); + + // Generate new subfields + ImmutableList.Builder subfields = ImmutableList.builder(); + generateNewSubfields(previousProjectionInfo.isPresent() ? + previousProjectionInfo.get().getHiveType() : previousColumnHandle.getHiveType(), + previousColumnHandle.getBaseColumnName(), subscriptExpressions, existingPathNames, subfields); + ImmutableList newSubfields = subfields.build(); + + if (newSubfields.size() != subscriptExpressions.size()) { + // Do not overwrite projection in case any subscript expression failed to be handled and causing missing data. + continue; + } + + HiveColumnProjectionInfo newHiveColumnProjectionInfo; + if (previousProjectionInfo.isPresent() && !previousProjectionInfo.get().getSubfields().equals(newSubfields)) { + // We want to overwrite the previous subfields result in case they are no longer eligible + newHiveColumnProjectionInfo = previousProjectionInfo.get(); + newHiveColumnProjectionInfo = new HiveColumnProjectionInfo( + newHiveColumnProjectionInfo.getDereferenceIndices(), + newHiveColumnProjectionInfo.getDereferenceNames(), + newHiveColumnProjectionInfo.getHiveType(), + newHiveColumnProjectionInfo.getType(), + newSubfields); + } + else if (!previousProjectionInfo.isPresent() && !newSubfields.isEmpty()) { + // create new HiveColumnProjectionInfo + newHiveColumnProjectionInfo = new HiveColumnProjectionInfo( + ImmutableList.of(), + ImmutableList.of(), + previousColumnHandle.getBaseHiveType(), + previousColumnHandle.getBaseType(), + newSubfields); + } + else { + // no new projection is added + continue; + } + + // Create new HiveColumnHandle + HiveColumnHandle newHiveColumnHandle = new HiveColumnHandle( + previousColumnHandle.getBaseColumnName(), + previousColumnHandle.getBaseHiveColumnIndex(), + previousColumnHandle.getBaseHiveType(), + previousColumnHandle.getBaseType(), + Optional.of(newHiveColumnProjectionInfo), + previousColumnHandle.getColumnType(), + previousColumnHandle.getComment()); + newAssignmentsMapBuilder.put(inputSymbolName, newHiveColumnHandle); + } + + Map newAssignmentsMap = newAssignmentsMapBuilder.buildOrThrow(); + // Only return new tableHandle if the rule actually create different subfields + if (!newAssignmentsMap.isEmpty()) { + return getProjectionApplicationResult(handle, arrayFieldDereferences, assignments, newAssignmentsMap); + } + return Optional.empty(); + } + + private void generateNewSubfields(HiveType hivetype, String baseColumnName, List subscriptExpressions, + List existingPathNames, ImmutableList.Builder subfields) + { + for (ConnectorExpression fieldDereferenceExpression : subscriptExpressions) { + ProjectedColumnRepresentation subscriptRepresentation = createProjectedColumnRepresentation(fieldDereferenceExpression); + ImmutableList.Builder pathElements = ImmutableList.builder(); + // Generate subfields with existing prefix + generatePathElementWithNames(pathElements, existingPathNames); + + List dereferenceIndices = subscriptRepresentation.getDereferenceIndices(); + + // Generate subfields with dereferences on array + if (!dereferenceIndices.isEmpty() && hivetype.getTypeInfo() instanceof ListTypeInfo typeInfo) { + pathElements.add(allSubscripts()); + generatePathElementWithNames(pathElements, getHiveDereferenceNamesWithinArray(typeInfo, dereferenceIndices)); + } + + ImmutableList pathElementList = pathElements.build(); + if (!pathElementList.isEmpty()) { + subfields.add(new Subfield(baseColumnName, pathElementList)); + } + } + } + + private Optional> getProjectionApplicationResult(HiveTableHandle handle, + Set arrayFieldDereferences, + Map assignments, + Map newAssignmentsMap) + { + ImmutableSet.Builder projectedColumnsBuilder = ImmutableSet.builder(); + ImmutableList.Builder newAssignmentBuilder = ImmutableList.builder(); + // Assignment type does not matter, only columnhandle is going to be used, type is kept to prevent touching + // existing applyProjection validations + for (Map.Entry m : assignments.entrySet()) { + String originalSymbol = m.getKey(); + HiveColumnHandle originalColumnHandle = (HiveColumnHandle) m.getValue(); + projectedColumnsBuilder.add(newAssignmentsMap.containsKey(originalSymbol) ? newAssignmentsMap.get(originalSymbol) : originalColumnHandle); + newAssignmentBuilder.add(newAssignmentsMap.containsKey(originalSymbol) ? + new Assignment(originalSymbol, newAssignmentsMap.get(originalSymbol), ((HiveColumnHandle) originalColumnHandle).getType()) + : new Assignment(originalSymbol, originalColumnHandle, ((HiveColumnHandle) originalColumnHandle).getType())); + } + return Optional.of(new ProjectionApplicationResult<>( + handle.withProjectedColumns(projectedColumnsBuilder.build()), + ImmutableList.copyOf(arrayFieldDereferences), + newAssignmentBuilder.build(), + false)); + } + private HiveColumnHandle createProjectedColumnHandle(HiveColumnHandle column, List indices) { HiveType oldHiveType = column.getHiveType(); @@ -3223,7 +3368,8 @@ private HiveColumnHandle createProjectedColumnHandle(HiveColumnHandle column, Li .addAll(getHiveDereferenceNames(oldHiveType, indices)) .build(), newHiveType, - typeManager.getType(getTypeSignature(newHiveType))); + typeManager.getType(getTypeSignature(newHiveType)), + ImmutableList.of()); // If prefixes were updated, this will need to be re-created. However this will never happen due to order of rules return new HiveColumnHandle( column.getBaseColumnName(), diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSourceProvider.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSourceProvider.java index a0b815a41b88..817495519726 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSourceProvider.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSourceProvider.java @@ -494,7 +494,8 @@ public static List toColumnHandles(List regular projectedColumn.getDereferenceIndices(), projectedColumn.getDereferenceNames(), fromHiveType, - createTypeFromCoercer(typeManager, fromHiveType, columnHandle.getHiveType(), coercionContext)); + createTypeFromCoercer(typeManager, fromHiveType, columnHandle.getHiveType(), coercionContext), + projectedColumn.getSubfields()); }); return new HiveColumnHandle( diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java index f16399c17d38..ae806416b932 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java @@ -13,6 +13,7 @@ */ package io.trino.plugin.hive.parquet; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -36,6 +37,7 @@ import io.trino.parquet.reader.MetadataReader; import io.trino.parquet.reader.ParquetReader; import io.trino.parquet.reader.RowGroupInfo; +import io.trino.plugin.base.subfield.Subfield; import io.trino.plugin.hive.AcidInfo; import io.trino.plugin.hive.FileFormatDataSourceStats; import io.trino.plugin.hive.HiveColumnHandle; @@ -56,6 +58,7 @@ import org.apache.parquet.io.MessageColumnIO; import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.OriginalType; import org.apache.parquet.schema.PrimitiveType; import org.apache.parquet.schema.Type; import org.joda.time.DateTimeZone; @@ -343,6 +346,34 @@ public static Optional getColumnType(HiveColumnH if (baseColumnType.isEmpty() || column.getHiveColumnProjectionInfo().isEmpty()) { return baseColumnType; } + + // If subfields is not empty, it is the source of truth instead, because it will contain all prefixes from the root + // if subfields are only created through PushSubscriptLambdaIntoTableScan rule, so disable the rule will automatically + // skip below logic if anything goes wrong + // TODO: This part of logic can be consolidated after switching completely to Subfield + if (useParquetColumnNames && !column.getHiveColumnProjectionInfo().get().getSubfields().isEmpty()) { + try { + MessageType combinedPrunedTypes = null; + for (Subfield subfield : column.getHiveColumnProjectionInfo().get().getSubfields()) { + MessageType prunedType = pruneRootTypeForSubfield(messageType, subfield); + if (combinedPrunedTypes == null) { + combinedPrunedTypes = prunedType; + } + else { + combinedPrunedTypes = combinedPrunedTypes.union(prunedType); + } + } + return Optional.of(combinedPrunedTypes) // Should never be null since subfields is non-empty. + .map(type -> getParquetTypeByName(column.getBaseColumnName(), type)); + } + catch (Exception e) { + return baseColumnType; + } + } + else if (column.getHiveColumnProjectionInfo().get().getDereferenceNames().isEmpty()) { + return baseColumnType; + } + GroupType baseType = baseColumnType.get().asGroupType(); Optional> subFieldTypesOptional = dereferenceSubFieldTypes(baseType, column.getHiveColumnProjectionInfo().get()); @@ -494,6 +525,11 @@ private static Optional> dereferenceSubFiel checkArgument(baseType != null, "base type cannot be null when dereferencing"); checkArgument(projectionInfo != null, "hive column projection info cannot be null when doing dereferencing"); + // dereferenceNames can be empty if subfields are available + if (projectionInfo.getDereferenceNames().isEmpty()) { + return Optional.empty(); + } + ImmutableList.Builder typeBuilder = ImmutableList.builder(); org.apache.parquet.schema.Type parentType = baseType; @@ -508,4 +544,84 @@ private static Optional> dereferenceSubFiel return Optional.of(typeBuilder.build()); } + + // Below are directly referenced from Presto to minimize testings needed, no further change was done + private static MessageType pruneRootTypeForSubfield(MessageType rootType, Subfield subfield) + { + org.apache.parquet.schema.Type columnType = getParquetTypeByName(subfield.getRootName(), rootType); + if (columnType == null) { + return new MessageType(rootType.getName(), ImmutableList.of()); + } + org.apache.parquet.schema.Type prunedColumnType = pruneColumnTypeForPath(columnType, subfield.getPath()); + return new MessageType(rootType.getName(), prunedColumnType); + } + + @VisibleForTesting + static org.apache.parquet.schema.Type pruneColumnTypeForPath(org.apache.parquet.schema.Type columnType, List pathElements) + { + try { + return pruneTypeForPath(columnType, pathElements); + } + catch (Exception e) { + String singleLineColumnTypeString = (columnType == null) ? + "Unknown" : + columnType.toString().replaceAll("[\\r\\n]+", ""); + return columnType; + } + } + + private static org.apache.parquet.schema.Type pruneTypeForPath(org.apache.parquet.schema.Type type, List path) + { + if (path.isEmpty()) { + return type; + } + + // Path is not empty, so type must be a group type + GroupType groupType = type.asGroupType(); + + Subfield.PathElement pathElement = path.get(0); + OriginalType originalType = type.getOriginalType(); + if (pathElement.isSubscript()) { // Accessing an array or map + + // If path element is subscript and its the last element, no more pruning is possible, only nested fields + // result in pruning + if (path.size() == 1) { + return type; + } + + GroupType firstFieldType = groupType.getType(0).asGroupType(); + + if (originalType == OriginalType.MAP + || originalType == OriginalType.MAP_KEY_VALUE) { // Backwards compatibility case + org.apache.parquet.schema.Type keyType = firstFieldType.asGroupType().getType(0); + org.apache.parquet.schema.Type valueType = firstFieldType.asGroupType().getType(1); + org.apache.parquet.schema.Type newValueType = pruneTypeForPath(valueType, + path.subList(1, path.size())); + return groupType.withNewFields(firstFieldType.withNewFields(keyType, newValueType)); + } + + if (originalType == OriginalType.LIST) { + if (firstFieldType.getFields().size() > 1 + || firstFieldType.getName().equals("array") + || firstFieldType.getName().equals(type.getName() + "_tuple")) { + return groupType.withNewFields(pruneTypeForPath(firstFieldType, path.subList(1, path.size()))); + } + + return groupType.withNewFields(firstFieldType.withNewFields( + pruneTypeForPath(firstFieldType.getType(0), path.subList(1, path.size())))); + } + } + else { // Accessing a nested field + + // The only non-subscript field is NestedField, casting is safe here + final String name = ((Subfield.NestedField) pathElement).getName(); + + org.apache.parquet.schema.Type subType = pruneTypeForPath(getParquetTypeByName(name, groupType), + path.subList(1, path.size())); + return groupType.withNewFields(subType); + } + + // Fallback to original type without pruning + return type; + } } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveTypeUtil.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveTypeUtil.java index 1dd0c9b7aeea..213889fc1592 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveTypeUtil.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveTypeUtil.java @@ -36,6 +36,7 @@ import static io.trino.hive.formats.UnionToRowCoercionUtils.UNION_FIELD_FIELD_PREFIX; import static io.trino.hive.formats.UnionToRowCoercionUtils.UNION_FIELD_TAG_NAME; import static io.trino.hive.formats.UnionToRowCoercionUtils.UNION_FIELD_TAG_TYPE; +import static io.trino.metastore.HiveType.fromTypeInfo; import static io.trino.plugin.hive.HiveStorageFormat.AVRO; import static io.trino.plugin.hive.HiveStorageFormat.ORC; import static io.trino.plugin.hive.HiveTimestampPrecision.DEFAULT_PRECISION; @@ -141,7 +142,15 @@ else if (typeInfo instanceof UnionTypeInfo unionTypeInfo) { throw new IllegalArgumentException(lenientFormat("typeInfo: %s should be struct or union type", typeInfo)); } } - return Optional.of(HiveType.fromTypeInfo(typeInfo)); + return Optional.of(fromTypeInfo(typeInfo)); + } + + public static List getHiveDereferenceNamesWithinArray(TypeInfo typeInfo, List dereferences) + { + checkArgument(typeInfo instanceof ListTypeInfo); + // Minimum but not optimized code change to reuse getHiveDereferenceNames + // Only single level of dereference into Array is possible for now + return HiveTypeUtil.getHiveDereferenceNames(fromTypeInfo(((ListTypeInfo) typeInfo).getListElementTypeInfo()), dereferences); } public static List getHiveDereferenceNames(HiveType hiveType, List dereferences) diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveColumnHandle.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveColumnHandle.java index 2b56a12bfa3f..0ba592c48de9 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveColumnHandle.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveColumnHandle.java @@ -71,7 +71,8 @@ public void testProjectedColumn() ImmutableList.of(1), ImmutableList.of("b"), HiveType.HIVE_LONG, - BIGINT); + BIGINT, + ImmutableList.of()); HiveColumnHandle projectedColumn = new HiveColumnHandle( "struct_col", diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFileFormats.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFileFormats.java index da439242ccfc..a9bd68502c27 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFileFormats.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFileFormats.java @@ -2017,7 +2017,7 @@ public HiveColumnHandle toHiveColumnHandle(int columnIndex) columnIndex, toHiveType(baseType), baseType, - Optional.of(new HiveColumnProjectionInfo(ImmutableList.of(0), ImmutableList.of(name), toHiveType(type), type)), + Optional.of(new HiveColumnProjectionInfo(ImmutableList.of(0), ImmutableList.of(name), toHiveType(type), type, ImmutableList.of())), partitionKey ? PARTITION_KEY : REGULAR, Optional.empty()); } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveReaderProjectionsUtil.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveReaderProjectionsUtil.java index abaabfd342a1..f5cd7da289ee 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveReaderProjectionsUtil.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveReaderProjectionsUtil.java @@ -76,7 +76,7 @@ public static HiveColumnHandle createProjectedColumnHandle(HiveColumnHandle colu List names = getHiveDereferenceNames(baseHiveType, indices); HiveType hiveType = getHiveTypeForDereferences(baseHiveType, indices).get(); - HiveColumnProjectionInfo columnProjection = new HiveColumnProjectionInfo(indices, names, hiveType, TESTING_TYPE_MANAGER.getType(getTypeSignature(hiveType))); + HiveColumnProjectionInfo columnProjection = new HiveColumnProjectionInfo(indices, names, hiveType, TESTING_TYPE_MANAGER.getType(getTypeSignature(hiveType)), ImmutableList.of()); return new HiveColumnHandle( column.getBaseColumnName(), diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestConnectorPushdownRulesWithHive.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestConnectorPushdownRulesWithHive.java index 844787214493..d4e2a06e355e 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestConnectorPushdownRulesWithHive.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestConnectorPushdownRulesWithHive.java @@ -23,6 +23,7 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.metastore.Database; import io.trino.metastore.HiveMetastore; +import io.trino.plugin.base.subfield.Subfield; import io.trino.plugin.hive.HiveColumnHandle; import io.trino.plugin.hive.HiveColumnProjectionInfo; import io.trino.plugin.hive.HiveTableHandle; @@ -34,6 +35,7 @@ import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.security.PrincipalType; +import io.trino.spi.type.ArrayType; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.sql.ir.Call; @@ -41,14 +43,20 @@ import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.ir.FieldReference; +import io.trino.sql.ir.Lambda; import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Row; import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.assertions.SymbolAliases; import io.trino.sql.planner.iterative.rule.PruneTableScanColumns; +import io.trino.sql.planner.iterative.rule.PushFieldReferenceLambdaIntoTableScan; +import io.trino.sql.planner.iterative.rule.PushFieldReferenceLambdaThroughFilterIntoTableScan; import io.trino.sql.planner.iterative.rule.PushPredicateIntoTableScan; import io.trino.sql.planner.iterative.rule.PushProjectionIntoTableScan; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; import io.trino.testing.PlanTester; +import io.trino.type.FunctionType; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Test; @@ -60,7 +68,9 @@ import static com.google.common.io.MoreFiles.deleteRecursively; import static com.google.common.io.RecursiveDeleteOption.ALLOW_INSECURE; +import static io.trino.SystemSessionProperties.ENABLE_PUSH_FIELD_DEREFERENCE_LAMBDA_INTO_SCAN; import static io.trino.metastore.HiveType.HIVE_INT; +import static io.trino.operator.scalar.ArrayTransformFunction.ARRAY_TRANSFORM_NAME; import static io.trino.plugin.hive.HiveColumnHandle.ColumnType.REGULAR; import static io.trino.plugin.hive.HiveColumnHandle.createBaseColumn; import static io.trino.plugin.hive.HiveQueryRunner.HIVE_CATALOG; @@ -69,7 +79,9 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.RowType.field; +import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; @@ -97,8 +109,22 @@ public class TestConnectorPushdownRulesWithHive private static final Session HIVE_SESSION = testSessionBuilder() .setCatalog(HIVE_CATALOG) .setSchema(SCHEMA_NAME) + .setSystemProperty(ENABLE_PUSH_FIELD_DEREFERENCE_LAMBDA_INTO_SCAN, "true") .build(); + private static final Type ANONYMOUS_ROW_TYPE = RowType.anonymous(ImmutableList.of(BIGINT, BIGINT)); + private static final Type PRUNED_ROW_TYPE = RowType.anonymous(ImmutableList.of(BIGINT)); + private static final Type PRUNED_ARRAY_ROW_TYPE = new ArrayType(PRUNED_ROW_TYPE); + private static final Type ARRAY_ROW_TYPE = new ArrayType(ANONYMOUS_ROW_TYPE); + private static final Type ROW_ARRAY_ROW_TYPE = RowType.anonymous(ImmutableList.of(ARRAY_ROW_TYPE)); + private static final Reference LAMBDA_ELEMENT_REFERENCE = new Reference(ROW_TYPE, "transformarray$element"); + private static final ResolvedFunction TRANSFORM = FUNCTIONS.resolveFunction(ARRAY_TRANSFORM_NAME, fromTypes(ARRAY_ROW_TYPE, new FunctionType(ImmutableList.of(ANONYMOUS_ROW_TYPE), PRUNED_ROW_TYPE))); + private static final Call dereferenceFunctionCall = new Call(TRANSFORM, ImmutableList.of(new Reference(ARRAY_ROW_TYPE, "array_of_struct"), + new Lambda(ImmutableList.of(new Symbol(ANONYMOUS_ROW_TYPE, "transformarray$element")), + new Row(ImmutableList.of(new FieldReference( + LAMBDA_ELEMENT_REFERENCE, + 0)))))); + @Override protected Optional createPlanTester() { @@ -149,7 +175,8 @@ public void testProjectionPushdown() ImmutableList.of(0), ImmutableList.of("a"), toHiveType(BIGINT), - BIGINT)), + BIGINT, + ImmutableList.of())), REGULAR, Optional.empty()); @@ -304,7 +331,8 @@ public void testPushdownWithDuplicateExpressions() ImmutableList.of(0), ImmutableList.of("a"), toHiveType(BIGINT), - BIGINT)), + BIGINT, + ImmutableList.of())), REGULAR, Optional.empty()); @@ -359,6 +387,466 @@ public void testPushdownWithDuplicateExpressions() metastore.dropTable(SCHEMA_NAME, tableName, true); } + @Test + public void testDereferenceInFieldReferenceLambdaPushdown() + { + String tableName = "array_filter_dereference_projection_test"; + PushFieldReferenceLambdaIntoTableScan pushFieldReferenceLambdaIntoTableScan = + new PushFieldReferenceLambdaIntoTableScan( + tester().getPlannerContext()); + + tester().getPlanTester().executeStatement(format( + "CREATE TABLE %s (array_of_struct) AS " + + "SELECT cast(ARRAY[ROW(1, 2), ROW(3, 4)] as ARRAY(ROW(a bigint, b bigint))) as array_of_struct", + tableName)); + + HiveColumnHandle partialColumn = new HiveColumnHandle( + "array_of_struct", + 0, + toHiveType(new ArrayType(ROW_TYPE)), + new ArrayType(ROW_TYPE), + Optional.of(new HiveColumnProjectionInfo( + ImmutableList.of(), + ImmutableList.of(), + toHiveType(new ArrayType(ROW_TYPE)), + new ArrayType(ROW_TYPE), + ImmutableList.of(new Subfield("array_of_struct", + ImmutableList.of(Subfield.AllSubscripts.getInstance(), + new Subfield.NestedField("a")))))), + REGULAR, + Optional.empty()); + + HiveTableHandle hiveTable = new HiveTableHandle(SCHEMA_NAME, tableName, + ImmutableMap.of(), ImmutableList.of(), ImmutableList.of(), Optional.empty()); + TableHandle table = new TableHandle(catalogHandle, hiveTable, new HiveTransactionHandle(false)); + + HiveColumnHandle fullColumn = partialColumn.getBaseColumn(); + + // Base symbol referenced by other assignments, skip the optimization + tester().assertThat(pushFieldReferenceLambdaIntoTableScan) + .on(p -> + p.project( + Assignments.of(p.symbol("pruned_nested_array", PRUNED_ARRAY_ROW_TYPE), + dereferenceFunctionCall, + p.symbol("array_of_struct", ARRAY_ROW_TYPE), + p.symbol("array_of_struct", ARRAY_ROW_TYPE).toSymbolReference()), + p.tableScan( + table, + ImmutableList.of(p.symbol("array_of_struct", ARRAY_ROW_TYPE)), + ImmutableMap.of(p.symbol("array_of_struct", ARRAY_ROW_TYPE), fullColumn)))) + .doesNotFire(); + + // No subscript lambda exists, skip the optimization + tester().assertThat(pushFieldReferenceLambdaIntoTableScan) + .on(p -> + p.project( + Assignments.of(p.symbol("array_of_struct", ARRAY_ROW_TYPE), + p.symbol("array_of_struct", ARRAY_ROW_TYPE).toSymbolReference()), + p.tableScan( + table, + ImmutableList.of(p.symbol("array_of_struct", ARRAY_ROW_TYPE)), + ImmutableMap.of(p.symbol("array_of_struct", ARRAY_ROW_TYPE), fullColumn)))) + .doesNotFire(); + + // Transform input argument is not symbol reference, skip the optimization + Call nestedSubscriptfunctionCall = new Call(FUNCTIONS.resolveFunction(ARRAY_TRANSFORM_NAME, fromTypes(ARRAY_ROW_TYPE, new FunctionType(ImmutableList.of(ANONYMOUS_ROW_TYPE), PRUNED_ROW_TYPE))), ImmutableList.of(new FieldReference(new Reference(ROW_ARRAY_ROW_TYPE, "struct_of_array_of_struct"), 0), + new Lambda(ImmutableList.of(new Symbol(ANONYMOUS_ROW_TYPE, "transformarray$element")), + new Row(ImmutableList.of(new FieldReference( + LAMBDA_ELEMENT_REFERENCE, + 0)))))); + + tester().assertThat(pushFieldReferenceLambdaIntoTableScan) + .on(p -> + p.project( + Assignments.of(p.symbol("pruned_nested_array", PRUNED_ARRAY_ROW_TYPE), + nestedSubscriptfunctionCall), + p.tableScan( + table, + ImmutableList.of(p.symbol("struct_of_array_of_struct", ROW_ARRAY_ROW_TYPE)), + ImmutableMap.of(p.symbol("struct_of_array_of_struct", ROW_ARRAY_ROW_TYPE), fullColumn)))) + .doesNotFire(); + + // If already applied and same subfields generated, will not re-apply + tester().assertThat(pushFieldReferenceLambdaIntoTableScan) + .on(p -> + p.project( + Assignments.of(p.symbol("pruned_nested_array", PRUNED_ARRAY_ROW_TYPE), dereferenceFunctionCall), + p.tableScan( + table, + ImmutableList.of(p.symbol("array_of_struct", ARRAY_ROW_TYPE)), + ImmutableMap.of(p.symbol("array_of_struct", ARRAY_ROW_TYPE), partialColumn)))) + .doesNotFire(); + + // Overwrite the existing subfields with latest + HiveColumnHandle previousColumnHandle = new HiveColumnHandle( + "array_of_struct", + 0, + toHiveType(new ArrayType(ROW_TYPE)), + new ArrayType(ROW_TYPE), + Optional.of(new HiveColumnProjectionInfo( + ImmutableList.of(), + ImmutableList.of(), + toHiveType(new ArrayType(ROW_TYPE)), + new ArrayType(ROW_TYPE), + ImmutableList.of(new Subfield("array_of_struct", + ImmutableList.of(Subfield.AllSubscripts.getInstance(), + new Subfield.NestedField("previous")))))), + REGULAR, + Optional.empty()); + + tester().assertThat(pushFieldReferenceLambdaIntoTableScan) + .on(p -> + p.project( + Assignments.of(p.symbol("pruned_nested_array", PRUNED_ARRAY_ROW_TYPE), dereferenceFunctionCall), + p.tableScan( + table, + ImmutableList.of(p.symbol("array_of_struct", ARRAY_ROW_TYPE)), + ImmutableMap.of(p.symbol("array_of_struct", ARRAY_ROW_TYPE), previousColumnHandle)))) + .matches( + project( + ImmutableMap.of("pruned_nested_array", expression(dereferenceFunctionCall, Optional.of(SymbolAliases.builder().put("transformarray$element", LAMBDA_ELEMENT_REFERENCE).build()))), + tableScan( + hiveTable.withProjectedColumns(ImmutableSet.of(partialColumn))::equals, + TupleDomain.all(), + ImmutableMap.of("array_of_struct", partialColumn::equals)))); + + // Subfields are added based on the subscript lambda + tester().assertThat(pushFieldReferenceLambdaIntoTableScan) + .on(p -> + p.project( + Assignments.of(p.symbol("pruned_nested_array", PRUNED_ARRAY_ROW_TYPE), dereferenceFunctionCall), + p.tableScan( + table, + ImmutableList.of(p.symbol("array_of_struct", ARRAY_ROW_TYPE)), + ImmutableMap.of(p.symbol("array_of_struct", ARRAY_ROW_TYPE), fullColumn)))) + .matches( + project( + ImmutableMap.of("pruned_nested_array", expression(dereferenceFunctionCall, Optional.of(SymbolAliases.builder().put("transformarray$element", LAMBDA_ELEMENT_REFERENCE).build()))), + tableScan( + hiveTable.withProjectedColumns(ImmutableSet.of(partialColumn))::equals, + TupleDomain.all(), + ImmutableMap.of("array_of_struct", partialColumn::equals)))); + + // Subfields are added based on the subscript lambda, and extends the existing prefix + HiveColumnHandle nestedColumn = new HiveColumnHandle( + "struct_of_array_of_struct", + 0, + toHiveType(RowType.from(asList(field("array", new ArrayType(ROW_TYPE))))), + RowType.from(asList(field("array", new ArrayType(ROW_TYPE)))), + Optional.of(new HiveColumnProjectionInfo( + ImmutableList.of(0), + ImmutableList.of("array"), + toHiveType(new ArrayType(ROW_TYPE)), + new ArrayType(ROW_TYPE), + ImmutableList.of())), + REGULAR, + Optional.empty()); + + HiveColumnHandle nestedPartialColumn = new HiveColumnHandle( + "struct_of_array_of_struct", + 0, + toHiveType(RowType.from(asList(field("array", new ArrayType(ROW_TYPE))))), + RowType.from(asList(field("array", new ArrayType(ROW_TYPE)))), + Optional.of(new HiveColumnProjectionInfo( + ImmutableList.of(0), + ImmutableList.of("array"), + toHiveType(new ArrayType(ROW_TYPE)), + new ArrayType(ROW_TYPE), + ImmutableList.of(new Subfield("struct_of_array_of_struct", + ImmutableList.of(new Subfield.NestedField("array"), + Subfield.AllSubscripts.getInstance(), + new Subfield.NestedField("a")))))), + REGULAR, + Optional.empty()); + + tester().assertThat(pushFieldReferenceLambdaIntoTableScan) + .on(p -> + p.project( + Assignments.of(p.symbol("pruned_nested_array", PRUNED_ARRAY_ROW_TYPE), dereferenceFunctionCall), + p.tableScan( + table, + ImmutableList.of(p.symbol("array_of_struct", ARRAY_ROW_TYPE)), + ImmutableMap.of(p.symbol("array_of_struct", ARRAY_ROW_TYPE), nestedColumn)))) + .matches( + project( + ImmutableMap.of("pruned_nested_array", expression(dereferenceFunctionCall, Optional.of(SymbolAliases.builder().put("transformarray$element", LAMBDA_ELEMENT_REFERENCE).build()))), + tableScan( + hiveTable.withProjectedColumns(ImmutableSet.of(nestedPartialColumn))::equals, + TupleDomain.all(), + ImmutableMap.of("array_of_struct", nestedPartialColumn::equals)))); + + metastore.dropTable(SCHEMA_NAME, tableName, true); + } + + @Test + public void testDereferenceInSubscriptLambdaPushdownThroughFilter() + { + String tableName = "array_filter_dereference_with_filter_projection_test"; + PushFieldReferenceLambdaThroughFilterIntoTableScan pushFieldReferenceLambdaThroughFilterIntoTableScan = + new PushFieldReferenceLambdaThroughFilterIntoTableScan( + tester().getPlannerContext()); + + tester().getPlanTester().executeStatement(format( + "CREATE TABLE %s (array_of_struct) AS " + + "SELECT cast(ARRAY[ROW(1, 2), ROW(3, 4)] as ARRAY(ROW(a bigint, b bigint))) as array_of_struct", + tableName)); + + HiveColumnHandle partialColumn = new HiveColumnHandle( + "array_of_struct", + 0, + toHiveType(new ArrayType(ROW_TYPE)), + new ArrayType(ROW_TYPE), + Optional.of(new HiveColumnProjectionInfo( + ImmutableList.of(), + ImmutableList.of(), + toHiveType(new ArrayType(ROW_TYPE)), + new ArrayType(ROW_TYPE), + ImmutableList.of(new Subfield("array_of_struct", + ImmutableList.of(Subfield.AllSubscripts.getInstance(), + new Subfield.NestedField("a")))))), + REGULAR, + Optional.empty()); + + HiveColumnHandle bigIntColumn = new HiveColumnHandle( + "e", + 0, + toHiveType(BIGINT), + BIGINT, + Optional.empty(), + REGULAR, + Optional.empty()); + + HiveTableHandle hiveTable = new HiveTableHandle(SCHEMA_NAME, tableName, + ImmutableMap.of(), ImmutableList.of(), ImmutableList.of(), Optional.empty()); + TableHandle table = new TableHandle(catalogHandle, hiveTable, new HiveTransactionHandle(false)); + + HiveColumnHandle fullColumn = partialColumn.getBaseColumn(); + + // Base symbol referenced by other assignments, skip the optimization + tester().assertThat(pushFieldReferenceLambdaThroughFilterIntoTableScan) + .on(p -> { + Symbol e = p.symbol("e", BIGINT); + return p.project( + Assignments.of(p.symbol("pruned_nested_array", PRUNED_ARRAY_ROW_TYPE), + dereferenceFunctionCall, + p.symbol("array_of_struct", ARRAY_ROW_TYPE), + p.symbol("array_of_struct", ARRAY_ROW_TYPE).toSymbolReference()), + p.filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "e"), new Constant(BIGINT, 1L)), + p.tableScan( + table, + ImmutableList.of(p.symbol("array_of_struct", ARRAY_ROW_TYPE), e), + ImmutableMap.of(p.symbol("array_of_struct", ARRAY_ROW_TYPE), fullColumn, + e, bigIntColumn)))); + }) + .doesNotFire(); + + // No subscript lambda exists, skip the optimization + tester().assertThat(pushFieldReferenceLambdaThroughFilterIntoTableScan) + .on(p -> { + Symbol e = p.symbol("e", BIGINT); + return p.project( + Assignments.of(p.symbol("array_of_struct", ARRAY_ROW_TYPE), + p.symbol("array_of_struct", ARRAY_ROW_TYPE).toSymbolReference()), + p.filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "e"), new Constant(BIGINT, 1L)), + p.tableScan( + table, + ImmutableList.of(p.symbol("array_of_struct", ARRAY_ROW_TYPE), e), + ImmutableMap.of(p.symbol("array_of_struct", ARRAY_ROW_TYPE), fullColumn, + e, bigIntColumn)))); + }) + .doesNotFire(); + + // Transform input argument is not symbol reference, skip the optimization + Call nestedSubscriptfunctionCall = new Call(FUNCTIONS.resolveFunction(ARRAY_TRANSFORM_NAME, fromTypes(ARRAY_ROW_TYPE, new FunctionType(ImmutableList.of(ANONYMOUS_ROW_TYPE), PRUNED_ROW_TYPE))), ImmutableList.of(new FieldReference(new Reference(ROW_ARRAY_ROW_TYPE, "struct_of_array_of_struct"), 0), + new Lambda(ImmutableList.of(new Symbol(ANONYMOUS_ROW_TYPE, "transformarray$element")), + new Row(ImmutableList.of(new FieldReference( + LAMBDA_ELEMENT_REFERENCE, + 0)))))); + + tester().assertThat(pushFieldReferenceLambdaThroughFilterIntoTableScan) + .on(p -> { + Symbol e = p.symbol("e", BIGINT); + return p.project( + Assignments.of(p.symbol("pruned_nested_array", PRUNED_ARRAY_ROW_TYPE), + nestedSubscriptfunctionCall), + p.filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "e"), new Constant(BIGINT, 1L)), + p.tableScan( + table, + ImmutableList.of(p.symbol("array_of_struct", ARRAY_ROW_TYPE), e), + ImmutableMap.of(p.symbol("array_of_struct", ARRAY_ROW_TYPE), fullColumn, + e, bigIntColumn)))); + }) + .doesNotFire(); + + // If already applied and same subfields generated, will not re-apply + tester().assertThat(pushFieldReferenceLambdaThroughFilterIntoTableScan) + .on(p -> { + Symbol e = p.symbol("e", BIGINT); + return p.project( + Assignments.of(p.symbol("pruned_nested_array", PRUNED_ARRAY_ROW_TYPE), dereferenceFunctionCall), + p.filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "e"), new Constant(BIGINT, 1L)), + p.tableScan( + table, + ImmutableList.of(p.symbol("array_of_struct", ARRAY_ROW_TYPE), e), + ImmutableMap.of(p.symbol("array_of_struct", ARRAY_ROW_TYPE), partialColumn, + e, bigIntColumn)))); + }) + .doesNotFire(); + + // Overwrite the existing subfields with latest + HiveColumnHandle previousColumnHandle = new HiveColumnHandle( + "array_of_struct", + 0, + toHiveType(new ArrayType(ROW_TYPE)), + new ArrayType(ROW_TYPE), + Optional.of(new HiveColumnProjectionInfo( + ImmutableList.of(), + ImmutableList.of(), + toHiveType(new ArrayType(ROW_TYPE)), + new ArrayType(ROW_TYPE), + ImmutableList.of(new Subfield("array_of_struct", + ImmutableList.of(Subfield.AllSubscripts.getInstance(), + new Subfield.NestedField("previous")))))), + REGULAR, + Optional.empty()); + + tester().assertThat(pushFieldReferenceLambdaThroughFilterIntoTableScan) + .on(p -> { + Symbol e = p.symbol("e", BIGINT); + return p.project( + Assignments.of(p.symbol("pruned_nested_array", PRUNED_ARRAY_ROW_TYPE), dereferenceFunctionCall), + p.filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "e"), new Constant(BIGINT, 1L)), + p.tableScan( + table, + ImmutableList.of(p.symbol("array_of_struct", ARRAY_ROW_TYPE), e), + ImmutableMap.of(p.symbol("array_of_struct", ARRAY_ROW_TYPE), previousColumnHandle, + e, bigIntColumn)))); + }) + .matches( + project( + ImmutableMap.of("pruned_nested_array", expression(dereferenceFunctionCall, Optional.of(SymbolAliases.builder().put("transformarray$element", LAMBDA_ELEMENT_REFERENCE).build()))), + filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "e"), new Constant(BIGINT, 1L)), + tableScan( + hiveTable.withProjectedColumns(ImmutableSet.of(partialColumn, bigIntColumn))::equals, + TupleDomain.all(), + ImmutableMap.of("array_of_struct", partialColumn::equals, "e", bigIntColumn::equals))))); + + // Subfields are added based on the subscript lambda + tester().assertThat(pushFieldReferenceLambdaThroughFilterIntoTableScan) + .on(p -> { + Symbol e = p.symbol("e", BIGINT); + return p.project( + Assignments.of(p.symbol("pruned_nested_array", PRUNED_ARRAY_ROW_TYPE), dereferenceFunctionCall), + p.filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "e"), new Constant(BIGINT, 1L)), + p.tableScan( + table, + ImmutableList.of(p.symbol("array_of_struct", ARRAY_ROW_TYPE), e), + ImmutableMap.of(p.symbol("array_of_struct", ARRAY_ROW_TYPE), fullColumn, + e, bigIntColumn)))); + }) + .matches( + project( + ImmutableMap.of("pruned_nested_array", expression(dereferenceFunctionCall, Optional.of(SymbolAliases.builder().put("transformarray$element", LAMBDA_ELEMENT_REFERENCE).build()))), + filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "e"), new Constant(BIGINT, 1L)), + tableScan( + hiveTable.withProjectedColumns(ImmutableSet.of(partialColumn, bigIntColumn))::equals, + TupleDomain.all(), + ImmutableMap.of("array_of_struct", partialColumn::equals, "e", bigIntColumn::equals))))); + + // Subfields are added based on the subscript lambda, and extends the existing prefix + HiveColumnHandle nestedColumn = new HiveColumnHandle( + "struct_of_array_of_struct", + 0, + toHiveType(RowType.from(asList(field("array", new ArrayType(ROW_TYPE))))), + RowType.from(asList(field("array", new ArrayType(ROW_TYPE)))), + Optional.of(new HiveColumnProjectionInfo( + ImmutableList.of(0), + ImmutableList.of("array"), + toHiveType(new ArrayType(ROW_TYPE)), + new ArrayType(ROW_TYPE), + ImmutableList.of())), + REGULAR, + Optional.empty()); + + HiveColumnHandle nestedPartialColumn = new HiveColumnHandle( + "struct_of_array_of_struct", + 0, + toHiveType(RowType.from(asList(field("array", new ArrayType(ROW_TYPE))))), + RowType.from(asList(field("array", new ArrayType(ROW_TYPE)))), + Optional.of(new HiveColumnProjectionInfo( + ImmutableList.of(0), + ImmutableList.of("array"), + toHiveType(new ArrayType(ROW_TYPE)), + new ArrayType(ROW_TYPE), + ImmutableList.of(new Subfield("struct_of_array_of_struct", + ImmutableList.of(new Subfield.NestedField("array"), + Subfield.AllSubscripts.getInstance(), + new Subfield.NestedField("a")))))), + REGULAR, + Optional.empty()); + + tester().assertThat(pushFieldReferenceLambdaThroughFilterIntoTableScan) + .on(p -> { + Symbol e = p.symbol("e", BIGINT); + return p.project( + Assignments.of(p.symbol("pruned_nested_array", PRUNED_ARRAY_ROW_TYPE), dereferenceFunctionCall), + p.filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "e"), new Constant(BIGINT, 1L)), + p.tableScan( + table, + ImmutableList.of(p.symbol("array_of_struct", ARRAY_ROW_TYPE), e), + ImmutableMap.of(p.symbol("array_of_struct", ARRAY_ROW_TYPE), nestedColumn, + e, bigIntColumn)))); + }) + .matches( + project( + ImmutableMap.of("pruned_nested_array", expression(dereferenceFunctionCall, Optional.of(SymbolAliases.builder().put("transformarray$element", LAMBDA_ELEMENT_REFERENCE).build()))), + filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "e"), new Constant(BIGINT, 1L)), + tableScan( + hiveTable.withProjectedColumns(ImmutableSet.of(nestedPartialColumn, bigIntColumn))::equals, + TupleDomain.all(), + ImmutableMap.of("array_of_struct", nestedPartialColumn::equals, "e", bigIntColumn::equals))))); + + // PushProjectionIntoTableScan will not be impacted and will not remove any array subscript column + PushProjectionIntoTableScan pushProjectionIntoTableScan = new PushProjectionIntoTableScan( + tester().getPlannerContext(), + new ScalarStatsCalculator(tester().getPlannerContext())); + + HiveColumnHandle structPartialColumn = new HiveColumnHandle( + "struct_of_int", + 0, + toHiveType(ROW_TYPE), + ROW_TYPE, + Optional.of(new HiveColumnProjectionInfo( + ImmutableList.of(0), + ImmutableList.of("a"), + toHiveType(BIGINT), + BIGINT, + ImmutableList.of())), + REGULAR, + Optional.empty()); + + tester().assertThat(pushProjectionIntoTableScan) + .on(p -> + p.project( + Assignments.of( + p.symbol("expr_deref", BIGINT), new FieldReference(p.symbol("struct_of_int", ROW_TYPE).toSymbolReference(), 0), + p.symbol("nested_array", PRUNED_ARRAY_ROW_TYPE), dereferenceFunctionCall), + p.tableScan( + table, + ImmutableList.of(p.symbol("struct_of_int", ROW_TYPE), p.symbol("array_of_struct", ARRAY_ROW_TYPE)), + ImmutableMap.of(p.symbol("struct_of_int", ROW_TYPE), structPartialColumn.getBaseColumn(), + p.symbol("array_of_struct", ARRAY_ROW_TYPE), partialColumn.getBaseColumn())))) + .matches(project( + ImmutableMap.of("expr_deref", expression(new Reference(BIGINT, "struct_of_int#a")), + "nested_array", expression(dereferenceFunctionCall, Optional.of(SymbolAliases.builder().put("transformarray$element", LAMBDA_ELEMENT_REFERENCE).build()))), + tableScan( + hiveTable.withProjectedColumns(ImmutableSet.of(structPartialColumn, partialColumn.getBaseColumn()))::equals, + TupleDomain.all(), + ImmutableMap.of("struct_of_int#a", structPartialColumn::equals, "array_of_struct", partialColumn.getBaseColumn()::equals)))); + + metastore.dropTable(SCHEMA_NAME, tableName, true); + } + @AfterAll public void cleanup() throws IOException diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/orc/TestOrcPredicates.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/orc/TestOrcPredicates.java index faaf64abc117..b6eb6b1efd51 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/orc/TestOrcPredicates.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/orc/TestOrcPredicates.java @@ -96,7 +96,8 @@ class TestOrcPredicates ImmutableList.of(1), ImmutableList.of("field1"), HiveType.HIVE_LONG, - BIGINT)), + BIGINT, + ImmutableList.of())), STRUCT_COLUMN.getColumnType(), STRUCT_COLUMN.getComment()); private static final List PROJECTED_COLUMNS = ImmutableList.of(BIGINT_COLUMN, STRUCT_FIELD1_COLUMN); diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestParquetPageSourceFactory.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestParquetPageSourceFactory.java index 48e92a420613..a0244e27bdc2 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestParquetPageSourceFactory.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestParquetPageSourceFactory.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableList; import io.trino.metastore.HiveType; +import io.trino.plugin.base.subfield.Subfield; import io.trino.plugin.hive.HiveColumnHandle; import io.trino.plugin.hive.HiveColumnProjectionInfo; import io.trino.spi.type.IntegerType; @@ -22,6 +23,7 @@ import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; import org.junit.jupiter.api.Test; import java.util.Optional; @@ -36,6 +38,11 @@ public class TestParquetPageSourceFactory { + private static final String OPTIONAL_LEVEL_1 = "optional_level1"; + private static final String OPTIONAL_LEVEL_2 = "optional_level2"; + private static final String REQUIRED_LEVEL_3_0 = "required_level3_0"; + private static final String REQUIRED_LEVEL_3_1 = "required_level3_1"; + @Test public void testGetNestedMixedRepetitionColumnType() { @@ -43,6 +50,55 @@ public void testGetNestedMixedRepetitionColumnType() testGetNestedMixedRepetitionColumnType(false); } + @Test + public void testSchemaPruningWithSubfields() + { + RowType rowType = rowType( + RowType.field( + OPTIONAL_LEVEL_2, + rowType(RowType.field( + REQUIRED_LEVEL_3_0, + IntegerType.INTEGER), + RowType.field( + REQUIRED_LEVEL_3_1, + IntegerType.INTEGER)))); + // Parquet schema is pruned by Subfields, not the name/index lists + HiveColumnHandle columnHandle = new HiveColumnHandle( + OPTIONAL_LEVEL_1, + 0, + HiveType.valueOf("struct>"), + rowType, + Optional.of( + new HiveColumnProjectionInfo( + ImmutableList.of(1), + ImmutableList.of(OPTIONAL_LEVEL_2), + toHiveType(IntegerType.INTEGER), + IntegerType.INTEGER, + ImmutableList.of(new Subfield(OPTIONAL_LEVEL_1, ImmutableList.of(new Subfield.NestedField(OPTIONAL_LEVEL_2), new Subfield.NestedField("required_level3_0")))))), + REGULAR, + Optional.empty()); + MessageType fileSchema = new MessageType( + "hive_schema", + new GroupType(OPTIONAL, OPTIONAL_LEVEL_1, + new GroupType(OPTIONAL, OPTIONAL_LEVEL_2, + new PrimitiveType(REQUIRED, INT32, REQUIRED_LEVEL_3_0), + new PrimitiveType(REQUIRED, INT32, REQUIRED_LEVEL_3_1)))); + + MessageType prunedFileSchema = new MessageType( + "hive_schema", + new GroupType(OPTIONAL, OPTIONAL_LEVEL_1, + new GroupType(OPTIONAL, OPTIONAL_LEVEL_2, + new PrimitiveType(REQUIRED, INT32, REQUIRED_LEVEL_3_0)))); + + // Hive column name is based on subscript names, while schema pruning should still use original column name + Type newType = ParquetPageSourceFactory.getColumnType(columnHandle, fileSchema, true).get(); + assertThat(OPTIONAL_LEVEL_1.equals(columnHandle.getBaseColumnName())); + assertThat(columnHandle.getName().equals(OPTIONAL_LEVEL_1 + "#" + OPTIONAL_LEVEL_2)); + assertThat(newType.getName().equals(columnHandle.getBaseColumnName())); + // Parquet schema is pruned and level 3_1 is removed + assertThat(newType.equals(prunedFileSchema.getType(OPTIONAL_LEVEL_1))); + } + private void testGetNestedMixedRepetitionColumnType(boolean useColumnNames) { RowType rowType = rowType( @@ -61,7 +117,8 @@ private void testGetNestedMixedRepetitionColumnType(boolean useColumnNames) ImmutableList.of(1, 1), ImmutableList.of("optional_level2", "required_level3"), toHiveType(IntegerType.INTEGER), - IntegerType.INTEGER)), + IntegerType.INTEGER, + ImmutableList.of())), REGULAR, Optional.empty()); MessageType fileSchema = new MessageType( diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestPruneTypeForPath.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestPruneTypeForPath.java new file mode 100644 index 000000000000..cdf4e9efec24 --- /dev/null +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestPruneTypeForPath.java @@ -0,0 +1,512 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive.parquet; + +import io.trino.plugin.base.subfield.Subfield.AllSubscripts; +import io.trino.plugin.base.subfield.Subfield.NestedField; +import io.trino.plugin.base.subfield.Subfield.PathElement; +import io.trino.plugin.base.subfield.Subfield.StringSubscript; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.OriginalType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; +import org.testng.annotations.Test; + +import java.util.List; + +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static org.assertj.core.api.Assertions.assertThat; + +// Class to test pruneColumnTypeForPath. Direct referenced from Presto +public class TestPruneTypeForPath +{ + @Test + public void fallback() + { + GroupType originalType = groupType("col", primitiveType("subfield1"), primitiveType("subfield2")); + + // Request non-existent field + List path = singletonList(new NestedField("subField3")); + + Type actualPrunedType = ParquetPageSourceFactory.pruneColumnTypeForPath(originalType, path); + assertThat(originalType.equals(actualPrunedType)); + } + + @Test + public void noPruning() + { + GroupType originalType = groupType("col", primitiveType("subField1")); + + List path = path(nestedField("subField1")); + + Type actualPrunedType = ParquetPageSourceFactory.pruneColumnTypeForPath(originalType, path); + assertThat(originalType.equals(actualPrunedType)); + } + + @Test + public void oneLevelNesting() + { + GroupType originalType = groupType("col", primitiveType("subField1"), primitiveType("subField2")); + + List path = path(nestedField("subField2")); + + GroupType expectedPrunedType = groupType("col", primitiveType("subField2")); + + Type actualPrunedType = ParquetPageSourceFactory.pruneColumnTypeForPath(originalType, path); + assertThat(expectedPrunedType.equals(actualPrunedType)); + } + + @Test + public void multiLevelNesting() + { + GroupType originalType = groupType("col", + primitiveType("subField1"), + groupType("subField2", + primitiveType("subField3"), + primitiveType("subField4"))); + + List path = path(nestedField("subField2"), nestedField("subField3")); + + GroupType expectedPrunedType = groupType("col", + groupType("subField2", + primitiveType("subField3"))); + + Type actualPrunedType = ParquetPageSourceFactory.pruneColumnTypeForPath(originalType, path); + assertThat(expectedPrunedType.equals(actualPrunedType)); + } + + @Test + public void arrayOfStructs() + { + GroupType originalType = groupType("col", OriginalType.LIST, + array( + primitiveType("subField1"), + groupType("subField2", + primitiveType("subField3"), + primitiveType("subField4")))); + + List path = path(allSubscripts(), nestedField("subField2"), + nestedField("subField4")); + + GroupType expectedPrunedType = groupType("col", OriginalType.LIST, + array( + groupType("subField2", + primitiveType("subField4")))); + + Type actualPrunedType = ParquetPageSourceFactory.pruneColumnTypeForPath(originalType, path); + assertThat(expectedPrunedType.equals(actualPrunedType)); + } + + @Test + public void arrayOfStructsWithSystemFieldName() + { + GroupType originalType = groupType("col", OriginalType.LIST, + array( + primitiveType("subField1"), + groupType("array_element", // System field name + primitiveType("subField2"), + primitiveType("subField3")))); + + List path = path(allSubscripts(), nestedField("array_element"), nestedField("subField3")); + + GroupType expectedPrunedType = groupType("col", OriginalType.LIST, + array( + groupType("array_element", + primitiveType("subField3")))); + + Type actualPrunedType = ParquetPageSourceFactory.pruneColumnTypeForPath(originalType, path); + assertThat(expectedPrunedType.equals(actualPrunedType)); + } + + @Test + public void arrayOfStructsType2() + { + GroupType originalType = groupType("col", OriginalType.LIST, + arrayType2( + primitiveType("subField1"), + groupType("subField2", + primitiveType("subField3"), + primitiveType("subField4")))); + + List path = path(allSubscripts(), nestedField("subField2"), nestedField("subField4")); + + GroupType expectedPrunedType = groupType("col", OriginalType.LIST, + arrayType2( + groupType("subField2", + primitiveType("subField4")))); + + Type actualPrunedType = ParquetPageSourceFactory.pruneColumnTypeForPath(originalType, path); + assertThat(expectedPrunedType.equals(actualPrunedType)); + } + + @Test + public void arrayOfStructsType3() + { + GroupType originalType = groupType("col", OriginalType.LIST, + arrayType3("col", + primitiveType("subField1"), + groupType("subField2", + primitiveType("subField3"), + primitiveType("subField4")))); + + List path = path(allSubscripts(), nestedField("subField2"), nestedField("subField4")); + + GroupType expectedPrunedType = groupType("col", OriginalType.LIST, + arrayType3("col", + groupType("subField2", + primitiveType("subField4")))); + + Type actualPrunedType = ParquetPageSourceFactory.pruneColumnTypeForPath(originalType, path); + assertThat(expectedPrunedType.equals(actualPrunedType)); + } + + @Test + public void arrayOfStructsType4() + { + GroupType originalType = groupType("col", OriginalType.LIST, + arrayType4( + primitiveType("subField1"), + groupType("subField2", + primitiveType("subField3"), + primitiveType("subField4")))); + + List path = path(allSubscripts(), nestedField("subField2"), nestedField("subField4")); + + GroupType expectedPrunedType = groupType("col", OriginalType.LIST, + arrayType4( + groupType("subField2", + primitiveType("subField4")))); + + Type actualPrunedType = ParquetPageSourceFactory.pruneColumnTypeForPath(originalType, path); + assertThat(expectedPrunedType.equals(actualPrunedType)); + } + + @Test + public void structContainingPrimitiveArray() + { + GroupType originalType = groupType("col", + groupType("subField1", OriginalType.LIST, primitiveArray()), + primitiveType("subField2")); + + List path = path(nestedField("subField1"), allSubscripts()); + + GroupType expectedPrunedType = groupType("col", + groupType("subField1", OriginalType.LIST, primitiveArray())); + + Type actualPrunedType = ParquetPageSourceFactory.pruneColumnTypeForPath(originalType, path); + assertThat(expectedPrunedType.equals(actualPrunedType)); + } + + @Test + public void structContainingPrimitiveMap() + { + GroupType originalType = groupType("col", + groupType("subField1", OriginalType.LIST, map(primitiveType("value"))), + primitiveType("subField2")); + + List path = path(nestedField("subField1"), allSubscripts()); + + GroupType expectedPrunedType = groupType("col", + groupType("subField1", OriginalType.LIST, map(primitiveType("value")))); + + Type actualPrunedType = ParquetPageSourceFactory.pruneColumnTypeForPath(originalType, path); + assertThat(expectedPrunedType.equals(actualPrunedType)); + } + + @Test + public void structContainingPrimitiveArrayType2() + { + GroupType originalType = groupType("col", + groupType("subField1", OriginalType.LIST, primitiveArrayType2()), + primitiveType("subField2")); + + List path = path(nestedField("subField1"), allSubscripts()); + + GroupType expectedPrunedType = groupType("col", + groupType("subField1", OriginalType.LIST, primitiveArrayType2())); + + Type actualPrunedType = ParquetPageSourceFactory.pruneColumnTypeForPath(originalType, path); + assertThat(expectedPrunedType.equals(actualPrunedType)); + } + + @Test + public void structWithArrayOfStructs() + { + GroupType originalType = groupType("col", + primitiveType("subField5"), + groupType("arrayOfStructs", OriginalType.LIST, + array( + primitiveType("subField1"), + groupType("subField2", + primitiveType("subField3"), + primitiveType("subField4"))))); + + List path = path(nestedField("arrayOfStructs"), allSubscripts(), nestedField("subField2"), + nestedField("subField4")); + + GroupType expectedPrunedType = groupType("col", + groupType("arrayOfStructs", OriginalType.LIST, + array( + groupType("subField2", + primitiveType("subField4"))))); + + Type actualPrunedType = ParquetPageSourceFactory.pruneColumnTypeForPath(originalType, path); + assertThat(expectedPrunedType.equals(actualPrunedType)); + } + + @Test + public void pathTerminatedAtArray() + { + GroupType originalType = groupType("col", + primitiveType("subField5"), + groupType("arrayOfStructs", OriginalType.LIST, + array( + primitiveType("subField1"), + groupType("subField2", + primitiveType("subField3"), + primitiveType("subField4"))))); + + List path = path(nestedField("arrayOfStructs"), allSubscripts()); + + GroupType expectedPrunedType = groupType("col", + groupType("arrayOfStructs", OriginalType.LIST, + array( + primitiveType("subField1"), + groupType("subField2", + primitiveType("subField3"), + primitiveType("subField4"))))); + + Type actualPrunedType = ParquetPageSourceFactory.pruneColumnTypeForPath(originalType, path); + assertThat(expectedPrunedType.equals(actualPrunedType)); + } + + @Test + public void arrayOfArrayOfStructs() + { + GroupType originalType = groupType("col", + primitiveType("subField1"), + groupType("arrayOfArrays", OriginalType.LIST, + array(OriginalType.LIST, array( + groupType("subField2", + primitiveType("subField3"), + primitiveType("subField4")), + primitiveType("subField5"))))); + + List path = path(nestedField("arrayOfArrays"), allSubscripts(), allSubscripts(), + nestedField("subField2"), nestedField("subField4")); + + GroupType expectedPrunedType = groupType("col", + groupType("arrayOfArrays", OriginalType.LIST, + array(OriginalType.LIST, array( + groupType("subField2", + primitiveType("subField4")))))); + + Type actualPrunedType = ParquetPageSourceFactory.pruneColumnTypeForPath(originalType, path); + assertThat(expectedPrunedType.equals(actualPrunedType)); + } + + @Test + public void arrayOfArrayOfStructsWithSystemFieldName() + { + GroupType originalType = groupType("col", + primitiveType("subField1"), + groupType("arrayOfArrays", OriginalType.LIST, + array(OriginalType.LIST, array( + groupType("array_element", + primitiveType("subField2"), + primitiveType("array_element")), + primitiveType("subField4"))))); + + List path = path(nestedField("arrayOfArrays"), allSubscripts(), allSubscripts(), + nestedField("array_element"), nestedField("array_element")); + + GroupType expectedPrunedType = groupType("col", + groupType("arrayOfArrays", OriginalType.LIST, + array(OriginalType.LIST, array( + groupType("array_element", + primitiveType("array_element")))))); + + Type actualPrunedType = ParquetPageSourceFactory.pruneColumnTypeForPath(originalType, path); + assertThat(expectedPrunedType.equals(actualPrunedType)); + } + + @Test + public void mapOfStruct() + { + GroupType originalType = groupType("col", OriginalType.MAP, + map(groupType("value", + primitiveType("subField1"), + primitiveType("subField2")))); + + List path = path(subscript("index"), nestedField("subField2")); + + GroupType expectedPrunedType = groupType("col", OriginalType.MAP, + map(groupType("value", + primitiveType("subField2")))); + + Type actualPrunedType = ParquetPageSourceFactory.pruneColumnTypeForPath(originalType, path); + assertThat(expectedPrunedType.equals(actualPrunedType)); + } + + @Test + public void mapOfStructType2() + { + GroupType originalType = groupType("col", OriginalType.MAP_KEY_VALUE, + mapType2(groupType("value", + primitiveType("subField1"), + primitiveType("subField2")))); + + List path = path(subscript("index"), nestedField("subField2")); + + GroupType expectedPrunedType = groupType("col", OriginalType.MAP_KEY_VALUE, + mapType2(groupType("value", + primitiveType("subField2")))); + + Type actualPrunedType = ParquetPageSourceFactory.pruneColumnTypeForPath(originalType, path); + assertThat(expectedPrunedType.equals(actualPrunedType)); + } + + @Test + public void caseInsensitivityForLowercasePath() + { + GroupType originalType = groupType("col", OriginalType.MAP_KEY_VALUE, + primitiveType("subField1"), + primitiveType("subField2")); + + List lowercasePath = path(nestedField("subfield2")); + + GroupType expectedPrunedType = groupType("col", OriginalType.MAP_KEY_VALUE, primitiveType("subField2")); + + Type actualPrunedType = ParquetPageSourceFactory.pruneColumnTypeForPath(originalType, lowercasePath); + assertThat(expectedPrunedType.equals(actualPrunedType)); + } + + @Test + public void caseInsensitivityForUppercasePath() + { + GroupType originalType = groupType("col", OriginalType.MAP_KEY_VALUE, + primitiveType("subField1"), + primitiveType("subField2")); + + List uppercasePath = path(nestedField("SUBFIELD2")); + + GroupType expectedPrunedType = groupType("col", OriginalType.MAP_KEY_VALUE, primitiveType("subField2")); + + Type actualPrunedType = ParquetPageSourceFactory.pruneColumnTypeForPath(originalType, uppercasePath); + assertThat(expectedPrunedType.equals(actualPrunedType)); + } + + @Test + public void caseInsensitivityForMixedcasePath() + { + GroupType originalType = groupType("col", OriginalType.MAP_KEY_VALUE, + primitiveType("subField1"), + primitiveType("subField2")); + + List mixedcasePath = path(nestedField("sUbfiEld2")); + + GroupType expectedPrunedType = groupType("col", OriginalType.MAP_KEY_VALUE, primitiveType("subField2")); + + Type actualPrunedType = ParquetPageSourceFactory.pruneColumnTypeForPath(originalType, mixedcasePath); + assertThat(expectedPrunedType.equals(actualPrunedType)); + } + + private static GroupType groupType(String name, Type... fields) + { + return new GroupType(Type.Repetition.OPTIONAL, name, fields); + } + + private static GroupType groupType(String name, OriginalType originalType, Type... fields) + { + return new GroupType(Type.Repetition.OPTIONAL, name, originalType, fields); + } + + private static GroupType array(Type... fields) + { + return new GroupType(Type.Repetition.REPEATED, "bag", + new GroupType(Type.Repetition.OPTIONAL, "array_element", fields)); + } + + private static GroupType array(OriginalType originalType, Type... fields) + { + return new GroupType(Type.Repetition.REPEATED, "bag", + new GroupType(Type.Repetition.OPTIONAL, "array_element", originalType, fields)); + } + + // Two-level group to represent array instead of 3 with known name 'array' for second level + private static GroupType arrayType2(Type... fields) + { + return new GroupType(Type.Repetition.REPEATED, "array", fields); + } + + // Two-level group to represent array instead of 3 with known name '{parent}_tuple' for second level + private static GroupType arrayType3(String parentFieldName, Type... fields) + { + return new GroupType(Type.Repetition.REPEATED, parentFieldName + "_tuple", fields); + } + + // Two-level group to represent array instead of 3 identified since normally the second level has exactly one child + // but this case has multiple + private static GroupType arrayType4(Type... fields) + { + return new GroupType(Type.Repetition.REPEATED, "element", fields); + } + + private static GroupType primitiveArray() + { + return new GroupType(Type.Repetition.REPEATED, "bag", + new PrimitiveType(Type.Repetition.OPTIONAL, PrimitiveType.PrimitiveTypeName.INT32, "array_element")); + } + + private static PrimitiveType primitiveArrayType2() + { + return new PrimitiveType(Type.Repetition.REPEATED, PrimitiveType.PrimitiveTypeName.INT32, "element"); + } + + private static GroupType map(Type valueType) + { + return new GroupType(Type.Repetition.REPEATED, "map", OriginalType.MAP_KEY_VALUE, primitiveType("key"), valueType); + } + + private static GroupType mapType2(Type valueType) + { + return new GroupType(Type.Repetition.REPEATED, "map", primitiveType("key"), valueType); + } + + private static Type primitiveType(String name) + { + return new PrimitiveType(Type.Repetition.OPTIONAL, PrimitiveType.PrimitiveTypeName.INT32, name); + } + + private static List path(PathElement... elements) + { + return asList(elements); + } + + private static NestedField nestedField(String name) + { + return new NestedField(name); + } + + private static AllSubscripts allSubscripts() + { + return AllSubscripts.getInstance(); + } + + private static StringSubscript subscript(String index) + { + return new StringSubscript(index); + } +} diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java index 2d483b5fde5a..fe1a2b7f8765 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java @@ -156,7 +156,8 @@ private void testParquetTupleDomainStructWithPrimitiveColumnPredicate(boolean us ImmutableList.of(1), ImmutableList.of("b"), HiveType.HIVE_INT, - INTEGER); + INTEGER, + ImmutableList.of()); HiveColumnHandle projectedColumn = new HiveColumnHandle( "row_field", @@ -197,7 +198,8 @@ public void testParquetTupleDomainStructWithComplexColumnPredicate() ImmutableList.of(2), ImmutableList.of("C"), HiveTypeTranslator.toHiveType(c1Type), - c1Type); + c1Type, + ImmutableList.of()); HiveColumnHandle projectedColumn = new HiveColumnHandle( "row_field", @@ -237,7 +239,8 @@ public void testParquetTupleDomainStructWithMissingPrimitiveColumn() ImmutableList.of(2), ImmutableList.of("non_exist"), HiveType.HIVE_INT, - INTEGER); + INTEGER, + ImmutableList.of()); HiveColumnHandle projectedColumn = new HiveColumnHandle( "row_field",