From d7cde8c5b63977f899737a1f0e8782a05cc344ec Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Fri, 1 Mar 2019 16:51:38 -0800 Subject: [PATCH 01/18] Rename listTableLayouts method More accurately, constructs a set of alternate plans with the filter pushed into the table scan based on available table layouts. --- .../prestosql/sql/planner/iterative/rule/PickTableLayout.java | 4 ++-- .../io/prestosql/sql/planner/optimizations/AddExchanges.java | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PickTableLayout.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PickTableLayout.java index a0a9c850f8ee..2ea5da69cb74 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PickTableLayout.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PickTableLayout.java @@ -228,7 +228,7 @@ private static PlanNode planTableScan( SqlParser parser, DomainTranslator domainTranslator) { - return listTableLayouts( + return pushFilterIntoTableScan( node, predicate, false, @@ -241,7 +241,7 @@ private static PlanNode planTableScan( .get(0); } - public static List listTableLayouts( + public static List pushFilterIntoTableScan( TableScanNode node, Expression predicate, boolean pruneWithPredicateExpression, diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddExchanges.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddExchanges.java index 14e1cb96fdad..a728aeb92bdd 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddExchanges.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddExchanges.java @@ -540,7 +540,7 @@ else if (redistributeWrites) { private PlanWithProperties planTableScan(TableScanNode node, Expression predicate, PreferredProperties preferredProperties) { - List possiblePlans = PickTableLayout.listTableLayouts(node, predicate, true, session, types, idAllocator, metadata, parser, domainTranslator); + List possiblePlans = PickTableLayout.pushFilterIntoTableScan(node, predicate, true, session, types, idAllocator, metadata, parser, domainTranslator); List possiblePlansWithProperties = possiblePlans.stream() .map(planNode -> new PlanWithProperties(planNode, derivePropertiesRecursively(planNode))) .collect(toImmutableList()); From c68109cecf094b4cb7fb4c90a26670da1c2032c9 Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Tue, 5 Mar 2019 09:45:12 -0800 Subject: [PATCH 02/18] Add null checks --- .../main/java/io/prestosql/metadata/TableLayoutResult.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/metadata/TableLayoutResult.java b/presto-main/src/main/java/io/prestosql/metadata/TableLayoutResult.java index bcf875138171..77c3a2905a70 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/TableLayoutResult.java +++ b/presto-main/src/main/java/io/prestosql/metadata/TableLayoutResult.java @@ -26,6 +26,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; public class TableLayoutResult { @@ -34,8 +35,8 @@ public class TableLayoutResult public TableLayoutResult(TableLayout layout, TupleDomain unenforcedConstraint) { - this.layout = layout; - this.unenforcedConstraint = unenforcedConstraint; + this.layout = requireNonNull(layout, "layout is null"); + this.unenforcedConstraint = requireNonNull(unenforcedConstraint, "unenforcedConstraint is null"); } public TableLayout getLayout() From 801423bdbc27495159fbb407273e850a7a902c21 Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Fri, 15 Feb 2019 18:18:42 -0800 Subject: [PATCH 03/18] Remove support for multiple table layouts --- .../benchmark/AbstractOperatorBenchmark.java | 4 +- .../hive/TestHiveIntegrationSmokeTest.java | 6 +- .../java/io/prestosql/metadata/Metadata.java | 2 +- .../prestosql/metadata/MetadataManager.java | 16 +- .../prestosql/metadata/TableLayoutResult.java | 18 - .../iterative/rule/ExtractSpatialJoins.java | 6 +- .../iterative/rule/PickTableLayout.java | 80 +- .../planner/optimizations/AddExchanges.java | 98 +-- .../optimizations/BeginTableWrite.java | 12 +- .../optimizations/MetadataQueryOptimizer.java | 7 +- .../metadata/AbstractMockMetadata.java | 2 +- .../optimizations/TestAddExchanges.java | 795 ------------------ 12 files changed, 59 insertions(+), 987 deletions(-) delete mode 100644 presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestAddExchanges.java diff --git a/presto-benchmark/src/main/java/io/prestosql/benchmark/AbstractOperatorBenchmark.java b/presto-benchmark/src/main/java/io/prestosql/benchmark/AbstractOperatorBenchmark.java index 7c17c605aea1..7c720e1158fb 100644 --- a/presto-benchmark/src/main/java/io/prestosql/benchmark/AbstractOperatorBenchmark.java +++ b/presto-benchmark/src/main/java/io/prestosql/benchmark/AbstractOperatorBenchmark.java @@ -170,8 +170,8 @@ protected final OperatorFactory createTableScanOperator(int operatorId, PlanNode List columnHandles = columnHandlesBuilder.build(); // get the split for this table - List layouts = metadata.getLayouts(session, tableHandle, Constraint.alwaysTrue(), Optional.empty()); - Split split = getLocalQuerySplit(session, layouts.get(0).getLayout().getHandle()); + Optional layout = metadata.getLayout(session, tableHandle, Constraint.alwaysTrue(), Optional.empty()); + Split split = getLocalQuerySplit(session, layout.get().getLayout().getHandle()); return new OperatorFactory() { diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveIntegrationSmokeTest.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveIntegrationSmokeTest.java index 6f5e69a4cd77..030a713eec4d 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveIntegrationSmokeTest.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveIntegrationSmokeTest.java @@ -25,7 +25,6 @@ import io.prestosql.metadata.QualifiedObjectName; import io.prestosql.metadata.TableHandle; import io.prestosql.metadata.TableLayout; -import io.prestosql.metadata.TableLayoutResult; import io.prestosql.metadata.TableMetadata; import io.prestosql.plugin.hive.HiveSessionProperties.InsertExistingPartitionsBehavior; import io.prestosql.spi.connector.CatalogSchemaTableName; @@ -1673,8 +1672,9 @@ private Object getHiveTableProperty(String tableName, Function tableHandle = metadata.getTableHandle(transactionSession, new QualifiedObjectName(catalog, TPCH_SCHEMA, tableName)); assertTrue(tableHandle.isPresent()); - List layouts = metadata.getLayouts(transactionSession, tableHandle.get(), Constraint.alwaysTrue(), Optional.empty()); - TableLayout layout = getOnlyElement(layouts).getLayout(); + TableLayout layout = metadata.getLayout(transactionSession, tableHandle.get(), Constraint.alwaysTrue(), Optional.empty()) + .get() + .getLayout(); return propertyGetter.apply((HiveTableLayoutHandle) layout.getHandle().getConnectorHandle()); }); } diff --git a/presto-main/src/main/java/io/prestosql/metadata/Metadata.java b/presto-main/src/main/java/io/prestosql/metadata/Metadata.java index 93aad49cd027..264fb3e4fc0f 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/Metadata.java +++ b/presto-main/src/main/java/io/prestosql/metadata/Metadata.java @@ -73,7 +73,7 @@ public interface Metadata Optional getTableHandleForStatisticsCollection(Session session, QualifiedObjectName tableName, Map analyzeProperties); - List getLayouts(Session session, TableHandle tableHandle, Constraint constraint, Optional> desiredColumns); + Optional getLayout(Session session, TableHandle tableHandle, Constraint constraint, Optional> desiredColumns); TableLayout getLayout(Session session, TableLayoutHandle handle); diff --git a/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java b/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java index 69a2cee9b473..49a76f3e8ffb 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java +++ b/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java @@ -88,7 +88,6 @@ import java.util.concurrent.ConcurrentMap; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.prestosql.metadata.QualifiedObjectName.convertFromSchemaTableName; import static io.prestosql.metadata.TableLayout.fromConnectorLayout; @@ -373,10 +372,10 @@ public Optional getSystemTable(Session session, QualifiedObjectName } @Override - public List getLayouts(Session session, TableHandle table, Constraint constraint, Optional> desiredColumns) + public Optional getLayout(Session session, TableHandle table, Constraint constraint, Optional> desiredColumns) { if (constraint.getSummary().isNone()) { - return ImmutableList.of(); + return Optional.empty(); } ConnectorId connectorId = table.getConnectorId(); @@ -387,10 +386,15 @@ public List getLayouts(Session session, TableHandle table, Co ConnectorTransactionHandle transaction = catalogMetadata.getTransactionHandleFor(connectorId); ConnectorSession connectorSession = session.toConnectorSession(connectorId); List layouts = metadata.getTableLayouts(connectorSession, connectorTable, constraint, desiredColumns); + if (layouts.isEmpty()) { + return Optional.empty(); + } + + if (layouts.size() > 1) { + throw new PrestoException(NOT_SUPPORTED, format("Connector returned multiple layouts for table %s", table)); + } - return layouts.stream() - .map(layout -> new TableLayoutResult(fromConnectorLayout(connectorId, transaction, layout.getTableLayout()), layout.getUnenforcedConstraint())) - .collect(toImmutableList()); + return Optional.of(new TableLayoutResult(fromConnectorLayout(connectorId, transaction, layouts.get(0).getTableLayout()), layouts.get(0).getUnenforcedConstraint())); } @Override diff --git a/presto-main/src/main/java/io/prestosql/metadata/TableLayoutResult.java b/presto-main/src/main/java/io/prestosql/metadata/TableLayoutResult.java index 77c3a2905a70..c545333fa995 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/TableLayoutResult.java +++ b/presto-main/src/main/java/io/prestosql/metadata/TableLayoutResult.java @@ -14,18 +14,13 @@ package io.prestosql.metadata; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.predicate.Domain; import io.prestosql.spi.predicate.TupleDomain; -import io.prestosql.sql.planner.plan.TableScanNode; -import java.util.List; import java.util.Map; -import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; public class TableLayoutResult @@ -49,19 +44,6 @@ public TupleDomain getUnenforcedConstraint() return unenforcedConstraint; } - public boolean hasAllOutputs(TableScanNode node) - { - if (!layout.getColumns().isPresent()) { - return true; - } - Set columns = ImmutableSet.copyOf(layout.getColumns().get()); - List nodeColumnHandles = node.getOutputSymbols().stream() - .map(node.getAssignments()::get) - .collect(toImmutableList()); - - return columns.containsAll(nodeColumnHandles); - } - public static TupleDomain computeEnforced(TupleDomain predicate, TupleDomain unenforced) { if (predicate.isNone()) { diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractSpatialJoins.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractSpatialJoins.java index d02859f5e810..9a5accbfaf8c 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractSpatialJoins.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractSpatialJoins.java @@ -469,11 +469,11 @@ private static KdbTree loadKdbTree(String tableName, Session session, Metadata m ColumnHandle kdbTreeColumn = Iterables.getOnlyElement(visibleColumnHandles); - List layouts = metadata.getLayouts(session, tableHandle, Constraint.alwaysTrue(), Optional.of(ImmutableSet.of(kdbTreeColumn))); - checkSpatialPartitioningTable(!layouts.isEmpty(), "Table is empty: %s", name); + Optional layout = metadata.getLayout(session, tableHandle, Constraint.alwaysTrue(), Optional.of(ImmutableSet.of(kdbTreeColumn))); + checkSpatialPartitioningTable(layout.isPresent(), "Table is empty: %s", name); Optional kdbTree = Optional.empty(); - try (SplitSource splitSource = splitManager.getSplits(session, layouts.get(0).getLayout().getHandle(), UNGROUPED_SCHEDULING)) { + try (SplitSource splitSource = splitManager.getSplits(session, layout.get().getLayout().getHandle(), UNGROUPED_SCHEDULING)) { while (!Thread.currentThread().isInterrupted()) { SplitBatch splitBatch = getFutureValue(splitSource.getNextBatch(NOT_PARTITIONED, Lifespan.taskWide(), 1000)); List splits = splitBatch.getSplits(); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PickTableLayout.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PickTableLayout.java index 2ea5da69cb74..9fc236b99158 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PickTableLayout.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PickTableLayout.java @@ -47,14 +47,11 @@ import io.prestosql.sql.tree.NodeRef; import io.prestosql.sql.tree.NullLiteral; -import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; -import static com.google.common.base.Preconditions.checkState; -import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Sets.intersection; import static io.prestosql.SystemSessionProperties.isNewOptimizerEnabled; @@ -71,7 +68,6 @@ import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL; import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; -import static java.util.stream.Collectors.toList; /** * These rules should not be run after AddExchanges so as not to overwrite the TableLayout @@ -237,11 +233,10 @@ private static PlanNode planTableScan( idAllocator, metadata, parser, - domainTranslator) - .get(0); + domainTranslator); } - public static List pushFilterIntoTableScan( + public static PlanNode pushFilterIntoTableScan( TableScanNode node, Expression predicate, boolean pruneWithPredicateExpression, @@ -288,8 +283,7 @@ public static List pushFilterIntoTableScan( constraint = new Constraint<>(newDomain); } - // Layouts will be returned in order of the connector's preference - List layouts = metadata.getLayouts( + Optional layout = metadata.getLayout( session, node.getTable(), constraint, @@ -297,51 +291,37 @@ public static List pushFilterIntoTableScan( .map(node.getAssignments()::get) .collect(toImmutableSet()))); - if (layouts.isEmpty()) { - return ImmutableList.of(new ValuesNode(idAllocator.getNextId(), node.getOutputSymbols(), ImmutableList.of())); + if (!layout.isPresent() || layout.get().getLayout().getPredicate().isNone()) { + return new ValuesNode(idAllocator.getNextId(), node.getOutputSymbols(), ImmutableList.of()); } - // Filter out layouts that cannot supply all the required columns - layouts = layouts.stream() - .filter(layout -> layout.hasAllOutputs(node)) - .collect(toList()); - checkState(!layouts.isEmpty(), "No usable layouts for %s", node); - - if (layouts.stream().anyMatch(layout -> layout.getLayout().getPredicate().isNone())) { - return ImmutableList.of(new ValuesNode(idAllocator.getNextId(), node.getOutputSymbols(), ImmutableList.of())); + TableScanNode tableScan = new TableScanNode( + node.getId(), + node.getTable(), + node.getOutputSymbols(), + node.getAssignments(), + Optional.of(layout.get().getLayout().getHandle()), + layout.get().getLayout().getPredicate(), + computeEnforced(newDomain, layout.get().getUnenforcedConstraint())); + + // The order of the arguments to combineConjuncts matters: + // * Unenforced constraints go first because they can only be simple column references, + // which are not prone to logic errors such as out-of-bound access, div-by-zero, etc. + // * Conjuncts in non-deterministic expressions and non-TupleDomain-expressible expressions should + // retain their original (maybe intermixed) order from the input predicate. However, this is not implemented yet. + // * Short of implementing the previous bullet point, the current order of non-deterministic expressions + // and non-TupleDomain-expressible expressions should be retained. Changing the order can lead + // to failures of previously successful queries. + Expression resultingPredicate = combineConjuncts( + domainTranslator.toPredicate(layout.get().getUnenforcedConstraint().transform(assignments::get)), + filterNonDeterministicConjuncts(predicate), + decomposedPredicate.getRemainingExpression()); + + if (!TRUE_LITERAL.equals(resultingPredicate)) { + return new FilterNode(idAllocator.getNextId(), tableScan, resultingPredicate); } - return layouts.stream() - .map(layout -> { - TableScanNode tableScan = new TableScanNode( - node.getId(), - node.getTable(), - node.getOutputSymbols(), - node.getAssignments(), - Optional.of(layout.getLayout().getHandle()), - layout.getLayout().getPredicate(), - computeEnforced(newDomain, layout.getUnenforcedConstraint())); - - // The order of the arguments to combineConjuncts matters: - // * Unenforced constraints go first because they can only be simple column references, - // which are not prone to logic errors such as out-of-bound access, div-by-zero, etc. - // * Conjuncts in non-deterministic expressions and non-TupleDomain-expressible expressions should - // retain their original (maybe intermixed) order from the input predicate. However, this is not implemented yet. - // * Short of implementing the previous bullet point, the current order of non-deterministic expressions - // and non-TupleDomain-expressible expressions should be retained. Changing the order can lead - // to failures of previously successful queries. - Expression resultingPredicate = combineConjuncts( - domainTranslator.toPredicate(layout.getUnenforcedConstraint().transform(assignments::get)), - filterNonDeterministicConjuncts(predicate), - decomposedPredicate.getRemainingExpression()); - - if (!TRUE_LITERAL.equals(resultingPredicate)) { - return new FilterNode(idAllocator.getNextId(), tableScan, resultingPredicate); - } - - return tableScan; - }) - .collect(toImmutableList()); + return tableScan; } private static class LayoutConstraintEvaluator diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddExchanges.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddExchanges.java index a728aeb92bdd..873e6eadb058 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddExchanges.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddExchanges.java @@ -14,10 +14,6 @@ package io.prestosql.sql.planner.optimizations; import com.google.common.annotations.VisibleForTesting; -import com.google.common.cache.CacheBuilder; -import com.google.common.cache.CacheLoader; -import com.google.common.cache.LoadingCache; -import com.google.common.collect.ComparisonChain; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableSet; @@ -78,10 +74,7 @@ import io.prestosql.sql.tree.SymbolReference; import java.util.ArrayList; -import java.util.Collections; -import java.util.Comparator; import java.util.HashMap; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; @@ -89,7 +82,6 @@ import java.util.function.Function; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; @@ -540,26 +532,8 @@ else if (redistributeWrites) { private PlanWithProperties planTableScan(TableScanNode node, Expression predicate, PreferredProperties preferredProperties) { - List possiblePlans = PickTableLayout.pushFilterIntoTableScan(node, predicate, true, session, types, idAllocator, metadata, parser, domainTranslator); - List possiblePlansWithProperties = possiblePlans.stream() - .map(planNode -> new PlanWithProperties(planNode, derivePropertiesRecursively(planNode))) - .collect(toImmutableList()); - return pickPlan(possiblePlansWithProperties, preferredProperties); - } - - /** - * possiblePlans should be provided in layout preference order - */ - private PlanWithProperties pickPlan(List possiblePlans, PreferredProperties preferredProperties) - { - checkArgument(!possiblePlans.isEmpty()); - - if (preferStreamingOperators) { - possiblePlans = new ArrayList<>(possiblePlans); - Collections.sort(possiblePlans, Comparator.comparing(PlanWithProperties::getProperties, streamingExecutionPreference(preferredProperties))); // stable sort; is Collections.min() guaranteed to be stable? - } - - return possiblePlans.get(0); + PlanNode plan = PickTableLayout.pushFilterIntoTableScan(node, predicate, true, session, types, idAllocator, metadata, parser, domainTranslator); + return new PlanWithProperties(plan, derivePropertiesRecursively(plan)); } @Override @@ -1239,74 +1213,6 @@ private static Map computeIdentityTranslations(Assignments assig return outputToInput; } - @VisibleForTesting - static Comparator streamingExecutionPreference(PreferredProperties preferred) - { - // Calculating the matches can be a bit expensive, so cache the results between comparisons - LoadingCache>, List>>> matchCache = CacheBuilder.newBuilder() - .build(CacheLoader.from(actualProperties -> LocalProperties.match(actualProperties, preferred.getLocalProperties()))); - - return (actual1, actual2) -> { - List>> matchLayout1 = matchCache.getUnchecked(actual1.getLocalProperties()); - List>> matchLayout2 = matchCache.getUnchecked(actual2.getLocalProperties()); - - return ComparisonChain.start() - .compareTrueFirst(hasLocalOptimization(preferred.getLocalProperties(), matchLayout1), hasLocalOptimization(preferred.getLocalProperties(), matchLayout2)) - .compareTrueFirst(meetsPartitioningRequirements(preferred, actual1), meetsPartitioningRequirements(preferred, actual2)) - .compare(matchLayout1, matchLayout2, matchedLayoutPreference()) - .result(); - }; - } - - private static boolean hasLocalOptimization(List> desiredLayout, List>> matchResult) - { - checkArgument(desiredLayout.size() == matchResult.size()); - if (matchResult.isEmpty()) { - return false; - } - // Optimizations can be applied if the first LocalProperty has been modified in the match in any way - return !matchResult.get(0).equals(Optional.of(desiredLayout.get(0))); - } - - private static boolean meetsPartitioningRequirements(PreferredProperties preferred, ActualProperties actual) - { - if (!preferred.getGlobalProperties().isPresent()) { - return true; - } - PreferredProperties.Global preferredGlobal = preferred.getGlobalProperties().get(); - if (!preferredGlobal.isDistributed()) { - return actual.isSingleNode(); - } - if (!preferredGlobal.getPartitioningProperties().isPresent()) { - return !actual.isSingleNode(); - } - return actual.isStreamPartitionedOn(preferredGlobal.getPartitioningProperties().get().getPartitioningColumns()); - } - - // Prefer the match result that satisfied the most requirements - private static Comparator>>> matchedLayoutPreference() - { - return (matchLayout1, matchLayout2) -> { - Iterator>> match1Iterator = matchLayout1.iterator(); - Iterator>> match2Iterator = matchLayout2.iterator(); - while (match1Iterator.hasNext() && match2Iterator.hasNext()) { - Optional> match1 = match1Iterator.next(); - Optional> match2 = match2Iterator.next(); - if (match1.isPresent() && match2.isPresent()) { - return Integer.compare(match1.get().getColumns().size(), match2.get().getColumns().size()); - } - else if (match1.isPresent()) { - return 1; - } - else if (match2.isPresent()) { - return -1; - } - } - checkState(!match1Iterator.hasNext() && !match2Iterator.hasNext()); // Should be the same size - return 0; - }; - } - @VisibleForTesting static class PlanWithProperties { diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/BeginTableWrite.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/BeginTableWrite.java index 7057b2344e27..4b35396aa74b 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/BeginTableWrite.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/BeginTableWrite.java @@ -41,12 +41,10 @@ import io.prestosql.sql.planner.plan.TableWriterNode; import io.prestosql.sql.planner.plan.UnionNode; -import java.util.List; import java.util.Optional; import java.util.Set; import static com.google.common.base.Preconditions.checkState; -import static com.google.common.base.Verify.verify; import static io.prestosql.metadata.TableLayoutResult.computeEnforced; import static io.prestosql.sql.planner.optimizations.QueryCardinalityUtil.isAtMostScalar; import static io.prestosql.sql.planner.plan.ChildReplacer.replaceChildren; @@ -197,22 +195,20 @@ private PlanNode rewriteDeleteTableScan(PlanNode node, TableHandle handle) TableScanNode scan = (TableScanNode) node; TupleDomain originalEnforcedConstraint = scan.getEnforcedConstraint(); - List layouts = metadata.getLayouts( + Optional layout = metadata.getLayout( session, handle, new Constraint<>(originalEnforcedConstraint), Optional.of(ImmutableSet.copyOf(scan.getAssignments().values()))); - verify(layouts.size() == 1, "Expected exactly one layout for delete"); - TableLayoutResult layoutResult = Iterables.getOnlyElement(layouts); return new TableScanNode( scan.getId(), handle, scan.getOutputSymbols(), scan.getAssignments(), - Optional.of(layoutResult.getLayout().getHandle()), - layoutResult.getLayout().getPredicate(), - computeEnforced(originalEnforcedConstraint, layoutResult.getUnenforcedConstraint())); + Optional.of(layout.get().getLayout().getHandle()), + layout.get().getLayout().getPredicate(), + computeEnforced(originalEnforcedConstraint, layout.get().getUnenforcedConstraint())); } if (node instanceof FilterNode) { diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/MetadataQueryOptimizer.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/MetadataQueryOptimizer.java index 3e70a8d9b905..03d541629a75 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/MetadataQueryOptimizer.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/MetadataQueryOptimizer.java @@ -139,10 +139,9 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext cont // with a Values node TableLayout layout = null; if (!tableScan.getLayout().isPresent()) { - List layouts = metadata.getLayouts(session, tableScan.getTable(), Constraint.alwaysTrue(), Optional.empty()); - if (layouts.size() == 1) { - layout = Iterables.getOnlyElement(layouts).getLayout(); - } + layout = metadata.getLayout(session, tableScan.getTable(), Constraint.alwaysTrue(), Optional.empty()) + .map(TableLayoutResult::getLayout) + .orElse(null); } else { layout = metadata.getLayout(session, tableScan.getLayout().get()); diff --git a/presto-main/src/test/java/io/prestosql/metadata/AbstractMockMetadata.java b/presto-main/src/test/java/io/prestosql/metadata/AbstractMockMetadata.java index 3af3a4c0c168..f58e5a05db1a 100644 --- a/presto-main/src/test/java/io/prestosql/metadata/AbstractMockMetadata.java +++ b/presto-main/src/test/java/io/prestosql/metadata/AbstractMockMetadata.java @@ -114,7 +114,7 @@ public Optional getSystemTable(Session session, QualifiedObjectName } @Override - public List getLayouts(Session session, TableHandle tableHandle, Constraint constraint, Optional> desiredColumns) + public Optional getLayout(Session session, TableHandle tableHandle, Constraint constraint, Optional> desiredColumns) { throw new UnsupportedOperationException(); } diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestAddExchanges.java b/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestAddExchanges.java deleted file mode 100644 index 78c343a81e03..000000000000 --- a/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestAddExchanges.java +++ /dev/null @@ -1,795 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prestosql.sql.planner.optimizations; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Lists; -import io.prestosql.spi.block.SortOrder; -import io.prestosql.spi.connector.ConstantProperty; -import io.prestosql.spi.connector.GroupingProperty; -import io.prestosql.spi.connector.SortingProperty; -import io.prestosql.sql.planner.Symbol; -import io.prestosql.sql.planner.optimizations.ActualProperties.Global; -import org.testng.annotations.Test; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.Comparator; -import java.util.List; -import java.util.Optional; - -import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.prestosql.spi.block.SortOrder.ASC_NULLS_FIRST; -import static io.prestosql.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; -import static io.prestosql.sql.planner.optimizations.ActualProperties.Global.arbitraryPartition; -import static io.prestosql.sql.planner.optimizations.ActualProperties.Global.partitionedOn; -import static io.prestosql.sql.planner.optimizations.ActualProperties.Global.singleStreamPartition; -import static io.prestosql.sql.planner.optimizations.ActualProperties.builder; -import static io.prestosql.sql.planner.optimizations.AddExchanges.streamingExecutionPreference; -import static org.testng.Assert.assertEquals; - -/** - * These are unit test for the internal logic in AddExchanges. - * For plan tests see {@link TestAddExchangesPlans} - */ -public class TestAddExchanges -{ - @Test - public void testPickLayoutAnyPreference() - { - Comparator preference = streamingExecutionPreference(PreferredProperties.any()); - - List input = ImmutableList.builder() - .add(builder() - .global(streamPartitionedOn("a", "b")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .build(); - // Given no preferences, the original input order should be maintained - assertEquals(stableSort(input, preference), input); - } - - @Test - public void testPickLayoutPartitionedPreference() - { - Comparator preference = streamingExecutionPreference(PreferredProperties.distributed()); - - List input = ImmutableList.builder() - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .build(); - - List expected = ImmutableList.builder() - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .build(); - assertEquals(stableSort(input, preference), expected); - } - - @Test - public void testPickLayoutUnpartitionedPreference() - { - Comparator preference = streamingExecutionPreference(PreferredProperties.undistributed()); - - List input = ImmutableList.builder() - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .build(); - - List expected = ImmutableList.builder() - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .build(); - assertEquals(stableSort(input, preference), expected); - } - - @Test - public void testPickLayoutPartitionedOnSingle() - { - Comparator preference = streamingExecutionPreference( - PreferredProperties.partitioned(ImmutableSet.of(symbol("a")))); - - List input = ImmutableList.builder() - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .build(); - - List expected = ImmutableList.builder() - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .build(); - assertEquals(stableSort(input, preference), expected); - } - - @Test - public void testPickLayoutPartitionedOnMultiple() - { - Comparator preference = streamingExecutionPreference( - PreferredProperties.partitioned(ImmutableSet.of(symbol("a"), symbol("b")))); - - List input = ImmutableList.builder() - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .build(); - - List expected = ImmutableList.builder() - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .build(); - assertEquals(stableSort(input, preference), expected); - } - - @Test - public void testPickLayoutGrouped() - { - Comparator preference = streamingExecutionPreference - (PreferredProperties.local(ImmutableList.of(grouped("a")))); - - List input = ImmutableList.builder() - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .build(); - - List expected = ImmutableList.builder() - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .build(); - assertEquals(stableSort(input, preference), expected); - } - - @Test - public void testPickLayoutGroupedMultiple() - { - Comparator preference = streamingExecutionPreference - (PreferredProperties.local(ImmutableList.of(grouped("a", "b")))); - - List input = ImmutableList.builder() - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .build(); - - List expected = ImmutableList.builder() - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .build(); - assertEquals(stableSort(input, preference), expected); - } - - @Test - public void testPickLayoutGroupedMultipleProperties() - { - Comparator preference = streamingExecutionPreference - (PreferredProperties.local(ImmutableList.of(grouped("a"), grouped("b")))); - - List input = ImmutableList.builder() - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .build(); - - List expected = ImmutableList.builder() - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .build(); - assertEquals(stableSort(input, preference), expected); - } - - @Test - public void testPickLayoutGroupedWithSort() - { - Comparator preference = streamingExecutionPreference - (PreferredProperties.local(ImmutableList.of(grouped("a"), sorted("b", ASC_NULLS_FIRST)))); - - List input = ImmutableList.builder() - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .build(); - - List expected = ImmutableList.builder() - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .build(); - assertEquals(stableSort(input, preference), expected); - } - - @Test - public void testPickLayoutUnpartitionedWithGroupAndSort() - { - Comparator preference = streamingExecutionPreference - (PreferredProperties.undistributedWithLocal(ImmutableList.of(grouped("a"), sorted("b", ASC_NULLS_FIRST)))); - - List input = ImmutableList.builder() - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .build(); - - List expected = ImmutableList.builder() - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .build(); - assertEquals(stableSort(input, preference), expected); - } - - @Test - public void testPickLayoutPartitionedWithGroup() - { - Comparator preference = streamingExecutionPreference - (PreferredProperties.partitionedWithLocal( - ImmutableSet.of(symbol("a")), - ImmutableList.of(grouped("a")))); - - List input = ImmutableList.builder() - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .build(); - - List expected = ImmutableList.builder() - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .build(); - assertEquals(stableSort(input, preference), expected); - } - - private static List stableSort(List list, Comparator comparator) - { - ArrayList copy = Lists.newArrayList(list); - Collections.sort(copy, comparator); - return copy; - } - - private static Global hashDistributedOn(String... columnNames) - { - return partitionedOn(FIXED_HASH_DISTRIBUTION, arguments(columnNames), Optional.of(arguments(columnNames))); - } - - public static Global singleStream() - { - return Global.streamPartitionedOn(ImmutableList.of()); - } - - private static Global streamPartitionedOn(String... columnNames) - { - return Global.streamPartitionedOn(arguments(columnNames)); - } - - private static ConstantProperty constant(String column) - { - return new ConstantProperty<>(symbol(column)); - } - - private static GroupingProperty grouped(String... columns) - { - return new GroupingProperty<>(Lists.transform(Arrays.asList(columns), Symbol::new)); - } - - private static SortingProperty sorted(String column, SortOrder order) - { - return new SortingProperty<>(symbol(column), order); - } - - private static Symbol symbol(String name) - { - return new Symbol(name); - } - - private static List arguments(String[] columnNames) - { - return Arrays.asList(columnNames).stream() - .map(Symbol::new) - .collect(toImmutableList()); - } -} From c6abef5dffa267d9d7ea8132e1b9528b14eb5cad Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Thu, 28 Feb 2019 09:58:39 -0800 Subject: [PATCH 04/18] Remove unused parameter --- .../prestosql/sql/planner/optimizations/AddExchanges.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddExchanges.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddExchanges.java index 873e6eadb058..a4f3a11a331b 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddExchanges.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddExchanges.java @@ -491,7 +491,7 @@ public PlanWithProperties visitDistinctLimit(DistinctLimitNode node, PreferredPr public PlanWithProperties visitFilter(FilterNode node, PreferredProperties preferredProperties) { if (node.getSource() instanceof TableScanNode) { - return planTableScan((TableScanNode) node.getSource(), node.getPredicate(), preferredProperties); + return planTableScan((TableScanNode) node.getSource(), node.getPredicate()); } return rebaseAndDeriveProperties(node, planChild(node, preferredProperties)); @@ -500,7 +500,7 @@ public PlanWithProperties visitFilter(FilterNode node, PreferredProperties prefe @Override public PlanWithProperties visitTableScan(TableScanNode node, PreferredProperties preferredProperties) { - return planTableScan(node, TRUE_LITERAL, preferredProperties); + return planTableScan(node, TRUE_LITERAL); } @Override @@ -530,7 +530,7 @@ else if (redistributeWrites) { return rebaseAndDeriveProperties(node, source); } - private PlanWithProperties planTableScan(TableScanNode node, Expression predicate, PreferredProperties preferredProperties) + private PlanWithProperties planTableScan(TableScanNode node, Expression predicate) { PlanNode plan = PickTableLayout.pushFilterIntoTableScan(node, predicate, true, session, types, idAllocator, metadata, parser, domainTranslator); return new PlanWithProperties(plan, derivePropertiesRecursively(plan)); From 3a3a5e73672a4ab124dd672ba3c7a2be15389356 Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Fri, 1 Mar 2019 11:42:04 -0800 Subject: [PATCH 05/18] Move test to distributed query tests This test relies on a layout being picked during planning. For LocalQueryRunner this only happens because of a synthetic rule (PickLayout w/o predicate) that attaches a layout to the table scan. So, in a sense, it's just a coincidence that it works. In distributed execution, the job is done by AddExchanges, so we want to make sure we're testing that behavior. --- .../io/prestosql/tests/TestLocalQueries.java | 42 ------------------- .../tests/TestTpchDistributedQueries.java | 39 +++++++++++++++++ 2 files changed, 39 insertions(+), 42 deletions(-) diff --git a/presto-tests/src/test/java/io/prestosql/tests/TestLocalQueries.java b/presto-tests/src/test/java/io/prestosql/tests/TestLocalQueries.java index 164b82fe5207..f8a25232d7a7 100644 --- a/presto-tests/src/test/java/io/prestosql/tests/TestLocalQueries.java +++ b/presto-tests/src/test/java/io/prestosql/tests/TestLocalQueries.java @@ -14,32 +14,18 @@ package io.prestosql.tests; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; import io.prestosql.Session; import io.prestosql.connector.ConnectorId; import io.prestosql.metadata.SessionPropertyManager; import io.prestosql.plugin.tpch.TpchConnectorFactory; -import io.prestosql.spi.connector.CatalogSchemaTableName; -import io.prestosql.sql.planner.planPrinter.IoPlanPrinter.ColumnConstraint; -import io.prestosql.sql.planner.planPrinter.IoPlanPrinter.FormattedDomain; -import io.prestosql.sql.planner.planPrinter.IoPlanPrinter.FormattedMarker; -import io.prestosql.sql.planner.planPrinter.IoPlanPrinter.FormattedRange; -import io.prestosql.sql.planner.planPrinter.IoPlanPrinter.IoPlan; -import io.prestosql.sql.planner.planPrinter.IoPlanPrinter.IoPlan.TableColumnInfo; import io.prestosql.testing.LocalQueryRunner; import io.prestosql.testing.MaterializedResult; import org.testng.annotations.Test; -import java.util.Optional; - -import static com.google.common.collect.Iterables.getOnlyElement; -import static io.airlift.json.JsonCodec.jsonCodec; import static io.prestosql.SystemSessionProperties.PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN; import static io.prestosql.plugin.tpch.TpchMetadata.TINY_SCHEMA_NAME; -import static io.prestosql.spi.predicate.Marker.Bound.EXACTLY; import static io.prestosql.spi.type.DoubleType.DOUBLE; import static io.prestosql.spi.type.VarcharType.VARCHAR; -import static io.prestosql.spi.type.VarcharType.createVarcharType; import static io.prestosql.testing.MaterializedResult.resultBuilder; import static io.prestosql.testing.TestingSession.TESTING_CATALOG; import static io.prestosql.testing.TestingSession.testSessionBuilder; @@ -115,34 +101,6 @@ public void testDecimal() assertQuery("SELECT 0.1", "SELECT CAST('0.1' AS DECIMAL)"); } - @Test - public void testIOExplain() - { - String query = "SELECT * FROM orders"; - MaterializedResult result = computeActual("EXPLAIN (TYPE IO, FORMAT JSON) " + query); - TableColumnInfo input = new TableColumnInfo( - new CatalogSchemaTableName("local", "sf0.01", "orders"), - ImmutableSet.of( - new ColumnConstraint( - "orderstatus", - createVarcharType(1).getTypeSignature(), - new FormattedDomain( - false, - ImmutableSet.of( - new FormattedRange( - new FormattedMarker(Optional.of("F"), EXACTLY), - new FormattedMarker(Optional.of("F"), EXACTLY)), - new FormattedRange( - new FormattedMarker(Optional.of("O"), EXACTLY), - new FormattedMarker(Optional.of("O"), EXACTLY)), - new FormattedRange( - new FormattedMarker(Optional.of("P"), EXACTLY), - new FormattedMarker(Optional.of("P"), EXACTLY))))))); - assertEquals( - jsonCodec(IoPlan.class).fromJson((String) getOnlyElement(result.getOnlyColumnAsSet())), - new IoPlan(ImmutableSet.of(input), Optional.empty())); - } - @Test public void testHueQueries() { diff --git a/presto-tests/src/test/java/io/prestosql/tests/TestTpchDistributedQueries.java b/presto-tests/src/test/java/io/prestosql/tests/TestTpchDistributedQueries.java index ac3442e47fd6..7ff64aa3fe2b 100644 --- a/presto-tests/src/test/java/io/prestosql/tests/TestTpchDistributedQueries.java +++ b/presto-tests/src/test/java/io/prestosql/tests/TestTpchDistributedQueries.java @@ -14,10 +14,21 @@ package io.prestosql.tests; import com.google.common.base.Strings; +import com.google.common.collect.ImmutableSet; +import io.prestosql.spi.connector.CatalogSchemaTableName; +import io.prestosql.sql.planner.planPrinter.IoPlanPrinter; +import io.prestosql.testing.MaterializedResult; import io.prestosql.tests.tpch.TpchQueryRunnerBuilder; import org.intellij.lang.annotations.Language; import org.testng.annotations.Test; +import java.util.Optional; + +import static com.google.common.collect.Iterables.getOnlyElement; +import static io.airlift.json.JsonCodec.jsonCodec; +import static io.prestosql.spi.predicate.Marker.Bound.EXACTLY; +import static io.prestosql.spi.type.VarcharType.createVarcharType; +import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; public class TestTpchDistributedQueries @@ -28,6 +39,34 @@ public TestTpchDistributedQueries() super(() -> TpchQueryRunnerBuilder.builder().build()); } + @Test + public void testIOExplain() + { + String query = "SELECT * FROM orders"; + MaterializedResult result = computeActual("EXPLAIN (TYPE IO, FORMAT JSON) " + query); + IoPlanPrinter.IoPlan.TableColumnInfo input = new IoPlanPrinter.IoPlan.TableColumnInfo( + new CatalogSchemaTableName("tpch", "sf0.01", "orders"), + ImmutableSet.of( + new IoPlanPrinter.ColumnConstraint( + "orderstatus", + createVarcharType(1).getTypeSignature(), + new IoPlanPrinter.FormattedDomain( + false, + ImmutableSet.of( + new IoPlanPrinter.FormattedRange( + new IoPlanPrinter.FormattedMarker(Optional.of("F"), EXACTLY), + new IoPlanPrinter.FormattedMarker(Optional.of("F"), EXACTLY)), + new IoPlanPrinter.FormattedRange( + new IoPlanPrinter.FormattedMarker(Optional.of("O"), EXACTLY), + new IoPlanPrinter.FormattedMarker(Optional.of("O"), EXACTLY)), + new IoPlanPrinter.FormattedRange( + new IoPlanPrinter.FormattedMarker(Optional.of("P"), EXACTLY), + new IoPlanPrinter.FormattedMarker(Optional.of("P"), EXACTLY))))))); + assertEquals( + jsonCodec(IoPlanPrinter.IoPlan.class).fromJson((String) getOnlyElement(result.getOnlyColumnAsSet())), + new IoPlanPrinter.IoPlan(ImmutableSet.of(input), Optional.empty())); + } + @Test public void testTooLongQuery() { From 1285a778d2626e098cbd78b9d0ded28e43f41cbc Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Wed, 27 Feb 2019 17:55:25 -0800 Subject: [PATCH 06/18] Simplify unconditional PickLayout It's just selecting a layout for the raw table scan, so no need to go through the logic for pushing a predicate, etc. --- .../iterative/rule/PickTableLayout.java | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PickTableLayout.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PickTableLayout.java index 9fc236b99158..23f3a4e7fa8e 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PickTableLayout.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PickTableLayout.java @@ -101,7 +101,7 @@ public PickTableLayoutForPredicate pickTableLayoutForPredicate() public PickTableLayoutWithoutPredicate pickTableLayoutWithoutPredicate() { - return new PickTableLayoutWithoutPredicate(metadata, parser, domainTranslator); + return new PickTableLayoutWithoutPredicate(metadata); } private static final class PickTableLayoutForPredicate @@ -179,14 +179,10 @@ private static final class PickTableLayoutWithoutPredicate implements Rule { private final Metadata metadata; - private final SqlParser parser; - private final DomainTranslator domainTranslator; - private PickTableLayoutWithoutPredicate(Metadata metadata, SqlParser parser, DomainTranslator domainTranslator) + private PickTableLayoutWithoutPredicate(Metadata metadata) { this.metadata = requireNonNull(metadata, "metadata is null"); - this.parser = requireNonNull(parser, "parser is null"); - this.domainTranslator = requireNonNull(domainTranslator, "domainTranslator is null"); } private static final Pattern PATTERN = tableScan(); @@ -210,7 +206,26 @@ public Result apply(TableScanNode tableScanNode, Captures captures, Context cont return Result.empty(); } - return Result.ofPlanNode(planTableScan(tableScanNode, TRUE_LITERAL, context.getSession(), context.getSymbolAllocator().getTypes(), context.getIdAllocator(), metadata, parser, domainTranslator)); + Optional layout = metadata.getLayout( + context.getSession(), + tableScanNode.getTable(), + Constraint.alwaysTrue(), + Optional.of(tableScanNode.getOutputSymbols().stream() + .map(tableScanNode.getAssignments()::get) + .collect(toImmutableSet()))); + + if (!layout.isPresent() || layout.get().getLayout().getPredicate().isNone()) { + return Result.ofPlanNode(new ValuesNode(context.getIdAllocator().getNextId(), tableScanNode.getOutputSymbols(), ImmutableList.of())); + } + + return Result.ofPlanNode(new TableScanNode( + tableScanNode.getId(), + tableScanNode.getTable(), + tableScanNode.getOutputSymbols(), + tableScanNode.getAssignments(), + Optional.of(layout.get().getLayout().getHandle()), + TupleDomain.all(), + TupleDomain.all())); } } From 4aa9539bf47d803a84d81182faf78f88a0f17622 Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Thu, 28 Feb 2019 09:39:19 -0800 Subject: [PATCH 07/18] Inline unnecessary method --- .../iterative/rule/PickTableLayout.java | 33 ++++++------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PickTableLayout.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PickTableLayout.java index 23f3a4e7fa8e..034c48cdcb0d 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PickTableLayout.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PickTableLayout.java @@ -140,7 +140,16 @@ public Result apply(FilterNode filterNode, Captures captures, Context context) { TableScanNode tableScan = captures.get(TABLE_SCAN); - PlanNode rewritten = planTableScan(tableScan, filterNode.getPredicate(), context.getSession(), context.getSymbolAllocator().getTypes(), context.getIdAllocator(), metadata, parser, domainTranslator); + PlanNode rewritten = pushFilterIntoTableScan( + tableScan, + filterNode.getPredicate(), + false, + context.getSession(), + context.getSymbolAllocator().getTypes(), + context.getIdAllocator(), + metadata, + parser, + domainTranslator); if (arePlansSame(filterNode, tableScan, rewritten)) { return Result.empty(); @@ -229,28 +238,6 @@ public Result apply(TableScanNode tableScanNode, Captures captures, Context cont } } - private static PlanNode planTableScan( - TableScanNode node, - Expression predicate, - Session session, - TypeProvider types, - PlanNodeIdAllocator idAllocator, - Metadata metadata, - SqlParser parser, - DomainTranslator domainTranslator) - { - return pushFilterIntoTableScan( - node, - predicate, - false, - session, - types, - idAllocator, - metadata, - parser, - domainTranslator); - } - public static PlanNode pushFilterIntoTableScan( TableScanNode node, Expression predicate, From 1d9e0d1260944b3fc68a3e82543fa79a9f55b498 Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Wed, 27 Feb 2019 15:37:57 -0800 Subject: [PATCH 08/18] Hide TableLayouts from engine This change hides table layouts from the engine as a first-class concept. We keep the SPI as is for backward compatibility for now. When predicates are pushed into a table scan by PickLayout (now PushPredicateIntoTableScan) or AddExchanges, we now replace the table handle associated with the table scan with a new one that contains the reference to the ConnectorTableLayoutHandle under the covers. --- .../benchmark/AbstractOperatorBenchmark.java | 8 +- .../hive/TestHiveIntegrationSmokeTest.java | 11 ++- .../java/io/prestosql/metadata/Metadata.java | 18 ++-- .../prestosql/metadata/MetadataManager.java | 64 +++++++++----- .../io/prestosql/metadata/TableHandle.java | 35 ++++---- .../io/prestosql/metadata/TableLayout.java | 25 ++---- .../prestosql/metadata/TableLayoutHandle.java | 85 ------------------- .../prestosql/metadata/TableLayoutResult.java | 9 +- .../operator/MetadataDeleteOperator.java | 15 ++-- .../java/io/prestosql/split/SplitManager.java | 30 +++++-- .../planner/DistributedExecutionPlanner.java | 2 +- .../prestosql/sql/planner/InputExtractor.java | 12 ++- .../sql/planner/LocalExecutionPlanner.java | 2 +- .../prestosql/sql/planner/LogicalPlanner.java | 2 +- .../prestosql/sql/planner/PlanFragmenter.java | 20 ++--- .../prestosql/sql/planner/PlanOptimizers.java | 6 +- .../prestosql/sql/planner/QueryPlanner.java | 2 +- .../sql/planner/RelationPlanner.java | 2 +- .../iterative/rule/ExtractSpatialJoins.java | 7 +- .../rule/PruneIndexSourceColumns.java | 1 - .../iterative/rule/PruneTableScanColumns.java | 1 - ...t.java => PushPredicateIntoTableScan.java} | 75 +--------------- .../planner/optimizations/AddExchanges.java | 4 +- .../optimizations/BeginTableWrite.java | 19 +---- .../optimizations/IndexJoinOptimizer.java | 1 - .../MetadataDeleteOptimizer.java | 8 +- .../optimizations/MetadataQueryOptimizer.java | 15 +--- .../optimizations/PropertyDerivations.java | 4 +- .../PruneUnreferencedOutputs.java | 3 +- .../StreamPropertyDerivations.java | 4 +- .../UnaliasSymbolReferences.java | 2 +- .../sql/planner/plan/IndexSourceNode.java | 11 --- .../sql/planner/plan/MetadataDeleteNode.java | 14 +-- .../sql/planner/plan/TableScanNode.java | 41 +++------ .../sql/planner/planPrinter/PlanPrinter.java | 11 --- .../prestosql/testing/LocalQueryRunner.java | 7 +- .../io/prestosql/testing/TestingMetadata.java | 5 +- .../connector/MockConnectorFactory.java | 4 +- .../io/prestosql/cost/TestCostCalculator.java | 4 +- .../execution/MockRemoteTaskFactory.java | 6 +- .../io/prestosql/execution/TaskTestUtils.java | 5 +- .../TestPhasedExecutionSchedule.java | 10 ++- .../TestSourcePartitionedScheduler.java | 5 +- .../metadata/AbstractMockMetadata.java | 10 +-- .../TestEffectivePredicateExtractor.java | 16 +--- .../sql/planner/TestTypeValidator.java | 5 +- .../planner/assertions/TableScanMatcher.java | 8 +- .../TestPruneCountAggregationOverScalar.java | 7 +- .../rule/TestPruneIndexSourceColumns.java | 7 +- .../rule/TestPruneTableScanColumns.java | 7 +- ...va => TestPushPredicateIntoTableScan.java} | 80 ++++++----------- .../iterative/rule/TestRemoveEmptyDelete.java | 5 +- ...mCorrelatedSingleRowSubqueryToProject.java | 11 ++- .../iterative/rule/test/PlanBuilder.java | 26 +++--- ...ValidateAggregationsWithDefaultValues.java | 10 +-- .../TestValidateStreamingAggregations.java | 23 ++--- 56 files changed, 296 insertions(+), 534 deletions(-) delete mode 100644 presto-main/src/main/java/io/prestosql/metadata/TableLayoutHandle.java rename presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/{PickTableLayout.java => PushPredicateIntoTableScan.java} (82%) rename presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/{TestPickTableLayout.java => TestPushPredicateIntoTableScan.java} (71%) diff --git a/presto-benchmark/src/main/java/io/prestosql/benchmark/AbstractOperatorBenchmark.java b/presto-benchmark/src/main/java/io/prestosql/benchmark/AbstractOperatorBenchmark.java index 7c720e1158fb..688b561dcd0d 100644 --- a/presto-benchmark/src/main/java/io/prestosql/benchmark/AbstractOperatorBenchmark.java +++ b/presto-benchmark/src/main/java/io/prestosql/benchmark/AbstractOperatorBenchmark.java @@ -29,8 +29,6 @@ import io.prestosql.metadata.QualifiedObjectName; import io.prestosql.metadata.Split; import io.prestosql.metadata.TableHandle; -import io.prestosql.metadata.TableLayoutHandle; -import io.prestosql.metadata.TableLayoutResult; import io.prestosql.operator.Driver; import io.prestosql.operator.DriverContext; import io.prestosql.operator.FilterAndProjectOperator; @@ -48,7 +46,6 @@ import io.prestosql.spi.QueryId; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ConnectorPageSource; -import io.prestosql.spi.connector.Constraint; import io.prestosql.spi.memory.MemoryPoolId; import io.prestosql.spi.type.Type; import io.prestosql.spiller.SpillSpaceTracker; @@ -170,8 +167,7 @@ protected final OperatorFactory createTableScanOperator(int operatorId, PlanNode List columnHandles = columnHandlesBuilder.build(); // get the split for this table - Optional layout = metadata.getLayout(session, tableHandle, Constraint.alwaysTrue(), Optional.empty()); - Split split = getLocalQuerySplit(session, layout.get().getLayout().getHandle()); + Split split = getLocalQuerySplit(session, tableHandle); return new OperatorFactory() { @@ -196,7 +192,7 @@ public OperatorFactory duplicate() }; } - private Split getLocalQuerySplit(Session session, TableLayoutHandle handle) + private Split getLocalQuerySplit(Session session, TableHandle handle) { SplitSource splitSource = localQueryRunner.getSplitManager().getSplits(session, handle, UNGROUPED_SCHEDULING); List splits = new ArrayList<>(); diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveIntegrationSmokeTest.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveIntegrationSmokeTest.java index 030a713eec4d..2e09bf5baba8 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveIntegrationSmokeTest.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveIntegrationSmokeTest.java @@ -24,12 +24,12 @@ import io.prestosql.metadata.Metadata; import io.prestosql.metadata.QualifiedObjectName; import io.prestosql.metadata.TableHandle; -import io.prestosql.metadata.TableLayout; import io.prestosql.metadata.TableMetadata; import io.prestosql.plugin.hive.HiveSessionProperties.InsertExistingPartitionsBehavior; import io.prestosql.spi.connector.CatalogSchemaTableName; import io.prestosql.spi.connector.ColumnMetadata; import io.prestosql.spi.connector.ConnectorSession; +import io.prestosql.spi.connector.ConnectorTableLayoutHandle; import io.prestosql.spi.connector.Constraint; import io.prestosql.spi.security.Identity; import io.prestosql.spi.security.SelectedRole; @@ -1672,10 +1672,13 @@ private Object getHiveTableProperty(String tableName, Function tableHandle = metadata.getTableHandle(transactionSession, new QualifiedObjectName(catalog, TPCH_SCHEMA, tableName)); assertTrue(tableHandle.isPresent()); - TableLayout layout = metadata.getLayout(transactionSession, tableHandle.get(), Constraint.alwaysTrue(), Optional.empty()) + ConnectorTableLayoutHandle connectorLayout = metadata.getLayout(transactionSession, tableHandle.get(), Constraint.alwaysTrue(), Optional.empty()) .get() - .getLayout(); - return propertyGetter.apply((HiveTableLayoutHandle) layout.getHandle().getConnectorHandle()); + .getNewTableHandle() + .getLayout() + .get(); + + return propertyGetter.apply((HiveTableLayoutHandle) connectorLayout); }); } diff --git a/presto-main/src/main/java/io/prestosql/metadata/Metadata.java b/presto-main/src/main/java/io/prestosql/metadata/Metadata.java index 264fb3e4fc0f..7dfeb8706ad5 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/Metadata.java +++ b/presto-main/src/main/java/io/prestosql/metadata/Metadata.java @@ -75,23 +75,23 @@ public interface Metadata Optional getLayout(Session session, TableHandle tableHandle, Constraint constraint, Optional> desiredColumns); - TableLayout getLayout(Session session, TableLayoutHandle handle); + TableLayout getLayout(Session session, TableHandle handle); /** - * Return a table layout handle whose partitioning is converted to the provided partitioning handle, - * but otherwise identical to the provided table layout handle. - * The provided table layout handle must be one that the connector can transparently convert to from - * the original partitioning handle associated with the provided table layout handle, + * Return a table handle whose partitioning is converted to the provided partitioning handle, + * but otherwise identical to the provided table handle. + * The provided table handle must be one that the connector can transparently convert to from + * the original partitioning handle associated with the provided table handle, * as promised by {@link #getCommonPartitioning}. */ - TableLayoutHandle makeCompatiblePartitioning(Session session, TableLayoutHandle tableLayoutHandle, PartitioningHandle partitioningHandle); + TableHandle makeCompatiblePartitioning(Session session, TableHandle table, PartitioningHandle partitioningHandle); /** * Return a partitioning handle which the connector can transparently convert both {@code left} and {@code right} into. */ Optional getCommonPartitioning(Session session, PartitioningHandle left, PartitioningHandle right); - Optional getInfo(Session session, TableLayoutHandle handle); + Optional getInfo(Session session, TableHandle handle); /** * Return the metadata for the specified table handle. @@ -241,14 +241,14 @@ public interface Metadata /** * @return whether delete without table scan is supported */ - boolean supportsMetadataDelete(Session session, TableHandle tableHandle, TableLayoutHandle tableLayoutHandle); + boolean supportsMetadataDelete(Session session, TableHandle tableHandle); /** * Delete the provide table layout * * @return number of rows deleted, or empty for unknown */ - OptionalLong metadataDelete(Session session, TableHandle tableHandle, TableLayoutHandle tableLayoutHandle); + OptionalLong metadataDelete(Session session, TableHandle tableHandle); /** * Begin delete query diff --git a/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java b/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java index 49a76f3e8ffb..990a0922556e 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java +++ b/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java @@ -90,7 +90,6 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.prestosql.metadata.QualifiedObjectName.convertFromSchemaTableName; -import static io.prestosql.metadata.TableLayout.fromConnectorLayout; import static io.prestosql.metadata.ViewDefinition.ViewColumn; import static io.prestosql.spi.StandardErrorCode.INVALID_VIEW; import static io.prestosql.spi.StandardErrorCode.NOT_SUPPORTED; @@ -325,9 +324,14 @@ public Optional getTableHandle(Session session, QualifiedObjectName ConnectorId connectorId = catalogMetadata.getConnectorId(session, table); ConnectorMetadata metadata = catalogMetadata.getMetadataFor(connectorId); - ConnectorTableHandle tableHandle = metadata.getTableHandle(session.toConnectorSession(connectorId), table.asSchemaTableName()); + ConnectorSession connectorSession = session.toConnectorSession(connectorId); + ConnectorTableHandle tableHandle = metadata.getTableHandle(connectorSession, table.asSchemaTableName()); if (tableHandle != null) { - return Optional.of(new TableHandle(connectorId, tableHandle)); + return Optional.of(new TableHandle( + connectorId, + tableHandle, + catalogMetadata.getTransactionHandleFor(connectorId), + Optional.empty())); } } return Optional.empty(); @@ -346,7 +350,11 @@ public Optional getTableHandleForStatisticsCollection(Session sessi ConnectorTableHandle tableHandle = metadata.getTableHandleForStatisticsCollection(session.toConnectorSession(connectorId), table.asSchemaTableName(), analyzeProperties); if (tableHandle != null) { - return Optional.of(new TableHandle(connectorId, tableHandle)); + return Optional.of(new TableHandle( + connectorId, + tableHandle, + catalogMetadata.getTransactionHandleFor(connectorId), + Optional.empty())); } } return Optional.empty(); @@ -394,30 +402,39 @@ public Optional getLayout(Session session, TableHandle table, throw new PrestoException(NOT_SUPPORTED, format("Connector returned multiple layouts for table %s", table)); } - return Optional.of(new TableLayoutResult(fromConnectorLayout(connectorId, transaction, layouts.get(0).getTableLayout()), layouts.get(0).getUnenforcedConstraint())); + ConnectorTableLayout tableLayout = layouts.get(0).getTableLayout(); + return Optional.of(new TableLayoutResult( + new TableHandle(connectorId, connectorTable, transaction, Optional.of(tableLayout.getHandle())), + new TableLayout(connectorId, transaction, tableLayout), + layouts.get(0).getUnenforcedConstraint())); } @Override - public TableLayout getLayout(Session session, TableLayoutHandle handle) + public TableLayout getLayout(Session session, TableHandle handle) { ConnectorId connectorId = handle.getConnectorId(); CatalogMetadata catalogMetadata = getCatalogMetadata(session, connectorId); ConnectorMetadata metadata = catalogMetadata.getMetadataFor(connectorId); - ConnectorTransactionHandle transaction = catalogMetadata.getTransactionHandleFor(connectorId); - return fromConnectorLayout(connectorId, transaction, metadata.getTableLayout(session.toConnectorSession(connectorId), handle.getConnectorHandle())); + ConnectorSession connectorSession = session.toConnectorSession(connectorId); + + return handle.getLayout() + .map(layout -> new TableLayout(connectorId, handle.getTransaction(), metadata.getTableLayout(connectorSession, layout))) + .orElseGet(() -> getLayout(session, handle, Constraint.alwaysTrue(), Optional.empty()) + .get() + .getLayout()); } @Override - public TableLayoutHandle makeCompatiblePartitioning(Session session, TableLayoutHandle tableLayoutHandle, PartitioningHandle partitioningHandle) + public TableHandle makeCompatiblePartitioning(Session session, TableHandle tableHandle, PartitioningHandle partitioningHandle) { checkArgument(partitioningHandle.getConnectorId().isPresent(), "Expect partitioning handle from connector, got system partitioning handle"); ConnectorId connectorId = partitioningHandle.getConnectorId().get(); - checkArgument(connectorId.equals(tableLayoutHandle.getConnectorId()), "ConnectorId of tableLayoutHandle and partitioningHandle does not match"); + checkArgument(connectorId.equals(tableHandle.getConnectorId()), "ConnectorId of tableHandle and partitioningHandle does not match"); CatalogMetadata catalogMetadata = getCatalogMetadata(session, connectorId); ConnectorMetadata metadata = catalogMetadata.getMetadataFor(connectorId); ConnectorTransactionHandle transaction = catalogMetadata.getTransactionHandleFor(connectorId); - ConnectorTableLayoutHandle newTableLayoutHandle = metadata.makeCompatiblePartitioning(session.toConnectorSession(connectorId), tableLayoutHandle.getConnectorHandle(), partitioningHandle.getConnectorHandle()); - return new TableLayoutHandle(connectorId, transaction, newTableLayoutHandle); + ConnectorTableLayoutHandle newTableLayoutHandle = metadata.makeCompatiblePartitioning(session.toConnectorSession(connectorId), tableHandle.getLayout().get(), partitioningHandle.getConnectorHandle()); + return new TableHandle(connectorId, tableHandle.getConnectorHandle(), transaction, Optional.of(newTableLayoutHandle)); } @Override @@ -439,12 +456,19 @@ public Optional getCommonPartitioning(Session session, Parti } @Override - public Optional getInfo(Session session, TableLayoutHandle handle) + public Optional getInfo(Session session, TableHandle handle) { ConnectorId connectorId = handle.getConnectorId(); ConnectorMetadata metadata = getMetadata(session, connectorId); - ConnectorTableLayout tableLayout = metadata.getTableLayout(session.toConnectorSession(connectorId), handle.getConnectorHandle()); - return metadata.getInfo(tableLayout.getHandle()); + + ConnectorTableLayoutHandle layoutHandle = handle.getLayout() + .orElseGet(() -> getLayout(session, handle, Constraint.alwaysTrue(), Optional.empty()) + .get() + .getNewTableHandle() + .getLayout() + .get()); + + return metadata.getInfo(layoutHandle); } @Override @@ -785,22 +809,22 @@ public ColumnHandle getUpdateRowIdColumnHandle(Session session, TableHandle tabl } @Override - public boolean supportsMetadataDelete(Session session, TableHandle tableHandle, TableLayoutHandle tableLayoutHandle) + public boolean supportsMetadataDelete(Session session, TableHandle tableHandle) { ConnectorId connectorId = tableHandle.getConnectorId(); ConnectorMetadata metadata = getMetadata(session, connectorId); return metadata.supportsMetadataDelete( session.toConnectorSession(connectorId), tableHandle.getConnectorHandle(), - tableLayoutHandle.getConnectorHandle()); + tableHandle.getLayout().get()); } @Override - public OptionalLong metadataDelete(Session session, TableHandle tableHandle, TableLayoutHandle tableLayoutHandle) + public OptionalLong metadataDelete(Session session, TableHandle tableHandle) { ConnectorId connectorId = tableHandle.getConnectorId(); ConnectorMetadata metadata = getMetadataForWrite(session, connectorId); - return metadata.metadataDelete(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle(), tableLayoutHandle.getConnectorHandle()); + return metadata.metadataDelete(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle(), tableHandle.getLayout().get()); } @Override @@ -809,7 +833,7 @@ public TableHandle beginDelete(Session session, TableHandle tableHandle) ConnectorId connectorId = tableHandle.getConnectorId(); ConnectorMetadata metadata = getMetadataForWrite(session, connectorId); ConnectorTableHandle newHandle = metadata.beginDelete(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle()); - return new TableHandle(tableHandle.getConnectorId(), newHandle); + return new TableHandle(tableHandle.getConnectorId(), newHandle, tableHandle.getTransaction(), tableHandle.getLayout()); } @Override diff --git a/presto-main/src/main/java/io/prestosql/metadata/TableHandle.java b/presto-main/src/main/java/io/prestosql/metadata/TableHandle.java index a117a664d0b7..dd987b7c01ba 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/TableHandle.java +++ b/presto-main/src/main/java/io/prestosql/metadata/TableHandle.java @@ -17,8 +17,10 @@ import com.fasterxml.jackson.annotation.JsonProperty; import io.prestosql.connector.ConnectorId; import io.prestosql.spi.connector.ConnectorTableHandle; +import io.prestosql.spi.connector.ConnectorTableLayoutHandle; +import io.prestosql.spi.connector.ConnectorTransactionHandle; -import java.util.Objects; +import java.util.Optional; import static java.util.Objects.requireNonNull; @@ -26,14 +28,23 @@ public final class TableHandle { private final ConnectorId connectorId; private final ConnectorTableHandle connectorHandle; + private final ConnectorTransactionHandle transaction; + + // Table layouts are deprecated, but we keep this here to hide the notion of layouts + // from the engine. TODO: it should be removed once table layouts are finally deleted + private final Optional layout; @JsonCreator public TableHandle( @JsonProperty("connectorId") ConnectorId connectorId, - @JsonProperty("connectorHandle") ConnectorTableHandle connectorHandle) + @JsonProperty("connectorHandle") ConnectorTableHandle connectorHandle, + @JsonProperty("transaction") ConnectorTransactionHandle transaction, + @JsonProperty("layout") Optional layout) { this.connectorId = requireNonNull(connectorId, "connectorId is null"); this.connectorHandle = requireNonNull(connectorHandle, "connectorHandle is null"); + this.transaction = requireNonNull(transaction, "transaction is null"); + this.layout = requireNonNull(layout, "layout is null"); } @JsonProperty @@ -48,24 +59,16 @@ public ConnectorTableHandle getConnectorHandle() return connectorHandle; } - @Override - public int hashCode() + @JsonProperty + public Optional getLayout() { - return Objects.hash(connectorId, connectorHandle); + return layout; } - @Override - public boolean equals(Object obj) + @JsonProperty + public ConnectorTransactionHandle getTransaction() { - if (this == obj) { - return true; - } - if (obj == null || getClass() != obj.getClass()) { - return false; - } - final TableHandle other = (TableHandle) obj; - return Objects.equals(this.connectorId, other.connectorId) && - Objects.equals(this.connectorHandle, other.connectorHandle); + return transaction; } @Override diff --git a/presto-main/src/main/java/io/prestosql/metadata/TableLayout.java b/presto-main/src/main/java/io/prestosql/metadata/TableLayout.java index 954280af0256..2d1c238cf0eb 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/TableLayout.java +++ b/presto-main/src/main/java/io/prestosql/metadata/TableLayout.java @@ -32,15 +32,18 @@ public class TableLayout { - private final TableLayoutHandle handle; private final ConnectorTableLayout layout; + private final ConnectorId connectorId; + private final ConnectorTransactionHandle transaction; - public TableLayout(TableLayoutHandle handle, ConnectorTableLayout layout) + public TableLayout(ConnectorId connectorId, ConnectorTransactionHandle transaction, ConnectorTableLayout layout) { - requireNonNull(handle, "handle is null"); + requireNonNull(connectorId, "connectorId is null"); + requireNonNull(transaction, "transaction is null"); requireNonNull(layout, "layout is null"); - this.handle = handle; + this.connectorId = connectorId; + this.transaction = transaction; this.layout = layout; } @@ -59,18 +62,13 @@ public List> getLocalProperties() return layout.getLocalProperties(); } - public TableLayoutHandle getHandle() - { - return handle; - } - public Optional getTablePartitioning() { return layout.getTablePartitioning() .map(nodePartitioning -> new TablePartitioning( new PartitioningHandle( - Optional.of(handle.getConnectorId()), - Optional.of(handle.getTransactionHandle()), + Optional.of(connectorId), + Optional.of(transaction), nodePartitioning.getPartitioningHandle()), nodePartitioning.getPartitioningColumns())); } @@ -85,11 +83,6 @@ public Optional getDiscretePredicates() return layout.getDiscretePredicates(); } - public static TableLayout fromConnectorLayout(ConnectorId connectorId, ConnectorTransactionHandle transactionHandle, ConnectorTableLayout layout) - { - return new TableLayout(new TableLayoutHandle(connectorId, transactionHandle, layout.getHandle()), layout); - } - public static class TablePartitioning { private final PartitioningHandle partitioningHandle; diff --git a/presto-main/src/main/java/io/prestosql/metadata/TableLayoutHandle.java b/presto-main/src/main/java/io/prestosql/metadata/TableLayoutHandle.java deleted file mode 100644 index 7b1cae2e23fd..000000000000 --- a/presto-main/src/main/java/io/prestosql/metadata/TableLayoutHandle.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prestosql.metadata; - -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonProperty; -import io.prestosql.connector.ConnectorId; -import io.prestosql.spi.connector.ConnectorTableLayoutHandle; -import io.prestosql.spi.connector.ConnectorTransactionHandle; - -import java.util.Objects; - -import static java.util.Objects.requireNonNull; - -public final class TableLayoutHandle -{ - private final ConnectorId connectorId; - private final ConnectorTransactionHandle transactionHandle; - private final ConnectorTableLayoutHandle layout; - - @JsonCreator - public TableLayoutHandle( - @JsonProperty("connectorId") ConnectorId connectorId, - @JsonProperty("transactionHandle") ConnectorTransactionHandle transactionHandle, - @JsonProperty("connectorHandle") ConnectorTableLayoutHandle layout) - { - requireNonNull(connectorId, "connectorId is null"); - requireNonNull(transactionHandle, "transactionHandle is null"); - requireNonNull(layout, "layout is null"); - - this.connectorId = connectorId; - this.transactionHandle = transactionHandle; - this.layout = layout; - } - - @JsonProperty - public ConnectorId getConnectorId() - { - return connectorId; - } - - @JsonProperty - public ConnectorTransactionHandle getTransactionHandle() - { - return transactionHandle; - } - - @JsonProperty - public ConnectorTableLayoutHandle getConnectorHandle() - { - return layout; - } - - @Override - public boolean equals(Object o) - { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - TableLayoutHandle that = (TableLayoutHandle) o; - return Objects.equals(connectorId, that.connectorId) && - Objects.equals(transactionHandle, that.transactionHandle) && - Objects.equals(layout, that.layout); - } - - @Override - public int hashCode() - { - return Objects.hash(connectorId, transactionHandle, layout); - } -} diff --git a/presto-main/src/main/java/io/prestosql/metadata/TableLayoutResult.java b/presto-main/src/main/java/io/prestosql/metadata/TableLayoutResult.java index c545333fa995..dc02158054c8 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/TableLayoutResult.java +++ b/presto-main/src/main/java/io/prestosql/metadata/TableLayoutResult.java @@ -25,15 +25,22 @@ public class TableLayoutResult { + private final TableHandle newTableHandle; private final TableLayout layout; private final TupleDomain unenforcedConstraint; - public TableLayoutResult(TableLayout layout, TupleDomain unenforcedConstraint) + public TableLayoutResult(TableHandle newTable, TableLayout layout, TupleDomain unenforcedConstraint) { + this.newTableHandle = requireNonNull(newTable, "newTable is null"); this.layout = requireNonNull(layout, "layout is null"); this.unenforcedConstraint = requireNonNull(unenforcedConstraint, "unenforcedConstraint is null"); } + public TableHandle getNewTableHandle() + { + return newTableHandle; + } + public TableLayout getLayout() { return layout; diff --git a/presto-main/src/main/java/io/prestosql/operator/MetadataDeleteOperator.java b/presto-main/src/main/java/io/prestosql/operator/MetadataDeleteOperator.java index d264d2e1953c..fd174e895e3d 100644 --- a/presto-main/src/main/java/io/prestosql/operator/MetadataDeleteOperator.java +++ b/presto-main/src/main/java/io/prestosql/operator/MetadataDeleteOperator.java @@ -17,7 +17,6 @@ import io.prestosql.Session; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.TableHandle; -import io.prestosql.metadata.TableLayoutHandle; import io.prestosql.spi.Page; import io.prestosql.spi.PageBuilder; import io.prestosql.spi.block.BlockBuilder; @@ -41,17 +40,15 @@ public static class MetadataDeleteOperatorFactory { private final int operatorId; private final PlanNodeId planNodeId; - private final TableLayoutHandle tableLayout; private final Metadata metadata; private final Session session; private final TableHandle tableHandle; private boolean closed; - public MetadataDeleteOperatorFactory(int operatorId, PlanNodeId planNodeId, TableLayoutHandle tableLayout, Metadata metadata, Session session, TableHandle tableHandle) + public MetadataDeleteOperatorFactory(int operatorId, PlanNodeId planNodeId, Metadata metadata, Session session, TableHandle tableHandle) { this.operatorId = operatorId; this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); - this.tableLayout = requireNonNull(tableLayout, "tableLayout is null"); this.metadata = requireNonNull(metadata, "metadata is null"); this.session = requireNonNull(session, "session is null"); this.tableHandle = requireNonNull(tableHandle, "tableHandle is null"); @@ -62,7 +59,7 @@ public Operator createOperator(DriverContext driverContext) { checkState(!closed, "Factory is already closed"); OperatorContext context = driverContext.addOperatorContext(operatorId, planNodeId, MetadataDeleteOperator.class.getSimpleName()); - return new MetadataDeleteOperator(context, tableLayout, metadata, session, tableHandle); + return new MetadataDeleteOperator(context, metadata, session, tableHandle); } @Override @@ -74,22 +71,20 @@ public void noMoreOperators() @Override public OperatorFactory duplicate() { - return new MetadataDeleteOperatorFactory(operatorId, planNodeId, tableLayout, metadata, session, tableHandle); + return new MetadataDeleteOperatorFactory(operatorId, planNodeId, metadata, session, tableHandle); } } private final OperatorContext operatorContext; - private final TableLayoutHandle tableLayout; private final Metadata metadata; private final Session session; private final TableHandle tableHandle; private boolean finished; - public MetadataDeleteOperator(OperatorContext operatorContext, TableLayoutHandle tableLayout, Metadata metadata, Session session, TableHandle tableHandle) + public MetadataDeleteOperator(OperatorContext operatorContext, Metadata metadata, Session session, TableHandle tableHandle) { this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); - this.tableLayout = requireNonNull(tableLayout, "tableLayout is null"); this.metadata = requireNonNull(metadata, "metadata is null"); this.session = requireNonNull(session, "session is null"); this.tableHandle = requireNonNull(tableHandle, "tableHandle is null"); @@ -132,7 +127,7 @@ public Page getOutput() } finished = true; - OptionalLong rowsDeletedCount = metadata.metadataDelete(session, tableHandle, tableLayout); + OptionalLong rowsDeletedCount = metadata.metadataDelete(session, tableHandle); // output page will only be constructed once, // so a new PageBuilder is constructed (instead of using PageBuilder.reset) diff --git a/presto-main/src/main/java/io/prestosql/split/SplitManager.java b/presto-main/src/main/java/io/prestosql/split/SplitManager.java index 7a9bc42fa3a2..72ad74bce3c6 100644 --- a/presto-main/src/main/java/io/prestosql/split/SplitManager.java +++ b/presto-main/src/main/java/io/prestosql/split/SplitManager.java @@ -16,14 +16,18 @@ import io.prestosql.Session; import io.prestosql.connector.ConnectorId; import io.prestosql.execution.QueryManagerConfig; -import io.prestosql.metadata.TableLayoutHandle; +import io.prestosql.metadata.Metadata; +import io.prestosql.metadata.TableHandle; import io.prestosql.spi.connector.ConnectorSession; import io.prestosql.spi.connector.ConnectorSplitManager; import io.prestosql.spi.connector.ConnectorSplitManager.SplitSchedulingStrategy; import io.prestosql.spi.connector.ConnectorSplitSource; +import io.prestosql.spi.connector.ConnectorTableLayoutHandle; +import io.prestosql.spi.connector.Constraint; import javax.inject.Inject; +import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -36,10 +40,15 @@ public class SplitManager private final ConcurrentMap splitManagers = new ConcurrentHashMap<>(); private final int minScheduleSplitBatchSize; + // This is used to fetch a table layout if the TableHandle doesn't have one set + // It's a temporary measure until we get rid of TableLayouts from the SPI + private final Metadata metadata; + @Inject - public SplitManager(QueryManagerConfig config) + public SplitManager(QueryManagerConfig config, Metadata metadata) { this.minScheduleSplitBatchSize = config.getMinScheduleSplitBatchSize(); + this.metadata = metadata; } public void addConnectorSplitManager(ConnectorId connectorId, ConnectorSplitManager connectorSplitManager) @@ -54,20 +63,27 @@ public void removeConnectorSplitManager(ConnectorId connectorId) splitManagers.remove(connectorId); } - public SplitSource getSplits(Session session, TableLayoutHandle layout, SplitSchedulingStrategy splitSchedulingStrategy) + public SplitSource getSplits(Session session, TableHandle table, SplitSchedulingStrategy splitSchedulingStrategy) { - ConnectorId connectorId = layout.getConnectorId(); + ConnectorId connectorId = table.getConnectorId(); ConnectorSplitManager splitManager = getConnectorSplitManager(connectorId); ConnectorSession connectorSession = session.toConnectorSession(connectorId); + ConnectorTableLayoutHandle layout = table.getLayout() + .orElseGet(() -> metadata.getLayout(session, table, Constraint.alwaysTrue(), Optional.empty()) + .get() + .getNewTableHandle() + .getLayout() + .get()); + ConnectorSplitSource source = splitManager.getSplits( - layout.getTransactionHandle(), + table.getTransaction(), connectorSession, - layout.getConnectorHandle(), + layout, splitSchedulingStrategy); - SplitSource splitSource = new ConnectorAwareSplitSource(connectorId, layout.getTransactionHandle(), source); + SplitSource splitSource = new ConnectorAwareSplitSource(connectorId, table.getTransaction(), source); if (minScheduleSplitBatchSize > 1) { splitSource = new BufferingSplitSource(splitSource, minScheduleSplitBatchSize); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/DistributedExecutionPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/DistributedExecutionPlanner.java index bae4b88fbf3c..6ec1f2cf07b3 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/DistributedExecutionPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/DistributedExecutionPlanner.java @@ -146,7 +146,7 @@ public Map visitTableScan(TableScanNode node, Void cont // get dataSource for table SplitSource splitSource = splitManager.getSplits( session, - node.getLayout().get(), + node.getTable(), stageExecutionDescriptor.isScanGroupedExecution(node.getId()) ? GROUPED_SCHEDULING : UNGROUPED_SCHEDULING); splitSources.add(splitSource); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/InputExtractor.java b/presto-main/src/main/java/io/prestosql/sql/planner/InputExtractor.java index d2f17ccf7f06..69b93766656c 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/InputExtractor.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/InputExtractor.java @@ -20,8 +20,6 @@ import io.prestosql.execution.Input; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.TableHandle; -import io.prestosql.metadata.TableLayoutHandle; -import io.prestosql.metadata.TableMetadata; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ColumnMetadata; import io.prestosql.spi.connector.SchemaTableName; @@ -59,10 +57,10 @@ private static Column createColumn(ColumnMetadata columnMetadata) return new Column(columnMetadata.getName(), columnMetadata.getType().toString()); } - private Input createInput(TableMetadata table, Optional layout, Set columns) + private Input createInput(Session session, TableHandle table, Set columns) { - SchemaTableName schemaTable = table.getTable(); - Optional inputMetadata = layout.flatMap(tableLayout -> metadata.getInfo(session, tableLayout)); + SchemaTableName schemaTable = metadata.getTableMetadata(session, table).getTable(); + Optional inputMetadata = metadata.getInfo(session, table); return new Input(table.getConnectorId(), schemaTable.getSchemaName(), schemaTable.getTableName(), inputMetadata, ImmutableList.copyOf(columns)); } @@ -86,7 +84,7 @@ public Void visitTableScan(TableScanNode node, Void context) columns.add(createColumn(metadata.getColumnMetadata(session, tableHandle, columnHandle))); } - inputs.add(createInput(metadata.getTableMetadata(session, tableHandle), node.getLayout(), columns)); + inputs.add(createInput(session, tableHandle, columns)); return null; } @@ -101,7 +99,7 @@ public Void visitIndexSource(IndexSourceNode node, Void context) columns.add(createColumn(metadata.getColumnMetadata(session, tableHandle, columnHandle))); } - inputs.add(createInput(metadata.getTableMetadata(session, tableHandle), node.getLayout(), columns)); + inputs.add(createInput(session, tableHandle, columns)); return null; } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java index 4aad7f3b2d90..db5b3d1ef9a1 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java @@ -2343,7 +2343,7 @@ public PhysicalOperation visitDelete(DeleteNode node, LocalExecutionPlanContext @Override public PhysicalOperation visitMetadataDelete(MetadataDeleteNode node, LocalExecutionPlanContext context) { - OperatorFactory operatorFactory = new MetadataDeleteOperatorFactory(context.getNextOperatorId(), node.getId(), node.getTableLayout(), metadata, session, node.getTarget().getHandle()); + OperatorFactory operatorFactory = new MetadataDeleteOperatorFactory(context.getNextOperatorId(), node.getId(), metadata, session, node.getTarget().getHandle()); return new PhysicalOperation(operatorFactory, makeLayout(node), context, UNGROUPED_EXECUTION); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/LogicalPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/LogicalPlanner.java index 557685f6d1e8..eb453f53dd03 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/LogicalPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/LogicalPlanner.java @@ -264,7 +264,7 @@ private RelationPlan createAnalyzePlan(Analysis analysis, Analyze analyzeStateme idAllocator.getNextId(), new AggregationNode( idAllocator.getNextId(), - new TableScanNode(idAllocator.getNextId(), targetTable, tableScanOutputs.build(), symbolToColumnHandle.build()), + TableScanNode.newInstance(idAllocator.getNextId(), targetTable, tableScanOutputs.build(), symbolToColumnHandle.build()), statisticAggregations.getAggregations(), singleGroupingSet(groupingSymbols), ImmutableList.of(), diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/PlanFragmenter.java b/presto-main/src/main/java/io/prestosql/sql/planner/PlanFragmenter.java index c84dbae72976..90ee22c5fb92 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/PlanFragmenter.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/PlanFragmenter.java @@ -21,9 +21,8 @@ import io.prestosql.cost.StatsAndCosts; import io.prestosql.execution.QueryManagerConfig; import io.prestosql.metadata.Metadata; -import io.prestosql.metadata.TableLayout; +import io.prestosql.metadata.TableHandle; import io.prestosql.metadata.TableLayout.TablePartitioning; -import io.prestosql.metadata.TableLayoutHandle; import io.prestosql.spi.PrestoException; import io.prestosql.spi.connector.ConnectorPartitionHandle; import io.prestosql.spi.connector.ConnectorPartitioningHandle; @@ -279,9 +278,8 @@ public PlanNode visitMetadataDelete(MetadataDeleteNode node, RewriteContext context) { - PartitioningHandle partitioning = node.getLayout() - .map(layout -> metadata.getLayout(session, layout)) - .flatMap(TableLayout::getTablePartitioning) + PartitioningHandle partitioning = metadata.getLayout(session, node.getTable()) + .getTablePartitioning() .map(TablePartitioning::getPartitioningHandle) .orElse(SOURCE_DISTRIBUTION); @@ -645,7 +643,7 @@ private GroupedExecutionProperties processWindowFunction(PlanNode node) @Override public GroupedExecutionProperties visitTableScan(TableScanNode node, Void context) { - Optional tablePartitioning = metadata.getLayout(session, node.getLayout().get()).getTablePartitioning(); + Optional tablePartitioning = metadata.getLayout(session, node.getTable()).getTablePartitioning(); if (!tablePartitioning.isPresent()) { return GroupedExecutionProperties.notCapable(); } @@ -750,9 +748,8 @@ public PartitioningHandleReassigner(PartitioningHandle fragmentPartitioningHandl @Override public PlanNode visitTableScan(TableScanNode node, RewriteContext context) { - PartitioningHandle partitioning = node.getLayout() - .map(layout -> metadata.getLayout(session, layout)) - .flatMap(TableLayout::getTablePartitioning) + PartitioningHandle partitioning = metadata.getLayout(session, node.getTable()) + .getTablePartitioning() .map(TablePartitioning::getPartitioningHandle) .orElse(SOURCE_DISTRIBUTION); if (partitioning.equals(fragmentPartitioningHandle)) { @@ -760,13 +757,12 @@ public PlanNode visitTableScan(TableScanNode node, RewriteContext context) return node; } - TableLayoutHandle newTableLayoutHandle = metadata.makeCompatiblePartitioning(session, node.getLayout().get(), fragmentPartitioningHandle); + TableHandle newTable = metadata.makeCompatiblePartitioning(session, node.getTable(), fragmentPartitioningHandle); return new TableScanNode( node.getId(), - node.getTable(), + newTable, node.getOutputSymbols(), node.getAssignments(), - Optional.of(newTableLayoutHandle), node.getCurrentConstraint(), node.getEnforcedConstraint()); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java index 7eb783582d48..e7c2a55c3f36 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java @@ -55,7 +55,6 @@ import io.prestosql.sql.planner.iterative.rule.MergeLimitWithTopN; import io.prestosql.sql.planner.iterative.rule.MergeLimits; import io.prestosql.sql.planner.iterative.rule.MultipleDistinctAggregationToMarkDistinct; -import io.prestosql.sql.planner.iterative.rule.PickTableLayout; import io.prestosql.sql.planner.iterative.rule.PruneAggregationColumns; import io.prestosql.sql.planner.iterative.rule.PruneAggregationSourceColumns; import io.prestosql.sql.planner.iterative.rule.PruneCountAggregationOverScalar; @@ -82,6 +81,7 @@ import io.prestosql.sql.planner.iterative.rule.PushLimitThroughSemiJoin; import io.prestosql.sql.planner.iterative.rule.PushPartialAggregationThroughExchange; import io.prestosql.sql.planner.iterative.rule.PushPartialAggregationThroughJoin; +import io.prestosql.sql.planner.iterative.rule.PushPredicateIntoTableScan; import io.prestosql.sql.planner.iterative.rule.PushProjectionThroughExchange; import io.prestosql.sql.planner.iterative.rule.PushProjectionThroughUnion; import io.prestosql.sql.planner.iterative.rule.PushRemoteExchangeThroughAssignUniqueId; @@ -357,7 +357,7 @@ public PlanOptimizers( ruleStats, statsCalculator, estimatedExchangesCostCalculator, - new PickTableLayout(metadata, sqlParser).rules()), + new PushPredicateIntoTableScan(metadata, sqlParser).rules()), new PruneUnreferencedOutputs(), new IterativeOptimizer( ruleStats, @@ -407,7 +407,7 @@ public PlanOptimizers( ruleStats, statsCalculator, estimatedExchangesCostCalculator, - new PickTableLayout(metadata, sqlParser).rules()), + new PushPredicateIntoTableScan(metadata, sqlParser).rules()), projectionPushDown, new PruneUnreferencedOutputs(), new IterativeOptimizer( diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/QueryPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/QueryPlanner.java index a880f0b1ff55..8efe82ed5bbc 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/QueryPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/QueryPlanner.java @@ -218,7 +218,7 @@ public DeleteNode plan(Delete node) fields.add(rowIdField); // create table scan - PlanNode tableScan = new TableScanNode(idAllocator.getNextId(), handle, outputSymbols.build(), columns.build()); + PlanNode tableScan = TableScanNode.newInstance(idAllocator.getNextId(), handle, outputSymbols.build(), columns.build()); Scope scope = Scope.builder().withRelationType(RelationId.anonymous(), new RelationType(fields.build())).build(); RelationPlan relationPlan = new RelationPlan(tableScan, scope, outputSymbols.build()); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/RelationPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/RelationPlanner.java index 3f98e345be67..3b1b26304a95 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/RelationPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/RelationPlanner.java @@ -158,7 +158,7 @@ protected RelationPlan visitTable(Table node, Void context) } List outputSymbols = outputSymbolsBuilder.build(); - PlanNode root = new TableScanNode(idAllocator.getNextId(), handle, outputSymbols, columns.build()); + PlanNode root = TableScanNode.newInstance(idAllocator.getNextId(), handle, outputSymbols, columns.build()); return new RelationPlan(root, scope, outputSymbols); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractSpatialJoins.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractSpatialJoins.java index 9a5accbfaf8c..b0c118ace6de 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractSpatialJoins.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractSpatialJoins.java @@ -31,12 +31,10 @@ import io.prestosql.metadata.QualifiedObjectName; import io.prestosql.metadata.Split; import io.prestosql.metadata.TableHandle; -import io.prestosql.metadata.TableLayoutResult; import io.prestosql.spi.Page; import io.prestosql.spi.PrestoException; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ConnectorPageSource; -import io.prestosql.spi.connector.Constraint; import io.prestosql.spi.type.ArrayType; import io.prestosql.spi.type.Type; import io.prestosql.spi.type.TypeSignature; @@ -469,11 +467,8 @@ private static KdbTree loadKdbTree(String tableName, Session session, Metadata m ColumnHandle kdbTreeColumn = Iterables.getOnlyElement(visibleColumnHandles); - Optional layout = metadata.getLayout(session, tableHandle, Constraint.alwaysTrue(), Optional.of(ImmutableSet.of(kdbTreeColumn))); - checkSpatialPartitioningTable(layout.isPresent(), "Table is empty: %s", name); - Optional kdbTree = Optional.empty(); - try (SplitSource splitSource = splitManager.getSplits(session, layout.get().getLayout().getHandle(), UNGROUPED_SCHEDULING)) { + try (SplitSource splitSource = splitManager.getSplits(session, tableHandle, UNGROUPED_SCHEDULING)) { while (!Thread.currentThread().isInterrupted()) { SplitBatch splitBatch = getFutureValue(splitSource.getNextBatch(NOT_PARTITIONED, Lifespan.taskWide(), 1000)); List splits = splitBatch.getSplits(); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PruneIndexSourceColumns.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PruneIndexSourceColumns.java index 6fa88ff5e9bf..b6e0c5bc6236 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PruneIndexSourceColumns.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PruneIndexSourceColumns.java @@ -60,7 +60,6 @@ protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, indexSourceNode.getId(), indexSourceNode.getIndexHandle(), indexSourceNode.getTableHandle(), - indexSourceNode.getLayout(), prunedLookupSymbols, prunedOutputList, prunedAssignments, diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PruneTableScanColumns.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PruneTableScanColumns.java index e8cacd4c31f5..928fa34367f7 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PruneTableScanColumns.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PruneTableScanColumns.java @@ -42,7 +42,6 @@ protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, tableScanNode.getTable(), filteredCopy(tableScanNode.getOutputSymbols(), referencedOutputs::contains), filterKeys(tableScanNode.getAssignments(), referencedOutputs::contains), - tableScanNode.getLayout(), tableScanNode.getCurrentConstraint(), tableScanNode.getEnforcedConstraint())); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PickTableLayout.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushPredicateIntoTableScan.java similarity index 82% rename from presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PickTableLayout.java rename to presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushPredicateIntoTableScan.java index 034c48cdcb0d..fead481c3d14 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PickTableLayout.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushPredicateIntoTableScan.java @@ -73,13 +73,13 @@ * These rules should not be run after AddExchanges so as not to overwrite the TableLayout * chosen by AddExchanges */ -public class PickTableLayout +public class PushPredicateIntoTableScan { private final Metadata metadata; private final SqlParser parser; private final DomainTranslator domainTranslator; - public PickTableLayout(Metadata metadata, SqlParser parser) + public PushPredicateIntoTableScan(Metadata metadata, SqlParser parser) { this.metadata = requireNonNull(metadata, "metadata is null"); this.parser = requireNonNull(parser, "parser is null"); @@ -88,10 +88,7 @@ public PickTableLayout(Metadata metadata, SqlParser parser) public Set> rules() { - return ImmutableSet.of( - checkRulesAreFiredBeforeAddExchangesRule(), - pickTableLayoutForPredicate(), - pickTableLayoutWithoutPredicate()); + return ImmutableSet.of(pickTableLayoutForPredicate()); } public PickTableLayoutForPredicate pickTableLayoutForPredicate() @@ -99,11 +96,6 @@ public PickTableLayoutForPredicate pickTableLayoutForPredicate() return new PickTableLayoutForPredicate(metadata, parser, domainTranslator); } - public PickTableLayoutWithoutPredicate pickTableLayoutWithoutPredicate() - { - return new PickTableLayoutWithoutPredicate(metadata); - } - private static final class PickTableLayoutForPredicate implements Rule { @@ -175,69 +167,11 @@ private boolean arePlansSame(FilterNode filter, TableScanNode tableScan, PlanNod TableScanNode rewrittenTableScan = (TableScanNode) rewrittenFilter.getSource(); - if (!tableScan.getLayout().isPresent() && rewrittenTableScan.getLayout().isPresent()) { - return false; - } - return Objects.equals(tableScan.getCurrentConstraint(), rewrittenTableScan.getCurrentConstraint()) && Objects.equals(tableScan.getEnforcedConstraint(), rewrittenTableScan.getEnforcedConstraint()); } } - private static final class PickTableLayoutWithoutPredicate - implements Rule - { - private final Metadata metadata; - - private PickTableLayoutWithoutPredicate(Metadata metadata) - { - this.metadata = requireNonNull(metadata, "metadata is null"); - } - - private static final Pattern PATTERN = tableScan(); - - @Override - public Pattern getPattern() - { - return PATTERN; - } - - @Override - public boolean isEnabled(Session session) - { - return isNewOptimizerEnabled(session); - } - - @Override - public Result apply(TableScanNode tableScanNode, Captures captures, Context context) - { - if (tableScanNode.getLayout().isPresent()) { - return Result.empty(); - } - - Optional layout = metadata.getLayout( - context.getSession(), - tableScanNode.getTable(), - Constraint.alwaysTrue(), - Optional.of(tableScanNode.getOutputSymbols().stream() - .map(tableScanNode.getAssignments()::get) - .collect(toImmutableSet()))); - - if (!layout.isPresent() || layout.get().getLayout().getPredicate().isNone()) { - return Result.ofPlanNode(new ValuesNode(context.getIdAllocator().getNextId(), tableScanNode.getOutputSymbols(), ImmutableList.of())); - } - - return Result.ofPlanNode(new TableScanNode( - tableScanNode.getId(), - tableScanNode.getTable(), - tableScanNode.getOutputSymbols(), - tableScanNode.getAssignments(), - Optional.of(layout.get().getLayout().getHandle()), - TupleDomain.all(), - TupleDomain.all())); - } - } - public static PlanNode pushFilterIntoTableScan( TableScanNode node, Expression predicate, @@ -299,10 +233,9 @@ public static PlanNode pushFilterIntoTableScan( TableScanNode tableScan = new TableScanNode( node.getId(), - node.getTable(), + layout.get().getNewTableHandle(), node.getOutputSymbols(), node.getAssignments(), - Optional.of(layout.get().getLayout().getHandle()), layout.get().getLayout().getPredicate(), computeEnforced(newDomain, layout.get().getUnenforcedConstraint())); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddExchanges.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddExchanges.java index a4f3a11a331b..e4a97ef1298b 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddExchanges.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddExchanges.java @@ -35,7 +35,7 @@ import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.SymbolAllocator; import io.prestosql.sql.planner.TypeProvider; -import io.prestosql.sql.planner.iterative.rule.PickTableLayout; +import io.prestosql.sql.planner.iterative.rule.PushPredicateIntoTableScan; import io.prestosql.sql.planner.plan.AggregationNode; import io.prestosql.sql.planner.plan.ApplyNode; import io.prestosql.sql.planner.plan.Assignments; @@ -532,7 +532,7 @@ else if (redistributeWrites) { private PlanWithProperties planTableScan(TableScanNode node, Expression predicate) { - PlanNode plan = PickTableLayout.pushFilterIntoTableScan(node, predicate, true, session, types, idAllocator, metadata, parser, domainTranslator); + PlanNode plan = PushPredicateIntoTableScan.pushFilterIntoTableScan(node, predicate, true, session, types, idAllocator, metadata, parser, domainTranslator); return new PlanWithProperties(plan, derivePropertiesRecursively(plan)); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/BeginTableWrite.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/BeginTableWrite.java index 4b35396aa74b..85495eba72b7 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/BeginTableWrite.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/BeginTableWrite.java @@ -14,16 +14,11 @@ package io.prestosql.sql.planner.optimizations; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import io.prestosql.Session; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.TableHandle; -import io.prestosql.metadata.TableLayoutResult; -import io.prestosql.spi.connector.ColumnHandle; -import io.prestosql.spi.connector.Constraint; -import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.sql.planner.PlanNodeIdAllocator; import io.prestosql.sql.planner.SymbolAllocator; import io.prestosql.sql.planner.TypeProvider; @@ -45,7 +40,6 @@ import java.util.Set; import static com.google.common.base.Preconditions.checkState; -import static io.prestosql.metadata.TableLayoutResult.computeEnforced; import static io.prestosql.sql.planner.optimizations.QueryCardinalityUtil.isAtMostScalar; import static io.prestosql.sql.planner.plan.ChildReplacer.replaceChildren; import static java.util.stream.Collectors.toSet; @@ -193,22 +187,13 @@ private PlanNode rewriteDeleteTableScan(PlanNode node, TableHandle handle) { if (node instanceof TableScanNode) { TableScanNode scan = (TableScanNode) node; - TupleDomain originalEnforcedConstraint = scan.getEnforcedConstraint(); - - Optional layout = metadata.getLayout( - session, - handle, - new Constraint<>(originalEnforcedConstraint), - Optional.of(ImmutableSet.copyOf(scan.getAssignments().values()))); - return new TableScanNode( scan.getId(), handle, scan.getOutputSymbols(), scan.getAssignments(), - Optional.of(layout.get().getLayout().getHandle()), - layout.get().getLayout().getPredicate(), - computeEnforced(originalEnforcedConstraint, layout.get().getUnenforcedConstraint())); + scan.getCurrentConstraint(), + scan.getEnforcedConstraint()); } if (node instanceof FilterNode) { diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/IndexJoinOptimizer.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/IndexJoinOptimizer.java index ae72b2958f5c..a4b7de827f3e 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/IndexJoinOptimizer.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/IndexJoinOptimizer.java @@ -301,7 +301,6 @@ private PlanNode planTableScan(TableScanNode node, Expression predicate, Context idAllocator.getNextId(), resolvedIndex.getIndexHandle(), node.getTable(), - node.getLayout(), context.getLookupSymbols(), node.getOutputSymbols(), node.getAssignments(), diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/MetadataDeleteOptimizer.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/MetadataDeleteOptimizer.java index d16369289be2..10614c65a26b 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/MetadataDeleteOptimizer.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/MetadataDeleteOptimizer.java @@ -27,6 +27,7 @@ import io.prestosql.sql.planner.plan.SimplePlanRewriter; import io.prestosql.sql.planner.plan.TableFinishNode; import io.prestosql.sql.planner.plan.TableScanNode; +import io.prestosql.sql.planner.plan.TableWriterNode; import java.util.List; import java.util.Optional; @@ -89,10 +90,13 @@ public PlanNode visitTableFinish(TableFinishNode node, RewriteContext cont return context.defaultRewrite(node); } TableScanNode tableScanNode = tableScan.get(); - if (!metadata.supportsMetadataDelete(session, tableScanNode.getTable(), tableScanNode.getLayout().get())) { + if (!metadata.supportsMetadataDelete(session, tableScanNode.getTable())) { return context.defaultRewrite(node); } - return new MetadataDeleteNode(idAllocator.getNextId(), delete.get().getTarget(), Iterables.getOnlyElement(node.getOutputSymbols()), tableScanNode.getLayout().get()); + return new MetadataDeleteNode( + idAllocator.getNextId(), + new TableWriterNode.DeleteHandle(tableScanNode.getTable(), delete.get().getTarget().getSchemaTableName()), + Iterables.getOnlyElement(node.getOutputSymbols())); } private static Optional findNode(PlanNode source, Class clazz) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/MetadataQueryOptimizer.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/MetadataQueryOptimizer.java index 03d541629a75..5da43eb4ed37 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/MetadataQueryOptimizer.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/MetadataQueryOptimizer.java @@ -22,10 +22,8 @@ import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.TableLayout; -import io.prestosql.metadata.TableLayoutResult; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ColumnMetadata; -import io.prestosql.spi.connector.Constraint; import io.prestosql.spi.connector.DiscretePredicates; import io.prestosql.spi.predicate.NullableValue; import io.prestosql.spi.predicate.TupleDomain; @@ -137,17 +135,8 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext cont // Materialize the list of partitions and replace the TableScan node // with a Values node - TableLayout layout = null; - if (!tableScan.getLayout().isPresent()) { - layout = metadata.getLayout(session, tableScan.getTable(), Constraint.alwaysTrue(), Optional.empty()) - .map(TableLayoutResult::getLayout) - .orElse(null); - } - else { - layout = metadata.getLayout(session, tableScan.getLayout().get()); - } - - if (layout == null || !layout.getDiscretePredicates().isPresent()) { + TableLayout layout = metadata.getLayout(session, tableScan.getTable()); + if (!layout.getDiscretePredicates().isPresent()) { return context.defaultRewrite(node); } DiscretePredicates predicates = layout.getDiscretePredicates().get(); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PropertyDerivations.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PropertyDerivations.java index 8306c2e9f90b..8657f7f3e31d 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PropertyDerivations.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PropertyDerivations.java @@ -709,9 +709,7 @@ public ActualProperties visitValues(ValuesNode node, List cont @Override public ActualProperties visitTableScan(TableScanNode node, List inputProperties) { - checkArgument(node.getLayout().isPresent(), "table layout has not yet been chosen"); - - TableLayout layout = metadata.getLayout(session, node.getLayout().get()); + TableLayout layout = metadata.getLayout(session, node.getTable()); Map assignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse(); ActualProperties.Builder properties = ActualProperties.builder(); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PruneUnreferencedOutputs.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PruneUnreferencedOutputs.java index 5207fc49c057..f568bb10b3b8 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -320,7 +320,7 @@ public PlanNode visitIndexSource(IndexSourceNode node, RewriteContext newAssignments = newOutputSymbols.stream() .collect(Collectors.toMap(Function.identity(), node.getAssignments()::get)); - return new IndexSourceNode(node.getId(), node.getIndexHandle(), node.getTableHandle(), node.getLayout(), newLookupSymbols, newOutputSymbols, newAssignments, node.getCurrentConstraint()); + return new IndexSourceNode(node.getId(), node.getIndexHandle(), node.getTableHandle(), newLookupSymbols, newOutputSymbols, newAssignments, node.getCurrentConstraint()); } @Override @@ -426,7 +426,6 @@ public PlanNode visitTableScan(TableScanNode node, RewriteContext> c node.getTable(), newOutputs, newAssignments, - node.getLayout(), node.getCurrentConstraint(), node.getEnforcedConstraint()); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/StreamPropertyDerivations.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/StreamPropertyDerivations.java index 27b5930ae57f..d7f620263238 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/StreamPropertyDerivations.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/StreamPropertyDerivations.java @@ -254,9 +254,7 @@ public StreamProperties visitValues(ValuesNode node, List cont @Override public StreamProperties visitTableScan(TableScanNode node, List inputProperties) { - checkArgument(node.getLayout().isPresent(), "table layout has not yet been chosen"); - - TableLayout layout = metadata.getLayout(session, node.getLayout().get()); + TableLayout layout = metadata.getLayout(session, node.getTable()); Map assignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse(); // Globally constant assignments diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/UnaliasSymbolReferences.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/UnaliasSymbolReferences.java index ad377f9e5095..c35919cf46d8 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -525,7 +525,7 @@ public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext cont @Override public PlanNode visitIndexSource(IndexSourceNode node, RewriteContext context) { - return new IndexSourceNode(node.getId(), node.getIndexHandle(), node.getTableHandle(), node.getLayout(), canonicalize(node.getLookupSymbols()), node.getOutputSymbols(), node.getAssignments(), node.getCurrentConstraint()); + return new IndexSourceNode(node.getId(), node.getIndexHandle(), node.getTableHandle(), canonicalize(node.getLookupSymbols()), node.getOutputSymbols(), node.getAssignments(), node.getCurrentConstraint()); } @Override diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/plan/IndexSourceNode.java b/presto-main/src/main/java/io/prestosql/sql/planner/plan/IndexSourceNode.java index d908fea62650..877e18ab0871 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/plan/IndexSourceNode.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/plan/IndexSourceNode.java @@ -20,14 +20,12 @@ import com.google.common.collect.ImmutableSet; import io.prestosql.metadata.IndexHandle; import io.prestosql.metadata.TableHandle; -import io.prestosql.metadata.TableLayoutHandle; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.sql.planner.Symbol; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; @@ -38,7 +36,6 @@ public class IndexSourceNode { private final IndexHandle indexHandle; private final TableHandle tableHandle; - private final Optional tableLayout; // only necessary for event listeners private final Set lookupSymbols; private final List outputSymbols; private final Map assignments; // symbol -> column @@ -49,7 +46,6 @@ public IndexSourceNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("indexHandle") IndexHandle indexHandle, @JsonProperty("tableHandle") TableHandle tableHandle, - @JsonProperty("tableLayout") Optional tableLayout, @JsonProperty("lookupSymbols") Set lookupSymbols, @JsonProperty("outputSymbols") List outputSymbols, @JsonProperty("assignments") Map assignments, @@ -58,7 +54,6 @@ public IndexSourceNode( super(id); this.indexHandle = requireNonNull(indexHandle, "indexHandle is null"); this.tableHandle = requireNonNull(tableHandle, "tableHandle is null"); - this.tableLayout = requireNonNull(tableLayout, "tableLayout is null"); this.lookupSymbols = ImmutableSet.copyOf(requireNonNull(lookupSymbols, "lookupSymbols is null")); this.outputSymbols = ImmutableList.copyOf(requireNonNull(outputSymbols, "outputSymbols is null")); this.assignments = ImmutableMap.copyOf(requireNonNull(assignments, "assignments is null")); @@ -81,12 +76,6 @@ public TableHandle getTableHandle() return tableHandle; } - @JsonProperty - public Optional getLayout() - { - return tableLayout; - } - @JsonProperty public Set getLookupSymbols() { diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/plan/MetadataDeleteNode.java b/presto-main/src/main/java/io/prestosql/sql/planner/plan/MetadataDeleteNode.java index 807b541f5e44..406d22f1383f 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/plan/MetadataDeleteNode.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/plan/MetadataDeleteNode.java @@ -16,7 +16,6 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; -import io.prestosql.metadata.TableLayoutHandle; import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.plan.TableWriterNode.DeleteHandle; @@ -32,20 +31,17 @@ public class MetadataDeleteNode { private final DeleteHandle target; private final Symbol output; - private final TableLayoutHandle tableLayout; @JsonCreator public MetadataDeleteNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("target") DeleteHandle target, - @JsonProperty("output") Symbol output, - @JsonProperty("tableLayout") TableLayoutHandle tableLayout) + @JsonProperty("output") Symbol output) { super(id); this.target = requireNonNull(target, "target is null"); this.output = requireNonNull(output, "output is null"); - this.tableLayout = requireNonNull(tableLayout, "tableLayout is null"); } @JsonProperty @@ -66,12 +62,6 @@ public List getOutputSymbols() return ImmutableList.of(output); } - @JsonProperty - public TableLayoutHandle getTableLayout() - { - return tableLayout; - } - @Override public List getSources() { @@ -87,6 +77,6 @@ public R accept(PlanVisitor visitor, C context) @Override public PlanNode replaceChildren(List newChildren) { - return new MetadataDeleteNode(getId(), target, output, tableLayout); + return new MetadataDeleteNode(getId(), target, output); } } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/plan/TableScanNode.java b/presto-main/src/main/java/io/prestosql/sql/planner/plan/TableScanNode.java index ad7eaf0d1ecd..f95f8ba44ab0 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/plan/TableScanNode.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/plan/TableScanNode.java @@ -18,7 +18,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.prestosql.metadata.TableHandle; -import io.prestosql.metadata.TableLayoutHandle; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.sql.planner.Symbol; @@ -27,7 +26,6 @@ import java.util.List; import java.util.Map; -import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; @@ -42,21 +40,30 @@ public class TableScanNode private final List outputSymbols; private final Map assignments; // symbol -> column - private final Optional tableLayout; - // Used during predicate refinement over multiple passes of predicate pushdown // TODO: think about how to get rid of this in new planner private final TupleDomain currentConstraint; private final TupleDomain enforcedConstraint; + // We need this factory method to disambiguate with the constructor used for deserializing + // from a json object. The deserializer sets some fields which are never transported + // to null + public static TableScanNode newInstance( + PlanNodeId id, + TableHandle table, + List outputs, + Map assignments) + { + return new TableScanNode(id, table, outputs, assignments, TupleDomain.all(), TupleDomain.all()); + } + @JsonCreator public TableScanNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("table") TableHandle table, @JsonProperty("outputSymbols") List outputs, - @JsonProperty("assignments") Map assignments, - @JsonProperty("layout") Optional tableLayout) + @JsonProperty("assignments") Map assignments) { // This constructor is for JSON deserialization only. Do not use. super(id); @@ -64,26 +71,15 @@ public TableScanNode( this.outputSymbols = ImmutableList.copyOf(requireNonNull(outputs, "outputs is null")); this.assignments = ImmutableMap.copyOf(requireNonNull(assignments, "assignments is null")); checkArgument(assignments.keySet().containsAll(outputs), "assignments does not cover all of outputs"); - this.tableLayout = requireNonNull(tableLayout, "tableLayout is null"); this.currentConstraint = null; this.enforcedConstraint = null; } - public TableScanNode( - PlanNodeId id, - TableHandle table, - List outputs, - Map assignments) - { - this(id, table, outputs, assignments, Optional.empty(), TupleDomain.all(), TupleDomain.all()); - } - public TableScanNode( PlanNodeId id, TableHandle table, List outputs, Map assignments, - Optional tableLayout, TupleDomain currentConstraint, TupleDomain enforcedConstraint) { @@ -92,12 +88,8 @@ public TableScanNode( this.outputSymbols = ImmutableList.copyOf(requireNonNull(outputs, "outputs is null")); this.assignments = ImmutableMap.copyOf(requireNonNull(assignments, "assignments is null")); checkArgument(assignments.keySet().containsAll(outputs), "assignments does not cover all of outputs"); - this.tableLayout = requireNonNull(tableLayout, "tableLayout is null"); this.currentConstraint = requireNonNull(currentConstraint, "currentConstraint is null"); this.enforcedConstraint = requireNonNull(enforcedConstraint, "enforcedConstraint is null"); - if (!currentConstraint.isAll() || !enforcedConstraint.isAll()) { - checkArgument(tableLayout.isPresent(), "tableLayout must be present when currentConstraint or enforcedConstraint is non-trivial"); - } } @JsonProperty("table") @@ -106,12 +98,6 @@ public TableHandle getTable() return table; } - @JsonProperty - public Optional getLayout() - { - return tableLayout; - } - @Override @JsonProperty("outputSymbols") public List getOutputSymbols() @@ -171,7 +157,6 @@ public String toString() { return toStringHelper(this) .add("table", table) - .add("tableLayout", tableLayout) .add("outputSymbols", outputSymbols) .add("assignments", assignments) .add("currentConstraint", currentConstraint) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/planPrinter/PlanPrinter.java b/presto-main/src/main/java/io/prestosql/sql/planner/planPrinter/PlanPrinter.java index 1834e19a73b3..985070896e9c 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/planPrinter/PlanPrinter.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/planPrinter/PlanPrinter.java @@ -36,7 +36,6 @@ import io.prestosql.metadata.TableHandle; import io.prestosql.operator.StageExecutionDescriptor; import io.prestosql.spi.connector.ColumnHandle; -import io.prestosql.spi.connector.ConnectorTableLayoutHandle; import io.prestosql.spi.predicate.Domain; import io.prestosql.spi.predicate.Marker; import io.prestosql.spi.predicate.NullableValue; @@ -758,16 +757,6 @@ private Void visitScanFilterAndProjectInfo( private void printTableScanInfo(NodeRepresentation nodeOutput, TableScanNode node) { - TableHandle table = node.getTable(); - - if (node.getLayout().isPresent()) { - // TODO: find a better way to do this - ConnectorTableLayoutHandle layout = node.getLayout().get().getConnectorHandle(); - if (!table.getConnectorHandle().toString().equals(layout.toString())) { - nodeOutput.appendDetailsLine("LAYOUT: %s", layout); - } - } - TupleDomain predicate = node.getCurrentConstraint(); if (predicate.isNone()) { nodeOutput.appendDetailsLine(":: NONE"); diff --git a/presto-main/src/main/java/io/prestosql/testing/LocalQueryRunner.java b/presto-main/src/main/java/io/prestosql/testing/LocalQueryRunner.java index 6a599189f8a9..b8c1a3c4041e 100644 --- a/presto-main/src/main/java/io/prestosql/testing/LocalQueryRunner.java +++ b/presto-main/src/main/java/io/prestosql/testing/LocalQueryRunner.java @@ -85,7 +85,6 @@ import io.prestosql.metadata.SchemaPropertyManager; import io.prestosql.metadata.SessionPropertyManager; import io.prestosql.metadata.Split; -import io.prestosql.metadata.TableLayoutHandle; import io.prestosql.metadata.TablePropertyManager; import io.prestosql.metadata.ViewDefinition; import io.prestosql.operator.Driver; @@ -300,7 +299,6 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, notificationExecutor); this.nodePartitioningManager = new NodePartitioningManager(nodeScheduler); - this.splitManager = new SplitManager(new QueryManagerConfig()); this.blockEncodingManager = new BlockEncodingManager(typeRegistry); this.metadata = new MetadataManager( featuresConfig, @@ -312,6 +310,7 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, new ColumnPropertyManager(), new AnalyzePropertyManager(), transactionManager); + this.splitManager = new SplitManager(new QueryManagerConfig(), metadata); this.planFragmenter = new PlanFragmenter(this.metadata, this.nodePartitioningManager, new QueryManagerConfig()); this.joinCompiler = new JoinCompiler(metadata, featuresConfig); this.pageIndexerFactory = new GroupByHashPageIndexerFactory(joinCompiler); @@ -731,11 +730,9 @@ private List createDrivers(Session session, Plan plan, OutputFactory out List sources = new ArrayList<>(); long sequenceId = 0; for (TableScanNode tableScan : findTableScanNodes(subplan.getFragment().getRoot())) { - TableLayoutHandle layout = tableScan.getLayout().get(); - SplitSource splitSource = splitManager.getSplits( session, - layout, + tableScan.getTable(), stageExecutionDescriptor.isScanGroupedExecution(tableScan.getId()) ? GROUPED_SCHEDULING : UNGROUPED_SCHEDULING); ImmutableSet.Builder scheduledSplits = ImmutableSet.builder(); diff --git a/presto-main/src/main/java/io/prestosql/testing/TestingMetadata.java b/presto-main/src/main/java/io/prestosql/testing/TestingMetadata.java index 859b0daabe94..a5907363ae7d 100644 --- a/presto-main/src/main/java/io/prestosql/testing/TestingMetadata.java +++ b/presto-main/src/main/java/io/prestosql/testing/TestingMetadata.java @@ -37,6 +37,7 @@ import io.prestosql.spi.connector.SchemaTableName; import io.prestosql.spi.connector.SchemaTablePrefix; import io.prestosql.spi.connector.ViewNotFoundException; +import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.spi.security.PrestoPrincipal; import io.prestosql.spi.security.Privilege; import io.prestosql.spi.statistics.ComputedStatistics; @@ -95,13 +96,13 @@ public ConnectorTableHandle getTableHandleForStatisticsCollection(ConnectorSessi @Override public List getTableLayouts(ConnectorSession session, ConnectorTableHandle table, Constraint constraint, Optional> desiredColumns) { - return ImmutableList.of(); + return ImmutableList.of(new ConnectorTableLayoutResult(new ConnectorTableLayout(TestingHandle.INSTANCE), TupleDomain.all())); } @Override public ConnectorTableLayout getTableLayout(ConnectorSession session, ConnectorTableLayoutHandle handle) { - throw new UnsupportedOperationException(); + return new ConnectorTableLayout(TestingHandle.INSTANCE); } @Override diff --git a/presto-main/src/test/java/io/prestosql/connector/MockConnectorFactory.java b/presto-main/src/test/java/io/prestosql/connector/MockConnectorFactory.java index 189a56608cf6..1aee4584c583 100644 --- a/presto-main/src/test/java/io/prestosql/connector/MockConnectorFactory.java +++ b/presto-main/src/test/java/io/prestosql/connector/MockConnectorFactory.java @@ -39,7 +39,9 @@ import io.prestosql.spi.connector.Constraint; import io.prestosql.spi.connector.SchemaTableName; import io.prestosql.spi.connector.SchemaTablePrefix; +import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.spi.transaction.IsolationLevel; +import io.prestosql.testing.TestingHandle; import java.util.List; import java.util.Map; @@ -197,7 +199,7 @@ public Map> listTableColumns(ConnectorSess @Override public List getTableLayouts(ConnectorSession session, ConnectorTableHandle table, Constraint constraint, Optional> desiredColumns) { - return ImmutableList.of(); + return ImmutableList.of(new ConnectorTableLayoutResult(new ConnectorTableLayout(TestingHandle.INSTANCE), TupleDomain.all())); } @Override diff --git a/presto-main/src/test/java/io/prestosql/cost/TestCostCalculator.java b/presto-main/src/test/java/io/prestosql/cost/TestCostCalculator.java index cd199d8f21df..31f1abcceff3 100644 --- a/presto-main/src/test/java/io/prestosql/cost/TestCostCalculator.java +++ b/presto-main/src/test/java/io/prestosql/cost/TestCostCalculator.java @@ -29,7 +29,6 @@ import io.prestosql.metadata.MetadataManager; import io.prestosql.metadata.Signature; import io.prestosql.metadata.TableHandle; -import io.prestosql.metadata.TableLayoutHandle; import io.prestosql.plugin.tpch.TpchColumnHandle; import io.prestosql.plugin.tpch.TpchTableHandle; import io.prestosql.plugin.tpch.TpchTableLayoutHandle; @@ -747,10 +746,9 @@ private TableScanNode tableScan(String id, String... symbols) TpchTableHandle tableHandle = new TpchTableHandle("orders", 1.0); return new TableScanNode( new PlanNodeId(id), - new TableHandle(new ConnectorId("tpch"), new TpchTableHandle("orders", 1.0)), + new TableHandle(new ConnectorId("tpch"), tableHandle, INSTANCE, Optional.of(new TpchTableLayoutHandle(tableHandle, TupleDomain.all()))), symbolsList, assignments.build(), - Optional.of(new TableLayoutHandle(new ConnectorId("tpch"), INSTANCE, new TpchTableLayoutHandle(tableHandle, TupleDomain.all()))), TupleDomain.all(), TupleDomain.all()); } diff --git a/presto-main/src/test/java/io/prestosql/execution/MockRemoteTaskFactory.java b/presto-main/src/test/java/io/prestosql/execution/MockRemoteTaskFactory.java index e4f51d3ba8cb..61c59f4e9e44 100644 --- a/presto-main/src/test/java/io/prestosql/execution/MockRemoteTaskFactory.java +++ b/presto-main/src/test/java/io/prestosql/execution/MockRemoteTaskFactory.java @@ -49,8 +49,10 @@ import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.planner.plan.PlanNodeId; import io.prestosql.sql.planner.plan.TableScanNode; +import io.prestosql.testing.TestingHandle; import io.prestosql.testing.TestingMetadata.TestingColumnHandle; import io.prestosql.testing.TestingMetadata.TestingTableHandle; +import io.prestosql.testing.TestingTransactionHandle; import org.joda.time.DateTime; import javax.annotation.concurrent.GuardedBy; @@ -108,9 +110,9 @@ public MockRemoteTask createTableScanTask(TaskId taskId, Node newNode, List getLayout(Session session, TableHandle tableH } @Override - public TableLayout getLayout(Session session, TableLayoutHandle handle) + public TableLayout getLayout(Session session, TableHandle handle) { throw new UnsupportedOperationException(); } @Override - public TableLayoutHandle makeCompatiblePartitioning(Session session, TableLayoutHandle tableLayoutHandle, PartitioningHandle partitioningHandle) + public TableHandle makeCompatiblePartitioning(Session session, TableHandle table, PartitioningHandle partitioningHandle) { throw new UnsupportedOperationException(); } @@ -138,7 +138,7 @@ public Optional getCommonPartitioning(Session session, Parti } @Override - public Optional getInfo(Session session, TableLayoutHandle handle) + public Optional getInfo(Session session, TableHandle handle) { throw new UnsupportedOperationException(); } @@ -306,13 +306,13 @@ public ColumnHandle getUpdateRowIdColumnHandle(Session session, TableHandle tabl } @Override - public boolean supportsMetadataDelete(Session session, TableHandle tableHandle, TableLayoutHandle tableLayoutHandle) + public boolean supportsMetadataDelete(Session session, TableHandle tableHandle) { throw new UnsupportedOperationException(); } @Override - public OptionalLong metadataDelete(Session session, TableHandle tableHandle, TableLayoutHandle tableLayoutHandle) + public OptionalLong metadataDelete(Session session, TableHandle tableHandle) { throw new UnsupportedOperationException(); } diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestEffectivePredicateExtractor.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestEffectivePredicateExtractor.java index 5657d336f7b6..ccc92b70772c 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/TestEffectivePredicateExtractor.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/TestEffectivePredicateExtractor.java @@ -27,7 +27,6 @@ import io.prestosql.metadata.MetadataManager; import io.prestosql.metadata.Signature; import io.prestosql.metadata.TableHandle; -import io.prestosql.metadata.TableLayoutHandle; import io.prestosql.spi.block.SortOrder; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.predicate.Domain; @@ -89,11 +88,7 @@ @Test(singleThreaded = true) public class TestEffectivePredicateExtractor { - private static final TableHandle DUAL_TABLE_HANDLE = new TableHandle(new ConnectorId("test"), new TestingTableHandle()); - private static final TableLayoutHandle TESTING_TABLE_LAYOUT = new TableLayoutHandle( - new ConnectorId("x"), - TestingTransactionHandle.create(), - TestingHandle.INSTANCE); + private static final TableHandle DUAL_TABLE_HANDLE = new TableHandle(new ConnectorId("test"), new TestingTableHandle(), TestingTransactionHandle.create(), Optional.of(TestingHandle.INSTANCE)); private static final Symbol A = new Symbol("a"); private static final Symbol B = new Symbol("b"); @@ -130,7 +125,7 @@ public void setUp() .build(); Map assignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(A, B, C, D, E, F))); - baseTableScan = new TableScanNode( + baseTableScan = TableScanNode.newInstance( newId(), DUAL_TABLE_HANDLE, ImmutableList.copyOf(assignments.keySet()), @@ -323,7 +318,7 @@ public void testTableScan() { // Effective predicate is True if there is no effective predicate Map assignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(A, B, C, D))); - PlanNode node = new TableScanNode( + PlanNode node = TableScanNode.newInstance( newId(), DUAL_TABLE_HANDLE, ImmutableList.copyOf(assignments.keySet()), @@ -336,7 +331,6 @@ public void testTableScan() DUAL_TABLE_HANDLE, ImmutableList.copyOf(assignments.keySet()), assignments, - Optional.of(TESTING_TABLE_LAYOUT), TupleDomain.none(), TupleDomain.all()); effectivePredicate = effectivePredicateExtractor.extract(node); @@ -347,7 +341,6 @@ public void testTableScan() DUAL_TABLE_HANDLE, ImmutableList.copyOf(assignments.keySet()), assignments, - Optional.of(TESTING_TABLE_LAYOUT), TupleDomain.withColumnDomains(ImmutableMap.of(scanAssignments.get(A), Domain.singleValue(BIGINT, 1L))), TupleDomain.all()); effectivePredicate = effectivePredicateExtractor.extract(node); @@ -358,7 +351,6 @@ public void testTableScan() DUAL_TABLE_HANDLE, ImmutableList.copyOf(assignments.keySet()), assignments, - Optional.of(TESTING_TABLE_LAYOUT), TupleDomain.withColumnDomains(ImmutableMap.of( scanAssignments.get(A), Domain.singleValue(BIGINT, 1L), scanAssignments.get(B), Domain.singleValue(BIGINT, 2L))), @@ -371,7 +363,6 @@ public void testTableScan() DUAL_TABLE_HANDLE, ImmutableList.copyOf(assignments.keySet()), assignments, - Optional.empty(), TupleDomain.all(), TupleDomain.all()); effectivePredicate = effectivePredicateExtractor.extract(node); @@ -717,7 +708,6 @@ private static TableScanNode tableScanNode(Map scanAssignm DUAL_TABLE_HANDLE, ImmutableList.copyOf(scanAssignments.keySet()), scanAssignments, - Optional.empty(), TupleDomain.all(), TupleDomain.all()); } diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestTypeValidator.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestTypeValidator.java index 3fda22c4962b..e3f7f760c431 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/TestTypeValidator.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/TestTypeValidator.java @@ -44,8 +44,10 @@ import io.prestosql.sql.tree.FunctionCall; import io.prestosql.sql.tree.QualifiedName; import io.prestosql.sql.tree.WindowFrame; +import io.prestosql.testing.TestingHandle; import io.prestosql.testing.TestingMetadata.TestingColumnHandle; import io.prestosql.testing.TestingMetadata.TestingTableHandle; +import io.prestosql.testing.TestingTransactionHandle; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; @@ -66,7 +68,7 @@ @Test(singleThreaded = true) public class TestTypeValidator { - private static final TableHandle TEST_TABLE_HANDLE = new TableHandle(new ConnectorId("test"), new TestingTableHandle()); + private static final TableHandle TEST_TABLE_HANDLE = new TableHandle(new ConnectorId("test"), new TestingTableHandle(), TestingTransactionHandle.create(), Optional.of(TestingHandle.INSTANCE)); private static final SqlParser SQL_PARSER = new SqlParser(); private static final TypeValidator TYPE_VALIDATOR = new TypeValidator(); @@ -101,7 +103,6 @@ public void setUp() TEST_TABLE_HANDLE, ImmutableList.copyOf(assignments.keySet()), assignments, - Optional.empty(), TupleDomain.all(), TupleDomain.all()); } diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/TableScanMatcher.java b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/TableScanMatcher.java index 3f7aa740f1bc..ba690e3482c0 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/TableScanMatcher.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/TableScanMatcher.java @@ -61,13 +61,7 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses return new MatchResult( expectedTableName.equalsIgnoreCase(actualTableName) && ((!expectedConstraint.isPresent()) || - domainsMatch(expectedConstraint, tableScanNode.getCurrentConstraint(), tableScanNode.getTable(), session, metadata)) && - hasTableLayout(tableScanNode)); - } - - private boolean hasTableLayout(TableScanNode tableScanNode) - { - return !hasTableLayout.isPresent() || hasTableLayout.get() == tableScanNode.getLayout().isPresent(); + domainsMatch(expectedConstraint, tableScanNode.getCurrentConstraint(), tableScanNode.getTable(), session, metadata))); } @Override diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPruneCountAggregationOverScalar.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPruneCountAggregationOverScalar.java index 770037558250..5654be2f0163 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPruneCountAggregationOverScalar.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPruneCountAggregationOverScalar.java @@ -19,6 +19,7 @@ import io.prestosql.metadata.TableHandle; import io.prestosql.plugin.tpch.TpchColumnHandle; import io.prestosql.plugin.tpch.TpchTableHandle; +import io.prestosql.plugin.tpch.TpchTransactionHandle; import io.prestosql.spi.type.BigintType; import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; @@ -29,6 +30,8 @@ import io.prestosql.sql.tree.SymbolReference; import org.testng.annotations.Test; +import java.util.Optional; + import static io.prestosql.plugin.tpch.TpchMetadata.TINY_SCALE_FACTOR; import static io.prestosql.spi.type.DoubleType.DOUBLE; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.values; @@ -143,7 +146,9 @@ public void testDoesNotFireOnNestedNonCountAggregate() p.tableScan( new TableHandle( new ConnectorId("local"), - new TpchTableHandle("orders", TINY_SCALE_FACTOR)), + new TpchTableHandle("orders", TINY_SCALE_FACTOR), + TpchTransactionHandle.INSTANCE, + Optional.empty()), ImmutableList.of(totalPrice), ImmutableMap.of(totalPrice, new TpchColumnHandle(totalPrice.getName(), DOUBLE)))))); diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPruneIndexSourceColumns.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPruneIndexSourceColumns.java index 9b9511478d7b..1589d0ad5e90 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPruneIndexSourceColumns.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPruneIndexSourceColumns.java @@ -21,6 +21,7 @@ import io.prestosql.metadata.TableHandle; import io.prestosql.plugin.tpch.TpchColumnHandle; import io.prestosql.plugin.tpch.TpchTableHandle; +import io.prestosql.plugin.tpch.TpchTransactionHandle; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.predicate.Domain; import io.prestosql.spi.predicate.TupleDomain; @@ -31,6 +32,7 @@ import io.prestosql.sql.planner.plan.PlanNode; import org.testng.annotations.Test; +import java.util.Optional; import java.util.function.Predicate; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -77,6 +79,7 @@ private static PlanNode buildProjectedIndexSource(PlanBuilder p, Predicate rule : pickTableLayout.rules()) { + for (Rule rule : pushPredicateIntoTableScan.rules()) { tester().assertThat(rule) .on(p -> p.values(p.symbol("a", BIGINT))) .doesNotFire(); } } - @Test - public void doesNotFireIfTableScanHasTableLayout() - { - tester().assertThat(pickTableLayout.pickTableLayoutWithoutPredicate()) - .on(p -> p.tableScan( - nationTableHandle, - ImmutableList.of(p.symbol("nationkey", BIGINT)), - ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)), - Optional.of(nationTableLayoutHandle))) - .doesNotFire(); - } - @Test public void eliminateTableScanWhenNoLayoutExist() { - tester().assertThat(pickTableLayout.pickTableLayoutForPredicate()) + tester().assertThat(pushPredicateIntoTableScan.pickTableLayoutForPredicate()) .on(p -> p.filter(expression("orderstatus = 'G'"), p.tableScan( ordersTableHandle, ImmutableList.of(p.symbol("orderstatus", createVarcharType(1))), - ImmutableMap.of(p.symbol("orderstatus", createVarcharType(1)), new TpchColumnHandle("orderstatus", createVarcharType(1))), - Optional.of(ordersTableLayoutHandle)))) + ImmutableMap.of(p.symbol("orderstatus", createVarcharType(1)), new TpchColumnHandle("orderstatus", createVarcharType(1)))))) .matches(values("A")); } @@ -116,13 +99,12 @@ public void eliminateTableScanWhenNoLayoutExist() public void replaceWithExistsWhenNoLayoutExist() { ColumnHandle columnHandle = new TpchColumnHandle("nationkey", BIGINT); - tester().assertThat(pickTableLayout.pickTableLayoutForPredicate()) + tester().assertThat(pushPredicateIntoTableScan.pickTableLayoutForPredicate()) .on(p -> p.filter(expression("nationkey = BIGINT '44'"), p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), ImmutableMap.of(p.symbol("nationkey", BIGINT), columnHandle), - Optional.of(nationTableLayoutHandle), TupleDomain.none(), TupleDomain.none()))) .matches(values("A")); @@ -131,37 +113,24 @@ public void replaceWithExistsWhenNoLayoutExist() @Test public void doesNotFireIfRuleNotChangePlan() { - tester().assertThat(pickTableLayout.pickTableLayoutForPredicate()) + tester().assertThat(pushPredicateIntoTableScan.pickTableLayoutForPredicate()) .on(p -> p.filter(expression("nationkey % 17 = BIGINT '44' AND nationkey % 15 = BIGINT '43'"), p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)), - Optional.of(nationTableLayoutHandle), TupleDomain.all(), TupleDomain.all()))) .doesNotFire(); } - @Test - public void ruleAddedTableLayoutToTableScan() - { - tester().assertThat(pickTableLayout.pickTableLayoutWithoutPredicate()) - .on(p -> p.tableScan( - nationTableHandle, - ImmutableList.of(p.symbol("nationkey", BIGINT)), - ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)))) - .matches( - constrainedTableScanWithTableLayout("nation", ImmutableMap.of(), ImmutableMap.of("nationkey", "nationkey"))); - } - @Test public void ruleAddedTableLayoutToFilterTableScan() { Map filterConstraint = ImmutableMap.builder() .put("orderstatus", singleValue(createVarcharType(1), utf8Slice("F"))) .build(); - tester().assertThat(pickTableLayout.pickTableLayoutForPredicate()) + tester().assertThat(pushPredicateIntoTableScan.pickTableLayoutForPredicate()) .on(p -> p.filter(expression("orderstatus = CAST ('F' AS VARCHAR(1))"), p.tableScan( ordersTableHandle, @@ -174,13 +143,12 @@ public void ruleAddedTableLayoutToFilterTableScan() @Test public void ruleAddedNewTableLayoutIfTableScanHasEmptyConstraint() { - tester().assertThat(pickTableLayout.pickTableLayoutForPredicate()) + tester().assertThat(pushPredicateIntoTableScan.pickTableLayoutForPredicate()) .on(p -> p.filter(expression("orderstatus = 'F'"), p.tableScan( ordersTableHandle, ImmutableList.of(p.symbol("orderstatus", createVarcharType(1))), - ImmutableMap.of(p.symbol("orderstatus", createVarcharType(1)), new TpchColumnHandle("orderstatus", createVarcharType(1))), - Optional.of(ordersTableLayoutHandle)))) + ImmutableMap.of(p.symbol("orderstatus", createVarcharType(1)), new TpchColumnHandle("orderstatus", createVarcharType(1)))))) .matches( constrainedTableScanWithTableLayout( "orders", @@ -192,7 +160,7 @@ public void ruleAddedNewTableLayoutIfTableScanHasEmptyConstraint() public void ruleWithPushdownableToTableLayoutPredicate() { Type orderStatusType = createVarcharType(1); - tester().assertThat(pickTableLayout.pickTableLayoutForPredicate()) + tester().assertThat(pushPredicateIntoTableScan.pickTableLayoutForPredicate()) .on(p -> p.filter(expression("orderstatus = 'O'"), p.tableScan( ordersTableHandle, @@ -208,7 +176,7 @@ public void ruleWithPushdownableToTableLayoutPredicate() public void nonDeterministicPredicate() { Type orderStatusType = createVarcharType(1); - tester().assertThat(pickTableLayout.pickTableLayoutForPredicate()) + tester().assertThat(pushPredicateIntoTableScan.pickTableLayoutForPredicate()) .on(p -> p.filter(expression("orderstatus = 'O' AND rand() = 0"), p.tableScan( ordersTableHandle, diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestRemoveEmptyDelete.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestRemoveEmptyDelete.java index 8e95676875e4..7fc51213d520 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestRemoveEmptyDelete.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestRemoveEmptyDelete.java @@ -17,12 +17,15 @@ import com.google.common.collect.ImmutableMap; import io.prestosql.metadata.TableHandle; import io.prestosql.plugin.tpch.TpchTableHandle; +import io.prestosql.plugin.tpch.TpchTransactionHandle; import io.prestosql.spi.connector.SchemaTableName; import io.prestosql.spi.type.BigintType; import io.prestosql.sql.planner.assertions.PlanMatchPattern; import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; import org.testng.annotations.Test; +import java.util.Optional; + import static io.prestosql.sql.planner.iterative.rule.test.RuleTester.CONNECTOR_ID; public class TestRemoveEmptyDelete @@ -35,7 +38,7 @@ public void testDoesNotFire() .on(p -> p.tableDelete( new SchemaTableName("sch", "tab"), p.tableScan( - new TableHandle(CONNECTOR_ID, new TpchTableHandle("nation", 1.0)), + new TableHandle(CONNECTOR_ID, new TpchTableHandle("nation", 1.0), TpchTransactionHandle.INSTANCE, Optional.empty()), ImmutableList.of(), ImmutableMap.of()), p.symbol("a", BigintType.BIGINT))) diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java index 82e1cc7f18f4..7557675fc899 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java @@ -19,11 +19,14 @@ import io.prestosql.metadata.TableHandle; import io.prestosql.plugin.tpch.TpchColumnHandle; import io.prestosql.plugin.tpch.TpchTableHandle; +import io.prestosql.plugin.tpch.TpchTransactionHandle; import io.prestosql.sql.planner.assertions.PlanMatchPattern; import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; import io.prestosql.sql.planner.plan.Assignments; import org.testng.annotations.Test; +import java.util.Optional; + import static io.prestosql.plugin.tpch.TpchMetadata.TINY_SCALE_FACTOR; import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.project; @@ -48,9 +51,13 @@ public void testRewrite() .on(p -> p.lateral( ImmutableList.of(p.symbol("l_nationkey")), - p.tableScan(new TableHandle( + p.tableScan( + new TableHandle( new ConnectorId("local"), - new TpchTableHandle("nation", TINY_SCALE_FACTOR)), ImmutableList.of(p.symbol("l_nationkey")), + new TpchTableHandle("nation", TINY_SCALE_FACTOR), + TpchTransactionHandle.INSTANCE, + Optional.empty()), + ImmutableList.of(p.symbol("l_nationkey")), ImmutableMap.of(p.symbol("l_nationkey"), new TpchColumnHandle("nationkey", BIGINT))), p.project( diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/PlanBuilder.java index 68cab31431b3..587d25e340c4 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/PlanBuilder.java @@ -24,7 +24,6 @@ import io.prestosql.metadata.Metadata; import io.prestosql.metadata.Signature; import io.prestosql.metadata.TableHandle; -import io.prestosql.metadata.TableLayoutHandle; import io.prestosql.spi.block.SortOrder; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.SchemaTableName; @@ -77,7 +76,9 @@ import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.FunctionCall; import io.prestosql.sql.tree.NullLiteral; +import io.prestosql.testing.TestingHandle; import io.prestosql.testing.TestingMetadata.TestingTableHandle; +import io.prestosql.testing.TestingTransactionHandle; import java.util.ArrayList; import java.util.Arrays; @@ -363,29 +364,24 @@ public LateralJoinNode lateral(List correlation, PlanNode input, PlanNod public TableScanNode tableScan(List symbols, Map assignments) { - TableHandle tableHandle = new TableHandle(new ConnectorId("testConnector"), new TestingTableHandle()); - return tableScan(tableHandle, symbols, assignments, Optional.empty(), TupleDomain.all(), TupleDomain.all()); - } - - public TableScanNode tableScan(TableHandle tableHandle, List symbols, Map assignments) - { - return tableScan(tableHandle, symbols, assignments, Optional.empty()); + return tableScan( + new TableHandle(new ConnectorId("testConnector"), new TestingTableHandle(), TestingTransactionHandle.create(), Optional.of(TestingHandle.INSTANCE)), + symbols, + assignments); } public TableScanNode tableScan( TableHandle tableHandle, List symbols, - Map assignments, - Optional tableLayout) + Map assignments) { - return tableScan(tableHandle, symbols, assignments, tableLayout, TupleDomain.all(), TupleDomain.all()); + return tableScan(tableHandle, symbols, assignments, TupleDomain.all(), TupleDomain.all()); } public TableScanNode tableScan( TableHandle tableHandle, List symbols, Map assignments, - Optional tableLayout, TupleDomain currentConstraint, TupleDomain enforcedConstraint) { @@ -394,7 +390,6 @@ public TableScanNode tableScan( tableHandle, symbols, assignments, - tableLayout, currentConstraint, enforcedConstraint); } @@ -404,7 +399,9 @@ public TableFinishNode tableDelete(SchemaTableName schemaTableName, PlanNode del TableWriterNode.DeleteHandle deleteHandle = new TableWriterNode.DeleteHandle( new TableHandle( new ConnectorId("testConnector"), - new TestingTableHandle()), + new TestingTableHandle(), + TestingTransactionHandle.create(), + Optional.of(TestingHandle.INSTANCE)), schemaTableName); return new TableFinishNode( idAllocator.getNextId(), @@ -488,7 +485,6 @@ public IndexSourceNode indexSource( TestingConnectorTransactionHandle.INSTANCE, TestingConnectorIndexHandle.INSTANCE), tableHandle, - Optional.empty(), lookupSymbols, outputSymbols, assignments, diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateAggregationsWithDefaultValues.java b/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateAggregationsWithDefaultValues.java index 928e2a95ec31..0184499637f9 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateAggregationsWithDefaultValues.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateAggregationsWithDefaultValues.java @@ -20,11 +20,8 @@ import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.TableHandle; -import io.prestosql.metadata.TableLayoutHandle; import io.prestosql.plugin.tpch.TpchColumnHandle; import io.prestosql.plugin.tpch.TpchTableHandle; -import io.prestosql.plugin.tpch.TpchTableLayoutHandle; -import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.PlanNodeIdAllocator; import io.prestosql.sql.planner.Symbol; @@ -66,13 +63,12 @@ public void setup() ConnectorId connectorId = getCurrentConnectorId(); TableHandle nationTableHandle = new TableHandle( connectorId, - new TpchTableHandle("nation", 1.0)); - TableLayoutHandle nationTableLayoutHandle = new TableLayoutHandle(connectorId, + new TpchTableHandle("nation", 1.0), TestingTransactionHandle.create(), - new TpchTableLayoutHandle((TpchTableHandle) nationTableHandle.getConnectorHandle(), TupleDomain.all())); + Optional.empty()); TpchColumnHandle nationkeyColumnHandle = new TpchColumnHandle("nationkey", BIGINT); symbol = new Symbol("nationkey"); - tableScanNode = builder.tableScan(nationTableHandle, ImmutableList.of(symbol), ImmutableMap.of(symbol, nationkeyColumnHandle), Optional.of(nationTableLayoutHandle)); + tableScanNode = builder.tableScan(nationTableHandle, ImmutableList.of(symbol), ImmutableMap.of(symbol, nationkeyColumnHandle)); } @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Final aggregation with default value not separated from partial aggregation by remote hash exchange") diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateStreamingAggregations.java b/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateStreamingAggregations.java index 218d4fb7c473..ae2576516a7a 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateStreamingAggregations.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateStreamingAggregations.java @@ -19,18 +19,15 @@ import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.TableHandle; -import io.prestosql.metadata.TableLayoutHandle; import io.prestosql.plugin.tpch.TpchColumnHandle; import io.prestosql.plugin.tpch.TpchTableHandle; -import io.prestosql.plugin.tpch.TpchTableLayoutHandle; -import io.prestosql.spi.predicate.TupleDomain; +import io.prestosql.plugin.tpch.TpchTransactionHandle; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.PlanNodeIdAllocator; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.assertions.BasePlanTest; import io.prestosql.sql.planner.iterative.rule.test.PlanBuilder; import io.prestosql.sql.planner.plan.PlanNode; -import io.prestosql.testing.TestingTransactionHandle; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -47,7 +44,6 @@ public class TestValidateStreamingAggregations private SqlParser sqlParser; private PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); private TableHandle nationTableHandle; - private TableLayoutHandle nationTableLayoutHandle; @BeforeClass public void setup() @@ -58,11 +54,9 @@ public void setup() ConnectorId connectorId = getCurrentConnectorId(); nationTableHandle = new TableHandle( connectorId, - new TpchTableHandle("nation", 1.0)); - - nationTableLayoutHandle = new TableLayoutHandle(connectorId, - TestingTransactionHandle.create(), - new TpchTableLayoutHandle((TpchTableHandle) nationTableHandle.getConnectorHandle(), TupleDomain.all())); + new TpchTableHandle("nation", 1.0), + TpchTransactionHandle.INSTANCE, + Optional.empty()); } @Test @@ -76,8 +70,7 @@ public void testValidateSuccessful() p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), - ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)), - Optional.of(nationTableLayoutHandle))))); + ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)))))); validatePlan( p -> p.aggregation( @@ -89,8 +82,7 @@ public void testValidateSuccessful() p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), - ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)), - Optional.of(nationTableLayoutHandle)))))); + ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT))))))); } @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Streaming aggregation with input not grouped on the grouping keys") @@ -105,8 +97,7 @@ public void testValidateFailed() p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), - ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)), - Optional.of(nationTableLayoutHandle))))); + ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)))))); } private void validatePlan(Function planProvider) From 622ce722f6e8cc3c88d441574c5faed8a6acb02c Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Tue, 5 Mar 2019 09:50:26 -0800 Subject: [PATCH 09/18] Make PushPredicateIntoTableScan top level rule It now contains a single rule, so no point in having it return a rule set. --- .../prestosql/sql/planner/PlanOptimizers.java | 4 +- .../rule/PushPredicateIntoTableScan.java | 110 +++++++----------- .../rule/TestPushPredicateIntoTableScan.java | 23 ++-- 3 files changed, 54 insertions(+), 83 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java index e7c2a55c3f36..f9dc3f265ef9 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java @@ -357,7 +357,7 @@ public PlanOptimizers( ruleStats, statsCalculator, estimatedExchangesCostCalculator, - new PushPredicateIntoTableScan(metadata, sqlParser).rules()), + ImmutableSet.of(new PushPredicateIntoTableScan(metadata, sqlParser))), new PruneUnreferencedOutputs(), new IterativeOptimizer( ruleStats, @@ -407,7 +407,7 @@ public PlanOptimizers( ruleStats, statsCalculator, estimatedExchangesCostCalculator, - new PushPredicateIntoTableScan(metadata, sqlParser).rules()), + ImmutableSet.of(new PushPredicateIntoTableScan(metadata, sqlParser))), projectionPushDown, new PruneUnreferencedOutputs(), new IterativeOptimizer( diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushPredicateIntoTableScan.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushPredicateIntoTableScan.java index fead481c3d14..244200de417d 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushPredicateIntoTableScan.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushPredicateIntoTableScan.java @@ -15,7 +15,6 @@ import com.google.common.collect.ImmutableBiMap; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; import io.prestosql.Session; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.matching.Capture; @@ -61,7 +60,6 @@ import static io.prestosql.sql.ExpressionUtils.filterDeterministicConjuncts; import static io.prestosql.sql.ExpressionUtils.filterNonDeterministicConjuncts; import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; -import static io.prestosql.sql.planner.iterative.rule.PreconditionRules.checkRulesAreFiredBeforeAddExchangesRule; import static io.prestosql.sql.planner.plan.Patterns.filter; import static io.prestosql.sql.planner.plan.Patterns.source; import static io.prestosql.sql.planner.plan.Patterns.tableScan; @@ -74,7 +72,13 @@ * chosen by AddExchanges */ public class PushPredicateIntoTableScan + implements Rule { + private static final Capture TABLE_SCAN = newCapture(); + + private static final Pattern PATTERN = filter().with(source().matching( + tableScan().capturedAs(TABLE_SCAN))); + private final Metadata metadata; private final SqlParser parser; private final DomainTranslator domainTranslator; @@ -86,90 +90,60 @@ public PushPredicateIntoTableScan(Metadata metadata, SqlParser parser) this.domainTranslator = new DomainTranslator(new LiteralEncoder(metadata.getBlockEncodingSerde())); } - public Set> rules() + @Override + public Pattern getPattern() { - return ImmutableSet.of(pickTableLayoutForPredicate()); + return PATTERN; } - public PickTableLayoutForPredicate pickTableLayoutForPredicate() + @Override + public boolean isEnabled(Session session) { - return new PickTableLayoutForPredicate(metadata, parser, domainTranslator); + return isNewOptimizerEnabled(session); } - private static final class PickTableLayoutForPredicate - implements Rule + @Override + public Result apply(FilterNode filterNode, Captures captures, Context context) { - private final Metadata metadata; - private final SqlParser parser; - private final DomainTranslator domainTranslator; + TableScanNode tableScan = captures.get(TABLE_SCAN); + + PlanNode rewritten = pushFilterIntoTableScan( + tableScan, + filterNode.getPredicate(), + false, + context.getSession(), + context.getSymbolAllocator().getTypes(), + context.getIdAllocator(), + metadata, + parser, + domainTranslator); - private PickTableLayoutForPredicate(Metadata metadata, SqlParser parser, DomainTranslator domainTranslator) - { - this.metadata = requireNonNull(metadata, "metadata is null"); - this.parser = requireNonNull(parser, "parser is null"); - this.domainTranslator = requireNonNull(domainTranslator, "domainTranslator is null"); + if (arePlansSame(filterNode, tableScan, rewritten)) { + return Result.empty(); } - private static final Capture TABLE_SCAN = newCapture(); - - private static final Pattern PATTERN = filter().with(source().matching( - tableScan().capturedAs(TABLE_SCAN))); + return Result.ofPlanNode(rewritten); + } - @Override - public Pattern getPattern() - { - return PATTERN; + private boolean arePlansSame(FilterNode filter, TableScanNode tableScan, PlanNode rewritten) + { + if (!(rewritten instanceof FilterNode)) { + return false; } - @Override - public boolean isEnabled(Session session) - { - return isNewOptimizerEnabled(session); + FilterNode rewrittenFilter = (FilterNode) rewritten; + if (!Objects.equals(filter.getPredicate(), rewrittenFilter.getPredicate())) { + return false; } - @Override - public Result apply(FilterNode filterNode, Captures captures, Context context) - { - TableScanNode tableScan = captures.get(TABLE_SCAN); - - PlanNode rewritten = pushFilterIntoTableScan( - tableScan, - filterNode.getPredicate(), - false, - context.getSession(), - context.getSymbolAllocator().getTypes(), - context.getIdAllocator(), - metadata, - parser, - domainTranslator); - - if (arePlansSame(filterNode, tableScan, rewritten)) { - return Result.empty(); - } - - return Result.ofPlanNode(rewritten); + if (!(rewrittenFilter.getSource() instanceof TableScanNode)) { + return false; } - private boolean arePlansSame(FilterNode filter, TableScanNode tableScan, PlanNode rewritten) - { - if (!(rewritten instanceof FilterNode)) { - return false; - } + TableScanNode rewrittenTableScan = (TableScanNode) rewrittenFilter.getSource(); - FilterNode rewrittenFilter = (FilterNode) rewritten; - if (!Objects.equals(filter.getPredicate(), rewrittenFilter.getPredicate())) { - return false; - } - - if (!(rewrittenFilter.getSource() instanceof TableScanNode)) { - return false; - } - - TableScanNode rewrittenTableScan = (TableScanNode) rewrittenFilter.getSource(); - - return Objects.equals(tableScan.getCurrentConstraint(), rewrittenTableScan.getCurrentConstraint()) - && Objects.equals(tableScan.getEnforcedConstraint(), rewrittenTableScan.getEnforcedConstraint()); - } + return Objects.equals(tableScan.getCurrentConstraint(), rewrittenTableScan.getCurrentConstraint()) + && Objects.equals(tableScan.getEnforcedConstraint(), rewrittenTableScan.getEnforcedConstraint()); } public static PlanNode pushFilterIntoTableScan( diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushPredicateIntoTableScan.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushPredicateIntoTableScan.java index 8c58ffe31e16..b10f0add15bc 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushPredicateIntoTableScan.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushPredicateIntoTableScan.java @@ -26,7 +26,6 @@ import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.spi.type.Type; import io.prestosql.sql.parser.SqlParser; -import io.prestosql.sql.planner.iterative.Rule; import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -76,17 +75,15 @@ public void setUpBeforeClass() @Test public void doesNotFireIfNoTableScan() { - for (Rule rule : pushPredicateIntoTableScan.rules()) { - tester().assertThat(rule) - .on(p -> p.values(p.symbol("a", BIGINT))) - .doesNotFire(); - } + tester().assertThat(pushPredicateIntoTableScan) + .on(p -> p.values(p.symbol("a", BIGINT))) + .doesNotFire(); } @Test public void eliminateTableScanWhenNoLayoutExist() { - tester().assertThat(pushPredicateIntoTableScan.pickTableLayoutForPredicate()) + tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter(expression("orderstatus = 'G'"), p.tableScan( ordersTableHandle, @@ -99,7 +96,7 @@ public void eliminateTableScanWhenNoLayoutExist() public void replaceWithExistsWhenNoLayoutExist() { ColumnHandle columnHandle = new TpchColumnHandle("nationkey", BIGINT); - tester().assertThat(pushPredicateIntoTableScan.pickTableLayoutForPredicate()) + tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter(expression("nationkey = BIGINT '44'"), p.tableScan( nationTableHandle, @@ -113,7 +110,7 @@ public void replaceWithExistsWhenNoLayoutExist() @Test public void doesNotFireIfRuleNotChangePlan() { - tester().assertThat(pushPredicateIntoTableScan.pickTableLayoutForPredicate()) + tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter(expression("nationkey % 17 = BIGINT '44' AND nationkey % 15 = BIGINT '43'"), p.tableScan( nationTableHandle, @@ -130,7 +127,7 @@ public void ruleAddedTableLayoutToFilterTableScan() Map filterConstraint = ImmutableMap.builder() .put("orderstatus", singleValue(createVarcharType(1), utf8Slice("F"))) .build(); - tester().assertThat(pushPredicateIntoTableScan.pickTableLayoutForPredicate()) + tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter(expression("orderstatus = CAST ('F' AS VARCHAR(1))"), p.tableScan( ordersTableHandle, @@ -143,7 +140,7 @@ public void ruleAddedTableLayoutToFilterTableScan() @Test public void ruleAddedNewTableLayoutIfTableScanHasEmptyConstraint() { - tester().assertThat(pushPredicateIntoTableScan.pickTableLayoutForPredicate()) + tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter(expression("orderstatus = 'F'"), p.tableScan( ordersTableHandle, @@ -160,7 +157,7 @@ public void ruleAddedNewTableLayoutIfTableScanHasEmptyConstraint() public void ruleWithPushdownableToTableLayoutPredicate() { Type orderStatusType = createVarcharType(1); - tester().assertThat(pushPredicateIntoTableScan.pickTableLayoutForPredicate()) + tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter(expression("orderstatus = 'O'"), p.tableScan( ordersTableHandle, @@ -176,7 +173,7 @@ public void ruleWithPushdownableToTableLayoutPredicate() public void nonDeterministicPredicate() { Type orderStatusType = createVarcharType(1); - tester().assertThat(pushPredicateIntoTableScan.pickTableLayoutForPredicate()) + tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter(expression("orderstatus = 'O' AND rand() = 0"), p.tableScan( ordersTableHandle, From a966589c513e030ab1770764fcafa8e77e0ae271 Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Wed, 6 Mar 2019 11:20:09 -0800 Subject: [PATCH 10/18] Remove unnecessary expression translations In order to translate expression to row expressions, the code was first replacing all symbol references with field references for the corresponding ordinal inputs. This is unnecessary, as the translation can be done on demand as the expression is translated to a row expression. --- .../sql/planner/LocalExecutionPlanner.java | 67 ++++++------------- .../optimizations/ExpressionEquivalence.java | 22 +++--- .../SqlToRowExpressionTranslator.java | 13 +++- ...BenchmarkScanFilterAndProjectOperator.java | 26 +++---- .../operator/scalar/FunctionAssertions.java | 38 +++++------ .../sql/TestSqlToRowExpressionTranslator.java | 2 +- .../sql/gen/PageProcessorBenchmark.java | 26 +++---- .../type/BenchmarkDecimalOperators.java | 15 ++--- 8 files changed, 83 insertions(+), 126 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java index db5b3d1ef9a1..fca17d4ab0fa 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java @@ -217,7 +217,6 @@ import static com.google.common.base.Verify.verify; import static com.google.common.collect.DiscreteDomain.integers; import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.concat; import static com.google.common.collect.Iterables.getOnlyElement; @@ -252,7 +251,6 @@ import static io.prestosql.spi.type.TypeUtils.writeNativeValue; import static io.prestosql.spiller.PartitioningSpillerFactory.unsupportedPartitioningSpillerFactory; import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput; import static io.prestosql.sql.gen.LambdaBytecodeGenerator.compileLambdaProvider; import static io.prestosql.sql.planner.ExpressionNodeInliner.replaceExpression; import static io.prestosql.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION; @@ -279,7 +277,6 @@ import static io.prestosql.util.SpatialJoinUtils.ST_WITHIN; import static io.prestosql.util.SpatialJoinUtils.extractSupportedSpatialComparisons; import static io.prestosql.util.SpatialJoinUtils.extractSupportedSpatialFunctions; -import static java.lang.String.format; import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; import static java.util.stream.IntStream.range; @@ -1170,7 +1167,6 @@ private PhysicalOperation visitScanFilterAndProject( // if source is a table scan we fold it directly into the filter and project // otherwise we plan it as a normal operator Map sourceLayout; - Map sourceTypes; List columns = null; PhysicalOperation source = null; if (sourceNode instanceof TableScanNode) { @@ -1178,7 +1174,6 @@ private PhysicalOperation visitScanFilterAndProject( // extract the column handles and channel to type mapping sourceLayout = new LinkedHashMap<>(); - sourceTypes = new LinkedHashMap<>(); columns = new ArrayList<>(); int channel = 0; for (Symbol symbol : tableScanNode.getOutputSymbols()) { @@ -1187,9 +1182,6 @@ private PhysicalOperation visitScanFilterAndProject( Integer input = channel; sourceLayout.put(symbol, input); - Type type = requireNonNull(context.getTypes().get(symbol), format("No type for symbol %s", symbol)); - sourceTypes.put(input, type); - channel++; } } @@ -1209,7 +1201,6 @@ else if (sourceNode instanceof SampleNode) { // plan source source = sourceNode.accept(this, context); sourceLayout = source.getLayout(); - sourceTypes = getInputTypes(source.getLayout(), source.getTypes()); } // build output mapping @@ -1220,27 +1211,24 @@ else if (sourceNode instanceof SampleNode) { } Map outputMappings = outputMappingsBuilder.build(); - // compiler uses inputs instead of symbols, so rewrite the expressions first - SymbolToInputRewriter symbolToInputRewriter = new SymbolToInputRewriter(sourceLayout); - Optional rewrittenFilter = filterExpression.map(symbolToInputRewriter::rewrite); - - List rewrittenProjections = new ArrayList<>(); + List projections = new ArrayList<>(); for (Symbol symbol : outputSymbols) { - rewrittenProjections.add(symbolToInputRewriter.rewrite(assignments.get(symbol))); + projections.add(assignments.get(symbol)); } - Map, Type> expressionTypes = getExpressionTypesFromInput( + Map, Type> expressionTypes = getExpressionTypes( context.getSession(), metadata, sqlParser, - sourceTypes, - concat(rewrittenFilter.map(ImmutableList::of).orElse(ImmutableList.of()), rewrittenProjections), + context.getTypes(), + concat(filterExpression.map(ImmutableList::of).orElse(ImmutableList.of()), assignments.getExpressions()), emptyList(), - NOOP); + NOOP, + false); - Optional translatedFilter = rewrittenFilter.map(filter -> toRowExpression(filter, expressionTypes)); - List translatedProjections = rewrittenProjections.stream() - .map(expression -> toRowExpression(expression, expressionTypes)) + Optional translatedFilter = filterExpression.map(filter -> toRowExpression(filter, expressionTypes, sourceLayout)); + List translatedProjections = projections.stream() + .map(expression -> toRowExpression(expression, expressionTypes, sourceLayout)) .collect(toImmutableList()); try { @@ -1256,7 +1244,7 @@ else if (sourceNode instanceof SampleNode) { cursorProcessor, pageProcessor, columns, - getTypes(rewrittenProjections, expressionTypes), + getTypes(projections, expressionTypes), getFilterAndProjectMinOutputPageSize(session), getFilterAndProjectMinOutputPageRowCount(session)); @@ -1269,7 +1257,7 @@ else if (sourceNode instanceof SampleNode) { context.getNextOperatorId(), planNodeId, pageProcessor, - getTypes(rewrittenProjections, expressionTypes), + getTypes(projections, expressionTypes), getFilterAndProjectMinOutputPageSize(session), getFilterAndProjectMinOutputPageRowCount(session)); @@ -1281,19 +1269,9 @@ else if (sourceNode instanceof SampleNode) { } } - private RowExpression toRowExpression(Expression expression, Map, Type> types) - { - return SqlToRowExpressionTranslator.translate(expression, SCALAR, types, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, true); - } - - private Map getInputTypes(Map layout, List types) + private RowExpression toRowExpression(Expression expression, Map, Type> types, Map layout) { - ImmutableMap.Builder inputTypes = ImmutableMap.builder(); - for (Integer input : ImmutableSet.copyOf(layout.values())) { - Type type = types.get(input); - inputTypes.put(input, type); - } - return inputTypes.build(); + return SqlToRowExpressionTranslator.translate(expression, SCALAR, types, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, true, layout); } @Override @@ -2058,20 +2036,17 @@ private JoinFilterFunctionFactory compileJoinFilterFunction( { Map joinSourcesLayout = createJoinSourcesLayout(buildLayout, probeLayout); - Map sourceTypes = joinSourcesLayout.entrySet().stream() - .collect(toImmutableMap(Map.Entry::getValue, entry -> types.get(entry.getKey()))); - - Expression rewrittenFilter = new SymbolToInputRewriter(joinSourcesLayout).rewrite(filterExpression); - Map, Type> expressionTypes = getExpressionTypesFromInput( + Map, Type> expressionTypes = getExpressionTypes( session, metadata, sqlParser, - sourceTypes, - rewrittenFilter, + types, + filterExpression, emptyList(), /* parameters have already been replaced */ - NOOP); + NOOP, + false); - RowExpression translatedFilter = toRowExpression(rewrittenFilter, expressionTypes); + RowExpression translatedFilter = toRowExpression(filterExpression, expressionTypes, joinSourcesLayout); return joinFilterFunctionCompiler.compileJoinFilterFunction(translatedFilter, buildLayout.size()); } @@ -2603,7 +2578,7 @@ private AccumulatorFactory buildAccumulatorFactory( NOOP)) .build(); - LambdaDefinitionExpression lambda = (LambdaDefinitionExpression) toRowExpression(lambdaExpression, expressionTypes); + LambdaDefinitionExpression lambda = (LambdaDefinitionExpression) toRowExpression(lambdaExpression, expressionTypes, ImmutableMap.of()); Class lambdaProviderClass = compileLambdaProvider(lambda, metadata.getFunctionRegistry(), lambdaInterfaces.get(i)); try { lambdaProviders.add((LambdaProvider) constructorMethodHandle(lambdaProviderClass, ConnectorSession.class).invoke(session.toConnectorSession())); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/ExpressionEquivalence.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/ExpressionEquivalence.java index 81b519f6a3a5..5a5066b5721a 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/ExpressionEquivalence.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/ExpressionEquivalence.java @@ -26,7 +26,6 @@ import io.prestosql.spi.type.Type; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; -import io.prestosql.sql.planner.SymbolToInputRewriter; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.relational.CallExpression; import io.prestosql.sql.relational.ConstantExpression; @@ -58,7 +57,7 @@ import static io.prestosql.spi.function.OperatorType.LESS_THAN_OR_EQUAL; import static io.prestosql.spi.function.OperatorType.NOT_EQUAL; import static io.prestosql.spi.type.BooleanType.BOOLEAN; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput; +import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.relational.SqlToRowExpressionTranslator.translate; import static java.lang.Integer.min; import static java.util.Collections.emptyList; @@ -80,15 +79,13 @@ public ExpressionEquivalence(Metadata metadata, SqlParser sqlParser) public boolean areExpressionsEquivalent(Session session, Expression leftExpression, Expression rightExpression, TypeProvider types) { Map symbolInput = new HashMap<>(); - Map inputTypes = new HashMap<>(); int inputId = 0; for (Entry entry : types.allTypes().entrySet()) { symbolInput.put(entry.getKey(), inputId); - inputTypes.put(inputId, entry.getValue()); inputId++; } - RowExpression leftRowExpression = toRowExpression(session, leftExpression, symbolInput, inputTypes); - RowExpression rightRowExpression = toRowExpression(session, rightExpression, symbolInput, inputTypes); + RowExpression leftRowExpression = toRowExpression(session, leftExpression, symbolInput, types); + RowExpression rightRowExpression = toRowExpression(session, rightExpression, symbolInput, types); RowExpression canonicalizedLeft = leftRowExpression.accept(CANONICALIZATION_VISITOR, null); RowExpression canonicalizedRight = rightRowExpression.accept(CANONICALIZATION_VISITOR, null); @@ -96,23 +93,20 @@ public boolean areExpressionsEquivalent(Session session, Expression leftExpressi return canonicalizedLeft.equals(canonicalizedRight); } - private RowExpression toRowExpression(Session session, Expression expression, Map symbolInput, Map inputTypes) + private RowExpression toRowExpression(Session session, Expression expression, Map symbolInput, TypeProvider types) { - // replace qualified names with input references since row expressions do not support these - Expression expressionWithInputReferences = new SymbolToInputRewriter(symbolInput).rewrite(expression); - // determine the type of every expression - Map, Type> expressionTypes = getExpressionTypesFromInput( + Map, Type> expressionTypes = getExpressionTypes( session, metadata, sqlParser, - inputTypes, - expressionWithInputReferences, + types, + expression, emptyList(), /* parameters have already been replaced */ WarningCollector.NOOP); // convert to row expression - return translate(expressionWithInputReferences, SCALAR, expressionTypes, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, false); + return translate(expression, SCALAR, expressionTypes, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, false, symbolInput); } private static class CanonicalizationVisitor diff --git a/presto-main/src/main/java/io/prestosql/sql/relational/SqlToRowExpressionTranslator.java b/presto-main/src/main/java/io/prestosql/sql/relational/SqlToRowExpressionTranslator.java index 720182a33668..fbf247765108 100644 --- a/presto-main/src/main/java/io/prestosql/sql/relational/SqlToRowExpressionTranslator.java +++ b/presto-main/src/main/java/io/prestosql/sql/relational/SqlToRowExpressionTranslator.java @@ -31,6 +31,7 @@ import io.prestosql.spi.type.TypeManager; import io.prestosql.spi.type.TypeSignature; import io.prestosql.spi.type.VarcharType; +import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.relational.optimizer.ExpressionOptimizer; import io.prestosql.sql.tree.ArithmeticBinaryExpression; import io.prestosql.sql.tree.ArithmeticUnaryExpression; @@ -144,12 +145,14 @@ public static RowExpression translate( FunctionRegistry functionRegistry, TypeManager typeManager, Session session, - boolean optimize) + boolean optimize, + Map layout) { Visitor visitor = new Visitor( functionKind, types, typeManager, + layout, session.getTimeZoneKey(), isLegacyRowFieldOrdinalAccessEnabled(session), SystemSessionProperties.isLegacyTimestamp(session)); @@ -171,6 +174,7 @@ private static class Visitor private final FunctionKind functionKind; private final Map, Type> types; private final TypeManager typeManager; + private final Map layout; private final TimeZoneKey timeZoneKey; private final boolean legacyRowFieldOrdinalAccess; @Deprecated @@ -180,6 +184,7 @@ private Visitor( FunctionKind functionKind, Map, Type> types, TypeManager typeManager, + Map layout, TimeZoneKey timeZoneKey, boolean legacyRowFieldOrdinalAccess, boolean isLegacyTimestamp) @@ -187,6 +192,7 @@ private Visitor( this.functionKind = functionKind; this.types = ImmutableMap.copyOf(requireNonNull(types, "types is null")); this.typeManager = typeManager; + this.layout = layout; this.timeZoneKey = timeZoneKey; this.legacyRowFieldOrdinalAccess = legacyRowFieldOrdinalAccess; this.isLegacyTimestamp = isLegacyTimestamp; @@ -363,6 +369,11 @@ protected RowExpression visitFunctionCall(FunctionCall node, Void context) @Override protected RowExpression visitSymbolReference(SymbolReference node, Void context) { + Integer field = layout.get(Symbol.from(node)); + if (field != null) { + return field(field, getType(node)); + } + return new VariableReferenceExpression(node.getName(), getType(node)); } diff --git a/presto-main/src/test/java/io/prestosql/operator/BenchmarkScanFilterAndProjectOperator.java b/presto-main/src/test/java/io/prestosql/operator/BenchmarkScanFilterAndProjectOperator.java index 439402947f45..418b2730462b 100644 --- a/presto-main/src/test/java/io/prestosql/operator/BenchmarkScanFilterAndProjectOperator.java +++ b/presto-main/src/test/java/io/prestosql/operator/BenchmarkScanFilterAndProjectOperator.java @@ -33,7 +33,6 @@ import io.prestosql.sql.gen.PageFunctionCompiler; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; -import io.prestosql.sql.planner.SymbolToInputRewriter; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.PlanNodeId; import io.prestosql.sql.relational.RowExpression; @@ -81,7 +80,7 @@ import static io.prestosql.operator.scalar.FunctionAssertions.createExpression; import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.spi.type.VarcharType.VARCHAR; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput; +import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.testing.TestingSplit.createLocalSplit; import static java.util.Collections.emptyList; import static java.util.Locale.ENGLISH; @@ -203,10 +202,10 @@ private List createInputPages(List types) private RowExpression getFilter(Type type) { if (type == VARCHAR) { - return rowExpression("cast(varchar0 as bigint) % 2 = 0", VARCHAR); + return rowExpression("cast(varchar0 as bigint) % 2 = 0"); } if (type == BIGINT) { - return rowExpression("bigint0 % 2 = 0", BIGINT); + return rowExpression("bigint0 % 2 = 0"); } throw new IllegalArgumentException("filter not supported for type : " + type); } @@ -216,32 +215,25 @@ private List getProjections(Type type) ImmutableList.Builder builder = ImmutableList.builder(); if (type == BIGINT) { for (int i = 0; i < columnCount; i++) { - builder.add(rowExpression("bigint" + i + " + 5", type)); + builder.add(rowExpression("bigint" + i + " + 5")); } } else if (type == VARCHAR) { for (int i = 0; i < columnCount; i++) { // alternatively use identity expression rowExpression("varchar" + i, type) or // rowExpression("substr(varchar" + i + ", 1, 1)", type) - builder.add(rowExpression("concat(varchar" + i + ", 'foo')", type)); + builder.add(rowExpression("concat(varchar" + i + ", 'foo')")); } } return builder.build(); } - private RowExpression rowExpression(String expression, Type type) + private RowExpression rowExpression(String value) { - SymbolToInputRewriter symbolToInputRewriter = new SymbolToInputRewriter(sourceLayout); - Expression inputReferenceExpression = symbolToInputRewriter.rewrite(createExpression(expression, METADATA, TypeProvider.copyOf(symbolTypes))); + Expression expression = createExpression(value, METADATA, TypeProvider.copyOf(symbolTypes)); - ImmutableMap.Builder builder = ImmutableMap.builder(); - for (int i = 0; i < columnCount; i++) { - builder.put(i, type); - } - Map types = builder.build(); - - Map, Type> expressionTypes = getExpressionTypesFromInput(TEST_SESSION, METADATA, SQL_PARSER, types, inputReferenceExpression, emptyList(), WarningCollector.NOOP); - return SqlToRowExpressionTranslator.translate(inputReferenceExpression, SCALAR, expressionTypes, METADATA.getFunctionRegistry(), METADATA.getTypeManager(), TEST_SESSION, true); + Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, TypeProvider.copyOf(symbolTypes), expression, emptyList(), WarningCollector.NOOP); + return SqlToRowExpressionTranslator.translate(expression, SCALAR, expressionTypes, METADATA.getFunctionRegistry(), METADATA.getTypeManager(), TEST_SESSION, true, sourceLayout); } private static Page createPage(List types, int positions, boolean dictionary) diff --git a/presto-main/src/test/java/io/prestosql/operator/scalar/FunctionAssertions.java b/presto-main/src/test/java/io/prestosql/operator/scalar/FunctionAssertions.java index 00d181faa4b6..2c81971cec69 100644 --- a/presto-main/src/test/java/io/prestosql/operator/scalar/FunctionAssertions.java +++ b/presto-main/src/test/java/io/prestosql/operator/scalar/FunctionAssertions.java @@ -66,7 +66,6 @@ import io.prestosql.sql.gen.ExpressionCompiler; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; -import io.prestosql.sql.planner.SymbolToInputRewriter; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.PlanNodeId; import io.prestosql.sql.relational.RowExpression; @@ -129,7 +128,7 @@ import static io.prestosql.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; import static io.prestosql.sql.ParsingUtil.createParsingOptions; import static io.prestosql.sql.analyzer.ExpressionAnalyzer.analyzeExpressionsWithSymbols; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput; +import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.planner.iterative.rule.CanonicalizeExpressionRewriter.canonicalizeExpression; import static io.prestosql.sql.relational.Expressions.constant; import static io.prestosql.sql.relational.SqlToRowExpressionTranslator.translate; @@ -168,17 +167,17 @@ public final class FunctionAssertions private static final Page ZERO_CHANNEL_PAGE = new Page(1); - private static final Map INPUT_TYPES = ImmutableMap.builder() - .put(0, BIGINT) - .put(1, VARCHAR) - .put(2, DOUBLE) - .put(3, BOOLEAN) - .put(4, BIGINT) - .put(5, VARCHAR) - .put(6, VARCHAR) - .put(7, TIMESTAMP_WITH_TIME_ZONE) - .put(8, VARBINARY) - .put(9, INTEGER) + private static final Map INPUT_TYPES = ImmutableMap.builder() + .put(new Symbol("bound_long"), BIGINT) + .put(new Symbol("bound_string"), VARCHAR) + .put(new Symbol("bound_double"), DOUBLE) + .put(new Symbol("bound_boolean"), BOOLEAN) + .put(new Symbol("bound_timestamp"), BIGINT) + .put(new Symbol("bound_pattern"), VARCHAR) + .put(new Symbol("bound_null_string"), VARCHAR) + .put(new Symbol("bound_timestamp_with_timezone"), TIMESTAMP_WITH_TIME_ZONE) + .put(new Symbol("bound_binary_literal"), VARBINARY) + .put(new Symbol("bound_integer"), INTEGER) .build(); private static final Map INPUT_MAPPING = ImmutableMap.builder() @@ -630,16 +629,15 @@ private List executeProjectionWithAll(String projection, Type expectedTy private RowExpression toRowExpression(Session session, Expression projectionExpression) { - Expression translatedProjection = new SymbolToInputRewriter(INPUT_MAPPING).rewrite(projectionExpression); - Map, Type> expressionTypes = getExpressionTypesFromInput( + Map, Type> expressionTypes = getExpressionTypes( session, metadata, SQL_PARSER, - INPUT_TYPES, - ImmutableList.of(translatedProjection), + TypeProvider.copyOf(INPUT_TYPES), + projectionExpression, ImmutableList.of(), WarningCollector.NOOP); - return toRowExpression(translatedProjection, expressionTypes); + return toRowExpression(projectionExpression, expressionTypes, INPUT_MAPPING); } private Object selectSingleValue(OperatorFactory operatorFactory, Type type, Session session) @@ -955,9 +953,9 @@ private static SourceOperatorFactory compileScanFilterProject(Optional, Type> expressionTypes) + private RowExpression toRowExpression(Expression projection, Map, Type> expressionTypes, Map layout) { - return translate(projection, SCALAR, expressionTypes, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, false); + return translate(projection, SCALAR, expressionTypes, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, false, layout); } private static Page getAtMostOnePage(Operator operator, Page sourcePage) diff --git a/presto-main/src/test/java/io/prestosql/sql/TestSqlToRowExpressionTranslator.java b/presto-main/src/test/java/io/prestosql/sql/TestSqlToRowExpressionTranslator.java index b8071a11f17a..349e453dd337 100644 --- a/presto-main/src/test/java/io/prestosql/sql/TestSqlToRowExpressionTranslator.java +++ b/presto-main/src/test/java/io/prestosql/sql/TestSqlToRowExpressionTranslator.java @@ -92,7 +92,7 @@ private RowExpression translateAndOptimize(Expression expression) private RowExpression translateAndOptimize(Expression expression, Map, Type> types) { - return SqlToRowExpressionTranslator.translate(expression, SCALAR, types, metadata.getFunctionRegistry(), metadata.getTypeManager(), TEST_SESSION, true); + return SqlToRowExpressionTranslator.translate(expression, SCALAR, types, metadata.getFunctionRegistry(), metadata.getTypeManager(), TEST_SESSION, true, ImmutableMap.of()); } private Expression simplifyExpression(Expression expression) diff --git a/presto-main/src/test/java/io/prestosql/sql/gen/PageProcessorBenchmark.java b/presto-main/src/test/java/io/prestosql/sql/gen/PageProcessorBenchmark.java index 2a556a9d5767..597b82f96169 100644 --- a/presto-main/src/test/java/io/prestosql/sql/gen/PageProcessorBenchmark.java +++ b/presto-main/src/test/java/io/prestosql/sql/gen/PageProcessorBenchmark.java @@ -30,7 +30,6 @@ import io.prestosql.spi.type.Type; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; -import io.prestosql.sql.planner.SymbolToInputRewriter; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.relational.RowExpression; import io.prestosql.sql.relational.SqlToRowExpressionTranslator; @@ -66,7 +65,7 @@ import static io.prestosql.operator.scalar.FunctionAssertions.createExpression; import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.spi.type.VarcharType.VARCHAR; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput; +import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static java.util.Collections.emptyList; import static java.util.Locale.ENGLISH; import static java.util.stream.Collectors.toList; @@ -151,10 +150,10 @@ public List> columnOriented() private RowExpression getFilter(Type type) { if (type == VARCHAR) { - return rowExpression("cast(varchar0 as bigint) % 2 = 0", VARCHAR); + return rowExpression("cast(varchar0 as bigint) % 2 = 0"); } if (type == BIGINT) { - return rowExpression("bigint0 % 2 = 0", BIGINT); + return rowExpression("bigint0 % 2 = 0"); } throw new IllegalArgumentException("filter not supported for type : " + type); } @@ -164,32 +163,25 @@ private List getProjections(Type type) ImmutableList.Builder builder = ImmutableList.builder(); if (type == BIGINT) { for (int i = 0; i < columnCount; i++) { - builder.add(rowExpression("bigint" + i + " + 5", type)); + builder.add(rowExpression("bigint" + i + " + 5")); } } else if (type == VARCHAR) { for (int i = 0; i < columnCount; i++) { // alternatively use identity expression rowExpression("varchar" + i, type) or // rowExpression("substr(varchar" + i + ", 1, 1)", type) - builder.add(rowExpression("concat(varchar" + i + ", 'foo')", type)); + builder.add(rowExpression("concat(varchar" + i + ", 'foo')")); } } return builder.build(); } - private RowExpression rowExpression(String expression, Type type) + private RowExpression rowExpression(String value) { - SymbolToInputRewriter symbolToInputRewriter = new SymbolToInputRewriter(sourceLayout); - Expression inputReferenceExpression = symbolToInputRewriter.rewrite(createExpression(expression, METADATA, TypeProvider.copyOf(symbolTypes))); + Expression expression = createExpression(value, METADATA, TypeProvider.copyOf(symbolTypes)); - ImmutableMap.Builder builder = ImmutableMap.builder(); - for (int i = 0; i < columnCount; i++) { - builder.put(i, type); - } - Map types = builder.build(); - - Map, Type> expressionTypes = getExpressionTypesFromInput(TEST_SESSION, METADATA, SQL_PARSER, types, inputReferenceExpression, emptyList(), WarningCollector.NOOP); - return SqlToRowExpressionTranslator.translate(inputReferenceExpression, SCALAR, expressionTypes, METADATA.getFunctionRegistry(), METADATA.getTypeManager(), TEST_SESSION, true); + Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, TypeProvider.copyOf(symbolTypes), expression, emptyList(), WarningCollector.NOOP); + return SqlToRowExpressionTranslator.translate(expression, SCALAR, expressionTypes, METADATA.getFunctionRegistry(), METADATA.getTypeManager(), TEST_SESSION, true, sourceLayout); } private static Page createPage(List types, boolean dictionary) diff --git a/presto-main/src/test/java/io/prestosql/type/BenchmarkDecimalOperators.java b/presto-main/src/test/java/io/prestosql/type/BenchmarkDecimalOperators.java index e6402826a62b..7c25281bb558 100644 --- a/presto-main/src/test/java/io/prestosql/type/BenchmarkDecimalOperators.java +++ b/presto-main/src/test/java/io/prestosql/type/BenchmarkDecimalOperators.java @@ -30,7 +30,6 @@ import io.prestosql.sql.gen.PageFunctionCompiler; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; -import io.prestosql.sql.planner.SymbolToInputRewriter; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.relational.RowExpression; import io.prestosql.sql.relational.SqlToRowExpressionTranslator; @@ -71,14 +70,13 @@ import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.spi.type.DecimalType.createDecimalType; import static io.prestosql.spi.type.DoubleType.DOUBLE; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput; +import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.testing.TestingConnectorSession.SESSION; import static io.prestosql.testing.TestingSession.testSessionBuilder; import static java.math.BigInteger.ONE; import static java.math.BigInteger.ZERO; import static java.util.Collections.emptyList; import static java.util.stream.Collectors.toList; -import static java.util.stream.Collectors.toMap; import static org.openjdk.jmh.annotations.Scope.Thread; @State(Scope.Thread) @@ -611,15 +609,12 @@ protected void setDoubleMaxValue(double doubleMaxValue) this.doubleMaxValue = doubleMaxValue; } - private RowExpression rowExpression(String expression) + private RowExpression rowExpression(String value) { - Expression inputReferenceExpression = new SymbolToInputRewriter(sourceLayout).rewrite(createExpression(expression, metadata, TypeProvider.copyOf(symbolTypes))); + Expression expression = createExpression(value, metadata, TypeProvider.copyOf(symbolTypes)); - Map types = sourceLayout.entrySet().stream() - .collect(toMap(Map.Entry::getValue, entry -> symbolTypes.get(entry.getKey()))); - - Map, Type> expressionTypes = getExpressionTypesFromInput(TEST_SESSION, metadata, SQL_PARSER, types, inputReferenceExpression, emptyList(), WarningCollector.NOOP); - return SqlToRowExpressionTranslator.translate(inputReferenceExpression, SCALAR, expressionTypes, metadata.getFunctionRegistry(), metadata.getTypeManager(), TEST_SESSION, true); + Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, metadata, SQL_PARSER, TypeProvider.copyOf(symbolTypes), expression, emptyList(), WarningCollector.NOOP); + return SqlToRowExpressionTranslator.translate(expression, SCALAR, expressionTypes, metadata.getFunctionRegistry(), metadata.getTypeManager(), TEST_SESSION, true, sourceLayout); } private Object generateRandomValue(Type type) From 64ed432b81b7328d97444b33083bd7d46d0e3719 Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Wed, 6 Mar 2019 12:31:45 -0800 Subject: [PATCH 11/18] Fix expression type The inferred type of the former expression is INTEGER, which doesn't match the signature of combineHash function call. --- .../sql/planner/optimizations/HashGenerationOptimizer.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/HashGenerationOptimizer.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/HashGenerationOptimizer.java index 57695177bc5f..2397d3f43a0f 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/HashGenerationOptimizer.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/HashGenerationOptimizer.java @@ -839,7 +839,7 @@ public static Optional getHashExpression(List symbols) return Optional.empty(); } - Expression result = new LongLiteral(String.valueOf(INITIAL_HASH_VALUE)); + Expression result = new GenericLiteral(StandardTypes.BIGINT, String.valueOf(INITIAL_HASH_VALUE)); for (Symbol symbol : symbols) { Expression hashField = new FunctionCall( QualifiedName.of(HASH_CODE), From cc952bc8901a77d6752c105ebd316aa02083f6fc Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Wed, 6 Mar 2019 11:32:31 -0800 Subject: [PATCH 12/18] Remove interpreted page processors They were only being used in tests. The engine no longer relies on them for query execution. --- .../benchmark/AbstractOperatorBenchmark.java | 24 +- .../project/InterpretedPageFilter.java | 104 ------- .../project/InterpretedPageProjection.java | 158 ---------- .../sql/planner/ExpressionInterpreter.java | 64 +--- .../sql/planner/LocalExecutionPlanner.java | 2 +- .../SymbolToInputParameterRewriter.java | 109 ------- .../optimizations/ExpressionEquivalence.java | 2 +- .../SqlToRowExpressionTranslator.java | 4 +- ...BenchmarkScanFilterAndProjectOperator.java | 2 +- .../operator/scalar/FunctionAssertions.java | 78 +++-- .../sql/TestSqlToRowExpressionTranslator.java | 2 +- .../sql/gen/PageProcessorBenchmark.java | 2 +- .../TestInterpretedPageFilterFunction.java | 220 ------------- ...TestInterpretedPageProjectionFunction.java | 288 ------------------ .../type/BenchmarkDecimalOperators.java | 2 +- 15 files changed, 77 insertions(+), 984 deletions(-) delete mode 100644 presto-main/src/main/java/io/prestosql/operator/project/InterpretedPageFilter.java delete mode 100644 presto-main/src/main/java/io/prestosql/operator/project/InterpretedPageProjection.java delete mode 100644 presto-main/src/main/java/io/prestosql/sql/planner/SymbolToInputParameterRewriter.java delete mode 100644 presto-main/src/test/java/io/prestosql/sql/planner/TestInterpretedPageFilterFunction.java delete mode 100644 presto-main/src/test/java/io/prestosql/sql/planner/TestInterpretedPageProjectionFunction.java diff --git a/presto-benchmark/src/main/java/io/prestosql/benchmark/AbstractOperatorBenchmark.java b/presto-benchmark/src/main/java/io/prestosql/benchmark/AbstractOperatorBenchmark.java index 688b561dcd0d..3b3171c37318 100644 --- a/presto-benchmark/src/main/java/io/prestosql/benchmark/AbstractOperatorBenchmark.java +++ b/presto-benchmark/src/main/java/io/prestosql/benchmark/AbstractOperatorBenchmark.java @@ -23,6 +23,7 @@ import io.prestosql.execution.Lifespan; import io.prestosql.execution.TaskId; import io.prestosql.execution.TaskStateMachine; +import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.memory.MemoryPool; import io.prestosql.memory.QueryContext; import io.prestosql.metadata.Metadata; @@ -39,7 +40,6 @@ import io.prestosql.operator.TaskContext; import io.prestosql.operator.TaskStats; import io.prestosql.operator.project.InputPageProjection; -import io.prestosql.operator.project.InterpretedPageProjection; import io.prestosql.operator.project.PageProcessor; import io.prestosql.operator.project.PageProjection; import io.prestosql.security.AllowAllAccessControl; @@ -50,11 +50,14 @@ import io.prestosql.spi.type.Type; import io.prestosql.spiller.SpillSpaceTracker; import io.prestosql.split.SplitSource; +import io.prestosql.sql.gen.PageFunctionCompiler; import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.optimizations.HashGenerationOptimizer; import io.prestosql.sql.planner.plan.PlanNodeId; +import io.prestosql.sql.relational.RowExpression; import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.NodeRef; import io.prestosql.testing.LocalQueryRunner; import io.prestosql.transaction.TransactionId; @@ -76,9 +79,12 @@ import static io.airlift.units.DataSize.Unit.MEGABYTE; import static io.prestosql.SystemSessionProperties.getFilterAndProjectMinOutputPageRowCount; import static io.prestosql.SystemSessionProperties.getFilterAndProjectMinOutputPageSize; +import static io.prestosql.metadata.FunctionKind.SCALAR; import static io.prestosql.spi.connector.ConnectorSplitManager.SplitSchedulingStrategy.UNGROUPED_SCHEDULING; import static io.prestosql.spi.connector.NotPartitionedPartitionHandle.NOT_PARTITIONED; import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; +import static io.prestosql.sql.relational.SqlToRowExpressionTranslator.translate; import static io.prestosql.testing.TestingSession.testSessionBuilder; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -222,13 +228,19 @@ protected final OperatorFactory createHashProjectOperator(int operatorId, PlanNo Optional hashExpression = HashGenerationOptimizer.getHashExpression(ImmutableList.copyOf(symbolTypes.build().keySet())); verify(hashExpression.isPresent()); - projections.add(new InterpretedPageProjection( - hashExpression.get(), - TypeProvider.copyOf(symbolTypes.build()), - symbolToInputMapping.build(), + + Map, Type> expressionTypes = getExpressionTypes( + session, localQueryRunner.getMetadata(), localQueryRunner.getSqlParser(), - session)); + TypeProvider.copyOf(symbolTypes.build()), + hashExpression.get(), + ImmutableList.of(), + WarningCollector.NOOP); + RowExpression translated = translate(hashExpression.get(), SCALAR, expressionTypes, symbolToInputMapping.build(), localQueryRunner.getMetadata().getFunctionRegistry(), localQueryRunner.getTypeManager(), session, false); + + PageFunctionCompiler functionCompiler = new PageFunctionCompiler(localQueryRunner.getMetadata(), 0); + projections.add(functionCompiler.compileProjection(translated, Optional.empty()).get()); return new FilterAndProjectOperator.FilterAndProjectOperatorFactory( operatorId, diff --git a/presto-main/src/main/java/io/prestosql/operator/project/InterpretedPageFilter.java b/presto-main/src/main/java/io/prestosql/operator/project/InterpretedPageFilter.java deleted file mode 100644 index d2134e0acd86..000000000000 --- a/presto-main/src/main/java/io/prestosql/operator/project/InterpretedPageFilter.java +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prestosql.operator.project; - -import com.google.common.collect.ImmutableMap; -import io.prestosql.Session; -import io.prestosql.execution.warnings.WarningCollector; -import io.prestosql.metadata.Metadata; -import io.prestosql.spi.Page; -import io.prestosql.spi.connector.ConnectorSession; -import io.prestosql.spi.type.Type; -import io.prestosql.sql.parser.SqlParser; -import io.prestosql.sql.planner.DeterminismEvaluator; -import io.prestosql.sql.planner.ExpressionInterpreter; -import io.prestosql.sql.planner.Symbol; -import io.prestosql.sql.planner.SymbolToInputParameterRewriter; -import io.prestosql.sql.planner.TypeProvider; -import io.prestosql.sql.tree.Expression; -import io.prestosql.sql.tree.NodeRef; - -import javax.annotation.concurrent.NotThreadSafe; - -import java.util.List; -import java.util.Map; - -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput; -import static java.lang.Boolean.TRUE; -import static java.util.Collections.emptyList; - -@NotThreadSafe -public class InterpretedPageFilter - implements PageFilter -{ - private final ExpressionInterpreter evaluator; - private final InputChannels inputChannels; - private final boolean deterministic; - private boolean[] selectedPositions = new boolean[0]; - - public InterpretedPageFilter( - Expression expression, - TypeProvider symbolTypes, - Map symbolToInputMappings, - Metadata metadata, - SqlParser sqlParser, - Session session) - { - SymbolToInputParameterRewriter rewriter = new SymbolToInputParameterRewriter(symbolTypes, symbolToInputMappings); - Expression rewritten = rewriter.rewrite(expression); - this.inputChannels = new InputChannels(rewriter.getInputChannels()); - this.deterministic = DeterminismEvaluator.isDeterministic(expression); - - // analyze rewritten expression so we can know the type of every expression in the tree - List inputTypes = rewriter.getInputTypes(); - ImmutableMap.Builder parameterTypes = ImmutableMap.builder(); - for (int parameter = 0; parameter < inputTypes.size(); parameter++) { - Type type = inputTypes.get(parameter); - parameterTypes.put(parameter, type); - } - Map, Type> expressionTypes = getExpressionTypesFromInput(session, metadata, sqlParser, parameterTypes.build(), rewritten, emptyList(), WarningCollector.NOOP); - this.evaluator = ExpressionInterpreter.expressionInterpreter(rewritten, metadata, session, expressionTypes); - } - - @Override - public boolean isDeterministic() - { - return deterministic; - } - - @Override - public InputChannels getInputChannels() - { - return inputChannels; - } - - @Override - public SelectedPositions filter(ConnectorSession session, Page page) - { - if (selectedPositions.length < page.getPositionCount()) { - selectedPositions = new boolean[page.getPositionCount()]; - } - - for (int position = 0; position < page.getPositionCount(); position++) { - selectedPositions[position] = filter(page, position); - } - - return PageFilter.positionsArrayToSelectedPositions(selectedPositions, page.getPositionCount()); - } - - private boolean filter(Page page, int position) - { - return TRUE.equals(evaluator.evaluate(position, page)); - } -} diff --git a/presto-main/src/main/java/io/prestosql/operator/project/InterpretedPageProjection.java b/presto-main/src/main/java/io/prestosql/operator/project/InterpretedPageProjection.java deleted file mode 100644 index 733c0e506bbc..000000000000 --- a/presto-main/src/main/java/io/prestosql/operator/project/InterpretedPageProjection.java +++ /dev/null @@ -1,158 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prestosql.operator.project; - -import com.google.common.collect.ImmutableMap; -import io.prestosql.Session; -import io.prestosql.execution.warnings.WarningCollector; -import io.prestosql.metadata.Metadata; -import io.prestosql.operator.DriverYieldSignal; -import io.prestosql.operator.Work; -import io.prestosql.spi.Page; -import io.prestosql.spi.block.Block; -import io.prestosql.spi.block.BlockBuilder; -import io.prestosql.spi.connector.ConnectorSession; -import io.prestosql.spi.type.Type; -import io.prestosql.sql.parser.SqlParser; -import io.prestosql.sql.planner.DeterminismEvaluator; -import io.prestosql.sql.planner.ExpressionInterpreter; -import io.prestosql.sql.planner.Symbol; -import io.prestosql.sql.planner.SymbolToInputParameterRewriter; -import io.prestosql.sql.planner.TypeProvider; -import io.prestosql.sql.tree.Expression; -import io.prestosql.sql.tree.NodeRef; - -import java.util.List; -import java.util.Map; - -import static com.google.common.base.Preconditions.checkState; -import static io.prestosql.spi.type.TypeUtils.writeNativeValue; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput; -import static java.util.Collections.emptyList; -import static java.util.Objects.requireNonNull; - -public class InterpretedPageProjection - implements PageProjection -{ - private final ExpressionInterpreter evaluator; - private final InputChannels inputChannels; - private final boolean deterministic; - private BlockBuilder blockBuilder; - - public InterpretedPageProjection( - Expression expression, - TypeProvider symbolTypes, - Map symbolToInputMappings, - Metadata metadata, - SqlParser sqlParser, - Session session) - { - SymbolToInputParameterRewriter rewriter = new SymbolToInputParameterRewriter(symbolTypes, symbolToInputMappings); - Expression rewritten = rewriter.rewrite(expression); - this.inputChannels = new InputChannels(rewriter.getInputChannels()); - this.deterministic = DeterminismEvaluator.isDeterministic(expression); - - // analyze rewritten expression so we can know the type of every expression in the tree - List inputTypes = rewriter.getInputTypes(); - ImmutableMap.Builder parameterTypes = ImmutableMap.builder(); - for (int parameter = 0; parameter < inputTypes.size(); parameter++) { - Type type = inputTypes.get(parameter); - parameterTypes.put(parameter, type); - } - Map, Type> expressionTypes = getExpressionTypesFromInput(session, metadata, sqlParser, parameterTypes.build(), rewritten, emptyList(), WarningCollector.NOOP); - this.evaluator = ExpressionInterpreter.expressionInterpreter(rewritten, metadata, session, expressionTypes); - - blockBuilder = evaluator.getType().createBlockBuilder(null, 1); - } - - @Override - public Type getType() - { - return evaluator.getType(); - } - - @Override - public boolean isDeterministic() - { - return deterministic; - } - - @Override - public InputChannels getInputChannels() - { - return inputChannels; - } - - @Override - public Work project(ConnectorSession session, DriverYieldSignal yieldSignal, Page page, SelectedPositions selectedPositions) - { - return new InterpretedPageProjectionWork(yieldSignal, page, selectedPositions); - } - - private class InterpretedPageProjectionWork - implements Work - { - private final DriverYieldSignal yieldSignal; - private final Page page; - private final SelectedPositions selectedPositions; - - private int nextIndexOrPosition; - private Block result; - - public InterpretedPageProjectionWork(DriverYieldSignal yieldSignal, Page page, SelectedPositions selectedPositions) - { - this.yieldSignal = requireNonNull(yieldSignal, "yieldSignal is null"); - this.page = requireNonNull(page, "page is null"); - this.selectedPositions = requireNonNull(selectedPositions, "selectedPositions is null"); - this.nextIndexOrPosition = selectedPositions.getOffset(); - } - - @Override - public boolean process() - { - checkState(result == null, "result has been generated"); - int length = selectedPositions.getOffset() + selectedPositions.size(); - if (selectedPositions.isList()) { - int[] positions = selectedPositions.getPositions(); - while (nextIndexOrPosition < length) { - writeNativeValue(evaluator.getType(), blockBuilder, evaluator.evaluate(positions[nextIndexOrPosition], page)); - nextIndexOrPosition++; - if (yieldSignal.isSet()) { - return false; - } - } - } - else { - while (nextIndexOrPosition < length) { - writeNativeValue(evaluator.getType(), blockBuilder, evaluator.evaluate(nextIndexOrPosition, page)); - nextIndexOrPosition++; - if (yieldSignal.isSet()) { - return false; - } - } - } - - result = blockBuilder.build(); - blockBuilder = blockBuilder.newBlockBuilderLike(null); - return true; - } - - @Override - public Block getResult() - { - checkState(result != null, "result has not been generated"); - return result; - } - } -} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/ExpressionInterpreter.java b/presto-main/src/main/java/io/prestosql/sql/planner/ExpressionInterpreter.java index 690a1d1891e8..a3e27e155909 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/ExpressionInterpreter.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/ExpressionInterpreter.java @@ -28,7 +28,6 @@ import io.prestosql.metadata.Signature; import io.prestosql.operator.scalar.ArraySubscriptOperator; import io.prestosql.operator.scalar.ScalarFunctionImplementation; -import io.prestosql.spi.Page; import io.prestosql.spi.PrestoException; import io.prestosql.spi.block.Block; import io.prestosql.spi.block.BlockBuilder; @@ -252,10 +251,10 @@ public Object evaluate() return visitor.process(expression, new NoPagePositionContext()); } - public Object evaluate(int position, Page page) + public Object evaluate(SymbolResolver inputs) { checkState(!optimize, "evaluate(int, Page) not allowed for optimizer"); - return visitor.process(expression, new SinglePagePositionContext(position, page)); + return visitor.process(expression, inputs); } public Object optimize(SymbolResolver inputs) @@ -271,39 +270,7 @@ private class Visitor @Override public Object visitFieldReference(FieldReference node, Object context) { - Type type = type(node); - - int channel = node.getFieldIndex(); - if (context instanceof PagePositionContext) { - PagePositionContext pagePositionContext = (PagePositionContext) context; - int position = pagePositionContext.getPosition(channel); - Block block = pagePositionContext.getBlock(channel); - - if (block.isNull(position)) { - return null; - } - - Class javaType = type.getJavaType(); - if (javaType == boolean.class) { - return type.getBoolean(block, position); - } - else if (javaType == long.class) { - return type.getLong(block, position); - } - else if (javaType == double.class) { - return type.getDouble(block, position); - } - else if (javaType == Slice.class) { - return type.getSlice(block, position); - } - else if (javaType == Block.class) { - return type.getObject(block, position); - } - else { - throw new UnsupportedOperationException("not yet implemented"); - } - } - throw new UnsupportedOperationException("Inputs must be set"); + throw new UnsupportedOperationException("Field references not supported in interpreter"); } @Override @@ -1299,31 +1266,6 @@ public int getPosition(int channel) } } - private static class SinglePagePositionContext - implements PagePositionContext - { - private final int position; - private final Page page; - - private SinglePagePositionContext(int position, Page page) - { - this.position = position; - this.page = page; - } - - @Override - public Block getBlock(int channel) - { - return page.getBlock(channel); - } - - @Override - public int getPosition(int channel) - { - return position; - } - } - private static Expression createFailureFunction(RuntimeException exception, Type type) { requireNonNull(exception, "Exception is null"); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java index fca17d4ab0fa..54f6f95fe0e6 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java @@ -1271,7 +1271,7 @@ else if (sourceNode instanceof SampleNode) { private RowExpression toRowExpression(Expression expression, Map, Type> types, Map layout) { - return SqlToRowExpressionTranslator.translate(expression, SCALAR, types, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, true, layout); + return SqlToRowExpressionTranslator.translate(expression, SCALAR, types, layout, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, true); } @Override diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/SymbolToInputParameterRewriter.java b/presto-main/src/main/java/io/prestosql/sql/planner/SymbolToInputParameterRewriter.java deleted file mode 100644 index bee2488bd716..000000000000 --- a/presto-main/src/main/java/io/prestosql/sql/planner/SymbolToInputParameterRewriter.java +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prestosql.sql.planner; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import io.prestosql.spi.type.Type; -import io.prestosql.sql.tree.Expression; -import io.prestosql.sql.tree.ExpressionRewriter; -import io.prestosql.sql.tree.ExpressionTreeRewriter; -import io.prestosql.sql.tree.FieldReference; -import io.prestosql.sql.tree.LambdaExpression; -import io.prestosql.sql.tree.SymbolReference; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static com.google.common.base.Preconditions.checkArgument; -import static java.util.Objects.requireNonNull; - -public class SymbolToInputParameterRewriter -{ - private final Map symbolToChannelMapping; - private final TypeProvider types; - - private final Map fieldToParameter = new HashMap<>(); - private final List inputChannels = new ArrayList<>(); - private final List inputTypes = new ArrayList<>(); - private int nextParameter; - - public List getInputChannels() - { - return ImmutableList.copyOf(inputChannels); - } - - public List getInputTypes() - { - return ImmutableList.copyOf(inputTypes); - } - - public SymbolToInputParameterRewriter(TypeProvider types, Map symbolToChannelMapping) - { - this.types = requireNonNull(types, "symbolToTypeMapping is null"); - - requireNonNull(symbolToChannelMapping, "symbolToChannelMapping is null"); - this.symbolToChannelMapping = ImmutableMap.copyOf(symbolToChannelMapping); - } - - public Expression rewrite(Expression expression) - { - return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter() - { - @Override - public Expression rewriteSymbolReference(SymbolReference node, Context context, ExpressionTreeRewriter treeRewriter) - { - Symbol symbol = Symbol.from(node); - Integer channel = symbolToChannelMapping.get(symbol); - if (channel == null) { - checkArgument(context.isInLambda(), "Cannot resolve symbol %s", node.getName()); - return node; - } - - Type type = types.get(symbol); - checkArgument(type != null, "Cannot resolve symbol %s", node.getName()); - - int parameter = fieldToParameter.computeIfAbsent(channel, field -> { - inputChannels.add(field); - inputTypes.add(type); - return nextParameter++; - }); - return new FieldReference(parameter); - } - - @Override - public Expression rewriteLambdaExpression(LambdaExpression node, Context context, ExpressionTreeRewriter treeRewriter) - { - return treeRewriter.defaultRewrite(node, new Context(true)); - } - }, expression, new Context(false)); - } - - private static class Context - { - private final boolean inLambda; - - public Context(boolean inLambda) - { - this.inLambda = inLambda; - } - - public boolean isInLambda() - { - return inLambda; - } - } -} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/ExpressionEquivalence.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/ExpressionEquivalence.java index 5a5066b5721a..d051b49f5941 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/ExpressionEquivalence.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/ExpressionEquivalence.java @@ -106,7 +106,7 @@ private RowExpression toRowExpression(Session session, Expression expression, Ma WarningCollector.NOOP); // convert to row expression - return translate(expression, SCALAR, expressionTypes, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, false, symbolInput); + return translate(expression, SCALAR, expressionTypes, symbolInput, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, false); } private static class CanonicalizationVisitor diff --git a/presto-main/src/main/java/io/prestosql/sql/relational/SqlToRowExpressionTranslator.java b/presto-main/src/main/java/io/prestosql/sql/relational/SqlToRowExpressionTranslator.java index fbf247765108..7dc1d044c6ad 100644 --- a/presto-main/src/main/java/io/prestosql/sql/relational/SqlToRowExpressionTranslator.java +++ b/presto-main/src/main/java/io/prestosql/sql/relational/SqlToRowExpressionTranslator.java @@ -142,11 +142,11 @@ public static RowExpression translate( Expression expression, FunctionKind functionKind, Map, Type> types, + Map layout, FunctionRegistry functionRegistry, TypeManager typeManager, Session session, - boolean optimize, - Map layout) + boolean optimize) { Visitor visitor = new Visitor( functionKind, diff --git a/presto-main/src/test/java/io/prestosql/operator/BenchmarkScanFilterAndProjectOperator.java b/presto-main/src/test/java/io/prestosql/operator/BenchmarkScanFilterAndProjectOperator.java index 418b2730462b..b1e99ae34fc4 100644 --- a/presto-main/src/test/java/io/prestosql/operator/BenchmarkScanFilterAndProjectOperator.java +++ b/presto-main/src/test/java/io/prestosql/operator/BenchmarkScanFilterAndProjectOperator.java @@ -233,7 +233,7 @@ private RowExpression rowExpression(String value) Expression expression = createExpression(value, METADATA, TypeProvider.copyOf(symbolTypes)); Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, TypeProvider.copyOf(symbolTypes), expression, emptyList(), WarningCollector.NOOP); - return SqlToRowExpressionTranslator.translate(expression, SCALAR, expressionTypes, METADATA.getFunctionRegistry(), METADATA.getTypeManager(), TEST_SESSION, true, sourceLayout); + return SqlToRowExpressionTranslator.translate(expression, SCALAR, expressionTypes, sourceLayout, METADATA.getFunctionRegistry(), METADATA.getTypeManager(), TEST_SESSION, true); } private static Page createPage(List types, int positions, boolean dictionary) diff --git a/presto-main/src/test/java/io/prestosql/operator/scalar/FunctionAssertions.java b/presto-main/src/test/java/io/prestosql/operator/scalar/FunctionAssertions.java index 2c81971cec69..a80249668457 100644 --- a/presto-main/src/test/java/io/prestosql/operator/scalar/FunctionAssertions.java +++ b/presto-main/src/test/java/io/prestosql/operator/scalar/FunctionAssertions.java @@ -37,9 +37,6 @@ import io.prestosql.operator.SourceOperator; import io.prestosql.operator.SourceOperatorFactory; import io.prestosql.operator.project.CursorProcessor; -import io.prestosql.operator.project.InterpretedPageFilter; -import io.prestosql.operator.project.InterpretedPageProjection; -import io.prestosql.operator.project.PageFilter; import io.prestosql.operator.project.PageProcessor; import io.prestosql.operator.project.PageProjection; import io.prestosql.spi.ErrorCodeSupplier; @@ -56,6 +53,7 @@ import io.prestosql.spi.connector.InMemoryRecordSet; import io.prestosql.spi.connector.RecordPageSource; import io.prestosql.spi.connector.RecordSet; +import io.prestosql.spi.predicate.Utils; import io.prestosql.spi.type.TimeZoneKey; import io.prestosql.spi.type.Type; import io.prestosql.split.PageSourceProvider; @@ -65,6 +63,7 @@ import io.prestosql.sql.analyzer.SemanticException; import io.prestosql.sql.gen.ExpressionCompiler; import io.prestosql.sql.parser.SqlParser; +import io.prestosql.sql.planner.ExpressionInterpreter; import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.PlanNodeId; @@ -132,10 +131,10 @@ import static io.prestosql.sql.planner.iterative.rule.CanonicalizeExpressionRewriter.canonicalizeExpression; import static io.prestosql.sql.relational.Expressions.constant; import static io.prestosql.sql.relational.SqlToRowExpressionTranslator.translate; -import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL; import static io.prestosql.testing.TestingTaskContext.createTaskContext; import static io.prestosql.type.UnknownType.UNKNOWN; import static java.lang.String.format; +import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; @@ -598,8 +597,7 @@ private List executeProjectionWithAll(String projection, Type expectedTy results.add(directOperatorValue); // interpret - Operator interpretedFilterProject = interpretedFilterProject(Optional.empty(), projectionExpression, expectedType, session); - Object interpretedValue = selectSingleValue(interpretedFilterProject, expectedType); + Object interpretedValue = interpret(projectionExpression, expectedType, session); results.add(interpretedValue); // execute over normal operator @@ -704,7 +702,10 @@ private List executeFilterWithAll(String filter, Session session, boole } // interpret - boolean interpretedValue = executeFilter(interpretedFilterProject(Optional.of(filterExpression), TRUE_LITERAL, BOOLEAN, session)); + Boolean interpretedValue = (Boolean) interpret(filterExpression, BOOLEAN, session); + if (interpretedValue == null) { + interpretedValue = false; + } results.add(interpretedValue); // execute over normal operator @@ -867,29 +868,46 @@ protected Void visitSymbolReference(SymbolReference node, Void context) return hasSymbolReferences.get(); } - private Operator interpretedFilterProject(Optional filter, Expression projection, Type expectedType, Session session) + private Object interpret(Expression expression, Type expectedType, Session session) { - Optional pageFilter = filter - .map(expression -> new InterpretedPageFilter( - expression, - SYMBOL_TYPES, - INPUT_MAPPING, - metadata, - SQL_PARSER, - session)); - - PageProjection pageProjection = new InterpretedPageProjection(projection, SYMBOL_TYPES, INPUT_MAPPING, metadata, SQL_PARSER, session); - assertEquals(pageProjection.getType(), expectedType); - - PageProcessor processor = new PageProcessor(pageFilter, ImmutableList.of(pageProjection)); - OperatorFactory operatorFactory = new FilterAndProjectOperatorFactory( - 0, - new PlanNodeId("test"), - () -> processor, - ImmutableList.of(pageProjection.getType()), - new DataSize(0, BYTE), - 0); - return operatorFactory.createOperator(createDriverContext(session)); + Map, Type> expressionTypes = getExpressionTypes(session, metadata, SQL_PARSER, SYMBOL_TYPES, expression, emptyList(), WarningCollector.NOOP); + ExpressionInterpreter evaluator = ExpressionInterpreter.expressionInterpreter(expression, metadata, session, expressionTypes); + + Object result = evaluator.evaluate(symbol -> { + int position = 0; + int channel = INPUT_MAPPING.get(symbol); + Type type = SYMBOL_TYPES.get(symbol); + + Block block = SOURCE_PAGE.getBlock(channel); + + if (block.isNull(position)) { + return null; + } + + Class javaType = type.getJavaType(); + if (javaType == boolean.class) { + return type.getBoolean(block, position); + } + else if (javaType == long.class) { + return type.getLong(block, position); + } + else if (javaType == double.class) { + return type.getDouble(block, position); + } + else if (javaType == Slice.class) { + return type.getSlice(block, position); + } + else if (javaType == Block.class) { + return type.getObject(block, position); + } + else { + throw new UnsupportedOperationException("not yet implemented"); + } + }); + + // convert result from stack type to Type ObjectValue + Block block = Utils.nativeValueToBlock(expectedType, result); + return expectedType.getObjectValue(session.toConnectorSession(), block, 0); } private static OperatorFactory compileFilterWithNoInputColumns(RowExpression filter, ExpressionCompiler compiler) @@ -955,7 +973,7 @@ private static SourceOperatorFactory compileScanFilterProject(Optional, Type> expressionTypes, Map layout) { - return translate(projection, SCALAR, expressionTypes, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, false, layout); + return translate(projection, SCALAR, expressionTypes, layout, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, false); } private static Page getAtMostOnePage(Operator operator, Page sourcePage) diff --git a/presto-main/src/test/java/io/prestosql/sql/TestSqlToRowExpressionTranslator.java b/presto-main/src/test/java/io/prestosql/sql/TestSqlToRowExpressionTranslator.java index 349e453dd337..20a52f4d2d53 100644 --- a/presto-main/src/test/java/io/prestosql/sql/TestSqlToRowExpressionTranslator.java +++ b/presto-main/src/test/java/io/prestosql/sql/TestSqlToRowExpressionTranslator.java @@ -92,7 +92,7 @@ private RowExpression translateAndOptimize(Expression expression) private RowExpression translateAndOptimize(Expression expression, Map, Type> types) { - return SqlToRowExpressionTranslator.translate(expression, SCALAR, types, metadata.getFunctionRegistry(), metadata.getTypeManager(), TEST_SESSION, true, ImmutableMap.of()); + return SqlToRowExpressionTranslator.translate(expression, SCALAR, types, ImmutableMap.of(), metadata.getFunctionRegistry(), metadata.getTypeManager(), TEST_SESSION, true); } private Expression simplifyExpression(Expression expression) diff --git a/presto-main/src/test/java/io/prestosql/sql/gen/PageProcessorBenchmark.java b/presto-main/src/test/java/io/prestosql/sql/gen/PageProcessorBenchmark.java index 597b82f96169..459f335a3f35 100644 --- a/presto-main/src/test/java/io/prestosql/sql/gen/PageProcessorBenchmark.java +++ b/presto-main/src/test/java/io/prestosql/sql/gen/PageProcessorBenchmark.java @@ -181,7 +181,7 @@ private RowExpression rowExpression(String value) Expression expression = createExpression(value, METADATA, TypeProvider.copyOf(symbolTypes)); Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, TypeProvider.copyOf(symbolTypes), expression, emptyList(), WarningCollector.NOOP); - return SqlToRowExpressionTranslator.translate(expression, SCALAR, expressionTypes, METADATA.getFunctionRegistry(), METADATA.getTypeManager(), TEST_SESSION, true, sourceLayout); + return SqlToRowExpressionTranslator.translate(expression, SCALAR, expressionTypes, sourceLayout, METADATA.getFunctionRegistry(), METADATA.getTypeManager(), TEST_SESSION, true); } private static Page createPage(List types, boolean dictionary) diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestInterpretedPageFilterFunction.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestInterpretedPageFilterFunction.java deleted file mode 100644 index c724efa16bf4..000000000000 --- a/presto-main/src/test/java/io/prestosql/sql/planner/TestInterpretedPageFilterFunction.java +++ /dev/null @@ -1,220 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prestosql.sql.planner; - -import com.google.common.collect.ImmutableMap; -import io.prestosql.metadata.Metadata; -import io.prestosql.metadata.MetadataManager; -import io.prestosql.operator.project.InterpretedPageFilter; -import io.prestosql.operator.project.SelectedPositions; -import io.prestosql.spi.Page; -import io.prestosql.sql.parser.SqlParser; -import io.prestosql.sql.tree.ComparisonExpression; -import org.testng.annotations.Test; - -import static io.prestosql.SessionTestUtils.TEST_SESSION; -import static io.prestosql.operator.scalar.FunctionAssertions.createExpression; -import static java.lang.String.format; -import static org.testng.Assert.assertEquals; -import static org.testng.Assert.assertTrue; - -public class TestInterpretedPageFilterFunction -{ - private static final SqlParser SQL_PARSER = new SqlParser(); - private static final Metadata METADATA = MetadataManager.createTestMetadataManager(); - - @Test - public void testNullLiteral() - { - assertFilter("null", false); - } - - @Test - public void testBooleanLiteral() - { - assertFilter("true", true); - assertFilter("false", false); - } - - @Test - public void testNotExpression() - { - assertFilter("not true", false); - assertFilter("not false", true); - assertFilter("not null", false); - } - - @Test - public void testAndExpression() - { - assertFilter("true and true", true); - assertFilter("true and false", false); - assertFilter("true and null", false); - - assertFilter("false and true", false); - assertFilter("false and false", false); - assertFilter("false and null", false); - - assertFilter("null and true", false); - assertFilter("null and false", false); - assertFilter("null and null", false); - } - - @Test - public void testORExpression() - { - assertFilter("true or true", true); - assertFilter("true or false", true); - assertFilter("true or null", true); - - assertFilter("false or true", true); - assertFilter("false or false", false); - assertFilter("false or null", false); - - assertFilter("null or true", true); - assertFilter("null or false", false); - assertFilter("null or null", false); - } - - @Test - public void testIsNullExpression() - { - assertFilter("null is null", true); - assertFilter("42 is null", false); - } - - @Test - public void testIsNotNullExpression() - { - assertFilter("42 is not null", true); - assertFilter("null is not null", false); - } - - @Test - public void testComparisonExpression() - { - assertFilter("42 = 42", true); - assertFilter("42 = 42.0", true); - assertFilter("42.42 = 42.42", true); - assertFilter("'foo' = 'foo'", true); - - assertFilter("42 = 87", false); - assertFilter("42 = 22.2", false); - assertFilter("42.42 = 22.2", false); - assertFilter("'foo' = 'bar'", false); - - assertFilter("42 != 87", true); - assertFilter("42 != 22.2", true); - assertFilter("42.42 != 22.22", true); - assertFilter("'foo' != 'bar'", true); - - assertFilter("42 != 42", false); - assertFilter("42 != 42.0", false); - assertFilter("42.42 != 42.42", false); - assertFilter("'foo' != 'foo'", false); - - assertFilter("42 < 88", true); - assertFilter("42 < 88.8", true); - assertFilter("42.42 < 88.8", true); - assertFilter("'bar' < 'foo'", true); - - assertFilter("88 < 42", false); - assertFilter("88 < 42.42", false); - assertFilter("88.8 < 42.42", false); - assertFilter("'foo' < 'bar'", false); - - assertFilter("42 <= 88", true); - assertFilter("42 <= 88.8", true); - assertFilter("42.42 <= 88.8", true); - assertFilter("'bar' <= 'foo'", true); - - assertFilter("42 <= 42", true); - assertFilter("42 <= 42.0", true); - assertFilter("42.42 <= 42.42", true); - assertFilter("'foo' <= 'foo'", true); - - assertFilter("88 <= 42", false); - assertFilter("88 <= 42.42", false); - assertFilter("88.8 <= 42.42", false); - assertFilter("'foo' <= 'bar'", false); - - assertFilter("88 >= 42", true); - assertFilter("88.8 >= 42.0", true); - assertFilter("88.8 >= 42.42", true); - assertFilter("'foo' >= 'bar'", true); - - assertFilter("42 >= 88", false); - assertFilter("42.42 >= 88.0", false); - assertFilter("42.42 >= 88.88", false); - assertFilter("'bar' >= 'foo'", false); - - assertFilter("88 >= 42", true); - assertFilter("88.8 >= 42.0", true); - assertFilter("88.8 >= 42.42", true); - assertFilter("'foo' >= 'bar'", true); - assertFilter("42 >= 42", true); - assertFilter("42 >= 42.0", true); - assertFilter("42.42 >= 42.42", true); - assertFilter("'foo' >= 'foo'", true); - - assertFilter("42 >= 88", false); - assertFilter("42.42 >= 88.0", false); - assertFilter("42.42 >= 88.88", false); - assertFilter("'bar' >= 'foo'", false); - } - - @Test - public void testComparisonExpressionWithNulls() - { - for (ComparisonExpression.Operator operator : ComparisonExpression.Operator.values()) { - if (operator == ComparisonExpression.Operator.IS_DISTINCT_FROM) { - // IS DISTINCT FROM has different NULL semantics - continue; - } - - assertFilter(format("NULL %s NULL", operator.getValue()), false); - - assertFilter(format("42 %s NULL", operator.getValue()), false); - assertFilter(format("NULL %s 42", operator.getValue()), false); - - assertFilter(format("11.1 %s NULL", operator.getValue()), false); - assertFilter(format("NULL %s 11.1", operator.getValue()), false); - } - } - - private static void assertFilter(String expression, boolean expectedValue) - { - InterpretedPageFilter filterFunction = new InterpretedPageFilter( - createExpression(expression, METADATA, TypeProvider.empty()), - TypeProvider.empty(), - ImmutableMap.of(), - METADATA, - SQL_PARSER, - TEST_SESSION); - - SelectedPositions selectedPositions = filterFunction.filter(TEST_SESSION.toConnectorSession(), new Page(1)); - assertEquals(selectedPositions.size(), expectedValue ? 1 : 0); - if (expectedValue) { - if (selectedPositions.isList()) { - assertEquals(selectedPositions.getPositions()[selectedPositions.getOffset()], 0); - } - else { - assertEquals(selectedPositions.getOffset(), 0); - } - } - else { - assertTrue(selectedPositions.isEmpty()); - } - } -} diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestInterpretedPageProjectionFunction.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestInterpretedPageProjectionFunction.java deleted file mode 100644 index 3e0d1b7b1432..000000000000 --- a/presto-main/src/test/java/io/prestosql/sql/planner/TestInterpretedPageProjectionFunction.java +++ /dev/null @@ -1,288 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prestosql.sql.planner; - -import com.google.common.collect.ImmutableMap; -import io.prestosql.block.BlockAssertions; -import io.prestosql.metadata.Metadata; -import io.prestosql.metadata.MetadataManager; -import io.prestosql.operator.DriverYieldSignal; -import io.prestosql.operator.Work; -import io.prestosql.operator.project.InterpretedPageProjection; -import io.prestosql.spi.Page; -import io.prestosql.spi.block.Block; -import io.prestosql.spi.block.BlockBuilder; -import io.prestosql.spi.type.Type; -import io.prestosql.sql.parser.SqlParser; -import io.prestosql.sql.tree.ArithmeticBinaryExpression; -import org.testng.annotations.AfterClass; -import org.testng.annotations.Test; - -import javax.annotation.Nullable; - -import java.util.List; -import java.util.Map; -import java.util.concurrent.ScheduledExecutorService; - -import static io.airlift.concurrent.Threads.daemonThreadsNamed; -import static io.prestosql.SessionTestUtils.TEST_SESSION; -import static io.prestosql.operator.project.SelectedPositions.positionsList; -import static io.prestosql.operator.scalar.FunctionAssertions.createExpression; -import static io.prestosql.spi.type.BigintType.BIGINT; -import static io.prestosql.spi.type.BooleanType.BOOLEAN; -import static io.prestosql.spi.type.DoubleType.DOUBLE; -import static io.prestosql.spi.type.TypeUtils.writeNativeValue; -import static io.prestosql.spi.type.VarcharType.VARCHAR; -import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor; -import static org.testng.Assert.assertEquals; -import static org.testng.Assert.assertFalse; -import static org.testng.Assert.assertTrue; - -public class TestInterpretedPageProjectionFunction -{ - // todo add cases for decimal - - private static final SqlParser SQL_PARSER = new SqlParser(); - private static final Metadata METADATA = MetadataManager.createTestMetadataManager(); - private static final ScheduledExecutorService executor = newSingleThreadScheduledExecutor(daemonThreadsNamed("test-%s")); - - @AfterClass(alwaysRun = true) - public void tearDown() - { - executor.shutdownNow(); - } - - @Test - public void testBooleanExpression() - { - assertProjection("true", true); - assertProjection("false", false); - assertProjection("1 = 1", true); - assertProjection("1 = 0", false); - assertProjection("true and false", false); - } - - @Test - public void testArithmeticExpression() - { - assertProjection("42 + 87", 42 + 87); - assertProjection("42 + 22.2E0", 42 + 22.2); - assertProjection("11.1E0 + 22.2E0", 11.1 + 22.2); - - assertProjection("42 - 87", 42 - 87); - assertProjection("42 - 22.2E0", 42 - 22.2); - assertProjection("11.1E0 - 22.2E0", 11.1 - 22.2); - - assertProjection("42 * 87", 42 * 87); - assertProjection("42 * 22.2E0", 42 * 22.2); - assertProjection("11.1E0 * 22.2E0", 11.1 * 22.2); - - assertProjection("42 / 87", 42 / 87); - assertProjection("42 / 22.2E0", 42 / 22.2); - assertProjection("11.1E0 / 22.2E0", 11.1 / 22.2); - - assertProjection("42 % 87", 42 % 87); - assertProjection("42 % 22.2E0", 42 % 22.2); - assertProjection("11.1E0 % 22.2E0", 11.1 % 22.2); - - assertProjection("42 + BIGINT '87'", 42 + 87L); - assertProjection("BIGINT '42' - 22.2E0", 42L - 22.2); - assertProjection("42 * BIGINT '87'", 42 * 87L); - assertProjection("BIGINT '11' / 22.2E0", 11L / 22.2); - assertProjection("11.1E0 % BIGINT '22'", 11.1 % 22L); - } - - @Test - public void testArithmeticExpressionWithNulls() - { - for (ArithmeticBinaryExpression.Operator operator : ArithmeticBinaryExpression.Operator.values()) { - assertProjection("CAST(NULL AS INTEGER) " + operator.getValue() + " CAST(NULL AS INTEGER)", null); - - assertProjection("42 " + operator.getValue() + " NULL", null); - assertProjection("NULL " + operator.getValue() + " 42", null); - - assertProjection("11.1 " + operator.getValue() + " CAST(NULL AS INTEGER)", null); - assertProjection("CAST(NULL AS INTEGER) " + operator.getValue() + " 11.1", null); - } - } - - @Test - public void testCoalesceExpression() - { - assertProjection("COALESCE(42, 87, 100)", 42); - assertProjection("COALESCE(NULL, 87, 100)", 87); - assertProjection("COALESCE(42, NULL, 100)", 42); - assertProjection("COALESCE(42, NULL, BIGINT '100')", 42L); - assertProjection("COALESCE(NULL, NULL, 100)", 100); - assertProjection("COALESCE(NULL, NULL, BIGINT '100')", 100L); - - assertProjection("COALESCE(42.2E0, 87.2E0, 100.2E0)", 42.2); - assertProjection("COALESCE(NULL, 87.2E0, 100.2E0)", 87.2); - assertProjection("COALESCE(42.2E0, NULL, 100.2E0)", 42.2); - assertProjection("COALESCE(NULL, NULL, 100.2E0)", 100.2); - - assertProjection("COALESCE('foo', 'bar', 'zah')", "foo"); - assertProjection("COALESCE(NULL, 'bar', 'zah')", "bar"); - assertProjection("COALESCE('foo', NULL, 'zah')", "foo"); - assertProjection("COALESCE(NULL, NULL, 'zah')", "zah"); - - assertProjection("COALESCE(NULL, NULL, NULL)", null); - } - - @Test - public void testNullIf() - { - assertProjection("NULLIF(42, 42)", null); - assertProjection("NULLIF(42, 42.0E0)", null); - assertProjection("NULLIF(42.42E0, 42.42E0)", null); - assertProjection("NULLIF('foo', 'foo')", null); - - assertProjection("NULLIF(42, 87)", 42); - assertProjection("NULLIF(42, 22.2E0)", 42); - assertProjection("NULLIF(42, BIGINT '87')", 42); - assertProjection("NULLIF(BIGINT '42', 22.2E0)", 42L); - assertProjection("NULLIF(42.42E0, 22.2E0)", 42.42); - assertProjection("NULLIF('foo', 'bar')", "foo"); - - assertProjection("NULLIF(NULL, NULL)", null); - - assertProjection("NULLIF(42, NULL)", 42); - assertProjection("NULLIF(NULL, 42)", null); - - assertProjection("NULLIF(11.1E0, NULL)", 11.1); - assertProjection("NULLIF(NULL, 11.1E0)", null); - } - - @Test - public void testSymbolReference() - { - Symbol symbol = new Symbol("symbol"); - ImmutableMap symbolToInputMappings = ImmutableMap.of(symbol, 0); - assertProjection("symbol", true, symbolToInputMappings, TypeProvider.copyOf(ImmutableMap.of(symbol, BOOLEAN)), 0, createBlock(BOOLEAN, true)); - assertProjection("symbol", null, symbolToInputMappings, TypeProvider.copyOf(ImmutableMap.of(symbol, BOOLEAN)), 0, createNullBlock(BOOLEAN)); - - assertProjection("symbol", 42L, symbolToInputMappings, TypeProvider.copyOf(ImmutableMap.of(symbol, BIGINT)), 0, createBlock(BIGINT, 42)); - assertProjection("symbol", null, symbolToInputMappings, TypeProvider.copyOf(ImmutableMap.of(symbol, BIGINT)), 0, createNullBlock(BIGINT)); - - assertProjection("symbol", 11.1, symbolToInputMappings, TypeProvider.copyOf(ImmutableMap.of(symbol, DOUBLE)), 0, createBlock(DOUBLE, 11.1)); - assertProjection("symbol", null, symbolToInputMappings, TypeProvider.copyOf(ImmutableMap.of(symbol, DOUBLE)), 0, createNullBlock(DOUBLE)); - - assertProjection("symbol", "foo", symbolToInputMappings, TypeProvider.copyOf(ImmutableMap.of(symbol, VARCHAR)), 0, createBlock(VARCHAR, "foo")); - assertProjection("symbol", null, symbolToInputMappings, TypeProvider.copyOf(ImmutableMap.of(symbol, VARCHAR)), 0, createNullBlock(VARCHAR)); - } - - private static void assertProjection(String expression, @Nullable Object expectedValue) - { - assertProjection( - expression, - expectedValue, - ImmutableMap.of(), - TypeProvider.empty(), - 0); - } - - private static void assertProjection( - String expression, - @Nullable Object expectedValue, - Map symbolToInputMappings, - TypeProvider symbolTypes, - int position, - Block... blocks) - { - assertProjection(expression, new Object[] {expectedValue}, symbolToInputMappings, symbolTypes, new int[] {position}, blocks); - } - - private static void assertProjection( - String expression, - Object[] expectedValues, - Map symbolToInputMappings, - TypeProvider symbolTypes, - int[] positions, - Block... blocks) - { - InterpretedPageProjection projectionFunction = new InterpretedPageProjection( - createExpression(expression, METADATA, symbolTypes), - symbolTypes, - symbolToInputMappings, - METADATA, - SQL_PARSER, - TEST_SESSION); - - // project with yield - DriverYieldSignal yieldSignal = new DriverYieldSignal(); - Work work = projectionFunction.project( - TEST_SESSION.toConnectorSession(), - yieldSignal, - new Page(positions.length, blocks), - positionsList(positions, 0, positions.length)); - - Block block; - // Get nothing for the first position.length compute due to yield - // Currently we enforce a yield check for every position; free feel to adjust the number if the behavior changes - for (int i = 0; i < positions.length; i++) { - yieldSignal.setWithDelay(1, executor); - yieldSignal.forceYieldForTesting(); - assertFalse(work.process()); - yieldSignal.reset(); - } - // the next yield is not going to prevent a block to be produced - yieldSignal.setWithDelay(1, executor); - yieldSignal.forceYieldForTesting(); - yieldSignal.reset(); - assertTrue(work.process()); - block = work.getResult(); - - List actualValues = BlockAssertions.toValues(projectionFunction.getType(), block); - assertEquals(actualValues.size(), positions.length); - assertEquals(expectedValues.length, positions.length); - for (int i = 0; i < positions.length; i++) { - assertEquals(actualValues.get(i), expectedValues[i]); - } - - // project without yield - work = projectionFunction.project( - TEST_SESSION.toConnectorSession(), - new DriverYieldSignal(), - new Page(positions.length, blocks), - positionsList(positions, 0, positions.length)); - assertTrue(work.process()); - block = work.getResult(); - - actualValues = BlockAssertions.toValues(projectionFunction.getType(), block); - assertEquals(actualValues.size(), positions.length); - assertEquals(expectedValues.length, positions.length); - for (int i = 0; i < positions.length; i++) { - assertEquals(actualValues.get(i), expectedValues[i]); - } - } - - private static Block createBlock(Type type, Object value) - { - return createBlock(type, new Object[] {value}); - } - - private static Block createNullBlock(Type type) - { - return createBlock(type, new Object[] {null}); - } - - private static Block createBlock(Type type, Object[] values) - { - BlockBuilder blockBuilder = type.createBlockBuilder(null, values.length); - for (Object value : values) { - writeNativeValue(type, blockBuilder, value); - } - return blockBuilder.build(); - } -} diff --git a/presto-main/src/test/java/io/prestosql/type/BenchmarkDecimalOperators.java b/presto-main/src/test/java/io/prestosql/type/BenchmarkDecimalOperators.java index 7c25281bb558..cac7e0d97a23 100644 --- a/presto-main/src/test/java/io/prestosql/type/BenchmarkDecimalOperators.java +++ b/presto-main/src/test/java/io/prestosql/type/BenchmarkDecimalOperators.java @@ -614,7 +614,7 @@ private RowExpression rowExpression(String value) Expression expression = createExpression(value, metadata, TypeProvider.copyOf(symbolTypes)); Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, metadata, SQL_PARSER, TypeProvider.copyOf(symbolTypes), expression, emptyList(), WarningCollector.NOOP); - return SqlToRowExpressionTranslator.translate(expression, SCALAR, expressionTypes, metadata.getFunctionRegistry(), metadata.getTypeManager(), TEST_SESSION, true, sourceLayout); + return SqlToRowExpressionTranslator.translate(expression, SCALAR, expressionTypes, sourceLayout, metadata.getFunctionRegistry(), metadata.getTypeManager(), TEST_SESSION, true); } private Object generateRandomValue(Type type) From c8d352e7c15ee97bd014d5f06d9688c229259296 Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Wed, 6 Mar 2019 11:33:24 -0800 Subject: [PATCH 13/18] Remove unused functions --- .../sql/analyzer/ExpressionAnalyzer.java | 56 ------------------- 1 file changed, 56 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java b/presto-main/src/main/java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java index f6642d34305a..f9dcd0042d92 100644 --- a/presto-main/src/main/java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java +++ b/presto-main/src/main/java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java @@ -110,7 +110,6 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Map; -import java.util.Map.Entry; import java.util.Optional; import java.util.OptionalInt; import java.util.Set; @@ -1472,30 +1471,6 @@ public static Map, Type> getExpressionTypes( return analyzeExpressionsWithSymbols(session, metadata, sqlParser, types, expressions, parameters, warningCollector, isDescribe).getExpressionTypes(); } - public static Map, Type> getExpressionTypesFromInput( - Session session, - Metadata metadata, - SqlParser sqlParser, - Map types, - Expression expression, - List parameters, - WarningCollector warningCollector) - { - return getExpressionTypesFromInput(session, metadata, sqlParser, types, ImmutableList.of(expression), parameters, warningCollector); - } - - public static Map, Type> getExpressionTypesFromInput( - Session session, - Metadata metadata, - SqlParser sqlParser, - Map types, - Iterable expressions, - List parameters, - WarningCollector warningCollector) - { - return analyzeExpressionsWithInputs(session, metadata, sqlParser, types, expressions, parameters, warningCollector).getExpressionTypes(); - } - public static ExpressionAnalysis analyzeExpressionsWithSymbols( Session session, Metadata metadata, @@ -1509,37 +1484,6 @@ public static ExpressionAnalysis analyzeExpressionsWithSymbols( return analyzeExpressions(session, metadata, sqlParser, new RelationType(), types, expressions, parameters, warningCollector, isDescribe); } - private static ExpressionAnalysis analyzeExpressionsWithInputs( - Session session, - Metadata metadata, - SqlParser sqlParser, - Map types, - Iterable expressions, - List parameters, - WarningCollector warningCollector) - { - Field[] fields = new Field[types.size()]; - for (Entry entry : types.entrySet()) { - fields[entry.getKey()] = io.prestosql.sql.analyzer.Field.newUnqualified(Optional.empty(), entry.getValue()); - } - RelationType tupleDescriptor = new RelationType(fields); - - return analyzeExpressions(session, metadata, sqlParser, tupleDescriptor, TypeProvider.empty(), expressions, parameters, warningCollector); - } - - public static ExpressionAnalysis analyzeExpressions( - Session session, - Metadata metadata, - SqlParser sqlParser, - RelationType tupleDescriptor, - TypeProvider types, - Iterable expressions, - List parameters, - WarningCollector warningCollector) - { - return analyzeExpressions(session, metadata, sqlParser, tupleDescriptor, types, expressions, parameters, warningCollector, false); - } - private static ExpressionAnalysis analyzeExpressions( Session session, Metadata metadata, From 6260a698526a6ab9cd9e0a2e3291fe0d27352737 Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Wed, 6 Mar 2019 13:46:49 -0800 Subject: [PATCH 14/18] Rename analyzeExpressionsWithSymbols method There's no longer a conflict with analyzeExpressionsWithInputs so simplify the name --- .../java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java | 4 ++-- .../java/io/prestosql/operator/scalar/FunctionAssertions.java | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java b/presto-main/src/main/java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java index f9dcd0042d92..49d510efb2c9 100644 --- a/presto-main/src/main/java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java +++ b/presto-main/src/main/java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java @@ -1468,10 +1468,10 @@ public static Map, Type> getExpressionTypes( WarningCollector warningCollector, boolean isDescribe) { - return analyzeExpressionsWithSymbols(session, metadata, sqlParser, types, expressions, parameters, warningCollector, isDescribe).getExpressionTypes(); + return analyzeExpressions(session, metadata, sqlParser, types, expressions, parameters, warningCollector, isDescribe).getExpressionTypes(); } - public static ExpressionAnalysis analyzeExpressionsWithSymbols( + public static ExpressionAnalysis analyzeExpressions( Session session, Metadata metadata, SqlParser sqlParser, diff --git a/presto-main/src/test/java/io/prestosql/operator/scalar/FunctionAssertions.java b/presto-main/src/test/java/io/prestosql/operator/scalar/FunctionAssertions.java index a80249668457..d4b4b2ee136e 100644 --- a/presto-main/src/test/java/io/prestosql/operator/scalar/FunctionAssertions.java +++ b/presto-main/src/test/java/io/prestosql/operator/scalar/FunctionAssertions.java @@ -126,7 +126,7 @@ import static io.prestosql.spi.type.VarcharType.VARCHAR; import static io.prestosql.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; import static io.prestosql.sql.ParsingUtil.createParsingOptions; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.analyzeExpressionsWithSymbols; +import static io.prestosql.sql.analyzer.ExpressionAnalyzer.analyzeExpressions; import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.planner.iterative.rule.CanonicalizeExpressionRewriter.canonicalizeExpression; import static io.prestosql.sql.relational.Expressions.constant; @@ -748,7 +748,7 @@ public static Expression createExpression(Session session, String expression, Me parsedExpression = rewriteIdentifiersToSymbolReferences(parsedExpression); - final ExpressionAnalysis analysis = analyzeExpressionsWithSymbols( + final ExpressionAnalysis analysis = analyzeExpressions( session, metadata, SQL_PARSER, From c705d2d18aaad2d52c3cb3980094dc8673258ee3 Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Wed, 6 Mar 2019 13:47:53 -0800 Subject: [PATCH 15/18] Inline analyzeExpressions method There's only one caller, so no need for an extra indirection. --- .../sql/analyzer/ExpressionAnalyzer.java | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java b/presto-main/src/main/java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java index 49d510efb2c9..dabb2e2876e5 100644 --- a/presto-main/src/main/java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java +++ b/presto-main/src/main/java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java @@ -1480,27 +1480,13 @@ public static ExpressionAnalysis analyzeExpressions( List parameters, WarningCollector warningCollector, boolean isDescribe) - { - return analyzeExpressions(session, metadata, sqlParser, new RelationType(), types, expressions, parameters, warningCollector, isDescribe); - } - - private static ExpressionAnalysis analyzeExpressions( - Session session, - Metadata metadata, - SqlParser sqlParser, - RelationType tupleDescriptor, - TypeProvider types, - Iterable expressions, - List parameters, - WarningCollector warningCollector, - boolean isDescribe) { // expressions at this point can not have sub queries so deny all access checks // in the future, we will need a full access controller here to verify access to functions Analysis analysis = new Analysis(null, parameters, isDescribe); ExpressionAnalyzer analyzer = create(analysis, session, metadata, sqlParser, new DenyAllAccessControl(), types, warningCollector); for (Expression expression : expressions) { - analyzer.analyze(expression, Scope.builder().withRelationType(RelationId.anonymous(), tupleDescriptor).build()); + analyzer.analyze(expression, Scope.builder().withRelationType(RelationId.anonymous(), new RelationType()).build()); } return new ExpressionAnalysis( From 55f3e220246b9ea3beefca39d4587c670e5b49ef Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Wed, 6 Mar 2019 15:02:53 -0800 Subject: [PATCH 16/18] Remove unnecessary expression rewrite --- .../sql/planner/LocalExecutionPlanner.java | 18 +---- .../sql/planner/SymbolToInputRewriter.java | 76 ------------------- 2 files changed, 2 insertions(+), 92 deletions(-) delete mode 100644 presto-main/src/main/java/io/prestosql/sql/planner/SymbolToInputRewriter.java diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java index 54f6f95fe0e6..800c88eb7563 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java @@ -182,7 +182,6 @@ import io.prestosql.sql.relational.SqlToRowExpressionTranslator; import io.prestosql.sql.tree.ComparisonExpression; import io.prestosql.sql.tree.Expression; -import io.prestosql.sql.tree.FieldReference; import io.prestosql.sql.tree.FunctionCall; import io.prestosql.sql.tree.LambdaArgumentDeclaration; import io.prestosql.sql.tree.LambdaExpression; @@ -1965,10 +1964,8 @@ private JoinBridgeManager createLookupSourceFact Optional sortChannel = sortExpressionContext .map(SortExpressionContext::getSortExpression) - .map(sortExpression -> sortExpressionAsSortChannel( - sortExpression, - probeSource.getLayout(), - buildSource.getLayout())); + .map(Symbol::from) + .map(sortSymbol -> createJoinSourcesLayout(buildSource.getLayout(), probeSource.getLayout()).get(sortSymbol)); List searchFunctionFactories = sortExpressionContext .map(SortExpressionContext::getSearchExpressions) @@ -2050,17 +2047,6 @@ private JoinFilterFunctionFactory compileJoinFilterFunction( return joinFilterFunctionCompiler.compileJoinFilterFunction(translatedFilter, buildLayout.size()); } - private int sortExpressionAsSortChannel( - Expression sortExpression, - Map probeLayout, - Map buildLayout) - { - Map joinSourcesLayout = createJoinSourcesLayout(buildLayout, probeLayout); - Expression rewrittenSortExpression = new SymbolToInputRewriter(joinSourcesLayout).rewrite(sortExpression); - checkArgument(rewrittenSortExpression instanceof FieldReference, "Unsupported expression type [%s]", rewrittenSortExpression); - return ((FieldReference) rewrittenSortExpression).getFieldIndex(); - } - private OperatorFactory createLookupJoin( JoinNode node, PhysicalOperation probeSource, diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/SymbolToInputRewriter.java b/presto-main/src/main/java/io/prestosql/sql/planner/SymbolToInputRewriter.java deleted file mode 100644 index b7b563ea20db..000000000000 --- a/presto-main/src/main/java/io/prestosql/sql/planner/SymbolToInputRewriter.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prestosql.sql.planner; - -import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableMap; -import io.prestosql.sql.tree.Expression; -import io.prestosql.sql.tree.ExpressionRewriter; -import io.prestosql.sql.tree.ExpressionTreeRewriter; -import io.prestosql.sql.tree.FieldReference; -import io.prestosql.sql.tree.LambdaExpression; -import io.prestosql.sql.tree.SymbolReference; - -import java.util.Map; - -import static java.util.Objects.requireNonNull; - -public class SymbolToInputRewriter -{ - private final Map symbolToChannelMapping; - - public SymbolToInputRewriter(Map symbolToChannelMapping) - { - requireNonNull(symbolToChannelMapping, "symbolToChannelMapping is null"); - this.symbolToChannelMapping = ImmutableMap.copyOf(symbolToChannelMapping); - } - - public Expression rewrite(Expression expression) - { - return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter() - { - @Override - public Expression rewriteSymbolReference(SymbolReference node, Context context, ExpressionTreeRewriter treeRewriter) - { - Integer channel = symbolToChannelMapping.get(Symbol.from(node)); - if (channel == null) { - Preconditions.checkArgument(context.isInLambda(), "Cannot resolve symbol %s", node.getName()); - return node; - } - return new FieldReference(channel); - } - - @Override - public Expression rewriteLambdaExpression(LambdaExpression node, Context context, ExpressionTreeRewriter treeRewriter) - { - return treeRewriter.defaultRewrite(node, new Context(true)); - } - }, expression, new Context(false)); - } - - private static class Context - { - private final boolean inLambda; - - public Context(boolean inLambda) - { - this.inLambda = inLambda; - } - - public boolean isInLambda() - { - return inLambda; - } - } -} From 49edfa7f834169d7624857a0d79e8e3e4736ca5e Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Wed, 6 Mar 2019 17:36:47 -0800 Subject: [PATCH 17/18] Encapsulate expression type analysis in planner The new class is to facilitate obtaining the type of an expression and its subexpressions during planning (i.e., when interacting with IR expression) and to remove spurious dependencies on the SQL parser. It will eventually get removed when we split the AST from the IR and we encode the type directly into IR expressions. --- .../benchmark/AbstractOperatorBenchmark.java | 14 ++--- .../TestExtractSpatialInnerJoin.java | 2 +- .../TestExtractSpatialLeftJoin.java | 2 +- .../execution/SqlQueryExecution.java | 3 +- .../io/prestosql/server/ServerMainModule.java | 2 + .../sql/analyzer/ExpressionAnalyzer.java | 38 ------------ .../sql/analyzer/QueryExplainer.java | 3 +- .../sql/analyzer/StatementAnalyzer.java | 9 +-- .../planner/DesugarAtTimeZoneRewriter.java | 10 +--- .../sql/planner/DomainTranslator.java | 19 ++---- .../sql/planner/LocalExecutionPlanner.java | 50 +++------------- .../prestosql/sql/planner/LogicalPlanner.java | 15 +++-- .../prestosql/sql/planner/PlanOptimizers.java | 25 ++++---- .../prestosql/sql/planner/TypeAnalyzer.java | 59 +++++++++++++++++++ ...wPartialAggregationOverGroupIdRuleSet.java | 10 ++-- .../iterative/rule/DesugarAtTimeZone.java | 12 ++-- .../iterative/rule/ExtractSpatialJoins.java | 54 +++++++---------- .../rule/PushPredicateIntoTableScan.java | 25 +++----- .../iterative/rule/SimplifyExpressions.java | 21 +++---- .../planner/optimizations/AddExchanges.java | 14 ++--- .../optimizations/AddLocalExchanges.java | 14 ++--- .../optimizations/ExpressionEquivalence.java | 31 ++++------ .../optimizations/PredicatePushDown.java | 46 ++++----------- .../optimizations/PropertyDerivations.java | 27 ++++----- .../StreamPropertyDerivations.java | 18 +++--- .../sanity/NoDuplicatePlanNodeIdsChecker.java | 4 +- .../sanity/NoIdentifierLeftChecker.java | 4 +- .../NoSubqueryExpressionLeftChecker.java | 4 +- .../sql/planner/sanity/PlanSanityChecker.java | 12 ++-- .../sql/planner/sanity/TypeValidator.java | 21 +++---- ...ValidateAggregationsWithDefaultValues.java | 16 ++--- .../sanity/ValidateDependenciesChecker.java | 4 +- .../sanity/ValidateStreamingAggregations.java | 14 ++--- .../sanity/VerifyNoFilteredAggregations.java | 4 +- .../sanity/VerifyOnlyOneOutputNode.java | 4 +- .../prestosql/testing/LocalQueryRunner.java | 7 ++- .../io/prestosql/execution/TaskTestUtils.java | 3 +- ...BenchmarkScanFilterAndProjectOperator.java | 18 +++--- .../operator/scalar/FunctionAssertions.java | 17 ++---- .../sql/TestExpressionInterpreter.java | 9 ++- .../sql/gen/PageProcessorBenchmark.java | 8 +-- .../sql/planner/TestTypeValidator.java | 4 +- .../rule/TestPushPredicateIntoTableScan.java | 3 +- .../rule/TestSimplifyExpressions.java | 3 +- .../iterative/rule/test/RuleTester.java | 13 ++-- .../optimizations/TestEliminateSorts.java | 3 +- .../TestExpressionEquivalence.java | 3 +- .../optimizations/TestReorderWindows.java | 3 +- ...ValidateAggregationsWithDefaultValues.java | 5 +- .../TestValidateStreamingAggregations.java | 8 +-- .../type/BenchmarkDecimalOperators.java | 17 ++++-- .../tests/AbstractTestQueryFramework.java | 3 +- 52 files changed, 328 insertions(+), 409 deletions(-) create mode 100644 presto-main/src/main/java/io/prestosql/sql/planner/TypeAnalyzer.java diff --git a/presto-benchmark/src/main/java/io/prestosql/benchmark/AbstractOperatorBenchmark.java b/presto-benchmark/src/main/java/io/prestosql/benchmark/AbstractOperatorBenchmark.java index 3b3171c37318..568a30387cee 100644 --- a/presto-benchmark/src/main/java/io/prestosql/benchmark/AbstractOperatorBenchmark.java +++ b/presto-benchmark/src/main/java/io/prestosql/benchmark/AbstractOperatorBenchmark.java @@ -23,7 +23,6 @@ import io.prestosql.execution.Lifespan; import io.prestosql.execution.TaskId; import io.prestosql.execution.TaskStateMachine; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.memory.MemoryPool; import io.prestosql.memory.QueryContext; import io.prestosql.metadata.Metadata; @@ -52,6 +51,7 @@ import io.prestosql.split.SplitSource; import io.prestosql.sql.gen.PageFunctionCompiler; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.optimizations.HashGenerationOptimizer; import io.prestosql.sql.planner.plan.PlanNodeId; @@ -83,7 +83,6 @@ import static io.prestosql.spi.connector.ConnectorSplitManager.SplitSchedulingStrategy.UNGROUPED_SCHEDULING; import static io.prestosql.spi.connector.NotPartitionedPartitionHandle.NOT_PARTITIONED; import static io.prestosql.spi.type.BigintType.BIGINT; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.relational.SqlToRowExpressionTranslator.translate; import static io.prestosql.testing.TestingSession.testSessionBuilder; import static java.lang.String.format; @@ -229,14 +228,9 @@ protected final OperatorFactory createHashProjectOperator(int operatorId, PlanNo Optional hashExpression = HashGenerationOptimizer.getHashExpression(ImmutableList.copyOf(symbolTypes.build().keySet())); verify(hashExpression.isPresent()); - Map, Type> expressionTypes = getExpressionTypes( - session, - localQueryRunner.getMetadata(), - localQueryRunner.getSqlParser(), - TypeProvider.copyOf(symbolTypes.build()), - hashExpression.get(), - ImmutableList.of(), - WarningCollector.NOOP); + Map, Type> expressionTypes = new TypeAnalyzer(localQueryRunner.getSqlParser(), localQueryRunner.getMetadata()) + .getTypes(session, TypeProvider.copyOf(symbolTypes.build()), hashExpression.get()); + RowExpression translated = translate(hashExpression.get(), SCALAR, expressionTypes, symbolToInputMapping.build(), localQueryRunner.getMetadata().getFunctionRegistry(), localQueryRunner.getTypeManager(), session, false); PageFunctionCompiler functionCompiler = new PageFunctionCompiler(localQueryRunner.getMetadata(), 0); diff --git a/presto-geospatial/src/test/java/io/prestosql/plugin/geospatial/TestExtractSpatialInnerJoin.java b/presto-geospatial/src/test/java/io/prestosql/plugin/geospatial/TestExtractSpatialInnerJoin.java index f25c31450df4..b16089288821 100644 --- a/presto-geospatial/src/test/java/io/prestosql/plugin/geospatial/TestExtractSpatialInnerJoin.java +++ b/presto-geospatial/src/test/java/io/prestosql/plugin/geospatial/TestExtractSpatialInnerJoin.java @@ -389,6 +389,6 @@ public void testPushDownAnd() private RuleAssert assertRuleApplication() { RuleTester tester = tester(); - return tester.assertThat(new ExtractSpatialInnerJoin(tester.getMetadata(), tester.getSplitManager(), tester.getPageSourceManager(), tester.getSqlParser())); + return tester.assertThat(new ExtractSpatialInnerJoin(tester.getMetadata(), tester.getSplitManager(), tester.getPageSourceManager(), tester.getTypeAnalyzer())); } } diff --git a/presto-geospatial/src/test/java/io/prestosql/plugin/geospatial/TestExtractSpatialLeftJoin.java b/presto-geospatial/src/test/java/io/prestosql/plugin/geospatial/TestExtractSpatialLeftJoin.java index d5eac82517f0..44e8374fc25e 100644 --- a/presto-geospatial/src/test/java/io/prestosql/plugin/geospatial/TestExtractSpatialLeftJoin.java +++ b/presto-geospatial/src/test/java/io/prestosql/plugin/geospatial/TestExtractSpatialLeftJoin.java @@ -258,6 +258,6 @@ public void testPushDownAnd() private RuleAssert assertRuleApplication() { RuleTester tester = tester(); - return tester().assertThat(new ExtractSpatialLeftJoin(tester.getMetadata(), tester.getSplitManager(), tester.getPageSourceManager(), tester.getSqlParser())); + return tester().assertThat(new ExtractSpatialLeftJoin(tester.getMetadata(), tester.getSplitManager(), tester.getPageSourceManager(), tester.getTypeAnalyzer())); } } diff --git a/presto-main/src/main/java/io/prestosql/execution/SqlQueryExecution.java b/presto-main/src/main/java/io/prestosql/execution/SqlQueryExecution.java index a46e38ff8017..52ba8216f3ac 100644 --- a/presto-main/src/main/java/io/prestosql/execution/SqlQueryExecution.java +++ b/presto-main/src/main/java/io/prestosql/execution/SqlQueryExecution.java @@ -62,6 +62,7 @@ import io.prestosql.sql.planner.PlanOptimizers; import io.prestosql.sql.planner.StageExecutionPlan; import io.prestosql.sql.planner.SubPlan; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.optimizations.PlanOptimizer; import io.prestosql.sql.tree.Explain; import io.prestosql.transaction.TransactionManager; @@ -414,7 +415,7 @@ private PlanRoot doAnalyzeQuery() // plan query PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); - LogicalPlanner logicalPlanner = new LogicalPlanner(stateMachine.getSession(), planOptimizers, idAllocator, metadata, sqlParser, statsCalculator, costCalculator, stateMachine.getWarningCollector()); + LogicalPlanner logicalPlanner = new LogicalPlanner(stateMachine.getSession(), planOptimizers, idAllocator, metadata, new TypeAnalyzer(sqlParser, metadata), statsCalculator, costCalculator, stateMachine.getWarningCollector()); Plan plan = logicalPlanner.plan(analysis); queryPlan.set(plan); diff --git a/presto-main/src/main/java/io/prestosql/server/ServerMainModule.java b/presto-main/src/main/java/io/prestosql/server/ServerMainModule.java index c8b15755063d..af58b6b1fe7d 100644 --- a/presto-main/src/main/java/io/prestosql/server/ServerMainModule.java +++ b/presto-main/src/main/java/io/prestosql/server/ServerMainModule.java @@ -127,6 +127,7 @@ import io.prestosql.sql.planner.CompilerConfig; import io.prestosql.sql.planner.LocalExecutionPlanner; import io.prestosql.sql.planner.NodePartitioningManager; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.FunctionCall; import io.prestosql.transaction.TransactionManagerConfig; @@ -354,6 +355,7 @@ protected void setup(Binder binder) binder.bind(Metadata.class).to(MetadataManager.class).in(Scopes.SINGLETON); // type + binder.bind(TypeAnalyzer.class).in(Scopes.SINGLETON); binder.bind(TypeRegistry.class).in(Scopes.SINGLETON); binder.bind(TypeManager.class).to(TypeRegistry.class).in(Scopes.SINGLETON); jsonBinder(binder).addDeserializerBinding(Type.class).to(TypeDeserializer.class); diff --git a/presto-main/src/main/java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java b/presto-main/src/main/java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java index dabb2e2876e5..8e0f0e04acb5 100644 --- a/presto-main/src/main/java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java +++ b/presto-main/src/main/java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java @@ -1433,44 +1433,6 @@ public static Signature resolveFunction(FunctionCall node, List, Type> getExpressionTypes( - Session session, - Metadata metadata, - SqlParser sqlParser, - TypeProvider types, - Expression expression, - List parameters, - WarningCollector warningCollector) - { - return getExpressionTypes(session, metadata, sqlParser, types, expression, parameters, warningCollector, false); - } - - public static Map, Type> getExpressionTypes( - Session session, - Metadata metadata, - SqlParser sqlParser, - TypeProvider types, - Expression expression, - List parameters, - WarningCollector warningCollector, - boolean isDescribe) - { - return getExpressionTypes(session, metadata, sqlParser, types, ImmutableList.of(expression), parameters, warningCollector, isDescribe); - } - - public static Map, Type> getExpressionTypes( - Session session, - Metadata metadata, - SqlParser sqlParser, - TypeProvider types, - Iterable expressions, - List parameters, - WarningCollector warningCollector, - boolean isDescribe) - { - return analyzeExpressions(session, metadata, sqlParser, types, expressions, parameters, warningCollector, isDescribe).getExpressionTypes(); - } - public static ExpressionAnalysis analyzeExpressions( Session session, Metadata metadata, diff --git a/presto-main/src/main/java/io/prestosql/sql/analyzer/QueryExplainer.java b/presto-main/src/main/java/io/prestosql/sql/analyzer/QueryExplainer.java index 834b9a313ee8..d0fd2416acec 100644 --- a/presto-main/src/main/java/io/prestosql/sql/analyzer/QueryExplainer.java +++ b/presto-main/src/main/java/io/prestosql/sql/analyzer/QueryExplainer.java @@ -29,6 +29,7 @@ import io.prestosql.sql.planner.PlanNodeIdAllocator; import io.prestosql.sql.planner.PlanOptimizers; import io.prestosql.sql.planner.SubPlan; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.optimizations.PlanOptimizer; import io.prestosql.sql.planner.planPrinter.IoPlanPrinter; import io.prestosql.sql.planner.planPrinter.PlanPrinter; @@ -175,7 +176,7 @@ public Plan getLogicalPlan(Session session, Statement statement, List s throw new SemanticException(NON_NUMERIC_SAMPLE_PERCENTAGE, relation.getSamplePercentage(), "Sample percentage cannot contain column references"); } - Map, Type> expressionTypes = getExpressionTypes( + Map, Type> expressionTypes = ExpressionAnalyzer.analyzeExpressions( session, metadata, sqlParser, TypeProvider.empty(), - relation.getSamplePercentage(), + ImmutableList.of(relation.getSamplePercentage()), analysis.getParameters(), WarningCollector.NOOP, - analysis.isDescribe()); + analysis.isDescribe()) + .getExpressionTypes(); + ExpressionInterpreter samplePercentageEval = expressionOptimizer(relation.getSamplePercentage(), metadata, session, expressionTypes); Object samplePercentageObject = samplePercentageEval.optimize(symbol -> { diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/DesugarAtTimeZoneRewriter.java b/presto-main/src/main/java/io/prestosql/sql/planner/DesugarAtTimeZoneRewriter.java index e9020862a05c..0c6a2375443f 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/DesugarAtTimeZoneRewriter.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/DesugarAtTimeZoneRewriter.java @@ -16,10 +16,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.prestosql.Session; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.spi.type.Type; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.tree.AtTimeZone; import io.prestosql.sql.tree.Cast; import io.prestosql.sql.tree.Expression; @@ -36,8 +34,6 @@ import static io.prestosql.spi.type.TimeWithTimeZoneType.TIME_WITH_TIME_ZONE; import static io.prestosql.spi.type.TimestampType.TIMESTAMP; import static io.prestosql.spi.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; public class DesugarAtTimeZoneRewriter @@ -49,15 +45,15 @@ public static Expression rewrite(Expression expression, Map, private DesugarAtTimeZoneRewriter() {} - public static Expression rewrite(Expression expression, Session session, Metadata metadata, SqlParser sqlParser, SymbolAllocator symbolAllocator) + public static Expression rewrite(Expression expression, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, SymbolAllocator symbolAllocator) { requireNonNull(metadata, "metadata is null"); - requireNonNull(sqlParser, "sqlParser is null"); + requireNonNull(typeAnalyzer, "typeAnalyzer is null"); if (expression instanceof SymbolReference) { return expression; } - Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, symbolAllocator.getTypes(), expression, emptyList(), WarningCollector.NOOP); + Map, Type> expressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), expression); return rewrite(expression, expressionTypes); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/DomainTranslator.java b/presto-main/src/main/java/io/prestosql/sql/planner/DomainTranslator.java index cfa67ca0d135..61eb2a8697cc 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/DomainTranslator.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/DomainTranslator.java @@ -17,7 +17,6 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.PeekingIterator; import io.prestosql.Session; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.Signature; import io.prestosql.spi.block.Block; @@ -34,7 +33,6 @@ import io.prestosql.spi.type.Type; import io.prestosql.sql.ExpressionUtils; import io.prestosql.sql.InterpretedFunctionInvoker; -import io.prestosql.sql.analyzer.ExpressionAnalyzer; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.tree.AstVisitor; import io.prestosql.sql.tree.BetweenPredicate; @@ -78,7 +76,6 @@ import static io.prestosql.sql.tree.ComparisonExpression.Operator.LESS_THAN; import static io.prestosql.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; import static io.prestosql.sql.tree.ComparisonExpression.Operator.NOT_EQUAL; -import static java.util.Collections.emptyList; import static java.util.Comparator.comparing; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.collectingAndThen; @@ -277,7 +274,7 @@ public static ExtractionResult fromPredicate( Expression predicate, TypeProvider types) { - return new Visitor(metadata, session, types).process(predicate, false); + return new Visitor(metadata, session, types, new TypeAnalyzer(new SqlParser(), metadata)).process(predicate, false); } private static class Visitor @@ -288,14 +285,16 @@ private static class Visitor private final Session session; private final TypeProvider types; private final InterpretedFunctionInvoker functionInvoker; + private final TypeAnalyzer typeAnalyzer; - private Visitor(Metadata metadata, Session session, TypeProvider types) + private Visitor(Metadata metadata, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer) { this.metadata = requireNonNull(metadata, "metadata is null"); this.literalEncoder = new LiteralEncoder(metadata.getBlockEncodingSerde()); this.session = requireNonNull(session, "session is null"); this.types = requireNonNull(types, "types is null"); this.functionInvoker = new InterpretedFunctionInvoker(metadata.getFunctionRegistry()); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); } private Type checkedTypeLookup(Symbol symbol) @@ -424,7 +423,7 @@ else if (symbolExpression instanceof Cast) { return super.visitComparisonExpression(node, complement); } - Type castSourceType = typeOf(castExpression.getExpression(), session, metadata, types); // type of expression which is then cast to type of value + Type castSourceType = typeAnalyzer.getType(session, types, castExpression.getExpression()); // type of expression which is then cast to type of value // we use saturated floor cast value -> castSourceType to rewrite original expression to new one with one cast peeled off the symbol side Optional coercedExpression = coerceComparisonWithRounding( @@ -489,7 +488,7 @@ private boolean isImplicitCoercion(Cast cast) private Map, Type> analyzeExpression(Expression expression) { - return ExpressionAnalyzer.getExpressionTypes(session, metadata, new SqlParser(), types, expression, emptyList(), WarningCollector.NOOP); + return typeAnalyzer.getTypes(session, types, expression); } private static ExtractionResult createComparisonExtractionResult(ComparisonExpression.Operator comparisonOperator, Symbol column, Type type, @Nullable Object value, boolean complement) @@ -757,12 +756,6 @@ protected ExtractionResult visitNullLiteral(NullLiteral node, Boolean complement } } - private static Type typeOf(Expression expression, Session session, Metadata metadata, TypeProvider types) - { - Map, Type> expressionTypes = ExpressionAnalyzer.getExpressionTypes(session, metadata, new SqlParser(), types, expression, emptyList(), WarningCollector.NOOP); - return expressionTypes.get(NodeRef.of(expression)); - } - private static class NormalizedSimpleComparison { private final Expression symbolExpression; diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java index 800c88eb7563..47aefe8fec2e 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java @@ -133,7 +133,6 @@ import io.prestosql.sql.gen.JoinFilterFunctionCompiler.JoinFilterFunctionFactory; import io.prestosql.sql.gen.OrderingCompiler; import io.prestosql.sql.gen.PageFunctionCompiler; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.optimizations.IndexJoinOptimizer; import io.prestosql.sql.planner.plan.AggregationNode; import io.prestosql.sql.planner.plan.AggregationNode.Aggregation; @@ -230,7 +229,6 @@ import static io.prestosql.SystemSessionProperties.isSpillEnabled; import static io.prestosql.SystemSessionProperties.isSpillOrderBy; import static io.prestosql.SystemSessionProperties.isSpillWindowOperator; -import static io.prestosql.execution.warnings.WarningCollector.NOOP; import static io.prestosql.metadata.FunctionKind.SCALAR; import static io.prestosql.operator.DistinctLimitOperator.DistinctLimitOperatorFactory; import static io.prestosql.operator.NestedLoopBuildOperator.NestedLoopBuildOperatorFactory; @@ -249,7 +247,6 @@ import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.spi.type.TypeUtils.writeNativeValue; import static io.prestosql.spiller.PartitioningSpillerFactory.unsupportedPartitioningSpillerFactory; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.gen.LambdaBytecodeGenerator.compileLambdaProvider; import static io.prestosql.sql.planner.ExpressionNodeInliner.replaceExpression; import static io.prestosql.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION; @@ -276,14 +273,13 @@ import static io.prestosql.util.SpatialJoinUtils.ST_WITHIN; import static io.prestosql.util.SpatialJoinUtils.extractSupportedSpatialComparisons; import static io.prestosql.util.SpatialJoinUtils.extractSupportedSpatialFunctions; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; import static java.util.stream.IntStream.range; public class LocalExecutionPlanner { private final Metadata metadata; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; private final Optional explainAnalyzeContext; private final PageSourceProvider pageSourceProvider; private final IndexManager indexManager; @@ -310,7 +306,7 @@ public class LocalExecutionPlanner @Inject public LocalExecutionPlanner( Metadata metadata, - SqlParser sqlParser, + TypeAnalyzer typeAnalyzer, Optional explainAnalyzeContext, PageSourceProvider pageSourceProvider, IndexManager indexManager, @@ -337,7 +333,7 @@ public LocalExecutionPlanner( this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null"); this.exchangeClientSupplier = exchangeClientSupplier; this.metadata = requireNonNull(metadata, "metadata is null"); - this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.pageSinkManager = requireNonNull(pageSinkManager, "pageSinkManager is null"); this.expressionCompiler = requireNonNull(expressionCompiler, "compiler is null"); this.pageFunctionCompiler = requireNonNull(pageFunctionCompiler, "pageFunctionCompiler is null"); @@ -1215,15 +1211,10 @@ else if (sourceNode instanceof SampleNode) { projections.add(assignments.get(symbol)); } - Map, Type> expressionTypes = getExpressionTypes( + Map, Type> expressionTypes = typeAnalyzer.getTypes( context.getSession(), - metadata, - sqlParser, context.getTypes(), - concat(filterExpression.map(ImmutableList::of).orElse(ImmutableList.of()), assignments.getExpressions()), - emptyList(), - NOOP, - false); + concat(filterExpression.map(ImmutableList::of).orElse(ImmutableList.of()), assignments.getExpressions())); Optional translatedFilter = filterExpression.map(filter -> toRowExpression(filter, expressionTypes, sourceLayout)); List translatedProjections = projections.stream() @@ -1300,15 +1291,7 @@ public PhysicalOperation visitValues(ValuesNode node, LocalExecutionPlanContext PageBuilder pageBuilder = new PageBuilder(node.getRows().size(), outputTypes); for (List row : node.getRows()) { pageBuilder.declarePosition(); - Map, Type> expressionTypes = getExpressionTypes( - context.getSession(), - metadata, - sqlParser, - TypeProvider.empty(), - ImmutableList.copyOf(row), - emptyList(), - NOOP, - false); + Map, Type> expressionTypes = typeAnalyzer.getTypes(context.getSession(), TypeProvider.empty(), ImmutableList.copyOf(row)); for (int i = 0; i < row.size(); i++) { // evaluate the literal value Object result = ExpressionInterpreter.expressionInterpreter(row.get(i), metadata, context.getSession(), expressionTypes).evaluate(); @@ -2033,17 +2016,7 @@ private JoinFilterFunctionFactory compileJoinFilterFunction( { Map joinSourcesLayout = createJoinSourcesLayout(buildLayout, probeLayout); - Map, Type> expressionTypes = getExpressionTypes( - session, - metadata, - sqlParser, - types, - filterExpression, - emptyList(), /* parameters have already been replaced */ - NOOP, - false); - - RowExpression translatedFilter = toRowExpression(filterExpression, expressionTypes, joinSourcesLayout); + RowExpression translatedFilter = toRowExpression(filterExpression, typeAnalyzer.getTypes(session, types, filterExpression), joinSourcesLayout); return joinFilterFunctionCompiler.compileJoinFilterFunction(translatedFilter, buildLayout.size()); } @@ -2554,14 +2527,7 @@ private AccumulatorFactory buildAccumulatorFactory( // expressions from lambda arguments .putAll(lambdaArgumentExpressionTypes) // expressions from lambda body - .putAll(getExpressionTypes( - session, - metadata, - sqlParser, - TypeProvider.copyOf(lambdaArgumentSymbolTypes), - lambdaExpression.getBody(), - emptyList(), - NOOP)) + .putAll(typeAnalyzer.getTypes(session, TypeProvider.copyOf(lambdaArgumentSymbolTypes), lambdaExpression.getBody())) .build(); LambdaDefinitionExpression lambda = (LambdaDefinitionExpression) toRowExpression(lambdaExpression, expressionTypes, ImmutableMap.of()); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/LogicalPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/LogicalPlanner.java index eb453f53dd03..a6df3075bc2d 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/LogicalPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/LogicalPlanner.java @@ -42,7 +42,6 @@ import io.prestosql.sql.analyzer.RelationId; import io.prestosql.sql.analyzer.RelationType; import io.prestosql.sql.analyzer.Scope; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.StatisticsAggregationPlanner.TableStatisticAggregation; import io.prestosql.sql.planner.optimizations.PlanOptimizer; import io.prestosql.sql.planner.plan.AggregationNode; @@ -115,7 +114,7 @@ public enum Stage private final PlanSanityChecker planSanityChecker; private final SymbolAllocator symbolAllocator = new SymbolAllocator(); private final Metadata metadata; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; private final StatisticsAggregationPlanner statisticsAggregationPlanner; private final StatsCalculator statsCalculator; private final CostCalculator costCalculator; @@ -125,12 +124,12 @@ public LogicalPlanner(Session session, List planOptimizers, PlanNodeIdAllocator idAllocator, Metadata metadata, - SqlParser sqlParser, + TypeAnalyzer typeAnalyzer, StatsCalculator statsCalculator, CostCalculator costCalculator, WarningCollector warningCollector) { - this(session, planOptimizers, DISTRIBUTED_PLAN_SANITY_CHECKER, idAllocator, metadata, sqlParser, statsCalculator, costCalculator, warningCollector); + this(session, planOptimizers, DISTRIBUTED_PLAN_SANITY_CHECKER, idAllocator, metadata, typeAnalyzer, statsCalculator, costCalculator, warningCollector); } public LogicalPlanner(Session session, @@ -138,7 +137,7 @@ public LogicalPlanner(Session session, PlanSanityChecker planSanityChecker, PlanNodeIdAllocator idAllocator, Metadata metadata, - SqlParser sqlParser, + TypeAnalyzer typeAnalyzer, StatsCalculator statsCalculator, CostCalculator costCalculator, WarningCollector warningCollector) @@ -148,7 +147,7 @@ public LogicalPlanner(Session session, this.planSanityChecker = requireNonNull(planSanityChecker, "planSanityChecker is null"); this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.metadata = requireNonNull(metadata, "metadata is null"); - this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.statisticsAggregationPlanner = new StatisticsAggregationPlanner(symbolAllocator, metadata); this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null"); this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); @@ -164,7 +163,7 @@ public Plan plan(Analysis analysis, Stage stage) { PlanNode root = planStatement(analysis, analysis.getStatement()); - planSanityChecker.validateIntermediatePlan(root, session, metadata, sqlParser, symbolAllocator.getTypes(), warningCollector); + planSanityChecker.validateIntermediatePlan(root, session, metadata, typeAnalyzer, symbolAllocator.getTypes(), warningCollector); if (stage.ordinal() >= Stage.OPTIMIZED.ordinal()) { for (PlanOptimizer optimizer : planOptimizers) { @@ -175,7 +174,7 @@ public Plan plan(Analysis analysis, Stage stage) if (stage.ordinal() >= Stage.OPTIMIZED_AND_VALIDATED.ordinal()) { // make sure we produce a valid plan after optimizations run. This is mainly to catch programming errors - planSanityChecker.validateFinalPlan(root, session, metadata, sqlParser, symbolAllocator.getTypes(), warningCollector); + planSanityChecker.validateFinalPlan(root, session, metadata, typeAnalyzer, symbolAllocator.getTypes(), warningCollector); } TypeProvider types = symbolAllocator.getTypes(); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java index f9dc3f265ef9..16e8817697fb 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java @@ -27,7 +27,6 @@ import io.prestosql.split.PageSourceManager; import io.prestosql.split.SplitManager; import io.prestosql.sql.analyzer.FeaturesConfig; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.iterative.IterativeOptimizer; import io.prestosql.sql.planner.iterative.Rule; import io.prestosql.sql.planner.iterative.rule.AddExchangesBelowPartialAggregationOverGroupIdRuleSet; @@ -145,7 +144,7 @@ public class PlanOptimizers @Inject public PlanOptimizers( Metadata metadata, - SqlParser sqlParser, + TypeAnalyzer typeAnalyzer, FeaturesConfig featuresConfig, NodeSchedulerConfig nodeSchedulerConfig, InternalNodeManager nodeManager, @@ -160,7 +159,7 @@ public PlanOptimizers( TaskCountEstimator taskCountEstimator) { this(metadata, - sqlParser, + typeAnalyzer, featuresConfig, taskManagerConfig, false, @@ -190,7 +189,7 @@ public void destroy() public PlanOptimizers( Metadata metadata, - SqlParser sqlParser, + TypeAnalyzer typeAnalyzer, FeaturesConfig featuresConfig, TaskManagerConfig taskManagerConfig, boolean forceSingleNode, @@ -249,9 +248,9 @@ public PlanOptimizers( ruleStats, statsCalculator, estimatedExchangesCostCalculator, - new SimplifyExpressions(metadata, sqlParser).rules()); + new SimplifyExpressions(metadata, typeAnalyzer).rules()); - PlanOptimizer predicatePushDown = new StatsRecordingPlanOptimizer(optimizerStats, new PredicatePushDown(metadata, sqlParser)); + PlanOptimizer predicatePushDown = new StatsRecordingPlanOptimizer(optimizerStats, new PredicatePushDown(metadata, typeAnalyzer)); builder.add( // Clean up all the sugar in expressions, e.g. AtTimeZone, must be run before all the other optimizers @@ -261,7 +260,7 @@ public PlanOptimizers( estimatedExchangesCostCalculator, ImmutableSet.>builder() .addAll(new DesugarLambdaExpression().rules()) - .addAll(new DesugarAtTimeZone(metadata, sqlParser).rules()) + .addAll(new DesugarAtTimeZone(metadata, typeAnalyzer).rules()) .addAll(new DesugarCurrentUser().rules()) .addAll(new DesugarCurrentPath().rules()) .addAll(new DesugarTryExpression().rules()) @@ -357,7 +356,7 @@ public PlanOptimizers( ruleStats, statsCalculator, estimatedExchangesCostCalculator, - ImmutableSet.of(new PushPredicateIntoTableScan(metadata, sqlParser))), + ImmutableSet.of(new PushPredicateIntoTableScan(metadata, typeAnalyzer))), new PruneUnreferencedOutputs(), new IterativeOptimizer( ruleStats, @@ -407,7 +406,7 @@ public PlanOptimizers( ruleStats, statsCalculator, estimatedExchangesCostCalculator, - ImmutableSet.of(new PushPredicateIntoTableScan(metadata, sqlParser))), + ImmutableSet.of(new PushPredicateIntoTableScan(metadata, typeAnalyzer))), projectionPushDown, new PruneUnreferencedOutputs(), new IterativeOptimizer( @@ -440,7 +439,7 @@ public PlanOptimizers( costCalculator, ImmutableSet.>builder() .add(new RemoveRedundantIdentityProjections()) - .addAll(new ExtractSpatialJoins(metadata, splitManager, pageSourceManager, sqlParser).rules()) + .addAll(new ExtractSpatialJoins(metadata, splitManager, pageSourceManager, typeAnalyzer).rules()) .add(new InlineProjections()) .build())); @@ -461,7 +460,7 @@ public PlanOptimizers( statsCalculator, estimatedExchangesCostCalculator, ImmutableSet.of(new PushTableWriteThroughUnion()))); // Must run before AddExchanges - builder.add(new StatsRecordingPlanOptimizer(optimizerStats, new AddExchanges(metadata, sqlParser))); + builder.add(new StatsRecordingPlanOptimizer(optimizerStats, new AddExchanges(metadata, typeAnalyzer))); } //noinspection UnusedAssignment estimatedExchangesCostCalculator = null; // Prevent accidental use after AddExchanges @@ -491,7 +490,7 @@ public PlanOptimizers( .build())); // Optimizers above this don't understand local exchanges, so be careful moving this. - builder.add(new AddLocalExchanges(metadata, sqlParser)); + builder.add(new AddLocalExchanges(metadata, typeAnalyzer)); // Optimizers above this do not need to care about aggregations with the type other than SINGLE // This optimizer must be run after all exchange-related optimizers @@ -507,7 +506,7 @@ public PlanOptimizers( ruleStats, statsCalculator, costCalculator, - new AddExchangesBelowPartialAggregationOverGroupIdRuleSet(metadata, sqlParser, taskCountEstimator, taskManagerConfig).rules())); + new AddExchangesBelowPartialAggregationOverGroupIdRuleSet(metadata, typeAnalyzer, taskCountEstimator, taskManagerConfig).rules())); builder.add(new IterativeOptimizer( ruleStats, statsCalculator, diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/TypeAnalyzer.java b/presto-main/src/main/java/io/prestosql/sql/planner/TypeAnalyzer.java new file mode 100644 index 000000000000..b9b97cb1b178 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/TypeAnalyzer.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner; + +import com.google.common.collect.ImmutableList; +import io.prestosql.Session; +import io.prestosql.execution.warnings.WarningCollector; +import io.prestosql.metadata.Metadata; +import io.prestosql.spi.type.Type; +import io.prestosql.sql.analyzer.ExpressionAnalyzer; +import io.prestosql.sql.parser.SqlParser; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.NodeRef; + +import javax.inject.Inject; + +import java.util.Map; + +// This class is to facilitate obtaining the type of an expression and its subexpressions +// during planning (i.e., when interacting with IR expression). It will eventually get +// removed when we split the AST from the IR and we encode the type directly into IR expressions. +public class TypeAnalyzer +{ + private final SqlParser parser; + private final Metadata metadata; + + @Inject + public TypeAnalyzer(SqlParser parser, Metadata metadata) + { + this.parser = parser; + this.metadata = metadata; + } + + public Map, Type> getTypes(Session session, TypeProvider inputTypes, Iterable expressions) + { + return ExpressionAnalyzer.analyzeExpressions(session, metadata, parser, inputTypes, expressions, ImmutableList.of(), WarningCollector.NOOP, false).getExpressionTypes(); + } + + public Map, Type> getTypes(Session session, TypeProvider inputTypes, Expression expression) + { + return getTypes(session, inputTypes, ImmutableList.of(expression)); + } + + public Type getType(Session session, TypeProvider inputTypes, Expression expression) + { + return getTypes(session, inputTypes, expression).get(NodeRef.of(expression)); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java index 6e6639dab739..010b8b9adcf9 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java @@ -26,10 +26,10 @@ import io.prestosql.matching.Captures; import io.prestosql.matching.Pattern; import io.prestosql.metadata.Metadata; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Partitioning; import io.prestosql.sql.planner.PartitioningScheme; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.iterative.Rule; import io.prestosql.sql.planner.optimizations.StreamPreferredProperties; import io.prestosql.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties; @@ -128,18 +128,18 @@ public class AddExchangesBelowPartialAggregationOverGroupIdRuleSet private static final double ANTI_SKEWNESS_MARGIN = 3; private final Metadata metadata; - private final SqlParser parser; + private final TypeAnalyzer typeAnalyzer; private final TaskCountEstimator taskCountEstimator; private final DataSize maxPartialAggregationMemoryUsage; public AddExchangesBelowPartialAggregationOverGroupIdRuleSet( Metadata metadata, - SqlParser parser, + TypeAnalyzer typeAnalyzer, TaskCountEstimator taskCountEstimator, TaskManagerConfig taskManagerConfig) { this.metadata = requireNonNull(metadata, "metadata is null"); - this.parser = requireNonNull(parser, "parser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.taskCountEstimator = requireNonNull(taskCountEstimator, "taskCountEstimator is null"); this.maxPartialAggregationMemoryUsage = requireNonNull(taskManagerConfig, "taskManagerConfig is null").getMaxPartialAggregationMemoryUsage(); } @@ -342,7 +342,7 @@ private StreamProperties derivePropertiesRecursively(PlanNode node, Context cont List inputProperties = resolvedPlanNode.getSources().stream() .map(source -> derivePropertiesRecursively(source, context)) .collect(toImmutableList()); - return deriveProperties(resolvedPlanNode, inputProperties, metadata, context.getSession(), context.getSymbolAllocator().getTypes(), parser); + return deriveProperties(resolvedPlanNode, inputProperties, metadata, context.getSession(), context.getSymbolAllocator().getTypes(), typeAnalyzer); } } } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/DesugarAtTimeZone.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/DesugarAtTimeZone.java index c2ebb7b3679f..1ee472b74a5a 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/DesugarAtTimeZone.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/DesugarAtTimeZone.java @@ -15,8 +15,8 @@ import com.google.common.collect.ImmutableSet; import io.prestosql.metadata.Metadata; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.DesugarAtTimeZoneRewriter; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.iterative.Rule; import java.util.Set; @@ -26,9 +26,9 @@ public class DesugarAtTimeZone extends ExpressionRewriteRuleSet { - public DesugarAtTimeZone(Metadata metadata, SqlParser sqlParser) + public DesugarAtTimeZone(Metadata metadata, TypeAnalyzer typeAnalyzer) { - super(createRewrite(metadata, sqlParser)); + super(createRewrite(metadata, typeAnalyzer)); } @Override @@ -42,11 +42,11 @@ public Set> rules() valuesExpressionRewrite()); } - private static ExpressionRewriter createRewrite(Metadata metadata, SqlParser sqlParser) + private static ExpressionRewriter createRewrite(Metadata metadata, TypeAnalyzer typeAnalyzer) { requireNonNull(metadata, "metadata is null"); - requireNonNull(sqlParser, "sqlParser is null"); + requireNonNull(typeAnalyzer, "typeAnalyzer is null"); - return (expression, context) -> DesugarAtTimeZoneRewriter.rewrite(expression, context.getSession(), metadata, sqlParser, context.getSymbolAllocator()); + return (expression, context) -> DesugarAtTimeZoneRewriter.rewrite(expression, context.getSession(), metadata, typeAnalyzer, context.getSymbolAllocator()); } } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractSpatialJoins.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractSpatialJoins.java index b0c118ace6de..170ca6294796 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractSpatialJoins.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractSpatialJoins.java @@ -21,7 +21,6 @@ import com.google.common.collect.Iterables; import io.prestosql.Session; import io.prestosql.execution.Lifespan; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.geospatial.KdbTree; import io.prestosql.geospatial.KdbTreeUtils; import io.prestosql.matching.Capture; @@ -42,8 +41,8 @@ import io.prestosql.split.SplitManager; import io.prestosql.split.SplitSource; import io.prestosql.split.SplitSource.SplitBatch; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.iterative.Rule; import io.prestosql.sql.planner.iterative.Rule.Context; import io.prestosql.sql.planner.iterative.Rule.Result; @@ -59,7 +58,6 @@ import io.prestosql.sql.tree.ComparisonExpression; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.FunctionCall; -import io.prestosql.sql.tree.NodeRef; import io.prestosql.sql.tree.QualifiedName; import io.prestosql.sql.tree.StringLiteral; import io.prestosql.sql.tree.SymbolReference; @@ -85,7 +83,6 @@ import static io.prestosql.spi.type.IntegerType.INTEGER; import static io.prestosql.spi.type.TypeSignature.parseTypeSignature; import static io.prestosql.spi.type.VarcharType.VARCHAR; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.planner.ExpressionNodeInliner.replaceExpression; import static io.prestosql.sql.planner.SymbolsExtractor.extractUnique; import static io.prestosql.sql.planner.plan.JoinNode.Type.INNER; @@ -98,7 +95,6 @@ import static io.prestosql.util.SpatialJoinUtils.extractSupportedSpatialComparisons; import static io.prestosql.util.SpatialJoinUtils.extractSupportedSpatialFunctions; import static java.lang.String.format; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; /** @@ -158,21 +154,21 @@ public class ExtractSpatialJoins private final Metadata metadata; private final SplitManager splitManager; private final PageSourceManager pageSourceManager; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; - public ExtractSpatialJoins(Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager, SqlParser sqlParser) + public ExtractSpatialJoins(Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager, TypeAnalyzer typeAnalyzer) { this.metadata = requireNonNull(metadata, "metadata is null"); this.splitManager = requireNonNull(splitManager, "splitManager is null"); this.pageSourceManager = requireNonNull(pageSourceManager, "pageSourceManager is null"); - this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); } public Set> rules() { return ImmutableSet.of( - new ExtractSpatialInnerJoin(metadata, splitManager, pageSourceManager, sqlParser), - new ExtractSpatialLeftJoin(metadata, splitManager, pageSourceManager, sqlParser)); + new ExtractSpatialInnerJoin(metadata, splitManager, pageSourceManager, typeAnalyzer), + new ExtractSpatialLeftJoin(metadata, splitManager, pageSourceManager, typeAnalyzer)); } @VisibleForTesting @@ -186,14 +182,14 @@ public static final class ExtractSpatialInnerJoin private final Metadata metadata; private final SplitManager splitManager; private final PageSourceManager pageSourceManager; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; - public ExtractSpatialInnerJoin(Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager, SqlParser sqlParser) + public ExtractSpatialInnerJoin(Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager, TypeAnalyzer typeAnalyzer) { this.metadata = requireNonNull(metadata, "metadata is null"); this.splitManager = requireNonNull(splitManager, "splitManager is null"); this.pageSourceManager = requireNonNull(pageSourceManager, "pageSourceManager is null"); - this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); } @Override @@ -215,7 +211,7 @@ public Result apply(FilterNode node, Captures captures, Context context) Expression filter = node.getPredicate(); List spatialFunctions = extractSupportedSpatialFunctions(filter); for (FunctionCall spatialFunction : spatialFunctions) { - Result result = tryCreateSpatialJoin(context, joinNode, filter, node.getId(), node.getOutputSymbols(), spatialFunction, Optional.empty(), metadata, splitManager, pageSourceManager, sqlParser); + Result result = tryCreateSpatialJoin(context, joinNode, filter, node.getId(), node.getOutputSymbols(), spatialFunction, Optional.empty(), metadata, splitManager, pageSourceManager, typeAnalyzer); if (!result.isEmpty()) { return result; } @@ -223,7 +219,7 @@ public Result apply(FilterNode node, Captures captures, Context context) List spatialComparisons = extractSupportedSpatialComparisons(filter); for (ComparisonExpression spatialComparison : spatialComparisons) { - Result result = tryCreateSpatialJoin(context, joinNode, filter, node.getId(), node.getOutputSymbols(), spatialComparison, metadata, splitManager, pageSourceManager, sqlParser); + Result result = tryCreateSpatialJoin(context, joinNode, filter, node.getId(), node.getOutputSymbols(), spatialComparison, metadata, splitManager, pageSourceManager, typeAnalyzer); if (!result.isEmpty()) { return result; } @@ -242,14 +238,14 @@ public static final class ExtractSpatialLeftJoin private final Metadata metadata; private final SplitManager splitManager; private final PageSourceManager pageSourceManager; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; - public ExtractSpatialLeftJoin(Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager, SqlParser sqlParser) + public ExtractSpatialLeftJoin(Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager, TypeAnalyzer typeAnalyzer) { this.metadata = requireNonNull(metadata, "metadata is null"); this.splitManager = requireNonNull(splitManager, "splitManager is null"); this.pageSourceManager = requireNonNull(pageSourceManager, "pageSourceManager is null"); - this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); } @Override @@ -270,7 +266,7 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) Expression filter = joinNode.getFilter().get(); List spatialFunctions = extractSupportedSpatialFunctions(filter); for (FunctionCall spatialFunction : spatialFunctions) { - Result result = tryCreateSpatialJoin(context, joinNode, filter, joinNode.getId(), joinNode.getOutputSymbols(), spatialFunction, Optional.empty(), metadata, splitManager, pageSourceManager, sqlParser); + Result result = tryCreateSpatialJoin(context, joinNode, filter, joinNode.getId(), joinNode.getOutputSymbols(), spatialFunction, Optional.empty(), metadata, splitManager, pageSourceManager, typeAnalyzer); if (!result.isEmpty()) { return result; } @@ -278,7 +274,7 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) List spatialComparisons = extractSupportedSpatialComparisons(filter); for (ComparisonExpression spatialComparison : spatialComparisons) { - Result result = tryCreateSpatialJoin(context, joinNode, filter, joinNode.getId(), joinNode.getOutputSymbols(), spatialComparison, metadata, splitManager, pageSourceManager, sqlParser); + Result result = tryCreateSpatialJoin(context, joinNode, filter, joinNode.getId(), joinNode.getOutputSymbols(), spatialComparison, metadata, splitManager, pageSourceManager, typeAnalyzer); if (!result.isEmpty()) { return result; } @@ -298,7 +294,7 @@ private static Result tryCreateSpatialJoin( Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager, - SqlParser sqlParser) + TypeAnalyzer typeAnalyzer) { PlanNode leftNode = joinNode.getLeft(); PlanNode rightNode = joinNode.getRight(); @@ -350,7 +346,7 @@ private static Result tryCreateSpatialJoin( joinNode.getDistributionType(), joinNode.isSpillable()); - return tryCreateSpatialJoin(context, newJoinNode, newFilter, nodeId, outputSymbols, (FunctionCall) newComparison.getLeft(), Optional.of(newComparison.getRight()), metadata, splitManager, pageSourceManager, sqlParser); + return tryCreateSpatialJoin(context, newJoinNode, newFilter, nodeId, outputSymbols, (FunctionCall) newComparison.getLeft(), Optional.of(newComparison.getRight()), metadata, splitManager, pageSourceManager, typeAnalyzer); } private static Result tryCreateSpatialJoin( @@ -364,7 +360,7 @@ private static Result tryCreateSpatialJoin( Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager, - SqlParser sqlParser) + TypeAnalyzer typeAnalyzer) { // TODO Add support for distributed left spatial joins Optional spatialPartitioningTableName = joinNode.getType() == INNER ? getSpatialPartitioningTableName(context.getSession()) : Optional.empty(); @@ -377,8 +373,8 @@ private static Result tryCreateSpatialJoin( Expression secondArgument = arguments.get(1); Type sphericalGeographyType = metadata.getType(SPHERICAL_GEOGRAPHY_TYPE_SIGNATURE); - if (getExpressionType(firstArgument, context, metadata, sqlParser).equals(sphericalGeographyType) - || getExpressionType(secondArgument, context, metadata, sqlParser).equals(sphericalGeographyType)) { + if (typeAnalyzer.getType(context.getSession(), context.getSymbolAllocator().getTypes(), firstArgument).equals(sphericalGeographyType) + || typeAnalyzer.getType(context.getSession(), context.getSymbolAllocator().getTypes(), secondArgument).equals(sphericalGeographyType)) { return Result.empty(); } @@ -446,14 +442,6 @@ else if (alignment < 0) { kdbTree.map(KdbTreeUtils::toJson))); } - private static Type getExpressionType(Expression expression, Context context, Metadata metadata, SqlParser sqlParser) - { - Type type = getExpressionTypes(context.getSession(), metadata, sqlParser, context.getSymbolAllocator().getTypes(), expression, emptyList(), WarningCollector.NOOP) - .get(NodeRef.of(expression)); - verify(type != null); - return type; - } - private static KdbTree loadKdbTree(String tableName, Session session, Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager) { QualifiedObjectName name = toQualifiedObjectName(tableName, session.getCatalog().get(), session.getSchema().get()); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushPredicateIntoTableScan.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushPredicateIntoTableScan.java index 244200de417d..07625f27ee58 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushPredicateIntoTableScan.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushPredicateIntoTableScan.java @@ -16,7 +16,6 @@ import com.google.common.collect.ImmutableBiMap; import com.google.common.collect.ImmutableList; import io.prestosql.Session; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.matching.Capture; import io.prestosql.matching.Captures; import io.prestosql.matching.Pattern; @@ -27,8 +26,6 @@ import io.prestosql.spi.connector.Constraint; import io.prestosql.spi.predicate.NullableValue; import io.prestosql.spi.predicate.TupleDomain; -import io.prestosql.spi.type.Type; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.DomainTranslator; import io.prestosql.sql.planner.ExpressionInterpreter; import io.prestosql.sql.planner.LiteralEncoder; @@ -36,6 +33,7 @@ import io.prestosql.sql.planner.PlanNodeIdAllocator; import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.SymbolsExtractor; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.iterative.Rule; import io.prestosql.sql.planner.plan.FilterNode; @@ -43,7 +41,6 @@ import io.prestosql.sql.planner.plan.TableScanNode; import io.prestosql.sql.planner.plan.ValuesNode; import io.prestosql.sql.tree.Expression; -import io.prestosql.sql.tree.NodeRef; import io.prestosql.sql.tree.NullLiteral; import java.util.Map; @@ -59,12 +56,10 @@ import static io.prestosql.sql.ExpressionUtils.combineConjuncts; import static io.prestosql.sql.ExpressionUtils.filterDeterministicConjuncts; import static io.prestosql.sql.ExpressionUtils.filterNonDeterministicConjuncts; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.planner.plan.Patterns.filter; import static io.prestosql.sql.planner.plan.Patterns.source; import static io.prestosql.sql.planner.plan.Patterns.tableScan; import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; /** @@ -80,13 +75,13 @@ public class PushPredicateIntoTableScan tableScan().capturedAs(TABLE_SCAN))); private final Metadata metadata; - private final SqlParser parser; + private final TypeAnalyzer typeAnalyzer; private final DomainTranslator domainTranslator; - public PushPredicateIntoTableScan(Metadata metadata, SqlParser parser) + public PushPredicateIntoTableScan(Metadata metadata, TypeAnalyzer typeAnalyzer) { this.metadata = requireNonNull(metadata, "metadata is null"); - this.parser = requireNonNull(parser, "parser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.domainTranslator = new DomainTranslator(new LiteralEncoder(metadata.getBlockEncodingSerde())); } @@ -115,7 +110,7 @@ public Result apply(FilterNode filterNode, Captures captures, Context context) context.getSymbolAllocator().getTypes(), context.getIdAllocator(), metadata, - parser, + typeAnalyzer, domainTranslator); if (arePlansSame(filterNode, tableScan, rewritten)) { @@ -154,7 +149,7 @@ public static PlanNode pushFilterIntoTableScan( TypeProvider types, PlanNodeIdAllocator idAllocator, Metadata metadata, - SqlParser parser, + TypeAnalyzer typeAnalyzer, DomainTranslator domainTranslator) { // don't include non-deterministic predicates @@ -176,7 +171,7 @@ public static PlanNode pushFilterIntoTableScan( if (pruneWithPredicateExpression) { LayoutConstraintEvaluator evaluator = new LayoutConstraintEvaluator( metadata, - parser, + typeAnalyzer, session, types, node.getAssignments(), @@ -239,13 +234,11 @@ private static class LayoutConstraintEvaluator private final ExpressionInterpreter evaluator; private final Set arguments; - public LayoutConstraintEvaluator(Metadata metadata, SqlParser parser, Session session, TypeProvider types, Map assignments, Expression expression) + public LayoutConstraintEvaluator(Metadata metadata, TypeAnalyzer typeAnalyzer, Session session, TypeProvider types, Map assignments, Expression expression) { this.assignments = assignments; - Map, Type> expressionTypes = getExpressionTypes(session, metadata, parser, types, expression, emptyList(), WarningCollector.NOOP); - - evaluator = ExpressionInterpreter.expressionOptimizer(expression, metadata, session, expressionTypes); + evaluator = ExpressionInterpreter.expressionOptimizer(expression, metadata, session, typeAnalyzer.getTypes(session, types, expression)); arguments = SymbolsExtractor.extractUnique(expression).stream() .map(assignments::get) .collect(toImmutableSet()); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/SimplifyExpressions.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/SimplifyExpressions.java index 0ccf965931ee..95c39cd64958 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/SimplifyExpressions.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/SimplifyExpressions.java @@ -16,14 +16,13 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableSet; import io.prestosql.Session; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.spi.type.Type; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.ExpressionInterpreter; import io.prestosql.sql.planner.LiteralEncoder; import io.prestosql.sql.planner.NoOpSymbolResolver; import io.prestosql.sql.planner.SymbolAllocator; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.iterative.Rule; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.NodeRef; @@ -32,33 +31,31 @@ import java.util.Map; import java.util.Set; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.planner.iterative.rule.ExtractCommonPredicatesExpressionRewriter.extractCommonPredicates; import static io.prestosql.sql.planner.iterative.rule.PushDownNegationsExpressionRewriter.pushDownNegations; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; public class SimplifyExpressions extends ExpressionRewriteRuleSet { @VisibleForTesting - static Expression rewrite(Expression expression, Session session, SymbolAllocator symbolAllocator, Metadata metadata, LiteralEncoder literalEncoder, SqlParser sqlParser) + static Expression rewrite(Expression expression, Session session, SymbolAllocator symbolAllocator, Metadata metadata, LiteralEncoder literalEncoder, TypeAnalyzer typeAnalyzer) { requireNonNull(metadata, "metadata is null"); - requireNonNull(sqlParser, "sqlParser is null"); + requireNonNull(typeAnalyzer, "typeAnalyzer is null"); if (expression instanceof SymbolReference) { return expression; } expression = pushDownNegations(expression); expression = extractCommonPredicates(expression); - Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, symbolAllocator.getTypes(), expression, emptyList(), WarningCollector.NOOP); + Map, Type> expressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), expression); ExpressionInterpreter interpreter = ExpressionInterpreter.expressionOptimizer(expression, metadata, session, expressionTypes); return literalEncoder.toExpression(interpreter.optimize(NoOpSymbolResolver.INSTANCE), expressionTypes.get(NodeRef.of(expression))); } - public SimplifyExpressions(Metadata metadata, SqlParser sqlParser) + public SimplifyExpressions(Metadata metadata, TypeAnalyzer typeAnalyzer) { - super(createRewrite(metadata, sqlParser)); + super(createRewrite(metadata, typeAnalyzer)); } @Override @@ -71,12 +68,12 @@ public Set> rules() valuesExpressionRewrite()); // ApplyNode and AggregationNode are not supported, because ExpressionInterpreter doesn't support them } - private static ExpressionRewriter createRewrite(Metadata metadata, SqlParser sqlParser) + private static ExpressionRewriter createRewrite(Metadata metadata, TypeAnalyzer typeAnalyzer) { requireNonNull(metadata, "metadata is null"); - requireNonNull(sqlParser, "sqlParser is null"); + requireNonNull(typeAnalyzer, "typeAnalyzer is null"); LiteralEncoder literalEncoder = new LiteralEncoder(metadata.getBlockEncodingSerde()); - return (expression, context) -> rewrite(expression, context.getSession(), context.getSymbolAllocator(), metadata, literalEncoder, sqlParser); + return (expression, context) -> rewrite(expression, context.getSession(), context.getSymbolAllocator(), metadata, literalEncoder, typeAnalyzer); } } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddExchanges.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddExchanges.java index e4a97ef1298b..57f98b9d8064 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddExchanges.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddExchanges.java @@ -26,7 +26,6 @@ import io.prestosql.spi.connector.GroupingProperty; import io.prestosql.spi.connector.LocalProperty; import io.prestosql.spi.connector.SortingProperty; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.DomainTranslator; import io.prestosql.sql.planner.LiteralEncoder; import io.prestosql.sql.planner.Partitioning; @@ -34,6 +33,7 @@ import io.prestosql.sql.planner.PlanNodeIdAllocator; import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.SymbolAllocator; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.iterative.rule.PushPredicateIntoTableScan; import io.prestosql.sql.planner.plan.AggregationNode; @@ -112,15 +112,15 @@ public class AddExchanges implements PlanOptimizer { - private final SqlParser parser; + private final TypeAnalyzer typeAnalyzer; private final Metadata metadata; private final DomainTranslator domainTranslator; - public AddExchanges(Metadata metadata, SqlParser parser) + public AddExchanges(Metadata metadata, TypeAnalyzer typeAnalyzer) { this.metadata = metadata; this.domainTranslator = new DomainTranslator(new LiteralEncoder(metadata.getBlockEncodingSerde())); - this.parser = parser; + this.typeAnalyzer = typeAnalyzer; } @Override @@ -532,7 +532,7 @@ else if (redistributeWrites) { private PlanWithProperties planTableScan(TableScanNode node, Expression predicate) { - PlanNode plan = PushPredicateIntoTableScan.pushFilterIntoTableScan(node, predicate, true, session, types, idAllocator, metadata, parser, domainTranslator); + PlanNode plan = PushPredicateIntoTableScan.pushFilterIntoTableScan(node, predicate, true, session, types, idAllocator, metadata, typeAnalyzer, domainTranslator); return new PlanWithProperties(plan, derivePropertiesRecursively(plan)); } @@ -1190,7 +1190,7 @@ private ActualProperties deriveProperties(PlanNode result, ActualProperties inpu private ActualProperties deriveProperties(PlanNode result, List inputProperties) { // TODO: move this logic to PlanSanityChecker once PropertyDerivations.deriveProperties fully supports local exchanges - ActualProperties outputProperties = PropertyDerivations.deriveProperties(result, inputProperties, metadata, session, types, parser); + ActualProperties outputProperties = PropertyDerivations.deriveProperties(result, inputProperties, metadata, session, types, typeAnalyzer); verify(result instanceof SemiJoinNode || inputProperties.stream().noneMatch(ActualProperties::isNullsAndAnyReplicated) || outputProperties.isNullsAndAnyReplicated(), "SemiJoinNode is the only node that can strip null replication"); return outputProperties; @@ -1198,7 +1198,7 @@ private ActualProperties deriveProperties(PlanNode result, List inputProperties) { - return new PlanWithProperties(result, StreamPropertyDerivations.deriveProperties(result, inputProperties, metadata, session, types, parser)); + return new PlanWithProperties(result, StreamPropertyDerivations.deriveProperties(result, inputProperties, metadata, session, types, typeAnalyzer)); } } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/ExpressionEquivalence.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/ExpressionEquivalence.java index d051b49f5941..efae3082d79e 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/ExpressionEquivalence.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/ExpressionEquivalence.java @@ -20,12 +20,11 @@ import com.google.common.collect.Ordering; import io.airlift.slice.Slice; import io.prestosql.Session; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.Signature; import io.prestosql.spi.type.Type; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.relational.CallExpression; import io.prestosql.sql.relational.ConstantExpression; @@ -35,7 +34,6 @@ import io.prestosql.sql.relational.RowExpressionVisitor; import io.prestosql.sql.relational.VariableReferenceExpression; import io.prestosql.sql.tree.Expression; -import io.prestosql.sql.tree.NodeRef; import java.util.Comparator; import java.util.HashMap; @@ -57,10 +55,8 @@ import static io.prestosql.spi.function.OperatorType.LESS_THAN_OR_EQUAL; import static io.prestosql.spi.function.OperatorType.NOT_EQUAL; import static io.prestosql.spi.type.BooleanType.BOOLEAN; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.relational.SqlToRowExpressionTranslator.translate; import static java.lang.Integer.min; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; public class ExpressionEquivalence @@ -68,12 +64,12 @@ public class ExpressionEquivalence private static final Ordering ROW_EXPRESSION_ORDERING = Ordering.from(new RowExpressionComparator()); private static final CanonicalizationVisitor CANONICALIZATION_VISITOR = new CanonicalizationVisitor(); private final Metadata metadata; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; - public ExpressionEquivalence(Metadata metadata, SqlParser sqlParser) + public ExpressionEquivalence(Metadata metadata, TypeAnalyzer typeAnalyzer) { this.metadata = requireNonNull(metadata, "metadata is null"); - this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); } public boolean areExpressionsEquivalent(Session session, Expression leftExpression, Expression rightExpression, TypeProvider types) @@ -95,18 +91,15 @@ public boolean areExpressionsEquivalent(Session session, Expression leftExpressi private RowExpression toRowExpression(Session session, Expression expression, Map symbolInput, TypeProvider types) { - // determine the type of every expression - Map, Type> expressionTypes = getExpressionTypes( - session, - metadata, - sqlParser, - types, + return translate( expression, - emptyList(), /* parameters have already been replaced */ - WarningCollector.NOOP); - - // convert to row expression - return translate(expression, SCALAR, expressionTypes, symbolInput, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, false); + SCALAR, + typeAnalyzer.getTypes(session, types, expression), + symbolInput, + metadata.getFunctionRegistry(), + metadata.getTypeManager(), + session, + false); } private static class CanonicalizationVisitor diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PredicatePushDown.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PredicatePushDown.java index 27f7e0a558ca..246616248cec 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PredicatePushDown.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PredicatePushDown.java @@ -20,7 +20,6 @@ import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.spi.type.Type; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.DeterminismEvaluator; import io.prestosql.sql.planner.DomainTranslator; import io.prestosql.sql.planner.EffectivePredicateExtractor; @@ -32,6 +31,7 @@ import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.SymbolAllocator; import io.prestosql.sql.planner.SymbolsExtractor; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.AggregationNode; import io.prestosql.sql.planner.plan.AssignUniqueId; @@ -84,7 +84,6 @@ import static io.prestosql.sql.ExpressionUtils.combineConjuncts; import static io.prestosql.sql.ExpressionUtils.extractConjuncts; import static io.prestosql.sql.ExpressionUtils.filterDeterministicConjuncts; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.planner.DeterminismEvaluator.isDeterministic; import static io.prestosql.sql.planner.EqualityInference.createEqualityInference; import static io.prestosql.sql.planner.ExpressionSymbolInliner.inlineSymbols; @@ -93,7 +92,6 @@ import static io.prestosql.sql.planner.plan.JoinNode.Type.LEFT; import static io.prestosql.sql.planner.plan.JoinNode.Type.RIGHT; import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; public class PredicatePushDown @@ -102,14 +100,14 @@ public class PredicatePushDown private final Metadata metadata; private final LiteralEncoder literalEncoder; private final EffectivePredicateExtractor effectivePredicateExtractor; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; - public PredicatePushDown(Metadata metadata, SqlParser sqlParser) + public PredicatePushDown(Metadata metadata, TypeAnalyzer typeAnalyzer) { this.metadata = requireNonNull(metadata, "metadata is null"); this.literalEncoder = new LiteralEncoder(metadata.getBlockEncodingSerde()); this.effectivePredicateExtractor = new EffectivePredicateExtractor(new DomainTranslator(literalEncoder)); - this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); } @Override @@ -121,7 +119,7 @@ public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, Sym requireNonNull(idAllocator, "idAllocator is null"); return SimplePlanRewriter.rewriteWith( - new Rewriter(symbolAllocator, idAllocator, metadata, literalEncoder, effectivePredicateExtractor, sqlParser, session, types), + new Rewriter(symbolAllocator, idAllocator, metadata, literalEncoder, effectivePredicateExtractor, typeAnalyzer, session, types), plan, TRUE_LITERAL); } @@ -134,7 +132,7 @@ private static class Rewriter private final Metadata metadata; private final LiteralEncoder literalEncoder; private final EffectivePredicateExtractor effectivePredicateExtractor; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; private final Session session; private final TypeProvider types; private final ExpressionEquivalence expressionEquivalence; @@ -145,7 +143,7 @@ private Rewriter( Metadata metadata, LiteralEncoder literalEncoder, EffectivePredicateExtractor effectivePredicateExtractor, - SqlParser sqlParser, + TypeAnalyzer typeAnalyzer, Session session, TypeProvider types) { @@ -154,10 +152,10 @@ private Rewriter( this.metadata = requireNonNull(metadata, "metadata is null"); this.literalEncoder = requireNonNull(literalEncoder, "literalEncoder is null"); this.effectivePredicateExtractor = requireNonNull(effectivePredicateExtractor, "effectivePredicateExtractor is null"); - this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.session = requireNonNull(session, "session is null"); this.types = requireNonNull(types, "types is null"); - this.expressionEquivalence = new ExpressionEquivalence(metadata, sqlParser); + this.expressionEquivalence = new ExpressionEquivalence(metadata, typeAnalyzer); } @Override @@ -638,7 +636,7 @@ private Symbol symbolForExpression(Expression expression) return Symbol.from(expression); } - return symbolAllocator.newSymbol(expression, extractType(expression)); + return symbolAllocator.newSymbol(expression, typeAnalyzer.getType(session, symbolAllocator.getTypes(), expression)); } private static OuterJoinPushDownResult processLimitedOuterJoin(Expression inheritedPredicate, Expression outerEffectivePredicate, Expression innerEffectivePredicate, Expression joinPredicate, Collection outerSymbols) @@ -891,12 +889,6 @@ private static Expression extractJoinPredicate(JoinNode joinNode) return combineConjuncts(builder.build()); } - private Type extractType(Expression expression) - { - Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, symbolAllocator.getTypes(), expression, emptyList(), /* parameters have already been replaced */WarningCollector.NOOP); - return expressionTypes.get(NodeRef.of(expression)); - } - private JoinNode tryNormalizeToOuterToInnerJoin(JoinNode node, Expression inheritedPredicate) { checkArgument(EnumSet.of(INNER, RIGHT, LEFT, FULL).contains(node.getType()), "Unsupported join type: %s", node.getType()); @@ -948,14 +940,7 @@ private boolean canConvertOuterToInner(List innerSymbolsForOuterJoin, Ex // Temporary implementation for joins because the SimplifyExpressions optimizers can not run properly on join clauses private Expression simplifyExpression(Expression expression) { - Map, Type> expressionTypes = getExpressionTypes( - session, - metadata, - sqlParser, - symbolAllocator.getTypes(), - expression, - emptyList(), /* parameters have already been replaced */ - WarningCollector.NOOP); + Map, Type> expressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), expression); ExpressionInterpreter optimizer = ExpressionInterpreter.expressionOptimizer(expression, metadata, session, expressionTypes); return literalEncoder.toExpression(optimizer.optimize(NoOpSymbolResolver.INSTANCE), expressionTypes.get(NodeRef.of(expression))); } @@ -970,14 +955,7 @@ private boolean areExpressionsEquivalent(Expression leftExpression, Expression r */ private Object nullInputEvaluator(final Collection nullSymbols, Expression expression) { - Map, Type> expressionTypes = getExpressionTypes( - session, - metadata, - sqlParser, - symbolAllocator.getTypes(), - expression, - emptyList(), /* parameters have already been replaced */ - WarningCollector.NOOP); + Map, Type> expressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), expression); return ExpressionInterpreter.expressionOptimizer(expression, metadata, session, expressionTypes) .optimize(symbol -> nullSymbols.contains(symbol) ? null : symbol.toSymbolReference()); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PropertyDerivations.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PropertyDerivations.java index 8657f7f3e31d..b0de8196db49 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PropertyDerivations.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PropertyDerivations.java @@ -21,7 +21,6 @@ import com.google.common.collect.Sets; import io.prestosql.Session; import io.prestosql.SystemSessionProperties; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.TableLayout; import io.prestosql.metadata.TableLayout.TablePartitioning; @@ -32,13 +31,13 @@ import io.prestosql.spi.connector.SortingProperty; import io.prestosql.spi.predicate.NullableValue; import io.prestosql.spi.type.Type; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.DomainTranslator; import io.prestosql.sql.planner.ExpressionInterpreter; import io.prestosql.sql.planner.NoOpSymbolResolver; import io.prestosql.sql.planner.OrderingScheme; import io.prestosql.sql.planner.Partitioning; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.optimizations.ActualProperties.Global; import io.prestosql.sql.planner.plan.AggregationNode; @@ -95,7 +94,6 @@ import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.prestosql.SystemSessionProperties.planWithTableNodePartitioning; import static io.prestosql.spi.predicate.TupleDomain.extractFixedValues; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.planner.SystemPartitioningHandle.ARBITRARY_DISTRIBUTION; import static io.prestosql.sql.planner.optimizations.ActualProperties.Global.arbitraryPartition; import static io.prestosql.sql.planner.optimizations.ActualProperties.Global.coordinatorSingleStreamPartition; @@ -104,7 +102,6 @@ import static io.prestosql.sql.planner.optimizations.ActualProperties.Global.streamPartitionedOn; import static io.prestosql.sql.planner.plan.ExchangeNode.Scope.LOCAL; import static io.prestosql.sql.planner.plan.ExchangeNode.Scope.REMOTE; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toMap; @@ -112,17 +109,17 @@ public class PropertyDerivations { private PropertyDerivations() {} - public static ActualProperties derivePropertiesRecursively(PlanNode node, Metadata metadata, Session session, TypeProvider types, SqlParser parser) + public static ActualProperties derivePropertiesRecursively(PlanNode node, Metadata metadata, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer) { List inputProperties = node.getSources().stream() - .map(source -> derivePropertiesRecursively(source, metadata, session, types, parser)) + .map(source -> derivePropertiesRecursively(source, metadata, session, types, typeAnalyzer)) .collect(toImmutableList()); - return deriveProperties(node, inputProperties, metadata, session, types, parser); + return deriveProperties(node, inputProperties, metadata, session, types, typeAnalyzer); } - public static ActualProperties deriveProperties(PlanNode node, List inputProperties, Metadata metadata, Session session, TypeProvider types, SqlParser parser) + public static ActualProperties deriveProperties(PlanNode node, List inputProperties, Metadata metadata, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer) { - ActualProperties output = node.accept(new Visitor(metadata, session, types, parser), inputProperties); + ActualProperties output = node.accept(new Visitor(metadata, session, types, typeAnalyzer), inputProperties); output.getNodePartitioning().ifPresent(partitioning -> verify(node.getOutputSymbols().containsAll(partitioning.getColumns()), "Node-level partitioning properties contain columns not present in node's output")); @@ -137,9 +134,9 @@ public static ActualProperties deriveProperties(PlanNode node, List inputProperties, Metadata metadata, Session session, TypeProvider types, SqlParser parser) + public static ActualProperties streamBackdoorDeriveProperties(PlanNode node, List inputProperties, Metadata metadata, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer) { - return node.accept(new Visitor(metadata, session, types, parser), inputProperties); + return node.accept(new Visitor(metadata, session, types, typeAnalyzer), inputProperties); } private static class Visitor @@ -148,14 +145,14 @@ private static class Visitor private final Metadata metadata; private final Session session; private final TypeProvider types; - private final SqlParser parser; + private final TypeAnalyzer typeAnalyzer; - public Visitor(Metadata metadata, Session session, TypeProvider types, SqlParser parser) + public Visitor(Metadata metadata, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer) { this.metadata = metadata; this.session = session; this.types = types; - this.parser = parser; + this.typeAnalyzer = typeAnalyzer; } @Override @@ -636,7 +633,7 @@ public ActualProperties visitProject(ProjectNode node, List in for (Map.Entry assignment : node.getAssignments().entrySet()) { Expression expression = assignment.getValue(); - Map, Type> expressionTypes = getExpressionTypes(session, metadata, parser, types, expression, emptyList(), WarningCollector.NOOP); + Map, Type> expressionTypes = typeAnalyzer.getTypes(session, types, expression); Type type = requireNonNull(expressionTypes.get(NodeRef.of(expression))); ExpressionInterpreter optimizer = ExpressionInterpreter.expressionOptimizer(expression, metadata, session, expressionTypes); // TODO: diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/StreamPropertyDerivations.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/StreamPropertyDerivations.java index d7f620263238..ed1c0ff221d6 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/StreamPropertyDerivations.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/StreamPropertyDerivations.java @@ -23,9 +23,9 @@ import io.prestosql.metadata.TableLayout; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.LocalProperty; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Partitioning.ArgumentBinding; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.AggregationNode; import io.prestosql.sql.planner.plan.ApplyNode; @@ -96,27 +96,27 @@ public final class StreamPropertyDerivations { private StreamPropertyDerivations() {} - public static StreamProperties derivePropertiesRecursively(PlanNode node, Metadata metadata, Session session, TypeProvider types, SqlParser parser) + public static StreamProperties derivePropertiesRecursively(PlanNode node, Metadata metadata, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer) { List inputProperties = node.getSources().stream() - .map(source -> derivePropertiesRecursively(source, metadata, session, types, parser)) + .map(source -> derivePropertiesRecursively(source, metadata, session, types, typeAnalyzer)) .collect(toImmutableList()); - return StreamPropertyDerivations.deriveProperties(node, inputProperties, metadata, session, types, parser); + return StreamPropertyDerivations.deriveProperties(node, inputProperties, metadata, session, types, typeAnalyzer); } - public static StreamProperties deriveProperties(PlanNode node, StreamProperties inputProperties, Metadata metadata, Session session, TypeProvider types, SqlParser parser) + public static StreamProperties deriveProperties(PlanNode node, StreamProperties inputProperties, Metadata metadata, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer) { - return deriveProperties(node, ImmutableList.of(inputProperties), metadata, session, types, parser); + return deriveProperties(node, ImmutableList.of(inputProperties), metadata, session, types, typeAnalyzer); } - public static StreamProperties deriveProperties(PlanNode node, List inputProperties, Metadata metadata, Session session, TypeProvider types, SqlParser parser) + public static StreamProperties deriveProperties(PlanNode node, List inputProperties, Metadata metadata, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer) { requireNonNull(node, "node is null"); requireNonNull(inputProperties, "inputProperties is null"); requireNonNull(metadata, "metadata is null"); requireNonNull(session, "session is null"); requireNonNull(types, "types is null"); - requireNonNull(parser, "parser is null"); + requireNonNull(typeAnalyzer, "typeAnalyzer is null"); // properties.otherActualProperties will never be null here because the only way // an external caller should obtain StreamProperties is from this method, and the @@ -129,7 +129,7 @@ public static StreamProperties deriveProperties(PlanNode node, List planNodeIds = new HashMap<>(); searchFrom(planNode) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/NoIdentifierLeftChecker.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/NoIdentifierLeftChecker.java index ddb9e03c65dc..a7d0f133cac0 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/NoIdentifierLeftChecker.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/NoIdentifierLeftChecker.java @@ -17,8 +17,8 @@ import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.sql.analyzer.ExpressionTreeUtils; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.ExpressionExtractor; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.tree.Identifier; @@ -29,7 +29,7 @@ public final class NoIdentifierLeftChecker implements PlanSanityChecker.Checker { @Override - public void validate(PlanNode plan, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validate(PlanNode plan, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { List identifiers = ExpressionTreeUtils.extractExpressions(ExpressionExtractor.extractExpressions(plan), Identifier.class); if (!identifiers.isEmpty()) { diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/NoSubqueryExpressionLeftChecker.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/NoSubqueryExpressionLeftChecker.java index 916235efc1b9..88c4e94c63b8 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/NoSubqueryExpressionLeftChecker.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/NoSubqueryExpressionLeftChecker.java @@ -16,8 +16,8 @@ import io.prestosql.Session; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.ExpressionExtractor; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.tree.DefaultTraversalVisitor; @@ -30,7 +30,7 @@ public final class NoSubqueryExpressionLeftChecker implements PlanSanityChecker.Checker { @Override - public void validate(PlanNode plan, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validate(PlanNode plan, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { for (Expression expression : ExpressionExtractor.extractExpressions(plan)) { new DefaultTraversalVisitor() diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/PlanSanityChecker.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/PlanSanityChecker.java index cc7a468aad48..6c6cb14d1de4 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/PlanSanityChecker.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/PlanSanityChecker.java @@ -18,7 +18,7 @@ import io.prestosql.Session; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; -import io.prestosql.sql.parser.SqlParser; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.PlanNode; @@ -56,19 +56,19 @@ public PlanSanityChecker(boolean forceSingleNode) .build(); } - public void validateFinalPlan(PlanNode planNode, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validateFinalPlan(PlanNode planNode, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { - checkers.get(Stage.FINAL).forEach(checker -> checker.validate(planNode, session, metadata, sqlParser, types, warningCollector)); + checkers.get(Stage.FINAL).forEach(checker -> checker.validate(planNode, session, metadata, typeAnalyzer, types, warningCollector)); } - public void validateIntermediatePlan(PlanNode planNode, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validateIntermediatePlan(PlanNode planNode, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { - checkers.get(Stage.INTERMEDIATE).forEach(checker -> checker.validate(planNode, session, metadata, sqlParser, types, warningCollector)); + checkers.get(Stage.INTERMEDIATE).forEach(checker -> checker.validate(planNode, session, metadata, typeAnalyzer, types, warningCollector)); } public interface Checker { - void validate(PlanNode planNode, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector); + void validate(PlanNode planNode, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector); } private enum Stage diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/TypeValidator.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/TypeValidator.java index daf7934aa459..30c239dc29fd 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/TypeValidator.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/TypeValidator.java @@ -21,9 +21,9 @@ import io.prestosql.spi.type.Type; import io.prestosql.spi.type.TypeManager; import io.prestosql.spi.type.TypeSignature; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.SimplePlanVisitor; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.AggregationNode; import io.prestosql.sql.planner.plan.AggregationNode.Aggregation; @@ -33,16 +33,13 @@ import io.prestosql.sql.planner.plan.WindowNode; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.FunctionCall; -import io.prestosql.sql.tree.NodeRef; import io.prestosql.sql.tree.SymbolReference; import java.util.List; import java.util.Map; import static com.google.common.base.Preconditions.checkArgument; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.type.UnknownType.UNKNOWN; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; /** @@ -54,9 +51,9 @@ public final class TypeValidator public TypeValidator() {} @Override - public void validate(PlanNode plan, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validate(PlanNode plan, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { - plan.accept(new Visitor(session, metadata, sqlParser, types, warningCollector), null); + plan.accept(new Visitor(session, metadata, typeAnalyzer, types, warningCollector), null); } private static class Visitor @@ -64,15 +61,15 @@ private static class Visitor { private final Session session; private final Metadata metadata; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; private final TypeProvider types; private final WarningCollector warningCollector; - public Visitor(Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public Visitor(Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { this.session = requireNonNull(session, "session is null"); this.metadata = requireNonNull(metadata, "metadata is null"); - this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.types = requireNonNull(types, "types is null"); this.warningCollector = requireNonNull(warningCollector, "warningCollector is null"); } @@ -119,8 +116,7 @@ public Void visitProject(ProjectNode node, Void context) verifyTypeSignature(entry.getKey(), expectedType.getTypeSignature(), types.get(Symbol.from(symbolReference)).getTypeSignature()); continue; } - Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, types, entry.getValue(), emptyList(), warningCollector); - Type actualType = expressionTypes.get(NodeRef.of(entry.getValue())); + Type actualType = typeAnalyzer.getType(session, types, entry.getValue()); verifyTypeSignature(entry.getKey(), expectedType.getTypeSignature(), actualType.getTypeSignature()); } @@ -165,8 +161,7 @@ private void checkSignature(Symbol symbol, Signature signature) private void checkCall(Symbol symbol, FunctionCall call) { Type expectedType = types.get(symbol); - Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, types, call, emptyList(), warningCollector); - Type actualType = expressionTypes.get(NodeRef.of(call)); + Type actualType = typeAnalyzer.getType(session, types, call); verifyTypeSignature(symbol, expectedType.getTypeSignature(), actualType.getTypeSignature()); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateAggregationsWithDefaultValues.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateAggregationsWithDefaultValues.java index a6f760773aaf..22152e412d34 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateAggregationsWithDefaultValues.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateAggregationsWithDefaultValues.java @@ -16,7 +16,7 @@ import io.prestosql.Session; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; -import io.prestosql.sql.parser.SqlParser; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.optimizations.ActualProperties; import io.prestosql.sql.planner.optimizations.PropertyDerivations; @@ -60,9 +60,9 @@ public ValidateAggregationsWithDefaultValues(boolean forceSingleNode) } @Override - public void validate(PlanNode planNode, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validate(PlanNode planNode, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { - planNode.accept(new Visitor(session, metadata, sqlParser, types), null); + planNode.accept(new Visitor(session, metadata, typeAnalyzer, types), null); } private class Visitor @@ -70,14 +70,14 @@ private class Visitor { final Session session; final Metadata metadata; - final SqlParser parser; + final TypeAnalyzer typeAnalyzer; final TypeProvider types; - Visitor(Session session, Metadata metadata, SqlParser parser, TypeProvider types) + Visitor(Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types) { this.session = requireNonNull(session, "session is null"); this.metadata = requireNonNull(metadata, "metadata is null"); - this.parser = requireNonNull(parser, "parser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.types = requireNonNull(types, "types is null"); } @@ -115,14 +115,14 @@ public Optional visitAggregation(AggregationNode node, Void conte // No remote repartition exchange between final and partial aggregation. // Make sure that final aggregation operators are executed on a single node. - ActualProperties globalProperties = PropertyDerivations.derivePropertiesRecursively(node, metadata, session, types, parser); + ActualProperties globalProperties = PropertyDerivations.derivePropertiesRecursively(node, metadata, session, types, typeAnalyzer); checkArgument(forceSingleNode || globalProperties.isSingleNode(), "Final aggregation with default value not separated from partial aggregation by remote hash exchange"); if (!seenExchanges.localRepartitionExchange) { // No local repartition exchange between final and partial aggregation. // Make sure that final aggregation operators are executed by single thread. - StreamProperties localProperties = StreamPropertyDerivations.derivePropertiesRecursively(node, metadata, session, types, parser); + StreamProperties localProperties = StreamPropertyDerivations.derivePropertiesRecursively(node, metadata, session, types, typeAnalyzer); checkArgument(localProperties.isSingleStream(), "Final aggregation with default value not separated from partial aggregation by local hash exchange"); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateDependenciesChecker.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateDependenciesChecker.java index 56a8175897b2..ea0949bbc717 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateDependenciesChecker.java @@ -18,9 +18,9 @@ import io.prestosql.Session; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.SymbolsExtractor; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.AggregationNode; import io.prestosql.sql.planner.plan.AggregationNode.Aggregation; @@ -85,7 +85,7 @@ public final class ValidateDependenciesChecker implements PlanSanityChecker.Checker { @Override - public void validate(PlanNode plan, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validate(PlanNode plan, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { validate(plan); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateStreamingAggregations.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateStreamingAggregations.java index 3d2c37400bac..bf0b96310413 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateStreamingAggregations.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateStreamingAggregations.java @@ -20,8 +20,8 @@ import io.prestosql.metadata.Metadata; import io.prestosql.spi.connector.GroupingProperty; import io.prestosql.spi.connector.LocalProperty; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.optimizations.LocalProperties; import io.prestosql.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties; @@ -44,9 +44,9 @@ public class ValidateStreamingAggregations implements Checker { @Override - public void validate(PlanNode planNode, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validate(PlanNode planNode, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { - planNode.accept(new Visitor(session, metadata, sqlParser, types, warningCollector), null); + planNode.accept(new Visitor(session, metadata, typeAnalyzer, types, warningCollector), null); } private static final class Visitor @@ -54,15 +54,15 @@ private static final class Visitor { private final Session session; private final Metadata metadata; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; private final TypeProvider types; private final WarningCollector warningCollector; - private Visitor(Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + private Visitor(Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { this.session = session; this.metadata = metadata; - this.sqlParser = sqlParser; + this.typeAnalyzer = typeAnalyzer; this.types = types; this.warningCollector = warningCollector; } @@ -81,7 +81,7 @@ public Void visitAggregation(AggregationNode node, Void context) return null; } - StreamProperties properties = derivePropertiesRecursively(node.getSource(), metadata, session, types, sqlParser); + StreamProperties properties = derivePropertiesRecursively(node.getSource(), metadata, session, types, typeAnalyzer); List> desiredProperties = ImmutableList.of(new GroupingProperty<>(node.getPreGroupedSymbols())); Iterator>> matchIterator = LocalProperties.match(properties.getLocalProperties(), desiredProperties).iterator(); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/VerifyNoFilteredAggregations.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/VerifyNoFilteredAggregations.java index 80999b7024dd..479e11487d8c 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/VerifyNoFilteredAggregations.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/VerifyNoFilteredAggregations.java @@ -16,7 +16,7 @@ import io.prestosql.Session; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; -import io.prestosql.sql.parser.SqlParser; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.AggregationNode; import io.prestosql.sql.planner.plan.PlanNode; @@ -27,7 +27,7 @@ public final class VerifyNoFilteredAggregations implements PlanSanityChecker.Checker { @Override - public void validate(PlanNode plan, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validate(PlanNode plan, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { searchFrom(plan) .where(AggregationNode.class::isInstance) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/VerifyOnlyOneOutputNode.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/VerifyOnlyOneOutputNode.java index db860491ae86..9c81a94d61a0 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/VerifyOnlyOneOutputNode.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/VerifyOnlyOneOutputNode.java @@ -16,7 +16,7 @@ import io.prestosql.Session; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; -import io.prestosql.sql.parser.SqlParser; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.OutputNode; import io.prestosql.sql.planner.plan.PlanNode; @@ -28,7 +28,7 @@ public final class VerifyOnlyOneOutputNode implements PlanSanityChecker.Checker { @Override - public void validate(PlanNode plan, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validate(PlanNode plan, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { int outputPlanNodesCount = searchFrom(plan) .where(OutputNode.class::isInstance) diff --git a/presto-main/src/main/java/io/prestosql/testing/LocalQueryRunner.java b/presto-main/src/main/java/io/prestosql/testing/LocalQueryRunner.java index b8c1a3c4041e..e1a63a7389e6 100644 --- a/presto-main/src/main/java/io/prestosql/testing/LocalQueryRunner.java +++ b/presto-main/src/main/java/io/prestosql/testing/LocalQueryRunner.java @@ -135,6 +135,7 @@ import io.prestosql.sql.planner.PlanNodeIdAllocator; import io.prestosql.sql.planner.PlanOptimizers; import io.prestosql.sql.planner.SubPlan; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.optimizations.PlanOptimizer; import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.planner.plan.PlanNodeId; @@ -694,7 +695,7 @@ private List createDrivers(Session session, Plan plan, OutputFactory out LocalExecutionPlanner executionPlanner = new LocalExecutionPlanner( metadata, - sqlParser, + new TypeAnalyzer(sqlParser, metadata), Optional.empty(), pageSourceManager, indexManager, @@ -806,7 +807,7 @@ public List getPlanOptimizers(boolean forceSingleNode) { return new PlanOptimizers( metadata, - sqlParser, + new TypeAnalyzer(sqlParser, metadata), featuresConfig, taskManagerConfig, forceSingleNode, @@ -844,7 +845,7 @@ public Plan createPlan(Session session, @Language("SQL") String sql, List TYPE_MAP = ImmutableMap.of("bigint", BIGINT, "varchar", VARCHAR); private static final Session TEST_SESSION = TestingSession.testSessionBuilder().build(); - private static final SqlParser SQL_PARSER = new SqlParser(); private static final Metadata METADATA = createTestMetadataManager(); + private static final TypeAnalyzer TYPE_ANALYZER = new TypeAnalyzer(new SqlParser(), METADATA); private static final int TOTAL_POSITIONS = 1_000_000; private static final DataSize FILTER_AND_PROJECT_MIN_OUTPUT_PAGE_SIZE = new DataSize(500, KILOBYTE); @@ -232,8 +229,15 @@ private RowExpression rowExpression(String value) { Expression expression = createExpression(value, METADATA, TypeProvider.copyOf(symbolTypes)); - Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, TypeProvider.copyOf(symbolTypes), expression, emptyList(), WarningCollector.NOOP); - return SqlToRowExpressionTranslator.translate(expression, SCALAR, expressionTypes, sourceLayout, METADATA.getFunctionRegistry(), METADATA.getTypeManager(), TEST_SESSION, true); + return SqlToRowExpressionTranslator.translate( + expression, + SCALAR, + TYPE_ANALYZER.getTypes(TEST_SESSION, TypeProvider.copyOf(symbolTypes), expression), + sourceLayout, + METADATA.getFunctionRegistry(), + METADATA.getTypeManager(), + TEST_SESSION, + true); } private static Page createPage(List types, int positions, boolean dictionary) diff --git a/presto-main/src/test/java/io/prestosql/operator/scalar/FunctionAssertions.java b/presto-main/src/test/java/io/prestosql/operator/scalar/FunctionAssertions.java index d4b4b2ee136e..f536c574d131 100644 --- a/presto-main/src/test/java/io/prestosql/operator/scalar/FunctionAssertions.java +++ b/presto-main/src/test/java/io/prestosql/operator/scalar/FunctionAssertions.java @@ -65,6 +65,7 @@ import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.ExpressionInterpreter; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.PlanNodeId; import io.prestosql.sql.relational.RowExpression; @@ -127,14 +128,12 @@ import static io.prestosql.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; import static io.prestosql.sql.ParsingUtil.createParsingOptions; import static io.prestosql.sql.analyzer.ExpressionAnalyzer.analyzeExpressions; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.planner.iterative.rule.CanonicalizeExpressionRewriter.canonicalizeExpression; import static io.prestosql.sql.relational.Expressions.constant; import static io.prestosql.sql.relational.SqlToRowExpressionTranslator.translate; import static io.prestosql.testing.TestingTaskContext.createTaskContext; import static io.prestosql.type.UnknownType.UNKNOWN; import static java.lang.String.format; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; @@ -211,6 +210,7 @@ public final class FunctionAssertions private final Session session; private final LocalQueryRunner runner; private final Metadata metadata; + private final TypeAnalyzer typeAnalyzer; private final ExpressionCompiler compiler; public FunctionAssertions() @@ -229,6 +229,7 @@ public FunctionAssertions(Session session, FeaturesConfig featuresConfig) runner = new LocalQueryRunner(session, featuresConfig); metadata = runner.getMetadata(); compiler = runner.getExpressionCompiler(); + typeAnalyzer = new TypeAnalyzer(SQL_PARSER, metadata); } public TypeRegistry getTypeRegistry() @@ -627,15 +628,7 @@ private List executeProjectionWithAll(String projection, Type expectedTy private RowExpression toRowExpression(Session session, Expression projectionExpression) { - Map, Type> expressionTypes = getExpressionTypes( - session, - metadata, - SQL_PARSER, - TypeProvider.copyOf(INPUT_TYPES), - projectionExpression, - ImmutableList.of(), - WarningCollector.NOOP); - return toRowExpression(projectionExpression, expressionTypes, INPUT_MAPPING); + return toRowExpression(projectionExpression, typeAnalyzer.getTypes(session, TypeProvider.copyOf(INPUT_TYPES), projectionExpression), INPUT_MAPPING); } private Object selectSingleValue(OperatorFactory operatorFactory, Type type, Session session) @@ -870,7 +863,7 @@ protected Void visitSymbolReference(SymbolReference node, Void context) private Object interpret(Expression expression, Type expectedType, Session session) { - Map, Type> expressionTypes = getExpressionTypes(session, metadata, SQL_PARSER, SYMBOL_TYPES, expression, emptyList(), WarningCollector.NOOP); + Map, Type> expressionTypes = typeAnalyzer.getTypes(session, SYMBOL_TYPES, expression); ExpressionInterpreter evaluator = ExpressionInterpreter.expressionInterpreter(expression, metadata, session, expressionTypes); Object result = evaluator.evaluate(symbol -> { diff --git a/presto-main/src/test/java/io/prestosql/sql/TestExpressionInterpreter.java b/presto-main/src/test/java/io/prestosql/sql/TestExpressionInterpreter.java index 3a2810049625..203d2900537a 100644 --- a/presto-main/src/test/java/io/prestosql/sql/TestExpressionInterpreter.java +++ b/presto-main/src/test/java/io/prestosql/sql/TestExpressionInterpreter.java @@ -18,7 +18,6 @@ import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.MetadataManager; import io.prestosql.operator.scalar.FunctionAssertions; @@ -31,6 +30,7 @@ import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.ExpressionInterpreter; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.ExpressionRewriter; @@ -69,13 +69,11 @@ import static io.prestosql.sql.ExpressionFormatter.formatExpression; import static io.prestosql.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; import static io.prestosql.sql.ParsingUtil.createParsingOptions; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.planner.ExpressionInterpreter.expressionInterpreter; import static io.prestosql.sql.planner.ExpressionInterpreter.expressionOptimizer; import static io.prestosql.type.IntervalDayTimeType.INTERVAL_DAY_TIME; import static io.prestosql.util.DateTimeZoneIndex.getDateTimeZone; import static java.lang.String.format; -import static java.util.Collections.emptyList; import static java.util.Locale.ENGLISH; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertThrows; @@ -116,6 +114,7 @@ public class TestExpressionInterpreter private static final SqlParser SQL_PARSER = new SqlParser(); private static final Metadata METADATA = MetadataManager.createTestMetadataManager(); + private static final TypeAnalyzer TYPE_ANALYZER = new TypeAnalyzer(SQL_PARSER, METADATA); @Test public void testAnd() @@ -1454,7 +1453,7 @@ private static Object optimize(@Language("SQL") String expression) Expression parsedExpression = FunctionAssertions.createExpression(expression, METADATA, SYMBOL_TYPES); - Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, SYMBOL_TYPES, parsedExpression, emptyList(), WarningCollector.NOOP); + Map, Type> expressionTypes = TYPE_ANALYZER.getTypes(TEST_SESSION, SYMBOL_TYPES, parsedExpression); ExpressionInterpreter interpreter = expressionOptimizer(parsedExpression, METADATA, TEST_SESSION, expressionTypes); return interpreter.optimize(symbol -> { switch (symbol.getName().toLowerCase(ENGLISH)) { @@ -1511,7 +1510,7 @@ private static void assertRoundTrip(String expression) private static Object evaluate(Expression expression) { - Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, SYMBOL_TYPES, expression, emptyList(), WarningCollector.NOOP); + Map, Type> expressionTypes = TYPE_ANALYZER.getTypes(TEST_SESSION, SYMBOL_TYPES, expression); ExpressionInterpreter interpreter = expressionInterpreter(expression, METADATA, TEST_SESSION, expressionTypes); return interpreter.evaluate(); diff --git a/presto-main/src/test/java/io/prestosql/sql/gen/PageProcessorBenchmark.java b/presto-main/src/test/java/io/prestosql/sql/gen/PageProcessorBenchmark.java index 459f335a3f35..7c98ae1e8693 100644 --- a/presto-main/src/test/java/io/prestosql/sql/gen/PageProcessorBenchmark.java +++ b/presto-main/src/test/java/io/prestosql/sql/gen/PageProcessorBenchmark.java @@ -17,7 +17,6 @@ import com.google.common.collect.ImmutableMap; import io.prestosql.SequencePageBuilder; import io.prestosql.Session; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.MetadataManager; import io.prestosql.operator.DriverYieldSignal; @@ -30,6 +29,7 @@ import io.prestosql.spi.type.Type; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.relational.RowExpression; import io.prestosql.sql.relational.SqlToRowExpressionTranslator; @@ -65,8 +65,6 @@ import static io.prestosql.operator.scalar.FunctionAssertions.createExpression; import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.spi.type.VarcharType.VARCHAR; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; -import static java.util.Collections.emptyList; import static java.util.Locale.ENGLISH; import static java.util.stream.Collectors.toList; @@ -80,8 +78,8 @@ public class PageProcessorBenchmark { private static final Map TYPE_MAP = ImmutableMap.of("bigint", BIGINT, "varchar", VARCHAR); - private static final SqlParser SQL_PARSER = new SqlParser(); private static final Metadata METADATA = createTestMetadataManager(); + private static final TypeAnalyzer TYPE_ANALYZER = new TypeAnalyzer(new SqlParser(), METADATA); private static final Session TEST_SESSION = TestingSession.testSessionBuilder().build(); private static final int POSITIONS = 1024; @@ -180,7 +178,7 @@ private RowExpression rowExpression(String value) { Expression expression = createExpression(value, METADATA, TypeProvider.copyOf(symbolTypes)); - Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, TypeProvider.copyOf(symbolTypes), expression, emptyList(), WarningCollector.NOOP); + Map, Type> expressionTypes = TYPE_ANALYZER.getTypes(TEST_SESSION, TypeProvider.copyOf(symbolTypes), expression); return SqlToRowExpressionTranslator.translate(expression, SCALAR, expressionTypes, sourceLayout, METADATA.getFunctionRegistry(), METADATA.getTypeManager(), TEST_SESSION, true); } diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestTypeValidator.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestTypeValidator.java index e3f7f760c431..2d26b5d0b953 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/TestTypeValidator.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/TestTypeValidator.java @@ -21,6 +21,7 @@ import io.prestosql.connector.ConnectorId; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.FunctionKind; +import io.prestosql.metadata.MetadataManager; import io.prestosql.metadata.Signature; import io.prestosql.metadata.TableHandle; import io.prestosql.spi.connector.ColumnHandle; @@ -393,7 +394,8 @@ public void testInvalidUnion() private void assertTypesValid(PlanNode node) { - TYPE_VALIDATOR.validate(node, TEST_SESSION, createTestMetadataManager(), SQL_PARSER, symbolAllocator.getTypes(), WarningCollector.NOOP); + MetadataManager metadata = createTestMetadataManager(); + TYPE_VALIDATOR.validate(node, TEST_SESSION, metadata, new TypeAnalyzer(SQL_PARSER, metadata), symbolAllocator.getTypes(), WarningCollector.NOOP); } private static PlanNodeId newId() diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushPredicateIntoTableScan.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushPredicateIntoTableScan.java index b10f0add15bc..0e12314ff29d 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushPredicateIntoTableScan.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushPredicateIntoTableScan.java @@ -26,6 +26,7 @@ import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.spi.type.Type; import io.prestosql.sql.parser.SqlParser; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -53,7 +54,7 @@ public class TestPushPredicateIntoTableScan @BeforeClass public void setUpBeforeClass() { - pushPredicateIntoTableScan = new PushPredicateIntoTableScan(tester().getMetadata(), new SqlParser()); + pushPredicateIntoTableScan = new PushPredicateIntoTableScan(tester().getMetadata(), new TypeAnalyzer(new SqlParser(), tester().getMetadata())); connectorId = tester().getCurrentConnectorId(); diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestSimplifyExpressions.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestSimplifyExpressions.java index ae0cbaa33e3e..28eb9868829d 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestSimplifyExpressions.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestSimplifyExpressions.java @@ -20,6 +20,7 @@ import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.SymbolAllocator; import io.prestosql.sql.planner.SymbolsExtractor; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.ExpressionRewriter; import io.prestosql.sql.tree.ExpressionTreeRewriter; @@ -118,7 +119,7 @@ private static void assertSimplifies(String expression, String expected) { Expression actualExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(expression)); Expression expectedExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(expected)); - Expression rewritten = rewrite(actualExpression, TEST_SESSION, new SymbolAllocator(booleanSymbolTypeMapFor(actualExpression)), METADATA, LITERAL_ENCODER, SQL_PARSER); + Expression rewritten = rewrite(actualExpression, TEST_SESSION, new SymbolAllocator(booleanSymbolTypeMapFor(actualExpression)), METADATA, LITERAL_ENCODER, new TypeAnalyzer(SQL_PARSER, METADATA)); assertEquals( normalize(rewritten), normalize(expectedExpression)); diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/RuleTester.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/RuleTester.java index 7b6c8f7212fc..b7cd186dc8c0 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/RuleTester.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/RuleTester.java @@ -22,7 +22,7 @@ import io.prestosql.spi.Plugin; import io.prestosql.split.PageSourceManager; import io.prestosql.split.SplitManager; -import io.prestosql.sql.parser.SqlParser; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.iterative.Rule; import io.prestosql.testing.LocalQueryRunner; import io.prestosql.transaction.TransactionManager; @@ -48,7 +48,7 @@ public class RuleTester private final SplitManager splitManager; private final PageSourceManager pageSourceManager; private final AccessControl accessControl; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; public RuleTester() { @@ -91,7 +91,7 @@ public RuleTester(List plugins, Map sessionProperties, O this.splitManager = queryRunner.getSplitManager(); this.pageSourceManager = queryRunner.getPageSourceManager(); this.accessControl = queryRunner.getAccessControl(); - this.sqlParser = queryRunner.getSqlParser(); + this.typeAnalyzer = new TypeAnalyzer(queryRunner.getSqlParser(), metadata); } public RuleAssert assertThat(Rule rule) @@ -120,12 +120,9 @@ public PageSourceManager getPageSourceManager() return pageSourceManager; } - // TODO: this is only being used by rules that need to get the type of an expression - // In the short term, it should be encapsulated into something that knows how to provide types - // Rules should *not* need to use the parser otherwise. - public SqlParser getSqlParser() + public TypeAnalyzer getTypeAnalyzer() { - return sqlParser; + return typeAnalyzer; } public ConnectorId getCurrentConnectorId() diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestEliminateSorts.java b/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestEliminateSorts.java index e659209b6b42..6cdc7591fbe6 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestEliminateSorts.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestEliminateSorts.java @@ -19,6 +19,7 @@ import io.prestosql.spi.block.SortOrder; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.RuleStatsRecorder; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.assertions.BasePlanTest; import io.prestosql.sql.planner.assertions.ExpectedValueProvider; import io.prestosql.sql.planner.assertions.PlanMatchPattern; @@ -89,7 +90,7 @@ public void assertUnitPlan(@Language("SQL") String sql, PlanMatchPattern pattern { List optimizers = ImmutableList.of( new UnaliasSymbolReferences(), - new AddExchanges(getQueryRunner().getMetadata(), new SqlParser()), + new AddExchanges(getQueryRunner().getMetadata(), new TypeAnalyzer(new SqlParser(), getQueryRunner().getMetadata())), new PruneUnreferencedOutputs(), new IterativeOptimizer( new RuleStatsRecorder(), diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestExpressionEquivalence.java b/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestExpressionEquivalence.java index 1c9d617c91a1..a61aad2080c6 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestExpressionEquivalence.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestExpressionEquivalence.java @@ -21,6 +21,7 @@ import io.prestosql.sql.parser.ParsingOptions; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.tree.Expression; import org.intellij.lang.annotations.Language; @@ -41,7 +42,7 @@ public class TestExpressionEquivalence { private static final SqlParser SQL_PARSER = new SqlParser(); private static final MetadataManager METADATA = MetadataManager.createTestMetadataManager(); - private static final ExpressionEquivalence EQUIVALENCE = new ExpressionEquivalence(METADATA, SQL_PARSER); + private static final ExpressionEquivalence EQUIVALENCE = new ExpressionEquivalence(METADATA, new TypeAnalyzer(SQL_PARSER, METADATA)); @Test public void testEquivalent() diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestReorderWindows.java b/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestReorderWindows.java index dfaa88c1eb4d..c4a46ae629d2 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestReorderWindows.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestReorderWindows.java @@ -18,6 +18,7 @@ import com.google.common.collect.ImmutableSet; import io.prestosql.spi.block.SortOrder; import io.prestosql.sql.planner.RuleStatsRecorder; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.assertions.BasePlanTest; import io.prestosql.sql.planner.assertions.ExpectedValueProvider; import io.prestosql.sql.planner.assertions.PlanMatchPattern; @@ -322,7 +323,7 @@ private void assertUnitPlan(@Language("SQL") String sql, PlanMatchPattern patter { List optimizers = ImmutableList.of( new UnaliasSymbolReferences(), - new PredicatePushDown(getQueryRunner().getMetadata(), getQueryRunner().getSqlParser()), + new PredicatePushDown(getQueryRunner().getMetadata(), new TypeAnalyzer(getQueryRunner().getSqlParser(), getQueryRunner().getMetadata())), new IterativeOptimizer( new RuleStatsRecorder(), getQueryRunner().getStatsCalculator(), diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateAggregationsWithDefaultValues.java b/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateAggregationsWithDefaultValues.java index 0184499637f9..2bfffc331b8e 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateAggregationsWithDefaultValues.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateAggregationsWithDefaultValues.java @@ -25,6 +25,7 @@ import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.PlanNodeIdAllocator; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.assertions.BasePlanTest; import io.prestosql.sql.planner.iterative.rule.test.PlanBuilder; @@ -48,8 +49,6 @@ public class TestValidateAggregationsWithDefaultValues extends BasePlanTest { - private static final SqlParser SQL_PARSER = new SqlParser(); - private Metadata metadata; private PlanBuilder builder; private Symbol symbol; @@ -192,7 +191,7 @@ private void validatePlan(PlanNode root, boolean forceSingleNode) getQueryRunner().inTransaction(session -> { // metadata.getCatalogHandle() registers the catalog for the transaction session.getCatalog().ifPresent(catalog -> metadata.getCatalogHandle(session, catalog)); - new ValidateAggregationsWithDefaultValues(forceSingleNode).validate(root, session, metadata, SQL_PARSER, TypeProvider.empty(), WarningCollector.NOOP); + new ValidateAggregationsWithDefaultValues(forceSingleNode).validate(root, session, metadata, new TypeAnalyzer(new SqlParser(), metadata), TypeProvider.empty(), WarningCollector.NOOP); return null; }); } diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateStreamingAggregations.java b/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateStreamingAggregations.java index ae2576516a7a..23e5e1684c3a 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateStreamingAggregations.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateStreamingAggregations.java @@ -22,8 +22,8 @@ import io.prestosql.plugin.tpch.TpchColumnHandle; import io.prestosql.plugin.tpch.TpchTableHandle; import io.prestosql.plugin.tpch.TpchTransactionHandle; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.PlanNodeIdAllocator; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.assertions.BasePlanTest; import io.prestosql.sql.planner.iterative.rule.test.PlanBuilder; @@ -41,7 +41,7 @@ public class TestValidateStreamingAggregations extends BasePlanTest { private Metadata metadata; - private SqlParser sqlParser; + private TypeAnalyzer typeAnalyzer; private PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); private TableHandle nationTableHandle; @@ -49,7 +49,7 @@ public class TestValidateStreamingAggregations public void setup() { metadata = getQueryRunner().getMetadata(); - sqlParser = getQueryRunner().getSqlParser(); + typeAnalyzer = new TypeAnalyzer(getQueryRunner().getSqlParser(), metadata); ConnectorId connectorId = getCurrentConnectorId(); nationTableHandle = new TableHandle( @@ -109,7 +109,7 @@ private void validatePlan(Function planProvider) getQueryRunner().inTransaction(session -> { // metadata.getCatalogHandle() registers the catalog for the transaction session.getCatalog().ifPresent(catalog -> metadata.getCatalogHandle(session, catalog)); - new ValidateStreamingAggregations().validate(planNode, session, metadata, sqlParser, types, WarningCollector.NOOP); + new ValidateStreamingAggregations().validate(planNode, session, metadata, typeAnalyzer, types, WarningCollector.NOOP); return null; }); } diff --git a/presto-main/src/test/java/io/prestosql/type/BenchmarkDecimalOperators.java b/presto-main/src/test/java/io/prestosql/type/BenchmarkDecimalOperators.java index cac7e0d97a23..046898b61722 100644 --- a/presto-main/src/test/java/io/prestosql/type/BenchmarkDecimalOperators.java +++ b/presto-main/src/test/java/io/prestosql/type/BenchmarkDecimalOperators.java @@ -16,7 +16,6 @@ import com.google.common.collect.ImmutableList; import io.prestosql.RowPagesBuilder; import io.prestosql.Session; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.MetadataManager; import io.prestosql.operator.DriverYieldSignal; import io.prestosql.operator.project.PageProcessor; @@ -30,11 +29,11 @@ import io.prestosql.sql.gen.PageFunctionCompiler; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.relational.RowExpression; import io.prestosql.sql.relational.SqlToRowExpressionTranslator; import io.prestosql.sql.tree.Expression; -import io.prestosql.sql.tree.NodeRef; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.Fork; import org.openjdk.jmh.annotations.Measurement; @@ -70,12 +69,10 @@ import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.spi.type.DecimalType.createDecimalType; import static io.prestosql.spi.type.DoubleType.DOUBLE; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.testing.TestingConnectorSession.SESSION; import static io.prestosql.testing.TestingSession.testSessionBuilder; import static java.math.BigInteger.ONE; import static java.math.BigInteger.ZERO; -import static java.util.Collections.emptyList; import static java.util.stream.Collectors.toList; import static org.openjdk.jmh.annotations.Scope.Thread; @@ -546,6 +543,7 @@ private Object execute(BaseState state) private static class BaseState { private final MetadataManager metadata = createTestMetadataManager(); + private final TypeAnalyzer typeAnalyzer = new TypeAnalyzer(new SqlParser(), metadata); private final Session session = testSessionBuilder().build(); private final Random random = new Random(); @@ -613,8 +611,15 @@ private RowExpression rowExpression(String value) { Expression expression = createExpression(value, metadata, TypeProvider.copyOf(symbolTypes)); - Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, metadata, SQL_PARSER, TypeProvider.copyOf(symbolTypes), expression, emptyList(), WarningCollector.NOOP); - return SqlToRowExpressionTranslator.translate(expression, SCALAR, expressionTypes, sourceLayout, metadata.getFunctionRegistry(), metadata.getTypeManager(), TEST_SESSION, true); + return SqlToRowExpressionTranslator.translate( + expression, + SCALAR, + typeAnalyzer.getTypes(TEST_SESSION, TypeProvider.copyOf(symbolTypes), expression), + sourceLayout, + metadata.getFunctionRegistry(), + metadata.getTypeManager(), + TEST_SESSION, + true); } private Object generateRandomValue(Type type) diff --git a/presto-tests/src/main/java/io/prestosql/tests/AbstractTestQueryFramework.java b/presto-tests/src/main/java/io/prestosql/tests/AbstractTestQueryFramework.java index a11c3c7abedc..45d55e5bb99a 100644 --- a/presto-tests/src/main/java/io/prestosql/tests/AbstractTestQueryFramework.java +++ b/presto-tests/src/main/java/io/prestosql/tests/AbstractTestQueryFramework.java @@ -34,6 +34,7 @@ import io.prestosql.sql.planner.Plan; import io.prestosql.sql.planner.PlanFragmenter; import io.prestosql.sql.planner.PlanOptimizers; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.optimizations.PlanOptimizer; import io.prestosql.sql.tree.ExplainType; import io.prestosql.testing.MaterializedResult; @@ -345,7 +346,7 @@ private QueryExplainer getQueryExplainer() CostCalculator costCalculator = new CostCalculatorUsingExchanges(taskCountEstimator); List optimizers = new PlanOptimizers( metadata, - sqlParser, + new TypeAnalyzer(sqlParser, metadata), featuresConfig, new TaskManagerConfig(), forceSingleNode, From 205302268eac798f9a26ac5e4fe1d44826516a3c Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Wed, 6 Mar 2019 17:42:15 -0800 Subject: [PATCH 18/18] WIP - Connector predicate pushdown --- .../metadata/FilterApplicationResult.java | 71 ++++++++ .../java/io/prestosql/metadata/Metadata.java | 4 + .../prestosql/metadata/MetadataManager.java | 44 +++++ .../io/prestosql/spi/expression/Apply.java | 43 +++++ .../spi/expression/ColumnReference.java | 34 ++++ .../spi/expression/ConnectorExpression.java | 31 ++++ .../ConnectorExpressionTranslator.java | 159 ++++++++++++++++++ .../io/prestosql/spi/expression/Constant.java | 33 ++++ .../prestosql/spi/expression/FunctionId.java | 29 ++++ .../prestosql/sql/planner/PlanOptimizers.java | 5 +- .../rule/PushFilterIntoTableScan.java | 114 +++++++++++++ .../metadata/AbstractMockMetadata.java | 7 + .../io/prestosql/tests/TestLocalQueries.java | 9 + 13 files changed, 582 insertions(+), 1 deletion(-) create mode 100644 presto-main/src/main/java/io/prestosql/metadata/FilterApplicationResult.java create mode 100644 presto-main/src/main/java/io/prestosql/spi/expression/Apply.java create mode 100644 presto-main/src/main/java/io/prestosql/spi/expression/ColumnReference.java create mode 100644 presto-main/src/main/java/io/prestosql/spi/expression/ConnectorExpression.java create mode 100644 presto-main/src/main/java/io/prestosql/spi/expression/ConnectorExpressionTranslator.java create mode 100644 presto-main/src/main/java/io/prestosql/spi/expression/Constant.java create mode 100644 presto-main/src/main/java/io/prestosql/spi/expression/FunctionId.java create mode 100644 presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushFilterIntoTableScan.java diff --git a/presto-main/src/main/java/io/prestosql/metadata/FilterApplicationResult.java b/presto-main/src/main/java/io/prestosql/metadata/FilterApplicationResult.java new file mode 100644 index 000000000000..6377680564f5 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/metadata/FilterApplicationResult.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.metadata; + +import io.prestosql.spi.connector.ColumnHandle; +import io.prestosql.spi.type.Type; +import io.prestosql.spi.expression.ConnectorExpression; + +import java.util.List; + +public class FilterApplicationResult +{ + private final TableHandle table; + private final ConnectorExpression remainingFilter; + private final List newProjections; + + public FilterApplicationResult(TableHandle table, ConnectorExpression remainingFilter, List newProjections) + { + this.table = table; + this.remainingFilter = remainingFilter; + this.newProjections = newProjections; + } + + public TableHandle getTable() + { + return table; + } + + public ConnectorExpression getRemainingFilter() + { + return remainingFilter; + } + + public List getNewProjections() + { + return newProjections; + } + + public static class Column + { + private final ColumnHandle column; + private final Type type; + + public Column(ColumnHandle column, Type type) + { + this.column = column; + this.type = type; + } + + public ColumnHandle getColumn() + { + return column; + } + + public Type getType() + { + return type; + } + } +} diff --git a/presto-main/src/main/java/io/prestosql/metadata/Metadata.java b/presto-main/src/main/java/io/prestosql/metadata/Metadata.java index 7dfeb8706ad5..257493f87c58 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/Metadata.java +++ b/presto-main/src/main/java/io/prestosql/metadata/Metadata.java @@ -37,6 +37,7 @@ import io.prestosql.spi.type.TypeManager; import io.prestosql.spi.type.TypeSignature; import io.prestosql.sql.planner.PartitioningHandle; +import io.prestosql.spi.expression.ConnectorExpression; import io.prestosql.sql.tree.QualifiedName; import java.util.Collection; @@ -380,4 +381,7 @@ public interface Metadata ColumnPropertyManager getColumnPropertyManager(); AnalyzePropertyManager getAnalyzePropertyManager(); + + // => TableHandle + remaining filter + new projections + Optional applyFilter(TableHandle table, ConnectorExpression expression); } diff --git a/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java b/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java index 990a0922556e..dcb0604b0cf6 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java +++ b/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java @@ -65,6 +65,9 @@ import io.prestosql.spi.type.TypeSignature; import io.prestosql.sql.analyzer.FeaturesConfig; import io.prestosql.sql.planner.PartitioningHandle; +import io.prestosql.spi.expression.Apply; +import io.prestosql.spi.expression.ColumnReference; +import io.prestosql.spi.expression.ConnectorExpression; import io.prestosql.sql.tree.QualifiedName; import io.prestosql.transaction.TransactionManager; import io.prestosql.type.TypeDeserializer; @@ -1151,6 +1154,47 @@ public AnalyzePropertyManager getAnalyzePropertyManager() return analyzePropertyManager; } + @Override + public Optional applyFilter(TableHandle table, ConnectorExpression expression) + { + // TODO: dispatch to connector that owns "table" + + + /////////////////////////////////// testing code + class CustomColumn implements ColumnHandle { + int id; + + public CustomColumn(int id) + { + this.id = id; + } + + @Override + public int hashCode() + { + return id; + } + + @Override + public boolean equals(Object obj) + { + return id == ((CustomColumn) obj).id; + } + } + + if (expression instanceof Apply) { + ColumnHandle column = new CustomColumn(1); + return Optional.of(new FilterApplicationResult( + table, + new ColumnReference(column, BOOLEAN), + ImmutableList.of(new FilterApplicationResult.Column(column, BOOLEAN)))); + } + /////////////////////////////////// testing code + + + return Optional.empty(); + } + private ViewDefinition deserializeView(String data) { try { diff --git a/presto-main/src/main/java/io/prestosql/spi/expression/Apply.java b/presto-main/src/main/java/io/prestosql/spi/expression/Apply.java new file mode 100644 index 000000000000..a4890dcb4c6f --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/spi/expression/Apply.java @@ -0,0 +1,43 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.spi.expression; + +import io.prestosql.spi.type.Type; + +import java.util.List; + +public class Apply + extends ConnectorExpression +{ + private final FunctionId function; + private final List arguments; + + public Apply(Type returnType, FunctionId function, List arguments) + { + super(returnType); + this.function = function; + this.arguments = arguments; + } + + // TODO: this will need to be a FunctionHandle + public FunctionId getFunction() + { + return function; + } + + public List getArguments() + { + return arguments; + } +} diff --git a/presto-main/src/main/java/io/prestosql/spi/expression/ColumnReference.java b/presto-main/src/main/java/io/prestosql/spi/expression/ColumnReference.java new file mode 100644 index 000000000000..ec8452c476df --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/spi/expression/ColumnReference.java @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.spi.expression; + +import io.prestosql.spi.connector.ColumnHandle; +import io.prestosql.spi.type.Type; + +public class ColumnReference + extends ConnectorExpression +{ + private final ColumnHandle column; + + public ColumnReference(ColumnHandle column, Type type) + { + super(type); + this.column = column; + } + + public ColumnHandle getColumn() + { + return column; + } +} diff --git a/presto-main/src/main/java/io/prestosql/spi/expression/ConnectorExpression.java b/presto-main/src/main/java/io/prestosql/spi/expression/ConnectorExpression.java new file mode 100644 index 000000000000..2c2e96f64f07 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/spi/expression/ConnectorExpression.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.spi.expression; + +import io.prestosql.spi.type.Type; + +public class ConnectorExpression +{ + private final Type type; + + public ConnectorExpression(Type type) + { + this.type = type; + } + + public Type getType() + { + return type; + } +} diff --git a/presto-main/src/main/java/io/prestosql/spi/expression/ConnectorExpressionTranslator.java b/presto-main/src/main/java/io/prestosql/spi/expression/ConnectorExpressionTranslator.java new file mode 100644 index 000000000000..c230f1b1cd53 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/spi/expression/ConnectorExpressionTranslator.java @@ -0,0 +1,159 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.spi.expression; + +import com.google.common.collect.ImmutableList; +import io.prestosql.Session; +import io.prestosql.metadata.Metadata; +import io.prestosql.spi.connector.ColumnHandle; +import io.prestosql.spi.type.Type; +import io.prestosql.sql.planner.LiteralEncoder; +import io.prestosql.sql.planner.LiteralInterpreter; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; +import io.prestosql.sql.planner.TypeProvider; +import io.prestosql.sql.tree.AstVisitor; +import io.prestosql.sql.tree.ComparisonExpression; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.FunctionCall; +import io.prestosql.sql.tree.GenericLiteral; +import io.prestosql.sql.tree.NodeRef; +import io.prestosql.sql.tree.QualifiedName; +import io.prestosql.sql.tree.StringLiteral; +import io.prestosql.sql.tree.SymbolReference; + +import java.util.Map; +import java.util.stream.Collectors; + +public class ConnectorExpressionTranslator +{ + private ConnectorExpressionTranslator() + { + } + + public static Expression translate(ConnectorExpression expression, Map mappings, Metadata metadata) + { + return new ConnectorToSqlExpressionTranslator(mappings, metadata).translate(expression); + } + + public static ConnectorExpression translate(Session session, Expression expression, Map assignments, TypeAnalyzer types, TypeProvider inputTypes, Metadata metadata) + { + return new SqlToConnectorExpressionTranslator(session, metadata, assignments, types.getTypes(session, inputTypes, expression)) + .process(expression); + } + + private static class ConnectorToSqlExpressionTranslator + { + private final Map mappings; + private final LiteralEncoder literalEncoder; + + public ConnectorToSqlExpressionTranslator(Map mappings, Metadata metadata) + { + this.mappings = mappings; + this.literalEncoder = new LiteralEncoder(metadata.getBlockEncodingSerde()); + } + + private String nameOf(FunctionId function) + { + return function.getName(); // TODO + } + + public Expression translate(ConnectorExpression expression) + { + if (expression instanceof Constant) { + return literalEncoder.toExpression(((Constant) expression).getValue(), expression.getType()); + } + + if (expression instanceof ColumnReference) { + return mappings.get(((ColumnReference) expression).getColumn()).toSymbolReference(); + } + + if (expression instanceof Apply) { + Apply apply = (Apply) expression; + + return new FunctionCall( + QualifiedName.of(nameOf(apply.getFunction())), + apply.getArguments().stream() + .map(this::translate) + .collect(Collectors.toList())); + } + + throw new UnsupportedOperationException("Expression type not supported: " + expression.getClass().getName()); + + } + } + + private static class SqlToConnectorExpressionTranslator + extends AstVisitor + { + private final Session session; + private final Metadata metadata; + private final Map assignments; + private final Map, Type> types; + + private SqlToConnectorExpressionTranslator(Session session, Metadata metadata, Map assignments, Map, Type> types) + { + this.session = session; + this.metadata = metadata; + this.assignments = assignments; + this.types = types; + } + + private Type typeOf(Expression node) + { + return types.get(NodeRef.of(node)); + } + + // TODO: need to return a FunctionHandle for the operator + private FunctionId signatureOf(ComparisonExpression.Operator operator, Type left, Type right) + { + return new FunctionId("$operator_" + operator.name()); + } + + @Override + protected ConnectorExpression visitComparisonExpression(ComparisonExpression node, Void context) + { + ConnectorExpression left = process(node.getLeft()); + ConnectorExpression right = process(node.getRight()); + + return new Apply( + typeOf(node), signatureOf(node.getOperator(), left.getType(), right.getType()), + ImmutableList.of(left, right)); + } + + @Override + protected ConnectorExpression visitSymbolReference(SymbolReference node, Void context) + { + return new ColumnReference(assignments.get(Symbol.from(node)), typeOf(node)); + } + + @Override + protected ConnectorExpression visitGenericLiteral(GenericLiteral node, Void context) + { + return new Constant(LiteralInterpreter.evaluate(metadata, session.toConnectorSession(), node), typeOf(node)); + } + + @Override + protected ConnectorExpression visitStringLiteral(StringLiteral node, Void context) + { + return new Constant(node.getSlice(), typeOf(node)); + } + + @Override + protected ConnectorExpression visitExpression(Expression node, Void context) + { + throw new UnsupportedOperationException("not yet implemented: expression translator for " + node.getClass().getName()); + } + } +} diff --git a/presto-main/src/main/java/io/prestosql/spi/expression/Constant.java b/presto-main/src/main/java/io/prestosql/spi/expression/Constant.java new file mode 100644 index 000000000000..15a91c2efec5 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/spi/expression/Constant.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.spi.expression; + +import io.prestosql.spi.type.Type; + +public class Constant + extends ConnectorExpression +{ + private final Object value; + + public Constant(Object value, Type type) + { + super(type); + this.value = value; + } + + public Object getValue() + { + return value; + } +} diff --git a/presto-main/src/main/java/io/prestosql/spi/expression/FunctionId.java b/presto-main/src/main/java/io/prestosql/spi/expression/FunctionId.java new file mode 100644 index 000000000000..88b650ce7225 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/spi/expression/FunctionId.java @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.spi.expression; + +public class FunctionId +{ + private final String identifier; + + public FunctionId(String identifier) + { + this.identifier = identifier; + } + + public String getName() + { + return identifier; + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java index 16e8817697fb..a144efc1b6fa 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java @@ -74,6 +74,7 @@ import io.prestosql.sql.planner.iterative.rule.PruneValuesColumns; import io.prestosql.sql.planner.iterative.rule.PruneWindowColumns; import io.prestosql.sql.planner.iterative.rule.PushAggregationThroughOuterJoin; +import io.prestosql.sql.planner.iterative.rule.PushFilterIntoTableScan; import io.prestosql.sql.planner.iterative.rule.PushLimitThroughMarkDistinct; import io.prestosql.sql.planner.iterative.rule.PushLimitThroughOuterJoin; import io.prestosql.sql.planner.iterative.rule.PushLimitThroughProject; @@ -203,6 +204,7 @@ public PlanOptimizers( TaskCountEstimator taskCountEstimator) { this.exporter = exporter; + ImmutableList.Builder builder = ImmutableList.builder(); Set> predicatePushDownRules = ImmutableSet.of( @@ -297,7 +299,8 @@ public PlanOptimizers( new MergeLimitWithDistinct(), new PruneCountAggregationOverScalar(), new PruneOrderByInAggregation(metadata.getFunctionRegistry()), - new RewriteSpatialPartitioningAggregation(metadata))) + new RewriteSpatialPartitioningAggregation(metadata), + new PushFilterIntoTableScan(metadata, typeAnalyzer))) .build()), simplifyOptimizer, new UnaliasSymbolReferences(), diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushFilterIntoTableScan.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushFilterIntoTableScan.java new file mode 100644 index 000000000000..0259b8b173b2 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushFilterIntoTableScan.java @@ -0,0 +1,114 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import io.prestosql.matching.Capture; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.metadata.FilterApplicationResult; +import io.prestosql.metadata.Metadata; +import io.prestosql.spi.connector.ColumnHandle; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; +import io.prestosql.spi.expression.ConnectorExpression; +import io.prestosql.spi.expression.ConnectorExpressionTranslator; +import io.prestosql.sql.planner.iterative.Rule; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.FilterNode; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.planner.plan.TableScanNode; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.sql.planner.plan.Patterns.filter; +import static io.prestosql.sql.planner.plan.Patterns.source; +import static io.prestosql.sql.planner.plan.Patterns.tableScan; + +public class PushFilterIntoTableScan + implements Rule +{ + private static final Capture TABLE_SCAN = newCapture(); + private static final Pattern PATTERN = filter().with(source().matching( + tableScan().capturedAs(TABLE_SCAN))); + + private final Metadata metadata; + private final TypeAnalyzer typeAnalyzer; + + public PushFilterIntoTableScan(Metadata metadata, TypeAnalyzer typeAnalyzer) + { + this.metadata = metadata; + this.typeAnalyzer = typeAnalyzer; + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(FilterNode filter, Captures captures, Context context) + { + TableScanNode tableScan = captures.get(TABLE_SCAN); + + ConnectorExpression expression = ConnectorExpressionTranslator.translate( + context.getSession(), + filter.getPredicate(), + tableScan.getAssignments(), + typeAnalyzer, + context.getSymbolAllocator().getTypes(), + metadata); + + Optional result = metadata.applyFilter(tableScan.getTable(), expression); + if (!result.isPresent()) { + return Result.empty(); + } + + Map mappings = new HashMap<>(); + for (Map.Entry assignment : tableScan.getAssignments().entrySet()) { + mappings.put(assignment.getValue(), assignment.getKey()); + } + + List newOutputs = new ArrayList<>(); + Map newAssignments = new HashMap<>(); + + newOutputs.addAll(tableScan.getOutputSymbols()); + newAssignments.putAll(tableScan.getAssignments()); + for (FilterApplicationResult.Column newProjection : result.get().getNewProjections()) { + Symbol symbol = context.getSymbolAllocator().newSymbol("column", newProjection.getType()); + + mappings.put(newProjection.getColumn(), symbol); + newOutputs.add(symbol); + newAssignments.put(symbol, newProjection.getColumn()); + } + + return Result.ofPlanNode( + new ProjectNode( // to preserve the schema of the transformed output + context.getIdAllocator().getNextId(), + new FilterNode( + filter.getId(), + TableScanNode.newInstance( + tableScan.getId(), + result.get().getTable(), + newOutputs, + newAssignments), + ConnectorExpressionTranslator.translate(result.get().getRemainingFilter(), mappings, metadata)), + Assignments.identity(filter.getOutputSymbols()))); + } +} diff --git a/presto-main/src/test/java/io/prestosql/metadata/AbstractMockMetadata.java b/presto-main/src/test/java/io/prestosql/metadata/AbstractMockMetadata.java index 433c7e026a4c..8184931265a1 100644 --- a/presto-main/src/test/java/io/prestosql/metadata/AbstractMockMetadata.java +++ b/presto-main/src/test/java/io/prestosql/metadata/AbstractMockMetadata.java @@ -36,6 +36,7 @@ import io.prestosql.spi.type.TypeManager; import io.prestosql.spi.type.TypeSignature; import io.prestosql.sql.planner.PartitioningHandle; +import io.prestosql.spi.expression.ConnectorExpression; import io.prestosql.sql.tree.QualifiedName; import java.util.Collection; @@ -508,4 +509,10 @@ public boolean catalogExists(Session session, String catalogName) { throw new UnsupportedOperationException(); } + + @Override + public Optional applyFilter(TableHandle table, ConnectorExpression expression) + { + return Optional.empty(); + } } diff --git a/presto-tests/src/test/java/io/prestosql/tests/TestLocalQueries.java b/presto-tests/src/test/java/io/prestosql/tests/TestLocalQueries.java index f8a25232d7a7..2864fc4d4bc1 100644 --- a/presto-tests/src/test/java/io/prestosql/tests/TestLocalQueries.java +++ b/presto-tests/src/test/java/io/prestosql/tests/TestLocalQueries.java @@ -110,4 +110,13 @@ public void testHueQueries() // https://github.com/cloudera/hue/blob/b49e98c1250c502be596667ce1f0fe118983b432/desktop/libs/notebook/src/notebook/connectors/jdbc.py#L213 assertQuerySucceeds(getSession(), "SELECT column_name, data_type, column_comment FROM information_schema.columns WHERE table_schema='local' AND TABLE_NAME='nation'"); } + + + @Test + public void testX() + { + ((LocalQueryRunner) getQueryRunner()).printPlan(); + computeActual("SELECT * FROM orders WHERE orderkey = BIGINT '1'"); + } + }