diff --git a/presto-benchmark/src/main/java/io/prestosql/benchmark/AbstractOperatorBenchmark.java b/presto-benchmark/src/main/java/io/prestosql/benchmark/AbstractOperatorBenchmark.java index 7c17c605aea..568a30387ce 100644 --- a/presto-benchmark/src/main/java/io/prestosql/benchmark/AbstractOperatorBenchmark.java +++ b/presto-benchmark/src/main/java/io/prestosql/benchmark/AbstractOperatorBenchmark.java @@ -29,8 +29,6 @@ import io.prestosql.metadata.QualifiedObjectName; import io.prestosql.metadata.Split; import io.prestosql.metadata.TableHandle; -import io.prestosql.metadata.TableLayoutHandle; -import io.prestosql.metadata.TableLayoutResult; import io.prestosql.operator.Driver; import io.prestosql.operator.DriverContext; import io.prestosql.operator.FilterAndProjectOperator; @@ -41,23 +39,25 @@ import io.prestosql.operator.TaskContext; import io.prestosql.operator.TaskStats; import io.prestosql.operator.project.InputPageProjection; -import io.prestosql.operator.project.InterpretedPageProjection; import io.prestosql.operator.project.PageProcessor; import io.prestosql.operator.project.PageProjection; import io.prestosql.security.AllowAllAccessControl; import io.prestosql.spi.QueryId; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ConnectorPageSource; -import io.prestosql.spi.connector.Constraint; import io.prestosql.spi.memory.MemoryPoolId; import io.prestosql.spi.type.Type; import io.prestosql.spiller.SpillSpaceTracker; import io.prestosql.split.SplitSource; +import io.prestosql.sql.gen.PageFunctionCompiler; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.optimizations.HashGenerationOptimizer; import io.prestosql.sql.planner.plan.PlanNodeId; +import io.prestosql.sql.relational.RowExpression; import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.NodeRef; import io.prestosql.testing.LocalQueryRunner; import io.prestosql.transaction.TransactionId; @@ -79,9 +79,11 @@ import static io.airlift.units.DataSize.Unit.MEGABYTE; import static io.prestosql.SystemSessionProperties.getFilterAndProjectMinOutputPageRowCount; import static io.prestosql.SystemSessionProperties.getFilterAndProjectMinOutputPageSize; +import static io.prestosql.metadata.FunctionKind.SCALAR; import static io.prestosql.spi.connector.ConnectorSplitManager.SplitSchedulingStrategy.UNGROUPED_SCHEDULING; import static io.prestosql.spi.connector.NotPartitionedPartitionHandle.NOT_PARTITIONED; import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.sql.relational.SqlToRowExpressionTranslator.translate; import static io.prestosql.testing.TestingSession.testSessionBuilder; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -170,8 +172,7 @@ protected final OperatorFactory createTableScanOperator(int operatorId, PlanNode List columnHandles = columnHandlesBuilder.build(); // get the split for this table - List layouts = metadata.getLayouts(session, tableHandle, Constraint.alwaysTrue(), Optional.empty()); - Split split = getLocalQuerySplit(session, layouts.get(0).getLayout().getHandle()); + Split split = getLocalQuerySplit(session, tableHandle); return new OperatorFactory() { @@ -196,7 +197,7 @@ public OperatorFactory duplicate() }; } - private Split getLocalQuerySplit(Session session, TableLayoutHandle handle) + private Split getLocalQuerySplit(Session session, TableHandle handle) { SplitSource splitSource = localQueryRunner.getSplitManager().getSplits(session, handle, UNGROUPED_SCHEDULING); List splits = new ArrayList<>(); @@ -226,13 +227,14 @@ protected final OperatorFactory createHashProjectOperator(int operatorId, PlanNo Optional hashExpression = HashGenerationOptimizer.getHashExpression(ImmutableList.copyOf(symbolTypes.build().keySet())); verify(hashExpression.isPresent()); - projections.add(new InterpretedPageProjection( - hashExpression.get(), - TypeProvider.copyOf(symbolTypes.build()), - symbolToInputMapping.build(), - localQueryRunner.getMetadata(), - localQueryRunner.getSqlParser(), - session)); + + Map, Type> expressionTypes = new TypeAnalyzer(localQueryRunner.getSqlParser(), localQueryRunner.getMetadata()) + .getTypes(session, TypeProvider.copyOf(symbolTypes.build()), hashExpression.get()); + + RowExpression translated = translate(hashExpression.get(), SCALAR, expressionTypes, symbolToInputMapping.build(), localQueryRunner.getMetadata().getFunctionRegistry(), localQueryRunner.getTypeManager(), session, false); + + PageFunctionCompiler functionCompiler = new PageFunctionCompiler(localQueryRunner.getMetadata(), 0); + projections.add(functionCompiler.compileProjection(translated, Optional.empty()).get()); return new FilterAndProjectOperator.FilterAndProjectOperatorFactory( operatorId, diff --git a/presto-geospatial/src/test/java/io/prestosql/plugin/geospatial/TestExtractSpatialInnerJoin.java b/presto-geospatial/src/test/java/io/prestosql/plugin/geospatial/TestExtractSpatialInnerJoin.java index f25c31450df..b1608928882 100644 --- a/presto-geospatial/src/test/java/io/prestosql/plugin/geospatial/TestExtractSpatialInnerJoin.java +++ b/presto-geospatial/src/test/java/io/prestosql/plugin/geospatial/TestExtractSpatialInnerJoin.java @@ -389,6 +389,6 @@ public void testPushDownAnd() private RuleAssert assertRuleApplication() { RuleTester tester = tester(); - return tester.assertThat(new ExtractSpatialInnerJoin(tester.getMetadata(), tester.getSplitManager(), tester.getPageSourceManager(), tester.getSqlParser())); + return tester.assertThat(new ExtractSpatialInnerJoin(tester.getMetadata(), tester.getSplitManager(), tester.getPageSourceManager(), tester.getTypeAnalyzer())); } } diff --git a/presto-geospatial/src/test/java/io/prestosql/plugin/geospatial/TestExtractSpatialLeftJoin.java b/presto-geospatial/src/test/java/io/prestosql/plugin/geospatial/TestExtractSpatialLeftJoin.java index d5eac82517f..44e8374fc25 100644 --- a/presto-geospatial/src/test/java/io/prestosql/plugin/geospatial/TestExtractSpatialLeftJoin.java +++ b/presto-geospatial/src/test/java/io/prestosql/plugin/geospatial/TestExtractSpatialLeftJoin.java @@ -258,6 +258,6 @@ public void testPushDownAnd() private RuleAssert assertRuleApplication() { RuleTester tester = tester(); - return tester().assertThat(new ExtractSpatialLeftJoin(tester.getMetadata(), tester.getSplitManager(), tester.getPageSourceManager(), tester.getSqlParser())); + return tester().assertThat(new ExtractSpatialLeftJoin(tester.getMetadata(), tester.getSplitManager(), tester.getPageSourceManager(), tester.getTypeAnalyzer())); } } diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveIntegrationSmokeTest.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveIntegrationSmokeTest.java index 6f5e69a4cd7..2e09bf5baba 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveIntegrationSmokeTest.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveIntegrationSmokeTest.java @@ -24,13 +24,12 @@ import io.prestosql.metadata.Metadata; import io.prestosql.metadata.QualifiedObjectName; import io.prestosql.metadata.TableHandle; -import io.prestosql.metadata.TableLayout; -import io.prestosql.metadata.TableLayoutResult; import io.prestosql.metadata.TableMetadata; import io.prestosql.plugin.hive.HiveSessionProperties.InsertExistingPartitionsBehavior; import io.prestosql.spi.connector.CatalogSchemaTableName; import io.prestosql.spi.connector.ColumnMetadata; import io.prestosql.spi.connector.ConnectorSession; +import io.prestosql.spi.connector.ConnectorTableLayoutHandle; import io.prestosql.spi.connector.Constraint; import io.prestosql.spi.security.Identity; import io.prestosql.spi.security.SelectedRole; @@ -1673,9 +1672,13 @@ private Object getHiveTableProperty(String tableName, Function tableHandle = metadata.getTableHandle(transactionSession, new QualifiedObjectName(catalog, TPCH_SCHEMA, tableName)); assertTrue(tableHandle.isPresent()); - List layouts = metadata.getLayouts(transactionSession, tableHandle.get(), Constraint.alwaysTrue(), Optional.empty()); - TableLayout layout = getOnlyElement(layouts).getLayout(); - return propertyGetter.apply((HiveTableLayoutHandle) layout.getHandle().getConnectorHandle()); + ConnectorTableLayoutHandle connectorLayout = metadata.getLayout(transactionSession, tableHandle.get(), Constraint.alwaysTrue(), Optional.empty()) + .get() + .getNewTableHandle() + .getLayout() + .get(); + + return propertyGetter.apply((HiveTableLayoutHandle) connectorLayout); }); } diff --git a/presto-main/src/main/java/io/prestosql/execution/SqlQueryExecution.java b/presto-main/src/main/java/io/prestosql/execution/SqlQueryExecution.java index a46e38ff801..52ba8216f3a 100644 --- a/presto-main/src/main/java/io/prestosql/execution/SqlQueryExecution.java +++ b/presto-main/src/main/java/io/prestosql/execution/SqlQueryExecution.java @@ -62,6 +62,7 @@ import io.prestosql.sql.planner.PlanOptimizers; import io.prestosql.sql.planner.StageExecutionPlan; import io.prestosql.sql.planner.SubPlan; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.optimizations.PlanOptimizer; import io.prestosql.sql.tree.Explain; import io.prestosql.transaction.TransactionManager; @@ -414,7 +415,7 @@ private PlanRoot doAnalyzeQuery() // plan query PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); - LogicalPlanner logicalPlanner = new LogicalPlanner(stateMachine.getSession(), planOptimizers, idAllocator, metadata, sqlParser, statsCalculator, costCalculator, stateMachine.getWarningCollector()); + LogicalPlanner logicalPlanner = new LogicalPlanner(stateMachine.getSession(), planOptimizers, idAllocator, metadata, new TypeAnalyzer(sqlParser, metadata), statsCalculator, costCalculator, stateMachine.getWarningCollector()); Plan plan = logicalPlanner.plan(analysis); queryPlan.set(plan); diff --git a/presto-main/src/main/java/io/prestosql/metadata/FilterApplicationResult.java b/presto-main/src/main/java/io/prestosql/metadata/FilterApplicationResult.java new file mode 100644 index 00000000000..6377680564f --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/metadata/FilterApplicationResult.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.metadata; + +import io.prestosql.spi.connector.ColumnHandle; +import io.prestosql.spi.type.Type; +import io.prestosql.spi.expression.ConnectorExpression; + +import java.util.List; + +public class FilterApplicationResult +{ + private final TableHandle table; + private final ConnectorExpression remainingFilter; + private final List newProjections; + + public FilterApplicationResult(TableHandle table, ConnectorExpression remainingFilter, List newProjections) + { + this.table = table; + this.remainingFilter = remainingFilter; + this.newProjections = newProjections; + } + + public TableHandle getTable() + { + return table; + } + + public ConnectorExpression getRemainingFilter() + { + return remainingFilter; + } + + public List getNewProjections() + { + return newProjections; + } + + public static class Column + { + private final ColumnHandle column; + private final Type type; + + public Column(ColumnHandle column, Type type) + { + this.column = column; + this.type = type; + } + + public ColumnHandle getColumn() + { + return column; + } + + public Type getType() + { + return type; + } + } +} diff --git a/presto-main/src/main/java/io/prestosql/metadata/Metadata.java b/presto-main/src/main/java/io/prestosql/metadata/Metadata.java index 93aad49cd02..257493f87c5 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/Metadata.java +++ b/presto-main/src/main/java/io/prestosql/metadata/Metadata.java @@ -37,6 +37,7 @@ import io.prestosql.spi.type.TypeManager; import io.prestosql.spi.type.TypeSignature; import io.prestosql.sql.planner.PartitioningHandle; +import io.prestosql.spi.expression.ConnectorExpression; import io.prestosql.sql.tree.QualifiedName; import java.util.Collection; @@ -73,25 +74,25 @@ public interface Metadata Optional getTableHandleForStatisticsCollection(Session session, QualifiedObjectName tableName, Map analyzeProperties); - List getLayouts(Session session, TableHandle tableHandle, Constraint constraint, Optional> desiredColumns); + Optional getLayout(Session session, TableHandle tableHandle, Constraint constraint, Optional> desiredColumns); - TableLayout getLayout(Session session, TableLayoutHandle handle); + TableLayout getLayout(Session session, TableHandle handle); /** - * Return a table layout handle whose partitioning is converted to the provided partitioning handle, - * but otherwise identical to the provided table layout handle. - * The provided table layout handle must be one that the connector can transparently convert to from - * the original partitioning handle associated with the provided table layout handle, + * Return a table handle whose partitioning is converted to the provided partitioning handle, + * but otherwise identical to the provided table handle. + * The provided table handle must be one that the connector can transparently convert to from + * the original partitioning handle associated with the provided table handle, * as promised by {@link #getCommonPartitioning}. */ - TableLayoutHandle makeCompatiblePartitioning(Session session, TableLayoutHandle tableLayoutHandle, PartitioningHandle partitioningHandle); + TableHandle makeCompatiblePartitioning(Session session, TableHandle table, PartitioningHandle partitioningHandle); /** * Return a partitioning handle which the connector can transparently convert both {@code left} and {@code right} into. */ Optional getCommonPartitioning(Session session, PartitioningHandle left, PartitioningHandle right); - Optional getInfo(Session session, TableLayoutHandle handle); + Optional getInfo(Session session, TableHandle handle); /** * Return the metadata for the specified table handle. @@ -241,14 +242,14 @@ public interface Metadata /** * @return whether delete without table scan is supported */ - boolean supportsMetadataDelete(Session session, TableHandle tableHandle, TableLayoutHandle tableLayoutHandle); + boolean supportsMetadataDelete(Session session, TableHandle tableHandle); /** * Delete the provide table layout * * @return number of rows deleted, or empty for unknown */ - OptionalLong metadataDelete(Session session, TableHandle tableHandle, TableLayoutHandle tableLayoutHandle); + OptionalLong metadataDelete(Session session, TableHandle tableHandle); /** * Begin delete query @@ -380,4 +381,7 @@ public interface Metadata ColumnPropertyManager getColumnPropertyManager(); AnalyzePropertyManager getAnalyzePropertyManager(); + + // => TableHandle + remaining filter + new projections + Optional applyFilter(TableHandle table, ConnectorExpression expression); } diff --git a/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java b/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java index 69a2cee9b47..dcb0604b0cf 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java +++ b/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java @@ -65,6 +65,9 @@ import io.prestosql.spi.type.TypeSignature; import io.prestosql.sql.analyzer.FeaturesConfig; import io.prestosql.sql.planner.PartitioningHandle; +import io.prestosql.spi.expression.Apply; +import io.prestosql.spi.expression.ColumnReference; +import io.prestosql.spi.expression.ConnectorExpression; import io.prestosql.sql.tree.QualifiedName; import io.prestosql.transaction.TransactionManager; import io.prestosql.type.TypeDeserializer; @@ -88,10 +91,8 @@ import java.util.concurrent.ConcurrentMap; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.prestosql.metadata.QualifiedObjectName.convertFromSchemaTableName; -import static io.prestosql.metadata.TableLayout.fromConnectorLayout; import static io.prestosql.metadata.ViewDefinition.ViewColumn; import static io.prestosql.spi.StandardErrorCode.INVALID_VIEW; import static io.prestosql.spi.StandardErrorCode.NOT_SUPPORTED; @@ -326,9 +327,14 @@ public Optional getTableHandle(Session session, QualifiedObjectName ConnectorId connectorId = catalogMetadata.getConnectorId(session, table); ConnectorMetadata metadata = catalogMetadata.getMetadataFor(connectorId); - ConnectorTableHandle tableHandle = metadata.getTableHandle(session.toConnectorSession(connectorId), table.asSchemaTableName()); + ConnectorSession connectorSession = session.toConnectorSession(connectorId); + ConnectorTableHandle tableHandle = metadata.getTableHandle(connectorSession, table.asSchemaTableName()); if (tableHandle != null) { - return Optional.of(new TableHandle(connectorId, tableHandle)); + return Optional.of(new TableHandle( + connectorId, + tableHandle, + catalogMetadata.getTransactionHandleFor(connectorId), + Optional.empty())); } } return Optional.empty(); @@ -347,7 +353,11 @@ public Optional getTableHandleForStatisticsCollection(Session sessi ConnectorTableHandle tableHandle = metadata.getTableHandleForStatisticsCollection(session.toConnectorSession(connectorId), table.asSchemaTableName(), analyzeProperties); if (tableHandle != null) { - return Optional.of(new TableHandle(connectorId, tableHandle)); + return Optional.of(new TableHandle( + connectorId, + tableHandle, + catalogMetadata.getTransactionHandleFor(connectorId), + Optional.empty())); } } return Optional.empty(); @@ -373,10 +383,10 @@ public Optional getSystemTable(Session session, QualifiedObjectName } @Override - public List getLayouts(Session session, TableHandle table, Constraint constraint, Optional> desiredColumns) + public Optional getLayout(Session session, TableHandle table, Constraint constraint, Optional> desiredColumns) { if (constraint.getSummary().isNone()) { - return ImmutableList.of(); + return Optional.empty(); } ConnectorId connectorId = table.getConnectorId(); @@ -387,33 +397,47 @@ public List getLayouts(Session session, TableHandle table, Co ConnectorTransactionHandle transaction = catalogMetadata.getTransactionHandleFor(connectorId); ConnectorSession connectorSession = session.toConnectorSession(connectorId); List layouts = metadata.getTableLayouts(connectorSession, connectorTable, constraint, desiredColumns); + if (layouts.isEmpty()) { + return Optional.empty(); + } - return layouts.stream() - .map(layout -> new TableLayoutResult(fromConnectorLayout(connectorId, transaction, layout.getTableLayout()), layout.getUnenforcedConstraint())) - .collect(toImmutableList()); + if (layouts.size() > 1) { + throw new PrestoException(NOT_SUPPORTED, format("Connector returned multiple layouts for table %s", table)); + } + + ConnectorTableLayout tableLayout = layouts.get(0).getTableLayout(); + return Optional.of(new TableLayoutResult( + new TableHandle(connectorId, connectorTable, transaction, Optional.of(tableLayout.getHandle())), + new TableLayout(connectorId, transaction, tableLayout), + layouts.get(0).getUnenforcedConstraint())); } @Override - public TableLayout getLayout(Session session, TableLayoutHandle handle) + public TableLayout getLayout(Session session, TableHandle handle) { ConnectorId connectorId = handle.getConnectorId(); CatalogMetadata catalogMetadata = getCatalogMetadata(session, connectorId); ConnectorMetadata metadata = catalogMetadata.getMetadataFor(connectorId); - ConnectorTransactionHandle transaction = catalogMetadata.getTransactionHandleFor(connectorId); - return fromConnectorLayout(connectorId, transaction, metadata.getTableLayout(session.toConnectorSession(connectorId), handle.getConnectorHandle())); + ConnectorSession connectorSession = session.toConnectorSession(connectorId); + + return handle.getLayout() + .map(layout -> new TableLayout(connectorId, handle.getTransaction(), metadata.getTableLayout(connectorSession, layout))) + .orElseGet(() -> getLayout(session, handle, Constraint.alwaysTrue(), Optional.empty()) + .get() + .getLayout()); } @Override - public TableLayoutHandle makeCompatiblePartitioning(Session session, TableLayoutHandle tableLayoutHandle, PartitioningHandle partitioningHandle) + public TableHandle makeCompatiblePartitioning(Session session, TableHandle tableHandle, PartitioningHandle partitioningHandle) { checkArgument(partitioningHandle.getConnectorId().isPresent(), "Expect partitioning handle from connector, got system partitioning handle"); ConnectorId connectorId = partitioningHandle.getConnectorId().get(); - checkArgument(connectorId.equals(tableLayoutHandle.getConnectorId()), "ConnectorId of tableLayoutHandle and partitioningHandle does not match"); + checkArgument(connectorId.equals(tableHandle.getConnectorId()), "ConnectorId of tableHandle and partitioningHandle does not match"); CatalogMetadata catalogMetadata = getCatalogMetadata(session, connectorId); ConnectorMetadata metadata = catalogMetadata.getMetadataFor(connectorId); ConnectorTransactionHandle transaction = catalogMetadata.getTransactionHandleFor(connectorId); - ConnectorTableLayoutHandle newTableLayoutHandle = metadata.makeCompatiblePartitioning(session.toConnectorSession(connectorId), tableLayoutHandle.getConnectorHandle(), partitioningHandle.getConnectorHandle()); - return new TableLayoutHandle(connectorId, transaction, newTableLayoutHandle); + ConnectorTableLayoutHandle newTableLayoutHandle = metadata.makeCompatiblePartitioning(session.toConnectorSession(connectorId), tableHandle.getLayout().get(), partitioningHandle.getConnectorHandle()); + return new TableHandle(connectorId, tableHandle.getConnectorHandle(), transaction, Optional.of(newTableLayoutHandle)); } @Override @@ -435,12 +459,19 @@ public Optional getCommonPartitioning(Session session, Parti } @Override - public Optional getInfo(Session session, TableLayoutHandle handle) + public Optional getInfo(Session session, TableHandle handle) { ConnectorId connectorId = handle.getConnectorId(); ConnectorMetadata metadata = getMetadata(session, connectorId); - ConnectorTableLayout tableLayout = metadata.getTableLayout(session.toConnectorSession(connectorId), handle.getConnectorHandle()); - return metadata.getInfo(tableLayout.getHandle()); + + ConnectorTableLayoutHandle layoutHandle = handle.getLayout() + .orElseGet(() -> getLayout(session, handle, Constraint.alwaysTrue(), Optional.empty()) + .get() + .getNewTableHandle() + .getLayout() + .get()); + + return metadata.getInfo(layoutHandle); } @Override @@ -781,22 +812,22 @@ public ColumnHandle getUpdateRowIdColumnHandle(Session session, TableHandle tabl } @Override - public boolean supportsMetadataDelete(Session session, TableHandle tableHandle, TableLayoutHandle tableLayoutHandle) + public boolean supportsMetadataDelete(Session session, TableHandle tableHandle) { ConnectorId connectorId = tableHandle.getConnectorId(); ConnectorMetadata metadata = getMetadata(session, connectorId); return metadata.supportsMetadataDelete( session.toConnectorSession(connectorId), tableHandle.getConnectorHandle(), - tableLayoutHandle.getConnectorHandle()); + tableHandle.getLayout().get()); } @Override - public OptionalLong metadataDelete(Session session, TableHandle tableHandle, TableLayoutHandle tableLayoutHandle) + public OptionalLong metadataDelete(Session session, TableHandle tableHandle) { ConnectorId connectorId = tableHandle.getConnectorId(); ConnectorMetadata metadata = getMetadataForWrite(session, connectorId); - return metadata.metadataDelete(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle(), tableLayoutHandle.getConnectorHandle()); + return metadata.metadataDelete(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle(), tableHandle.getLayout().get()); } @Override @@ -805,7 +836,7 @@ public TableHandle beginDelete(Session session, TableHandle tableHandle) ConnectorId connectorId = tableHandle.getConnectorId(); ConnectorMetadata metadata = getMetadataForWrite(session, connectorId); ConnectorTableHandle newHandle = metadata.beginDelete(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle()); - return new TableHandle(tableHandle.getConnectorId(), newHandle); + return new TableHandle(tableHandle.getConnectorId(), newHandle, tableHandle.getTransaction(), tableHandle.getLayout()); } @Override @@ -1123,6 +1154,47 @@ public AnalyzePropertyManager getAnalyzePropertyManager() return analyzePropertyManager; } + @Override + public Optional applyFilter(TableHandle table, ConnectorExpression expression) + { + // TODO: dispatch to connector that owns "table" + + + /////////////////////////////////// testing code + class CustomColumn implements ColumnHandle { + int id; + + public CustomColumn(int id) + { + this.id = id; + } + + @Override + public int hashCode() + { + return id; + } + + @Override + public boolean equals(Object obj) + { + return id == ((CustomColumn) obj).id; + } + } + + if (expression instanceof Apply) { + ColumnHandle column = new CustomColumn(1); + return Optional.of(new FilterApplicationResult( + table, + new ColumnReference(column, BOOLEAN), + ImmutableList.of(new FilterApplicationResult.Column(column, BOOLEAN)))); + } + /////////////////////////////////// testing code + + + return Optional.empty(); + } + private ViewDefinition deserializeView(String data) { try { diff --git a/presto-main/src/main/java/io/prestosql/metadata/TableHandle.java b/presto-main/src/main/java/io/prestosql/metadata/TableHandle.java index a117a664d0b..dd987b7c01b 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/TableHandle.java +++ b/presto-main/src/main/java/io/prestosql/metadata/TableHandle.java @@ -17,8 +17,10 @@ import com.fasterxml.jackson.annotation.JsonProperty; import io.prestosql.connector.ConnectorId; import io.prestosql.spi.connector.ConnectorTableHandle; +import io.prestosql.spi.connector.ConnectorTableLayoutHandle; +import io.prestosql.spi.connector.ConnectorTransactionHandle; -import java.util.Objects; +import java.util.Optional; import static java.util.Objects.requireNonNull; @@ -26,14 +28,23 @@ public final class TableHandle { private final ConnectorId connectorId; private final ConnectorTableHandle connectorHandle; + private final ConnectorTransactionHandle transaction; + + // Table layouts are deprecated, but we keep this here to hide the notion of layouts + // from the engine. TODO: it should be removed once table layouts are finally deleted + private final Optional layout; @JsonCreator public TableHandle( @JsonProperty("connectorId") ConnectorId connectorId, - @JsonProperty("connectorHandle") ConnectorTableHandle connectorHandle) + @JsonProperty("connectorHandle") ConnectorTableHandle connectorHandle, + @JsonProperty("transaction") ConnectorTransactionHandle transaction, + @JsonProperty("layout") Optional layout) { this.connectorId = requireNonNull(connectorId, "connectorId is null"); this.connectorHandle = requireNonNull(connectorHandle, "connectorHandle is null"); + this.transaction = requireNonNull(transaction, "transaction is null"); + this.layout = requireNonNull(layout, "layout is null"); } @JsonProperty @@ -48,24 +59,16 @@ public ConnectorTableHandle getConnectorHandle() return connectorHandle; } - @Override - public int hashCode() + @JsonProperty + public Optional getLayout() { - return Objects.hash(connectorId, connectorHandle); + return layout; } - @Override - public boolean equals(Object obj) + @JsonProperty + public ConnectorTransactionHandle getTransaction() { - if (this == obj) { - return true; - } - if (obj == null || getClass() != obj.getClass()) { - return false; - } - final TableHandle other = (TableHandle) obj; - return Objects.equals(this.connectorId, other.connectorId) && - Objects.equals(this.connectorHandle, other.connectorHandle); + return transaction; } @Override diff --git a/presto-main/src/main/java/io/prestosql/metadata/TableLayout.java b/presto-main/src/main/java/io/prestosql/metadata/TableLayout.java index 954280af025..2d1c238cf0e 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/TableLayout.java +++ b/presto-main/src/main/java/io/prestosql/metadata/TableLayout.java @@ -32,15 +32,18 @@ public class TableLayout { - private final TableLayoutHandle handle; private final ConnectorTableLayout layout; + private final ConnectorId connectorId; + private final ConnectorTransactionHandle transaction; - public TableLayout(TableLayoutHandle handle, ConnectorTableLayout layout) + public TableLayout(ConnectorId connectorId, ConnectorTransactionHandle transaction, ConnectorTableLayout layout) { - requireNonNull(handle, "handle is null"); + requireNonNull(connectorId, "connectorId is null"); + requireNonNull(transaction, "transaction is null"); requireNonNull(layout, "layout is null"); - this.handle = handle; + this.connectorId = connectorId; + this.transaction = transaction; this.layout = layout; } @@ -59,18 +62,13 @@ public List> getLocalProperties() return layout.getLocalProperties(); } - public TableLayoutHandle getHandle() - { - return handle; - } - public Optional getTablePartitioning() { return layout.getTablePartitioning() .map(nodePartitioning -> new TablePartitioning( new PartitioningHandle( - Optional.of(handle.getConnectorId()), - Optional.of(handle.getTransactionHandle()), + Optional.of(connectorId), + Optional.of(transaction), nodePartitioning.getPartitioningHandle()), nodePartitioning.getPartitioningColumns())); } @@ -85,11 +83,6 @@ public Optional getDiscretePredicates() return layout.getDiscretePredicates(); } - public static TableLayout fromConnectorLayout(ConnectorId connectorId, ConnectorTransactionHandle transactionHandle, ConnectorTableLayout layout) - { - return new TableLayout(new TableLayoutHandle(connectorId, transactionHandle, layout.getHandle()), layout); - } - public static class TablePartitioning { private final PartitioningHandle partitioningHandle; diff --git a/presto-main/src/main/java/io/prestosql/metadata/TableLayoutHandle.java b/presto-main/src/main/java/io/prestosql/metadata/TableLayoutHandle.java deleted file mode 100644 index 7b1cae2e23f..00000000000 --- a/presto-main/src/main/java/io/prestosql/metadata/TableLayoutHandle.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prestosql.metadata; - -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonProperty; -import io.prestosql.connector.ConnectorId; -import io.prestosql.spi.connector.ConnectorTableLayoutHandle; -import io.prestosql.spi.connector.ConnectorTransactionHandle; - -import java.util.Objects; - -import static java.util.Objects.requireNonNull; - -public final class TableLayoutHandle -{ - private final ConnectorId connectorId; - private final ConnectorTransactionHandle transactionHandle; - private final ConnectorTableLayoutHandle layout; - - @JsonCreator - public TableLayoutHandle( - @JsonProperty("connectorId") ConnectorId connectorId, - @JsonProperty("transactionHandle") ConnectorTransactionHandle transactionHandle, - @JsonProperty("connectorHandle") ConnectorTableLayoutHandle layout) - { - requireNonNull(connectorId, "connectorId is null"); - requireNonNull(transactionHandle, "transactionHandle is null"); - requireNonNull(layout, "layout is null"); - - this.connectorId = connectorId; - this.transactionHandle = transactionHandle; - this.layout = layout; - } - - @JsonProperty - public ConnectorId getConnectorId() - { - return connectorId; - } - - @JsonProperty - public ConnectorTransactionHandle getTransactionHandle() - { - return transactionHandle; - } - - @JsonProperty - public ConnectorTableLayoutHandle getConnectorHandle() - { - return layout; - } - - @Override - public boolean equals(Object o) - { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - TableLayoutHandle that = (TableLayoutHandle) o; - return Objects.equals(connectorId, that.connectorId) && - Objects.equals(transactionHandle, that.transactionHandle) && - Objects.equals(layout, that.layout); - } - - @Override - public int hashCode() - { - return Objects.hash(connectorId, transactionHandle, layout); - } -} diff --git a/presto-main/src/main/java/io/prestosql/metadata/TableLayoutResult.java b/presto-main/src/main/java/io/prestosql/metadata/TableLayoutResult.java index bcf87513817..dc02158054c 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/TableLayoutResult.java +++ b/presto-main/src/main/java/io/prestosql/metadata/TableLayoutResult.java @@ -14,28 +14,31 @@ package io.prestosql.metadata; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.predicate.Domain; import io.prestosql.spi.predicate.TupleDomain; -import io.prestosql.sql.planner.plan.TableScanNode; -import java.util.List; import java.util.Map; -import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; public class TableLayoutResult { + private final TableHandle newTableHandle; private final TableLayout layout; private final TupleDomain unenforcedConstraint; - public TableLayoutResult(TableLayout layout, TupleDomain unenforcedConstraint) + public TableLayoutResult(TableHandle newTable, TableLayout layout, TupleDomain unenforcedConstraint) { - this.layout = layout; - this.unenforcedConstraint = unenforcedConstraint; + this.newTableHandle = requireNonNull(newTable, "newTable is null"); + this.layout = requireNonNull(layout, "layout is null"); + this.unenforcedConstraint = requireNonNull(unenforcedConstraint, "unenforcedConstraint is null"); + } + + public TableHandle getNewTableHandle() + { + return newTableHandle; } public TableLayout getLayout() @@ -48,19 +51,6 @@ public TupleDomain getUnenforcedConstraint() return unenforcedConstraint; } - public boolean hasAllOutputs(TableScanNode node) - { - if (!layout.getColumns().isPresent()) { - return true; - } - Set columns = ImmutableSet.copyOf(layout.getColumns().get()); - List nodeColumnHandles = node.getOutputSymbols().stream() - .map(node.getAssignments()::get) - .collect(toImmutableList()); - - return columns.containsAll(nodeColumnHandles); - } - public static TupleDomain computeEnforced(TupleDomain predicate, TupleDomain unenforced) { if (predicate.isNone()) { diff --git a/presto-main/src/main/java/io/prestosql/operator/MetadataDeleteOperator.java b/presto-main/src/main/java/io/prestosql/operator/MetadataDeleteOperator.java index d264d2e1953..fd174e895e3 100644 --- a/presto-main/src/main/java/io/prestosql/operator/MetadataDeleteOperator.java +++ b/presto-main/src/main/java/io/prestosql/operator/MetadataDeleteOperator.java @@ -17,7 +17,6 @@ import io.prestosql.Session; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.TableHandle; -import io.prestosql.metadata.TableLayoutHandle; import io.prestosql.spi.Page; import io.prestosql.spi.PageBuilder; import io.prestosql.spi.block.BlockBuilder; @@ -41,17 +40,15 @@ public static class MetadataDeleteOperatorFactory { private final int operatorId; private final PlanNodeId planNodeId; - private final TableLayoutHandle tableLayout; private final Metadata metadata; private final Session session; private final TableHandle tableHandle; private boolean closed; - public MetadataDeleteOperatorFactory(int operatorId, PlanNodeId planNodeId, TableLayoutHandle tableLayout, Metadata metadata, Session session, TableHandle tableHandle) + public MetadataDeleteOperatorFactory(int operatorId, PlanNodeId planNodeId, Metadata metadata, Session session, TableHandle tableHandle) { this.operatorId = operatorId; this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); - this.tableLayout = requireNonNull(tableLayout, "tableLayout is null"); this.metadata = requireNonNull(metadata, "metadata is null"); this.session = requireNonNull(session, "session is null"); this.tableHandle = requireNonNull(tableHandle, "tableHandle is null"); @@ -62,7 +59,7 @@ public Operator createOperator(DriverContext driverContext) { checkState(!closed, "Factory is already closed"); OperatorContext context = driverContext.addOperatorContext(operatorId, planNodeId, MetadataDeleteOperator.class.getSimpleName()); - return new MetadataDeleteOperator(context, tableLayout, metadata, session, tableHandle); + return new MetadataDeleteOperator(context, metadata, session, tableHandle); } @Override @@ -74,22 +71,20 @@ public void noMoreOperators() @Override public OperatorFactory duplicate() { - return new MetadataDeleteOperatorFactory(operatorId, planNodeId, tableLayout, metadata, session, tableHandle); + return new MetadataDeleteOperatorFactory(operatorId, planNodeId, metadata, session, tableHandle); } } private final OperatorContext operatorContext; - private final TableLayoutHandle tableLayout; private final Metadata metadata; private final Session session; private final TableHandle tableHandle; private boolean finished; - public MetadataDeleteOperator(OperatorContext operatorContext, TableLayoutHandle tableLayout, Metadata metadata, Session session, TableHandle tableHandle) + public MetadataDeleteOperator(OperatorContext operatorContext, Metadata metadata, Session session, TableHandle tableHandle) { this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); - this.tableLayout = requireNonNull(tableLayout, "tableLayout is null"); this.metadata = requireNonNull(metadata, "metadata is null"); this.session = requireNonNull(session, "session is null"); this.tableHandle = requireNonNull(tableHandle, "tableHandle is null"); @@ -132,7 +127,7 @@ public Page getOutput() } finished = true; - OptionalLong rowsDeletedCount = metadata.metadataDelete(session, tableHandle, tableLayout); + OptionalLong rowsDeletedCount = metadata.metadataDelete(session, tableHandle); // output page will only be constructed once, // so a new PageBuilder is constructed (instead of using PageBuilder.reset) diff --git a/presto-main/src/main/java/io/prestosql/operator/project/InterpretedPageFilter.java b/presto-main/src/main/java/io/prestosql/operator/project/InterpretedPageFilter.java deleted file mode 100644 index d2134e0acd8..00000000000 --- a/presto-main/src/main/java/io/prestosql/operator/project/InterpretedPageFilter.java +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prestosql.operator.project; - -import com.google.common.collect.ImmutableMap; -import io.prestosql.Session; -import io.prestosql.execution.warnings.WarningCollector; -import io.prestosql.metadata.Metadata; -import io.prestosql.spi.Page; -import io.prestosql.spi.connector.ConnectorSession; -import io.prestosql.spi.type.Type; -import io.prestosql.sql.parser.SqlParser; -import io.prestosql.sql.planner.DeterminismEvaluator; -import io.prestosql.sql.planner.ExpressionInterpreter; -import io.prestosql.sql.planner.Symbol; -import io.prestosql.sql.planner.SymbolToInputParameterRewriter; -import io.prestosql.sql.planner.TypeProvider; -import io.prestosql.sql.tree.Expression; -import io.prestosql.sql.tree.NodeRef; - -import javax.annotation.concurrent.NotThreadSafe; - -import java.util.List; -import java.util.Map; - -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput; -import static java.lang.Boolean.TRUE; -import static java.util.Collections.emptyList; - -@NotThreadSafe -public class InterpretedPageFilter - implements PageFilter -{ - private final ExpressionInterpreter evaluator; - private final InputChannels inputChannels; - private final boolean deterministic; - private boolean[] selectedPositions = new boolean[0]; - - public InterpretedPageFilter( - Expression expression, - TypeProvider symbolTypes, - Map symbolToInputMappings, - Metadata metadata, - SqlParser sqlParser, - Session session) - { - SymbolToInputParameterRewriter rewriter = new SymbolToInputParameterRewriter(symbolTypes, symbolToInputMappings); - Expression rewritten = rewriter.rewrite(expression); - this.inputChannels = new InputChannels(rewriter.getInputChannels()); - this.deterministic = DeterminismEvaluator.isDeterministic(expression); - - // analyze rewritten expression so we can know the type of every expression in the tree - List inputTypes = rewriter.getInputTypes(); - ImmutableMap.Builder parameterTypes = ImmutableMap.builder(); - for (int parameter = 0; parameter < inputTypes.size(); parameter++) { - Type type = inputTypes.get(parameter); - parameterTypes.put(parameter, type); - } - Map, Type> expressionTypes = getExpressionTypesFromInput(session, metadata, sqlParser, parameterTypes.build(), rewritten, emptyList(), WarningCollector.NOOP); - this.evaluator = ExpressionInterpreter.expressionInterpreter(rewritten, metadata, session, expressionTypes); - } - - @Override - public boolean isDeterministic() - { - return deterministic; - } - - @Override - public InputChannels getInputChannels() - { - return inputChannels; - } - - @Override - public SelectedPositions filter(ConnectorSession session, Page page) - { - if (selectedPositions.length < page.getPositionCount()) { - selectedPositions = new boolean[page.getPositionCount()]; - } - - for (int position = 0; position < page.getPositionCount(); position++) { - selectedPositions[position] = filter(page, position); - } - - return PageFilter.positionsArrayToSelectedPositions(selectedPositions, page.getPositionCount()); - } - - private boolean filter(Page page, int position) - { - return TRUE.equals(evaluator.evaluate(position, page)); - } -} diff --git a/presto-main/src/main/java/io/prestosql/operator/project/InterpretedPageProjection.java b/presto-main/src/main/java/io/prestosql/operator/project/InterpretedPageProjection.java deleted file mode 100644 index 733c0e506bb..00000000000 --- a/presto-main/src/main/java/io/prestosql/operator/project/InterpretedPageProjection.java +++ /dev/null @@ -1,158 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prestosql.operator.project; - -import com.google.common.collect.ImmutableMap; -import io.prestosql.Session; -import io.prestosql.execution.warnings.WarningCollector; -import io.prestosql.metadata.Metadata; -import io.prestosql.operator.DriverYieldSignal; -import io.prestosql.operator.Work; -import io.prestosql.spi.Page; -import io.prestosql.spi.block.Block; -import io.prestosql.spi.block.BlockBuilder; -import io.prestosql.spi.connector.ConnectorSession; -import io.prestosql.spi.type.Type; -import io.prestosql.sql.parser.SqlParser; -import io.prestosql.sql.planner.DeterminismEvaluator; -import io.prestosql.sql.planner.ExpressionInterpreter; -import io.prestosql.sql.planner.Symbol; -import io.prestosql.sql.planner.SymbolToInputParameterRewriter; -import io.prestosql.sql.planner.TypeProvider; -import io.prestosql.sql.tree.Expression; -import io.prestosql.sql.tree.NodeRef; - -import java.util.List; -import java.util.Map; - -import static com.google.common.base.Preconditions.checkState; -import static io.prestosql.spi.type.TypeUtils.writeNativeValue; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput; -import static java.util.Collections.emptyList; -import static java.util.Objects.requireNonNull; - -public class InterpretedPageProjection - implements PageProjection -{ - private final ExpressionInterpreter evaluator; - private final InputChannels inputChannels; - private final boolean deterministic; - private BlockBuilder blockBuilder; - - public InterpretedPageProjection( - Expression expression, - TypeProvider symbolTypes, - Map symbolToInputMappings, - Metadata metadata, - SqlParser sqlParser, - Session session) - { - SymbolToInputParameterRewriter rewriter = new SymbolToInputParameterRewriter(symbolTypes, symbolToInputMappings); - Expression rewritten = rewriter.rewrite(expression); - this.inputChannels = new InputChannels(rewriter.getInputChannels()); - this.deterministic = DeterminismEvaluator.isDeterministic(expression); - - // analyze rewritten expression so we can know the type of every expression in the tree - List inputTypes = rewriter.getInputTypes(); - ImmutableMap.Builder parameterTypes = ImmutableMap.builder(); - for (int parameter = 0; parameter < inputTypes.size(); parameter++) { - Type type = inputTypes.get(parameter); - parameterTypes.put(parameter, type); - } - Map, Type> expressionTypes = getExpressionTypesFromInput(session, metadata, sqlParser, parameterTypes.build(), rewritten, emptyList(), WarningCollector.NOOP); - this.evaluator = ExpressionInterpreter.expressionInterpreter(rewritten, metadata, session, expressionTypes); - - blockBuilder = evaluator.getType().createBlockBuilder(null, 1); - } - - @Override - public Type getType() - { - return evaluator.getType(); - } - - @Override - public boolean isDeterministic() - { - return deterministic; - } - - @Override - public InputChannels getInputChannels() - { - return inputChannels; - } - - @Override - public Work project(ConnectorSession session, DriverYieldSignal yieldSignal, Page page, SelectedPositions selectedPositions) - { - return new InterpretedPageProjectionWork(yieldSignal, page, selectedPositions); - } - - private class InterpretedPageProjectionWork - implements Work - { - private final DriverYieldSignal yieldSignal; - private final Page page; - private final SelectedPositions selectedPositions; - - private int nextIndexOrPosition; - private Block result; - - public InterpretedPageProjectionWork(DriverYieldSignal yieldSignal, Page page, SelectedPositions selectedPositions) - { - this.yieldSignal = requireNonNull(yieldSignal, "yieldSignal is null"); - this.page = requireNonNull(page, "page is null"); - this.selectedPositions = requireNonNull(selectedPositions, "selectedPositions is null"); - this.nextIndexOrPosition = selectedPositions.getOffset(); - } - - @Override - public boolean process() - { - checkState(result == null, "result has been generated"); - int length = selectedPositions.getOffset() + selectedPositions.size(); - if (selectedPositions.isList()) { - int[] positions = selectedPositions.getPositions(); - while (nextIndexOrPosition < length) { - writeNativeValue(evaluator.getType(), blockBuilder, evaluator.evaluate(positions[nextIndexOrPosition], page)); - nextIndexOrPosition++; - if (yieldSignal.isSet()) { - return false; - } - } - } - else { - while (nextIndexOrPosition < length) { - writeNativeValue(evaluator.getType(), blockBuilder, evaluator.evaluate(nextIndexOrPosition, page)); - nextIndexOrPosition++; - if (yieldSignal.isSet()) { - return false; - } - } - } - - result = blockBuilder.build(); - blockBuilder = blockBuilder.newBlockBuilderLike(null); - return true; - } - - @Override - public Block getResult() - { - checkState(result != null, "result has not been generated"); - return result; - } - } -} diff --git a/presto-main/src/main/java/io/prestosql/server/ServerMainModule.java b/presto-main/src/main/java/io/prestosql/server/ServerMainModule.java index c8b15755063..af58b6b1fe7 100644 --- a/presto-main/src/main/java/io/prestosql/server/ServerMainModule.java +++ b/presto-main/src/main/java/io/prestosql/server/ServerMainModule.java @@ -127,6 +127,7 @@ import io.prestosql.sql.planner.CompilerConfig; import io.prestosql.sql.planner.LocalExecutionPlanner; import io.prestosql.sql.planner.NodePartitioningManager; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.FunctionCall; import io.prestosql.transaction.TransactionManagerConfig; @@ -354,6 +355,7 @@ protected void setup(Binder binder) binder.bind(Metadata.class).to(MetadataManager.class).in(Scopes.SINGLETON); // type + binder.bind(TypeAnalyzer.class).in(Scopes.SINGLETON); binder.bind(TypeRegistry.class).in(Scopes.SINGLETON); binder.bind(TypeManager.class).to(TypeRegistry.class).in(Scopes.SINGLETON); jsonBinder(binder).addDeserializerBinding(Type.class).to(TypeDeserializer.class); diff --git a/presto-main/src/main/java/io/prestosql/spi/expression/Apply.java b/presto-main/src/main/java/io/prestosql/spi/expression/Apply.java new file mode 100644 index 00000000000..a4890dcb4c6 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/spi/expression/Apply.java @@ -0,0 +1,43 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.spi.expression; + +import io.prestosql.spi.type.Type; + +import java.util.List; + +public class Apply + extends ConnectorExpression +{ + private final FunctionId function; + private final List arguments; + + public Apply(Type returnType, FunctionId function, List arguments) + { + super(returnType); + this.function = function; + this.arguments = arguments; + } + + // TODO: this will need to be a FunctionHandle + public FunctionId getFunction() + { + return function; + } + + public List getArguments() + { + return arguments; + } +} diff --git a/presto-main/src/main/java/io/prestosql/spi/expression/ColumnReference.java b/presto-main/src/main/java/io/prestosql/spi/expression/ColumnReference.java new file mode 100644 index 00000000000..ec8452c476d --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/spi/expression/ColumnReference.java @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.spi.expression; + +import io.prestosql.spi.connector.ColumnHandle; +import io.prestosql.spi.type.Type; + +public class ColumnReference + extends ConnectorExpression +{ + private final ColumnHandle column; + + public ColumnReference(ColumnHandle column, Type type) + { + super(type); + this.column = column; + } + + public ColumnHandle getColumn() + { + return column; + } +} diff --git a/presto-main/src/main/java/io/prestosql/spi/expression/ConnectorExpression.java b/presto-main/src/main/java/io/prestosql/spi/expression/ConnectorExpression.java new file mode 100644 index 00000000000..2c2e96f64f0 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/spi/expression/ConnectorExpression.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.spi.expression; + +import io.prestosql.spi.type.Type; + +public class ConnectorExpression +{ + private final Type type; + + public ConnectorExpression(Type type) + { + this.type = type; + } + + public Type getType() + { + return type; + } +} diff --git a/presto-main/src/main/java/io/prestosql/spi/expression/ConnectorExpressionTranslator.java b/presto-main/src/main/java/io/prestosql/spi/expression/ConnectorExpressionTranslator.java new file mode 100644 index 00000000000..c230f1b1cd5 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/spi/expression/ConnectorExpressionTranslator.java @@ -0,0 +1,159 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.spi.expression; + +import com.google.common.collect.ImmutableList; +import io.prestosql.Session; +import io.prestosql.metadata.Metadata; +import io.prestosql.spi.connector.ColumnHandle; +import io.prestosql.spi.type.Type; +import io.prestosql.sql.planner.LiteralEncoder; +import io.prestosql.sql.planner.LiteralInterpreter; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; +import io.prestosql.sql.planner.TypeProvider; +import io.prestosql.sql.tree.AstVisitor; +import io.prestosql.sql.tree.ComparisonExpression; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.FunctionCall; +import io.prestosql.sql.tree.GenericLiteral; +import io.prestosql.sql.tree.NodeRef; +import io.prestosql.sql.tree.QualifiedName; +import io.prestosql.sql.tree.StringLiteral; +import io.prestosql.sql.tree.SymbolReference; + +import java.util.Map; +import java.util.stream.Collectors; + +public class ConnectorExpressionTranslator +{ + private ConnectorExpressionTranslator() + { + } + + public static Expression translate(ConnectorExpression expression, Map mappings, Metadata metadata) + { + return new ConnectorToSqlExpressionTranslator(mappings, metadata).translate(expression); + } + + public static ConnectorExpression translate(Session session, Expression expression, Map assignments, TypeAnalyzer types, TypeProvider inputTypes, Metadata metadata) + { + return new SqlToConnectorExpressionTranslator(session, metadata, assignments, types.getTypes(session, inputTypes, expression)) + .process(expression); + } + + private static class ConnectorToSqlExpressionTranslator + { + private final Map mappings; + private final LiteralEncoder literalEncoder; + + public ConnectorToSqlExpressionTranslator(Map mappings, Metadata metadata) + { + this.mappings = mappings; + this.literalEncoder = new LiteralEncoder(metadata.getBlockEncodingSerde()); + } + + private String nameOf(FunctionId function) + { + return function.getName(); // TODO + } + + public Expression translate(ConnectorExpression expression) + { + if (expression instanceof Constant) { + return literalEncoder.toExpression(((Constant) expression).getValue(), expression.getType()); + } + + if (expression instanceof ColumnReference) { + return mappings.get(((ColumnReference) expression).getColumn()).toSymbolReference(); + } + + if (expression instanceof Apply) { + Apply apply = (Apply) expression; + + return new FunctionCall( + QualifiedName.of(nameOf(apply.getFunction())), + apply.getArguments().stream() + .map(this::translate) + .collect(Collectors.toList())); + } + + throw new UnsupportedOperationException("Expression type not supported: " + expression.getClass().getName()); + + } + } + + private static class SqlToConnectorExpressionTranslator + extends AstVisitor + { + private final Session session; + private final Metadata metadata; + private final Map assignments; + private final Map, Type> types; + + private SqlToConnectorExpressionTranslator(Session session, Metadata metadata, Map assignments, Map, Type> types) + { + this.session = session; + this.metadata = metadata; + this.assignments = assignments; + this.types = types; + } + + private Type typeOf(Expression node) + { + return types.get(NodeRef.of(node)); + } + + // TODO: need to return a FunctionHandle for the operator + private FunctionId signatureOf(ComparisonExpression.Operator operator, Type left, Type right) + { + return new FunctionId("$operator_" + operator.name()); + } + + @Override + protected ConnectorExpression visitComparisonExpression(ComparisonExpression node, Void context) + { + ConnectorExpression left = process(node.getLeft()); + ConnectorExpression right = process(node.getRight()); + + return new Apply( + typeOf(node), signatureOf(node.getOperator(), left.getType(), right.getType()), + ImmutableList.of(left, right)); + } + + @Override + protected ConnectorExpression visitSymbolReference(SymbolReference node, Void context) + { + return new ColumnReference(assignments.get(Symbol.from(node)), typeOf(node)); + } + + @Override + protected ConnectorExpression visitGenericLiteral(GenericLiteral node, Void context) + { + return new Constant(LiteralInterpreter.evaluate(metadata, session.toConnectorSession(), node), typeOf(node)); + } + + @Override + protected ConnectorExpression visitStringLiteral(StringLiteral node, Void context) + { + return new Constant(node.getSlice(), typeOf(node)); + } + + @Override + protected ConnectorExpression visitExpression(Expression node, Void context) + { + throw new UnsupportedOperationException("not yet implemented: expression translator for " + node.getClass().getName()); + } + } +} diff --git a/presto-main/src/main/java/io/prestosql/spi/expression/Constant.java b/presto-main/src/main/java/io/prestosql/spi/expression/Constant.java new file mode 100644 index 00000000000..15a91c2efec --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/spi/expression/Constant.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.spi.expression; + +import io.prestosql.spi.type.Type; + +public class Constant + extends ConnectorExpression +{ + private final Object value; + + public Constant(Object value, Type type) + { + super(type); + this.value = value; + } + + public Object getValue() + { + return value; + } +} diff --git a/presto-main/src/main/java/io/prestosql/spi/expression/FunctionId.java b/presto-main/src/main/java/io/prestosql/spi/expression/FunctionId.java new file mode 100644 index 00000000000..88b650ce722 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/spi/expression/FunctionId.java @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.spi.expression; + +public class FunctionId +{ + private final String identifier; + + public FunctionId(String identifier) + { + this.identifier = identifier; + } + + public String getName() + { + return identifier; + } +} diff --git a/presto-main/src/main/java/io/prestosql/split/SplitManager.java b/presto-main/src/main/java/io/prestosql/split/SplitManager.java index 7a9bc42fa3a..72ad74bce3c 100644 --- a/presto-main/src/main/java/io/prestosql/split/SplitManager.java +++ b/presto-main/src/main/java/io/prestosql/split/SplitManager.java @@ -16,14 +16,18 @@ import io.prestosql.Session; import io.prestosql.connector.ConnectorId; import io.prestosql.execution.QueryManagerConfig; -import io.prestosql.metadata.TableLayoutHandle; +import io.prestosql.metadata.Metadata; +import io.prestosql.metadata.TableHandle; import io.prestosql.spi.connector.ConnectorSession; import io.prestosql.spi.connector.ConnectorSplitManager; import io.prestosql.spi.connector.ConnectorSplitManager.SplitSchedulingStrategy; import io.prestosql.spi.connector.ConnectorSplitSource; +import io.prestosql.spi.connector.ConnectorTableLayoutHandle; +import io.prestosql.spi.connector.Constraint; import javax.inject.Inject; +import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -36,10 +40,15 @@ public class SplitManager private final ConcurrentMap splitManagers = new ConcurrentHashMap<>(); private final int minScheduleSplitBatchSize; + // This is used to fetch a table layout if the TableHandle doesn't have one set + // It's a temporary measure until we get rid of TableLayouts from the SPI + private final Metadata metadata; + @Inject - public SplitManager(QueryManagerConfig config) + public SplitManager(QueryManagerConfig config, Metadata metadata) { this.minScheduleSplitBatchSize = config.getMinScheduleSplitBatchSize(); + this.metadata = metadata; } public void addConnectorSplitManager(ConnectorId connectorId, ConnectorSplitManager connectorSplitManager) @@ -54,20 +63,27 @@ public void removeConnectorSplitManager(ConnectorId connectorId) splitManagers.remove(connectorId); } - public SplitSource getSplits(Session session, TableLayoutHandle layout, SplitSchedulingStrategy splitSchedulingStrategy) + public SplitSource getSplits(Session session, TableHandle table, SplitSchedulingStrategy splitSchedulingStrategy) { - ConnectorId connectorId = layout.getConnectorId(); + ConnectorId connectorId = table.getConnectorId(); ConnectorSplitManager splitManager = getConnectorSplitManager(connectorId); ConnectorSession connectorSession = session.toConnectorSession(connectorId); + ConnectorTableLayoutHandle layout = table.getLayout() + .orElseGet(() -> metadata.getLayout(session, table, Constraint.alwaysTrue(), Optional.empty()) + .get() + .getNewTableHandle() + .getLayout() + .get()); + ConnectorSplitSource source = splitManager.getSplits( - layout.getTransactionHandle(), + table.getTransaction(), connectorSession, - layout.getConnectorHandle(), + layout, splitSchedulingStrategy); - SplitSource splitSource = new ConnectorAwareSplitSource(connectorId, layout.getTransactionHandle(), source); + SplitSource splitSource = new ConnectorAwareSplitSource(connectorId, table.getTransaction(), source); if (minScheduleSplitBatchSize > 1) { splitSource = new BufferingSplitSource(splitSource, minScheduleSplitBatchSize); } diff --git a/presto-main/src/main/java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java b/presto-main/src/main/java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java index f6642d34305..8e0f0e04acb 100644 --- a/presto-main/src/main/java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java +++ b/presto-main/src/main/java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java @@ -110,7 +110,6 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Map; -import java.util.Map.Entry; import java.util.Optional; import java.util.OptionalInt; import java.util.Set; @@ -1434,119 +1433,12 @@ public static Signature resolveFunction(FunctionCall node, List, Type> getExpressionTypes( - Session session, - Metadata metadata, - SqlParser sqlParser, - TypeProvider types, - Expression expression, - List parameters, - WarningCollector warningCollector) - { - return getExpressionTypes(session, metadata, sqlParser, types, expression, parameters, warningCollector, false); - } - - public static Map, Type> getExpressionTypes( - Session session, - Metadata metadata, - SqlParser sqlParser, - TypeProvider types, - Expression expression, - List parameters, - WarningCollector warningCollector, - boolean isDescribe) - { - return getExpressionTypes(session, metadata, sqlParser, types, ImmutableList.of(expression), parameters, warningCollector, isDescribe); - } - - public static Map, Type> getExpressionTypes( - Session session, - Metadata metadata, - SqlParser sqlParser, - TypeProvider types, - Iterable expressions, - List parameters, - WarningCollector warningCollector, - boolean isDescribe) - { - return analyzeExpressionsWithSymbols(session, metadata, sqlParser, types, expressions, parameters, warningCollector, isDescribe).getExpressionTypes(); - } - - public static Map, Type> getExpressionTypesFromInput( - Session session, - Metadata metadata, - SqlParser sqlParser, - Map types, - Expression expression, - List parameters, - WarningCollector warningCollector) - { - return getExpressionTypesFromInput(session, metadata, sqlParser, types, ImmutableList.of(expression), parameters, warningCollector); - } - - public static Map, Type> getExpressionTypesFromInput( - Session session, - Metadata metadata, - SqlParser sqlParser, - Map types, - Iterable expressions, - List parameters, - WarningCollector warningCollector) - { - return analyzeExpressionsWithInputs(session, metadata, sqlParser, types, expressions, parameters, warningCollector).getExpressionTypes(); - } - - public static ExpressionAnalysis analyzeExpressionsWithSymbols( - Session session, - Metadata metadata, - SqlParser sqlParser, - TypeProvider types, - Iterable expressions, - List parameters, - WarningCollector warningCollector, - boolean isDescribe) - { - return analyzeExpressions(session, metadata, sqlParser, new RelationType(), types, expressions, parameters, warningCollector, isDescribe); - } - - private static ExpressionAnalysis analyzeExpressionsWithInputs( - Session session, - Metadata metadata, - SqlParser sqlParser, - Map types, - Iterable expressions, - List parameters, - WarningCollector warningCollector) - { - Field[] fields = new Field[types.size()]; - for (Entry entry : types.entrySet()) { - fields[entry.getKey()] = io.prestosql.sql.analyzer.Field.newUnqualified(Optional.empty(), entry.getValue()); - } - RelationType tupleDescriptor = new RelationType(fields); - - return analyzeExpressions(session, metadata, sqlParser, tupleDescriptor, TypeProvider.empty(), expressions, parameters, warningCollector); - } - public static ExpressionAnalysis analyzeExpressions( Session session, Metadata metadata, SqlParser sqlParser, - RelationType tupleDescriptor, TypeProvider types, - Iterable expressions, - List parameters, - WarningCollector warningCollector) - { - return analyzeExpressions(session, metadata, sqlParser, tupleDescriptor, types, expressions, parameters, warningCollector, false); - } - - private static ExpressionAnalysis analyzeExpressions( - Session session, - Metadata metadata, - SqlParser sqlParser, - RelationType tupleDescriptor, - TypeProvider types, - Iterable expressions, + Iterable expressions, List parameters, WarningCollector warningCollector, boolean isDescribe) @@ -1556,7 +1448,7 @@ private static ExpressionAnalysis analyzeExpressions( Analysis analysis = new Analysis(null, parameters, isDescribe); ExpressionAnalyzer analyzer = create(analysis, session, metadata, sqlParser, new DenyAllAccessControl(), types, warningCollector); for (Expression expression : expressions) { - analyzer.analyze(expression, Scope.builder().withRelationType(RelationId.anonymous(), tupleDescriptor).build()); + analyzer.analyze(expression, Scope.builder().withRelationType(RelationId.anonymous(), new RelationType()).build()); } return new ExpressionAnalysis( diff --git a/presto-main/src/main/java/io/prestosql/sql/analyzer/QueryExplainer.java b/presto-main/src/main/java/io/prestosql/sql/analyzer/QueryExplainer.java index 834b9a313ee..d0fd2416ace 100644 --- a/presto-main/src/main/java/io/prestosql/sql/analyzer/QueryExplainer.java +++ b/presto-main/src/main/java/io/prestosql/sql/analyzer/QueryExplainer.java @@ -29,6 +29,7 @@ import io.prestosql.sql.planner.PlanNodeIdAllocator; import io.prestosql.sql.planner.PlanOptimizers; import io.prestosql.sql.planner.SubPlan; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.optimizations.PlanOptimizer; import io.prestosql.sql.planner.planPrinter.IoPlanPrinter; import io.prestosql.sql.planner.planPrinter.PlanPrinter; @@ -175,7 +176,7 @@ public Plan getLogicalPlan(Session session, Statement statement, List s throw new SemanticException(NON_NUMERIC_SAMPLE_PERCENTAGE, relation.getSamplePercentage(), "Sample percentage cannot contain column references"); } - Map, Type> expressionTypes = getExpressionTypes( + Map, Type> expressionTypes = ExpressionAnalyzer.analyzeExpressions( session, metadata, sqlParser, TypeProvider.empty(), - relation.getSamplePercentage(), + ImmutableList.of(relation.getSamplePercentage()), analysis.getParameters(), WarningCollector.NOOP, - analysis.isDescribe()); + analysis.isDescribe()) + .getExpressionTypes(); + ExpressionInterpreter samplePercentageEval = expressionOptimizer(relation.getSamplePercentage(), metadata, session, expressionTypes); Object samplePercentageObject = samplePercentageEval.optimize(symbol -> { diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/DesugarAtTimeZoneRewriter.java b/presto-main/src/main/java/io/prestosql/sql/planner/DesugarAtTimeZoneRewriter.java index e9020862a05..0c6a2375443 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/DesugarAtTimeZoneRewriter.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/DesugarAtTimeZoneRewriter.java @@ -16,10 +16,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.prestosql.Session; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.spi.type.Type; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.tree.AtTimeZone; import io.prestosql.sql.tree.Cast; import io.prestosql.sql.tree.Expression; @@ -36,8 +34,6 @@ import static io.prestosql.spi.type.TimeWithTimeZoneType.TIME_WITH_TIME_ZONE; import static io.prestosql.spi.type.TimestampType.TIMESTAMP; import static io.prestosql.spi.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; public class DesugarAtTimeZoneRewriter @@ -49,15 +45,15 @@ public static Expression rewrite(Expression expression, Map, private DesugarAtTimeZoneRewriter() {} - public static Expression rewrite(Expression expression, Session session, Metadata metadata, SqlParser sqlParser, SymbolAllocator symbolAllocator) + public static Expression rewrite(Expression expression, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, SymbolAllocator symbolAllocator) { requireNonNull(metadata, "metadata is null"); - requireNonNull(sqlParser, "sqlParser is null"); + requireNonNull(typeAnalyzer, "typeAnalyzer is null"); if (expression instanceof SymbolReference) { return expression; } - Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, symbolAllocator.getTypes(), expression, emptyList(), WarningCollector.NOOP); + Map, Type> expressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), expression); return rewrite(expression, expressionTypes); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/DistributedExecutionPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/DistributedExecutionPlanner.java index bae4b88fbf3..6ec1f2cf07b 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/DistributedExecutionPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/DistributedExecutionPlanner.java @@ -146,7 +146,7 @@ public Map visitTableScan(TableScanNode node, Void cont // get dataSource for table SplitSource splitSource = splitManager.getSplits( session, - node.getLayout().get(), + node.getTable(), stageExecutionDescriptor.isScanGroupedExecution(node.getId()) ? GROUPED_SCHEDULING : UNGROUPED_SCHEDULING); splitSources.add(splitSource); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/DomainTranslator.java b/presto-main/src/main/java/io/prestosql/sql/planner/DomainTranslator.java index cfa67ca0d13..61eb2a8697c 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/DomainTranslator.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/DomainTranslator.java @@ -17,7 +17,6 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.PeekingIterator; import io.prestosql.Session; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.Signature; import io.prestosql.spi.block.Block; @@ -34,7 +33,6 @@ import io.prestosql.spi.type.Type; import io.prestosql.sql.ExpressionUtils; import io.prestosql.sql.InterpretedFunctionInvoker; -import io.prestosql.sql.analyzer.ExpressionAnalyzer; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.tree.AstVisitor; import io.prestosql.sql.tree.BetweenPredicate; @@ -78,7 +76,6 @@ import static io.prestosql.sql.tree.ComparisonExpression.Operator.LESS_THAN; import static io.prestosql.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; import static io.prestosql.sql.tree.ComparisonExpression.Operator.NOT_EQUAL; -import static java.util.Collections.emptyList; import static java.util.Comparator.comparing; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.collectingAndThen; @@ -277,7 +274,7 @@ public static ExtractionResult fromPredicate( Expression predicate, TypeProvider types) { - return new Visitor(metadata, session, types).process(predicate, false); + return new Visitor(metadata, session, types, new TypeAnalyzer(new SqlParser(), metadata)).process(predicate, false); } private static class Visitor @@ -288,14 +285,16 @@ private static class Visitor private final Session session; private final TypeProvider types; private final InterpretedFunctionInvoker functionInvoker; + private final TypeAnalyzer typeAnalyzer; - private Visitor(Metadata metadata, Session session, TypeProvider types) + private Visitor(Metadata metadata, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer) { this.metadata = requireNonNull(metadata, "metadata is null"); this.literalEncoder = new LiteralEncoder(metadata.getBlockEncodingSerde()); this.session = requireNonNull(session, "session is null"); this.types = requireNonNull(types, "types is null"); this.functionInvoker = new InterpretedFunctionInvoker(metadata.getFunctionRegistry()); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); } private Type checkedTypeLookup(Symbol symbol) @@ -424,7 +423,7 @@ else if (symbolExpression instanceof Cast) { return super.visitComparisonExpression(node, complement); } - Type castSourceType = typeOf(castExpression.getExpression(), session, metadata, types); // type of expression which is then cast to type of value + Type castSourceType = typeAnalyzer.getType(session, types, castExpression.getExpression()); // type of expression which is then cast to type of value // we use saturated floor cast value -> castSourceType to rewrite original expression to new one with one cast peeled off the symbol side Optional coercedExpression = coerceComparisonWithRounding( @@ -489,7 +488,7 @@ private boolean isImplicitCoercion(Cast cast) private Map, Type> analyzeExpression(Expression expression) { - return ExpressionAnalyzer.getExpressionTypes(session, metadata, new SqlParser(), types, expression, emptyList(), WarningCollector.NOOP); + return typeAnalyzer.getTypes(session, types, expression); } private static ExtractionResult createComparisonExtractionResult(ComparisonExpression.Operator comparisonOperator, Symbol column, Type type, @Nullable Object value, boolean complement) @@ -757,12 +756,6 @@ protected ExtractionResult visitNullLiteral(NullLiteral node, Boolean complement } } - private static Type typeOf(Expression expression, Session session, Metadata metadata, TypeProvider types) - { - Map, Type> expressionTypes = ExpressionAnalyzer.getExpressionTypes(session, metadata, new SqlParser(), types, expression, emptyList(), WarningCollector.NOOP); - return expressionTypes.get(NodeRef.of(expression)); - } - private static class NormalizedSimpleComparison { private final Expression symbolExpression; diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/ExpressionInterpreter.java b/presto-main/src/main/java/io/prestosql/sql/planner/ExpressionInterpreter.java index 690a1d1891e..a3e27e15590 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/ExpressionInterpreter.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/ExpressionInterpreter.java @@ -28,7 +28,6 @@ import io.prestosql.metadata.Signature; import io.prestosql.operator.scalar.ArraySubscriptOperator; import io.prestosql.operator.scalar.ScalarFunctionImplementation; -import io.prestosql.spi.Page; import io.prestosql.spi.PrestoException; import io.prestosql.spi.block.Block; import io.prestosql.spi.block.BlockBuilder; @@ -252,10 +251,10 @@ public Object evaluate() return visitor.process(expression, new NoPagePositionContext()); } - public Object evaluate(int position, Page page) + public Object evaluate(SymbolResolver inputs) { checkState(!optimize, "evaluate(int, Page) not allowed for optimizer"); - return visitor.process(expression, new SinglePagePositionContext(position, page)); + return visitor.process(expression, inputs); } public Object optimize(SymbolResolver inputs) @@ -271,39 +270,7 @@ private class Visitor @Override public Object visitFieldReference(FieldReference node, Object context) { - Type type = type(node); - - int channel = node.getFieldIndex(); - if (context instanceof PagePositionContext) { - PagePositionContext pagePositionContext = (PagePositionContext) context; - int position = pagePositionContext.getPosition(channel); - Block block = pagePositionContext.getBlock(channel); - - if (block.isNull(position)) { - return null; - } - - Class javaType = type.getJavaType(); - if (javaType == boolean.class) { - return type.getBoolean(block, position); - } - else if (javaType == long.class) { - return type.getLong(block, position); - } - else if (javaType == double.class) { - return type.getDouble(block, position); - } - else if (javaType == Slice.class) { - return type.getSlice(block, position); - } - else if (javaType == Block.class) { - return type.getObject(block, position); - } - else { - throw new UnsupportedOperationException("not yet implemented"); - } - } - throw new UnsupportedOperationException("Inputs must be set"); + throw new UnsupportedOperationException("Field references not supported in interpreter"); } @Override @@ -1299,31 +1266,6 @@ public int getPosition(int channel) } } - private static class SinglePagePositionContext - implements PagePositionContext - { - private final int position; - private final Page page; - - private SinglePagePositionContext(int position, Page page) - { - this.position = position; - this.page = page; - } - - @Override - public Block getBlock(int channel) - { - return page.getBlock(channel); - } - - @Override - public int getPosition(int channel) - { - return position; - } - } - private static Expression createFailureFunction(RuntimeException exception, Type type) { requireNonNull(exception, "Exception is null"); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/InputExtractor.java b/presto-main/src/main/java/io/prestosql/sql/planner/InputExtractor.java index d2f17ccf7f0..69b93766656 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/InputExtractor.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/InputExtractor.java @@ -20,8 +20,6 @@ import io.prestosql.execution.Input; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.TableHandle; -import io.prestosql.metadata.TableLayoutHandle; -import io.prestosql.metadata.TableMetadata; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ColumnMetadata; import io.prestosql.spi.connector.SchemaTableName; @@ -59,10 +57,10 @@ private static Column createColumn(ColumnMetadata columnMetadata) return new Column(columnMetadata.getName(), columnMetadata.getType().toString()); } - private Input createInput(TableMetadata table, Optional layout, Set columns) + private Input createInput(Session session, TableHandle table, Set columns) { - SchemaTableName schemaTable = table.getTable(); - Optional inputMetadata = layout.flatMap(tableLayout -> metadata.getInfo(session, tableLayout)); + SchemaTableName schemaTable = metadata.getTableMetadata(session, table).getTable(); + Optional inputMetadata = metadata.getInfo(session, table); return new Input(table.getConnectorId(), schemaTable.getSchemaName(), schemaTable.getTableName(), inputMetadata, ImmutableList.copyOf(columns)); } @@ -86,7 +84,7 @@ public Void visitTableScan(TableScanNode node, Void context) columns.add(createColumn(metadata.getColumnMetadata(session, tableHandle, columnHandle))); } - inputs.add(createInput(metadata.getTableMetadata(session, tableHandle), node.getLayout(), columns)); + inputs.add(createInput(session, tableHandle, columns)); return null; } @@ -101,7 +99,7 @@ public Void visitIndexSource(IndexSourceNode node, Void context) columns.add(createColumn(metadata.getColumnMetadata(session, tableHandle, columnHandle))); } - inputs.add(createInput(metadata.getTableMetadata(session, tableHandle), node.getLayout(), columns)); + inputs.add(createInput(session, tableHandle, columns)); return null; } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java index 4aad7f3b2d9..47aefe8fec2 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java @@ -133,7 +133,6 @@ import io.prestosql.sql.gen.JoinFilterFunctionCompiler.JoinFilterFunctionFactory; import io.prestosql.sql.gen.OrderingCompiler; import io.prestosql.sql.gen.PageFunctionCompiler; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.optimizations.IndexJoinOptimizer; import io.prestosql.sql.planner.plan.AggregationNode; import io.prestosql.sql.planner.plan.AggregationNode.Aggregation; @@ -182,7 +181,6 @@ import io.prestosql.sql.relational.SqlToRowExpressionTranslator; import io.prestosql.sql.tree.ComparisonExpression; import io.prestosql.sql.tree.Expression; -import io.prestosql.sql.tree.FieldReference; import io.prestosql.sql.tree.FunctionCall; import io.prestosql.sql.tree.LambdaArgumentDeclaration; import io.prestosql.sql.tree.LambdaExpression; @@ -217,7 +215,6 @@ import static com.google.common.base.Verify.verify; import static com.google.common.collect.DiscreteDomain.integers; import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.concat; import static com.google.common.collect.Iterables.getOnlyElement; @@ -232,7 +229,6 @@ import static io.prestosql.SystemSessionProperties.isSpillEnabled; import static io.prestosql.SystemSessionProperties.isSpillOrderBy; import static io.prestosql.SystemSessionProperties.isSpillWindowOperator; -import static io.prestosql.execution.warnings.WarningCollector.NOOP; import static io.prestosql.metadata.FunctionKind.SCALAR; import static io.prestosql.operator.DistinctLimitOperator.DistinctLimitOperatorFactory; import static io.prestosql.operator.NestedLoopBuildOperator.NestedLoopBuildOperatorFactory; @@ -251,8 +247,6 @@ import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.spi.type.TypeUtils.writeNativeValue; import static io.prestosql.spiller.PartitioningSpillerFactory.unsupportedPartitioningSpillerFactory; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput; import static io.prestosql.sql.gen.LambdaBytecodeGenerator.compileLambdaProvider; import static io.prestosql.sql.planner.ExpressionNodeInliner.replaceExpression; import static io.prestosql.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION; @@ -279,15 +273,13 @@ import static io.prestosql.util.SpatialJoinUtils.ST_WITHIN; import static io.prestosql.util.SpatialJoinUtils.extractSupportedSpatialComparisons; import static io.prestosql.util.SpatialJoinUtils.extractSupportedSpatialFunctions; -import static java.lang.String.format; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; import static java.util.stream.IntStream.range; public class LocalExecutionPlanner { private final Metadata metadata; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; private final Optional explainAnalyzeContext; private final PageSourceProvider pageSourceProvider; private final IndexManager indexManager; @@ -314,7 +306,7 @@ public class LocalExecutionPlanner @Inject public LocalExecutionPlanner( Metadata metadata, - SqlParser sqlParser, + TypeAnalyzer typeAnalyzer, Optional explainAnalyzeContext, PageSourceProvider pageSourceProvider, IndexManager indexManager, @@ -341,7 +333,7 @@ public LocalExecutionPlanner( this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null"); this.exchangeClientSupplier = exchangeClientSupplier; this.metadata = requireNonNull(metadata, "metadata is null"); - this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.pageSinkManager = requireNonNull(pageSinkManager, "pageSinkManager is null"); this.expressionCompiler = requireNonNull(expressionCompiler, "compiler is null"); this.pageFunctionCompiler = requireNonNull(pageFunctionCompiler, "pageFunctionCompiler is null"); @@ -1170,7 +1162,6 @@ private PhysicalOperation visitScanFilterAndProject( // if source is a table scan we fold it directly into the filter and project // otherwise we plan it as a normal operator Map sourceLayout; - Map sourceTypes; List columns = null; PhysicalOperation source = null; if (sourceNode instanceof TableScanNode) { @@ -1178,7 +1169,6 @@ private PhysicalOperation visitScanFilterAndProject( // extract the column handles and channel to type mapping sourceLayout = new LinkedHashMap<>(); - sourceTypes = new LinkedHashMap<>(); columns = new ArrayList<>(); int channel = 0; for (Symbol symbol : tableScanNode.getOutputSymbols()) { @@ -1187,9 +1177,6 @@ private PhysicalOperation visitScanFilterAndProject( Integer input = channel; sourceLayout.put(symbol, input); - Type type = requireNonNull(context.getTypes().get(symbol), format("No type for symbol %s", symbol)); - sourceTypes.put(input, type); - channel++; } } @@ -1209,7 +1196,6 @@ else if (sourceNode instanceof SampleNode) { // plan source source = sourceNode.accept(this, context); sourceLayout = source.getLayout(); - sourceTypes = getInputTypes(source.getLayout(), source.getTypes()); } // build output mapping @@ -1220,27 +1206,19 @@ else if (sourceNode instanceof SampleNode) { } Map outputMappings = outputMappingsBuilder.build(); - // compiler uses inputs instead of symbols, so rewrite the expressions first - SymbolToInputRewriter symbolToInputRewriter = new SymbolToInputRewriter(sourceLayout); - Optional rewrittenFilter = filterExpression.map(symbolToInputRewriter::rewrite); - - List rewrittenProjections = new ArrayList<>(); + List projections = new ArrayList<>(); for (Symbol symbol : outputSymbols) { - rewrittenProjections.add(symbolToInputRewriter.rewrite(assignments.get(symbol))); + projections.add(assignments.get(symbol)); } - Map, Type> expressionTypes = getExpressionTypesFromInput( + Map, Type> expressionTypes = typeAnalyzer.getTypes( context.getSession(), - metadata, - sqlParser, - sourceTypes, - concat(rewrittenFilter.map(ImmutableList::of).orElse(ImmutableList.of()), rewrittenProjections), - emptyList(), - NOOP); - - Optional translatedFilter = rewrittenFilter.map(filter -> toRowExpression(filter, expressionTypes)); - List translatedProjections = rewrittenProjections.stream() - .map(expression -> toRowExpression(expression, expressionTypes)) + context.getTypes(), + concat(filterExpression.map(ImmutableList::of).orElse(ImmutableList.of()), assignments.getExpressions())); + + Optional translatedFilter = filterExpression.map(filter -> toRowExpression(filter, expressionTypes, sourceLayout)); + List translatedProjections = projections.stream() + .map(expression -> toRowExpression(expression, expressionTypes, sourceLayout)) .collect(toImmutableList()); try { @@ -1256,7 +1234,7 @@ else if (sourceNode instanceof SampleNode) { cursorProcessor, pageProcessor, columns, - getTypes(rewrittenProjections, expressionTypes), + getTypes(projections, expressionTypes), getFilterAndProjectMinOutputPageSize(session), getFilterAndProjectMinOutputPageRowCount(session)); @@ -1269,7 +1247,7 @@ else if (sourceNode instanceof SampleNode) { context.getNextOperatorId(), planNodeId, pageProcessor, - getTypes(rewrittenProjections, expressionTypes), + getTypes(projections, expressionTypes), getFilterAndProjectMinOutputPageSize(session), getFilterAndProjectMinOutputPageRowCount(session)); @@ -1281,19 +1259,9 @@ else if (sourceNode instanceof SampleNode) { } } - private RowExpression toRowExpression(Expression expression, Map, Type> types) - { - return SqlToRowExpressionTranslator.translate(expression, SCALAR, types, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, true); - } - - private Map getInputTypes(Map layout, List types) + private RowExpression toRowExpression(Expression expression, Map, Type> types, Map layout) { - ImmutableMap.Builder inputTypes = ImmutableMap.builder(); - for (Integer input : ImmutableSet.copyOf(layout.values())) { - Type type = types.get(input); - inputTypes.put(input, type); - } - return inputTypes.build(); + return SqlToRowExpressionTranslator.translate(expression, SCALAR, types, layout, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, true); } @Override @@ -1323,15 +1291,7 @@ public PhysicalOperation visitValues(ValuesNode node, LocalExecutionPlanContext PageBuilder pageBuilder = new PageBuilder(node.getRows().size(), outputTypes); for (List row : node.getRows()) { pageBuilder.declarePosition(); - Map, Type> expressionTypes = getExpressionTypes( - context.getSession(), - metadata, - sqlParser, - TypeProvider.empty(), - ImmutableList.copyOf(row), - emptyList(), - NOOP, - false); + Map, Type> expressionTypes = typeAnalyzer.getTypes(context.getSession(), TypeProvider.empty(), ImmutableList.copyOf(row)); for (int i = 0; i < row.size(); i++) { // evaluate the literal value Object result = ExpressionInterpreter.expressionInterpreter(row.get(i), metadata, context.getSession(), expressionTypes).evaluate(); @@ -1987,10 +1947,8 @@ private JoinBridgeManager createLookupSourceFact Optional sortChannel = sortExpressionContext .map(SortExpressionContext::getSortExpression) - .map(sortExpression -> sortExpressionAsSortChannel( - sortExpression, - probeSource.getLayout(), - buildSource.getLayout())); + .map(Symbol::from) + .map(sortSymbol -> createJoinSourcesLayout(buildSource.getLayout(), probeSource.getLayout()).get(sortSymbol)); List searchFunctionFactories = sortExpressionContext .map(SortExpressionContext::getSearchExpressions) @@ -2058,34 +2016,10 @@ private JoinFilterFunctionFactory compileJoinFilterFunction( { Map joinSourcesLayout = createJoinSourcesLayout(buildLayout, probeLayout); - Map sourceTypes = joinSourcesLayout.entrySet().stream() - .collect(toImmutableMap(Map.Entry::getValue, entry -> types.get(entry.getKey()))); - - Expression rewrittenFilter = new SymbolToInputRewriter(joinSourcesLayout).rewrite(filterExpression); - Map, Type> expressionTypes = getExpressionTypesFromInput( - session, - metadata, - sqlParser, - sourceTypes, - rewrittenFilter, - emptyList(), /* parameters have already been replaced */ - NOOP); - - RowExpression translatedFilter = toRowExpression(rewrittenFilter, expressionTypes); + RowExpression translatedFilter = toRowExpression(filterExpression, typeAnalyzer.getTypes(session, types, filterExpression), joinSourcesLayout); return joinFilterFunctionCompiler.compileJoinFilterFunction(translatedFilter, buildLayout.size()); } - private int sortExpressionAsSortChannel( - Expression sortExpression, - Map probeLayout, - Map buildLayout) - { - Map joinSourcesLayout = createJoinSourcesLayout(buildLayout, probeLayout); - Expression rewrittenSortExpression = new SymbolToInputRewriter(joinSourcesLayout).rewrite(sortExpression); - checkArgument(rewrittenSortExpression instanceof FieldReference, "Unsupported expression type [%s]", rewrittenSortExpression); - return ((FieldReference) rewrittenSortExpression).getFieldIndex(); - } - private OperatorFactory createLookupJoin( JoinNode node, PhysicalOperation probeSource, @@ -2343,7 +2277,7 @@ public PhysicalOperation visitDelete(DeleteNode node, LocalExecutionPlanContext @Override public PhysicalOperation visitMetadataDelete(MetadataDeleteNode node, LocalExecutionPlanContext context) { - OperatorFactory operatorFactory = new MetadataDeleteOperatorFactory(context.getNextOperatorId(), node.getId(), node.getTableLayout(), metadata, session, node.getTarget().getHandle()); + OperatorFactory operatorFactory = new MetadataDeleteOperatorFactory(context.getNextOperatorId(), node.getId(), metadata, session, node.getTarget().getHandle()); return new PhysicalOperation(operatorFactory, makeLayout(node), context, UNGROUPED_EXECUTION); } @@ -2593,17 +2527,10 @@ private AccumulatorFactory buildAccumulatorFactory( // expressions from lambda arguments .putAll(lambdaArgumentExpressionTypes) // expressions from lambda body - .putAll(getExpressionTypes( - session, - metadata, - sqlParser, - TypeProvider.copyOf(lambdaArgumentSymbolTypes), - lambdaExpression.getBody(), - emptyList(), - NOOP)) + .putAll(typeAnalyzer.getTypes(session, TypeProvider.copyOf(lambdaArgumentSymbolTypes), lambdaExpression.getBody())) .build(); - LambdaDefinitionExpression lambda = (LambdaDefinitionExpression) toRowExpression(lambdaExpression, expressionTypes); + LambdaDefinitionExpression lambda = (LambdaDefinitionExpression) toRowExpression(lambdaExpression, expressionTypes, ImmutableMap.of()); Class lambdaProviderClass = compileLambdaProvider(lambda, metadata.getFunctionRegistry(), lambdaInterfaces.get(i)); try { lambdaProviders.add((LambdaProvider) constructorMethodHandle(lambdaProviderClass, ConnectorSession.class).invoke(session.toConnectorSession())); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/LogicalPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/LogicalPlanner.java index 557685f6d1e..a6df3075bc2 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/LogicalPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/LogicalPlanner.java @@ -42,7 +42,6 @@ import io.prestosql.sql.analyzer.RelationId; import io.prestosql.sql.analyzer.RelationType; import io.prestosql.sql.analyzer.Scope; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.StatisticsAggregationPlanner.TableStatisticAggregation; import io.prestosql.sql.planner.optimizations.PlanOptimizer; import io.prestosql.sql.planner.plan.AggregationNode; @@ -115,7 +114,7 @@ public enum Stage private final PlanSanityChecker planSanityChecker; private final SymbolAllocator symbolAllocator = new SymbolAllocator(); private final Metadata metadata; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; private final StatisticsAggregationPlanner statisticsAggregationPlanner; private final StatsCalculator statsCalculator; private final CostCalculator costCalculator; @@ -125,12 +124,12 @@ public LogicalPlanner(Session session, List planOptimizers, PlanNodeIdAllocator idAllocator, Metadata metadata, - SqlParser sqlParser, + TypeAnalyzer typeAnalyzer, StatsCalculator statsCalculator, CostCalculator costCalculator, WarningCollector warningCollector) { - this(session, planOptimizers, DISTRIBUTED_PLAN_SANITY_CHECKER, idAllocator, metadata, sqlParser, statsCalculator, costCalculator, warningCollector); + this(session, planOptimizers, DISTRIBUTED_PLAN_SANITY_CHECKER, idAllocator, metadata, typeAnalyzer, statsCalculator, costCalculator, warningCollector); } public LogicalPlanner(Session session, @@ -138,7 +137,7 @@ public LogicalPlanner(Session session, PlanSanityChecker planSanityChecker, PlanNodeIdAllocator idAllocator, Metadata metadata, - SqlParser sqlParser, + TypeAnalyzer typeAnalyzer, StatsCalculator statsCalculator, CostCalculator costCalculator, WarningCollector warningCollector) @@ -148,7 +147,7 @@ public LogicalPlanner(Session session, this.planSanityChecker = requireNonNull(planSanityChecker, "planSanityChecker is null"); this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.metadata = requireNonNull(metadata, "metadata is null"); - this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.statisticsAggregationPlanner = new StatisticsAggregationPlanner(symbolAllocator, metadata); this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null"); this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); @@ -164,7 +163,7 @@ public Plan plan(Analysis analysis, Stage stage) { PlanNode root = planStatement(analysis, analysis.getStatement()); - planSanityChecker.validateIntermediatePlan(root, session, metadata, sqlParser, symbolAllocator.getTypes(), warningCollector); + planSanityChecker.validateIntermediatePlan(root, session, metadata, typeAnalyzer, symbolAllocator.getTypes(), warningCollector); if (stage.ordinal() >= Stage.OPTIMIZED.ordinal()) { for (PlanOptimizer optimizer : planOptimizers) { @@ -175,7 +174,7 @@ public Plan plan(Analysis analysis, Stage stage) if (stage.ordinal() >= Stage.OPTIMIZED_AND_VALIDATED.ordinal()) { // make sure we produce a valid plan after optimizations run. This is mainly to catch programming errors - planSanityChecker.validateFinalPlan(root, session, metadata, sqlParser, symbolAllocator.getTypes(), warningCollector); + planSanityChecker.validateFinalPlan(root, session, metadata, typeAnalyzer, symbolAllocator.getTypes(), warningCollector); } TypeProvider types = symbolAllocator.getTypes(); @@ -264,7 +263,7 @@ private RelationPlan createAnalyzePlan(Analysis analysis, Analyze analyzeStateme idAllocator.getNextId(), new AggregationNode( idAllocator.getNextId(), - new TableScanNode(idAllocator.getNextId(), targetTable, tableScanOutputs.build(), symbolToColumnHandle.build()), + TableScanNode.newInstance(idAllocator.getNextId(), targetTable, tableScanOutputs.build(), symbolToColumnHandle.build()), statisticAggregations.getAggregations(), singleGroupingSet(groupingSymbols), ImmutableList.of(), diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/PlanFragmenter.java b/presto-main/src/main/java/io/prestosql/sql/planner/PlanFragmenter.java index c84dbae7297..90ee22c5fb9 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/PlanFragmenter.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/PlanFragmenter.java @@ -21,9 +21,8 @@ import io.prestosql.cost.StatsAndCosts; import io.prestosql.execution.QueryManagerConfig; import io.prestosql.metadata.Metadata; -import io.prestosql.metadata.TableLayout; +import io.prestosql.metadata.TableHandle; import io.prestosql.metadata.TableLayout.TablePartitioning; -import io.prestosql.metadata.TableLayoutHandle; import io.prestosql.spi.PrestoException; import io.prestosql.spi.connector.ConnectorPartitionHandle; import io.prestosql.spi.connector.ConnectorPartitioningHandle; @@ -279,9 +278,8 @@ public PlanNode visitMetadataDelete(MetadataDeleteNode node, RewriteContext context) { - PartitioningHandle partitioning = node.getLayout() - .map(layout -> metadata.getLayout(session, layout)) - .flatMap(TableLayout::getTablePartitioning) + PartitioningHandle partitioning = metadata.getLayout(session, node.getTable()) + .getTablePartitioning() .map(TablePartitioning::getPartitioningHandle) .orElse(SOURCE_DISTRIBUTION); @@ -645,7 +643,7 @@ private GroupedExecutionProperties processWindowFunction(PlanNode node) @Override public GroupedExecutionProperties visitTableScan(TableScanNode node, Void context) { - Optional tablePartitioning = metadata.getLayout(session, node.getLayout().get()).getTablePartitioning(); + Optional tablePartitioning = metadata.getLayout(session, node.getTable()).getTablePartitioning(); if (!tablePartitioning.isPresent()) { return GroupedExecutionProperties.notCapable(); } @@ -750,9 +748,8 @@ public PartitioningHandleReassigner(PartitioningHandle fragmentPartitioningHandl @Override public PlanNode visitTableScan(TableScanNode node, RewriteContext context) { - PartitioningHandle partitioning = node.getLayout() - .map(layout -> metadata.getLayout(session, layout)) - .flatMap(TableLayout::getTablePartitioning) + PartitioningHandle partitioning = metadata.getLayout(session, node.getTable()) + .getTablePartitioning() .map(TablePartitioning::getPartitioningHandle) .orElse(SOURCE_DISTRIBUTION); if (partitioning.equals(fragmentPartitioningHandle)) { @@ -760,13 +757,12 @@ public PlanNode visitTableScan(TableScanNode node, RewriteContext context) return node; } - TableLayoutHandle newTableLayoutHandle = metadata.makeCompatiblePartitioning(session, node.getLayout().get(), fragmentPartitioningHandle); + TableHandle newTable = metadata.makeCompatiblePartitioning(session, node.getTable(), fragmentPartitioningHandle); return new TableScanNode( node.getId(), - node.getTable(), + newTable, node.getOutputSymbols(), node.getAssignments(), - Optional.of(newTableLayoutHandle), node.getCurrentConstraint(), node.getEnforcedConstraint()); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java index 7eb783582d4..a144efc1b6f 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java @@ -27,7 +27,6 @@ import io.prestosql.split.PageSourceManager; import io.prestosql.split.SplitManager; import io.prestosql.sql.analyzer.FeaturesConfig; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.iterative.IterativeOptimizer; import io.prestosql.sql.planner.iterative.Rule; import io.prestosql.sql.planner.iterative.rule.AddExchangesBelowPartialAggregationOverGroupIdRuleSet; @@ -55,7 +54,6 @@ import io.prestosql.sql.planner.iterative.rule.MergeLimitWithTopN; import io.prestosql.sql.planner.iterative.rule.MergeLimits; import io.prestosql.sql.planner.iterative.rule.MultipleDistinctAggregationToMarkDistinct; -import io.prestosql.sql.planner.iterative.rule.PickTableLayout; import io.prestosql.sql.planner.iterative.rule.PruneAggregationColumns; import io.prestosql.sql.planner.iterative.rule.PruneAggregationSourceColumns; import io.prestosql.sql.planner.iterative.rule.PruneCountAggregationOverScalar; @@ -76,12 +74,14 @@ import io.prestosql.sql.planner.iterative.rule.PruneValuesColumns; import io.prestosql.sql.planner.iterative.rule.PruneWindowColumns; import io.prestosql.sql.planner.iterative.rule.PushAggregationThroughOuterJoin; +import io.prestosql.sql.planner.iterative.rule.PushFilterIntoTableScan; import io.prestosql.sql.planner.iterative.rule.PushLimitThroughMarkDistinct; import io.prestosql.sql.planner.iterative.rule.PushLimitThroughOuterJoin; import io.prestosql.sql.planner.iterative.rule.PushLimitThroughProject; import io.prestosql.sql.planner.iterative.rule.PushLimitThroughSemiJoin; import io.prestosql.sql.planner.iterative.rule.PushPartialAggregationThroughExchange; import io.prestosql.sql.planner.iterative.rule.PushPartialAggregationThroughJoin; +import io.prestosql.sql.planner.iterative.rule.PushPredicateIntoTableScan; import io.prestosql.sql.planner.iterative.rule.PushProjectionThroughExchange; import io.prestosql.sql.planner.iterative.rule.PushProjectionThroughUnion; import io.prestosql.sql.planner.iterative.rule.PushRemoteExchangeThroughAssignUniqueId; @@ -145,7 +145,7 @@ public class PlanOptimizers @Inject public PlanOptimizers( Metadata metadata, - SqlParser sqlParser, + TypeAnalyzer typeAnalyzer, FeaturesConfig featuresConfig, NodeSchedulerConfig nodeSchedulerConfig, InternalNodeManager nodeManager, @@ -160,7 +160,7 @@ public PlanOptimizers( TaskCountEstimator taskCountEstimator) { this(metadata, - sqlParser, + typeAnalyzer, featuresConfig, taskManagerConfig, false, @@ -190,7 +190,7 @@ public void destroy() public PlanOptimizers( Metadata metadata, - SqlParser sqlParser, + TypeAnalyzer typeAnalyzer, FeaturesConfig featuresConfig, TaskManagerConfig taskManagerConfig, boolean forceSingleNode, @@ -204,6 +204,7 @@ public PlanOptimizers( TaskCountEstimator taskCountEstimator) { this.exporter = exporter; + ImmutableList.Builder builder = ImmutableList.builder(); Set> predicatePushDownRules = ImmutableSet.of( @@ -249,9 +250,9 @@ public PlanOptimizers( ruleStats, statsCalculator, estimatedExchangesCostCalculator, - new SimplifyExpressions(metadata, sqlParser).rules()); + new SimplifyExpressions(metadata, typeAnalyzer).rules()); - PlanOptimizer predicatePushDown = new StatsRecordingPlanOptimizer(optimizerStats, new PredicatePushDown(metadata, sqlParser)); + PlanOptimizer predicatePushDown = new StatsRecordingPlanOptimizer(optimizerStats, new PredicatePushDown(metadata, typeAnalyzer)); builder.add( // Clean up all the sugar in expressions, e.g. AtTimeZone, must be run before all the other optimizers @@ -261,7 +262,7 @@ public PlanOptimizers( estimatedExchangesCostCalculator, ImmutableSet.>builder() .addAll(new DesugarLambdaExpression().rules()) - .addAll(new DesugarAtTimeZone(metadata, sqlParser).rules()) + .addAll(new DesugarAtTimeZone(metadata, typeAnalyzer).rules()) .addAll(new DesugarCurrentUser().rules()) .addAll(new DesugarCurrentPath().rules()) .addAll(new DesugarTryExpression().rules()) @@ -298,7 +299,8 @@ public PlanOptimizers( new MergeLimitWithDistinct(), new PruneCountAggregationOverScalar(), new PruneOrderByInAggregation(metadata.getFunctionRegistry()), - new RewriteSpatialPartitioningAggregation(metadata))) + new RewriteSpatialPartitioningAggregation(metadata), + new PushFilterIntoTableScan(metadata, typeAnalyzer))) .build()), simplifyOptimizer, new UnaliasSymbolReferences(), @@ -357,7 +359,7 @@ public PlanOptimizers( ruleStats, statsCalculator, estimatedExchangesCostCalculator, - new PickTableLayout(metadata, sqlParser).rules()), + ImmutableSet.of(new PushPredicateIntoTableScan(metadata, typeAnalyzer))), new PruneUnreferencedOutputs(), new IterativeOptimizer( ruleStats, @@ -407,7 +409,7 @@ public PlanOptimizers( ruleStats, statsCalculator, estimatedExchangesCostCalculator, - new PickTableLayout(metadata, sqlParser).rules()), + ImmutableSet.of(new PushPredicateIntoTableScan(metadata, typeAnalyzer))), projectionPushDown, new PruneUnreferencedOutputs(), new IterativeOptimizer( @@ -440,7 +442,7 @@ public PlanOptimizers( costCalculator, ImmutableSet.>builder() .add(new RemoveRedundantIdentityProjections()) - .addAll(new ExtractSpatialJoins(metadata, splitManager, pageSourceManager, sqlParser).rules()) + .addAll(new ExtractSpatialJoins(metadata, splitManager, pageSourceManager, typeAnalyzer).rules()) .add(new InlineProjections()) .build())); @@ -461,7 +463,7 @@ public PlanOptimizers( statsCalculator, estimatedExchangesCostCalculator, ImmutableSet.of(new PushTableWriteThroughUnion()))); // Must run before AddExchanges - builder.add(new StatsRecordingPlanOptimizer(optimizerStats, new AddExchanges(metadata, sqlParser))); + builder.add(new StatsRecordingPlanOptimizer(optimizerStats, new AddExchanges(metadata, typeAnalyzer))); } //noinspection UnusedAssignment estimatedExchangesCostCalculator = null; // Prevent accidental use after AddExchanges @@ -491,7 +493,7 @@ public PlanOptimizers( .build())); // Optimizers above this don't understand local exchanges, so be careful moving this. - builder.add(new AddLocalExchanges(metadata, sqlParser)); + builder.add(new AddLocalExchanges(metadata, typeAnalyzer)); // Optimizers above this do not need to care about aggregations with the type other than SINGLE // This optimizer must be run after all exchange-related optimizers @@ -507,7 +509,7 @@ public PlanOptimizers( ruleStats, statsCalculator, costCalculator, - new AddExchangesBelowPartialAggregationOverGroupIdRuleSet(metadata, sqlParser, taskCountEstimator, taskManagerConfig).rules())); + new AddExchangesBelowPartialAggregationOverGroupIdRuleSet(metadata, typeAnalyzer, taskCountEstimator, taskManagerConfig).rules())); builder.add(new IterativeOptimizer( ruleStats, statsCalculator, diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/QueryPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/QueryPlanner.java index a880f0b1ff5..8efe82ed5bb 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/QueryPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/QueryPlanner.java @@ -218,7 +218,7 @@ public DeleteNode plan(Delete node) fields.add(rowIdField); // create table scan - PlanNode tableScan = new TableScanNode(idAllocator.getNextId(), handle, outputSymbols.build(), columns.build()); + PlanNode tableScan = TableScanNode.newInstance(idAllocator.getNextId(), handle, outputSymbols.build(), columns.build()); Scope scope = Scope.builder().withRelationType(RelationId.anonymous(), new RelationType(fields.build())).build(); RelationPlan relationPlan = new RelationPlan(tableScan, scope, outputSymbols.build()); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/RelationPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/RelationPlanner.java index 3f98e345be6..3b1b26304a9 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/RelationPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/RelationPlanner.java @@ -158,7 +158,7 @@ protected RelationPlan visitTable(Table node, Void context) } List outputSymbols = outputSymbolsBuilder.build(); - PlanNode root = new TableScanNode(idAllocator.getNextId(), handle, outputSymbols, columns.build()); + PlanNode root = TableScanNode.newInstance(idAllocator.getNextId(), handle, outputSymbols, columns.build()); return new RelationPlan(root, scope, outputSymbols); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/SymbolToInputParameterRewriter.java b/presto-main/src/main/java/io/prestosql/sql/planner/SymbolToInputParameterRewriter.java deleted file mode 100644 index bee2488bd71..00000000000 --- a/presto-main/src/main/java/io/prestosql/sql/planner/SymbolToInputParameterRewriter.java +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prestosql.sql.planner; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import io.prestosql.spi.type.Type; -import io.prestosql.sql.tree.Expression; -import io.prestosql.sql.tree.ExpressionRewriter; -import io.prestosql.sql.tree.ExpressionTreeRewriter; -import io.prestosql.sql.tree.FieldReference; -import io.prestosql.sql.tree.LambdaExpression; -import io.prestosql.sql.tree.SymbolReference; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static com.google.common.base.Preconditions.checkArgument; -import static java.util.Objects.requireNonNull; - -public class SymbolToInputParameterRewriter -{ - private final Map symbolToChannelMapping; - private final TypeProvider types; - - private final Map fieldToParameter = new HashMap<>(); - private final List inputChannels = new ArrayList<>(); - private final List inputTypes = new ArrayList<>(); - private int nextParameter; - - public List getInputChannels() - { - return ImmutableList.copyOf(inputChannels); - } - - public List getInputTypes() - { - return ImmutableList.copyOf(inputTypes); - } - - public SymbolToInputParameterRewriter(TypeProvider types, Map symbolToChannelMapping) - { - this.types = requireNonNull(types, "symbolToTypeMapping is null"); - - requireNonNull(symbolToChannelMapping, "symbolToChannelMapping is null"); - this.symbolToChannelMapping = ImmutableMap.copyOf(symbolToChannelMapping); - } - - public Expression rewrite(Expression expression) - { - return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter() - { - @Override - public Expression rewriteSymbolReference(SymbolReference node, Context context, ExpressionTreeRewriter treeRewriter) - { - Symbol symbol = Symbol.from(node); - Integer channel = symbolToChannelMapping.get(symbol); - if (channel == null) { - checkArgument(context.isInLambda(), "Cannot resolve symbol %s", node.getName()); - return node; - } - - Type type = types.get(symbol); - checkArgument(type != null, "Cannot resolve symbol %s", node.getName()); - - int parameter = fieldToParameter.computeIfAbsent(channel, field -> { - inputChannels.add(field); - inputTypes.add(type); - return nextParameter++; - }); - return new FieldReference(parameter); - } - - @Override - public Expression rewriteLambdaExpression(LambdaExpression node, Context context, ExpressionTreeRewriter treeRewriter) - { - return treeRewriter.defaultRewrite(node, new Context(true)); - } - }, expression, new Context(false)); - } - - private static class Context - { - private final boolean inLambda; - - public Context(boolean inLambda) - { - this.inLambda = inLambda; - } - - public boolean isInLambda() - { - return inLambda; - } - } -} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/SymbolToInputRewriter.java b/presto-main/src/main/java/io/prestosql/sql/planner/SymbolToInputRewriter.java deleted file mode 100644 index b7b563ea20d..00000000000 --- a/presto-main/src/main/java/io/prestosql/sql/planner/SymbolToInputRewriter.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prestosql.sql.planner; - -import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableMap; -import io.prestosql.sql.tree.Expression; -import io.prestosql.sql.tree.ExpressionRewriter; -import io.prestosql.sql.tree.ExpressionTreeRewriter; -import io.prestosql.sql.tree.FieldReference; -import io.prestosql.sql.tree.LambdaExpression; -import io.prestosql.sql.tree.SymbolReference; - -import java.util.Map; - -import static java.util.Objects.requireNonNull; - -public class SymbolToInputRewriter -{ - private final Map symbolToChannelMapping; - - public SymbolToInputRewriter(Map symbolToChannelMapping) - { - requireNonNull(symbolToChannelMapping, "symbolToChannelMapping is null"); - this.symbolToChannelMapping = ImmutableMap.copyOf(symbolToChannelMapping); - } - - public Expression rewrite(Expression expression) - { - return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter() - { - @Override - public Expression rewriteSymbolReference(SymbolReference node, Context context, ExpressionTreeRewriter treeRewriter) - { - Integer channel = symbolToChannelMapping.get(Symbol.from(node)); - if (channel == null) { - Preconditions.checkArgument(context.isInLambda(), "Cannot resolve symbol %s", node.getName()); - return node; - } - return new FieldReference(channel); - } - - @Override - public Expression rewriteLambdaExpression(LambdaExpression node, Context context, ExpressionTreeRewriter treeRewriter) - { - return treeRewriter.defaultRewrite(node, new Context(true)); - } - }, expression, new Context(false)); - } - - private static class Context - { - private final boolean inLambda; - - public Context(boolean inLambda) - { - this.inLambda = inLambda; - } - - public boolean isInLambda() - { - return inLambda; - } - } -} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/TypeAnalyzer.java b/presto-main/src/main/java/io/prestosql/sql/planner/TypeAnalyzer.java new file mode 100644 index 00000000000..b9b97cb1b17 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/TypeAnalyzer.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner; + +import com.google.common.collect.ImmutableList; +import io.prestosql.Session; +import io.prestosql.execution.warnings.WarningCollector; +import io.prestosql.metadata.Metadata; +import io.prestosql.spi.type.Type; +import io.prestosql.sql.analyzer.ExpressionAnalyzer; +import io.prestosql.sql.parser.SqlParser; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.NodeRef; + +import javax.inject.Inject; + +import java.util.Map; + +// This class is to facilitate obtaining the type of an expression and its subexpressions +// during planning (i.e., when interacting with IR expression). It will eventually get +// removed when we split the AST from the IR and we encode the type directly into IR expressions. +public class TypeAnalyzer +{ + private final SqlParser parser; + private final Metadata metadata; + + @Inject + public TypeAnalyzer(SqlParser parser, Metadata metadata) + { + this.parser = parser; + this.metadata = metadata; + } + + public Map, Type> getTypes(Session session, TypeProvider inputTypes, Iterable expressions) + { + return ExpressionAnalyzer.analyzeExpressions(session, metadata, parser, inputTypes, expressions, ImmutableList.of(), WarningCollector.NOOP, false).getExpressionTypes(); + } + + public Map, Type> getTypes(Session session, TypeProvider inputTypes, Expression expression) + { + return getTypes(session, inputTypes, ImmutableList.of(expression)); + } + + public Type getType(Session session, TypeProvider inputTypes, Expression expression) + { + return getTypes(session, inputTypes, expression).get(NodeRef.of(expression)); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java index 6e6639dab73..010b8b9adcf 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java @@ -26,10 +26,10 @@ import io.prestosql.matching.Captures; import io.prestosql.matching.Pattern; import io.prestosql.metadata.Metadata; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Partitioning; import io.prestosql.sql.planner.PartitioningScheme; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.iterative.Rule; import io.prestosql.sql.planner.optimizations.StreamPreferredProperties; import io.prestosql.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties; @@ -128,18 +128,18 @@ public class AddExchangesBelowPartialAggregationOverGroupIdRuleSet private static final double ANTI_SKEWNESS_MARGIN = 3; private final Metadata metadata; - private final SqlParser parser; + private final TypeAnalyzer typeAnalyzer; private final TaskCountEstimator taskCountEstimator; private final DataSize maxPartialAggregationMemoryUsage; public AddExchangesBelowPartialAggregationOverGroupIdRuleSet( Metadata metadata, - SqlParser parser, + TypeAnalyzer typeAnalyzer, TaskCountEstimator taskCountEstimator, TaskManagerConfig taskManagerConfig) { this.metadata = requireNonNull(metadata, "metadata is null"); - this.parser = requireNonNull(parser, "parser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.taskCountEstimator = requireNonNull(taskCountEstimator, "taskCountEstimator is null"); this.maxPartialAggregationMemoryUsage = requireNonNull(taskManagerConfig, "taskManagerConfig is null").getMaxPartialAggregationMemoryUsage(); } @@ -342,7 +342,7 @@ private StreamProperties derivePropertiesRecursively(PlanNode node, Context cont List inputProperties = resolvedPlanNode.getSources().stream() .map(source -> derivePropertiesRecursively(source, context)) .collect(toImmutableList()); - return deriveProperties(resolvedPlanNode, inputProperties, metadata, context.getSession(), context.getSymbolAllocator().getTypes(), parser); + return deriveProperties(resolvedPlanNode, inputProperties, metadata, context.getSession(), context.getSymbolAllocator().getTypes(), typeAnalyzer); } } } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/DesugarAtTimeZone.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/DesugarAtTimeZone.java index c2ebb7b3679..1ee472b74a5 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/DesugarAtTimeZone.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/DesugarAtTimeZone.java @@ -15,8 +15,8 @@ import com.google.common.collect.ImmutableSet; import io.prestosql.metadata.Metadata; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.DesugarAtTimeZoneRewriter; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.iterative.Rule; import java.util.Set; @@ -26,9 +26,9 @@ public class DesugarAtTimeZone extends ExpressionRewriteRuleSet { - public DesugarAtTimeZone(Metadata metadata, SqlParser sqlParser) + public DesugarAtTimeZone(Metadata metadata, TypeAnalyzer typeAnalyzer) { - super(createRewrite(metadata, sqlParser)); + super(createRewrite(metadata, typeAnalyzer)); } @Override @@ -42,11 +42,11 @@ public Set> rules() valuesExpressionRewrite()); } - private static ExpressionRewriter createRewrite(Metadata metadata, SqlParser sqlParser) + private static ExpressionRewriter createRewrite(Metadata metadata, TypeAnalyzer typeAnalyzer) { requireNonNull(metadata, "metadata is null"); - requireNonNull(sqlParser, "sqlParser is null"); + requireNonNull(typeAnalyzer, "typeAnalyzer is null"); - return (expression, context) -> DesugarAtTimeZoneRewriter.rewrite(expression, context.getSession(), metadata, sqlParser, context.getSymbolAllocator()); + return (expression, context) -> DesugarAtTimeZoneRewriter.rewrite(expression, context.getSession(), metadata, typeAnalyzer, context.getSymbolAllocator()); } } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractSpatialJoins.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractSpatialJoins.java index d02859f5e81..170ca629479 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractSpatialJoins.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractSpatialJoins.java @@ -21,7 +21,6 @@ import com.google.common.collect.Iterables; import io.prestosql.Session; import io.prestosql.execution.Lifespan; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.geospatial.KdbTree; import io.prestosql.geospatial.KdbTreeUtils; import io.prestosql.matching.Capture; @@ -31,12 +30,10 @@ import io.prestosql.metadata.QualifiedObjectName; import io.prestosql.metadata.Split; import io.prestosql.metadata.TableHandle; -import io.prestosql.metadata.TableLayoutResult; import io.prestosql.spi.Page; import io.prestosql.spi.PrestoException; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ConnectorPageSource; -import io.prestosql.spi.connector.Constraint; import io.prestosql.spi.type.ArrayType; import io.prestosql.spi.type.Type; import io.prestosql.spi.type.TypeSignature; @@ -44,8 +41,8 @@ import io.prestosql.split.SplitManager; import io.prestosql.split.SplitSource; import io.prestosql.split.SplitSource.SplitBatch; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.iterative.Rule; import io.prestosql.sql.planner.iterative.Rule.Context; import io.prestosql.sql.planner.iterative.Rule.Result; @@ -61,7 +58,6 @@ import io.prestosql.sql.tree.ComparisonExpression; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.FunctionCall; -import io.prestosql.sql.tree.NodeRef; import io.prestosql.sql.tree.QualifiedName; import io.prestosql.sql.tree.StringLiteral; import io.prestosql.sql.tree.SymbolReference; @@ -87,7 +83,6 @@ import static io.prestosql.spi.type.IntegerType.INTEGER; import static io.prestosql.spi.type.TypeSignature.parseTypeSignature; import static io.prestosql.spi.type.VarcharType.VARCHAR; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.planner.ExpressionNodeInliner.replaceExpression; import static io.prestosql.sql.planner.SymbolsExtractor.extractUnique; import static io.prestosql.sql.planner.plan.JoinNode.Type.INNER; @@ -100,7 +95,6 @@ import static io.prestosql.util.SpatialJoinUtils.extractSupportedSpatialComparisons; import static io.prestosql.util.SpatialJoinUtils.extractSupportedSpatialFunctions; import static java.lang.String.format; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; /** @@ -160,21 +154,21 @@ public class ExtractSpatialJoins private final Metadata metadata; private final SplitManager splitManager; private final PageSourceManager pageSourceManager; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; - public ExtractSpatialJoins(Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager, SqlParser sqlParser) + public ExtractSpatialJoins(Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager, TypeAnalyzer typeAnalyzer) { this.metadata = requireNonNull(metadata, "metadata is null"); this.splitManager = requireNonNull(splitManager, "splitManager is null"); this.pageSourceManager = requireNonNull(pageSourceManager, "pageSourceManager is null"); - this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); } public Set> rules() { return ImmutableSet.of( - new ExtractSpatialInnerJoin(metadata, splitManager, pageSourceManager, sqlParser), - new ExtractSpatialLeftJoin(metadata, splitManager, pageSourceManager, sqlParser)); + new ExtractSpatialInnerJoin(metadata, splitManager, pageSourceManager, typeAnalyzer), + new ExtractSpatialLeftJoin(metadata, splitManager, pageSourceManager, typeAnalyzer)); } @VisibleForTesting @@ -188,14 +182,14 @@ public static final class ExtractSpatialInnerJoin private final Metadata metadata; private final SplitManager splitManager; private final PageSourceManager pageSourceManager; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; - public ExtractSpatialInnerJoin(Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager, SqlParser sqlParser) + public ExtractSpatialInnerJoin(Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager, TypeAnalyzer typeAnalyzer) { this.metadata = requireNonNull(metadata, "metadata is null"); this.splitManager = requireNonNull(splitManager, "splitManager is null"); this.pageSourceManager = requireNonNull(pageSourceManager, "pageSourceManager is null"); - this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); } @Override @@ -217,7 +211,7 @@ public Result apply(FilterNode node, Captures captures, Context context) Expression filter = node.getPredicate(); List spatialFunctions = extractSupportedSpatialFunctions(filter); for (FunctionCall spatialFunction : spatialFunctions) { - Result result = tryCreateSpatialJoin(context, joinNode, filter, node.getId(), node.getOutputSymbols(), spatialFunction, Optional.empty(), metadata, splitManager, pageSourceManager, sqlParser); + Result result = tryCreateSpatialJoin(context, joinNode, filter, node.getId(), node.getOutputSymbols(), spatialFunction, Optional.empty(), metadata, splitManager, pageSourceManager, typeAnalyzer); if (!result.isEmpty()) { return result; } @@ -225,7 +219,7 @@ public Result apply(FilterNode node, Captures captures, Context context) List spatialComparisons = extractSupportedSpatialComparisons(filter); for (ComparisonExpression spatialComparison : spatialComparisons) { - Result result = tryCreateSpatialJoin(context, joinNode, filter, node.getId(), node.getOutputSymbols(), spatialComparison, metadata, splitManager, pageSourceManager, sqlParser); + Result result = tryCreateSpatialJoin(context, joinNode, filter, node.getId(), node.getOutputSymbols(), spatialComparison, metadata, splitManager, pageSourceManager, typeAnalyzer); if (!result.isEmpty()) { return result; } @@ -244,14 +238,14 @@ public static final class ExtractSpatialLeftJoin private final Metadata metadata; private final SplitManager splitManager; private final PageSourceManager pageSourceManager; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; - public ExtractSpatialLeftJoin(Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager, SqlParser sqlParser) + public ExtractSpatialLeftJoin(Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager, TypeAnalyzer typeAnalyzer) { this.metadata = requireNonNull(metadata, "metadata is null"); this.splitManager = requireNonNull(splitManager, "splitManager is null"); this.pageSourceManager = requireNonNull(pageSourceManager, "pageSourceManager is null"); - this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); } @Override @@ -272,7 +266,7 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) Expression filter = joinNode.getFilter().get(); List spatialFunctions = extractSupportedSpatialFunctions(filter); for (FunctionCall spatialFunction : spatialFunctions) { - Result result = tryCreateSpatialJoin(context, joinNode, filter, joinNode.getId(), joinNode.getOutputSymbols(), spatialFunction, Optional.empty(), metadata, splitManager, pageSourceManager, sqlParser); + Result result = tryCreateSpatialJoin(context, joinNode, filter, joinNode.getId(), joinNode.getOutputSymbols(), spatialFunction, Optional.empty(), metadata, splitManager, pageSourceManager, typeAnalyzer); if (!result.isEmpty()) { return result; } @@ -280,7 +274,7 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) List spatialComparisons = extractSupportedSpatialComparisons(filter); for (ComparisonExpression spatialComparison : spatialComparisons) { - Result result = tryCreateSpatialJoin(context, joinNode, filter, joinNode.getId(), joinNode.getOutputSymbols(), spatialComparison, metadata, splitManager, pageSourceManager, sqlParser); + Result result = tryCreateSpatialJoin(context, joinNode, filter, joinNode.getId(), joinNode.getOutputSymbols(), spatialComparison, metadata, splitManager, pageSourceManager, typeAnalyzer); if (!result.isEmpty()) { return result; } @@ -300,7 +294,7 @@ private static Result tryCreateSpatialJoin( Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager, - SqlParser sqlParser) + TypeAnalyzer typeAnalyzer) { PlanNode leftNode = joinNode.getLeft(); PlanNode rightNode = joinNode.getRight(); @@ -352,7 +346,7 @@ private static Result tryCreateSpatialJoin( joinNode.getDistributionType(), joinNode.isSpillable()); - return tryCreateSpatialJoin(context, newJoinNode, newFilter, nodeId, outputSymbols, (FunctionCall) newComparison.getLeft(), Optional.of(newComparison.getRight()), metadata, splitManager, pageSourceManager, sqlParser); + return tryCreateSpatialJoin(context, newJoinNode, newFilter, nodeId, outputSymbols, (FunctionCall) newComparison.getLeft(), Optional.of(newComparison.getRight()), metadata, splitManager, pageSourceManager, typeAnalyzer); } private static Result tryCreateSpatialJoin( @@ -366,7 +360,7 @@ private static Result tryCreateSpatialJoin( Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager, - SqlParser sqlParser) + TypeAnalyzer typeAnalyzer) { // TODO Add support for distributed left spatial joins Optional spatialPartitioningTableName = joinNode.getType() == INNER ? getSpatialPartitioningTableName(context.getSession()) : Optional.empty(); @@ -379,8 +373,8 @@ private static Result tryCreateSpatialJoin( Expression secondArgument = arguments.get(1); Type sphericalGeographyType = metadata.getType(SPHERICAL_GEOGRAPHY_TYPE_SIGNATURE); - if (getExpressionType(firstArgument, context, metadata, sqlParser).equals(sphericalGeographyType) - || getExpressionType(secondArgument, context, metadata, sqlParser).equals(sphericalGeographyType)) { + if (typeAnalyzer.getType(context.getSession(), context.getSymbolAllocator().getTypes(), firstArgument).equals(sphericalGeographyType) + || typeAnalyzer.getType(context.getSession(), context.getSymbolAllocator().getTypes(), secondArgument).equals(sphericalGeographyType)) { return Result.empty(); } @@ -448,14 +442,6 @@ else if (alignment < 0) { kdbTree.map(KdbTreeUtils::toJson))); } - private static Type getExpressionType(Expression expression, Context context, Metadata metadata, SqlParser sqlParser) - { - Type type = getExpressionTypes(context.getSession(), metadata, sqlParser, context.getSymbolAllocator().getTypes(), expression, emptyList(), WarningCollector.NOOP) - .get(NodeRef.of(expression)); - verify(type != null); - return type; - } - private static KdbTree loadKdbTree(String tableName, Session session, Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager) { QualifiedObjectName name = toQualifiedObjectName(tableName, session.getCatalog().get(), session.getSchema().get()); @@ -469,11 +455,8 @@ private static KdbTree loadKdbTree(String tableName, Session session, Metadata m ColumnHandle kdbTreeColumn = Iterables.getOnlyElement(visibleColumnHandles); - List layouts = metadata.getLayouts(session, tableHandle, Constraint.alwaysTrue(), Optional.of(ImmutableSet.of(kdbTreeColumn))); - checkSpatialPartitioningTable(!layouts.isEmpty(), "Table is empty: %s", name); - Optional kdbTree = Optional.empty(); - try (SplitSource splitSource = splitManager.getSplits(session, layouts.get(0).getLayout().getHandle(), UNGROUPED_SCHEDULING)) { + try (SplitSource splitSource = splitManager.getSplits(session, tableHandle, UNGROUPED_SCHEDULING)) { while (!Thread.currentThread().isInterrupted()) { SplitBatch splitBatch = getFutureValue(splitSource.getNextBatch(NOT_PARTITIONED, Lifespan.taskWide(), 1000)); List splits = splitBatch.getSplits(); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PickTableLayout.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PickTableLayout.java deleted file mode 100644 index a0a9c850f8e..00000000000 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PickTableLayout.java +++ /dev/null @@ -1,384 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prestosql.sql.planner.iterative.rule; - -import com.google.common.collect.ImmutableBiMap; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; -import io.prestosql.Session; -import io.prestosql.execution.warnings.WarningCollector; -import io.prestosql.matching.Capture; -import io.prestosql.matching.Captures; -import io.prestosql.matching.Pattern; -import io.prestosql.metadata.Metadata; -import io.prestosql.metadata.TableLayoutResult; -import io.prestosql.operator.scalar.TryFunction; -import io.prestosql.spi.connector.ColumnHandle; -import io.prestosql.spi.connector.Constraint; -import io.prestosql.spi.predicate.NullableValue; -import io.prestosql.spi.predicate.TupleDomain; -import io.prestosql.spi.type.Type; -import io.prestosql.sql.parser.SqlParser; -import io.prestosql.sql.planner.DomainTranslator; -import io.prestosql.sql.planner.ExpressionInterpreter; -import io.prestosql.sql.planner.LiteralEncoder; -import io.prestosql.sql.planner.LookupSymbolResolver; -import io.prestosql.sql.planner.PlanNodeIdAllocator; -import io.prestosql.sql.planner.Symbol; -import io.prestosql.sql.planner.SymbolsExtractor; -import io.prestosql.sql.planner.TypeProvider; -import io.prestosql.sql.planner.iterative.Rule; -import io.prestosql.sql.planner.plan.FilterNode; -import io.prestosql.sql.planner.plan.PlanNode; -import io.prestosql.sql.planner.plan.TableScanNode; -import io.prestosql.sql.planner.plan.ValuesNode; -import io.prestosql.sql.tree.Expression; -import io.prestosql.sql.tree.NodeRef; -import io.prestosql.sql.tree.NullLiteral; - -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.Set; - -import static com.google.common.base.Preconditions.checkState; -import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.ImmutableSet.toImmutableSet; -import static com.google.common.collect.Sets.intersection; -import static io.prestosql.SystemSessionProperties.isNewOptimizerEnabled; -import static io.prestosql.matching.Capture.newCapture; -import static io.prestosql.metadata.TableLayoutResult.computeEnforced; -import static io.prestosql.sql.ExpressionUtils.combineConjuncts; -import static io.prestosql.sql.ExpressionUtils.filterDeterministicConjuncts; -import static io.prestosql.sql.ExpressionUtils.filterNonDeterministicConjuncts; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; -import static io.prestosql.sql.planner.iterative.rule.PreconditionRules.checkRulesAreFiredBeforeAddExchangesRule; -import static io.prestosql.sql.planner.plan.Patterns.filter; -import static io.prestosql.sql.planner.plan.Patterns.source; -import static io.prestosql.sql.planner.plan.Patterns.tableScan; -import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL; -import static java.util.Collections.emptyList; -import static java.util.Objects.requireNonNull; -import static java.util.stream.Collectors.toList; - -/** - * These rules should not be run after AddExchanges so as not to overwrite the TableLayout - * chosen by AddExchanges - */ -public class PickTableLayout -{ - private final Metadata metadata; - private final SqlParser parser; - private final DomainTranslator domainTranslator; - - public PickTableLayout(Metadata metadata, SqlParser parser) - { - this.metadata = requireNonNull(metadata, "metadata is null"); - this.parser = requireNonNull(parser, "parser is null"); - this.domainTranslator = new DomainTranslator(new LiteralEncoder(metadata.getBlockEncodingSerde())); - } - - public Set> rules() - { - return ImmutableSet.of( - checkRulesAreFiredBeforeAddExchangesRule(), - pickTableLayoutForPredicate(), - pickTableLayoutWithoutPredicate()); - } - - public PickTableLayoutForPredicate pickTableLayoutForPredicate() - { - return new PickTableLayoutForPredicate(metadata, parser, domainTranslator); - } - - public PickTableLayoutWithoutPredicate pickTableLayoutWithoutPredicate() - { - return new PickTableLayoutWithoutPredicate(metadata, parser, domainTranslator); - } - - private static final class PickTableLayoutForPredicate - implements Rule - { - private final Metadata metadata; - private final SqlParser parser; - private final DomainTranslator domainTranslator; - - private PickTableLayoutForPredicate(Metadata metadata, SqlParser parser, DomainTranslator domainTranslator) - { - this.metadata = requireNonNull(metadata, "metadata is null"); - this.parser = requireNonNull(parser, "parser is null"); - this.domainTranslator = requireNonNull(domainTranslator, "domainTranslator is null"); - } - - private static final Capture TABLE_SCAN = newCapture(); - - private static final Pattern PATTERN = filter().with(source().matching( - tableScan().capturedAs(TABLE_SCAN))); - - @Override - public Pattern getPattern() - { - return PATTERN; - } - - @Override - public boolean isEnabled(Session session) - { - return isNewOptimizerEnabled(session); - } - - @Override - public Result apply(FilterNode filterNode, Captures captures, Context context) - { - TableScanNode tableScan = captures.get(TABLE_SCAN); - - PlanNode rewritten = planTableScan(tableScan, filterNode.getPredicate(), context.getSession(), context.getSymbolAllocator().getTypes(), context.getIdAllocator(), metadata, parser, domainTranslator); - - if (arePlansSame(filterNode, tableScan, rewritten)) { - return Result.empty(); - } - - return Result.ofPlanNode(rewritten); - } - - private boolean arePlansSame(FilterNode filter, TableScanNode tableScan, PlanNode rewritten) - { - if (!(rewritten instanceof FilterNode)) { - return false; - } - - FilterNode rewrittenFilter = (FilterNode) rewritten; - if (!Objects.equals(filter.getPredicate(), rewrittenFilter.getPredicate())) { - return false; - } - - if (!(rewrittenFilter.getSource() instanceof TableScanNode)) { - return false; - } - - TableScanNode rewrittenTableScan = (TableScanNode) rewrittenFilter.getSource(); - - if (!tableScan.getLayout().isPresent() && rewrittenTableScan.getLayout().isPresent()) { - return false; - } - - return Objects.equals(tableScan.getCurrentConstraint(), rewrittenTableScan.getCurrentConstraint()) - && Objects.equals(tableScan.getEnforcedConstraint(), rewrittenTableScan.getEnforcedConstraint()); - } - } - - private static final class PickTableLayoutWithoutPredicate - implements Rule - { - private final Metadata metadata; - private final SqlParser parser; - private final DomainTranslator domainTranslator; - - private PickTableLayoutWithoutPredicate(Metadata metadata, SqlParser parser, DomainTranslator domainTranslator) - { - this.metadata = requireNonNull(metadata, "metadata is null"); - this.parser = requireNonNull(parser, "parser is null"); - this.domainTranslator = requireNonNull(domainTranslator, "domainTranslator is null"); - } - - private static final Pattern PATTERN = tableScan(); - - @Override - public Pattern getPattern() - { - return PATTERN; - } - - @Override - public boolean isEnabled(Session session) - { - return isNewOptimizerEnabled(session); - } - - @Override - public Result apply(TableScanNode tableScanNode, Captures captures, Context context) - { - if (tableScanNode.getLayout().isPresent()) { - return Result.empty(); - } - - return Result.ofPlanNode(planTableScan(tableScanNode, TRUE_LITERAL, context.getSession(), context.getSymbolAllocator().getTypes(), context.getIdAllocator(), metadata, parser, domainTranslator)); - } - } - - private static PlanNode planTableScan( - TableScanNode node, - Expression predicate, - Session session, - TypeProvider types, - PlanNodeIdAllocator idAllocator, - Metadata metadata, - SqlParser parser, - DomainTranslator domainTranslator) - { - return listTableLayouts( - node, - predicate, - false, - session, - types, - idAllocator, - metadata, - parser, - domainTranslator) - .get(0); - } - - public static List listTableLayouts( - TableScanNode node, - Expression predicate, - boolean pruneWithPredicateExpression, - Session session, - TypeProvider types, - PlanNodeIdAllocator idAllocator, - Metadata metadata, - SqlParser parser, - DomainTranslator domainTranslator) - { - // don't include non-deterministic predicates - Expression deterministicPredicate = filterDeterministicConjuncts(predicate); - - DomainTranslator.ExtractionResult decomposedPredicate = DomainTranslator.fromPredicate( - metadata, - session, - deterministicPredicate, - types); - - TupleDomain newDomain = decomposedPredicate.getTupleDomain() - .transform(node.getAssignments()::get) - .intersect(node.getEnforcedConstraint()); - - Map assignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse(); - - Constraint constraint; - if (pruneWithPredicateExpression) { - LayoutConstraintEvaluator evaluator = new LayoutConstraintEvaluator( - metadata, - parser, - session, - types, - node.getAssignments(), - combineConjuncts( - deterministicPredicate, - // Simplify the tuple domain to avoid creating an expression with too many nodes, - // which would be expensive to evaluate in the call to isCandidate below. - domainTranslator.toPredicate(newDomain.simplify().transform(assignments::get)))); - constraint = new Constraint<>(newDomain, evaluator::isCandidate); - } - else { - // Currently, invoking the expression interpreter is very expensive. - // TODO invoke the interpreter unconditionally when the interpreter becomes cheap enough. - constraint = new Constraint<>(newDomain); - } - - // Layouts will be returned in order of the connector's preference - List layouts = metadata.getLayouts( - session, - node.getTable(), - constraint, - Optional.of(node.getOutputSymbols().stream() - .map(node.getAssignments()::get) - .collect(toImmutableSet()))); - - if (layouts.isEmpty()) { - return ImmutableList.of(new ValuesNode(idAllocator.getNextId(), node.getOutputSymbols(), ImmutableList.of())); - } - - // Filter out layouts that cannot supply all the required columns - layouts = layouts.stream() - .filter(layout -> layout.hasAllOutputs(node)) - .collect(toList()); - checkState(!layouts.isEmpty(), "No usable layouts for %s", node); - - if (layouts.stream().anyMatch(layout -> layout.getLayout().getPredicate().isNone())) { - return ImmutableList.of(new ValuesNode(idAllocator.getNextId(), node.getOutputSymbols(), ImmutableList.of())); - } - - return layouts.stream() - .map(layout -> { - TableScanNode tableScan = new TableScanNode( - node.getId(), - node.getTable(), - node.getOutputSymbols(), - node.getAssignments(), - Optional.of(layout.getLayout().getHandle()), - layout.getLayout().getPredicate(), - computeEnforced(newDomain, layout.getUnenforcedConstraint())); - - // The order of the arguments to combineConjuncts matters: - // * Unenforced constraints go first because they can only be simple column references, - // which are not prone to logic errors such as out-of-bound access, div-by-zero, etc. - // * Conjuncts in non-deterministic expressions and non-TupleDomain-expressible expressions should - // retain their original (maybe intermixed) order from the input predicate. However, this is not implemented yet. - // * Short of implementing the previous bullet point, the current order of non-deterministic expressions - // and non-TupleDomain-expressible expressions should be retained. Changing the order can lead - // to failures of previously successful queries. - Expression resultingPredicate = combineConjuncts( - domainTranslator.toPredicate(layout.getUnenforcedConstraint().transform(assignments::get)), - filterNonDeterministicConjuncts(predicate), - decomposedPredicate.getRemainingExpression()); - - if (!TRUE_LITERAL.equals(resultingPredicate)) { - return new FilterNode(idAllocator.getNextId(), tableScan, resultingPredicate); - } - - return tableScan; - }) - .collect(toImmutableList()); - } - - private static class LayoutConstraintEvaluator - { - private final Map assignments; - private final ExpressionInterpreter evaluator; - private final Set arguments; - - public LayoutConstraintEvaluator(Metadata metadata, SqlParser parser, Session session, TypeProvider types, Map assignments, Expression expression) - { - this.assignments = assignments; - - Map, Type> expressionTypes = getExpressionTypes(session, metadata, parser, types, expression, emptyList(), WarningCollector.NOOP); - - evaluator = ExpressionInterpreter.expressionOptimizer(expression, metadata, session, expressionTypes); - arguments = SymbolsExtractor.extractUnique(expression).stream() - .map(assignments::get) - .collect(toImmutableSet()); - } - - private boolean isCandidate(Map bindings) - { - if (intersection(bindings.keySet(), arguments).isEmpty()) { - return true; - } - LookupSymbolResolver inputs = new LookupSymbolResolver(assignments, bindings); - - // Skip pruning if evaluation fails in a recoverable way. Failing here can cause - // spurious query failures for partitions that would otherwise be filtered out. - Object optimized = TryFunction.evaluate(() -> evaluator.optimize(inputs), true); - - // If any conjuncts evaluate to FALSE or null, then the whole predicate will never be true and so the partition should be pruned - if (Boolean.FALSE.equals(optimized) || optimized == null || optimized instanceof NullLiteral) { - return false; - } - - return true; - } - } -} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PruneIndexSourceColumns.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PruneIndexSourceColumns.java index 6fa88ff5e9b..b6e0c5bc623 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PruneIndexSourceColumns.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PruneIndexSourceColumns.java @@ -60,7 +60,6 @@ protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, indexSourceNode.getId(), indexSourceNode.getIndexHandle(), indexSourceNode.getTableHandle(), - indexSourceNode.getLayout(), prunedLookupSymbols, prunedOutputList, prunedAssignments, diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PruneTableScanColumns.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PruneTableScanColumns.java index e8cacd4c31f..928fa34367f 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PruneTableScanColumns.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PruneTableScanColumns.java @@ -42,7 +42,6 @@ protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, tableScanNode.getTable(), filteredCopy(tableScanNode.getOutputSymbols(), referencedOutputs::contains), filterKeys(tableScanNode.getAssignments(), referencedOutputs::contains), - tableScanNode.getLayout(), tableScanNode.getCurrentConstraint(), tableScanNode.getEnforcedConstraint())); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushFilterIntoTableScan.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushFilterIntoTableScan.java new file mode 100644 index 00000000000..0259b8b173b --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushFilterIntoTableScan.java @@ -0,0 +1,114 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import io.prestosql.matching.Capture; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.metadata.FilterApplicationResult; +import io.prestosql.metadata.Metadata; +import io.prestosql.spi.connector.ColumnHandle; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; +import io.prestosql.spi.expression.ConnectorExpression; +import io.prestosql.spi.expression.ConnectorExpressionTranslator; +import io.prestosql.sql.planner.iterative.Rule; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.FilterNode; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.planner.plan.TableScanNode; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.sql.planner.plan.Patterns.filter; +import static io.prestosql.sql.planner.plan.Patterns.source; +import static io.prestosql.sql.planner.plan.Patterns.tableScan; + +public class PushFilterIntoTableScan + implements Rule +{ + private static final Capture TABLE_SCAN = newCapture(); + private static final Pattern PATTERN = filter().with(source().matching( + tableScan().capturedAs(TABLE_SCAN))); + + private final Metadata metadata; + private final TypeAnalyzer typeAnalyzer; + + public PushFilterIntoTableScan(Metadata metadata, TypeAnalyzer typeAnalyzer) + { + this.metadata = metadata; + this.typeAnalyzer = typeAnalyzer; + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(FilterNode filter, Captures captures, Context context) + { + TableScanNode tableScan = captures.get(TABLE_SCAN); + + ConnectorExpression expression = ConnectorExpressionTranslator.translate( + context.getSession(), + filter.getPredicate(), + tableScan.getAssignments(), + typeAnalyzer, + context.getSymbolAllocator().getTypes(), + metadata); + + Optional result = metadata.applyFilter(tableScan.getTable(), expression); + if (!result.isPresent()) { + return Result.empty(); + } + + Map mappings = new HashMap<>(); + for (Map.Entry assignment : tableScan.getAssignments().entrySet()) { + mappings.put(assignment.getValue(), assignment.getKey()); + } + + List newOutputs = new ArrayList<>(); + Map newAssignments = new HashMap<>(); + + newOutputs.addAll(tableScan.getOutputSymbols()); + newAssignments.putAll(tableScan.getAssignments()); + for (FilterApplicationResult.Column newProjection : result.get().getNewProjections()) { + Symbol symbol = context.getSymbolAllocator().newSymbol("column", newProjection.getType()); + + mappings.put(newProjection.getColumn(), symbol); + newOutputs.add(symbol); + newAssignments.put(symbol, newProjection.getColumn()); + } + + return Result.ofPlanNode( + new ProjectNode( // to preserve the schema of the transformed output + context.getIdAllocator().getNextId(), + new FilterNode( + filter.getId(), + TableScanNode.newInstance( + tableScan.getId(), + result.get().getTable(), + newOutputs, + newAssignments), + ConnectorExpressionTranslator.translate(result.get().getRemainingFilter(), mappings, metadata)), + Assignments.identity(filter.getOutputSymbols()))); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushPredicateIntoTableScan.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushPredicateIntoTableScan.java new file mode 100644 index 00000000000..07625f27ee5 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushPredicateIntoTableScan.java @@ -0,0 +1,266 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableBiMap; +import com.google.common.collect.ImmutableList; +import io.prestosql.Session; +import io.prestosql.matching.Capture; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.metadata.Metadata; +import io.prestosql.metadata.TableLayoutResult; +import io.prestosql.operator.scalar.TryFunction; +import io.prestosql.spi.connector.ColumnHandle; +import io.prestosql.spi.connector.Constraint; +import io.prestosql.spi.predicate.NullableValue; +import io.prestosql.spi.predicate.TupleDomain; +import io.prestosql.sql.planner.DomainTranslator; +import io.prestosql.sql.planner.ExpressionInterpreter; +import io.prestosql.sql.planner.LiteralEncoder; +import io.prestosql.sql.planner.LookupSymbolResolver; +import io.prestosql.sql.planner.PlanNodeIdAllocator; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.SymbolsExtractor; +import io.prestosql.sql.planner.TypeAnalyzer; +import io.prestosql.sql.planner.TypeProvider; +import io.prestosql.sql.planner.iterative.Rule; +import io.prestosql.sql.planner.plan.FilterNode; +import io.prestosql.sql.planner.plan.PlanNode; +import io.prestosql.sql.planner.plan.TableScanNode; +import io.prestosql.sql.planner.plan.ValuesNode; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.NullLiteral; + +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Sets.intersection; +import static io.prestosql.SystemSessionProperties.isNewOptimizerEnabled; +import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.metadata.TableLayoutResult.computeEnforced; +import static io.prestosql.sql.ExpressionUtils.combineConjuncts; +import static io.prestosql.sql.ExpressionUtils.filterDeterministicConjuncts; +import static io.prestosql.sql.ExpressionUtils.filterNonDeterministicConjuncts; +import static io.prestosql.sql.planner.plan.Patterns.filter; +import static io.prestosql.sql.planner.plan.Patterns.source; +import static io.prestosql.sql.planner.plan.Patterns.tableScan; +import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL; +import static java.util.Objects.requireNonNull; + +/** + * These rules should not be run after AddExchanges so as not to overwrite the TableLayout + * chosen by AddExchanges + */ +public class PushPredicateIntoTableScan + implements Rule +{ + private static final Capture TABLE_SCAN = newCapture(); + + private static final Pattern PATTERN = filter().with(source().matching( + tableScan().capturedAs(TABLE_SCAN))); + + private final Metadata metadata; + private final TypeAnalyzer typeAnalyzer; + private final DomainTranslator domainTranslator; + + public PushPredicateIntoTableScan(Metadata metadata, TypeAnalyzer typeAnalyzer) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + this.domainTranslator = new DomainTranslator(new LiteralEncoder(metadata.getBlockEncodingSerde())); + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public boolean isEnabled(Session session) + { + return isNewOptimizerEnabled(session); + } + + @Override + public Result apply(FilterNode filterNode, Captures captures, Context context) + { + TableScanNode tableScan = captures.get(TABLE_SCAN); + + PlanNode rewritten = pushFilterIntoTableScan( + tableScan, + filterNode.getPredicate(), + false, + context.getSession(), + context.getSymbolAllocator().getTypes(), + context.getIdAllocator(), + metadata, + typeAnalyzer, + domainTranslator); + + if (arePlansSame(filterNode, tableScan, rewritten)) { + return Result.empty(); + } + + return Result.ofPlanNode(rewritten); + } + + private boolean arePlansSame(FilterNode filter, TableScanNode tableScan, PlanNode rewritten) + { + if (!(rewritten instanceof FilterNode)) { + return false; + } + + FilterNode rewrittenFilter = (FilterNode) rewritten; + if (!Objects.equals(filter.getPredicate(), rewrittenFilter.getPredicate())) { + return false; + } + + if (!(rewrittenFilter.getSource() instanceof TableScanNode)) { + return false; + } + + TableScanNode rewrittenTableScan = (TableScanNode) rewrittenFilter.getSource(); + + return Objects.equals(tableScan.getCurrentConstraint(), rewrittenTableScan.getCurrentConstraint()) + && Objects.equals(tableScan.getEnforcedConstraint(), rewrittenTableScan.getEnforcedConstraint()); + } + + public static PlanNode pushFilterIntoTableScan( + TableScanNode node, + Expression predicate, + boolean pruneWithPredicateExpression, + Session session, + TypeProvider types, + PlanNodeIdAllocator idAllocator, + Metadata metadata, + TypeAnalyzer typeAnalyzer, + DomainTranslator domainTranslator) + { + // don't include non-deterministic predicates + Expression deterministicPredicate = filterDeterministicConjuncts(predicate); + + DomainTranslator.ExtractionResult decomposedPredicate = DomainTranslator.fromPredicate( + metadata, + session, + deterministicPredicate, + types); + + TupleDomain newDomain = decomposedPredicate.getTupleDomain() + .transform(node.getAssignments()::get) + .intersect(node.getEnforcedConstraint()); + + Map assignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse(); + + Constraint constraint; + if (pruneWithPredicateExpression) { + LayoutConstraintEvaluator evaluator = new LayoutConstraintEvaluator( + metadata, + typeAnalyzer, + session, + types, + node.getAssignments(), + combineConjuncts( + deterministicPredicate, + // Simplify the tuple domain to avoid creating an expression with too many nodes, + // which would be expensive to evaluate in the call to isCandidate below. + domainTranslator.toPredicate(newDomain.simplify().transform(assignments::get)))); + constraint = new Constraint<>(newDomain, evaluator::isCandidate); + } + else { + // Currently, invoking the expression interpreter is very expensive. + // TODO invoke the interpreter unconditionally when the interpreter becomes cheap enough. + constraint = new Constraint<>(newDomain); + } + + Optional layout = metadata.getLayout( + session, + node.getTable(), + constraint, + Optional.of(node.getOutputSymbols().stream() + .map(node.getAssignments()::get) + .collect(toImmutableSet()))); + + if (!layout.isPresent() || layout.get().getLayout().getPredicate().isNone()) { + return new ValuesNode(idAllocator.getNextId(), node.getOutputSymbols(), ImmutableList.of()); + } + + TableScanNode tableScan = new TableScanNode( + node.getId(), + layout.get().getNewTableHandle(), + node.getOutputSymbols(), + node.getAssignments(), + layout.get().getLayout().getPredicate(), + computeEnforced(newDomain, layout.get().getUnenforcedConstraint())); + + // The order of the arguments to combineConjuncts matters: + // * Unenforced constraints go first because they can only be simple column references, + // which are not prone to logic errors such as out-of-bound access, div-by-zero, etc. + // * Conjuncts in non-deterministic expressions and non-TupleDomain-expressible expressions should + // retain their original (maybe intermixed) order from the input predicate. However, this is not implemented yet. + // * Short of implementing the previous bullet point, the current order of non-deterministic expressions + // and non-TupleDomain-expressible expressions should be retained. Changing the order can lead + // to failures of previously successful queries. + Expression resultingPredicate = combineConjuncts( + domainTranslator.toPredicate(layout.get().getUnenforcedConstraint().transform(assignments::get)), + filterNonDeterministicConjuncts(predicate), + decomposedPredicate.getRemainingExpression()); + + if (!TRUE_LITERAL.equals(resultingPredicate)) { + return new FilterNode(idAllocator.getNextId(), tableScan, resultingPredicate); + } + + return tableScan; + } + + private static class LayoutConstraintEvaluator + { + private final Map assignments; + private final ExpressionInterpreter evaluator; + private final Set arguments; + + public LayoutConstraintEvaluator(Metadata metadata, TypeAnalyzer typeAnalyzer, Session session, TypeProvider types, Map assignments, Expression expression) + { + this.assignments = assignments; + + evaluator = ExpressionInterpreter.expressionOptimizer(expression, metadata, session, typeAnalyzer.getTypes(session, types, expression)); + arguments = SymbolsExtractor.extractUnique(expression).stream() + .map(assignments::get) + .collect(toImmutableSet()); + } + + private boolean isCandidate(Map bindings) + { + if (intersection(bindings.keySet(), arguments).isEmpty()) { + return true; + } + LookupSymbolResolver inputs = new LookupSymbolResolver(assignments, bindings); + + // Skip pruning if evaluation fails in a recoverable way. Failing here can cause + // spurious query failures for partitions that would otherwise be filtered out. + Object optimized = TryFunction.evaluate(() -> evaluator.optimize(inputs), true); + + // If any conjuncts evaluate to FALSE or null, then the whole predicate will never be true and so the partition should be pruned + if (Boolean.FALSE.equals(optimized) || optimized == null || optimized instanceof NullLiteral) { + return false; + } + + return true; + } + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/SimplifyExpressions.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/SimplifyExpressions.java index 0ccf965931e..95c39cd6495 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/SimplifyExpressions.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/SimplifyExpressions.java @@ -16,14 +16,13 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableSet; import io.prestosql.Session; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.spi.type.Type; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.ExpressionInterpreter; import io.prestosql.sql.planner.LiteralEncoder; import io.prestosql.sql.planner.NoOpSymbolResolver; import io.prestosql.sql.planner.SymbolAllocator; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.iterative.Rule; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.NodeRef; @@ -32,33 +31,31 @@ import java.util.Map; import java.util.Set; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.planner.iterative.rule.ExtractCommonPredicatesExpressionRewriter.extractCommonPredicates; import static io.prestosql.sql.planner.iterative.rule.PushDownNegationsExpressionRewriter.pushDownNegations; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; public class SimplifyExpressions extends ExpressionRewriteRuleSet { @VisibleForTesting - static Expression rewrite(Expression expression, Session session, SymbolAllocator symbolAllocator, Metadata metadata, LiteralEncoder literalEncoder, SqlParser sqlParser) + static Expression rewrite(Expression expression, Session session, SymbolAllocator symbolAllocator, Metadata metadata, LiteralEncoder literalEncoder, TypeAnalyzer typeAnalyzer) { requireNonNull(metadata, "metadata is null"); - requireNonNull(sqlParser, "sqlParser is null"); + requireNonNull(typeAnalyzer, "typeAnalyzer is null"); if (expression instanceof SymbolReference) { return expression; } expression = pushDownNegations(expression); expression = extractCommonPredicates(expression); - Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, symbolAllocator.getTypes(), expression, emptyList(), WarningCollector.NOOP); + Map, Type> expressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), expression); ExpressionInterpreter interpreter = ExpressionInterpreter.expressionOptimizer(expression, metadata, session, expressionTypes); return literalEncoder.toExpression(interpreter.optimize(NoOpSymbolResolver.INSTANCE), expressionTypes.get(NodeRef.of(expression))); } - public SimplifyExpressions(Metadata metadata, SqlParser sqlParser) + public SimplifyExpressions(Metadata metadata, TypeAnalyzer typeAnalyzer) { - super(createRewrite(metadata, sqlParser)); + super(createRewrite(metadata, typeAnalyzer)); } @Override @@ -71,12 +68,12 @@ public Set> rules() valuesExpressionRewrite()); // ApplyNode and AggregationNode are not supported, because ExpressionInterpreter doesn't support them } - private static ExpressionRewriter createRewrite(Metadata metadata, SqlParser sqlParser) + private static ExpressionRewriter createRewrite(Metadata metadata, TypeAnalyzer typeAnalyzer) { requireNonNull(metadata, "metadata is null"); - requireNonNull(sqlParser, "sqlParser is null"); + requireNonNull(typeAnalyzer, "typeAnalyzer is null"); LiteralEncoder literalEncoder = new LiteralEncoder(metadata.getBlockEncodingSerde()); - return (expression, context) -> rewrite(expression, context.getSession(), context.getSymbolAllocator(), metadata, literalEncoder, sqlParser); + return (expression, context) -> rewrite(expression, context.getSession(), context.getSymbolAllocator(), metadata, literalEncoder, typeAnalyzer); } } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddExchanges.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddExchanges.java index 14e1cb96fda..57f98b9d806 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddExchanges.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddExchanges.java @@ -14,10 +14,6 @@ package io.prestosql.sql.planner.optimizations; import com.google.common.annotations.VisibleForTesting; -import com.google.common.cache.CacheBuilder; -import com.google.common.cache.CacheLoader; -import com.google.common.cache.LoadingCache; -import com.google.common.collect.ComparisonChain; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableSet; @@ -30,7 +26,6 @@ import io.prestosql.spi.connector.GroupingProperty; import io.prestosql.spi.connector.LocalProperty; import io.prestosql.spi.connector.SortingProperty; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.DomainTranslator; import io.prestosql.sql.planner.LiteralEncoder; import io.prestosql.sql.planner.Partitioning; @@ -38,8 +33,9 @@ import io.prestosql.sql.planner.PlanNodeIdAllocator; import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.SymbolAllocator; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; -import io.prestosql.sql.planner.iterative.rule.PickTableLayout; +import io.prestosql.sql.planner.iterative.rule.PushPredicateIntoTableScan; import io.prestosql.sql.planner.plan.AggregationNode; import io.prestosql.sql.planner.plan.ApplyNode; import io.prestosql.sql.planner.plan.Assignments; @@ -78,10 +74,7 @@ import io.prestosql.sql.tree.SymbolReference; import java.util.ArrayList; -import java.util.Collections; -import java.util.Comparator; import java.util.HashMap; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; @@ -89,7 +82,6 @@ import java.util.function.Function; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; @@ -120,15 +112,15 @@ public class AddExchanges implements PlanOptimizer { - private final SqlParser parser; + private final TypeAnalyzer typeAnalyzer; private final Metadata metadata; private final DomainTranslator domainTranslator; - public AddExchanges(Metadata metadata, SqlParser parser) + public AddExchanges(Metadata metadata, TypeAnalyzer typeAnalyzer) { this.metadata = metadata; this.domainTranslator = new DomainTranslator(new LiteralEncoder(metadata.getBlockEncodingSerde())); - this.parser = parser; + this.typeAnalyzer = typeAnalyzer; } @Override @@ -499,7 +491,7 @@ public PlanWithProperties visitDistinctLimit(DistinctLimitNode node, PreferredPr public PlanWithProperties visitFilter(FilterNode node, PreferredProperties preferredProperties) { if (node.getSource() instanceof TableScanNode) { - return planTableScan((TableScanNode) node.getSource(), node.getPredicate(), preferredProperties); + return planTableScan((TableScanNode) node.getSource(), node.getPredicate()); } return rebaseAndDeriveProperties(node, planChild(node, preferredProperties)); @@ -508,7 +500,7 @@ public PlanWithProperties visitFilter(FilterNode node, PreferredProperties prefe @Override public PlanWithProperties visitTableScan(TableScanNode node, PreferredProperties preferredProperties) { - return planTableScan(node, TRUE_LITERAL, preferredProperties); + return planTableScan(node, TRUE_LITERAL); } @Override @@ -538,28 +530,10 @@ else if (redistributeWrites) { return rebaseAndDeriveProperties(node, source); } - private PlanWithProperties planTableScan(TableScanNode node, Expression predicate, PreferredProperties preferredProperties) + private PlanWithProperties planTableScan(TableScanNode node, Expression predicate) { - List possiblePlans = PickTableLayout.listTableLayouts(node, predicate, true, session, types, idAllocator, metadata, parser, domainTranslator); - List possiblePlansWithProperties = possiblePlans.stream() - .map(planNode -> new PlanWithProperties(planNode, derivePropertiesRecursively(planNode))) - .collect(toImmutableList()); - return pickPlan(possiblePlansWithProperties, preferredProperties); - } - - /** - * possiblePlans should be provided in layout preference order - */ - private PlanWithProperties pickPlan(List possiblePlans, PreferredProperties preferredProperties) - { - checkArgument(!possiblePlans.isEmpty()); - - if (preferStreamingOperators) { - possiblePlans = new ArrayList<>(possiblePlans); - Collections.sort(possiblePlans, Comparator.comparing(PlanWithProperties::getProperties, streamingExecutionPreference(preferredProperties))); // stable sort; is Collections.min() guaranteed to be stable? - } - - return possiblePlans.get(0); + PlanNode plan = PushPredicateIntoTableScan.pushFilterIntoTableScan(node, predicate, true, session, types, idAllocator, metadata, typeAnalyzer, domainTranslator); + return new PlanWithProperties(plan, derivePropertiesRecursively(plan)); } @Override @@ -1216,7 +1190,7 @@ private ActualProperties deriveProperties(PlanNode result, ActualProperties inpu private ActualProperties deriveProperties(PlanNode result, List inputProperties) { // TODO: move this logic to PlanSanityChecker once PropertyDerivations.deriveProperties fully supports local exchanges - ActualProperties outputProperties = PropertyDerivations.deriveProperties(result, inputProperties, metadata, session, types, parser); + ActualProperties outputProperties = PropertyDerivations.deriveProperties(result, inputProperties, metadata, session, types, typeAnalyzer); verify(result instanceof SemiJoinNode || inputProperties.stream().noneMatch(ActualProperties::isNullsAndAnyReplicated) || outputProperties.isNullsAndAnyReplicated(), "SemiJoinNode is the only node that can strip null replication"); return outputProperties; @@ -1224,7 +1198,7 @@ private ActualProperties deriveProperties(PlanNode result, List computeIdentityTranslations(Assignments assig return outputToInput; } - @VisibleForTesting - static Comparator streamingExecutionPreference(PreferredProperties preferred) - { - // Calculating the matches can be a bit expensive, so cache the results between comparisons - LoadingCache>, List>>> matchCache = CacheBuilder.newBuilder() - .build(CacheLoader.from(actualProperties -> LocalProperties.match(actualProperties, preferred.getLocalProperties()))); - - return (actual1, actual2) -> { - List>> matchLayout1 = matchCache.getUnchecked(actual1.getLocalProperties()); - List>> matchLayout2 = matchCache.getUnchecked(actual2.getLocalProperties()); - - return ComparisonChain.start() - .compareTrueFirst(hasLocalOptimization(preferred.getLocalProperties(), matchLayout1), hasLocalOptimization(preferred.getLocalProperties(), matchLayout2)) - .compareTrueFirst(meetsPartitioningRequirements(preferred, actual1), meetsPartitioningRequirements(preferred, actual2)) - .compare(matchLayout1, matchLayout2, matchedLayoutPreference()) - .result(); - }; - } - - private static boolean hasLocalOptimization(List> desiredLayout, List>> matchResult) - { - checkArgument(desiredLayout.size() == matchResult.size()); - if (matchResult.isEmpty()) { - return false; - } - // Optimizations can be applied if the first LocalProperty has been modified in the match in any way - return !matchResult.get(0).equals(Optional.of(desiredLayout.get(0))); - } - - private static boolean meetsPartitioningRequirements(PreferredProperties preferred, ActualProperties actual) - { - if (!preferred.getGlobalProperties().isPresent()) { - return true; - } - PreferredProperties.Global preferredGlobal = preferred.getGlobalProperties().get(); - if (!preferredGlobal.isDistributed()) { - return actual.isSingleNode(); - } - if (!preferredGlobal.getPartitioningProperties().isPresent()) { - return !actual.isSingleNode(); - } - return actual.isStreamPartitionedOn(preferredGlobal.getPartitioningProperties().get().getPartitioningColumns()); - } - - // Prefer the match result that satisfied the most requirements - private static Comparator>>> matchedLayoutPreference() - { - return (matchLayout1, matchLayout2) -> { - Iterator>> match1Iterator = matchLayout1.iterator(); - Iterator>> match2Iterator = matchLayout2.iterator(); - while (match1Iterator.hasNext() && match2Iterator.hasNext()) { - Optional> match1 = match1Iterator.next(); - Optional> match2 = match2Iterator.next(); - if (match1.isPresent() && match2.isPresent()) { - return Integer.compare(match1.get().getColumns().size(), match2.get().getColumns().size()); - } - else if (match1.isPresent()) { - return 1; - } - else if (match2.isPresent()) { - return -1; - } - } - checkState(!match1Iterator.hasNext() && !match2Iterator.hasNext()); // Should be the same size - return 0; - }; - } - @VisibleForTesting static class PlanWithProperties { diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddLocalExchanges.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddLocalExchanges.java index 9cc6a0cb79d..23be41b4f4f 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddLocalExchanges.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddLocalExchanges.java @@ -23,12 +23,12 @@ import io.prestosql.spi.connector.GroupingProperty; import io.prestosql.spi.connector.LocalProperty; import io.prestosql.spi.connector.SortingProperty; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Partitioning; import io.prestosql.sql.planner.PartitioningScheme; import io.prestosql.sql.planner.PlanNodeIdAllocator; import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.SymbolAllocator; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties; import io.prestosql.sql.planner.plan.AggregationNode; @@ -96,12 +96,12 @@ public class AddLocalExchanges implements PlanOptimizer { private final Metadata metadata; - private final SqlParser parser; + private final TypeAnalyzer typeAnalyzer; - public AddLocalExchanges(Metadata metadata, SqlParser parser) + public AddLocalExchanges(Metadata metadata, TypeAnalyzer typeAnalyzer) { this.metadata = requireNonNull(metadata, "metadata is null"); - this.parser = requireNonNull(parser, "parser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); } @Override @@ -645,7 +645,7 @@ public PlanWithProperties visitIndexJoin(IndexJoinNode node, StreamPreferredProp parentPreferences.constrainTo(node.getProbeSource().getOutputSymbols()).withDefaultParallelism(session)); // index source does not support local parallel and must produce a single stream - StreamProperties indexStreamProperties = derivePropertiesRecursively(node.getIndexSource(), metadata, session, types, parser); + StreamProperties indexStreamProperties = derivePropertiesRecursively(node.getIndexSource(), metadata, session, types, typeAnalyzer); checkArgument(indexStreamProperties.getDistribution() == SINGLE, "index source must be single stream"); PlanWithProperties index = new PlanWithProperties(node.getIndexSource(), indexStreamProperties); @@ -746,12 +746,12 @@ private PlanWithProperties rebaseAndDeriveProperties(PlanNode node, List inputProperties) { - return new PlanWithProperties(result, StreamPropertyDerivations.deriveProperties(result, inputProperties, metadata, session, types, parser)); + return new PlanWithProperties(result, StreamPropertyDerivations.deriveProperties(result, inputProperties, metadata, session, types, typeAnalyzer)); } } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/BeginTableWrite.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/BeginTableWrite.java index 7057b2344e2..85495eba72b 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/BeginTableWrite.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/BeginTableWrite.java @@ -14,16 +14,11 @@ package io.prestosql.sql.planner.optimizations; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import io.prestosql.Session; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.TableHandle; -import io.prestosql.metadata.TableLayoutResult; -import io.prestosql.spi.connector.ColumnHandle; -import io.prestosql.spi.connector.Constraint; -import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.sql.planner.PlanNodeIdAllocator; import io.prestosql.sql.planner.SymbolAllocator; import io.prestosql.sql.planner.TypeProvider; @@ -41,13 +36,10 @@ import io.prestosql.sql.planner.plan.TableWriterNode; import io.prestosql.sql.planner.plan.UnionNode; -import java.util.List; import java.util.Optional; import java.util.Set; import static com.google.common.base.Preconditions.checkState; -import static com.google.common.base.Verify.verify; -import static io.prestosql.metadata.TableLayoutResult.computeEnforced; import static io.prestosql.sql.planner.optimizations.QueryCardinalityUtil.isAtMostScalar; import static io.prestosql.sql.planner.plan.ChildReplacer.replaceChildren; import static java.util.stream.Collectors.toSet; @@ -195,24 +187,13 @@ private PlanNode rewriteDeleteTableScan(PlanNode node, TableHandle handle) { if (node instanceof TableScanNode) { TableScanNode scan = (TableScanNode) node; - TupleDomain originalEnforcedConstraint = scan.getEnforcedConstraint(); - - List layouts = metadata.getLayouts( - session, - handle, - new Constraint<>(originalEnforcedConstraint), - Optional.of(ImmutableSet.copyOf(scan.getAssignments().values()))); - verify(layouts.size() == 1, "Expected exactly one layout for delete"); - TableLayoutResult layoutResult = Iterables.getOnlyElement(layouts); - return new TableScanNode( scan.getId(), handle, scan.getOutputSymbols(), scan.getAssignments(), - Optional.of(layoutResult.getLayout().getHandle()), - layoutResult.getLayout().getPredicate(), - computeEnforced(originalEnforcedConstraint, layoutResult.getUnenforcedConstraint())); + scan.getCurrentConstraint(), + scan.getEnforcedConstraint()); } if (node instanceof FilterNode) { diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/ExpressionEquivalence.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/ExpressionEquivalence.java index 81b519f6a3a..efae3082d79 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/ExpressionEquivalence.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/ExpressionEquivalence.java @@ -20,13 +20,11 @@ import com.google.common.collect.Ordering; import io.airlift.slice.Slice; import io.prestosql.Session; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.Signature; import io.prestosql.spi.type.Type; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; -import io.prestosql.sql.planner.SymbolToInputRewriter; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.relational.CallExpression; import io.prestosql.sql.relational.ConstantExpression; @@ -36,7 +34,6 @@ import io.prestosql.sql.relational.RowExpressionVisitor; import io.prestosql.sql.relational.VariableReferenceExpression; import io.prestosql.sql.tree.Expression; -import io.prestosql.sql.tree.NodeRef; import java.util.Comparator; import java.util.HashMap; @@ -58,10 +55,8 @@ import static io.prestosql.spi.function.OperatorType.LESS_THAN_OR_EQUAL; import static io.prestosql.spi.function.OperatorType.NOT_EQUAL; import static io.prestosql.spi.type.BooleanType.BOOLEAN; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput; import static io.prestosql.sql.relational.SqlToRowExpressionTranslator.translate; import static java.lang.Integer.min; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; public class ExpressionEquivalence @@ -69,26 +64,24 @@ public class ExpressionEquivalence private static final Ordering ROW_EXPRESSION_ORDERING = Ordering.from(new RowExpressionComparator()); private static final CanonicalizationVisitor CANONICALIZATION_VISITOR = new CanonicalizationVisitor(); private final Metadata metadata; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; - public ExpressionEquivalence(Metadata metadata, SqlParser sqlParser) + public ExpressionEquivalence(Metadata metadata, TypeAnalyzer typeAnalyzer) { this.metadata = requireNonNull(metadata, "metadata is null"); - this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); } public boolean areExpressionsEquivalent(Session session, Expression leftExpression, Expression rightExpression, TypeProvider types) { Map symbolInput = new HashMap<>(); - Map inputTypes = new HashMap<>(); int inputId = 0; for (Entry entry : types.allTypes().entrySet()) { symbolInput.put(entry.getKey(), inputId); - inputTypes.put(inputId, entry.getValue()); inputId++; } - RowExpression leftRowExpression = toRowExpression(session, leftExpression, symbolInput, inputTypes); - RowExpression rightRowExpression = toRowExpression(session, rightExpression, symbolInput, inputTypes); + RowExpression leftRowExpression = toRowExpression(session, leftExpression, symbolInput, types); + RowExpression rightRowExpression = toRowExpression(session, rightExpression, symbolInput, types); RowExpression canonicalizedLeft = leftRowExpression.accept(CANONICALIZATION_VISITOR, null); RowExpression canonicalizedRight = rightRowExpression.accept(CANONICALIZATION_VISITOR, null); @@ -96,23 +89,17 @@ public boolean areExpressionsEquivalent(Session session, Expression leftExpressi return canonicalizedLeft.equals(canonicalizedRight); } - private RowExpression toRowExpression(Session session, Expression expression, Map symbolInput, Map inputTypes) + private RowExpression toRowExpression(Session session, Expression expression, Map symbolInput, TypeProvider types) { - // replace qualified names with input references since row expressions do not support these - Expression expressionWithInputReferences = new SymbolToInputRewriter(symbolInput).rewrite(expression); - - // determine the type of every expression - Map, Type> expressionTypes = getExpressionTypesFromInput( + return translate( + expression, + SCALAR, + typeAnalyzer.getTypes(session, types, expression), + symbolInput, + metadata.getFunctionRegistry(), + metadata.getTypeManager(), session, - metadata, - sqlParser, - inputTypes, - expressionWithInputReferences, - emptyList(), /* parameters have already been replaced */ - WarningCollector.NOOP); - - // convert to row expression - return translate(expressionWithInputReferences, SCALAR, expressionTypes, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, false); + false); } private static class CanonicalizationVisitor diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/HashGenerationOptimizer.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/HashGenerationOptimizer.java index 57695177bc5..2397d3f43a0 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/HashGenerationOptimizer.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/HashGenerationOptimizer.java @@ -839,7 +839,7 @@ public static Optional getHashExpression(List symbols) return Optional.empty(); } - Expression result = new LongLiteral(String.valueOf(INITIAL_HASH_VALUE)); + Expression result = new GenericLiteral(StandardTypes.BIGINT, String.valueOf(INITIAL_HASH_VALUE)); for (Symbol symbol : symbols) { Expression hashField = new FunctionCall( QualifiedName.of(HASH_CODE), diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/IndexJoinOptimizer.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/IndexJoinOptimizer.java index ae72b2958f5..a4b7de827f3 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/IndexJoinOptimizer.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/IndexJoinOptimizer.java @@ -301,7 +301,6 @@ private PlanNode planTableScan(TableScanNode node, Expression predicate, Context idAllocator.getNextId(), resolvedIndex.getIndexHandle(), node.getTable(), - node.getLayout(), context.getLookupSymbols(), node.getOutputSymbols(), node.getAssignments(), diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/MetadataDeleteOptimizer.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/MetadataDeleteOptimizer.java index d16369289be..10614c65a26 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/MetadataDeleteOptimizer.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/MetadataDeleteOptimizer.java @@ -27,6 +27,7 @@ import io.prestosql.sql.planner.plan.SimplePlanRewriter; import io.prestosql.sql.planner.plan.TableFinishNode; import io.prestosql.sql.planner.plan.TableScanNode; +import io.prestosql.sql.planner.plan.TableWriterNode; import java.util.List; import java.util.Optional; @@ -89,10 +90,13 @@ public PlanNode visitTableFinish(TableFinishNode node, RewriteContext cont return context.defaultRewrite(node); } TableScanNode tableScanNode = tableScan.get(); - if (!metadata.supportsMetadataDelete(session, tableScanNode.getTable(), tableScanNode.getLayout().get())) { + if (!metadata.supportsMetadataDelete(session, tableScanNode.getTable())) { return context.defaultRewrite(node); } - return new MetadataDeleteNode(idAllocator.getNextId(), delete.get().getTarget(), Iterables.getOnlyElement(node.getOutputSymbols()), tableScanNode.getLayout().get()); + return new MetadataDeleteNode( + idAllocator.getNextId(), + new TableWriterNode.DeleteHandle(tableScanNode.getTable(), delete.get().getTarget().getSchemaTableName()), + Iterables.getOnlyElement(node.getOutputSymbols())); } private static Optional findNode(PlanNode source, Class clazz) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/MetadataQueryOptimizer.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/MetadataQueryOptimizer.java index 3e70a8d9b90..5da43eb4ed3 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/MetadataQueryOptimizer.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/MetadataQueryOptimizer.java @@ -22,10 +22,8 @@ import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.TableLayout; -import io.prestosql.metadata.TableLayoutResult; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ColumnMetadata; -import io.prestosql.spi.connector.Constraint; import io.prestosql.spi.connector.DiscretePredicates; import io.prestosql.spi.predicate.NullableValue; import io.prestosql.spi.predicate.TupleDomain; @@ -137,18 +135,8 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext cont // Materialize the list of partitions and replace the TableScan node // with a Values node - TableLayout layout = null; - if (!tableScan.getLayout().isPresent()) { - List layouts = metadata.getLayouts(session, tableScan.getTable(), Constraint.alwaysTrue(), Optional.empty()); - if (layouts.size() == 1) { - layout = Iterables.getOnlyElement(layouts).getLayout(); - } - } - else { - layout = metadata.getLayout(session, tableScan.getLayout().get()); - } - - if (layout == null || !layout.getDiscretePredicates().isPresent()) { + TableLayout layout = metadata.getLayout(session, tableScan.getTable()); + if (!layout.getDiscretePredicates().isPresent()) { return context.defaultRewrite(node); } DiscretePredicates predicates = layout.getDiscretePredicates().get(); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PredicatePushDown.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PredicatePushDown.java index 27f7e0a558c..246616248ce 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PredicatePushDown.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PredicatePushDown.java @@ -20,7 +20,6 @@ import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.spi.type.Type; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.DeterminismEvaluator; import io.prestosql.sql.planner.DomainTranslator; import io.prestosql.sql.planner.EffectivePredicateExtractor; @@ -32,6 +31,7 @@ import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.SymbolAllocator; import io.prestosql.sql.planner.SymbolsExtractor; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.AggregationNode; import io.prestosql.sql.planner.plan.AssignUniqueId; @@ -84,7 +84,6 @@ import static io.prestosql.sql.ExpressionUtils.combineConjuncts; import static io.prestosql.sql.ExpressionUtils.extractConjuncts; import static io.prestosql.sql.ExpressionUtils.filterDeterministicConjuncts; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.planner.DeterminismEvaluator.isDeterministic; import static io.prestosql.sql.planner.EqualityInference.createEqualityInference; import static io.prestosql.sql.planner.ExpressionSymbolInliner.inlineSymbols; @@ -93,7 +92,6 @@ import static io.prestosql.sql.planner.plan.JoinNode.Type.LEFT; import static io.prestosql.sql.planner.plan.JoinNode.Type.RIGHT; import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; public class PredicatePushDown @@ -102,14 +100,14 @@ public class PredicatePushDown private final Metadata metadata; private final LiteralEncoder literalEncoder; private final EffectivePredicateExtractor effectivePredicateExtractor; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; - public PredicatePushDown(Metadata metadata, SqlParser sqlParser) + public PredicatePushDown(Metadata metadata, TypeAnalyzer typeAnalyzer) { this.metadata = requireNonNull(metadata, "metadata is null"); this.literalEncoder = new LiteralEncoder(metadata.getBlockEncodingSerde()); this.effectivePredicateExtractor = new EffectivePredicateExtractor(new DomainTranslator(literalEncoder)); - this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); } @Override @@ -121,7 +119,7 @@ public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, Sym requireNonNull(idAllocator, "idAllocator is null"); return SimplePlanRewriter.rewriteWith( - new Rewriter(symbolAllocator, idAllocator, metadata, literalEncoder, effectivePredicateExtractor, sqlParser, session, types), + new Rewriter(symbolAllocator, idAllocator, metadata, literalEncoder, effectivePredicateExtractor, typeAnalyzer, session, types), plan, TRUE_LITERAL); } @@ -134,7 +132,7 @@ private static class Rewriter private final Metadata metadata; private final LiteralEncoder literalEncoder; private final EffectivePredicateExtractor effectivePredicateExtractor; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; private final Session session; private final TypeProvider types; private final ExpressionEquivalence expressionEquivalence; @@ -145,7 +143,7 @@ private Rewriter( Metadata metadata, LiteralEncoder literalEncoder, EffectivePredicateExtractor effectivePredicateExtractor, - SqlParser sqlParser, + TypeAnalyzer typeAnalyzer, Session session, TypeProvider types) { @@ -154,10 +152,10 @@ private Rewriter( this.metadata = requireNonNull(metadata, "metadata is null"); this.literalEncoder = requireNonNull(literalEncoder, "literalEncoder is null"); this.effectivePredicateExtractor = requireNonNull(effectivePredicateExtractor, "effectivePredicateExtractor is null"); - this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.session = requireNonNull(session, "session is null"); this.types = requireNonNull(types, "types is null"); - this.expressionEquivalence = new ExpressionEquivalence(metadata, sqlParser); + this.expressionEquivalence = new ExpressionEquivalence(metadata, typeAnalyzer); } @Override @@ -638,7 +636,7 @@ private Symbol symbolForExpression(Expression expression) return Symbol.from(expression); } - return symbolAllocator.newSymbol(expression, extractType(expression)); + return symbolAllocator.newSymbol(expression, typeAnalyzer.getType(session, symbolAllocator.getTypes(), expression)); } private static OuterJoinPushDownResult processLimitedOuterJoin(Expression inheritedPredicate, Expression outerEffectivePredicate, Expression innerEffectivePredicate, Expression joinPredicate, Collection outerSymbols) @@ -891,12 +889,6 @@ private static Expression extractJoinPredicate(JoinNode joinNode) return combineConjuncts(builder.build()); } - private Type extractType(Expression expression) - { - Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, symbolAllocator.getTypes(), expression, emptyList(), /* parameters have already been replaced */WarningCollector.NOOP); - return expressionTypes.get(NodeRef.of(expression)); - } - private JoinNode tryNormalizeToOuterToInnerJoin(JoinNode node, Expression inheritedPredicate) { checkArgument(EnumSet.of(INNER, RIGHT, LEFT, FULL).contains(node.getType()), "Unsupported join type: %s", node.getType()); @@ -948,14 +940,7 @@ private boolean canConvertOuterToInner(List innerSymbolsForOuterJoin, Ex // Temporary implementation for joins because the SimplifyExpressions optimizers can not run properly on join clauses private Expression simplifyExpression(Expression expression) { - Map, Type> expressionTypes = getExpressionTypes( - session, - metadata, - sqlParser, - symbolAllocator.getTypes(), - expression, - emptyList(), /* parameters have already been replaced */ - WarningCollector.NOOP); + Map, Type> expressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), expression); ExpressionInterpreter optimizer = ExpressionInterpreter.expressionOptimizer(expression, metadata, session, expressionTypes); return literalEncoder.toExpression(optimizer.optimize(NoOpSymbolResolver.INSTANCE), expressionTypes.get(NodeRef.of(expression))); } @@ -970,14 +955,7 @@ private boolean areExpressionsEquivalent(Expression leftExpression, Expression r */ private Object nullInputEvaluator(final Collection nullSymbols, Expression expression) { - Map, Type> expressionTypes = getExpressionTypes( - session, - metadata, - sqlParser, - symbolAllocator.getTypes(), - expression, - emptyList(), /* parameters have already been replaced */ - WarningCollector.NOOP); + Map, Type> expressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), expression); return ExpressionInterpreter.expressionOptimizer(expression, metadata, session, expressionTypes) .optimize(symbol -> nullSymbols.contains(symbol) ? null : symbol.toSymbolReference()); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PropertyDerivations.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PropertyDerivations.java index 8306c2e9f90..b0de8196db4 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PropertyDerivations.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PropertyDerivations.java @@ -21,7 +21,6 @@ import com.google.common.collect.Sets; import io.prestosql.Session; import io.prestosql.SystemSessionProperties; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.TableLayout; import io.prestosql.metadata.TableLayout.TablePartitioning; @@ -32,13 +31,13 @@ import io.prestosql.spi.connector.SortingProperty; import io.prestosql.spi.predicate.NullableValue; import io.prestosql.spi.type.Type; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.DomainTranslator; import io.prestosql.sql.planner.ExpressionInterpreter; import io.prestosql.sql.planner.NoOpSymbolResolver; import io.prestosql.sql.planner.OrderingScheme; import io.prestosql.sql.planner.Partitioning; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.optimizations.ActualProperties.Global; import io.prestosql.sql.planner.plan.AggregationNode; @@ -95,7 +94,6 @@ import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.prestosql.SystemSessionProperties.planWithTableNodePartitioning; import static io.prestosql.spi.predicate.TupleDomain.extractFixedValues; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.planner.SystemPartitioningHandle.ARBITRARY_DISTRIBUTION; import static io.prestosql.sql.planner.optimizations.ActualProperties.Global.arbitraryPartition; import static io.prestosql.sql.planner.optimizations.ActualProperties.Global.coordinatorSingleStreamPartition; @@ -104,7 +102,6 @@ import static io.prestosql.sql.planner.optimizations.ActualProperties.Global.streamPartitionedOn; import static io.prestosql.sql.planner.plan.ExchangeNode.Scope.LOCAL; import static io.prestosql.sql.planner.plan.ExchangeNode.Scope.REMOTE; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toMap; @@ -112,17 +109,17 @@ public class PropertyDerivations { private PropertyDerivations() {} - public static ActualProperties derivePropertiesRecursively(PlanNode node, Metadata metadata, Session session, TypeProvider types, SqlParser parser) + public static ActualProperties derivePropertiesRecursively(PlanNode node, Metadata metadata, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer) { List inputProperties = node.getSources().stream() - .map(source -> derivePropertiesRecursively(source, metadata, session, types, parser)) + .map(source -> derivePropertiesRecursively(source, metadata, session, types, typeAnalyzer)) .collect(toImmutableList()); - return deriveProperties(node, inputProperties, metadata, session, types, parser); + return deriveProperties(node, inputProperties, metadata, session, types, typeAnalyzer); } - public static ActualProperties deriveProperties(PlanNode node, List inputProperties, Metadata metadata, Session session, TypeProvider types, SqlParser parser) + public static ActualProperties deriveProperties(PlanNode node, List inputProperties, Metadata metadata, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer) { - ActualProperties output = node.accept(new Visitor(metadata, session, types, parser), inputProperties); + ActualProperties output = node.accept(new Visitor(metadata, session, types, typeAnalyzer), inputProperties); output.getNodePartitioning().ifPresent(partitioning -> verify(node.getOutputSymbols().containsAll(partitioning.getColumns()), "Node-level partitioning properties contain columns not present in node's output")); @@ -137,9 +134,9 @@ public static ActualProperties deriveProperties(PlanNode node, List inputProperties, Metadata metadata, Session session, TypeProvider types, SqlParser parser) + public static ActualProperties streamBackdoorDeriveProperties(PlanNode node, List inputProperties, Metadata metadata, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer) { - return node.accept(new Visitor(metadata, session, types, parser), inputProperties); + return node.accept(new Visitor(metadata, session, types, typeAnalyzer), inputProperties); } private static class Visitor @@ -148,14 +145,14 @@ private static class Visitor private final Metadata metadata; private final Session session; private final TypeProvider types; - private final SqlParser parser; + private final TypeAnalyzer typeAnalyzer; - public Visitor(Metadata metadata, Session session, TypeProvider types, SqlParser parser) + public Visitor(Metadata metadata, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer) { this.metadata = metadata; this.session = session; this.types = types; - this.parser = parser; + this.typeAnalyzer = typeAnalyzer; } @Override @@ -636,7 +633,7 @@ public ActualProperties visitProject(ProjectNode node, List in for (Map.Entry assignment : node.getAssignments().entrySet()) { Expression expression = assignment.getValue(); - Map, Type> expressionTypes = getExpressionTypes(session, metadata, parser, types, expression, emptyList(), WarningCollector.NOOP); + Map, Type> expressionTypes = typeAnalyzer.getTypes(session, types, expression); Type type = requireNonNull(expressionTypes.get(NodeRef.of(expression))); ExpressionInterpreter optimizer = ExpressionInterpreter.expressionOptimizer(expression, metadata, session, expressionTypes); // TODO: @@ -709,9 +706,7 @@ public ActualProperties visitValues(ValuesNode node, List cont @Override public ActualProperties visitTableScan(TableScanNode node, List inputProperties) { - checkArgument(node.getLayout().isPresent(), "table layout has not yet been chosen"); - - TableLayout layout = metadata.getLayout(session, node.getLayout().get()); + TableLayout layout = metadata.getLayout(session, node.getTable()); Map assignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse(); ActualProperties.Builder properties = ActualProperties.builder(); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PruneUnreferencedOutputs.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PruneUnreferencedOutputs.java index 5207fc49c05..f568bb10b3b 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -320,7 +320,7 @@ public PlanNode visitIndexSource(IndexSourceNode node, RewriteContext newAssignments = newOutputSymbols.stream() .collect(Collectors.toMap(Function.identity(), node.getAssignments()::get)); - return new IndexSourceNode(node.getId(), node.getIndexHandle(), node.getTableHandle(), node.getLayout(), newLookupSymbols, newOutputSymbols, newAssignments, node.getCurrentConstraint()); + return new IndexSourceNode(node.getId(), node.getIndexHandle(), node.getTableHandle(), newLookupSymbols, newOutputSymbols, newAssignments, node.getCurrentConstraint()); } @Override @@ -426,7 +426,6 @@ public PlanNode visitTableScan(TableScanNode node, RewriteContext> c node.getTable(), newOutputs, newAssignments, - node.getLayout(), node.getCurrentConstraint(), node.getEnforcedConstraint()); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/StreamPropertyDerivations.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/StreamPropertyDerivations.java index 27b5930ae57..ed1c0ff221d 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/StreamPropertyDerivations.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/StreamPropertyDerivations.java @@ -23,9 +23,9 @@ import io.prestosql.metadata.TableLayout; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.LocalProperty; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Partitioning.ArgumentBinding; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.AggregationNode; import io.prestosql.sql.planner.plan.ApplyNode; @@ -96,27 +96,27 @@ public final class StreamPropertyDerivations { private StreamPropertyDerivations() {} - public static StreamProperties derivePropertiesRecursively(PlanNode node, Metadata metadata, Session session, TypeProvider types, SqlParser parser) + public static StreamProperties derivePropertiesRecursively(PlanNode node, Metadata metadata, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer) { List inputProperties = node.getSources().stream() - .map(source -> derivePropertiesRecursively(source, metadata, session, types, parser)) + .map(source -> derivePropertiesRecursively(source, metadata, session, types, typeAnalyzer)) .collect(toImmutableList()); - return StreamPropertyDerivations.deriveProperties(node, inputProperties, metadata, session, types, parser); + return StreamPropertyDerivations.deriveProperties(node, inputProperties, metadata, session, types, typeAnalyzer); } - public static StreamProperties deriveProperties(PlanNode node, StreamProperties inputProperties, Metadata metadata, Session session, TypeProvider types, SqlParser parser) + public static StreamProperties deriveProperties(PlanNode node, StreamProperties inputProperties, Metadata metadata, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer) { - return deriveProperties(node, ImmutableList.of(inputProperties), metadata, session, types, parser); + return deriveProperties(node, ImmutableList.of(inputProperties), metadata, session, types, typeAnalyzer); } - public static StreamProperties deriveProperties(PlanNode node, List inputProperties, Metadata metadata, Session session, TypeProvider types, SqlParser parser) + public static StreamProperties deriveProperties(PlanNode node, List inputProperties, Metadata metadata, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer) { requireNonNull(node, "node is null"); requireNonNull(inputProperties, "inputProperties is null"); requireNonNull(metadata, "metadata is null"); requireNonNull(session, "session is null"); requireNonNull(types, "types is null"); - requireNonNull(parser, "parser is null"); + requireNonNull(typeAnalyzer, "typeAnalyzer is null"); // properties.otherActualProperties will never be null here because the only way // an external caller should obtain StreamProperties is from this method, and the @@ -129,7 +129,7 @@ public static StreamProperties deriveProperties(PlanNode node, List cont @Override public StreamProperties visitTableScan(TableScanNode node, List inputProperties) { - checkArgument(node.getLayout().isPresent(), "table layout has not yet been chosen"); - - TableLayout layout = metadata.getLayout(session, node.getLayout().get()); + TableLayout layout = metadata.getLayout(session, node.getTable()); Map assignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse(); // Globally constant assignments diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/UnaliasSymbolReferences.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/UnaliasSymbolReferences.java index ad377f9e509..c35919cf46d 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -525,7 +525,7 @@ public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext cont @Override public PlanNode visitIndexSource(IndexSourceNode node, RewriteContext context) { - return new IndexSourceNode(node.getId(), node.getIndexHandle(), node.getTableHandle(), node.getLayout(), canonicalize(node.getLookupSymbols()), node.getOutputSymbols(), node.getAssignments(), node.getCurrentConstraint()); + return new IndexSourceNode(node.getId(), node.getIndexHandle(), node.getTableHandle(), canonicalize(node.getLookupSymbols()), node.getOutputSymbols(), node.getAssignments(), node.getCurrentConstraint()); } @Override diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/plan/IndexSourceNode.java b/presto-main/src/main/java/io/prestosql/sql/planner/plan/IndexSourceNode.java index d908fea6265..877e18ab087 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/plan/IndexSourceNode.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/plan/IndexSourceNode.java @@ -20,14 +20,12 @@ import com.google.common.collect.ImmutableSet; import io.prestosql.metadata.IndexHandle; import io.prestosql.metadata.TableHandle; -import io.prestosql.metadata.TableLayoutHandle; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.sql.planner.Symbol; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; @@ -38,7 +36,6 @@ public class IndexSourceNode { private final IndexHandle indexHandle; private final TableHandle tableHandle; - private final Optional tableLayout; // only necessary for event listeners private final Set lookupSymbols; private final List outputSymbols; private final Map assignments; // symbol -> column @@ -49,7 +46,6 @@ public IndexSourceNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("indexHandle") IndexHandle indexHandle, @JsonProperty("tableHandle") TableHandle tableHandle, - @JsonProperty("tableLayout") Optional tableLayout, @JsonProperty("lookupSymbols") Set lookupSymbols, @JsonProperty("outputSymbols") List outputSymbols, @JsonProperty("assignments") Map assignments, @@ -58,7 +54,6 @@ public IndexSourceNode( super(id); this.indexHandle = requireNonNull(indexHandle, "indexHandle is null"); this.tableHandle = requireNonNull(tableHandle, "tableHandle is null"); - this.tableLayout = requireNonNull(tableLayout, "tableLayout is null"); this.lookupSymbols = ImmutableSet.copyOf(requireNonNull(lookupSymbols, "lookupSymbols is null")); this.outputSymbols = ImmutableList.copyOf(requireNonNull(outputSymbols, "outputSymbols is null")); this.assignments = ImmutableMap.copyOf(requireNonNull(assignments, "assignments is null")); @@ -81,12 +76,6 @@ public TableHandle getTableHandle() return tableHandle; } - @JsonProperty - public Optional getLayout() - { - return tableLayout; - } - @JsonProperty public Set getLookupSymbols() { diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/plan/MetadataDeleteNode.java b/presto-main/src/main/java/io/prestosql/sql/planner/plan/MetadataDeleteNode.java index 807b541f5e4..406d22f1383 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/plan/MetadataDeleteNode.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/plan/MetadataDeleteNode.java @@ -16,7 +16,6 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; -import io.prestosql.metadata.TableLayoutHandle; import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.plan.TableWriterNode.DeleteHandle; @@ -32,20 +31,17 @@ public class MetadataDeleteNode { private final DeleteHandle target; private final Symbol output; - private final TableLayoutHandle tableLayout; @JsonCreator public MetadataDeleteNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("target") DeleteHandle target, - @JsonProperty("output") Symbol output, - @JsonProperty("tableLayout") TableLayoutHandle tableLayout) + @JsonProperty("output") Symbol output) { super(id); this.target = requireNonNull(target, "target is null"); this.output = requireNonNull(output, "output is null"); - this.tableLayout = requireNonNull(tableLayout, "tableLayout is null"); } @JsonProperty @@ -66,12 +62,6 @@ public List getOutputSymbols() return ImmutableList.of(output); } - @JsonProperty - public TableLayoutHandle getTableLayout() - { - return tableLayout; - } - @Override public List getSources() { @@ -87,6 +77,6 @@ public R accept(PlanVisitor visitor, C context) @Override public PlanNode replaceChildren(List newChildren) { - return new MetadataDeleteNode(getId(), target, output, tableLayout); + return new MetadataDeleteNode(getId(), target, output); } } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/plan/TableScanNode.java b/presto-main/src/main/java/io/prestosql/sql/planner/plan/TableScanNode.java index ad7eaf0d1ec..f95f8ba44ab 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/plan/TableScanNode.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/plan/TableScanNode.java @@ -18,7 +18,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.prestosql.metadata.TableHandle; -import io.prestosql.metadata.TableLayoutHandle; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.sql.planner.Symbol; @@ -27,7 +26,6 @@ import java.util.List; import java.util.Map; -import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; @@ -42,21 +40,30 @@ public class TableScanNode private final List outputSymbols; private final Map assignments; // symbol -> column - private final Optional tableLayout; - // Used during predicate refinement over multiple passes of predicate pushdown // TODO: think about how to get rid of this in new planner private final TupleDomain currentConstraint; private final TupleDomain enforcedConstraint; + // We need this factory method to disambiguate with the constructor used for deserializing + // from a json object. The deserializer sets some fields which are never transported + // to null + public static TableScanNode newInstance( + PlanNodeId id, + TableHandle table, + List outputs, + Map assignments) + { + return new TableScanNode(id, table, outputs, assignments, TupleDomain.all(), TupleDomain.all()); + } + @JsonCreator public TableScanNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("table") TableHandle table, @JsonProperty("outputSymbols") List outputs, - @JsonProperty("assignments") Map assignments, - @JsonProperty("layout") Optional tableLayout) + @JsonProperty("assignments") Map assignments) { // This constructor is for JSON deserialization only. Do not use. super(id); @@ -64,26 +71,15 @@ public TableScanNode( this.outputSymbols = ImmutableList.copyOf(requireNonNull(outputs, "outputs is null")); this.assignments = ImmutableMap.copyOf(requireNonNull(assignments, "assignments is null")); checkArgument(assignments.keySet().containsAll(outputs), "assignments does not cover all of outputs"); - this.tableLayout = requireNonNull(tableLayout, "tableLayout is null"); this.currentConstraint = null; this.enforcedConstraint = null; } - public TableScanNode( - PlanNodeId id, - TableHandle table, - List outputs, - Map assignments) - { - this(id, table, outputs, assignments, Optional.empty(), TupleDomain.all(), TupleDomain.all()); - } - public TableScanNode( PlanNodeId id, TableHandle table, List outputs, Map assignments, - Optional tableLayout, TupleDomain currentConstraint, TupleDomain enforcedConstraint) { @@ -92,12 +88,8 @@ public TableScanNode( this.outputSymbols = ImmutableList.copyOf(requireNonNull(outputs, "outputs is null")); this.assignments = ImmutableMap.copyOf(requireNonNull(assignments, "assignments is null")); checkArgument(assignments.keySet().containsAll(outputs), "assignments does not cover all of outputs"); - this.tableLayout = requireNonNull(tableLayout, "tableLayout is null"); this.currentConstraint = requireNonNull(currentConstraint, "currentConstraint is null"); this.enforcedConstraint = requireNonNull(enforcedConstraint, "enforcedConstraint is null"); - if (!currentConstraint.isAll() || !enforcedConstraint.isAll()) { - checkArgument(tableLayout.isPresent(), "tableLayout must be present when currentConstraint or enforcedConstraint is non-trivial"); - } } @JsonProperty("table") @@ -106,12 +98,6 @@ public TableHandle getTable() return table; } - @JsonProperty - public Optional getLayout() - { - return tableLayout; - } - @Override @JsonProperty("outputSymbols") public List getOutputSymbols() @@ -171,7 +157,6 @@ public String toString() { return toStringHelper(this) .add("table", table) - .add("tableLayout", tableLayout) .add("outputSymbols", outputSymbols) .add("assignments", assignments) .add("currentConstraint", currentConstraint) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/planPrinter/PlanPrinter.java b/presto-main/src/main/java/io/prestosql/sql/planner/planPrinter/PlanPrinter.java index 1834e19a73b..985070896e9 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/planPrinter/PlanPrinter.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/planPrinter/PlanPrinter.java @@ -36,7 +36,6 @@ import io.prestosql.metadata.TableHandle; import io.prestosql.operator.StageExecutionDescriptor; import io.prestosql.spi.connector.ColumnHandle; -import io.prestosql.spi.connector.ConnectorTableLayoutHandle; import io.prestosql.spi.predicate.Domain; import io.prestosql.spi.predicate.Marker; import io.prestosql.spi.predicate.NullableValue; @@ -758,16 +757,6 @@ private Void visitScanFilterAndProjectInfo( private void printTableScanInfo(NodeRepresentation nodeOutput, TableScanNode node) { - TableHandle table = node.getTable(); - - if (node.getLayout().isPresent()) { - // TODO: find a better way to do this - ConnectorTableLayoutHandle layout = node.getLayout().get().getConnectorHandle(); - if (!table.getConnectorHandle().toString().equals(layout.toString())) { - nodeOutput.appendDetailsLine("LAYOUT: %s", layout); - } - } - TupleDomain predicate = node.getCurrentConstraint(); if (predicate.isNone()) { nodeOutput.appendDetailsLine(":: NONE"); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/NoDuplicatePlanNodeIdsChecker.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/NoDuplicatePlanNodeIdsChecker.java index 37bb2e3787c..f0cdccad121 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/NoDuplicatePlanNodeIdsChecker.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/NoDuplicatePlanNodeIdsChecker.java @@ -16,7 +16,7 @@ import io.prestosql.Session; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; -import io.prestosql.sql.parser.SqlParser; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.planner.plan.PlanNodeId; @@ -33,7 +33,7 @@ public class NoDuplicatePlanNodeIdsChecker implements PlanSanityChecker.Checker { @Override - public void validate(PlanNode planNode, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validate(PlanNode planNode, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { Map planNodeIds = new HashMap<>(); searchFrom(planNode) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/NoIdentifierLeftChecker.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/NoIdentifierLeftChecker.java index ddb9e03c65d..a7d0f133cac 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/NoIdentifierLeftChecker.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/NoIdentifierLeftChecker.java @@ -17,8 +17,8 @@ import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.sql.analyzer.ExpressionTreeUtils; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.ExpressionExtractor; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.tree.Identifier; @@ -29,7 +29,7 @@ public final class NoIdentifierLeftChecker implements PlanSanityChecker.Checker { @Override - public void validate(PlanNode plan, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validate(PlanNode plan, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { List identifiers = ExpressionTreeUtils.extractExpressions(ExpressionExtractor.extractExpressions(plan), Identifier.class); if (!identifiers.isEmpty()) { diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/NoSubqueryExpressionLeftChecker.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/NoSubqueryExpressionLeftChecker.java index 916235efc1b..88c4e94c63b 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/NoSubqueryExpressionLeftChecker.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/NoSubqueryExpressionLeftChecker.java @@ -16,8 +16,8 @@ import io.prestosql.Session; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.ExpressionExtractor; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.tree.DefaultTraversalVisitor; @@ -30,7 +30,7 @@ public final class NoSubqueryExpressionLeftChecker implements PlanSanityChecker.Checker { @Override - public void validate(PlanNode plan, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validate(PlanNode plan, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { for (Expression expression : ExpressionExtractor.extractExpressions(plan)) { new DefaultTraversalVisitor() diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/PlanSanityChecker.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/PlanSanityChecker.java index cc7a468aad4..6c6cb14d1de 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/PlanSanityChecker.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/PlanSanityChecker.java @@ -18,7 +18,7 @@ import io.prestosql.Session; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; -import io.prestosql.sql.parser.SqlParser; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.PlanNode; @@ -56,19 +56,19 @@ public PlanSanityChecker(boolean forceSingleNode) .build(); } - public void validateFinalPlan(PlanNode planNode, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validateFinalPlan(PlanNode planNode, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { - checkers.get(Stage.FINAL).forEach(checker -> checker.validate(planNode, session, metadata, sqlParser, types, warningCollector)); + checkers.get(Stage.FINAL).forEach(checker -> checker.validate(planNode, session, metadata, typeAnalyzer, types, warningCollector)); } - public void validateIntermediatePlan(PlanNode planNode, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validateIntermediatePlan(PlanNode planNode, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { - checkers.get(Stage.INTERMEDIATE).forEach(checker -> checker.validate(planNode, session, metadata, sqlParser, types, warningCollector)); + checkers.get(Stage.INTERMEDIATE).forEach(checker -> checker.validate(planNode, session, metadata, typeAnalyzer, types, warningCollector)); } public interface Checker { - void validate(PlanNode planNode, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector); + void validate(PlanNode planNode, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector); } private enum Stage diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/TypeValidator.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/TypeValidator.java index daf7934aa45..30c239dc29f 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/TypeValidator.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/TypeValidator.java @@ -21,9 +21,9 @@ import io.prestosql.spi.type.Type; import io.prestosql.spi.type.TypeManager; import io.prestosql.spi.type.TypeSignature; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.SimplePlanVisitor; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.AggregationNode; import io.prestosql.sql.planner.plan.AggregationNode.Aggregation; @@ -33,16 +33,13 @@ import io.prestosql.sql.planner.plan.WindowNode; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.FunctionCall; -import io.prestosql.sql.tree.NodeRef; import io.prestosql.sql.tree.SymbolReference; import java.util.List; import java.util.Map; import static com.google.common.base.Preconditions.checkArgument; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.type.UnknownType.UNKNOWN; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; /** @@ -54,9 +51,9 @@ public final class TypeValidator public TypeValidator() {} @Override - public void validate(PlanNode plan, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validate(PlanNode plan, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { - plan.accept(new Visitor(session, metadata, sqlParser, types, warningCollector), null); + plan.accept(new Visitor(session, metadata, typeAnalyzer, types, warningCollector), null); } private static class Visitor @@ -64,15 +61,15 @@ private static class Visitor { private final Session session; private final Metadata metadata; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; private final TypeProvider types; private final WarningCollector warningCollector; - public Visitor(Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public Visitor(Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { this.session = requireNonNull(session, "session is null"); this.metadata = requireNonNull(metadata, "metadata is null"); - this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.types = requireNonNull(types, "types is null"); this.warningCollector = requireNonNull(warningCollector, "warningCollector is null"); } @@ -119,8 +116,7 @@ public Void visitProject(ProjectNode node, Void context) verifyTypeSignature(entry.getKey(), expectedType.getTypeSignature(), types.get(Symbol.from(symbolReference)).getTypeSignature()); continue; } - Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, types, entry.getValue(), emptyList(), warningCollector); - Type actualType = expressionTypes.get(NodeRef.of(entry.getValue())); + Type actualType = typeAnalyzer.getType(session, types, entry.getValue()); verifyTypeSignature(entry.getKey(), expectedType.getTypeSignature(), actualType.getTypeSignature()); } @@ -165,8 +161,7 @@ private void checkSignature(Symbol symbol, Signature signature) private void checkCall(Symbol symbol, FunctionCall call) { Type expectedType = types.get(symbol); - Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, types, call, emptyList(), warningCollector); - Type actualType = expressionTypes.get(NodeRef.of(call)); + Type actualType = typeAnalyzer.getType(session, types, call); verifyTypeSignature(symbol, expectedType.getTypeSignature(), actualType.getTypeSignature()); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateAggregationsWithDefaultValues.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateAggregationsWithDefaultValues.java index a6f760773aa..22152e412d3 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateAggregationsWithDefaultValues.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateAggregationsWithDefaultValues.java @@ -16,7 +16,7 @@ import io.prestosql.Session; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; -import io.prestosql.sql.parser.SqlParser; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.optimizations.ActualProperties; import io.prestosql.sql.planner.optimizations.PropertyDerivations; @@ -60,9 +60,9 @@ public ValidateAggregationsWithDefaultValues(boolean forceSingleNode) } @Override - public void validate(PlanNode planNode, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validate(PlanNode planNode, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { - planNode.accept(new Visitor(session, metadata, sqlParser, types), null); + planNode.accept(new Visitor(session, metadata, typeAnalyzer, types), null); } private class Visitor @@ -70,14 +70,14 @@ private class Visitor { final Session session; final Metadata metadata; - final SqlParser parser; + final TypeAnalyzer typeAnalyzer; final TypeProvider types; - Visitor(Session session, Metadata metadata, SqlParser parser, TypeProvider types) + Visitor(Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types) { this.session = requireNonNull(session, "session is null"); this.metadata = requireNonNull(metadata, "metadata is null"); - this.parser = requireNonNull(parser, "parser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.types = requireNonNull(types, "types is null"); } @@ -115,14 +115,14 @@ public Optional visitAggregation(AggregationNode node, Void conte // No remote repartition exchange between final and partial aggregation. // Make sure that final aggregation operators are executed on a single node. - ActualProperties globalProperties = PropertyDerivations.derivePropertiesRecursively(node, metadata, session, types, parser); + ActualProperties globalProperties = PropertyDerivations.derivePropertiesRecursively(node, metadata, session, types, typeAnalyzer); checkArgument(forceSingleNode || globalProperties.isSingleNode(), "Final aggregation with default value not separated from partial aggregation by remote hash exchange"); if (!seenExchanges.localRepartitionExchange) { // No local repartition exchange between final and partial aggregation. // Make sure that final aggregation operators are executed by single thread. - StreamProperties localProperties = StreamPropertyDerivations.derivePropertiesRecursively(node, metadata, session, types, parser); + StreamProperties localProperties = StreamPropertyDerivations.derivePropertiesRecursively(node, metadata, session, types, typeAnalyzer); checkArgument(localProperties.isSingleStream(), "Final aggregation with default value not separated from partial aggregation by local hash exchange"); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateDependenciesChecker.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateDependenciesChecker.java index 56a8175897b..ea0949bbc71 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateDependenciesChecker.java @@ -18,9 +18,9 @@ import io.prestosql.Session; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.SymbolsExtractor; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.AggregationNode; import io.prestosql.sql.planner.plan.AggregationNode.Aggregation; @@ -85,7 +85,7 @@ public final class ValidateDependenciesChecker implements PlanSanityChecker.Checker { @Override - public void validate(PlanNode plan, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validate(PlanNode plan, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { validate(plan); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateStreamingAggregations.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateStreamingAggregations.java index 3d2c37400ba..bf0b9631041 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateStreamingAggregations.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateStreamingAggregations.java @@ -20,8 +20,8 @@ import io.prestosql.metadata.Metadata; import io.prestosql.spi.connector.GroupingProperty; import io.prestosql.spi.connector.LocalProperty; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.optimizations.LocalProperties; import io.prestosql.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties; @@ -44,9 +44,9 @@ public class ValidateStreamingAggregations implements Checker { @Override - public void validate(PlanNode planNode, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validate(PlanNode planNode, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { - planNode.accept(new Visitor(session, metadata, sqlParser, types, warningCollector), null); + planNode.accept(new Visitor(session, metadata, typeAnalyzer, types, warningCollector), null); } private static final class Visitor @@ -54,15 +54,15 @@ private static final class Visitor { private final Session session; private final Metadata metadata; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; private final TypeProvider types; private final WarningCollector warningCollector; - private Visitor(Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + private Visitor(Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { this.session = session; this.metadata = metadata; - this.sqlParser = sqlParser; + this.typeAnalyzer = typeAnalyzer; this.types = types; this.warningCollector = warningCollector; } @@ -81,7 +81,7 @@ public Void visitAggregation(AggregationNode node, Void context) return null; } - StreamProperties properties = derivePropertiesRecursively(node.getSource(), metadata, session, types, sqlParser); + StreamProperties properties = derivePropertiesRecursively(node.getSource(), metadata, session, types, typeAnalyzer); List> desiredProperties = ImmutableList.of(new GroupingProperty<>(node.getPreGroupedSymbols())); Iterator>> matchIterator = LocalProperties.match(properties.getLocalProperties(), desiredProperties).iterator(); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/VerifyNoFilteredAggregations.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/VerifyNoFilteredAggregations.java index 80999b7024d..479e11487d8 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/VerifyNoFilteredAggregations.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/VerifyNoFilteredAggregations.java @@ -16,7 +16,7 @@ import io.prestosql.Session; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; -import io.prestosql.sql.parser.SqlParser; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.AggregationNode; import io.prestosql.sql.planner.plan.PlanNode; @@ -27,7 +27,7 @@ public final class VerifyNoFilteredAggregations implements PlanSanityChecker.Checker { @Override - public void validate(PlanNode plan, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validate(PlanNode plan, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { searchFrom(plan) .where(AggregationNode.class::isInstance) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/VerifyOnlyOneOutputNode.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/VerifyOnlyOneOutputNode.java index db860491ae8..9c81a94d61a 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/VerifyOnlyOneOutputNode.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/VerifyOnlyOneOutputNode.java @@ -16,7 +16,7 @@ import io.prestosql.Session; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; -import io.prestosql.sql.parser.SqlParser; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.OutputNode; import io.prestosql.sql.planner.plan.PlanNode; @@ -28,7 +28,7 @@ public final class VerifyOnlyOneOutputNode implements PlanSanityChecker.Checker { @Override - public void validate(PlanNode plan, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validate(PlanNode plan, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { int outputPlanNodesCount = searchFrom(plan) .where(OutputNode.class::isInstance) diff --git a/presto-main/src/main/java/io/prestosql/sql/relational/SqlToRowExpressionTranslator.java b/presto-main/src/main/java/io/prestosql/sql/relational/SqlToRowExpressionTranslator.java index 720182a3366..7dc1d044c6a 100644 --- a/presto-main/src/main/java/io/prestosql/sql/relational/SqlToRowExpressionTranslator.java +++ b/presto-main/src/main/java/io/prestosql/sql/relational/SqlToRowExpressionTranslator.java @@ -31,6 +31,7 @@ import io.prestosql.spi.type.TypeManager; import io.prestosql.spi.type.TypeSignature; import io.prestosql.spi.type.VarcharType; +import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.relational.optimizer.ExpressionOptimizer; import io.prestosql.sql.tree.ArithmeticBinaryExpression; import io.prestosql.sql.tree.ArithmeticUnaryExpression; @@ -141,6 +142,7 @@ public static RowExpression translate( Expression expression, FunctionKind functionKind, Map, Type> types, + Map layout, FunctionRegistry functionRegistry, TypeManager typeManager, Session session, @@ -150,6 +152,7 @@ public static RowExpression translate( functionKind, types, typeManager, + layout, session.getTimeZoneKey(), isLegacyRowFieldOrdinalAccessEnabled(session), SystemSessionProperties.isLegacyTimestamp(session)); @@ -171,6 +174,7 @@ private static class Visitor private final FunctionKind functionKind; private final Map, Type> types; private final TypeManager typeManager; + private final Map layout; private final TimeZoneKey timeZoneKey; private final boolean legacyRowFieldOrdinalAccess; @Deprecated @@ -180,6 +184,7 @@ private Visitor( FunctionKind functionKind, Map, Type> types, TypeManager typeManager, + Map layout, TimeZoneKey timeZoneKey, boolean legacyRowFieldOrdinalAccess, boolean isLegacyTimestamp) @@ -187,6 +192,7 @@ private Visitor( this.functionKind = functionKind; this.types = ImmutableMap.copyOf(requireNonNull(types, "types is null")); this.typeManager = typeManager; + this.layout = layout; this.timeZoneKey = timeZoneKey; this.legacyRowFieldOrdinalAccess = legacyRowFieldOrdinalAccess; this.isLegacyTimestamp = isLegacyTimestamp; @@ -363,6 +369,11 @@ protected RowExpression visitFunctionCall(FunctionCall node, Void context) @Override protected RowExpression visitSymbolReference(SymbolReference node, Void context) { + Integer field = layout.get(Symbol.from(node)); + if (field != null) { + return field(field, getType(node)); + } + return new VariableReferenceExpression(node.getName(), getType(node)); } diff --git a/presto-main/src/main/java/io/prestosql/testing/LocalQueryRunner.java b/presto-main/src/main/java/io/prestosql/testing/LocalQueryRunner.java index 6a599189f8a..e1a63a7389e 100644 --- a/presto-main/src/main/java/io/prestosql/testing/LocalQueryRunner.java +++ b/presto-main/src/main/java/io/prestosql/testing/LocalQueryRunner.java @@ -85,7 +85,6 @@ import io.prestosql.metadata.SchemaPropertyManager; import io.prestosql.metadata.SessionPropertyManager; import io.prestosql.metadata.Split; -import io.prestosql.metadata.TableLayoutHandle; import io.prestosql.metadata.TablePropertyManager; import io.prestosql.metadata.ViewDefinition; import io.prestosql.operator.Driver; @@ -136,6 +135,7 @@ import io.prestosql.sql.planner.PlanNodeIdAllocator; import io.prestosql.sql.planner.PlanOptimizers; import io.prestosql.sql.planner.SubPlan; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.optimizations.PlanOptimizer; import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.planner.plan.PlanNodeId; @@ -300,7 +300,6 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, notificationExecutor); this.nodePartitioningManager = new NodePartitioningManager(nodeScheduler); - this.splitManager = new SplitManager(new QueryManagerConfig()); this.blockEncodingManager = new BlockEncodingManager(typeRegistry); this.metadata = new MetadataManager( featuresConfig, @@ -312,6 +311,7 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, new ColumnPropertyManager(), new AnalyzePropertyManager(), transactionManager); + this.splitManager = new SplitManager(new QueryManagerConfig(), metadata); this.planFragmenter = new PlanFragmenter(this.metadata, this.nodePartitioningManager, new QueryManagerConfig()); this.joinCompiler = new JoinCompiler(metadata, featuresConfig); this.pageIndexerFactory = new GroupByHashPageIndexerFactory(joinCompiler); @@ -695,7 +695,7 @@ private List createDrivers(Session session, Plan plan, OutputFactory out LocalExecutionPlanner executionPlanner = new LocalExecutionPlanner( metadata, - sqlParser, + new TypeAnalyzer(sqlParser, metadata), Optional.empty(), pageSourceManager, indexManager, @@ -731,11 +731,9 @@ private List createDrivers(Session session, Plan plan, OutputFactory out List sources = new ArrayList<>(); long sequenceId = 0; for (TableScanNode tableScan : findTableScanNodes(subplan.getFragment().getRoot())) { - TableLayoutHandle layout = tableScan.getLayout().get(); - SplitSource splitSource = splitManager.getSplits( session, - layout, + tableScan.getTable(), stageExecutionDescriptor.isScanGroupedExecution(tableScan.getId()) ? GROUPED_SCHEDULING : UNGROUPED_SCHEDULING); ImmutableSet.Builder scheduledSplits = ImmutableSet.builder(); @@ -809,7 +807,7 @@ public List getPlanOptimizers(boolean forceSingleNode) { return new PlanOptimizers( metadata, - sqlParser, + new TypeAnalyzer(sqlParser, metadata), featuresConfig, taskManagerConfig, forceSingleNode, @@ -847,7 +845,7 @@ public Plan createPlan(Session session, @Language("SQL") String sql, List getTableLayouts(ConnectorSession session, ConnectorTableHandle table, Constraint constraint, Optional> desiredColumns) { - return ImmutableList.of(); + return ImmutableList.of(new ConnectorTableLayoutResult(new ConnectorTableLayout(TestingHandle.INSTANCE), TupleDomain.all())); } @Override public ConnectorTableLayout getTableLayout(ConnectorSession session, ConnectorTableLayoutHandle handle) { - throw new UnsupportedOperationException(); + return new ConnectorTableLayout(TestingHandle.INSTANCE); } @Override diff --git a/presto-main/src/test/java/io/prestosql/connector/MockConnectorFactory.java b/presto-main/src/test/java/io/prestosql/connector/MockConnectorFactory.java index 189a56608cf..1aee4584c58 100644 --- a/presto-main/src/test/java/io/prestosql/connector/MockConnectorFactory.java +++ b/presto-main/src/test/java/io/prestosql/connector/MockConnectorFactory.java @@ -39,7 +39,9 @@ import io.prestosql.spi.connector.Constraint; import io.prestosql.spi.connector.SchemaTableName; import io.prestosql.spi.connector.SchemaTablePrefix; +import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.spi.transaction.IsolationLevel; +import io.prestosql.testing.TestingHandle; import java.util.List; import java.util.Map; @@ -197,7 +199,7 @@ public Map> listTableColumns(ConnectorSess @Override public List getTableLayouts(ConnectorSession session, ConnectorTableHandle table, Constraint constraint, Optional> desiredColumns) { - return ImmutableList.of(); + return ImmutableList.of(new ConnectorTableLayoutResult(new ConnectorTableLayout(TestingHandle.INSTANCE), TupleDomain.all())); } @Override diff --git a/presto-main/src/test/java/io/prestosql/cost/TestCostCalculator.java b/presto-main/src/test/java/io/prestosql/cost/TestCostCalculator.java index cd199d8f21d..31f1abcceff 100644 --- a/presto-main/src/test/java/io/prestosql/cost/TestCostCalculator.java +++ b/presto-main/src/test/java/io/prestosql/cost/TestCostCalculator.java @@ -29,7 +29,6 @@ import io.prestosql.metadata.MetadataManager; import io.prestosql.metadata.Signature; import io.prestosql.metadata.TableHandle; -import io.prestosql.metadata.TableLayoutHandle; import io.prestosql.plugin.tpch.TpchColumnHandle; import io.prestosql.plugin.tpch.TpchTableHandle; import io.prestosql.plugin.tpch.TpchTableLayoutHandle; @@ -747,10 +746,9 @@ private TableScanNode tableScan(String id, String... symbols) TpchTableHandle tableHandle = new TpchTableHandle("orders", 1.0); return new TableScanNode( new PlanNodeId(id), - new TableHandle(new ConnectorId("tpch"), new TpchTableHandle("orders", 1.0)), + new TableHandle(new ConnectorId("tpch"), tableHandle, INSTANCE, Optional.of(new TpchTableLayoutHandle(tableHandle, TupleDomain.all()))), symbolsList, assignments.build(), - Optional.of(new TableLayoutHandle(new ConnectorId("tpch"), INSTANCE, new TpchTableLayoutHandle(tableHandle, TupleDomain.all()))), TupleDomain.all(), TupleDomain.all()); } diff --git a/presto-main/src/test/java/io/prestosql/execution/MockRemoteTaskFactory.java b/presto-main/src/test/java/io/prestosql/execution/MockRemoteTaskFactory.java index e4f51d3ba8c..61c59f4e9e4 100644 --- a/presto-main/src/test/java/io/prestosql/execution/MockRemoteTaskFactory.java +++ b/presto-main/src/test/java/io/prestosql/execution/MockRemoteTaskFactory.java @@ -49,8 +49,10 @@ import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.planner.plan.PlanNodeId; import io.prestosql.sql.planner.plan.TableScanNode; +import io.prestosql.testing.TestingHandle; import io.prestosql.testing.TestingMetadata.TestingColumnHandle; import io.prestosql.testing.TestingMetadata.TestingTableHandle; +import io.prestosql.testing.TestingTransactionHandle; import org.joda.time.DateTime; import javax.annotation.concurrent.GuardedBy; @@ -108,9 +110,9 @@ public MockRemoteTask createTableScanTask(TaskId taskId, Node newNode, List getSystemTable(Session session, QualifiedObjectName } @Override - public List getLayouts(Session session, TableHandle tableHandle, Constraint constraint, Optional> desiredColumns) + public Optional getLayout(Session session, TableHandle tableHandle, Constraint constraint, Optional> desiredColumns) { throw new UnsupportedOperationException(); } @Override - public TableLayout getLayout(Session session, TableLayoutHandle handle) + public TableLayout getLayout(Session session, TableHandle handle) { throw new UnsupportedOperationException(); } @Override - public TableLayoutHandle makeCompatiblePartitioning(Session session, TableLayoutHandle tableLayoutHandle, PartitioningHandle partitioningHandle) + public TableHandle makeCompatiblePartitioning(Session session, TableHandle table, PartitioningHandle partitioningHandle) { throw new UnsupportedOperationException(); } @@ -138,7 +139,7 @@ public Optional getCommonPartitioning(Session session, Parti } @Override - public Optional getInfo(Session session, TableLayoutHandle handle) + public Optional getInfo(Session session, TableHandle handle) { throw new UnsupportedOperationException(); } @@ -306,13 +307,13 @@ public ColumnHandle getUpdateRowIdColumnHandle(Session session, TableHandle tabl } @Override - public boolean supportsMetadataDelete(Session session, TableHandle tableHandle, TableLayoutHandle tableLayoutHandle) + public boolean supportsMetadataDelete(Session session, TableHandle tableHandle) { throw new UnsupportedOperationException(); } @Override - public OptionalLong metadataDelete(Session session, TableHandle tableHandle, TableLayoutHandle tableLayoutHandle) + public OptionalLong metadataDelete(Session session, TableHandle tableHandle) { throw new UnsupportedOperationException(); } @@ -508,4 +509,10 @@ public boolean catalogExists(Session session, String catalogName) { throw new UnsupportedOperationException(); } + + @Override + public Optional applyFilter(TableHandle table, ConnectorExpression expression) + { + return Optional.empty(); + } } diff --git a/presto-main/src/test/java/io/prestosql/operator/BenchmarkScanFilterAndProjectOperator.java b/presto-main/src/test/java/io/prestosql/operator/BenchmarkScanFilterAndProjectOperator.java index 439402947f4..fced32b682b 100644 --- a/presto-main/src/test/java/io/prestosql/operator/BenchmarkScanFilterAndProjectOperator.java +++ b/presto-main/src/test/java/io/prestosql/operator/BenchmarkScanFilterAndProjectOperator.java @@ -19,7 +19,6 @@ import io.prestosql.SequencePageBuilder; import io.prestosql.Session; import io.prestosql.connector.ConnectorId; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.Split; import io.prestosql.operator.ScanFilterAndProjectOperator.ScanFilterAndProjectOperatorFactory; @@ -33,13 +32,12 @@ import io.prestosql.sql.gen.PageFunctionCompiler; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; -import io.prestosql.sql.planner.SymbolToInputRewriter; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.PlanNodeId; import io.prestosql.sql.relational.RowExpression; import io.prestosql.sql.relational.SqlToRowExpressionTranslator; import io.prestosql.sql.tree.Expression; -import io.prestosql.sql.tree.NodeRef; import io.prestosql.testing.TestingMetadata.TestingColumnHandle; import io.prestosql.testing.TestingSession; import io.prestosql.testing.TestingTaskContext; @@ -81,9 +79,7 @@ import static io.prestosql.operator.scalar.FunctionAssertions.createExpression; import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.spi.type.VarcharType.VARCHAR; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput; import static io.prestosql.testing.TestingSplit.createLocalSplit; -import static java.util.Collections.emptyList; import static java.util.Locale.ENGLISH; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; @@ -102,8 +98,8 @@ public class BenchmarkScanFilterAndProjectOperator private static final Map TYPE_MAP = ImmutableMap.of("bigint", BIGINT, "varchar", VARCHAR); private static final Session TEST_SESSION = TestingSession.testSessionBuilder().build(); - private static final SqlParser SQL_PARSER = new SqlParser(); private static final Metadata METADATA = createTestMetadataManager(); + private static final TypeAnalyzer TYPE_ANALYZER = new TypeAnalyzer(new SqlParser(), METADATA); private static final int TOTAL_POSITIONS = 1_000_000; private static final DataSize FILTER_AND_PROJECT_MIN_OUTPUT_PAGE_SIZE = new DataSize(500, KILOBYTE); @@ -203,10 +199,10 @@ private List createInputPages(List types) private RowExpression getFilter(Type type) { if (type == VARCHAR) { - return rowExpression("cast(varchar0 as bigint) % 2 = 0", VARCHAR); + return rowExpression("cast(varchar0 as bigint) % 2 = 0"); } if (type == BIGINT) { - return rowExpression("bigint0 % 2 = 0", BIGINT); + return rowExpression("bigint0 % 2 = 0"); } throw new IllegalArgumentException("filter not supported for type : " + type); } @@ -216,32 +212,32 @@ private List getProjections(Type type) ImmutableList.Builder builder = ImmutableList.builder(); if (type == BIGINT) { for (int i = 0; i < columnCount; i++) { - builder.add(rowExpression("bigint" + i + " + 5", type)); + builder.add(rowExpression("bigint" + i + " + 5")); } } else if (type == VARCHAR) { for (int i = 0; i < columnCount; i++) { // alternatively use identity expression rowExpression("varchar" + i, type) or // rowExpression("substr(varchar" + i + ", 1, 1)", type) - builder.add(rowExpression("concat(varchar" + i + ", 'foo')", type)); + builder.add(rowExpression("concat(varchar" + i + ", 'foo')")); } } return builder.build(); } - private RowExpression rowExpression(String expression, Type type) + private RowExpression rowExpression(String value) { - SymbolToInputRewriter symbolToInputRewriter = new SymbolToInputRewriter(sourceLayout); - Expression inputReferenceExpression = symbolToInputRewriter.rewrite(createExpression(expression, METADATA, TypeProvider.copyOf(symbolTypes))); - - ImmutableMap.Builder builder = ImmutableMap.builder(); - for (int i = 0; i < columnCount; i++) { - builder.put(i, type); - } - Map types = builder.build(); - - Map, Type> expressionTypes = getExpressionTypesFromInput(TEST_SESSION, METADATA, SQL_PARSER, types, inputReferenceExpression, emptyList(), WarningCollector.NOOP); - return SqlToRowExpressionTranslator.translate(inputReferenceExpression, SCALAR, expressionTypes, METADATA.getFunctionRegistry(), METADATA.getTypeManager(), TEST_SESSION, true); + Expression expression = createExpression(value, METADATA, TypeProvider.copyOf(symbolTypes)); + + return SqlToRowExpressionTranslator.translate( + expression, + SCALAR, + TYPE_ANALYZER.getTypes(TEST_SESSION, TypeProvider.copyOf(symbolTypes), expression), + sourceLayout, + METADATA.getFunctionRegistry(), + METADATA.getTypeManager(), + TEST_SESSION, + true); } private static Page createPage(List types, int positions, boolean dictionary) diff --git a/presto-main/src/test/java/io/prestosql/operator/scalar/FunctionAssertions.java b/presto-main/src/test/java/io/prestosql/operator/scalar/FunctionAssertions.java index 00d181faa4b..f536c574d13 100644 --- a/presto-main/src/test/java/io/prestosql/operator/scalar/FunctionAssertions.java +++ b/presto-main/src/test/java/io/prestosql/operator/scalar/FunctionAssertions.java @@ -37,9 +37,6 @@ import io.prestosql.operator.SourceOperator; import io.prestosql.operator.SourceOperatorFactory; import io.prestosql.operator.project.CursorProcessor; -import io.prestosql.operator.project.InterpretedPageFilter; -import io.prestosql.operator.project.InterpretedPageProjection; -import io.prestosql.operator.project.PageFilter; import io.prestosql.operator.project.PageProcessor; import io.prestosql.operator.project.PageProjection; import io.prestosql.spi.ErrorCodeSupplier; @@ -56,6 +53,7 @@ import io.prestosql.spi.connector.InMemoryRecordSet; import io.prestosql.spi.connector.RecordPageSource; import io.prestosql.spi.connector.RecordSet; +import io.prestosql.spi.predicate.Utils; import io.prestosql.spi.type.TimeZoneKey; import io.prestosql.spi.type.Type; import io.prestosql.split.PageSourceProvider; @@ -65,8 +63,9 @@ import io.prestosql.sql.analyzer.SemanticException; import io.prestosql.sql.gen.ExpressionCompiler; import io.prestosql.sql.parser.SqlParser; +import io.prestosql.sql.planner.ExpressionInterpreter; import io.prestosql.sql.planner.Symbol; -import io.prestosql.sql.planner.SymbolToInputRewriter; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.PlanNodeId; import io.prestosql.sql.relational.RowExpression; @@ -128,12 +127,10 @@ import static io.prestosql.spi.type.VarcharType.VARCHAR; import static io.prestosql.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; import static io.prestosql.sql.ParsingUtil.createParsingOptions; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.analyzeExpressionsWithSymbols; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput; +import static io.prestosql.sql.analyzer.ExpressionAnalyzer.analyzeExpressions; import static io.prestosql.sql.planner.iterative.rule.CanonicalizeExpressionRewriter.canonicalizeExpression; import static io.prestosql.sql.relational.Expressions.constant; import static io.prestosql.sql.relational.SqlToRowExpressionTranslator.translate; -import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL; import static io.prestosql.testing.TestingTaskContext.createTaskContext; import static io.prestosql.type.UnknownType.UNKNOWN; import static java.lang.String.format; @@ -168,17 +165,17 @@ public final class FunctionAssertions private static final Page ZERO_CHANNEL_PAGE = new Page(1); - private static final Map INPUT_TYPES = ImmutableMap.builder() - .put(0, BIGINT) - .put(1, VARCHAR) - .put(2, DOUBLE) - .put(3, BOOLEAN) - .put(4, BIGINT) - .put(5, VARCHAR) - .put(6, VARCHAR) - .put(7, TIMESTAMP_WITH_TIME_ZONE) - .put(8, VARBINARY) - .put(9, INTEGER) + private static final Map INPUT_TYPES = ImmutableMap.builder() + .put(new Symbol("bound_long"), BIGINT) + .put(new Symbol("bound_string"), VARCHAR) + .put(new Symbol("bound_double"), DOUBLE) + .put(new Symbol("bound_boolean"), BOOLEAN) + .put(new Symbol("bound_timestamp"), BIGINT) + .put(new Symbol("bound_pattern"), VARCHAR) + .put(new Symbol("bound_null_string"), VARCHAR) + .put(new Symbol("bound_timestamp_with_timezone"), TIMESTAMP_WITH_TIME_ZONE) + .put(new Symbol("bound_binary_literal"), VARBINARY) + .put(new Symbol("bound_integer"), INTEGER) .build(); private static final Map INPUT_MAPPING = ImmutableMap.builder() @@ -213,6 +210,7 @@ public final class FunctionAssertions private final Session session; private final LocalQueryRunner runner; private final Metadata metadata; + private final TypeAnalyzer typeAnalyzer; private final ExpressionCompiler compiler; public FunctionAssertions() @@ -231,6 +229,7 @@ public FunctionAssertions(Session session, FeaturesConfig featuresConfig) runner = new LocalQueryRunner(session, featuresConfig); metadata = runner.getMetadata(); compiler = runner.getExpressionCompiler(); + typeAnalyzer = new TypeAnalyzer(SQL_PARSER, metadata); } public TypeRegistry getTypeRegistry() @@ -599,8 +598,7 @@ private List executeProjectionWithAll(String projection, Type expectedTy results.add(directOperatorValue); // interpret - Operator interpretedFilterProject = interpretedFilterProject(Optional.empty(), projectionExpression, expectedType, session); - Object interpretedValue = selectSingleValue(interpretedFilterProject, expectedType); + Object interpretedValue = interpret(projectionExpression, expectedType, session); results.add(interpretedValue); // execute over normal operator @@ -630,16 +628,7 @@ private List executeProjectionWithAll(String projection, Type expectedTy private RowExpression toRowExpression(Session session, Expression projectionExpression) { - Expression translatedProjection = new SymbolToInputRewriter(INPUT_MAPPING).rewrite(projectionExpression); - Map, Type> expressionTypes = getExpressionTypesFromInput( - session, - metadata, - SQL_PARSER, - INPUT_TYPES, - ImmutableList.of(translatedProjection), - ImmutableList.of(), - WarningCollector.NOOP); - return toRowExpression(translatedProjection, expressionTypes); + return toRowExpression(projectionExpression, typeAnalyzer.getTypes(session, TypeProvider.copyOf(INPUT_TYPES), projectionExpression), INPUT_MAPPING); } private Object selectSingleValue(OperatorFactory operatorFactory, Type type, Session session) @@ -706,7 +695,10 @@ private List executeFilterWithAll(String filter, Session session, boole } // interpret - boolean interpretedValue = executeFilter(interpretedFilterProject(Optional.of(filterExpression), TRUE_LITERAL, BOOLEAN, session)); + Boolean interpretedValue = (Boolean) interpret(filterExpression, BOOLEAN, session); + if (interpretedValue == null) { + interpretedValue = false; + } results.add(interpretedValue); // execute over normal operator @@ -749,7 +741,7 @@ public static Expression createExpression(Session session, String expression, Me parsedExpression = rewriteIdentifiersToSymbolReferences(parsedExpression); - final ExpressionAnalysis analysis = analyzeExpressionsWithSymbols( + final ExpressionAnalysis analysis = analyzeExpressions( session, metadata, SQL_PARSER, @@ -869,29 +861,46 @@ protected Void visitSymbolReference(SymbolReference node, Void context) return hasSymbolReferences.get(); } - private Operator interpretedFilterProject(Optional filter, Expression projection, Type expectedType, Session session) + private Object interpret(Expression expression, Type expectedType, Session session) { - Optional pageFilter = filter - .map(expression -> new InterpretedPageFilter( - expression, - SYMBOL_TYPES, - INPUT_MAPPING, - metadata, - SQL_PARSER, - session)); - - PageProjection pageProjection = new InterpretedPageProjection(projection, SYMBOL_TYPES, INPUT_MAPPING, metadata, SQL_PARSER, session); - assertEquals(pageProjection.getType(), expectedType); - - PageProcessor processor = new PageProcessor(pageFilter, ImmutableList.of(pageProjection)); - OperatorFactory operatorFactory = new FilterAndProjectOperatorFactory( - 0, - new PlanNodeId("test"), - () -> processor, - ImmutableList.of(pageProjection.getType()), - new DataSize(0, BYTE), - 0); - return operatorFactory.createOperator(createDriverContext(session)); + Map, Type> expressionTypes = typeAnalyzer.getTypes(session, SYMBOL_TYPES, expression); + ExpressionInterpreter evaluator = ExpressionInterpreter.expressionInterpreter(expression, metadata, session, expressionTypes); + + Object result = evaluator.evaluate(symbol -> { + int position = 0; + int channel = INPUT_MAPPING.get(symbol); + Type type = SYMBOL_TYPES.get(symbol); + + Block block = SOURCE_PAGE.getBlock(channel); + + if (block.isNull(position)) { + return null; + } + + Class javaType = type.getJavaType(); + if (javaType == boolean.class) { + return type.getBoolean(block, position); + } + else if (javaType == long.class) { + return type.getLong(block, position); + } + else if (javaType == double.class) { + return type.getDouble(block, position); + } + else if (javaType == Slice.class) { + return type.getSlice(block, position); + } + else if (javaType == Block.class) { + return type.getObject(block, position); + } + else { + throw new UnsupportedOperationException("not yet implemented"); + } + }); + + // convert result from stack type to Type ObjectValue + Block block = Utils.nativeValueToBlock(expectedType, result); + return expectedType.getObjectValue(session.toConnectorSession(), block, 0); } private static OperatorFactory compileFilterWithNoInputColumns(RowExpression filter, ExpressionCompiler compiler) @@ -955,9 +964,9 @@ private static SourceOperatorFactory compileScanFilterProject(Optional, Type> expressionTypes) + private RowExpression toRowExpression(Expression projection, Map, Type> expressionTypes, Map layout) { - return translate(projection, SCALAR, expressionTypes, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, false); + return translate(projection, SCALAR, expressionTypes, layout, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, false); } private static Page getAtMostOnePage(Operator operator, Page sourcePage) diff --git a/presto-main/src/test/java/io/prestosql/sql/TestExpressionInterpreter.java b/presto-main/src/test/java/io/prestosql/sql/TestExpressionInterpreter.java index 3a281004962..203d2900537 100644 --- a/presto-main/src/test/java/io/prestosql/sql/TestExpressionInterpreter.java +++ b/presto-main/src/test/java/io/prestosql/sql/TestExpressionInterpreter.java @@ -18,7 +18,6 @@ import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.MetadataManager; import io.prestosql.operator.scalar.FunctionAssertions; @@ -31,6 +30,7 @@ import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.ExpressionInterpreter; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.ExpressionRewriter; @@ -69,13 +69,11 @@ import static io.prestosql.sql.ExpressionFormatter.formatExpression; import static io.prestosql.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; import static io.prestosql.sql.ParsingUtil.createParsingOptions; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.planner.ExpressionInterpreter.expressionInterpreter; import static io.prestosql.sql.planner.ExpressionInterpreter.expressionOptimizer; import static io.prestosql.type.IntervalDayTimeType.INTERVAL_DAY_TIME; import static io.prestosql.util.DateTimeZoneIndex.getDateTimeZone; import static java.lang.String.format; -import static java.util.Collections.emptyList; import static java.util.Locale.ENGLISH; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertThrows; @@ -116,6 +114,7 @@ public class TestExpressionInterpreter private static final SqlParser SQL_PARSER = new SqlParser(); private static final Metadata METADATA = MetadataManager.createTestMetadataManager(); + private static final TypeAnalyzer TYPE_ANALYZER = new TypeAnalyzer(SQL_PARSER, METADATA); @Test public void testAnd() @@ -1454,7 +1453,7 @@ private static Object optimize(@Language("SQL") String expression) Expression parsedExpression = FunctionAssertions.createExpression(expression, METADATA, SYMBOL_TYPES); - Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, SYMBOL_TYPES, parsedExpression, emptyList(), WarningCollector.NOOP); + Map, Type> expressionTypes = TYPE_ANALYZER.getTypes(TEST_SESSION, SYMBOL_TYPES, parsedExpression); ExpressionInterpreter interpreter = expressionOptimizer(parsedExpression, METADATA, TEST_SESSION, expressionTypes); return interpreter.optimize(symbol -> { switch (symbol.getName().toLowerCase(ENGLISH)) { @@ -1511,7 +1510,7 @@ private static void assertRoundTrip(String expression) private static Object evaluate(Expression expression) { - Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, SYMBOL_TYPES, expression, emptyList(), WarningCollector.NOOP); + Map, Type> expressionTypes = TYPE_ANALYZER.getTypes(TEST_SESSION, SYMBOL_TYPES, expression); ExpressionInterpreter interpreter = expressionInterpreter(expression, METADATA, TEST_SESSION, expressionTypes); return interpreter.evaluate(); diff --git a/presto-main/src/test/java/io/prestosql/sql/TestSqlToRowExpressionTranslator.java b/presto-main/src/test/java/io/prestosql/sql/TestSqlToRowExpressionTranslator.java index b8071a11f17..20a52f4d2d5 100644 --- a/presto-main/src/test/java/io/prestosql/sql/TestSqlToRowExpressionTranslator.java +++ b/presto-main/src/test/java/io/prestosql/sql/TestSqlToRowExpressionTranslator.java @@ -92,7 +92,7 @@ private RowExpression translateAndOptimize(Expression expression) private RowExpression translateAndOptimize(Expression expression, Map, Type> types) { - return SqlToRowExpressionTranslator.translate(expression, SCALAR, types, metadata.getFunctionRegistry(), metadata.getTypeManager(), TEST_SESSION, true); + return SqlToRowExpressionTranslator.translate(expression, SCALAR, types, ImmutableMap.of(), metadata.getFunctionRegistry(), metadata.getTypeManager(), TEST_SESSION, true); } private Expression simplifyExpression(Expression expression) diff --git a/presto-main/src/test/java/io/prestosql/sql/gen/PageProcessorBenchmark.java b/presto-main/src/test/java/io/prestosql/sql/gen/PageProcessorBenchmark.java index 2a556a9d576..7c98ae1e869 100644 --- a/presto-main/src/test/java/io/prestosql/sql/gen/PageProcessorBenchmark.java +++ b/presto-main/src/test/java/io/prestosql/sql/gen/PageProcessorBenchmark.java @@ -17,7 +17,6 @@ import com.google.common.collect.ImmutableMap; import io.prestosql.SequencePageBuilder; import io.prestosql.Session; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.MetadataManager; import io.prestosql.operator.DriverYieldSignal; @@ -30,7 +29,7 @@ import io.prestosql.spi.type.Type; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; -import io.prestosql.sql.planner.SymbolToInputRewriter; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.relational.RowExpression; import io.prestosql.sql.relational.SqlToRowExpressionTranslator; @@ -66,8 +65,6 @@ import static io.prestosql.operator.scalar.FunctionAssertions.createExpression; import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.spi.type.VarcharType.VARCHAR; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput; -import static java.util.Collections.emptyList; import static java.util.Locale.ENGLISH; import static java.util.stream.Collectors.toList; @@ -81,8 +78,8 @@ public class PageProcessorBenchmark { private static final Map TYPE_MAP = ImmutableMap.of("bigint", BIGINT, "varchar", VARCHAR); - private static final SqlParser SQL_PARSER = new SqlParser(); private static final Metadata METADATA = createTestMetadataManager(); + private static final TypeAnalyzer TYPE_ANALYZER = new TypeAnalyzer(new SqlParser(), METADATA); private static final Session TEST_SESSION = TestingSession.testSessionBuilder().build(); private static final int POSITIONS = 1024; @@ -151,10 +148,10 @@ public List> columnOriented() private RowExpression getFilter(Type type) { if (type == VARCHAR) { - return rowExpression("cast(varchar0 as bigint) % 2 = 0", VARCHAR); + return rowExpression("cast(varchar0 as bigint) % 2 = 0"); } if (type == BIGINT) { - return rowExpression("bigint0 % 2 = 0", BIGINT); + return rowExpression("bigint0 % 2 = 0"); } throw new IllegalArgumentException("filter not supported for type : " + type); } @@ -164,32 +161,25 @@ private List getProjections(Type type) ImmutableList.Builder builder = ImmutableList.builder(); if (type == BIGINT) { for (int i = 0; i < columnCount; i++) { - builder.add(rowExpression("bigint" + i + " + 5", type)); + builder.add(rowExpression("bigint" + i + " + 5")); } } else if (type == VARCHAR) { for (int i = 0; i < columnCount; i++) { // alternatively use identity expression rowExpression("varchar" + i, type) or // rowExpression("substr(varchar" + i + ", 1, 1)", type) - builder.add(rowExpression("concat(varchar" + i + ", 'foo')", type)); + builder.add(rowExpression("concat(varchar" + i + ", 'foo')")); } } return builder.build(); } - private RowExpression rowExpression(String expression, Type type) + private RowExpression rowExpression(String value) { - SymbolToInputRewriter symbolToInputRewriter = new SymbolToInputRewriter(sourceLayout); - Expression inputReferenceExpression = symbolToInputRewriter.rewrite(createExpression(expression, METADATA, TypeProvider.copyOf(symbolTypes))); + Expression expression = createExpression(value, METADATA, TypeProvider.copyOf(symbolTypes)); - ImmutableMap.Builder builder = ImmutableMap.builder(); - for (int i = 0; i < columnCount; i++) { - builder.put(i, type); - } - Map types = builder.build(); - - Map, Type> expressionTypes = getExpressionTypesFromInput(TEST_SESSION, METADATA, SQL_PARSER, types, inputReferenceExpression, emptyList(), WarningCollector.NOOP); - return SqlToRowExpressionTranslator.translate(inputReferenceExpression, SCALAR, expressionTypes, METADATA.getFunctionRegistry(), METADATA.getTypeManager(), TEST_SESSION, true); + Map, Type> expressionTypes = TYPE_ANALYZER.getTypes(TEST_SESSION, TypeProvider.copyOf(symbolTypes), expression); + return SqlToRowExpressionTranslator.translate(expression, SCALAR, expressionTypes, sourceLayout, METADATA.getFunctionRegistry(), METADATA.getTypeManager(), TEST_SESSION, true); } private static Page createPage(List types, boolean dictionary) diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestEffectivePredicateExtractor.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestEffectivePredicateExtractor.java index 5657d336f7b..ccc92b70772 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/TestEffectivePredicateExtractor.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/TestEffectivePredicateExtractor.java @@ -27,7 +27,6 @@ import io.prestosql.metadata.MetadataManager; import io.prestosql.metadata.Signature; import io.prestosql.metadata.TableHandle; -import io.prestosql.metadata.TableLayoutHandle; import io.prestosql.spi.block.SortOrder; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.predicate.Domain; @@ -89,11 +88,7 @@ @Test(singleThreaded = true) public class TestEffectivePredicateExtractor { - private static final TableHandle DUAL_TABLE_HANDLE = new TableHandle(new ConnectorId("test"), new TestingTableHandle()); - private static final TableLayoutHandle TESTING_TABLE_LAYOUT = new TableLayoutHandle( - new ConnectorId("x"), - TestingTransactionHandle.create(), - TestingHandle.INSTANCE); + private static final TableHandle DUAL_TABLE_HANDLE = new TableHandle(new ConnectorId("test"), new TestingTableHandle(), TestingTransactionHandle.create(), Optional.of(TestingHandle.INSTANCE)); private static final Symbol A = new Symbol("a"); private static final Symbol B = new Symbol("b"); @@ -130,7 +125,7 @@ public void setUp() .build(); Map assignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(A, B, C, D, E, F))); - baseTableScan = new TableScanNode( + baseTableScan = TableScanNode.newInstance( newId(), DUAL_TABLE_HANDLE, ImmutableList.copyOf(assignments.keySet()), @@ -323,7 +318,7 @@ public void testTableScan() { // Effective predicate is True if there is no effective predicate Map assignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(A, B, C, D))); - PlanNode node = new TableScanNode( + PlanNode node = TableScanNode.newInstance( newId(), DUAL_TABLE_HANDLE, ImmutableList.copyOf(assignments.keySet()), @@ -336,7 +331,6 @@ public void testTableScan() DUAL_TABLE_HANDLE, ImmutableList.copyOf(assignments.keySet()), assignments, - Optional.of(TESTING_TABLE_LAYOUT), TupleDomain.none(), TupleDomain.all()); effectivePredicate = effectivePredicateExtractor.extract(node); @@ -347,7 +341,6 @@ public void testTableScan() DUAL_TABLE_HANDLE, ImmutableList.copyOf(assignments.keySet()), assignments, - Optional.of(TESTING_TABLE_LAYOUT), TupleDomain.withColumnDomains(ImmutableMap.of(scanAssignments.get(A), Domain.singleValue(BIGINT, 1L))), TupleDomain.all()); effectivePredicate = effectivePredicateExtractor.extract(node); @@ -358,7 +351,6 @@ public void testTableScan() DUAL_TABLE_HANDLE, ImmutableList.copyOf(assignments.keySet()), assignments, - Optional.of(TESTING_TABLE_LAYOUT), TupleDomain.withColumnDomains(ImmutableMap.of( scanAssignments.get(A), Domain.singleValue(BIGINT, 1L), scanAssignments.get(B), Domain.singleValue(BIGINT, 2L))), @@ -371,7 +363,6 @@ public void testTableScan() DUAL_TABLE_HANDLE, ImmutableList.copyOf(assignments.keySet()), assignments, - Optional.empty(), TupleDomain.all(), TupleDomain.all()); effectivePredicate = effectivePredicateExtractor.extract(node); @@ -717,7 +708,6 @@ private static TableScanNode tableScanNode(Map scanAssignm DUAL_TABLE_HANDLE, ImmutableList.copyOf(scanAssignments.keySet()), scanAssignments, - Optional.empty(), TupleDomain.all(), TupleDomain.all()); } diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestInterpretedPageFilterFunction.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestInterpretedPageFilterFunction.java deleted file mode 100644 index c724efa16bf..00000000000 --- a/presto-main/src/test/java/io/prestosql/sql/planner/TestInterpretedPageFilterFunction.java +++ /dev/null @@ -1,220 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prestosql.sql.planner; - -import com.google.common.collect.ImmutableMap; -import io.prestosql.metadata.Metadata; -import io.prestosql.metadata.MetadataManager; -import io.prestosql.operator.project.InterpretedPageFilter; -import io.prestosql.operator.project.SelectedPositions; -import io.prestosql.spi.Page; -import io.prestosql.sql.parser.SqlParser; -import io.prestosql.sql.tree.ComparisonExpression; -import org.testng.annotations.Test; - -import static io.prestosql.SessionTestUtils.TEST_SESSION; -import static io.prestosql.operator.scalar.FunctionAssertions.createExpression; -import static java.lang.String.format; -import static org.testng.Assert.assertEquals; -import static org.testng.Assert.assertTrue; - -public class TestInterpretedPageFilterFunction -{ - private static final SqlParser SQL_PARSER = new SqlParser(); - private static final Metadata METADATA = MetadataManager.createTestMetadataManager(); - - @Test - public void testNullLiteral() - { - assertFilter("null", false); - } - - @Test - public void testBooleanLiteral() - { - assertFilter("true", true); - assertFilter("false", false); - } - - @Test - public void testNotExpression() - { - assertFilter("not true", false); - assertFilter("not false", true); - assertFilter("not null", false); - } - - @Test - public void testAndExpression() - { - assertFilter("true and true", true); - assertFilter("true and false", false); - assertFilter("true and null", false); - - assertFilter("false and true", false); - assertFilter("false and false", false); - assertFilter("false and null", false); - - assertFilter("null and true", false); - assertFilter("null and false", false); - assertFilter("null and null", false); - } - - @Test - public void testORExpression() - { - assertFilter("true or true", true); - assertFilter("true or false", true); - assertFilter("true or null", true); - - assertFilter("false or true", true); - assertFilter("false or false", false); - assertFilter("false or null", false); - - assertFilter("null or true", true); - assertFilter("null or false", false); - assertFilter("null or null", false); - } - - @Test - public void testIsNullExpression() - { - assertFilter("null is null", true); - assertFilter("42 is null", false); - } - - @Test - public void testIsNotNullExpression() - { - assertFilter("42 is not null", true); - assertFilter("null is not null", false); - } - - @Test - public void testComparisonExpression() - { - assertFilter("42 = 42", true); - assertFilter("42 = 42.0", true); - assertFilter("42.42 = 42.42", true); - assertFilter("'foo' = 'foo'", true); - - assertFilter("42 = 87", false); - assertFilter("42 = 22.2", false); - assertFilter("42.42 = 22.2", false); - assertFilter("'foo' = 'bar'", false); - - assertFilter("42 != 87", true); - assertFilter("42 != 22.2", true); - assertFilter("42.42 != 22.22", true); - assertFilter("'foo' != 'bar'", true); - - assertFilter("42 != 42", false); - assertFilter("42 != 42.0", false); - assertFilter("42.42 != 42.42", false); - assertFilter("'foo' != 'foo'", false); - - assertFilter("42 < 88", true); - assertFilter("42 < 88.8", true); - assertFilter("42.42 < 88.8", true); - assertFilter("'bar' < 'foo'", true); - - assertFilter("88 < 42", false); - assertFilter("88 < 42.42", false); - assertFilter("88.8 < 42.42", false); - assertFilter("'foo' < 'bar'", false); - - assertFilter("42 <= 88", true); - assertFilter("42 <= 88.8", true); - assertFilter("42.42 <= 88.8", true); - assertFilter("'bar' <= 'foo'", true); - - assertFilter("42 <= 42", true); - assertFilter("42 <= 42.0", true); - assertFilter("42.42 <= 42.42", true); - assertFilter("'foo' <= 'foo'", true); - - assertFilter("88 <= 42", false); - assertFilter("88 <= 42.42", false); - assertFilter("88.8 <= 42.42", false); - assertFilter("'foo' <= 'bar'", false); - - assertFilter("88 >= 42", true); - assertFilter("88.8 >= 42.0", true); - assertFilter("88.8 >= 42.42", true); - assertFilter("'foo' >= 'bar'", true); - - assertFilter("42 >= 88", false); - assertFilter("42.42 >= 88.0", false); - assertFilter("42.42 >= 88.88", false); - assertFilter("'bar' >= 'foo'", false); - - assertFilter("88 >= 42", true); - assertFilter("88.8 >= 42.0", true); - assertFilter("88.8 >= 42.42", true); - assertFilter("'foo' >= 'bar'", true); - assertFilter("42 >= 42", true); - assertFilter("42 >= 42.0", true); - assertFilter("42.42 >= 42.42", true); - assertFilter("'foo' >= 'foo'", true); - - assertFilter("42 >= 88", false); - assertFilter("42.42 >= 88.0", false); - assertFilter("42.42 >= 88.88", false); - assertFilter("'bar' >= 'foo'", false); - } - - @Test - public void testComparisonExpressionWithNulls() - { - for (ComparisonExpression.Operator operator : ComparisonExpression.Operator.values()) { - if (operator == ComparisonExpression.Operator.IS_DISTINCT_FROM) { - // IS DISTINCT FROM has different NULL semantics - continue; - } - - assertFilter(format("NULL %s NULL", operator.getValue()), false); - - assertFilter(format("42 %s NULL", operator.getValue()), false); - assertFilter(format("NULL %s 42", operator.getValue()), false); - - assertFilter(format("11.1 %s NULL", operator.getValue()), false); - assertFilter(format("NULL %s 11.1", operator.getValue()), false); - } - } - - private static void assertFilter(String expression, boolean expectedValue) - { - InterpretedPageFilter filterFunction = new InterpretedPageFilter( - createExpression(expression, METADATA, TypeProvider.empty()), - TypeProvider.empty(), - ImmutableMap.of(), - METADATA, - SQL_PARSER, - TEST_SESSION); - - SelectedPositions selectedPositions = filterFunction.filter(TEST_SESSION.toConnectorSession(), new Page(1)); - assertEquals(selectedPositions.size(), expectedValue ? 1 : 0); - if (expectedValue) { - if (selectedPositions.isList()) { - assertEquals(selectedPositions.getPositions()[selectedPositions.getOffset()], 0); - } - else { - assertEquals(selectedPositions.getOffset(), 0); - } - } - else { - assertTrue(selectedPositions.isEmpty()); - } - } -} diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestInterpretedPageProjectionFunction.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestInterpretedPageProjectionFunction.java deleted file mode 100644 index 3e0d1b7b143..00000000000 --- a/presto-main/src/test/java/io/prestosql/sql/planner/TestInterpretedPageProjectionFunction.java +++ /dev/null @@ -1,288 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prestosql.sql.planner; - -import com.google.common.collect.ImmutableMap; -import io.prestosql.block.BlockAssertions; -import io.prestosql.metadata.Metadata; -import io.prestosql.metadata.MetadataManager; -import io.prestosql.operator.DriverYieldSignal; -import io.prestosql.operator.Work; -import io.prestosql.operator.project.InterpretedPageProjection; -import io.prestosql.spi.Page; -import io.prestosql.spi.block.Block; -import io.prestosql.spi.block.BlockBuilder; -import io.prestosql.spi.type.Type; -import io.prestosql.sql.parser.SqlParser; -import io.prestosql.sql.tree.ArithmeticBinaryExpression; -import org.testng.annotations.AfterClass; -import org.testng.annotations.Test; - -import javax.annotation.Nullable; - -import java.util.List; -import java.util.Map; -import java.util.concurrent.ScheduledExecutorService; - -import static io.airlift.concurrent.Threads.daemonThreadsNamed; -import static io.prestosql.SessionTestUtils.TEST_SESSION; -import static io.prestosql.operator.project.SelectedPositions.positionsList; -import static io.prestosql.operator.scalar.FunctionAssertions.createExpression; -import static io.prestosql.spi.type.BigintType.BIGINT; -import static io.prestosql.spi.type.BooleanType.BOOLEAN; -import static io.prestosql.spi.type.DoubleType.DOUBLE; -import static io.prestosql.spi.type.TypeUtils.writeNativeValue; -import static io.prestosql.spi.type.VarcharType.VARCHAR; -import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor; -import static org.testng.Assert.assertEquals; -import static org.testng.Assert.assertFalse; -import static org.testng.Assert.assertTrue; - -public class TestInterpretedPageProjectionFunction -{ - // todo add cases for decimal - - private static final SqlParser SQL_PARSER = new SqlParser(); - private static final Metadata METADATA = MetadataManager.createTestMetadataManager(); - private static final ScheduledExecutorService executor = newSingleThreadScheduledExecutor(daemonThreadsNamed("test-%s")); - - @AfterClass(alwaysRun = true) - public void tearDown() - { - executor.shutdownNow(); - } - - @Test - public void testBooleanExpression() - { - assertProjection("true", true); - assertProjection("false", false); - assertProjection("1 = 1", true); - assertProjection("1 = 0", false); - assertProjection("true and false", false); - } - - @Test - public void testArithmeticExpression() - { - assertProjection("42 + 87", 42 + 87); - assertProjection("42 + 22.2E0", 42 + 22.2); - assertProjection("11.1E0 + 22.2E0", 11.1 + 22.2); - - assertProjection("42 - 87", 42 - 87); - assertProjection("42 - 22.2E0", 42 - 22.2); - assertProjection("11.1E0 - 22.2E0", 11.1 - 22.2); - - assertProjection("42 * 87", 42 * 87); - assertProjection("42 * 22.2E0", 42 * 22.2); - assertProjection("11.1E0 * 22.2E0", 11.1 * 22.2); - - assertProjection("42 / 87", 42 / 87); - assertProjection("42 / 22.2E0", 42 / 22.2); - assertProjection("11.1E0 / 22.2E0", 11.1 / 22.2); - - assertProjection("42 % 87", 42 % 87); - assertProjection("42 % 22.2E0", 42 % 22.2); - assertProjection("11.1E0 % 22.2E0", 11.1 % 22.2); - - assertProjection("42 + BIGINT '87'", 42 + 87L); - assertProjection("BIGINT '42' - 22.2E0", 42L - 22.2); - assertProjection("42 * BIGINT '87'", 42 * 87L); - assertProjection("BIGINT '11' / 22.2E0", 11L / 22.2); - assertProjection("11.1E0 % BIGINT '22'", 11.1 % 22L); - } - - @Test - public void testArithmeticExpressionWithNulls() - { - for (ArithmeticBinaryExpression.Operator operator : ArithmeticBinaryExpression.Operator.values()) { - assertProjection("CAST(NULL AS INTEGER) " + operator.getValue() + " CAST(NULL AS INTEGER)", null); - - assertProjection("42 " + operator.getValue() + " NULL", null); - assertProjection("NULL " + operator.getValue() + " 42", null); - - assertProjection("11.1 " + operator.getValue() + " CAST(NULL AS INTEGER)", null); - assertProjection("CAST(NULL AS INTEGER) " + operator.getValue() + " 11.1", null); - } - } - - @Test - public void testCoalesceExpression() - { - assertProjection("COALESCE(42, 87, 100)", 42); - assertProjection("COALESCE(NULL, 87, 100)", 87); - assertProjection("COALESCE(42, NULL, 100)", 42); - assertProjection("COALESCE(42, NULL, BIGINT '100')", 42L); - assertProjection("COALESCE(NULL, NULL, 100)", 100); - assertProjection("COALESCE(NULL, NULL, BIGINT '100')", 100L); - - assertProjection("COALESCE(42.2E0, 87.2E0, 100.2E0)", 42.2); - assertProjection("COALESCE(NULL, 87.2E0, 100.2E0)", 87.2); - assertProjection("COALESCE(42.2E0, NULL, 100.2E0)", 42.2); - assertProjection("COALESCE(NULL, NULL, 100.2E0)", 100.2); - - assertProjection("COALESCE('foo', 'bar', 'zah')", "foo"); - assertProjection("COALESCE(NULL, 'bar', 'zah')", "bar"); - assertProjection("COALESCE('foo', NULL, 'zah')", "foo"); - assertProjection("COALESCE(NULL, NULL, 'zah')", "zah"); - - assertProjection("COALESCE(NULL, NULL, NULL)", null); - } - - @Test - public void testNullIf() - { - assertProjection("NULLIF(42, 42)", null); - assertProjection("NULLIF(42, 42.0E0)", null); - assertProjection("NULLIF(42.42E0, 42.42E0)", null); - assertProjection("NULLIF('foo', 'foo')", null); - - assertProjection("NULLIF(42, 87)", 42); - assertProjection("NULLIF(42, 22.2E0)", 42); - assertProjection("NULLIF(42, BIGINT '87')", 42); - assertProjection("NULLIF(BIGINT '42', 22.2E0)", 42L); - assertProjection("NULLIF(42.42E0, 22.2E0)", 42.42); - assertProjection("NULLIF('foo', 'bar')", "foo"); - - assertProjection("NULLIF(NULL, NULL)", null); - - assertProjection("NULLIF(42, NULL)", 42); - assertProjection("NULLIF(NULL, 42)", null); - - assertProjection("NULLIF(11.1E0, NULL)", 11.1); - assertProjection("NULLIF(NULL, 11.1E0)", null); - } - - @Test - public void testSymbolReference() - { - Symbol symbol = new Symbol("symbol"); - ImmutableMap symbolToInputMappings = ImmutableMap.of(symbol, 0); - assertProjection("symbol", true, symbolToInputMappings, TypeProvider.copyOf(ImmutableMap.of(symbol, BOOLEAN)), 0, createBlock(BOOLEAN, true)); - assertProjection("symbol", null, symbolToInputMappings, TypeProvider.copyOf(ImmutableMap.of(symbol, BOOLEAN)), 0, createNullBlock(BOOLEAN)); - - assertProjection("symbol", 42L, symbolToInputMappings, TypeProvider.copyOf(ImmutableMap.of(symbol, BIGINT)), 0, createBlock(BIGINT, 42)); - assertProjection("symbol", null, symbolToInputMappings, TypeProvider.copyOf(ImmutableMap.of(symbol, BIGINT)), 0, createNullBlock(BIGINT)); - - assertProjection("symbol", 11.1, symbolToInputMappings, TypeProvider.copyOf(ImmutableMap.of(symbol, DOUBLE)), 0, createBlock(DOUBLE, 11.1)); - assertProjection("symbol", null, symbolToInputMappings, TypeProvider.copyOf(ImmutableMap.of(symbol, DOUBLE)), 0, createNullBlock(DOUBLE)); - - assertProjection("symbol", "foo", symbolToInputMappings, TypeProvider.copyOf(ImmutableMap.of(symbol, VARCHAR)), 0, createBlock(VARCHAR, "foo")); - assertProjection("symbol", null, symbolToInputMappings, TypeProvider.copyOf(ImmutableMap.of(symbol, VARCHAR)), 0, createNullBlock(VARCHAR)); - } - - private static void assertProjection(String expression, @Nullable Object expectedValue) - { - assertProjection( - expression, - expectedValue, - ImmutableMap.of(), - TypeProvider.empty(), - 0); - } - - private static void assertProjection( - String expression, - @Nullable Object expectedValue, - Map symbolToInputMappings, - TypeProvider symbolTypes, - int position, - Block... blocks) - { - assertProjection(expression, new Object[] {expectedValue}, symbolToInputMappings, symbolTypes, new int[] {position}, blocks); - } - - private static void assertProjection( - String expression, - Object[] expectedValues, - Map symbolToInputMappings, - TypeProvider symbolTypes, - int[] positions, - Block... blocks) - { - InterpretedPageProjection projectionFunction = new InterpretedPageProjection( - createExpression(expression, METADATA, symbolTypes), - symbolTypes, - symbolToInputMappings, - METADATA, - SQL_PARSER, - TEST_SESSION); - - // project with yield - DriverYieldSignal yieldSignal = new DriverYieldSignal(); - Work work = projectionFunction.project( - TEST_SESSION.toConnectorSession(), - yieldSignal, - new Page(positions.length, blocks), - positionsList(positions, 0, positions.length)); - - Block block; - // Get nothing for the first position.length compute due to yield - // Currently we enforce a yield check for every position; free feel to adjust the number if the behavior changes - for (int i = 0; i < positions.length; i++) { - yieldSignal.setWithDelay(1, executor); - yieldSignal.forceYieldForTesting(); - assertFalse(work.process()); - yieldSignal.reset(); - } - // the next yield is not going to prevent a block to be produced - yieldSignal.setWithDelay(1, executor); - yieldSignal.forceYieldForTesting(); - yieldSignal.reset(); - assertTrue(work.process()); - block = work.getResult(); - - List actualValues = BlockAssertions.toValues(projectionFunction.getType(), block); - assertEquals(actualValues.size(), positions.length); - assertEquals(expectedValues.length, positions.length); - for (int i = 0; i < positions.length; i++) { - assertEquals(actualValues.get(i), expectedValues[i]); - } - - // project without yield - work = projectionFunction.project( - TEST_SESSION.toConnectorSession(), - new DriverYieldSignal(), - new Page(positions.length, blocks), - positionsList(positions, 0, positions.length)); - assertTrue(work.process()); - block = work.getResult(); - - actualValues = BlockAssertions.toValues(projectionFunction.getType(), block); - assertEquals(actualValues.size(), positions.length); - assertEquals(expectedValues.length, positions.length); - for (int i = 0; i < positions.length; i++) { - assertEquals(actualValues.get(i), expectedValues[i]); - } - } - - private static Block createBlock(Type type, Object value) - { - return createBlock(type, new Object[] {value}); - } - - private static Block createNullBlock(Type type) - { - return createBlock(type, new Object[] {null}); - } - - private static Block createBlock(Type type, Object[] values) - { - BlockBuilder blockBuilder = type.createBlockBuilder(null, values.length); - for (Object value : values) { - writeNativeValue(type, blockBuilder, value); - } - return blockBuilder.build(); - } -} diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestTypeValidator.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestTypeValidator.java index 3fda22c4962..2d26b5d0b95 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/TestTypeValidator.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/TestTypeValidator.java @@ -21,6 +21,7 @@ import io.prestosql.connector.ConnectorId; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.FunctionKind; +import io.prestosql.metadata.MetadataManager; import io.prestosql.metadata.Signature; import io.prestosql.metadata.TableHandle; import io.prestosql.spi.connector.ColumnHandle; @@ -44,8 +45,10 @@ import io.prestosql.sql.tree.FunctionCall; import io.prestosql.sql.tree.QualifiedName; import io.prestosql.sql.tree.WindowFrame; +import io.prestosql.testing.TestingHandle; import io.prestosql.testing.TestingMetadata.TestingColumnHandle; import io.prestosql.testing.TestingMetadata.TestingTableHandle; +import io.prestosql.testing.TestingTransactionHandle; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; @@ -66,7 +69,7 @@ @Test(singleThreaded = true) public class TestTypeValidator { - private static final TableHandle TEST_TABLE_HANDLE = new TableHandle(new ConnectorId("test"), new TestingTableHandle()); + private static final TableHandle TEST_TABLE_HANDLE = new TableHandle(new ConnectorId("test"), new TestingTableHandle(), TestingTransactionHandle.create(), Optional.of(TestingHandle.INSTANCE)); private static final SqlParser SQL_PARSER = new SqlParser(); private static final TypeValidator TYPE_VALIDATOR = new TypeValidator(); @@ -101,7 +104,6 @@ public void setUp() TEST_TABLE_HANDLE, ImmutableList.copyOf(assignments.keySet()), assignments, - Optional.empty(), TupleDomain.all(), TupleDomain.all()); } @@ -392,7 +394,8 @@ public void testInvalidUnion() private void assertTypesValid(PlanNode node) { - TYPE_VALIDATOR.validate(node, TEST_SESSION, createTestMetadataManager(), SQL_PARSER, symbolAllocator.getTypes(), WarningCollector.NOOP); + MetadataManager metadata = createTestMetadataManager(); + TYPE_VALIDATOR.validate(node, TEST_SESSION, metadata, new TypeAnalyzer(SQL_PARSER, metadata), symbolAllocator.getTypes(), WarningCollector.NOOP); } private static PlanNodeId newId() diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/TableScanMatcher.java b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/TableScanMatcher.java index 3f7aa740f1b..ba690e3482c 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/TableScanMatcher.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/TableScanMatcher.java @@ -61,13 +61,7 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses return new MatchResult( expectedTableName.equalsIgnoreCase(actualTableName) && ((!expectedConstraint.isPresent()) || - domainsMatch(expectedConstraint, tableScanNode.getCurrentConstraint(), tableScanNode.getTable(), session, metadata)) && - hasTableLayout(tableScanNode)); - } - - private boolean hasTableLayout(TableScanNode tableScanNode) - { - return !hasTableLayout.isPresent() || hasTableLayout.get() == tableScanNode.getLayout().isPresent(); + domainsMatch(expectedConstraint, tableScanNode.getCurrentConstraint(), tableScanNode.getTable(), session, metadata))); } @Override diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPruneCountAggregationOverScalar.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPruneCountAggregationOverScalar.java index 77003755825..5654be2f016 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPruneCountAggregationOverScalar.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPruneCountAggregationOverScalar.java @@ -19,6 +19,7 @@ import io.prestosql.metadata.TableHandle; import io.prestosql.plugin.tpch.TpchColumnHandle; import io.prestosql.plugin.tpch.TpchTableHandle; +import io.prestosql.plugin.tpch.TpchTransactionHandle; import io.prestosql.spi.type.BigintType; import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; @@ -29,6 +30,8 @@ import io.prestosql.sql.tree.SymbolReference; import org.testng.annotations.Test; +import java.util.Optional; + import static io.prestosql.plugin.tpch.TpchMetadata.TINY_SCALE_FACTOR; import static io.prestosql.spi.type.DoubleType.DOUBLE; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.values; @@ -143,7 +146,9 @@ public void testDoesNotFireOnNestedNonCountAggregate() p.tableScan( new TableHandle( new ConnectorId("local"), - new TpchTableHandle("orders", TINY_SCALE_FACTOR)), + new TpchTableHandle("orders", TINY_SCALE_FACTOR), + TpchTransactionHandle.INSTANCE, + Optional.empty()), ImmutableList.of(totalPrice), ImmutableMap.of(totalPrice, new TpchColumnHandle(totalPrice.getName(), DOUBLE)))))); diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPruneIndexSourceColumns.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPruneIndexSourceColumns.java index 9b9511478d7..1589d0ad5e9 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPruneIndexSourceColumns.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPruneIndexSourceColumns.java @@ -21,6 +21,7 @@ import io.prestosql.metadata.TableHandle; import io.prestosql.plugin.tpch.TpchColumnHandle; import io.prestosql.plugin.tpch.TpchTableHandle; +import io.prestosql.plugin.tpch.TpchTransactionHandle; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.predicate.Domain; import io.prestosql.spi.predicate.TupleDomain; @@ -31,6 +32,7 @@ import io.prestosql.sql.planner.plan.PlanNode; import org.testng.annotations.Test; +import java.util.Optional; import java.util.function.Predicate; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -77,6 +79,7 @@ private static PlanNode buildProjectedIndexSource(PlanBuilder p, Predicate rule : pickTableLayout.rules()) { - tester().assertThat(rule) - .on(p -> p.values(p.symbol("a", BIGINT))) - .doesNotFire(); - } - } - - @Test - public void doesNotFireIfTableScanHasTableLayout() - { - tester().assertThat(pickTableLayout.pickTableLayoutWithoutPredicate()) - .on(p -> p.tableScan( - nationTableHandle, - ImmutableList.of(p.symbol("nationkey", BIGINT)), - ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)), - Optional.of(nationTableLayoutHandle))) + tester().assertThat(pushPredicateIntoTableScan) + .on(p -> p.values(p.symbol("a", BIGINT))) .doesNotFire(); } @Test public void eliminateTableScanWhenNoLayoutExist() { - tester().assertThat(pickTableLayout.pickTableLayoutForPredicate()) + tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter(expression("orderstatus = 'G'"), p.tableScan( ordersTableHandle, ImmutableList.of(p.symbol("orderstatus", createVarcharType(1))), - ImmutableMap.of(p.symbol("orderstatus", createVarcharType(1)), new TpchColumnHandle("orderstatus", createVarcharType(1))), - Optional.of(ordersTableLayoutHandle)))) + ImmutableMap.of(p.symbol("orderstatus", createVarcharType(1)), new TpchColumnHandle("orderstatus", createVarcharType(1)))))) .matches(values("A")); } @@ -116,13 +97,12 @@ public void eliminateTableScanWhenNoLayoutExist() public void replaceWithExistsWhenNoLayoutExist() { ColumnHandle columnHandle = new TpchColumnHandle("nationkey", BIGINT); - tester().assertThat(pickTableLayout.pickTableLayoutForPredicate()) + tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter(expression("nationkey = BIGINT '44'"), p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), ImmutableMap.of(p.symbol("nationkey", BIGINT), columnHandle), - Optional.of(nationTableLayoutHandle), TupleDomain.none(), TupleDomain.none()))) .matches(values("A")); @@ -131,37 +111,24 @@ public void replaceWithExistsWhenNoLayoutExist() @Test public void doesNotFireIfRuleNotChangePlan() { - tester().assertThat(pickTableLayout.pickTableLayoutForPredicate()) + tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter(expression("nationkey % 17 = BIGINT '44' AND nationkey % 15 = BIGINT '43'"), p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)), - Optional.of(nationTableLayoutHandle), TupleDomain.all(), TupleDomain.all()))) .doesNotFire(); } - @Test - public void ruleAddedTableLayoutToTableScan() - { - tester().assertThat(pickTableLayout.pickTableLayoutWithoutPredicate()) - .on(p -> p.tableScan( - nationTableHandle, - ImmutableList.of(p.symbol("nationkey", BIGINT)), - ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)))) - .matches( - constrainedTableScanWithTableLayout("nation", ImmutableMap.of(), ImmutableMap.of("nationkey", "nationkey"))); - } - @Test public void ruleAddedTableLayoutToFilterTableScan() { Map filterConstraint = ImmutableMap.builder() .put("orderstatus", singleValue(createVarcharType(1), utf8Slice("F"))) .build(); - tester().assertThat(pickTableLayout.pickTableLayoutForPredicate()) + tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter(expression("orderstatus = CAST ('F' AS VARCHAR(1))"), p.tableScan( ordersTableHandle, @@ -174,13 +141,12 @@ public void ruleAddedTableLayoutToFilterTableScan() @Test public void ruleAddedNewTableLayoutIfTableScanHasEmptyConstraint() { - tester().assertThat(pickTableLayout.pickTableLayoutForPredicate()) + tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter(expression("orderstatus = 'F'"), p.tableScan( ordersTableHandle, ImmutableList.of(p.symbol("orderstatus", createVarcharType(1))), - ImmutableMap.of(p.symbol("orderstatus", createVarcharType(1)), new TpchColumnHandle("orderstatus", createVarcharType(1))), - Optional.of(ordersTableLayoutHandle)))) + ImmutableMap.of(p.symbol("orderstatus", createVarcharType(1)), new TpchColumnHandle("orderstatus", createVarcharType(1)))))) .matches( constrainedTableScanWithTableLayout( "orders", @@ -192,7 +158,7 @@ public void ruleAddedNewTableLayoutIfTableScanHasEmptyConstraint() public void ruleWithPushdownableToTableLayoutPredicate() { Type orderStatusType = createVarcharType(1); - tester().assertThat(pickTableLayout.pickTableLayoutForPredicate()) + tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter(expression("orderstatus = 'O'"), p.tableScan( ordersTableHandle, @@ -208,7 +174,7 @@ public void ruleWithPushdownableToTableLayoutPredicate() public void nonDeterministicPredicate() { Type orderStatusType = createVarcharType(1); - tester().assertThat(pickTableLayout.pickTableLayoutForPredicate()) + tester().assertThat(pushPredicateIntoTableScan) .on(p -> p.filter(expression("orderstatus = 'O' AND rand() = 0"), p.tableScan( ordersTableHandle, diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestRemoveEmptyDelete.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestRemoveEmptyDelete.java index 8e95676875e..7fc51213d52 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestRemoveEmptyDelete.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestRemoveEmptyDelete.java @@ -17,12 +17,15 @@ import com.google.common.collect.ImmutableMap; import io.prestosql.metadata.TableHandle; import io.prestosql.plugin.tpch.TpchTableHandle; +import io.prestosql.plugin.tpch.TpchTransactionHandle; import io.prestosql.spi.connector.SchemaTableName; import io.prestosql.spi.type.BigintType; import io.prestosql.sql.planner.assertions.PlanMatchPattern; import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; import org.testng.annotations.Test; +import java.util.Optional; + import static io.prestosql.sql.planner.iterative.rule.test.RuleTester.CONNECTOR_ID; public class TestRemoveEmptyDelete @@ -35,7 +38,7 @@ public void testDoesNotFire() .on(p -> p.tableDelete( new SchemaTableName("sch", "tab"), p.tableScan( - new TableHandle(CONNECTOR_ID, new TpchTableHandle("nation", 1.0)), + new TableHandle(CONNECTOR_ID, new TpchTableHandle("nation", 1.0), TpchTransactionHandle.INSTANCE, Optional.empty()), ImmutableList.of(), ImmutableMap.of()), p.symbol("a", BigintType.BIGINT))) diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestSimplifyExpressions.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestSimplifyExpressions.java index ae0cbaa33e3..28eb9868829 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestSimplifyExpressions.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestSimplifyExpressions.java @@ -20,6 +20,7 @@ import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.SymbolAllocator; import io.prestosql.sql.planner.SymbolsExtractor; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.ExpressionRewriter; import io.prestosql.sql.tree.ExpressionTreeRewriter; @@ -118,7 +119,7 @@ private static void assertSimplifies(String expression, String expected) { Expression actualExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(expression)); Expression expectedExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(expected)); - Expression rewritten = rewrite(actualExpression, TEST_SESSION, new SymbolAllocator(booleanSymbolTypeMapFor(actualExpression)), METADATA, LITERAL_ENCODER, SQL_PARSER); + Expression rewritten = rewrite(actualExpression, TEST_SESSION, new SymbolAllocator(booleanSymbolTypeMapFor(actualExpression)), METADATA, LITERAL_ENCODER, new TypeAnalyzer(SQL_PARSER, METADATA)); assertEquals( normalize(rewritten), normalize(expectedExpression)); diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java index 82e1cc7f18f..7557675fc89 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java @@ -19,11 +19,14 @@ import io.prestosql.metadata.TableHandle; import io.prestosql.plugin.tpch.TpchColumnHandle; import io.prestosql.plugin.tpch.TpchTableHandle; +import io.prestosql.plugin.tpch.TpchTransactionHandle; import io.prestosql.sql.planner.assertions.PlanMatchPattern; import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; import io.prestosql.sql.planner.plan.Assignments; import org.testng.annotations.Test; +import java.util.Optional; + import static io.prestosql.plugin.tpch.TpchMetadata.TINY_SCALE_FACTOR; import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.project; @@ -48,9 +51,13 @@ public void testRewrite() .on(p -> p.lateral( ImmutableList.of(p.symbol("l_nationkey")), - p.tableScan(new TableHandle( + p.tableScan( + new TableHandle( new ConnectorId("local"), - new TpchTableHandle("nation", TINY_SCALE_FACTOR)), ImmutableList.of(p.symbol("l_nationkey")), + new TpchTableHandle("nation", TINY_SCALE_FACTOR), + TpchTransactionHandle.INSTANCE, + Optional.empty()), + ImmutableList.of(p.symbol("l_nationkey")), ImmutableMap.of(p.symbol("l_nationkey"), new TpchColumnHandle("nationkey", BIGINT))), p.project( diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/PlanBuilder.java index 68cab31431b..587d25e340c 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/PlanBuilder.java @@ -24,7 +24,6 @@ import io.prestosql.metadata.Metadata; import io.prestosql.metadata.Signature; import io.prestosql.metadata.TableHandle; -import io.prestosql.metadata.TableLayoutHandle; import io.prestosql.spi.block.SortOrder; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.SchemaTableName; @@ -77,7 +76,9 @@ import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.FunctionCall; import io.prestosql.sql.tree.NullLiteral; +import io.prestosql.testing.TestingHandle; import io.prestosql.testing.TestingMetadata.TestingTableHandle; +import io.prestosql.testing.TestingTransactionHandle; import java.util.ArrayList; import java.util.Arrays; @@ -363,29 +364,24 @@ public LateralJoinNode lateral(List correlation, PlanNode input, PlanNod public TableScanNode tableScan(List symbols, Map assignments) { - TableHandle tableHandle = new TableHandle(new ConnectorId("testConnector"), new TestingTableHandle()); - return tableScan(tableHandle, symbols, assignments, Optional.empty(), TupleDomain.all(), TupleDomain.all()); - } - - public TableScanNode tableScan(TableHandle tableHandle, List symbols, Map assignments) - { - return tableScan(tableHandle, symbols, assignments, Optional.empty()); + return tableScan( + new TableHandle(new ConnectorId("testConnector"), new TestingTableHandle(), TestingTransactionHandle.create(), Optional.of(TestingHandle.INSTANCE)), + symbols, + assignments); } public TableScanNode tableScan( TableHandle tableHandle, List symbols, - Map assignments, - Optional tableLayout) + Map assignments) { - return tableScan(tableHandle, symbols, assignments, tableLayout, TupleDomain.all(), TupleDomain.all()); + return tableScan(tableHandle, symbols, assignments, TupleDomain.all(), TupleDomain.all()); } public TableScanNode tableScan( TableHandle tableHandle, List symbols, Map assignments, - Optional tableLayout, TupleDomain currentConstraint, TupleDomain enforcedConstraint) { @@ -394,7 +390,6 @@ public TableScanNode tableScan( tableHandle, symbols, assignments, - tableLayout, currentConstraint, enforcedConstraint); } @@ -404,7 +399,9 @@ public TableFinishNode tableDelete(SchemaTableName schemaTableName, PlanNode del TableWriterNode.DeleteHandle deleteHandle = new TableWriterNode.DeleteHandle( new TableHandle( new ConnectorId("testConnector"), - new TestingTableHandle()), + new TestingTableHandle(), + TestingTransactionHandle.create(), + Optional.of(TestingHandle.INSTANCE)), schemaTableName); return new TableFinishNode( idAllocator.getNextId(), @@ -488,7 +485,6 @@ public IndexSourceNode indexSource( TestingConnectorTransactionHandle.INSTANCE, TestingConnectorIndexHandle.INSTANCE), tableHandle, - Optional.empty(), lookupSymbols, outputSymbols, assignments, diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/RuleTester.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/RuleTester.java index 7b6c8f7212f..b7cd186dc8c 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/RuleTester.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/RuleTester.java @@ -22,7 +22,7 @@ import io.prestosql.spi.Plugin; import io.prestosql.split.PageSourceManager; import io.prestosql.split.SplitManager; -import io.prestosql.sql.parser.SqlParser; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.iterative.Rule; import io.prestosql.testing.LocalQueryRunner; import io.prestosql.transaction.TransactionManager; @@ -48,7 +48,7 @@ public class RuleTester private final SplitManager splitManager; private final PageSourceManager pageSourceManager; private final AccessControl accessControl; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; public RuleTester() { @@ -91,7 +91,7 @@ public RuleTester(List plugins, Map sessionProperties, O this.splitManager = queryRunner.getSplitManager(); this.pageSourceManager = queryRunner.getPageSourceManager(); this.accessControl = queryRunner.getAccessControl(); - this.sqlParser = queryRunner.getSqlParser(); + this.typeAnalyzer = new TypeAnalyzer(queryRunner.getSqlParser(), metadata); } public RuleAssert assertThat(Rule rule) @@ -120,12 +120,9 @@ public PageSourceManager getPageSourceManager() return pageSourceManager; } - // TODO: this is only being used by rules that need to get the type of an expression - // In the short term, it should be encapsulated into something that knows how to provide types - // Rules should *not* need to use the parser otherwise. - public SqlParser getSqlParser() + public TypeAnalyzer getTypeAnalyzer() { - return sqlParser; + return typeAnalyzer; } public ConnectorId getCurrentConnectorId() diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestAddExchanges.java b/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestAddExchanges.java deleted file mode 100644 index 78c343a81e0..00000000000 --- a/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestAddExchanges.java +++ /dev/null @@ -1,795 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prestosql.sql.planner.optimizations; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Lists; -import io.prestosql.spi.block.SortOrder; -import io.prestosql.spi.connector.ConstantProperty; -import io.prestosql.spi.connector.GroupingProperty; -import io.prestosql.spi.connector.SortingProperty; -import io.prestosql.sql.planner.Symbol; -import io.prestosql.sql.planner.optimizations.ActualProperties.Global; -import org.testng.annotations.Test; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.Comparator; -import java.util.List; -import java.util.Optional; - -import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.prestosql.spi.block.SortOrder.ASC_NULLS_FIRST; -import static io.prestosql.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; -import static io.prestosql.sql.planner.optimizations.ActualProperties.Global.arbitraryPartition; -import static io.prestosql.sql.planner.optimizations.ActualProperties.Global.partitionedOn; -import static io.prestosql.sql.planner.optimizations.ActualProperties.Global.singleStreamPartition; -import static io.prestosql.sql.planner.optimizations.ActualProperties.builder; -import static io.prestosql.sql.planner.optimizations.AddExchanges.streamingExecutionPreference; -import static org.testng.Assert.assertEquals; - -/** - * These are unit test for the internal logic in AddExchanges. - * For plan tests see {@link TestAddExchangesPlans} - */ -public class TestAddExchanges -{ - @Test - public void testPickLayoutAnyPreference() - { - Comparator preference = streamingExecutionPreference(PreferredProperties.any()); - - List input = ImmutableList.builder() - .add(builder() - .global(streamPartitionedOn("a", "b")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .build(); - // Given no preferences, the original input order should be maintained - assertEquals(stableSort(input, preference), input); - } - - @Test - public void testPickLayoutPartitionedPreference() - { - Comparator preference = streamingExecutionPreference(PreferredProperties.distributed()); - - List input = ImmutableList.builder() - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .build(); - - List expected = ImmutableList.builder() - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .build(); - assertEquals(stableSort(input, preference), expected); - } - - @Test - public void testPickLayoutUnpartitionedPreference() - { - Comparator preference = streamingExecutionPreference(PreferredProperties.undistributed()); - - List input = ImmutableList.builder() - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .build(); - - List expected = ImmutableList.builder() - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .build(); - assertEquals(stableSort(input, preference), expected); - } - - @Test - public void testPickLayoutPartitionedOnSingle() - { - Comparator preference = streamingExecutionPreference( - PreferredProperties.partitioned(ImmutableSet.of(symbol("a")))); - - List input = ImmutableList.builder() - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .build(); - - List expected = ImmutableList.builder() - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .build(); - assertEquals(stableSort(input, preference), expected); - } - - @Test - public void testPickLayoutPartitionedOnMultiple() - { - Comparator preference = streamingExecutionPreference( - PreferredProperties.partitioned(ImmutableSet.of(symbol("a"), symbol("b")))); - - List input = ImmutableList.builder() - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .build(); - - List expected = ImmutableList.builder() - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .build(); - assertEquals(stableSort(input, preference), expected); - } - - @Test - public void testPickLayoutGrouped() - { - Comparator preference = streamingExecutionPreference - (PreferredProperties.local(ImmutableList.of(grouped("a")))); - - List input = ImmutableList.builder() - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .build(); - - List expected = ImmutableList.builder() - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .build(); - assertEquals(stableSort(input, preference), expected); - } - - @Test - public void testPickLayoutGroupedMultiple() - { - Comparator preference = streamingExecutionPreference - (PreferredProperties.local(ImmutableList.of(grouped("a", "b")))); - - List input = ImmutableList.builder() - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .build(); - - List expected = ImmutableList.builder() - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .build(); - assertEquals(stableSort(input, preference), expected); - } - - @Test - public void testPickLayoutGroupedMultipleProperties() - { - Comparator preference = streamingExecutionPreference - (PreferredProperties.local(ImmutableList.of(grouped("a"), grouped("b")))); - - List input = ImmutableList.builder() - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .build(); - - List expected = ImmutableList.builder() - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .build(); - assertEquals(stableSort(input, preference), expected); - } - - @Test - public void testPickLayoutGroupedWithSort() - { - Comparator preference = streamingExecutionPreference - (PreferredProperties.local(ImmutableList.of(grouped("a"), sorted("b", ASC_NULLS_FIRST)))); - - List input = ImmutableList.builder() - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .build(); - - List expected = ImmutableList.builder() - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .build(); - assertEquals(stableSort(input, preference), expected); - } - - @Test - public void testPickLayoutUnpartitionedWithGroupAndSort() - { - Comparator preference = streamingExecutionPreference - (PreferredProperties.undistributedWithLocal(ImmutableList.of(grouped("a"), sorted("b", ASC_NULLS_FIRST)))); - - List input = ImmutableList.builder() - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .build(); - - List expected = ImmutableList.builder() - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .build(); - assertEquals(stableSort(input, preference), expected); - } - - @Test - public void testPickLayoutPartitionedWithGroup() - { - Comparator preference = streamingExecutionPreference - (PreferredProperties.partitionedWithLocal( - ImmutableSet.of(symbol("a")), - ImmutableList.of(grouped("a")))); - - List input = ImmutableList.builder() - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .build(); - - List expected = ImmutableList.builder() - .add(builder() - .global(singleStream()) - .local(ImmutableList.of(constant("a"), sorted("b", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(singleStreamPartition()) - .local(ImmutableList.of(sorted("a", ASC_NULLS_FIRST))) - .build()) - .add(builder() - .global(streamPartitionedOn("a")) - .build()) - .add(builder() - .global(singleStreamPartition()) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(hashDistributedOn("a")) - .build()) - .add(builder() - .global(arbitraryPartition()) - .local(ImmutableList.of(grouped("a", "b"))) - .build()) - .add(builder() - .global(arbitraryPartition()) - .build()) - .build(); - assertEquals(stableSort(input, preference), expected); - } - - private static List stableSort(List list, Comparator comparator) - { - ArrayList copy = Lists.newArrayList(list); - Collections.sort(copy, comparator); - return copy; - } - - private static Global hashDistributedOn(String... columnNames) - { - return partitionedOn(FIXED_HASH_DISTRIBUTION, arguments(columnNames), Optional.of(arguments(columnNames))); - } - - public static Global singleStream() - { - return Global.streamPartitionedOn(ImmutableList.of()); - } - - private static Global streamPartitionedOn(String... columnNames) - { - return Global.streamPartitionedOn(arguments(columnNames)); - } - - private static ConstantProperty constant(String column) - { - return new ConstantProperty<>(symbol(column)); - } - - private static GroupingProperty grouped(String... columns) - { - return new GroupingProperty<>(Lists.transform(Arrays.asList(columns), Symbol::new)); - } - - private static SortingProperty sorted(String column, SortOrder order) - { - return new SortingProperty<>(symbol(column), order); - } - - private static Symbol symbol(String name) - { - return new Symbol(name); - } - - private static List arguments(String[] columnNames) - { - return Arrays.asList(columnNames).stream() - .map(Symbol::new) - .collect(toImmutableList()); - } -} diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestEliminateSorts.java b/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestEliminateSorts.java index e659209b6b4..6cdc7591fbe 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestEliminateSorts.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestEliminateSorts.java @@ -19,6 +19,7 @@ import io.prestosql.spi.block.SortOrder; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.RuleStatsRecorder; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.assertions.BasePlanTest; import io.prestosql.sql.planner.assertions.ExpectedValueProvider; import io.prestosql.sql.planner.assertions.PlanMatchPattern; @@ -89,7 +90,7 @@ public void assertUnitPlan(@Language("SQL") String sql, PlanMatchPattern pattern { List optimizers = ImmutableList.of( new UnaliasSymbolReferences(), - new AddExchanges(getQueryRunner().getMetadata(), new SqlParser()), + new AddExchanges(getQueryRunner().getMetadata(), new TypeAnalyzer(new SqlParser(), getQueryRunner().getMetadata())), new PruneUnreferencedOutputs(), new IterativeOptimizer( new RuleStatsRecorder(), diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestExpressionEquivalence.java b/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestExpressionEquivalence.java index 1c9d617c91a..a61aad2080c 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestExpressionEquivalence.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestExpressionEquivalence.java @@ -21,6 +21,7 @@ import io.prestosql.sql.parser.ParsingOptions; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.tree.Expression; import org.intellij.lang.annotations.Language; @@ -41,7 +42,7 @@ public class TestExpressionEquivalence { private static final SqlParser SQL_PARSER = new SqlParser(); private static final MetadataManager METADATA = MetadataManager.createTestMetadataManager(); - private static final ExpressionEquivalence EQUIVALENCE = new ExpressionEquivalence(METADATA, SQL_PARSER); + private static final ExpressionEquivalence EQUIVALENCE = new ExpressionEquivalence(METADATA, new TypeAnalyzer(SQL_PARSER, METADATA)); @Test public void testEquivalent() diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestReorderWindows.java b/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestReorderWindows.java index dfaa88c1eb4..c4a46ae629d 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestReorderWindows.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestReorderWindows.java @@ -18,6 +18,7 @@ import com.google.common.collect.ImmutableSet; import io.prestosql.spi.block.SortOrder; import io.prestosql.sql.planner.RuleStatsRecorder; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.assertions.BasePlanTest; import io.prestosql.sql.planner.assertions.ExpectedValueProvider; import io.prestosql.sql.planner.assertions.PlanMatchPattern; @@ -322,7 +323,7 @@ private void assertUnitPlan(@Language("SQL") String sql, PlanMatchPattern patter { List optimizers = ImmutableList.of( new UnaliasSymbolReferences(), - new PredicatePushDown(getQueryRunner().getMetadata(), getQueryRunner().getSqlParser()), + new PredicatePushDown(getQueryRunner().getMetadata(), new TypeAnalyzer(getQueryRunner().getSqlParser(), getQueryRunner().getMetadata())), new IterativeOptimizer( new RuleStatsRecorder(), getQueryRunner().getStatsCalculator(), diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateAggregationsWithDefaultValues.java b/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateAggregationsWithDefaultValues.java index 928e2a95ec3..2bfffc331b8 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateAggregationsWithDefaultValues.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateAggregationsWithDefaultValues.java @@ -20,14 +20,12 @@ import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.TableHandle; -import io.prestosql.metadata.TableLayoutHandle; import io.prestosql.plugin.tpch.TpchColumnHandle; import io.prestosql.plugin.tpch.TpchTableHandle; -import io.prestosql.plugin.tpch.TpchTableLayoutHandle; -import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.PlanNodeIdAllocator; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.assertions.BasePlanTest; import io.prestosql.sql.planner.iterative.rule.test.PlanBuilder; @@ -51,8 +49,6 @@ public class TestValidateAggregationsWithDefaultValues extends BasePlanTest { - private static final SqlParser SQL_PARSER = new SqlParser(); - private Metadata metadata; private PlanBuilder builder; private Symbol symbol; @@ -66,13 +62,12 @@ public void setup() ConnectorId connectorId = getCurrentConnectorId(); TableHandle nationTableHandle = new TableHandle( connectorId, - new TpchTableHandle("nation", 1.0)); - TableLayoutHandle nationTableLayoutHandle = new TableLayoutHandle(connectorId, + new TpchTableHandle("nation", 1.0), TestingTransactionHandle.create(), - new TpchTableLayoutHandle((TpchTableHandle) nationTableHandle.getConnectorHandle(), TupleDomain.all())); + Optional.empty()); TpchColumnHandle nationkeyColumnHandle = new TpchColumnHandle("nationkey", BIGINT); symbol = new Symbol("nationkey"); - tableScanNode = builder.tableScan(nationTableHandle, ImmutableList.of(symbol), ImmutableMap.of(symbol, nationkeyColumnHandle), Optional.of(nationTableLayoutHandle)); + tableScanNode = builder.tableScan(nationTableHandle, ImmutableList.of(symbol), ImmutableMap.of(symbol, nationkeyColumnHandle)); } @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Final aggregation with default value not separated from partial aggregation by remote hash exchange") @@ -196,7 +191,7 @@ private void validatePlan(PlanNode root, boolean forceSingleNode) getQueryRunner().inTransaction(session -> { // metadata.getCatalogHandle() registers the catalog for the transaction session.getCatalog().ifPresent(catalog -> metadata.getCatalogHandle(session, catalog)); - new ValidateAggregationsWithDefaultValues(forceSingleNode).validate(root, session, metadata, SQL_PARSER, TypeProvider.empty(), WarningCollector.NOOP); + new ValidateAggregationsWithDefaultValues(forceSingleNode).validate(root, session, metadata, new TypeAnalyzer(new SqlParser(), metadata), TypeProvider.empty(), WarningCollector.NOOP); return null; }); } diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateStreamingAggregations.java b/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateStreamingAggregations.java index 218d4fb7c47..23e5e1684c3 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateStreamingAggregations.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateStreamingAggregations.java @@ -19,18 +19,15 @@ import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.TableHandle; -import io.prestosql.metadata.TableLayoutHandle; import io.prestosql.plugin.tpch.TpchColumnHandle; import io.prestosql.plugin.tpch.TpchTableHandle; -import io.prestosql.plugin.tpch.TpchTableLayoutHandle; -import io.prestosql.spi.predicate.TupleDomain; -import io.prestosql.sql.parser.SqlParser; +import io.prestosql.plugin.tpch.TpchTransactionHandle; import io.prestosql.sql.planner.PlanNodeIdAllocator; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.assertions.BasePlanTest; import io.prestosql.sql.planner.iterative.rule.test.PlanBuilder; import io.prestosql.sql.planner.plan.PlanNode; -import io.prestosql.testing.TestingTransactionHandle; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -44,25 +41,22 @@ public class TestValidateStreamingAggregations extends BasePlanTest { private Metadata metadata; - private SqlParser sqlParser; + private TypeAnalyzer typeAnalyzer; private PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); private TableHandle nationTableHandle; - private TableLayoutHandle nationTableLayoutHandle; @BeforeClass public void setup() { metadata = getQueryRunner().getMetadata(); - sqlParser = getQueryRunner().getSqlParser(); + typeAnalyzer = new TypeAnalyzer(getQueryRunner().getSqlParser(), metadata); ConnectorId connectorId = getCurrentConnectorId(); nationTableHandle = new TableHandle( connectorId, - new TpchTableHandle("nation", 1.0)); - - nationTableLayoutHandle = new TableLayoutHandle(connectorId, - TestingTransactionHandle.create(), - new TpchTableLayoutHandle((TpchTableHandle) nationTableHandle.getConnectorHandle(), TupleDomain.all())); + new TpchTableHandle("nation", 1.0), + TpchTransactionHandle.INSTANCE, + Optional.empty()); } @Test @@ -76,8 +70,7 @@ public void testValidateSuccessful() p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), - ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)), - Optional.of(nationTableLayoutHandle))))); + ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)))))); validatePlan( p -> p.aggregation( @@ -89,8 +82,7 @@ public void testValidateSuccessful() p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), - ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)), - Optional.of(nationTableLayoutHandle)))))); + ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT))))))); } @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Streaming aggregation with input not grouped on the grouping keys") @@ -105,8 +97,7 @@ public void testValidateFailed() p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), - ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)), - Optional.of(nationTableLayoutHandle))))); + ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)))))); } private void validatePlan(Function planProvider) @@ -118,7 +109,7 @@ private void validatePlan(Function planProvider) getQueryRunner().inTransaction(session -> { // metadata.getCatalogHandle() registers the catalog for the transaction session.getCatalog().ifPresent(catalog -> metadata.getCatalogHandle(session, catalog)); - new ValidateStreamingAggregations().validate(planNode, session, metadata, sqlParser, types, WarningCollector.NOOP); + new ValidateStreamingAggregations().validate(planNode, session, metadata, typeAnalyzer, types, WarningCollector.NOOP); return null; }); } diff --git a/presto-main/src/test/java/io/prestosql/type/BenchmarkDecimalOperators.java b/presto-main/src/test/java/io/prestosql/type/BenchmarkDecimalOperators.java index e6402826a62..046898b6172 100644 --- a/presto-main/src/test/java/io/prestosql/type/BenchmarkDecimalOperators.java +++ b/presto-main/src/test/java/io/prestosql/type/BenchmarkDecimalOperators.java @@ -16,7 +16,6 @@ import com.google.common.collect.ImmutableList; import io.prestosql.RowPagesBuilder; import io.prestosql.Session; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.MetadataManager; import io.prestosql.operator.DriverYieldSignal; import io.prestosql.operator.project.PageProcessor; @@ -30,12 +29,11 @@ import io.prestosql.sql.gen.PageFunctionCompiler; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; -import io.prestosql.sql.planner.SymbolToInputRewriter; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.relational.RowExpression; import io.prestosql.sql.relational.SqlToRowExpressionTranslator; import io.prestosql.sql.tree.Expression; -import io.prestosql.sql.tree.NodeRef; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.Fork; import org.openjdk.jmh.annotations.Measurement; @@ -71,14 +69,11 @@ import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.spi.type.DecimalType.createDecimalType; import static io.prestosql.spi.type.DoubleType.DOUBLE; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput; import static io.prestosql.testing.TestingConnectorSession.SESSION; import static io.prestosql.testing.TestingSession.testSessionBuilder; import static java.math.BigInteger.ONE; import static java.math.BigInteger.ZERO; -import static java.util.Collections.emptyList; import static java.util.stream.Collectors.toList; -import static java.util.stream.Collectors.toMap; import static org.openjdk.jmh.annotations.Scope.Thread; @State(Scope.Thread) @@ -548,6 +543,7 @@ private Object execute(BaseState state) private static class BaseState { private final MetadataManager metadata = createTestMetadataManager(); + private final TypeAnalyzer typeAnalyzer = new TypeAnalyzer(new SqlParser(), metadata); private final Session session = testSessionBuilder().build(); private final Random random = new Random(); @@ -611,15 +607,19 @@ protected void setDoubleMaxValue(double doubleMaxValue) this.doubleMaxValue = doubleMaxValue; } - private RowExpression rowExpression(String expression) + private RowExpression rowExpression(String value) { - Expression inputReferenceExpression = new SymbolToInputRewriter(sourceLayout).rewrite(createExpression(expression, metadata, TypeProvider.copyOf(symbolTypes))); - - Map types = sourceLayout.entrySet().stream() - .collect(toMap(Map.Entry::getValue, entry -> symbolTypes.get(entry.getKey()))); - - Map, Type> expressionTypes = getExpressionTypesFromInput(TEST_SESSION, metadata, SQL_PARSER, types, inputReferenceExpression, emptyList(), WarningCollector.NOOP); - return SqlToRowExpressionTranslator.translate(inputReferenceExpression, SCALAR, expressionTypes, metadata.getFunctionRegistry(), metadata.getTypeManager(), TEST_SESSION, true); + Expression expression = createExpression(value, metadata, TypeProvider.copyOf(symbolTypes)); + + return SqlToRowExpressionTranslator.translate( + expression, + SCALAR, + typeAnalyzer.getTypes(TEST_SESSION, TypeProvider.copyOf(symbolTypes), expression), + sourceLayout, + metadata.getFunctionRegistry(), + metadata.getTypeManager(), + TEST_SESSION, + true); } private Object generateRandomValue(Type type) diff --git a/presto-tests/src/main/java/io/prestosql/tests/AbstractTestQueryFramework.java b/presto-tests/src/main/java/io/prestosql/tests/AbstractTestQueryFramework.java index a11c3c7abed..45d55e5bb99 100644 --- a/presto-tests/src/main/java/io/prestosql/tests/AbstractTestQueryFramework.java +++ b/presto-tests/src/main/java/io/prestosql/tests/AbstractTestQueryFramework.java @@ -34,6 +34,7 @@ import io.prestosql.sql.planner.Plan; import io.prestosql.sql.planner.PlanFragmenter; import io.prestosql.sql.planner.PlanOptimizers; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.optimizations.PlanOptimizer; import io.prestosql.sql.tree.ExplainType; import io.prestosql.testing.MaterializedResult; @@ -345,7 +346,7 @@ private QueryExplainer getQueryExplainer() CostCalculator costCalculator = new CostCalculatorUsingExchanges(taskCountEstimator); List optimizers = new PlanOptimizers( metadata, - sqlParser, + new TypeAnalyzer(sqlParser, metadata), featuresConfig, new TaskManagerConfig(), forceSingleNode, diff --git a/presto-tests/src/test/java/io/prestosql/tests/TestLocalQueries.java b/presto-tests/src/test/java/io/prestosql/tests/TestLocalQueries.java index 164b82fe520..2864fc4d4bc 100644 --- a/presto-tests/src/test/java/io/prestosql/tests/TestLocalQueries.java +++ b/presto-tests/src/test/java/io/prestosql/tests/TestLocalQueries.java @@ -14,32 +14,18 @@ package io.prestosql.tests; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; import io.prestosql.Session; import io.prestosql.connector.ConnectorId; import io.prestosql.metadata.SessionPropertyManager; import io.prestosql.plugin.tpch.TpchConnectorFactory; -import io.prestosql.spi.connector.CatalogSchemaTableName; -import io.prestosql.sql.planner.planPrinter.IoPlanPrinter.ColumnConstraint; -import io.prestosql.sql.planner.planPrinter.IoPlanPrinter.FormattedDomain; -import io.prestosql.sql.planner.planPrinter.IoPlanPrinter.FormattedMarker; -import io.prestosql.sql.planner.planPrinter.IoPlanPrinter.FormattedRange; -import io.prestosql.sql.planner.planPrinter.IoPlanPrinter.IoPlan; -import io.prestosql.sql.planner.planPrinter.IoPlanPrinter.IoPlan.TableColumnInfo; import io.prestosql.testing.LocalQueryRunner; import io.prestosql.testing.MaterializedResult; import org.testng.annotations.Test; -import java.util.Optional; - -import static com.google.common.collect.Iterables.getOnlyElement; -import static io.airlift.json.JsonCodec.jsonCodec; import static io.prestosql.SystemSessionProperties.PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN; import static io.prestosql.plugin.tpch.TpchMetadata.TINY_SCHEMA_NAME; -import static io.prestosql.spi.predicate.Marker.Bound.EXACTLY; import static io.prestosql.spi.type.DoubleType.DOUBLE; import static io.prestosql.spi.type.VarcharType.VARCHAR; -import static io.prestosql.spi.type.VarcharType.createVarcharType; import static io.prestosql.testing.MaterializedResult.resultBuilder; import static io.prestosql.testing.TestingSession.TESTING_CATALOG; import static io.prestosql.testing.TestingSession.testSessionBuilder; @@ -115,34 +101,6 @@ public void testDecimal() assertQuery("SELECT 0.1", "SELECT CAST('0.1' AS DECIMAL)"); } - @Test - public void testIOExplain() - { - String query = "SELECT * FROM orders"; - MaterializedResult result = computeActual("EXPLAIN (TYPE IO, FORMAT JSON) " + query); - TableColumnInfo input = new TableColumnInfo( - new CatalogSchemaTableName("local", "sf0.01", "orders"), - ImmutableSet.of( - new ColumnConstraint( - "orderstatus", - createVarcharType(1).getTypeSignature(), - new FormattedDomain( - false, - ImmutableSet.of( - new FormattedRange( - new FormattedMarker(Optional.of("F"), EXACTLY), - new FormattedMarker(Optional.of("F"), EXACTLY)), - new FormattedRange( - new FormattedMarker(Optional.of("O"), EXACTLY), - new FormattedMarker(Optional.of("O"), EXACTLY)), - new FormattedRange( - new FormattedMarker(Optional.of("P"), EXACTLY), - new FormattedMarker(Optional.of("P"), EXACTLY))))))); - assertEquals( - jsonCodec(IoPlan.class).fromJson((String) getOnlyElement(result.getOnlyColumnAsSet())), - new IoPlan(ImmutableSet.of(input), Optional.empty())); - } - @Test public void testHueQueries() { @@ -152,4 +110,13 @@ public void testHueQueries() // https://github.com/cloudera/hue/blob/b49e98c1250c502be596667ce1f0fe118983b432/desktop/libs/notebook/src/notebook/connectors/jdbc.py#L213 assertQuerySucceeds(getSession(), "SELECT column_name, data_type, column_comment FROM information_schema.columns WHERE table_schema='local' AND TABLE_NAME='nation'"); } + + + @Test + public void testX() + { + ((LocalQueryRunner) getQueryRunner()).printPlan(); + computeActual("SELECT * FROM orders WHERE orderkey = BIGINT '1'"); + } + } diff --git a/presto-tests/src/test/java/io/prestosql/tests/TestTpchDistributedQueries.java b/presto-tests/src/test/java/io/prestosql/tests/TestTpchDistributedQueries.java index ac3442e47fd..7ff64aa3fe2 100644 --- a/presto-tests/src/test/java/io/prestosql/tests/TestTpchDistributedQueries.java +++ b/presto-tests/src/test/java/io/prestosql/tests/TestTpchDistributedQueries.java @@ -14,10 +14,21 @@ package io.prestosql.tests; import com.google.common.base.Strings; +import com.google.common.collect.ImmutableSet; +import io.prestosql.spi.connector.CatalogSchemaTableName; +import io.prestosql.sql.planner.planPrinter.IoPlanPrinter; +import io.prestosql.testing.MaterializedResult; import io.prestosql.tests.tpch.TpchQueryRunnerBuilder; import org.intellij.lang.annotations.Language; import org.testng.annotations.Test; +import java.util.Optional; + +import static com.google.common.collect.Iterables.getOnlyElement; +import static io.airlift.json.JsonCodec.jsonCodec; +import static io.prestosql.spi.predicate.Marker.Bound.EXACTLY; +import static io.prestosql.spi.type.VarcharType.createVarcharType; +import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; public class TestTpchDistributedQueries @@ -28,6 +39,34 @@ public TestTpchDistributedQueries() super(() -> TpchQueryRunnerBuilder.builder().build()); } + @Test + public void testIOExplain() + { + String query = "SELECT * FROM orders"; + MaterializedResult result = computeActual("EXPLAIN (TYPE IO, FORMAT JSON) " + query); + IoPlanPrinter.IoPlan.TableColumnInfo input = new IoPlanPrinter.IoPlan.TableColumnInfo( + new CatalogSchemaTableName("tpch", "sf0.01", "orders"), + ImmutableSet.of( + new IoPlanPrinter.ColumnConstraint( + "orderstatus", + createVarcharType(1).getTypeSignature(), + new IoPlanPrinter.FormattedDomain( + false, + ImmutableSet.of( + new IoPlanPrinter.FormattedRange( + new IoPlanPrinter.FormattedMarker(Optional.of("F"), EXACTLY), + new IoPlanPrinter.FormattedMarker(Optional.of("F"), EXACTLY)), + new IoPlanPrinter.FormattedRange( + new IoPlanPrinter.FormattedMarker(Optional.of("O"), EXACTLY), + new IoPlanPrinter.FormattedMarker(Optional.of("O"), EXACTLY)), + new IoPlanPrinter.FormattedRange( + new IoPlanPrinter.FormattedMarker(Optional.of("P"), EXACTLY), + new IoPlanPrinter.FormattedMarker(Optional.of("P"), EXACTLY))))))); + assertEquals( + jsonCodec(IoPlanPrinter.IoPlan.class).fromJson((String) getOnlyElement(result.getOnlyColumnAsSet())), + new IoPlanPrinter.IoPlan(ImmutableSet.of(input), Optional.empty())); + } + @Test public void testTooLongQuery() {