Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New IR -- WIP #24466

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.dialect.trino;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.spi.type.Type;
import io.trino.sql.ir.Logical;

import java.util.List;
import java.util.Map;

import static java.util.Objects.requireNonNull;

public class Attributes
{
public static final AttributeMetadata<Long> CARDINALITY = new AttributeMetadata<>("cardinality", Long.class, true);
public static final AttributeMetadata<ConstantResult> CONSTANT_RESULT = new AttributeMetadata<>("constant_result", ConstantResult.class, true);
public static final AttributeMetadata<String> FIELD_NAME = new AttributeMetadata<>("field_name", String.class, false);
public static final AttributeMetadata<JoinType> JOIN_TYPE = new AttributeMetadata<>("join_type", JoinType.class, false);
public static final AttributeMetadata<LogicalOperator> LOGICAL_OPERATOR = new AttributeMetadata<>("logical_operator", LogicalOperator.class, false);
public static final AttributeMetadata<OutputNames> OUTPUT_NAMES = new AttributeMetadata<>("output_names", OutputNames.class, false);

// TODO define attributes for deeply nested fields, not just top level or column level

private Attributes() {}

public static class AttributeMetadata<T>
{
private final String name;
private final Class<T> type;
private final boolean external;

Check failure on line 43 in core/trino-main/src/main/java/io/trino/sql/dialect/trino/Attributes.java

View workflow job for this annotation

GitHub Actions / error-prone-checks

The field 'external' is never read.

Check failure on line 43 in core/trino-main/src/main/java/io/trino/sql/dialect/trino/Attributes.java

View workflow job for this annotation

GitHub Actions / error-prone-checks

The field 'external' is never read.

private AttributeMetadata(String name, Class<T> type, boolean external)
{
this.name = requireNonNull(name, "name is null");
this.type = requireNonNull(type, "type is null");
this.external = external;
}

public T getAttribute(Map<String, Object> map)
{
return this.type.cast(map.get(this.name));
}

public T putAttribute(Map<String, Object> map, T attribute)
{
return this.type.cast(map.put(name, attribute));
}

public Map<String, Object> asMap(T attribute)
{
return ImmutableMap.of(name, attribute);
}
}

public record ConstantResult(Type type, Object value)
{
public ConstantResult
{
requireNonNull(type, "type is null");
}

@Override
public String toString()
{
return value.toString() + ":" + type.toString();
}
}

public enum JoinType
{
INNER,
LEFT,
RIGHT,
FULL;

public static JoinType of(io.trino.sql.planner.plan.JoinType joinType)
{
return switch (joinType) {
case INNER -> INNER;
case LEFT -> LEFT;
case RIGHT -> RIGHT;
case FULL -> FULL;
};
}
}

public record OutputNames(List<String> outputNames)
{
public OutputNames(List<String> outputNames)
{
this.outputNames = ImmutableList.copyOf(requireNonNull(outputNames, "outputNames is null"));
}

@Override
public String toString()
{
return outputNames.toString();
}
}

public enum LogicalOperator
{
AND,
OR;

public static LogicalOperator of(Logical.Operator operator)
{
return switch (operator) {
case AND -> AND;
case OR -> OR;
};
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.dialect.trino;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.sql.newir.Block;
import io.trino.sql.planner.Symbol;

import java.util.List;
import java.util.Map;

import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static java.util.HashMap.newHashMap;
import static java.util.Objects.requireNonNull;

public record Context(Block.Builder block, Map<Symbol, RowField> symbolMapping)
{
public Context(Block.Builder block)
{
this(block, Map.of());
}

public Context(Block.Builder block, Map<Symbol, RowField> symbolMapping)
{
this.block = requireNonNull(block, "block is null");
this.symbolMapping = ImmutableMap.copyOf(requireNonNull(symbolMapping, "symbolMapping is null"));
}

public static Map<Symbol, RowField> argumentMapping(Block.Parameter parameter, Map<Symbol, String> symbolMapping)
{
return symbolMapping.entrySet().stream()
.collect(toImmutableMap(
Map.Entry::getKey,
entry -> new RowField(parameter, entry.getValue())));
}

public static Map<Symbol, RowField> composedMapping(Context context, Map<Symbol, RowField> newMapping)
{
return composedMapping(context, ImmutableList.of(newMapping));
}

/**
* Compose the correlated mapping from the context with symbol mappings for the current block parameters.
*
* @param context rewrite context containing symbol mapping from all levels of correlation
* @param newMappings list of symbol mappings for current block parameters
* @return composed symbol mapping to rewrite the current block
*/
public static Map<Symbol, RowField> composedMapping(Context context, List<Map<Symbol, RowField>> newMappings)
{
Map<Symbol, RowField> composed = newHashMap(context.symbolMapping().size() + newMappings.stream().mapToInt(Map::size).sum());
composed.putAll(context.symbolMapping());
newMappings.stream().forEach(composed::putAll);
return composed;
}

public record RowField(Block.Parameter row, String field)
{
public RowField
{
requireNonNull(row, "row is null");
requireNonNull(field, "field is null");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.dialect.trino;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;
import io.trino.spi.TrinoException;
import io.trino.sql.dialect.trino.operation.Query;
import io.trino.sql.newir.Block;
import io.trino.sql.newir.Program;
import io.trino.sql.newir.SourceNode;
import io.trino.sql.newir.Value;
import io.trino.sql.planner.plan.OutputNode;
import io.trino.sql.planner.plan.PlanNode;

import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.IntStream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.spi.StandardErrorCode.IR_ERROR;

/**
* ProgramBuilder builds a MLIR program from a PlanNode tree.
* For now, it builds a program for a single query, and assumes that OutputNode is the root PlanNode.
* In the future, we might support multiple statements.
* The resulting program has the special Query operation as the top-level operation.
* It encloses all query computations in one block.
*/
public class ProgramBuilder
{
private ProgramBuilder() {}

public static Program buildProgram(PlanNode root)
{
checkArgument(root instanceof OutputNode, "Expected root to be an OutputNode. Actual: " + root.getClass().getSimpleName());

ValueNameAllocator nameAllocator = new ValueNameAllocator();
ImmutableMap.Builder<Value, SourceNode> valueMapBuilder = ImmutableMap.builder();
Block.Builder rootBlock = new Block.Builder(Optional.of("^query"), ImmutableList.of());

// for now, ignoring return value. Could be worth to remember it as the final terminal Operation in the Program.
root.accept(new RelationalProgramBuilder(nameAllocator, valueMapBuilder), new Context(rootBlock));

// verify if all values are mapped
Set<String> allocatedValues = IntStream.range(0, nameAllocator.label)
.mapToObj(index -> "%" + index)
.collect(toImmutableSet());
Map<Value, SourceNode> valueMap = valueMapBuilder.buildOrThrow();
Set<String> mappedValues = valueMap.keySet().stream()
.map(Value::name)
.collect(toImmutableSet());
if (!Sets.symmetricDifference(allocatedValues, mappedValues).isEmpty()) {
throw new TrinoException(IR_ERROR, "allocated values differ from mapped values");
}

// allocating this name last to avoid stealing the "%0" label. This label won't be printed.
String resultName = nameAllocator.newName();

return new Program(new Query(resultName, rootBlock.build()), valueMap);
}

public static class ValueNameAllocator
{
private int label = 0;

public String newName()
{
return "%" + label++;
}
}
}
Loading
Loading