From 743cd5b420cfc67c45fc18578a2287f004256821 Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Tue, 10 Dec 2024 13:59:04 +0100 Subject: [PATCH] New IR -- WIP --- .../trino/sql/dialect/trino/Attributes.java | 127 +++++++ .../io/trino/sql/dialect/trino/Context.java | 77 ++++ .../sql/dialect/trino/ProgramBuilder.java | 86 +++++ .../trino/RelationalProgramBuilder.java | 354 ++++++++++++++++++ .../dialect/trino/ScalarProgramBuilder.java | 168 +++++++++ .../sql/dialect/trino/TypeConstraint.java | 64 ++++ .../sql/dialect/trino/operation/Array.java | 120 ++++++ .../sql/dialect/trino/operation/Between.java | 122 ++++++ .../sql/dialect/trino/operation/Constant.java | 98 +++++ .../trino/operation/CorrelatedJoin.java | 174 +++++++++ .../trino/operation/FieldSelection.java | 134 +++++++ .../sql/dialect/trino/operation/Filter.java | 130 +++++++ .../sql/dialect/trino/operation/Logical.java | 117 ++++++ .../sql/dialect/trino/operation/Output.java | 129 +++++++ .../sql/dialect/trino/operation/Query.java | 111 ++++++ .../sql/dialect/trino/operation/Return.java | 102 +++++ .../sql/dialect/trino/operation/Row.java | 122 ++++++ .../sql/dialect/trino/operation/Values.java | 157 ++++++++ .../main/java/io/trino/sql/newir/Block.java | 153 ++++++++ .../java/io/trino/sql/newir/Operation.java | 98 +++++ .../io/trino/sql/newir/PrinterOptions.java | 21 ++ .../main/java/io/trino/sql/newir/Program.java | 93 +++++ .../main/java/io/trino/sql/newir/README.md | 89 +++++ .../main/java/io/trino/sql/newir/Region.java | 63 ++++ .../java/io/trino/sql/newir/SourceNode.java | 27 ++ .../main/java/io/trino/sql/newir/Value.java | 39 ++ .../sql/dialect/trino/TestProgramBuilder.java | 104 +++++ .../java/io/trino/spi/StandardErrorCode.java | 1 + .../java/io/trino/spi/type/EmptyRowType.java | 196 ++++++++++ .../java/io/trino/spi/type/MultisetType.java | 95 +++++ .../java/io/trino/spi/type/StandardTypes.java | 3 + .../main/java/io/trino/spi/type/VoidType.java | 197 ++++++++++ 32 files changed, 3571 insertions(+) create mode 100644 core/trino-main/src/main/java/io/trino/sql/dialect/trino/Attributes.java create mode 100644 core/trino-main/src/main/java/io/trino/sql/dialect/trino/Context.java create mode 100644 core/trino-main/src/main/java/io/trino/sql/dialect/trino/ProgramBuilder.java create mode 100644 core/trino-main/src/main/java/io/trino/sql/dialect/trino/RelationalProgramBuilder.java create mode 100644 core/trino-main/src/main/java/io/trino/sql/dialect/trino/ScalarProgramBuilder.java create mode 100644 core/trino-main/src/main/java/io/trino/sql/dialect/trino/TypeConstraint.java create mode 100644 core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Array.java create mode 100644 core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Between.java create mode 100644 core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Constant.java create mode 100644 core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/CorrelatedJoin.java create mode 100644 core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/FieldSelection.java create mode 100644 core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Filter.java create mode 100644 core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Logical.java create mode 100644 core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Output.java create mode 100644 core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Query.java create mode 100644 core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Return.java create mode 100644 core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Row.java create mode 100644 core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Values.java create mode 100644 core/trino-main/src/main/java/io/trino/sql/newir/Block.java create mode 100644 core/trino-main/src/main/java/io/trino/sql/newir/Operation.java create mode 100644 core/trino-main/src/main/java/io/trino/sql/newir/PrinterOptions.java create mode 100644 core/trino-main/src/main/java/io/trino/sql/newir/Program.java create mode 100644 core/trino-main/src/main/java/io/trino/sql/newir/README.md create mode 100644 core/trino-main/src/main/java/io/trino/sql/newir/Region.java create mode 100644 core/trino-main/src/main/java/io/trino/sql/newir/SourceNode.java create mode 100644 core/trino-main/src/main/java/io/trino/sql/newir/Value.java create mode 100644 core/trino-main/src/test/java/io/trino/sql/dialect/trino/TestProgramBuilder.java create mode 100644 core/trino-spi/src/main/java/io/trino/spi/type/EmptyRowType.java create mode 100644 core/trino-spi/src/main/java/io/trino/spi/type/MultisetType.java create mode 100644 core/trino-spi/src/main/java/io/trino/spi/type/VoidType.java diff --git a/core/trino-main/src/main/java/io/trino/sql/dialect/trino/Attributes.java b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/Attributes.java new file mode 100644 index 000000000000..dfa103a2c2f5 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/Attributes.java @@ -0,0 +1,127 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.dialect.trino; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.spi.type.Type; +import io.trino.sql.ir.Logical; + +import java.util.List; +import java.util.Map; + +import static java.util.Objects.requireNonNull; + +public class Attributes +{ + public static final AttributeMetadata CARDINALITY = new AttributeMetadata<>("cardinality", Long.class, true); + public static final AttributeMetadata CONSTANT_RESULT = new AttributeMetadata<>("constant_result", ConstantResult.class, true); + public static final AttributeMetadata FIELD_NAME = new AttributeMetadata<>("field_name", String.class, false); + public static final AttributeMetadata JOIN_TYPE = new AttributeMetadata<>("join_type", JoinType.class, false); + public static final AttributeMetadata LOGICAL_OPERATOR = new AttributeMetadata<>("logical_operator", LogicalOperator.class, false); + public static final AttributeMetadata OUTPUT_NAMES = new AttributeMetadata<>("output_names", OutputNames.class, false); + + // TODO define attributes for deeply nested fields, not just top level or column level + + private Attributes() {} + + public static class AttributeMetadata + { + private final String name; + private final Class type; + private final boolean external; + + private AttributeMetadata(String name, Class type, boolean external) + { + this.name = requireNonNull(name, "name is null"); + this.type = requireNonNull(type, "type is null"); + this.external = external; + } + + public T getAttribute(Map map) + { + return this.type.cast(map.get(this.name)); + } + + public T putAttribute(Map map, T attribute) + { + return this.type.cast(map.put(name, attribute)); + } + + public Map asMap(T attribute) + { + return ImmutableMap.of(name, attribute); + } + } + + public record ConstantResult(Type type, Object value) + { + public ConstantResult + { + requireNonNull(type, "type is null"); + } + + @Override + public String toString() + { + return value.toString() + ":" + type.toString(); + } + } + + public enum JoinType + { + INNER, + LEFT, + RIGHT, + FULL; + + public static JoinType of(io.trino.sql.planner.plan.JoinType joinType) + { + return switch (joinType) { + case INNER -> INNER; + case LEFT -> LEFT; + case RIGHT -> RIGHT; + case FULL -> FULL; + }; + } + } + + public record OutputNames(List outputNames) + { + public OutputNames(List outputNames) + { + this.outputNames = ImmutableList.copyOf(requireNonNull(outputNames, "outputNames is null")); + } + + @Override + public String toString() + { + return outputNames.toString(); + } + } + + public enum LogicalOperator + { + AND, + OR; + + public static LogicalOperator of(Logical.Operator operator) + { + return switch (operator) { + case AND -> AND; + case OR -> OR; + }; + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/dialect/trino/Context.java b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/Context.java new file mode 100644 index 000000000000..f621d9c7e939 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/Context.java @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.dialect.trino; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.sql.newir.Block; +import io.trino.sql.planner.Symbol; + +import java.util.List; +import java.util.Map; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.util.HashMap.newHashMap; +import static java.util.Objects.requireNonNull; + +public record Context(Block.Builder block, Map symbolMapping) +{ + public Context(Block.Builder block) + { + this(block, Map.of()); + } + + public Context(Block.Builder block, Map symbolMapping) + { + this.block = requireNonNull(block, "block is null"); + this.symbolMapping = ImmutableMap.copyOf(requireNonNull(symbolMapping, "symbolMapping is null")); + } + + public static Map argumentMapping(Block.Parameter parameter, Map symbolMapping) + { + return symbolMapping.entrySet().stream() + .collect(toImmutableMap( + Map.Entry::getKey, + entry -> new RowField(parameter, entry.getValue()))); + } + + public static Map composedMapping(Context context, Map newMapping) + { + return composedMapping(context, ImmutableList.of(newMapping)); + } + + /** + * Compose the correlated mapping from the context with symbol mappings for the current block parameters. + * + * @param context rewrite context containing symbol mapping from all levels of correlation + * @param newMappings list of symbol mappings for current block parameters + * @return composed symbol mapping to rewrite the current block + */ + public static Map composedMapping(Context context, List> newMappings) + { + Map composed = newHashMap(context.symbolMapping().size() + newMappings.stream().mapToInt(Map::size).sum()); + composed.putAll(context.symbolMapping()); + newMappings.stream().forEach(composed::putAll); + return composed; + } + + public record RowField(Block.Parameter row, String field) + { + public RowField + { + requireNonNull(row, "row is null"); + requireNonNull(field, "field is null"); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/dialect/trino/ProgramBuilder.java b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/ProgramBuilder.java new file mode 100644 index 000000000000..2221bf7b8503 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/ProgramBuilder.java @@ -0,0 +1,86 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.dialect.trino; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Sets; +import io.trino.spi.TrinoException; +import io.trino.sql.dialect.trino.operation.Query; +import io.trino.sql.newir.Block; +import io.trino.sql.newir.Program; +import io.trino.sql.newir.SourceNode; +import io.trino.sql.newir.Value; +import io.trino.sql.planner.plan.OutputNode; +import io.trino.sql.planner.plan.PlanNode; + +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.IntStream; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.trino.spi.StandardErrorCode.IR_ERROR; + +/** + * ProgramBuilder builds a MLIR program from a PlanNode tree. + * For now, it builds a program for a single query, and assumes that OutputNode is the root PlanNode. + * In the future, we might support multiple statements. + * The resulting program has the special Query operation as the top-level operation. + * It encloses all query computations in one block. + */ +public class ProgramBuilder +{ + private ProgramBuilder() {} + + public static Program buildProgram(PlanNode root) + { + checkArgument(root instanceof OutputNode, "Expected root to be an OutputNode. Actual: " + root.getClass().getSimpleName()); + + ValueNameAllocator nameAllocator = new ValueNameAllocator(); + ImmutableMap.Builder valueMapBuilder = ImmutableMap.builder(); + Block.Builder rootBlock = new Block.Builder(Optional.of("^query"), ImmutableList.of()); + + // for now, ignoring return value. Could be worth to remember it as the final terminal Operation in the Program. + root.accept(new RelationalProgramBuilder(nameAllocator, valueMapBuilder), new Context(rootBlock)); + + // verify if all values are mapped + Set allocatedValues = IntStream.range(0, nameAllocator.label) + .mapToObj(index -> "%" + index) + .collect(toImmutableSet()); + Map valueMap = valueMapBuilder.buildOrThrow(); + Set mappedValues = valueMap.keySet().stream() + .map(Value::name) + .collect(toImmutableSet()); + if (!Sets.symmetricDifference(allocatedValues, mappedValues).isEmpty()) { + throw new TrinoException(IR_ERROR, "allocated values differ from mapped values"); + } + + // allocating this name last to avoid stealing the "%0" label. This label won't be printed. + String resultName = nameAllocator.newName(); + + return new Program(new Query(resultName, rootBlock.build()), valueMap); + } + + public static class ValueNameAllocator + { + private int label = 0; + + public String newName() + { + return "%" + label++; + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/dialect/trino/RelationalProgramBuilder.java b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/RelationalProgramBuilder.java new file mode 100644 index 000000000000..6196a6b9d739 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/RelationalProgramBuilder.java @@ -0,0 +1,354 @@ +package io.trino.sql.dialect.trino; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.spi.TrinoException; +import io.trino.spi.type.MultisetType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; +import io.trino.sql.dialect.trino.Attributes.JoinType; +import io.trino.sql.dialect.trino.operation.CorrelatedJoin; +import io.trino.sql.dialect.trino.operation.FieldSelection; +import io.trino.sql.dialect.trino.operation.Filter; +import io.trino.sql.dialect.trino.operation.Output; +import io.trino.sql.dialect.trino.operation.Return; +import io.trino.sql.dialect.trino.operation.Row; +import io.trino.sql.dialect.trino.operation.Values; +import io.trino.sql.newir.Block; +import io.trino.sql.newir.Operation; +import io.trino.sql.newir.SourceNode; +import io.trino.sql.newir.Value; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.plan.CorrelatedJoinNode; +import io.trino.sql.planner.plan.FilterNode; +import io.trino.sql.planner.plan.OutputNode; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.PlanVisitor; +import io.trino.sql.planner.plan.ValuesNode; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.spi.StandardErrorCode.IR_ERROR; +import static io.trino.spi.type.EmptyRowType.EMPTY_ROW; +import static io.trino.sql.dialect.trino.Context.argumentMapping; +import static io.trino.sql.dialect.trino.Context.composedMapping; +import static io.trino.sql.dialect.trino.RelationalProgramBuilder.OperationAndMapping; +import static io.trino.sql.dialect.trino.TypeConstraint.IS_RELATION; +import static io.trino.sql.dialect.trino.TypeConstraint.IS_RELATION_ROW; +import static io.trino.sql.dialect.trino.operation.Values.valuesWithoutFields; +import static java.util.Objects.requireNonNull; + +/** + * A rewriter transforming a tree of PlanNodes into a MLIR program based on `PlanVisitor`. + * For scalar expressions, uses another rewriter `ScalarProgramBuilder` based on `IrVisitor`. + */ +public class RelationalProgramBuilder + extends PlanVisitor +{ + private final ProgramBuilder.ValueNameAllocator nameAllocator; + private final ImmutableMap.Builder valueMap; + + public RelationalProgramBuilder(ProgramBuilder.ValueNameAllocator nameAllocator, ImmutableMap.Builder valueMap) + { + this.nameAllocator = requireNonNull(nameAllocator, "nameAllocator is null"); + this.valueMap = requireNonNull(valueMap, "valueMap is null"); + } + + @Override + protected OperationAndMapping visitPlan(PlanNode node, Context context) + { + throw new UnsupportedOperationException("The new IR does not support " + node.getClass().getSimpleName() + " yet"); + } + + @Override + public OperationAndMapping visitCorrelatedJoin(CorrelatedJoinNode node, Context context) + { + OperationAndMapping input = node.getInput().accept(this, context); + String resultName = nameAllocator.newName(); + + Type inputRowType = relationRowType(input.operation().result().type()); + + // model correlation as field selection (lambda) + Block.Parameter correlationParameter = new Block.Parameter( + nameAllocator.newName(), + inputRowType); + Block correlation = fieldSelectorBlock("^correlationSelector", correlationParameter, input.mapping(), node.getCorrelation()); + valueMap.put(correlationParameter, correlation); + + // model subquery as a lambda + Block.Parameter subqueryParameter = new Block.Parameter( + nameAllocator.newName(), + inputRowType); + Block.Builder subqueryBuilder = new Block.Builder(Optional.of("^subquery"), ImmutableList.of(subqueryParameter)); + node.getSubquery().accept( + this, + new Context(subqueryBuilder, composedMapping(context, argumentMapping(subqueryParameter, input.mapping())))); + addReturnOperation(subqueryBuilder); + Map subqueryMapping = deriveOutputMapping(relationRowType(subqueryBuilder.recentOperation().result().type()), node.getSubquery().getOutputSymbols()); + Map subqueryAttributes = subqueryBuilder.recentOperation().attributes(); + Block subquery = subqueryBuilder.build(); + valueMap.put(subqueryParameter, subquery); + + // model filter as a lambda + Block.Parameter firstFilterParameter = new Block.Parameter( + nameAllocator.newName(), + inputRowType); + Block.Parameter secondFilterParameter = new Block.Parameter( + nameAllocator.newName(), + relationRowType(subquery.getReturnedType())); + Block.Builder filterBuilder = new Block.Builder(Optional.of("^filter"), ImmutableList.of(firstFilterParameter, secondFilterParameter)); + node.getFilter().accept( + new ScalarProgramBuilder(nameAllocator, valueMap), + new Context(filterBuilder, composedMapping(context, ImmutableList.of( + argumentMapping(firstFilterParameter, input.mapping()), + argumentMapping(secondFilterParameter, subqueryMapping))))); + addReturnOperation(filterBuilder); + Block filter = filterBuilder.build(); + valueMap.put(firstFilterParameter, filter); + valueMap.put(secondFilterParameter, filter); + + CorrelatedJoin correlatedJoin = new CorrelatedJoin( + resultName, + input.operation().result(), + correlation, + subquery, + filter, + JoinType.of(node.getType()), + input.operation().attributes(), + subqueryAttributes + ); + valueMap.put(correlatedJoin.result(), correlatedJoin); + + Map outputMapping = deriveOutputMapping(relationRowType(correlatedJoin.result().type()), node.getOutputSymbols()); + context.block().addOperation(correlatedJoin); + return new OperationAndMapping(correlatedJoin, outputMapping); + } + + @Override + public OperationAndMapping visitFilter(FilterNode node, Context context) + { + OperationAndMapping input = node.getSource().accept(this, context); + String resultName = nameAllocator.newName(); + + // model filter predicate as a lambda (Block) + Block.Parameter predicateParameter = new Block.Parameter( + nameAllocator.newName(), + relationRowType(input.operation().result().type())); + Block.Builder predicateBuilder = new Block.Builder(Optional.of("^predicate"), ImmutableList.of(predicateParameter)); + + node.getPredicate().accept( + new ScalarProgramBuilder(nameAllocator, valueMap), + new Context(predicateBuilder, composedMapping(context, argumentMapping(predicateParameter, input.mapping())))); + + addReturnOperation(predicateBuilder); + + Block predicate = predicateBuilder.build(); + valueMap.put(predicateParameter, predicate); + + Filter filter = new Filter(resultName, input.operation().result(), predicate, input.operation().attributes()); + valueMap.put(filter.result(), filter); + Map outputMapping = deriveOutputMapping(relationRowType(filter.result().type()), node.getOutputSymbols()); + context.block().addOperation(filter); + return new OperationAndMapping(filter, outputMapping); + } + + @Override + public OperationAndMapping visitOutput(OutputNode node, Context context) + { + OperationAndMapping input = node.getSource().accept(this, context); + String resultName = nameAllocator.newName(); + + // model output fields selection as a lambda (Block) + Block.Parameter fieldSelectorParameter = new Block.Parameter( + nameAllocator.newName(), + relationRowType(input.operation().result().type())); + Block fieldSelectorBlock = fieldSelectorBlock("^outputFieldSelector", fieldSelectorParameter, input.mapping(), node.getOutputSymbols()); + valueMap.put(fieldSelectorParameter, fieldSelectorBlock); + + Output output = new Output(resultName, input.operation().result(), fieldSelectorBlock, node.getColumnNames()); + valueMap.put(output.result(), output); + context.block().addOperation(output); + return new OperationAndMapping(output, ImmutableMap.of()); // unlike OutputNode, the Output operation returns Void, not a relation + } + + @Override + public OperationAndMapping visitValues(ValuesNode node, Context context) + { + String resultName = nameAllocator.newName(); + + Values values; + if (node.getOutputSymbols().isEmpty()) { + values = valuesWithoutFields(resultName, node.getRowCount()); + } + else { + // model each component row of Values as a no-argument Block + // TODO we could use sequential names for blocks: "row_1", "row_2"... + List rows = node.getRows().orElseThrow().stream() + .map(rowExpression -> { + Block.Builder rowBlock = new Block.Builder(Optional.of("^row"), ImmutableList.of()); + rowExpression.accept( + new ScalarProgramBuilder(nameAllocator, valueMap), + new Context(rowBlock, context.symbolMapping())); + addReturnOperation(rowBlock); + return rowBlock.build(); + }) + .collect(toImmutableList()); + RowType rowType = RowType.anonymous(node.getOutputSymbols().stream() + .map(Symbol::type) + .collect(toImmutableList())); + values = new Values(resultName, rowType, rows); + } + valueMap.put(values.result(), values); + Map outputMapping = deriveOutputMapping(relationRowType(values.result().type()), node.getOutputSymbols()); + context.block().addOperation(values); + return new OperationAndMapping(values, outputMapping); + } + + /** + * A type of relation is represented as MultisetType of row type being either RowType or EmptyRowType. This method extracts the element type. + */ + public static Type relationRowType(Type relationType) + { + if (!IS_RELATION.test(relationType)) { + throw new TrinoException(IR_ERROR, "not a relation type. expected multiset of row"); + } + + return ((MultisetType) relationType).getElementType(); + } + + /** + * Map each output symbol of the PlanNode to a corresponding field name in the Operations output row type. + */ + private static Map deriveOutputMapping(Type relationRowType, List outputSymbols) + { + if (!IS_RELATION_ROW.test(relationRowType)) { + throw new TrinoException(IR_ERROR, "not a relation row type. expected RowType or EmptyRowType"); + } + + if (relationRowType.equals(EMPTY_ROW)) { + if (!outputSymbols.isEmpty()) { + throw new TrinoException(IR_ERROR, "relation row type mismatch: output symbols present for EmptyRowType"); + } + return ImmutableMap.of(); + } + + RowType rowType = (RowType) relationRowType; + if (rowType.getFields().size() != outputSymbols.size()) { + throw new TrinoException(IR_ERROR, "relation RowType does not match output symbols"); + } + + // Using a HashMap because it can handle duplicates. + // If a PlanNode outputs some symbol twice, we will use the first occurrence for mapping. + // As a result, the downstream references to the symbol will be mapped to the same output field, + // and the other field will be unused and eligible for pruning. + Map mapping = HashMap.newHashMap(outputSymbols.size()); + for (int i = 0; i < outputSymbols.size(); i++) { + Symbol symbol = outputSymbols.get(i); + String fieldName = rowType.getFields().get(i).getName().orElseThrow(); + mapping.putIfAbsent(symbol, fieldName); + } + return mapping; + } + + private Block fieldSelectorBlock(String blockName, Block.Parameter inputRow, Map inputSymbolMapping, List selectedSymbolsList) + { + return fieldSelectorBlock( + blockName, + ImmutableList.of(inputRow), + ImmutableList.of(inputSymbolMapping), + ImmutableList.of(selectedSymbolsList)); + } + + /** + * A helper method to express input field selection as a lambda. + * Useful for operations which pass selected input fields on output, for example join, unnest. + * Also useful for selecting input columns necessary for the operation's logic, for example ordering columns, partitioning columns. + * + * @param blockName name for the result Block + * @param inputRows arguments to the lambda representing input rows from which we want to select fields + * @param inputSymbolMappings mapping symbol --> row field name for each input row + * @param selectedSymbolsLists list of symbols to select from each input row + * @return a row containing selected fields from all input rows + */ + private Block fieldSelectorBlock(String blockName, List inputRows, List> inputSymbolMappings, List> selectedSymbolsLists) + { + if (inputRows.size() != inputSymbolMappings.size()) { + throw new TrinoException(IR_ERROR, "inputs and input symbol mappings do not match"); + } + if (inputRows.size() != selectedSymbolsLists.size()) { + throw new TrinoException(IR_ERROR, "inputs and symbol lists do not match"); + } + + ImmutableList.Builder selections = ImmutableList.builder(); + + Block.Builder selectorBlock = new Block.Builder(Optional.of(blockName), inputRows); + for (int i = 0; i < inputRows.size(); i++) { + Block.Parameter parameter = inputRows.get(i); + Map symbolMapping = inputSymbolMappings.get(i); + List symbols = selectedSymbolsLists.get(i); + for (Symbol symbol : symbols) { + String value = nameAllocator.newName(); + FieldSelection fieldSelection = new FieldSelection(value, parameter, symbolMapping.get(symbol), ImmutableMap.of()); // TODO pass appropriate row-specific input attributes through lambda arguments + valueMap.put(fieldSelection.result(), fieldSelection); + selectorBlock.addOperation(fieldSelection); + selections.add(fieldSelection); + } + } + // build a row of selected items + String rowValue = nameAllocator.newName(); + Row rowConstructor = new Row( + rowValue, + selections.build().stream().map(Operation::result).collect(toImmutableList()), + selections.build().stream().map(Operation::attributes).collect(toImmutableList())); + valueMap.put(rowConstructor.result(), rowConstructor); + selectorBlock.addOperation(rowConstructor); + + addReturnOperation(selectorBlock); + + return selectorBlock.build(); + } + + /** + * Return the value of the recent operation in the builder + */ + private void addReturnOperation(Block.Builder builder) + { + String returnValue = nameAllocator.newName(); + Operation recentOperation = builder.recentOperation(); + Return returnOperation = new Return(returnValue, recentOperation.result(), recentOperation.attributes()); + valueMap.put(returnOperation.result(), returnOperation); + builder.addOperation(returnOperation); + } + + /** + * Assigns unique lowercase names, compliant with IS_RELATION_ROW type constraint: f_1, f_2, ... + * Indexing from 1 for similarity with SQL indexing. + */ + public static RowType assignRelationRowTypeFieldNames(RowType relationRowType) + { + ImmutableList.Builder fields = ImmutableList.builder(); + for (int i = 0; i < relationRowType.getTypeParameters().size(); i++) { + fields.add(new RowType.Field( + Optional.of(String.format("f_%s", i + 1)), + relationRowType.getTypeParameters().get(i))); + } + return RowType.from(fields.build()); + } + + /** + * A result of transforming a PlanNode into an Operation. + * Maps each output symbol of the PlanNode to a corresponding field name in the Operations output RowType. + */ + public record OperationAndMapping(Operation operation, Map mapping) + { + public OperationAndMapping(Operation operation, Map mapping) + { + this.operation = requireNonNull(operation, "operation is null"); + this.mapping = ImmutableMap.copyOf(requireNonNull(mapping, "mapping is null")); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/dialect/trino/ScalarProgramBuilder.java b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/ScalarProgramBuilder.java new file mode 100644 index 000000000000..2730a7412600 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/ScalarProgramBuilder.java @@ -0,0 +1,168 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.dialect.trino; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.spi.TrinoException; +import io.trino.sql.dialect.trino.ProgramBuilder.ValueNameAllocator; +import io.trino.sql.dialect.trino.operation.Array; +import io.trino.sql.dialect.trino.operation.Between; +import io.trino.sql.dialect.trino.operation.Constant; +import io.trino.sql.dialect.trino.operation.FieldSelection; +import io.trino.sql.dialect.trino.operation.Logical; +import io.trino.sql.dialect.trino.operation.Row; +import io.trino.sql.ir.Expression; +import io.trino.sql.ir.IrVisitor; +import io.trino.sql.ir.Reference; +import io.trino.sql.newir.Operation; +import io.trino.sql.newir.SourceNode; +import io.trino.sql.newir.Value; +import io.trino.sql.planner.Symbol; + +import java.util.List; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.spi.StandardErrorCode.IR_ERROR; +import static java.util.Objects.requireNonNull; + +/** + * A rewriter transforming a tree of scalar Expressions into a MLIR program based on `IrVisitor`. + * It is called from `RelationalProgramBuilder` to transform predicates etc. + * Note: we don't need to pass a `RelationalProgramBuilder` for recursive calls because the + * IR Expressions don't have nested relations. + */ +public class ScalarProgramBuilder + extends IrVisitor +{ + private final ValueNameAllocator nameAllocator; + private final ImmutableMap.Builder valueMap; + + public ScalarProgramBuilder(ValueNameAllocator nameAllocator, ImmutableMap.Builder valueMap) + { + this.nameAllocator = requireNonNull(nameAllocator, "nameAllocator is null"); + this.valueMap = requireNonNull(valueMap, "valueMap is null"); + } + + @Override + protected Operation visitExpression(Expression node, Context context) + { + throw new UnsupportedOperationException("The new IR does not support " + node.getClass().getSimpleName() + " yet"); + } + + @Override + protected Operation visitArray(io.trino.sql.ir.Array node, Context context) + { + // lowering alert! Unrolls the array constructor into elements and a final Array operation. + ImmutableList.Builder elementsBuilder = ImmutableList.builder(); + for (Expression element : node.elements()) { + elementsBuilder.add(element.accept(this, context)); + } + List elements = elementsBuilder.build(); + + String resultName = nameAllocator.newName(); + Array array = new Array( + resultName, + node.elementType(), + elements.stream().map(Operation::result).collect(toImmutableList()), + elements.stream().map(Operation::attributes).collect(toImmutableList())); + valueMap.put(array.result(), array); + context.block().addOperation(array); + return array; + } + + @Override + protected Operation visitBetween(io.trino.sql.ir.Between node, Context context) + { + Operation input = node.value().accept(this, context); + Operation min = node.min().accept(this, context); + Operation max = node.max().accept(this, context); + + String resultName = nameAllocator.newName(); + Between between = new Between( + resultName, + input.result(), + min.result(), + max.result(), + ImmutableList.of(input.attributes(), min.attributes(), max.attributes())); + valueMap.put(between.result(), between); + context.block().addOperation(between); + return between; + } + + @Override + protected Operation visitConstant(io.trino.sql.ir.Constant node, Context context) + { + String resultName = nameAllocator.newName(); + Constant constant = new Constant(resultName, node.type(), node.value()); + valueMap.put(constant.result(), constant); + context.block().addOperation(constant); + return constant; + } + + @Override + protected Operation visitLogical(io.trino.sql.ir.Logical node, Context context) + { + // lowering alert! Unrolls the logical expression into terms and a final Logical operation constructor. + ImmutableList.Builder termsBuilder = ImmutableList.builder(); + for (Expression term : node.terms()) { + termsBuilder.add(term.accept(this, context)); + } + List terms = termsBuilder.build(); + + String resultName = nameAllocator.newName(); + Logical logical = new Logical( + resultName, + terms.stream().map(Operation::result).collect(toImmutableList()), + Attributes.LogicalOperator.of(node.operator()), + terms.stream().map(Operation::attributes).collect(toImmutableList())); + valueMap.put(logical.result(), logical); + context.block().addOperation(logical); + return logical; + } + + @Override + protected Operation visitReference(Reference node, Context context) + { + Context.RowField rowField = context.symbolMapping().get(new Symbol(node.type(), node.name())); + if (rowField == null) { + throw new TrinoException(IR_ERROR, "no mapping for symbol " + node.name()); + } + String resultName = nameAllocator.newName(); + FieldSelection fieldSelection = new FieldSelection(resultName, rowField.row(), rowField.field(), ImmutableMap.of()); // TODO pass attributes through correlation / block argument + valueMap.put(fieldSelection.result(), fieldSelection); + context.block().addOperation(fieldSelection); + return fieldSelection; + } + + @Override + protected Operation visitRow(io.trino.sql.ir.Row node, Context context) + { + // lowering alert! Unrolls the row into field assignments and a final row constructor. + ImmutableList.Builder itemsBuilder = ImmutableList.builder(); + for (Expression item : node.items()) { + itemsBuilder.add(item.accept(this, context)); + } + List items = itemsBuilder.build(); + + String resultName = nameAllocator.newName(); + Row row = new Row( + resultName, + items.stream().map(Operation::result).collect(toImmutableList()), + items.stream().map(Operation::attributes).collect(toImmutableList())); + valueMap.put(row.result(), row); + context.block().addOperation(row); + return row; + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/dialect/trino/TypeConstraint.java b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/TypeConstraint.java new file mode 100644 index 000000000000..265852b4293f --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/TypeConstraint.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.dialect.trino; + +import io.trino.spi.type.EmptyRowType; +import io.trino.spi.type.MultisetType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; + +import java.util.HashSet; +import java.util.Set; +import java.util.function.Predicate; + +import static java.util.Objects.requireNonNull; + +public record TypeConstraint(Predicate constraint) +{ + // Intermediate result row type. + // Row without fields is supported and represented as EmptyRowType. + // If row fields are present, they must have unique names. + public static final TypeConstraint IS_RELATION_ROW = new TypeConstraint(type -> { + if (type instanceof EmptyRowType) { + return true; + } + if (type instanceof RowType rowType) { + Set uniqueFieldNames = new HashSet<>(); + for (RowType.Field field : rowType.getFields()) { + if (field.getName().isEmpty()) { + return false; + } + if (!uniqueFieldNames.add(field.getName().orElseThrow())) { + return false; + } + } + return true; + } + return false; + }); + + // Intermediate result type. + public static final TypeConstraint IS_RELATION = new TypeConstraint( + type -> type instanceof MultisetType multisetType && IS_RELATION_ROW.test(multisetType.getElementType())); + + public TypeConstraint + { + requireNonNull(constraint, "constraint is null"); + } + + public boolean test(Type t) + { + return constraint.test(t); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Array.java b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Array.java new file mode 100644 index 000000000000..22bc6d3ffc3c --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Array.java @@ -0,0 +1,120 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.dialect.trino.operation; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.spi.TrinoException; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.Type; +import io.trino.sql.newir.Operation; +import io.trino.sql.newir.Region; +import io.trino.sql.newir.Value; + +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static io.trino.spi.StandardErrorCode.IR_ERROR; +import static io.trino.spi.type.VoidType.VOID; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public final class Array + implements Operation +{ + private static final String NAME = "array"; + + private final Result result; + private final List elements; + private final Map attributes; + + public Array(String resultName, Type elementType, List elements, List> sourceAttributes) + { + requireNonNull(resultName, "resultName is null"); + requireNonNull(elementType, "elementType is null"); + requireNonNull(elements, "elements is null"); + requireNonNull(sourceAttributes, "sourceAttributes is null"); + + if (elementType.equals(VOID)) { + throw new TrinoException(IR_ERROR, "cannot use void type for array elements"); + } + + this.result = new Result(resultName, new ArrayType(elementType)); + + elements.stream() + .forEach(element -> { + if (!element.type().equals(elementType)) { + throw new TrinoException(IR_ERROR, format("type of array element: %s does not match the declared type: %s", element.type().getDisplayName(), elementType.getDisplayName())); + } + }); + this.elements = ImmutableList.copyOf(elements); + + // TODO derive attributes from source attributes + this.attributes = ImmutableMap.of(); + } + + @Override + public String name() + { + return NAME; + } + + @Override + public Result result() + { + return result; + } + + @Override + public List arguments() + { + return elements; + } + + @Override + public List regions() + { + return ImmutableList.of(); + } + + @Override + public Map attributes() + { + return attributes; + } + + @Override + public String prettyPrint(int indentLevel) + { + return "array :)"; + } + + @Override + public boolean equals(Object obj) + { + if (obj == this) {return true;} + if (obj == null || obj.getClass() != this.getClass()) {return false;} + var that = (Array) obj; + return Objects.equals(this.result, that.result) && + Objects.equals(this.elements, that.elements) && + Objects.equals(this.attributes, that.attributes); + } + + @Override + public int hashCode() + { + return Objects.hash(result, elements, attributes); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Between.java b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Between.java new file mode 100644 index 000000000000..e8e45b1fa88c --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Between.java @@ -0,0 +1,122 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.dialect.trino.operation; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.spi.TrinoException; +import io.trino.sql.newir.Operation; +import io.trino.sql.newir.Region; +import io.trino.sql.newir.Value; + +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static io.trino.spi.StandardErrorCode.IR_ERROR; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public final class Between + implements Operation +{ + private static final String NAME = "between"; + + private final Result result; + private final Value input; + private final Value min; + private final Value max; + private final Map attributes; + + public Between(String resultName, Value input, Value min, Value max, List> sourceAttributes) + { + requireNonNull(resultName, "resultName is null"); + requireNonNull(input, "input is null"); + requireNonNull(min, "min is null"); + requireNonNull(max, "max is null"); + requireNonNull(sourceAttributes, "sourceAttributes is null"); + + this.result = new Result(resultName, BOOLEAN); + + if (!input.type().equals(min.type())) { + throw new TrinoException(IR_ERROR, format("lower range of between operation has mismatching type. expected: %s, actual: %s", input.type().getDisplayName(), min.type().getDisplayName())); + } + if (!input.type().equals(max.type())) { + throw new TrinoException(IR_ERROR, format("upper range of between operation has mismatching type. expected: %s, actual: %s", input.type().getDisplayName(), min.type().getDisplayName())); + } + + this.input = input; + this.min = min; + this.max = max; + + // TODO derive attributes + this.attributes = ImmutableMap.of(); + } + + @Override + public String name() + { + return NAME; + } + + @Override + public Result result() + { + return result; + } + + @Override + public List arguments() + { + return ImmutableList.of(input, min, max); + } + + @Override + public List regions() + { + return ImmutableList.of(); + } + + @Override + public Map attributes() + { + return attributes; + } + + @Override + public String prettyPrint(int indentLevel) + { + return "pretty between"; + } + + @Override + public boolean equals(Object obj) + { + if (obj == this) {return true;} + if (obj == null || obj.getClass() != this.getClass()) {return false;} + var that = (Between) obj; + return Objects.equals(this.result, that.result) && + Objects.equals(this.input, that.input) && + Objects.equals(this.min, that.min) && + Objects.equals(this.max, that.max) && + Objects.equals(this.attributes, that.attributes); + } + + @Override + public int hashCode() + { + return Objects.hash(result, input, min, max, attributes); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Constant.java b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Constant.java new file mode 100644 index 000000000000..ad340e2ae53c --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Constant.java @@ -0,0 +1,98 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.dialect.trino.operation; + +import com.google.common.collect.ImmutableList; +import io.trino.spi.type.Type; +import io.trino.sql.dialect.trino.Attributes.ConstantResult; +import io.trino.sql.newir.Operation; +import io.trino.sql.newir.Region; +import io.trino.sql.newir.Value; + +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static io.trino.sql.dialect.trino.Attributes.CONSTANT_RESULT; +import static java.util.Objects.requireNonNull; + +public final class Constant + implements Operation +{ + private static final String NAME = "constant"; + + private final Result result; + private final Map attributes; + + public Constant(String resultName, Type type, Object value) + { + requireNonNull(resultName, "resultName is null"); + + this.result = new Result(resultName, type); + + this.attributes = CONSTANT_RESULT.asMap(new ConstantResult(type, value)); + } + + @Override + public String name() + { + return NAME; + } + + @Override + public Result result() + { + return result; + } + + @Override + public List arguments() + { + return ImmutableList.of(); + } + + @Override + public List regions() + { + return ImmutableList.of(); + } + + @Override + public Map attributes() + { + return attributes; + } + + @Override + public String prettyPrint(int indentLevel) + { + return "pretty constant"; + } + + @Override + public boolean equals(Object obj) + { + if (obj == this) {return true;} + if (obj == null || obj.getClass() != this.getClass()) {return false;} + var that = (Constant) obj; + return Objects.equals(this.result, that.result) && + Objects.equals(this.attributes, that.attributes); + } + + @Override + public int hashCode() + { + return Objects.hash(result, attributes); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/CorrelatedJoin.java b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/CorrelatedJoin.java new file mode 100644 index 000000000000..a9285efcfd63 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/CorrelatedJoin.java @@ -0,0 +1,174 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.dialect.trino.operation; + +import com.google.common.collect.ImmutableList; +import io.trino.spi.TrinoException; +import io.trino.spi.type.MultisetType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; +import io.trino.sql.dialect.trino.Attributes.JoinType; +import io.trino.sql.newir.Block; +import io.trino.sql.newir.Operation; +import io.trino.sql.newir.Region; +import io.trino.sql.newir.Value; + +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static io.trino.spi.StandardErrorCode.IR_ERROR; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.EmptyRowType.EMPTY_ROW; +import static io.trino.sql.dialect.trino.Attributes.JOIN_TYPE; +import static io.trino.sql.dialect.trino.RelationalProgramBuilder.assignRelationRowTypeFieldNames; +import static io.trino.sql.dialect.trino.RelationalProgramBuilder.relationRowType; +import static io.trino.sql.dialect.trino.TypeConstraint.IS_RELATION; +import static io.trino.sql.newir.Region.singleBlockRegion; +import static java.util.Objects.requireNonNull; + +public final class CorrelatedJoin + implements Operation +{ + + private static final String NAME = "correlated_join"; + + private final Result result; + private final Value input; + // correlation as field selector. later we should model correlation through the uses graph? + private final Region correlation; + private final Region subquery; + private final Region filter; + private final Map attributes; + // TODO the PlanNode has origin subquery for debug. skipping it for now + + public CorrelatedJoin( + String resultName, + Value input, + Block correlation, + Block subquery, + Block filter, + JoinType joinType, + Map sourceAttributes, + Map subqueryAttributes) + { + requireNonNull(resultName, "resultName is null"); + requireNonNull(input, "input is null"); + requireNonNull(correlation, "correlation is null"); + requireNonNull(subquery, "subquery is null"); + requireNonNull(filter, "filter is null"); + requireNonNull(joinType, "joinType is null"); + requireNonNull(sourceAttributes, "sourceAttributes is null"); + requireNonNull(subqueryAttributes, "subqueryAttributes is null"); + + if (!IS_RELATION.test(input.type()) || !IS_RELATION.test(subquery.getReturnedType())) { + throw new TrinoException(IR_ERROR, "input and subquery of CorrelatedJoin must be of relation type"); + } + + List outputTypes = ImmutableList.builder() + .addAll((relationRowType(input.type())).getTypeParameters()) + .addAll((relationRowType(subquery.getReturnedType())).getTypeParameters()) + .build(); + + if (outputTypes.isEmpty()) { + this.result = new Result(resultName, new MultisetType(EMPTY_ROW)); + } + else { + this.result = new Result(resultName, new MultisetType(assignRelationRowTypeFieldNames(RowType.anonymous(outputTypes)))); + } + + this.input = input; + + if (correlation.parameters().size() != 1 || + !correlation.parameters().getFirst().type().equals(relationRowType(input.type())) || + !(correlation.getReturnedType() instanceof RowType || correlation.getReturnedType().equals(EMPTY_ROW))) { + throw new TrinoException(IR_ERROR, "invalid correlation for CorrelatedJoin operation"); + } + this.correlation = singleBlockRegion(correlation); + + if (subquery.parameters().size() != 1 || + !subquery.parameters().getFirst().type().equals(relationRowType(input.type())) || + !IS_RELATION.test(subquery.getReturnedType())) { + throw new TrinoException(IR_ERROR, "invalid subquery for CorrelatedJoin operation"); + } + this.subquery = singleBlockRegion(subquery); + + if (filter.parameters().size() != 2 || + !filter.parameters().get(0).type().equals(relationRowType(input.type())) || + !filter.parameters().get(1).type().equals(relationRowType(subquery.getReturnedType())) || + !(filter.getReturnedType().equals(BOOLEAN))) { + throw new TrinoException(IR_ERROR, "invalid filter for CorrelatedJoin operation"); + } + this.filter = singleBlockRegion(filter); + + // TODO also derive attributes from source and subquery attributes + this.attributes = JOIN_TYPE.asMap(joinType); + } + + @Override + public String name() + { + return NAME; + } + + @Override + public Result result() + { + return result; + } + + @Override + public List arguments() + { + return ImmutableList.of(input); + } + + @Override + public List regions() + { + return ImmutableList.of(correlation, subquery, filter); + } + + @Override + public Map attributes() + { + return attributes; + } + + @Override + public String prettyPrint(int indentLevel) + { + return "pretty correlated join"; + } + + @Override + public boolean equals(Object obj) + { + if (obj == this) {return true;} + if (obj == null || obj.getClass() != this.getClass()) {return false;} + var that = (CorrelatedJoin) obj; + return Objects.equals(this.result, that.result) && + Objects.equals(this.input, that.input) && + Objects.equals(this.correlation, that.correlation) && + Objects.equals(this.subquery, that.subquery) && + Objects.equals(this.filter, that.filter) && + Objects.equals(this.attributes, that.attributes); + } + + @Override + public int hashCode() + { + return Objects.hash(result, input, correlation, subquery, filter, attributes); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/FieldSelection.java b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/FieldSelection.java new file mode 100644 index 000000000000..350aa5e63178 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/FieldSelection.java @@ -0,0 +1,134 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.dialect.trino.operation; + +import com.google.common.collect.ImmutableList; +import io.trino.spi.TrinoException; +import io.trino.spi.type.RowType; +import io.trino.sql.newir.Operation; +import io.trino.sql.newir.Region; +import io.trino.sql.newir.Value; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +import static io.trino.spi.StandardErrorCode.IR_ERROR; +import static io.trino.spi.type.EmptyRowType.EMPTY_ROW; +import static io.trino.sql.dialect.trino.Attributes.FIELD_NAME; +import static io.trino.sql.dialect.trino.TypeConstraint.IS_RELATION_ROW; +import static java.util.Objects.requireNonNull; + +/** + * FieldSelection operation selects a row field by name. + *

+ * We compare field names case-sensitive, although in TranslationMap field references are resolved case-insensitive. + * Explanation: + * In TranslationMap, all user-provided field references by name are resolved case-insensitive and translated to field references by index. + * Operations, like this one, are created after TranslationMap, so there are no user-provided field references by name. + * At this point, the only field references by name are added programmatically (the FieldSelection Operation), and they refer to RowTypes created programmatically. + * Those RowTypes have lower-case unique field names which can be safely compared case-sensitive. + * When we add a Parser to create the IR from text, we should assume that the text is a printout of a valid query program, + * and thus all field references by name are case-safe. + */ +public final class FieldSelection + implements Operation +{ + private static final String NAME = "field_selection"; + + private final Result result; + private final Value input; + private final Map attributes; + + public FieldSelection(String resultName, Value input, String fieldName, Map sourceAttributes) + { + requireNonNull(resultName, "resultName is null"); + requireNonNull(input, "input is null"); + requireNonNull(sourceAttributes, "sourceAttributes is null"); + + if (!IS_RELATION_ROW.test(input.type()) || input.type().equals(EMPTY_ROW)) { + throw new TrinoException(IR_ERROR, "input to the FieldSelection operation must be a relation row type with fields"); + } + Optional matchingField = ((RowType) input.type()).getFields().stream() + .filter(field -> fieldName.equals(field.getName().orElseThrow())).findFirst(); + if (matchingField.isEmpty()) { + throw new TrinoException(IR_ERROR, "invalid row field selection: no matching field name"); + } + this.result = new Result(resultName, matchingField.orElseThrow().getType()); + + this.input = input; + + this.attributes = deriveAttributes(fieldName, sourceAttributes); + } + + private Map deriveAttributes(String fieldName, Map sourceAttributes) + { + return FIELD_NAME.asMap(fieldName); + // TODO add source attributes for the selected field + } + + @Override + public String name() + { + return NAME; + } + + @Override + public Result result() + { + return result; + } + + @Override + public List arguments() + { + return ImmutableList.of(input); + } + + @Override + public List regions() + { + return ImmutableList.of(); + } + + @Override + public Map attributes() + { + return attributes; + } + + @Override + public String prettyPrint(int indentLevel) + { + return "pretty field selection"; + } + + @Override + public boolean equals(Object obj) + { + if (obj == this) {return true;} + if (obj == null || obj.getClass() != this.getClass()) {return false;} + var that = (FieldSelection) obj; + return Objects.equals(this.result, that.result) && + Objects.equals(this.input, that.input) && + Objects.equals(this.attributes, that.attributes); + } + + @Override + public int hashCode() + { + return Objects.hash(result, input, attributes); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Filter.java b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Filter.java new file mode 100644 index 000000000000..b9040ffd5a77 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Filter.java @@ -0,0 +1,130 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.dialect.trino.operation; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.spi.TrinoException; +import io.trino.sql.newir.Block; +import io.trino.sql.newir.Operation; +import io.trino.sql.newir.Region; +import io.trino.sql.newir.Value; + +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static io.trino.spi.StandardErrorCode.IR_ERROR; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.sql.dialect.trino.RelationalProgramBuilder.relationRowType; +import static io.trino.sql.dialect.trino.TypeConstraint.IS_RELATION; +import static io.trino.sql.newir.Region.singleBlockRegion; +import static java.util.Objects.requireNonNull; + +public final class Filter + implements Operation +{ + private static final String NAME = "filter"; + + private final Result result; + private final Value input; + private final Region predicate; + private final Map attributes; + + public Filter(String resultName, Value input, Block predicate, Map sourceAttributes) + { + requireNonNull(resultName, "resultName is null"); + requireNonNull(input, "input is null"); + requireNonNull(predicate, "predicate is null"); + requireNonNull(sourceAttributes, "sourceAttributes is null"); + + if (!IS_RELATION.test(input.type())) { + throw new TrinoException(IR_ERROR, "input to the Filter operation must be of relation type"); + } + + this.result = new Result(resultName, input.type()); // derives output type: same as input type + + this.input = input; + + if (predicate.parameters().size() != 1 || + !predicate.parameters().getFirst().type().equals(relationRowType(input.type())) || + !(predicate.getReturnedType().equals(BOOLEAN))) { + throw new TrinoException(IR_ERROR, "invalid predicate for Filter operation"); + } + + this.predicate = singleBlockRegion(predicate); + + this.attributes = deriveAttributes(sourceAttributes); + } + + // TODO + private Map deriveAttributes(Map sourceAttributes) + { + return ImmutableMap.of(); + } + + @Override + public String name() + { + return NAME; + } + + @Override + public Result result() + { + return result; + } + + @Override + public List arguments() + { + return ImmutableList.of(input); + } + + @Override + public List regions() + { + return ImmutableList.of(predicate); + } + + @Override + public Map attributes() + { + return attributes; + } + + @Override + public String prettyPrint(int indentLevel) + { + return "pretty filter"; + } + + @Override + public boolean equals(Object obj) + { + if (obj == this) {return true;} + if (obj == null || obj.getClass() != this.getClass()) {return false;} + var that = (Filter) obj; + return Objects.equals(this.result, that.result) && + Objects.equals(this.input, that.input) && + Objects.equals(this.predicate, that.predicate) && + Objects.equals(this.attributes, that.attributes); + } + + @Override + public int hashCode() + { + return Objects.hash(result, input, predicate, attributes); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Logical.java b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Logical.java new file mode 100644 index 000000000000..f80715a4c8c9 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Logical.java @@ -0,0 +1,117 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.dialect.trino.operation; + +import com.google.common.collect.ImmutableList; +import io.trino.spi.TrinoException; +import io.trino.sql.dialect.trino.Attributes.LogicalOperator; +import io.trino.sql.newir.Operation; +import io.trino.sql.newir.Region; +import io.trino.sql.newir.Value; + +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static io.trino.spi.StandardErrorCode.IR_ERROR; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.sql.dialect.trino.Attributes.LOGICAL_OPERATOR; +import static java.util.Objects.requireNonNull; + +public final class Logical + implements Operation +{ + private static final String NAME = "logical"; + + private final Result result; + private final List terms; + private final Map attributes; + + public Logical(String resultName, List terms, LogicalOperator logicalOperator, List> sourceAttributes) + { + requireNonNull(resultName, "resultName is null"); + requireNonNull(terms, "terms is null"); + requireNonNull(logicalOperator, "logicalOperator is null"); + requireNonNull(sourceAttributes, "sourceAttributes is null"); + + this.result = new Result(resultName, BOOLEAN); + + if (terms.size() < 2) { + throw new TrinoException(IR_ERROR, "logical operation must have at least 2 terms. actual: " + terms.size()); + } + terms.stream() + .forEach(term -> { + if (!term.type().equals(BOOLEAN)) { + throw new TrinoException(IR_ERROR, "all terms of a logical operation must be of boolean type. found: " + term.type().getDisplayName()); + } + }); + this.terms = ImmutableList.copyOf(terms); + + // TODO also derive attributes from source attributes + this.attributes = LOGICAL_OPERATOR.asMap(logicalOperator); + } + + @Override + public String name() + { + return NAME; + } + + @Override + public Result result() + { + return result; + } + + @Override + public List arguments() + { + return terms; + } + + @Override + public List regions() + { + return ImmutableList.of(); + } + + @Override + public Map attributes() + { + return attributes; + } + + @Override + public String prettyPrint(int indentLevel) + { + return "logical :)"; + } + + @Override + public boolean equals(Object obj) + { + if (obj == this) {return true;} + if (obj == null || obj.getClass() != this.getClass()) {return false;} + var that = (Logical) obj; + return Objects.equals(this.result, that.result) && + Objects.equals(this.terms, that.terms) && + Objects.equals(this.attributes, that.attributes); + } + + @Override + public int hashCode() + { + return Objects.hash(result, terms, attributes); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Output.java b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Output.java new file mode 100644 index 000000000000..906b4277d973 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Output.java @@ -0,0 +1,129 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.dialect.trino.operation; + +import com.google.common.collect.ImmutableList; +import io.trino.spi.TrinoException; +import io.trino.spi.type.RowType; +import io.trino.sql.dialect.trino.Attributes.OutputNames; +import io.trino.sql.newir.Block; +import io.trino.sql.newir.Operation; +import io.trino.sql.newir.Region; +import io.trino.sql.newir.Value; + +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static io.trino.spi.StandardErrorCode.IR_ERROR; +import static io.trino.spi.type.EmptyRowType.EMPTY_ROW; +import static io.trino.spi.type.VoidType.VOID; +import static io.trino.sql.dialect.trino.Attributes.OUTPUT_NAMES; +import static io.trino.sql.dialect.trino.RelationalProgramBuilder.relationRowType; +import static io.trino.sql.dialect.trino.TypeConstraint.IS_RELATION; +import static io.trino.sql.newir.Region.singleBlockRegion; +import static java.util.Objects.requireNonNull; + +public final class Output + implements Operation +{ + private static final String NAME = "output"; + + private final Result result; + private final Value input; + private final Region fieldSelector; + private final List outputNames; + + public Output(String resultName, Value input, Block fieldSelector, List outputNames) + { + requireNonNull(resultName, "resultName is null"); + requireNonNull(input, "input is null"); + requireNonNull(fieldSelector, "fieldSelector is null"); + requireNonNull(outputNames, "outputNames is null"); + + if (!IS_RELATION.test(input.type())) { + throw new TrinoException(IR_ERROR, "input to the Output operation must be of relation type"); + } + + this.result = new Result(resultName, VOID); + + this.input = input; + + if (fieldSelector.parameters().size() != 1 || + !fieldSelector.parameters().getFirst().type().equals(relationRowType(input.type())) || + fieldSelector.parameters().getFirst().type().equals(EMPTY_ROW) || + !(fieldSelector.getReturnedType() instanceof RowType) || + ((RowType) fieldSelector.getReturnedType()).getTypeParameters().size() != outputNames.size()) { + throw new TrinoException(IR_ERROR, "invalid field selection for Output operation"); + } + + this.fieldSelector = singleBlockRegion(fieldSelector); + + this.outputNames = outputNames; + } + + @Override + public String name() + { + return NAME; + } + + @Override + public Result result() + { + return result; + } + + @Override + public List arguments() + { + return ImmutableList.of(input); + } + + @Override + public List regions() + { + return ImmutableList.of(fieldSelector); + } + + @Override + public Map attributes() + { + return OUTPUT_NAMES.asMap(new OutputNames(outputNames)); + } + + @Override + public String prettyPrint(int indentLevel) + { + return "pretty output"; + } + + @Override + public boolean equals(Object obj) + { + if (obj == this) {return true;} + if (obj == null || obj.getClass() != this.getClass()) {return false;} + var that = (Output) obj; + return Objects.equals(this.result, that.result) && + Objects.equals(this.input, that.input) && + Objects.equals(this.fieldSelector, that.fieldSelector) && + Objects.equals(this.outputNames, that.outputNames); + } + + @Override + public int hashCode() + { + return Objects.hash(result, input, fieldSelector, outputNames); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Query.java b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Query.java new file mode 100644 index 000000000000..f611a852cb02 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Query.java @@ -0,0 +1,111 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.dialect.trino.operation; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.spi.TrinoException; +import io.trino.sql.newir.Block; +import io.trino.sql.newir.Operation; +import io.trino.sql.newir.Region; +import io.trino.sql.newir.Value; + +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static io.trino.spi.StandardErrorCode.IR_ERROR; +import static io.trino.spi.type.VoidType.VOID; +import static io.trino.sql.newir.Region.singleBlockRegion; +import static java.util.Objects.requireNonNull; + +public final class Query + implements Operation +{ + private static final String NAME = "query"; + + private final Result result; + private final Region query; + + public Query(String resultName, Block query) + { + requireNonNull(resultName, "resultName is null"); + requireNonNull(query, "query is null"); + + this.result = new Result(resultName, VOID); + + if (!(query.getTerminalOperation() instanceof Output)) { + throw new TrinoException(IR_ERROR, "query block must end in Output operation"); + } + this.query = singleBlockRegion(query); + } + + @Override + public String name() + { + return NAME; + } + + @Override + public Result result() + { + return result; + } + + @Override + public List arguments() + { + return ImmutableList.of(); + } + + @Override + public List regions() + { + return ImmutableList.of(query); + } + + @Override + public Map attributes() + { + return ImmutableMap.of(); + } + + @Override + public String print(int indentLevel) + { + return query.print(indentLevel); + } + + @Override + public String prettyPrint(int indentLevel) + { + return "♡♡♡ query ♡♡♡"; + } + + @Override + public boolean equals(Object obj) + { + if (obj == this) {return true;} + if (obj == null || obj.getClass() != this.getClass()) {return false;} + var that = (Query) obj; + return Objects.equals(this.result, that.result) && + Objects.equals(this.query, that.query); + } + + @Override + public int hashCode() + { + return Objects.hash(result, query); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Return.java b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Return.java new file mode 100644 index 000000000000..b97c52b892c0 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Return.java @@ -0,0 +1,102 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.dialect.trino.operation; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.sql.newir.Operation; +import io.trino.sql.newir.Region; +import io.trino.sql.newir.Value; + +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +public final class Return + implements Operation +{ + private static final String NAME = "return"; + + private final Result result; + private final Value input; + private final Map attributes; + + public Return(String resultName, Value input, Map sourceAttributes) + { + requireNonNull(resultName, "resultName is null"); + requireNonNull(input, "input is null"); + requireNonNull(sourceAttributes, "sourceAttributes is null"); + + this.result = new Result(resultName, input.type()); + + this.input = input; + + this.attributes = ImmutableMap.of(); // TODO pass relevant attributes (skip attrs like join type) + } + + @Override + public String name() + { + return NAME; + } + + @Override + public Result result() + { + return result; + } + + @Override + public List arguments() + { + return ImmutableList.of(input); + } + + @Override + public List regions() + { + return ImmutableList.of(); + } + + @Override + public Map attributes() + { + return attributes; + } + + @Override + public String prettyPrint(int indentLevel) + { + return "pretty return"; + } + + @Override + public boolean equals(Object obj) + { + if (obj == this) {return true;} + if (obj == null || obj.getClass() != this.getClass()) {return false;} + var that = (Return) obj; + return Objects.equals(this.result, that.result) && + Objects.equals(this.input, that.input) && + Objects.equals(this.attributes, that.attributes); + } + + @Override + public int hashCode() + { + return Objects.hash(result, input, attributes); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Row.java b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Row.java new file mode 100644 index 000000000000..1845a2bc0329 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Row.java @@ -0,0 +1,122 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.dialect.trino.operation; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.spi.TrinoException; +import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; +import io.trino.sql.newir.Operation; +import io.trino.sql.newir.Region; +import io.trino.sql.newir.Value; + +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.spi.StandardErrorCode.IR_ERROR; +import static io.trino.spi.type.VoidType.VOID; +import static java.util.Objects.requireNonNull; + +public final class Row + implements Operation +{ + private static final String NAME = "row"; + + private final Result result; + private final List fields; + private final Map attributes; + + public Row(String resultName, List fields, List> sourceAttributes) + { + requireNonNull(resultName, "resultName is null"); + requireNonNull(fields, "fields is null"); + requireNonNull(sourceAttributes, "sourceAttributes is null"); + + Type resultType = RowType.anonymous( // fails if there are no fields + fields.stream().map(value -> { + if (value.type().equals(VOID)) { + throw new TrinoException(IR_ERROR, "cannot use void type for a row field"); + } + return value.type(); + }).collect(toImmutableList())); + + this.result = new Result(resultName, resultType); + + this.fields = ImmutableList.copyOf(fields); + + this.attributes = deriveAttributes(sourceAttributes); + } + + // TODO + private Map deriveAttributes(List> sourceAttributes) + { + return ImmutableMap.of(); + } + + @Override + public String name() + { + return NAME; + } + + @Override + public Result result() + { + return result; + } + + @Override + public List arguments() + { + return fields; + } + + @Override + public List regions() + { + return ImmutableList.of(); + } + + @Override + public Map attributes() + { + return attributes; + } + + @Override + public String prettyPrint(int indentLevel) + { + return "row :)"; + } + + @Override + public boolean equals(Object obj) + { + if (obj == this) {return true;} + if (obj == null || obj.getClass() != this.getClass()) {return false;} + var that = (Row) obj; + return Objects.equals(this.result, that.result) && + Objects.equals(this.fields, that.fields) && + Objects.equals(this.attributes, that.attributes); + } + + @Override + public int hashCode() + { + return Objects.hash(result, fields, attributes); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Values.java b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Values.java new file mode 100644 index 000000000000..4e1914810f6d --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Values.java @@ -0,0 +1,157 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.dialect.trino.operation; + +import com.google.common.collect.ImmutableList; +import io.trino.spi.TrinoException; +import io.trino.spi.type.MultisetType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; +import io.trino.sql.newir.Block; +import io.trino.sql.newir.Operation; +import io.trino.sql.newir.Region; +import io.trino.sql.newir.Value; + +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.spi.StandardErrorCode.IR_ERROR; +import static io.trino.spi.type.EmptyRowType.EMPTY_ROW; +import static io.trino.sql.dialect.trino.Attributes.CARDINALITY; +import static io.trino.sql.dialect.trino.RelationalProgramBuilder.assignRelationRowTypeFieldNames; +import static io.trino.sql.newir.Region.singleBlockRegion; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public final class Values + implements Operation +{ + private static final String NAME = "values"; + + private final Result result; + private final List rows; + private final Map attributes; + + // TODO Values can be constant or correlated. If it is constant, it should be folded to Constant operation + + public Values(String resultName, RowType rowType, List rows) + { + requireNonNull(resultName, "resultName is null"); + requireNonNull(rows, "rows is null"); + + // Create output type with unique field names. + // Field names of the passed RowType and of the individual rows (if present) will be ignored. + // This is consistent with the Trino behavior in StatementAnalyzer: the RelationType + // for Values has anonymous fields even if individual rows had named fields. + RowType outputType = assignRelationRowTypeFieldNames(rowType); + this.result = new Result(resultName, new MultisetType(outputType)); + + // Verify that each row matches the output type. Check field types only. + // Field names are ignored. They will be overridden by the output type. + this.rows = rows.stream() + .map(rowBlock -> { + // verify that blocks have no parameters + if (!rowBlock.parameters().isEmpty()) { + throw new TrinoException(IR_ERROR, format("no block parameters expected. got %s parameters", rowBlock.parameters().size())); + } + Type blockType = rowBlock.getReturnedType(); + if (!(blockType instanceof RowType)) { + throw new TrinoException(IR_ERROR, "block should return RowType. actual: " + blockType.getDisplayName()); + } + if (!((RowType) blockType).getTypeParameters().equals(rowType.getTypeParameters())) { + throw new TrinoException(IR_ERROR, format("type of row: %s does not match the declared output type: %s", blockType.getDisplayName(), rowType.getDisplayName())); + } + return singleBlockRegion(rowBlock); + }) + .collect(toImmutableList()); + // TODO all Blocks representing rows could be combined into one Block returning a multiset + + this.attributes = CARDINALITY.asMap((long) rows.size()); + } + + private Values(String resultName, int rows) + { + requireNonNull(resultName, "resultName is null"); + + if (rows < 0) { + throw new TrinoException(IR_ERROR, "negative row count: " + rows); + } + + this.result = new Result(resultName, EMPTY_ROW); + + this.rows = ImmutableList.of(); + + this.attributes = CARDINALITY.asMap((long) rows); + } + + public static Values valuesWithoutFields(String resultName, int rows) + { + return new Values(resultName, rows); + } + + @Override + public String name() + { + return NAME; + } + + @Override + public Result result() + { + return result; + } + + @Override + public List arguments() + { + return ImmutableList.of(); + } + + @Override + public List regions() + { + return rows; + } + + @Override + public Map attributes() + { + return attributes; + } + + @Override + public String prettyPrint(int indentLevel) + { + return "pretty values"; + } + + @Override + public boolean equals(Object obj) + { + if (obj == this) {return true;} + if (obj == null || obj.getClass() != this.getClass()) {return false;} + var that = (Values) obj; + return Objects.equals(this.result, that.result) && + Objects.equals(this.rows, that.rows) && + Objects.equals(this.attributes, that.attributes); + } + + @Override + public int hashCode() + { + return Objects.hash(result, rows, attributes); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/newir/Block.java b/core/trino-main/src/main/java/io/trino/sql/newir/Block.java new file mode 100644 index 000000000000..c28e7b86d148 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/newir/Block.java @@ -0,0 +1,153 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.newir; + +import com.google.common.collect.ImmutableList; +import io.trino.spi.TrinoException; +import io.trino.spi.type.Type; + +import java.util.List; +import java.util.Optional; + +import static io.trino.spi.StandardErrorCode.IR_ERROR; +import static io.trino.sql.newir.PrinterOptions.INDENT; +import static io.trino.sql.newir.Value.validateValueName; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.joining; + +/** + * A Block is a unit of code consisting of a list of Operations. + * A Block belongs to a Region, and together with other Blocks within a Region it forms a Control Flow Graph. + * Note that for now, we only support single-block Regions, so no control flow is involved. + * Blocks define Parameters being typed values. + * Blocks must end with a terminal operation such as Return. + * Blocks have an optional name (label). They should be accessed by their position in the list within a Region. + */ +public record Block(Optional name, List parameters, List operations) + implements SourceNode +{ + public record Parameter(String name, Type type) + implements Value + { + public Parameter + { + validateValueName(name); + } + + @Override + public Block source(Program program) + { + return program.getBlock(this); + } + } + + public Block(Optional name, List parameters, List operations) + { + this.name = requireNonNull(name, "name is null"); + validateBlockName(name); + this.parameters = ImmutableList.copyOf(requireNonNull(parameters, "parameters is null")); + this.operations = ImmutableList.copyOf(requireNonNull(operations, "operations is null")); + if (operations.isEmpty()) { + throw new TrinoException(IR_ERROR, "invalid block: empty operations list"); + } + // TODO verify that block ends with a terminal operation. Define a top-level attribute to mark terminal operations? + } + + private static void validateBlockName(Optional optionalName) + { + optionalName.ifPresent(name -> { + if (!name.startsWith("^")) { + throw new TrinoException(IR_ERROR, format("invalid block name: \"%s\"", name)); + } + }); + } + + public String print(int indentLevel) + { + StringBuilder builder = new StringBuilder(); + String indent = INDENT.repeat(indentLevel); + + builder.append(indent) + .append(name().orElse("")); + + if (!parameters().isEmpty()) { + builder.append(parameters().stream() + .map(parameter -> parameter.name() + " : " + parameter.type()) + .collect(joining(", ", " (", ")"))); + } + + builder.append(operations().stream() + .map(operation -> operation.print(indentLevel + 1)) + .collect(joining("\n", "\n", ""))); + + return builder.toString(); + } + + public String prettyPrint(int indentLevel) + { + return print(indentLevel); + } + + public int getIndex(Parameter parameter) + { + int index = parameters.indexOf(parameter); + if (index < 0) { + throw new TrinoException(IR_ERROR, parameter.name() + "is not a parameter of given block"); + } + return index; + } + + public Operation getTerminalOperation() + { + return operations.getLast(); + } + + public Type getReturnedType() + { + return getTerminalOperation().result().type(); + } + + public static class Builder + { + private final Optional name; + private final List parameters; + private final ImmutableList.Builder operations = ImmutableList.builder(); + private Optional recentOperation = Optional.empty(); + + public Builder(Optional name, List parameters) + { + this.name = name; + this.parameters = parameters; + } + + public Builder addOperation(Operation operation) + { + operations.add(operation); + recentOperation = Optional.of(operation); + return this; + } + + // access to the recently added operation allows the caller to append a return operation or a navigating operation (in the future) + public Operation recentOperation() + { + return recentOperation.orElseThrow(() -> new TrinoException(IR_ERROR, "no operations added yet")); + } + + public Block build() + { + return new Block(name, parameters, operations.build()); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/newir/Operation.java b/core/trino-main/src/main/java/io/trino/sql/newir/Operation.java new file mode 100644 index 000000000000..cf2f095a995d --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/newir/Operation.java @@ -0,0 +1,98 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.newir; + +import io.trino.spi.type.Type; + +import java.util.List; +import java.util.Map; + +import static io.trino.sql.newir.PrinterOptions.INDENT; +import static io.trino.sql.newir.Value.validateValueName; +import static java.util.stream.Collectors.joining; + +/** + * Operation is the main building block of a program. + * Operation, as well as other code elements, does not use JSON serialization. + * The serialized format is obtained through the print() method. + */ +public non-sealed interface Operation + extends SourceNode +{ + record Result(String name, Type type) + implements Value + { + public Result + { + validateValueName(name); + } + + @Override + public Operation source(Program program) + { + return program.getOperation(this); + } + } + + String name(); + + Result result(); + + List arguments(); + + List regions(); + + Map attributes(); + + default String print(int indentLevel) + { + StringBuilder builder = new StringBuilder(); + String indent = INDENT.repeat(indentLevel); + + builder.append(indent) + .append(result().name()) + .append(" = ") + .append(name()) + .append(arguments().stream() + .map(Value::name) + .collect(joining(", ", "(", ")"))) + .append(" : ") + .append(arguments().stream() + .map(Value::type) + .map(Type::toString) + .collect(joining(", ", "(", ")"))) + .append(" -> ") + .append(result().type().toString()) + .append(regions().stream() + .map(region -> region.print(indentLevel + 1)) + .collect(joining(", ", " (", ")"))); + + // do not render empty attributes list + if (!attributes().isEmpty()) { + builder.append("\n") + .append(indent) + .append(INDENT) + .append(attributes().entrySet().stream() + .map(entry -> entry.getKey() + " = " + entry.getValue().toString()) + .collect(joining(", ", "{", "}"))); + } + + return builder.toString(); + } + + default String prettyPrint(int indentLevel) + { + return print(indentLevel); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/newir/PrinterOptions.java b/core/trino-main/src/main/java/io/trino/sql/newir/PrinterOptions.java new file mode 100644 index 000000000000..237ff5ba848a --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/newir/PrinterOptions.java @@ -0,0 +1,21 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.newir; + +public class PrinterOptions +{ + public static final String INDENT = " "; + + private PrinterOptions() {} +} diff --git a/core/trino-main/src/main/java/io/trino/sql/newir/Program.java b/core/trino-main/src/main/java/io/trino/sql/newir/Program.java new file mode 100644 index 000000000000..7dbffdf71ad8 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/newir/Program.java @@ -0,0 +1,93 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.newir; + +import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.Immutable; +import io.trino.spi.TrinoException; +import io.trino.spi.type.Type; +import io.trino.sql.newir.Block.Parameter; +import io.trino.sql.newir.Operation.Result; + +import java.util.Map; + +import static io.trino.spi.StandardErrorCode.IR_ERROR; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +/** + * Representation of a program in the spirit of MLIR. + * The top-level entity is an Operation which has a Block enclosing the logic of the program. + * The valueMap is a mapping of each Value to its source. Value names are unique. + */ +@Immutable +public final class Program +{ + private final Operation root; + + // each Operation.Result is mapped to its returning Operation + // each Block.Parameter is mapped to its declaring Block + private final Map valueMap; + + public Program(Operation root, Map valueMap) + { + this.root = requireNonNull(root, "root is null"); + this.valueMap = ImmutableMap.copyOf(requireNonNull(valueMap, "valueMap is null")); + } + + public Operation getOperation(Result value) + { + SourceNode source = valueMap.get(value); + if (source == null) { + throw new TrinoException(IR_ERROR, format("value %s is not defined", value.name())); + } + + if (source instanceof Operation operation) { + if (!value.type().equals(operation.result().type())) { + throw new TrinoException(IR_ERROR, format("value %s type mismatch. expected: %s, actual: %s", value.name(), value.type(), operation.result().type())); + } + return operation; + } + + throw new TrinoException(IR_ERROR, format("value %s is not an operation result", value.name())); + } + + public Block getBlock(Parameter value) + { + SourceNode source = valueMap.get(value); + if (source == null) { + throw new TrinoException(IR_ERROR, format("value %s is not defined", value.name())); + } + + if (source instanceof Block block) { + Type parameterType = block.parameters().get(block.getIndex(value)).type(); + if (!value.type().equals(parameterType)) { + throw new TrinoException(IR_ERROR, format("value %s type mismatch. expected: %s, actual: %s", value.name(), value.type(), parameterType)); + } + return block; + } + + throw new TrinoException(IR_ERROR, format("value %s is not a block parameter", value.name())); + } + + public Operation getRoot() + { + return root; + } + + public String print() + { + return root.print(0); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/newir/README.md b/core/trino-main/src/main/java/io/trino/sql/newir/README.md new file mode 100644 index 000000000000..756bb2abfa24 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/newir/README.md @@ -0,0 +1,89 @@ +# The new IR + +The new IR for Trino is conceptually based on [MLIR](https://mlir.llvm.org/).\ +This is a prototype implementation. It will change over time. + +## Abstractions + +### Operation + +[Operation](Operation.java) is the main abstraction. It is used to represent +relational and scalar SQL operations in a uniform way. Some examples of +operations are: + +- [Constant](Constant.java): represents a constant scalar value or a constant relation, +- [FieldSelection](FieldSelection.java): selects a row field by name, +- [Filter](Filter.java): relational filter operation, +- [Output](Output.java): query output, +- [Query](Query.java): represents a SELECT statement, +- [Return](Return.java): terminal operation, +- [Row](Row.java): row constructor, +- [Values](Values.java): SQL values. + +All operations share common features: + +1. Operations return a single result, being a typed value. An operation derives its + output type based on its inputs and attributes. +2. Operations can have a list of [Regions](Region.java) containing [Blocks](Block.java) + to express nested logic as a lambda. For example, [Filter](Filter.java) has a block + with the predicate, CorrelatedJoin has a block with the subquery program. A block + has a list of operations, forming a recursive structure. +3. Operations can take arguments, being either results of another operation, + or parameters of an enclosing block. Operations can require that arguments + adhere to [TypeConstraints](TypeConstraint.java). +4. Operations can have [Attributes](Attribute.java). Attributes represent + logical properties of the operation's result, like cardinality. They can also represent + operation's properties, like join type. + +### Value + +[Value](Value.java) is another abstraction important for MLIR-style modeling. +A value is either an operation result or a block parameter.\ +Values follow the +[SSA form](https://en.wikipedia.org/wiki/Static_single-assignment_form), which means +that each value is assigned exactly once.\ +Values follow the usual visibility rules: a value is visible if it was assigned +earlier in the same block, or it is a parameter of some enclosing block. +Visibility across nested blocks is useful for modeling nested lambdas, like +deeply correlated subqueries.\ +In MLIR, blocks can invoke other blocks and pass values. In our initial model, +we operate on a higher abstraction level, and do not use it. + +In the current Trino IR, the PlanNodes take other PlanNodes as sources. +Using explicit values introduces indirection between the operation which +produces the value and the operation which consumes it. + +### Program + +Program represents a query plan.\ +It has one top-level [Query](Query.java) operation. +This operation has one block which contains the full query program, ending with a +terminal [Output](Output.java) operation.\ +Program also has a map which links each value to its source. The source is either +an operation returning the value or a block declaring the value as a parameter. + +In the future, this model can be easily extended to represent a set of statements. + +## Conversions between old and new IRs + +Initially, the old and new IRs will coexist in Trino. We need a way to go from one +representation to the other.\ +[ProgramBuilder](ProgramBuilder.java) converts a tree of PlanNodes into a Program. +For now, it supports a small subset of PlanNodes.\ +For the conversion, it uses two visitor-based rewriters: + +- [RelationalProgramBuilder](RelationalProgramBuilder.java) rewrites PlanNodes based on [PlanVisitor](../planner/plan/PlanVisitor.java) +- [ScalarProgramBuilder](ScalarProgramBuilder.java) rewrites scalar expressions based on [IrVisitor](../ir/IrVisitor.java) + +We need the two rewriters because the old IR has different representations for scalar +and relational operations. In the new IR, all operations are represented in a uniform way. + +## Evolution of the project + +#### Phase 1: new plan representation (in progress) + +- Define the abstractions in the spirit of MLIR. +- Support programmatic creation. +- Implement serialization-deserialization using the MLIR assembly format. +- Implement the scalar and relational operations which are present in the optimized Trino plan. + Use high level of abstraction, similar to that of PlanNodes. diff --git a/core/trino-main/src/main/java/io/trino/sql/newir/Region.java b/core/trino-main/src/main/java/io/trino/sql/newir/Region.java new file mode 100644 index 000000000000..4093c49b168e --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/newir/Region.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.newir; + +import com.google.common.collect.ImmutableList; +import io.trino.spi.TrinoException; + +import java.util.List; + +import static io.trino.spi.StandardErrorCode.IR_ERROR; +import static io.trino.sql.newir.PrinterOptions.INDENT; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.joining; + +/** + * A Region is a list of Blocks forming a Control Flow Graph. + * A Region is attached to an Operation, and an Operation can have multiple Regions forming an ordered list. + * Unlike Blocks, Regions do not have an optional name (label). They can be accessed by their position in the list. + * The semantics of a Region is defined by the enclosing Operation. It means that control flow across Regions + * is implicitly managed by the enclosing Operation. + *

+ * For now, we only support single-block Regions. + */ +public record Region(List blocks) +{ + public Region(List blocks) + { + this.blocks = ImmutableList.copyOf(requireNonNull(blocks, "blocks is null")); + if (blocks.size() != 1) { // TODO when we lift the single block restriction, verify that the blocks list is not empty + throw new TrinoException(IR_ERROR, "expected 1 block, actual: " + blocks.size()); + } + } + + public static Region singleBlockRegion(Block block) + { + return new Region(ImmutableList.of(block)); + } + + public String print(int indentLevel) + { + String indent = INDENT.repeat(indentLevel); + + return blocks().stream() + .map(block -> block.print(indentLevel)) + .collect(joining("\n", "{\n", "\n" + indent + "}")); + } + + public String prettyPrint(int indentLevel) + { + return print(indentLevel); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/newir/SourceNode.java b/core/trino-main/src/main/java/io/trino/sql/newir/SourceNode.java new file mode 100644 index 000000000000..439caaa41978 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/newir/SourceNode.java @@ -0,0 +1,27 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.newir; + +import com.google.errorprone.annotations.Immutable; + +/** + * A Node is a code element which can be a source of a Value: + * - an Operation is a source of its Result. + * - a Block is a source of all its Parameters. + */ +@Immutable +public sealed interface SourceNode + permits Operation, Block +{ +} diff --git a/core/trino-main/src/main/java/io/trino/sql/newir/Value.java b/core/trino-main/src/main/java/io/trino/sql/newir/Value.java new file mode 100644 index 000000000000..a87ecce496e7 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/newir/Value.java @@ -0,0 +1,39 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.newir; + +import com.google.errorprone.annotations.Immutable; +import io.trino.spi.TrinoException; +import io.trino.spi.type.Type; + +import static io.trino.spi.StandardErrorCode.IR_ERROR; +import static java.lang.String.format; + +@Immutable +public sealed interface Value + permits Operation.Result, Block.Parameter +{ + String name(); + + Type type(); + + SourceNode source(Program program); + + static void validateValueName(String name) + { + if (!name.startsWith("%")) { + throw new TrinoException(IR_ERROR, format("invalid value name: \"%s\"", name)); + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/sql/dialect/trino/TestProgramBuilder.java b/core/trino-main/src/test/java/io/trino/sql/dialect/trino/TestProgramBuilder.java new file mode 100644 index 000000000000..40b122f1fe2e --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/dialect/trino/TestProgramBuilder.java @@ -0,0 +1,104 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.dialect.trino; + +import com.google.common.collect.ImmutableList; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Row; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.plan.CorrelatedJoinNode; +import io.trino.sql.planner.plan.FilterNode; +import io.trino.sql.planner.plan.JoinType; +import io.trino.sql.planner.plan.OutputNode; +import io.trino.sql.planner.plan.PlanNodeId; +import io.trino.sql.planner.plan.ValuesNode; +import io.trino.sql.tree.Identifier; +import org.junit.jupiter.api.Test; + +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.sql.ir.Logical.Operator.AND; +import static org.assertj.core.api.Assertions.assertThat; + +class TestProgramBuilder +{ + @Test + public void testProgramBuilderAndPrinter() + { + assertThat(ProgramBuilder + .buildProgram( + new OutputNode( + new PlanNodeId("100"), + new FilterNode( + new PlanNodeId("101"), + new ValuesNode( + new PlanNodeId("102"), + ImmutableList.of(new Symbol(BIGINT, "a"), new Symbol(BOOLEAN, "b")), + ImmutableList.of( + new io.trino.sql.ir.Row(ImmutableList.of(new io.trino.sql.ir.Constant(BIGINT, 3L), new io.trino.sql.ir.Constant(BOOLEAN, true))), + new io.trino.sql.ir.Row(ImmutableList.of(new io.trino.sql.ir.Constant(BIGINT, 5L), new io.trino.sql.ir.Constant(BOOLEAN, false))))), + new io.trino.sql.ir.Constant(BOOLEAN, true)), + ImmutableList.of("col_a"), + ImmutableList.of(new Symbol(BIGINT, "a")))) + .print()) + .isEqualTo("whatever, give me the printout in the error message"); + } + + @Test + public void testCorrelation() + { + assertThat(ProgramBuilder + .buildProgram( + new OutputNode( + new PlanNodeId("100"), + new CorrelatedJoinNode( + new PlanNodeId("101"), + new ValuesNode( + new PlanNodeId("102"), + ImmutableList.of(new Symbol(BIGINT, "a"), new Symbol(BOOLEAN, "b")), + ImmutableList.of( + new io.trino.sql.ir.Row(ImmutableList.of(new io.trino.sql.ir.Constant(BIGINT, 1L), new io.trino.sql.ir.Constant(BOOLEAN, true))), + new io.trino.sql.ir.Row(ImmutableList.of(new io.trino.sql.ir.Constant(BIGINT, 2L), new io.trino.sql.ir.Constant(BOOLEAN, false))))), + new CorrelatedJoinNode( + new PlanNodeId("103"), + new ValuesNode( + new PlanNodeId("104"), + ImmutableList.of(new Symbol(BIGINT, "c"), new Symbol(BOOLEAN, "d")), + ImmutableList.of( + new io.trino.sql.ir.Row(ImmutableList.of(new io.trino.sql.ir.Constant(BIGINT, 3L), new io.trino.sql.ir.Constant(BOOLEAN, true))), + new io.trino.sql.ir.Row(ImmutableList.of(new io.trino.sql.ir.Constant(BIGINT, 4L), new io.trino.sql.ir.Constant(BOOLEAN, false))))), + new FilterNode( + new PlanNodeId("105"), + new ValuesNode( + new PlanNodeId("106"), + ImmutableList.of(new Symbol(BIGINT, "e"), new Symbol(BOOLEAN, "f")), + ImmutableList.of( + new io.trino.sql.ir.Row(ImmutableList.of(new Reference(BIGINT, "a"), new Reference(BOOLEAN, "b"))), // correlated level 1 + new Row(ImmutableList.of(new Reference(BIGINT, "c"), new Reference(BOOLEAN, "d"))))), // correlated level 2 + new Logical(AND, ImmutableList.of(new Reference(BOOLEAN, "b"), new Reference(BOOLEAN, "d")))), // correlated on 2 levels + ImmutableList.of(new Symbol(BIGINT, "c"), new Symbol(BOOLEAN, "d")), + JoinType.INNER, + new Reference(BOOLEAN, "b"), + new Identifier("bla")), // origin subquery, whatever + ImmutableList.of(new Symbol(BIGINT, "a"), new Symbol(BOOLEAN, "b")), + JoinType.INNER, + new Reference(BOOLEAN, "b"), + new Identifier("bla")), // origin subquery, whatever + ImmutableList.of("col_1", "col_2", "col_3", "col_4", "col_5", "col_6"), + ImmutableList.of(new Symbol(BIGINT, "a"), new Symbol(BOOLEAN, "b"), new Symbol(BIGINT, "c"), new Symbol(BOOLEAN, "d"), new Symbol(BIGINT, "e"), new Symbol(BOOLEAN, "f")))) + .print()) + .isEqualTo("whatever, give me the printout in the error message"); + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java b/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java index f4397d4d013d..5824884741b9 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java +++ b/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java @@ -186,6 +186,7 @@ public enum StandardErrorCode EXCHANGE_MANAGER_NOT_CONFIGURED(65564, INTERNAL_ERROR), CATALOG_NOT_AVAILABLE(65565, INTERNAL_ERROR), CATALOG_STORE_ERROR(65566, INTERNAL_ERROR), + IR_ERROR(65567, INTERNAL_ERROR), GENERIC_INSUFFICIENT_RESOURCES(131072, INSUFFICIENT_RESOURCES), EXCEEDED_GLOBAL_MEMORY_LIMIT(131073, INSUFFICIENT_RESOURCES), diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/EmptyRowType.java b/core/trino-spi/src/main/java/io/trino/spi/type/EmptyRowType.java new file mode 100644 index 000000000000..a8c66c2c0c10 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/type/EmptyRowType.java @@ -0,0 +1,196 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.type; + +import io.airlift.slice.Slice; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.ValueBlock; +import io.trino.spi.connector.ConnectorSession; + +import java.util.List; +import java.util.Optional; + +public class EmptyRowType + implements Type +{ + public static final EmptyRowType EMPTY_ROW = new EmptyRowType(); + + private final TypeSignature signature = new TypeSignature(StandardTypes.EMPTY_ROW); + + private EmptyRowType() {} + + @Override + public TypeSignature getTypeSignature() + { + return signature; + } + + @Override + public String getDisplayName() + { + return signature.toString(); + } + + @Override + public boolean isComparable() + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public boolean isOrderable() + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public Class getJavaType() + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public Class getValueBlockType() + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public List getTypeParameters() + { + return List.of(); + } + + @Override + public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public Object getObjectValue(ConnectorSession session, Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public boolean getBoolean(Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public long getLong(Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public double getDouble(Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public Slice getSlice(Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public Object getObject(Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void writeBoolean(BlockBuilder blockBuilder, boolean value) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void writeLong(BlockBuilder blockBuilder, long value) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void writeDouble(BlockBuilder blockBuilder, double value) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void writeSlice(BlockBuilder blockBuilder, Slice value) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void writeSlice(BlockBuilder blockBuilder, Slice value, int offset, int length) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void writeObject(BlockBuilder blockBuilder, Object value) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void appendTo(Block block, int position, BlockBuilder blockBuilder) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public Optional getNextValue(Object value) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public int getFlatFixedSize() + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public boolean isFlatVariableWidth() + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public int getFlatVariableWidthSize(Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public int relocateFlatVariableWidthOffsets(byte[] fixedSizeSlice, int fixedSizeOffset, byte[] variableSizeSlice, int variableSizeOffset) + { + throw new UnsupportedOperationException(getClass().getName()); + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/MultisetType.java b/core/trino-spi/src/main/java/io/trino/spi/type/MultisetType.java new file mode 100644 index 000000000000..081aefff9f9e --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/type/MultisetType.java @@ -0,0 +1,95 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.type; + +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.MapBlock; +import io.trino.spi.block.SqlMap; +import io.trino.spi.connector.ConnectorSession; + +import static io.trino.spi.type.StandardTypes.MULTISET; + +public class MultisetType + extends AbstractType // TODO complete implementation needed to implement constant Values +{ + private final Type elementType; + + public MultisetType(Type elementType) + { + // TODO same as map type? + super(new TypeSignature(MULTISET, TypeSignatureParameter.typeParameter(elementType.getTypeSignature())), SqlMap.class, MapBlock.class); + this.elementType = elementType; + } + + public Type getElementType() + { + return elementType; + } + + @Override + public String getDisplayName() + { + return MULTISET + "(" + elementType.getDisplayName() + ")"; + } + + @Override + public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public Object getObjectValue(ConnectorSession session, Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void appendTo(Block block, int position, BlockBuilder blockBuilder) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public int getFlatFixedSize() + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public boolean isFlatVariableWidth() + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public int getFlatVariableWidthSize(Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public int relocateFlatVariableWidthOffsets(byte[] fixedSizeSlice, int fixedSizeOffset, byte[] variableSizeSlice, int variableSizeOffset) + { + throw new UnsupportedOperationException(getClass().getName()); + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/StandardTypes.java b/core/trino-spi/src/main/java/io/trino/spi/type/StandardTypes.java index e39e7ba6bde9..9345a23edd86 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/StandardTypes.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/StandardTypes.java @@ -45,6 +45,9 @@ public final class StandardTypes public static final String IPADDRESS = "ipaddress"; public static final String GEOMETRY = "Geometry"; public static final String UUID = "uuid"; + public static final String MULTISET = "multiset"; + public static final String EMPTY_ROW = "empty row"; + public static final String VOID = "void"; private StandardTypes() {} } diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/VoidType.java b/core/trino-spi/src/main/java/io/trino/spi/type/VoidType.java new file mode 100644 index 000000000000..1aba8611afd5 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/type/VoidType.java @@ -0,0 +1,197 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.type; + +import io.airlift.slice.Slice; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.ValueBlock; +import io.trino.spi.connector.ConnectorSession; + +import java.util.List; +import java.util.Optional; + +public class VoidType + implements Type +{ + public static final VoidType VOID = new VoidType(); + + private final TypeSignature signature = new TypeSignature(StandardTypes.VOID); + + private VoidType() {} + + @Override + public TypeSignature getTypeSignature() + { + return signature; + } + + @Override + public String getDisplayName() + { + return signature.toString(); + } + + @Override + public boolean isComparable() + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public boolean isOrderable() + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public Class getJavaType() + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public Class getValueBlockType() + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public List getTypeParameters() + { + return List.of(); + } + + @Override + public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public Object getObjectValue(ConnectorSession session, Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public boolean getBoolean(Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public long getLong(Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public double getDouble(Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public Slice getSlice(Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public Object getObject(Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void writeBoolean(BlockBuilder blockBuilder, boolean value) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void writeLong(BlockBuilder blockBuilder, long value) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void writeDouble(BlockBuilder blockBuilder, double value) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void writeSlice(BlockBuilder blockBuilder, Slice value) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void writeSlice(BlockBuilder blockBuilder, Slice value, int offset, int length) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void writeObject(BlockBuilder blockBuilder, Object value) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void appendTo(Block block, int position, BlockBuilder blockBuilder) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public Optional getNextValue(Object value) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public int getFlatFixedSize() + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public boolean isFlatVariableWidth() + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public int getFlatVariableWidthSize(Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public int relocateFlatVariableWidthOffsets(byte[] fixedSizeSlice, int fixedSizeOffset, byte[] variableSizeSlice, int variableSizeOffset) + { + throw new UnsupportedOperationException(getClass().getName()); + } +} +