Skip to content

Commit

Permalink
New IR -- WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
kasiafi committed Dec 20, 2024
1 parent 67c941f commit 03bdcf5
Show file tree
Hide file tree
Showing 35 changed files with 4,169 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.dialect.trino;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.spi.type.Type;
import io.trino.sql.ir.Logical;

import java.util.List;
import java.util.Map;

import static java.util.Objects.requireNonNull;

public class Attributes
{
public static final AttributeMetadata<Long> CARDINALITY = new AttributeMetadata<>("cardinality", Long.class, true);
public static final AttributeMetadata<ConstantResult> CONSTANT_RESULT = new AttributeMetadata<>("constant_result", ConstantResult.class, true);
public static final AttributeMetadata<String> FIELD_NAME = new AttributeMetadata<>("field_name", String.class, false);
public static final AttributeMetadata<JoinType> JOIN_TYPE = new AttributeMetadata<>("join_type", JoinType.class, false);
public static final AttributeMetadata<LogicalOperator> LOGICAL_OPERATOR = new AttributeMetadata<>("logical_operator", LogicalOperator.class, false);
public static final AttributeMetadata<OutputNames> OUTPUT_NAMES = new AttributeMetadata<>("output_names", OutputNames.class, false);

// TODO define attributes for deeply nested fields, not just top level or column level

private Attributes() {}

public static class AttributeMetadata<T>
{
private final String name;
private final Class<T> type;
private final boolean external;

private AttributeMetadata(String name, Class<T> type, boolean external)
{
this.name = requireNonNull(name, "name is null");
this.type = requireNonNull(type, "type is null");
this.external = external;
}

public T getAttribute(Map<String, Object> map)
{
return this.type.cast(map.get(this.name));
}

public T putAttribute(Map<String, Object> map, T attribute)
{
return this.type.cast(map.put(name, attribute));
}

public Map<String, Object> asMap(T attribute)
{
return ImmutableMap.of(name, attribute);
}
}

public record ConstantResult(Type type, Object value)
{
public ConstantResult
{
requireNonNull(type, "type is null");
}

@Override
public String toString()
{
return value.toString() + ":" + type.toString();
}
}

public enum JoinType
{
INNER,
LEFT,
RIGHT,
FULL;

public static JoinType of(io.trino.sql.planner.plan.JoinType joinType)
{
return switch (joinType) {
case INNER -> INNER;
case LEFT -> LEFT;
case RIGHT -> RIGHT;
case FULL -> FULL;
};
}
}

public record OutputNames(List<String> outputNames)
{
public OutputNames(List<String> outputNames)
{
this.outputNames = ImmutableList.copyOf(requireNonNull(outputNames, "outputNames is null"));
}

@Override
public String toString()
{
return outputNames.toString();
}
}

public enum LogicalOperator
{
AND,
OR;

public static LogicalOperator of(Logical.Operator operator)
{
return switch (operator) {
case AND -> AND;
case OR -> OR;
};
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.dialect.trino;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.sql.newir.Block;
import io.trino.sql.planner.Symbol;

import java.util.List;
import java.util.Map;

import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static java.util.HashMap.newHashMap;
import static java.util.Objects.requireNonNull;

public record Context(Block.Builder block, Map<Symbol, RowField> symbolMapping)
{
public Context(Block.Builder block)
{
this(block, Map.of());
}

public Context(Block.Builder block, Map<Symbol, RowField> symbolMapping)
{
this.block = requireNonNull(block, "block is null");
this.symbolMapping = ImmutableMap.copyOf(requireNonNull(symbolMapping, "symbolMapping is null"));
}

public static Map<Symbol, RowField> argumentMapping(Block.Parameter parameter, Map<Symbol, String> symbolMapping)
{
return symbolMapping.entrySet().stream()
.collect(toImmutableMap(
Map.Entry::getKey,
entry -> new RowField(parameter, entry.getValue())));
}

public static Map<Symbol, RowField> composedMapping(Context context, Map<Symbol, RowField> newMapping)
{
return composedMapping(context, ImmutableList.of(newMapping));
}

/**
* Compose the correlated mapping from the context with symbol mappings for the current block parameters.
*
* @param context rewrite context containing symbol mapping from all levels of correlation
* @param newMappings list of symbol mappings for current block parameters
* @return composed symbol mapping to rewrite the current block
*/
public static Map<Symbol, RowField> composedMapping(Context context, List<Map<Symbol, RowField>> newMappings)
{
Map<Symbol, RowField> composed = newHashMap(context.symbolMapping().size() + newMappings.stream().mapToInt(Map::size).sum());
composed.putAll(context.symbolMapping());
newMappings.stream().forEach(composed::putAll);
return composed;
}

public record RowField(Block.Parameter row, String field)
{
public RowField
{
requireNonNull(row, "row is null");
requireNonNull(field, "field is null");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.dialect.trino;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;
import io.trino.spi.TrinoException;
import io.trino.sql.dialect.trino.operation.Query;
import io.trino.sql.newir.Block;
import io.trino.sql.newir.Program;
import io.trino.sql.newir.SourceNode;
import io.trino.sql.newir.Value;
import io.trino.sql.planner.plan.OutputNode;
import io.trino.sql.planner.plan.PlanNode;

import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.IntStream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.spi.StandardErrorCode.IR_ERROR;

/**
* ProgramBuilder builds a MLIR program from a PlanNode tree.
* For now, it builds a program for a single query, and assumes that OutputNode is the root PlanNode.
* In the future, we might support multiple statements.
* The resulting program has the special Query operation as the top-level operation.
* It encloses all query computations in one block.
*/
public class ProgramBuilder
{
private ProgramBuilder() {}

public static Program buildProgram(PlanNode root)
{
checkArgument(root instanceof OutputNode, "Expected root to be an OutputNode. Actual: " + root.getClass().getSimpleName());

ValueNameAllocator nameAllocator = new ValueNameAllocator();
ImmutableMap.Builder<Value, SourceNode> valueMapBuilder = ImmutableMap.builder();
Block.Builder rootBlock = new Block.Builder(Optional.of("^query"), ImmutableList.of());

// for now, ignoring return value. Could be worth to remember it as the final terminal Operation in the Program.
root.accept(new RelationalProgramBuilder(nameAllocator, valueMapBuilder), new Context(rootBlock));

// verify if all values are mapped
Set<String> allocatedValues = IntStream.range(0, nameAllocator.label)
.mapToObj(index -> "%" + index)
.collect(toImmutableSet());
Map<Value, SourceNode> valueMap = valueMapBuilder.buildOrThrow();
Set<String> mappedValues = valueMap.keySet().stream()
.map(Value::name)
.collect(toImmutableSet());
if (!Sets.symmetricDifference(allocatedValues, mappedValues).isEmpty()) {
throw new TrinoException(IR_ERROR, "allocated values differ from mapped values");
}

// allocating this name last to avoid stealing the "%0" label. This label won't be printed.
String resultName = nameAllocator.newName();

return new Program(new Query(resultName, rootBlock.build()), valueMap);
}

public static class ValueNameAllocator
{
private int label = 0;

public String newName()
{
return "%" + label++;
}
}
}
Loading

0 comments on commit 03bdcf5

Please sign in to comment.