Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Normalize symbol names #21204

Merged
merged 6 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ public RelationPlan planExpand(Query query)
List<NodeAndMappings> recursionStepsToUnion = recursionSteps.build();

List<Symbol> unionOutputSymbols = anchorPlan.getFieldMappings().stream()
.map(symbol -> symbolAllocator.newSymbol(symbol, "_expanded"))
.map(symbol -> symbolAllocator.newSymbol("expanded_" + symbol.getName(), symbol.getType()))
.collect(toImmutableList());

ImmutableListMultimap.Builder<Symbol, Symbol> unionSymbolMapping = ImmutableListMultimap.builder();
Expand Down Expand Up @@ -1144,7 +1144,7 @@ private GroupingSetsPlan planGroupingSets(PlanBuilder subPlan, QuerySpecificatio
Symbol[] fields = new Symbol[subPlan.getTranslations().getFieldSymbols().size()];
for (FieldId field : groupingSetAnalysis.getAllFields()) {
Symbol input = subPlan.getTranslations().getFieldSymbols().get(field.getFieldIndex());
Symbol output = symbolAllocator.newSymbol(input, "gid");
Symbol output = symbolAllocator.newSymbol(input.getName() + "_gid", input.getType());
fields[field.getFieldIndex()] = output;
groupingSetMappings.put(output, input);
}
Expand All @@ -1153,7 +1153,7 @@ private GroupingSetsPlan planGroupingSets(PlanBuilder subPlan, QuerySpecificatio
for (io.trino.sql.tree.Expression expression : groupingSetAnalysis.getComplexExpressions()) {
if (!complexExpressions.containsKey(scopeAwareKey(expression, analysis, subPlan.getScope()))) {
Symbol input = subPlan.translate(expression);
Symbol output = symbolAllocator.newSymbol("expr", analysis.getType(expression), "gid");
Symbol output = symbolAllocator.newSymbol("gid", analysis.getType(expression));
complexExpressions.put(scopeAwareKey(expression, analysis, subPlan.getScope()), output);
groupingSetMappings.put(output, input);
}
Expand Down Expand Up @@ -1550,7 +1550,7 @@ private FrameBoundPlanAndSymbols planFrameBound(PlanBuilder subPlan, PlanAndMapp
coercions.get(sortKey).toSymbolReference(),
expectedType,
false);
sortKeyCoercedForFrameBoundCalculation = symbolAllocator.newSymbol(cast, expectedType);
sortKeyCoercedForFrameBoundCalculation = symbolAllocator.newSymbol(cast);
sortKeyCoercions.put(expectedType, sortKeyCoercedForFrameBoundCalculation);
subPlan = subPlan.withNewRoot(new ProjectNode(
idAllocator.getNextId(),
Expand All @@ -1570,7 +1570,7 @@ private FrameBoundPlanAndSymbols planFrameBound(PlanBuilder subPlan, PlanAndMapp
ImmutableList.of(
sortKeyCoercedForFrameBoundCalculation.toSymbolReference(),
offsetSymbol.toSymbolReference()));
Symbol frameBoundSymbol = symbolAllocator.newSymbol(functionCall, function.getSignature().getReturnType());
Symbol frameBoundSymbol = symbolAllocator.newSymbol(functionCall);
subPlan = subPlan.withNewRoot(new ProjectNode(
idAllocator.getNextId(),
subPlan.getRoot(),
Expand All @@ -1593,7 +1593,7 @@ private FrameBoundPlanAndSymbols planFrameBound(PlanBuilder subPlan, PlanAndMapp
coercions.get(sortKey).toSymbolReference(),
expectedType,
false);
Symbol castSymbol = symbolAllocator.newSymbol(cast, expectedType);
Symbol castSymbol = symbolAllocator.newSymbol(cast);
sortKeyCoercions.put(expectedType, castSymbol);
subPlan = subPlan.withNewRoot(new ProjectNode(
idAllocator.getNextId(),
Expand Down Expand Up @@ -1663,7 +1663,7 @@ else if (actualPrecision > MAX_BIGINT_PRECISION) {
false);
}

Symbol coercedOffsetSymbol = symbolAllocator.newSymbol(offsetToBigint, BIGINT);
Symbol coercedOffsetSymbol = symbolAllocator.newSymbol(offsetToBigint);
subPlan = subPlan.withNewRoot(new ProjectNode(
idAllocator.getNextId(),
subPlan.getRoot(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1506,7 +1506,7 @@ else if (jsonTable.getPlan().orElseThrow() instanceof JsonTableDefaultPlan defau
result = new Cast(result, expectedType);
}

Symbol output = symbolAllocator.newSymbol(result, expectedType);
Symbol output = symbolAllocator.newSymbol(result);
outputLayout.add(output);
assignments.put(output, result);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@
*/
package io.trino.sql.planner;

import com.google.common.base.CharMatcher;
import com.google.common.primitives.Ints;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.sql.analyzer.Field;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Reference;
import jakarta.annotation.Nullable;

import java.util.Collection;
import java.util.HashMap;
Expand All @@ -32,6 +31,11 @@

public class SymbolAllocator
{
public static final CharMatcher EXCLUDED_CHARACTERS = CharMatcher.inRange('a', 'z')
.or(CharMatcher.inRange('0', '9'))
.negate()
.precomputed();

private final Map<String, Symbol> symbols;
private int nextId;

Expand All @@ -48,31 +52,15 @@ public SymbolAllocator(Collection<Symbol> initial)

public Symbol newSymbol(Symbol symbolHint)
{
return newSymbol(symbolHint, null);
}

public Symbol newSymbol(Symbol symbolHint, String suffix)
{
return newSymbol(symbolHint.getName(), symbolHint.getType(), suffix);
return newSymbol(symbolHint.getName(), symbolHint.getType());
}

public Symbol newSymbol(String nameHint, Type type)
{
return newSymbol(nameHint, type, null);
}

public Symbol newHashSymbol()
{
return newSymbol("$hashValue", BigintType.BIGINT);
}

public Symbol newSymbol(String nameHint, Type type, @Nullable String suffix)
{
requireNonNull(nameHint, "nameHint is null");
requireNonNull(type, "type is null");

// TODO: workaround for the fact that QualifiedName lowercases parts
nameHint = nameHint.toLowerCase(ENGLISH);
nameHint = EXCLUDED_CHARACTERS.trimAndCollapseFrom(nameHint.toLowerCase(ENGLISH), '_');

// don't strip the tail if the only _ is the first character
int index = nameHint.lastIndexOf("_");
Expand All @@ -85,37 +73,23 @@ public Symbol newSymbol(String nameHint, Type type, @Nullable String suffix)
}
}

String unique = nameHint;

if (suffix != null) {
unique = unique + "$" + suffix;
}

Symbol symbol = new Symbol(type, unique);
Symbol symbol = new Symbol(type, nameHint);
while (symbols.putIfAbsent(symbol.getName(), symbol) != null) {
symbol = new Symbol(type, unique + "_" + nextId());
symbol = new Symbol(type, nameHint + "_" + nextId());
}

return symbol;
}

public Symbol newSymbol(Expression expression, Type type)
public Symbol newSymbol(Expression expression)
{
return newSymbol(expression, type, null);
}

public Symbol newSymbol(Expression expression, Type type, String suffix)
{
String nameHint = "expr";
if (expression instanceof Call call) {
// symbol allocation can happen during planning, before function calls are rewritten
nameHint = call.function().getName().getFunctionName();
}
else if (expression instanceof Reference reference) {
nameHint = reference.name();
}
String nameHint = switch (expression) {
case Call call -> call.function().getName().getFunctionName();
case Reference reference -> reference.name();
default -> "expr";
};

return newSymbol(nameHint, type, suffix);
return newSymbol(nameHint, expression.type());
}

public Symbol newSymbol(Field field)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ConnectorPageSource;
import io.trino.spi.connector.DynamicFilter;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeManager;
import io.trino.spi.type.TypeSignature;
import io.trino.split.PageSourceManager;
import io.trino.split.SplitManager;
Expand Down Expand Up @@ -384,8 +382,8 @@ private static Result tryCreateSpatialJoin(
return Result.empty();
}

Optional<Symbol> newFirstSymbol = newGeometrySymbol(context, firstArgument, plannerContext.getTypeManager());
Optional<Symbol> newSecondSymbol = newGeometrySymbol(context, secondArgument, plannerContext.getTypeManager());
Optional<Symbol> newFirstSymbol = newGeometrySymbol(context, firstArgument);
Optional<Symbol> newSecondSymbol = newGeometrySymbol(context, secondArgument);

PlanNode leftNode = joinNode.getLeft();
PlanNode rightNode = joinNode.getRight();
Expand Down Expand Up @@ -551,13 +549,13 @@ private static Expression toExpression(Optional<Symbol> optionalSymbol, Expressi
return optionalSymbol.map(symbol -> (Expression) symbol.toSymbolReference()).orElse(defaultExpression);
}

private static Optional<Symbol> newGeometrySymbol(Context context, Expression expression, TypeManager typeManager)
private static Optional<Symbol> newGeometrySymbol(Context context, Expression expression)
{
if (expression instanceof Reference) {
return Optional.empty();
}

return Optional.of(context.getSymbolAllocator().newSymbol(expression, typeManager.getType(GEOMETRY_TYPE_SIGNATURE)));
return Optional.of(context.getSymbolAllocator().newSymbol(expression));
}

private static Optional<Symbol> newRadiusSymbol(Context context, Expression expression)
Expand All @@ -566,7 +564,7 @@ private static Optional<Symbol> newRadiusSymbol(Context context, Expression expr
return Optional.empty();
}

return Optional.of(context.getSymbolAllocator().newSymbol(expression, DOUBLE));
return Optional.of(context.getSymbolAllocator().newSymbol(expression));
}

private static PlanNode addProjection(Context context, PlanNode node, Symbol symbol, Expression expression)
Expand Down Expand Up @@ -595,7 +593,7 @@ private static PlanNode addPartitioningNodes(PlannerContext plannerContext, Cont
radius.ifPresent(value -> spatialPartitionsCall.addArgument(DOUBLE, value));
Call partitioningFunction = spatialPartitionsCall.build();

Symbol partitionsSymbol = context.getSymbolAllocator().newSymbol(partitioningFunction, new ArrayType(INTEGER));
Symbol partitionsSymbol = context.getSymbolAllocator().newSymbol(partitioningFunction);
projections.put(partitionsSymbol, partitioningFunction);

return new UnnestNode(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ public Result apply(AggregationNode parent, Captures captures, Context context)

Symbol marker = markers.get(inputs);
if (marker == null) {
marker = context.getSymbolAllocator().newSymbol(Iterables.getLast(inputs).getName(), BOOLEAN, "distinct");
marker = context.getSymbolAllocator().newSymbol(Iterables.getLast(inputs).getName() + "_distinct", BOOLEAN);
markers.put(inputs, marker);

ImmutableSet.Builder<Symbol> distinctSymbols = ImmutableSet.<Symbol>builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ private Map<PreAggregationKey, PreAggregation> getPreAggregations(List<CaseAggre
preProjection = ifExpression(unionConditions, preProjection);
}

Symbol preProjectionSymbol = context.getSymbolAllocator().newSymbol(preProjection, preProjectionType);
Symbol preProjectionSymbol = context.getSymbolAllocator().newSymbol(preProjection);
Symbol preAggregationSymbol = context.getSymbolAllocator().newSymbol(caseAggregations.iterator().next().getAggregationSymbol());
return new PreAggregation(preAggregationSymbol, preProjection, preProjectionSymbol);
}));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import com.google.common.collect.ImmutableSet;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.spi.type.Type;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.Symbol;
Expand All @@ -35,7 +34,6 @@
import io.trino.sql.planner.rowpattern.ValuePointer;
import io.trino.sql.planner.rowpattern.ir.IrLabel;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -142,7 +140,6 @@ private static ExpressionAndValuePointers rewrite(ExpressionAndValuePointers exp
.filter(Optional::isPresent)
.map(Optional::get)
.collect(toImmutableSet());
List<Type> argumentTypes = pointer.getFunction().getSignature().getArgumentTypes();

ImmutableList.Builder<Expression> rewrittenArguments = ImmutableList.builder();
for (int i = 0; i < pointer.getArguments().size(); i++) {
Expand All @@ -152,7 +149,7 @@ private static ExpressionAndValuePointers rewrite(ExpressionAndValuePointers exp
rewrittenArguments.add(argument);
}
else {
Symbol symbol = context.getSymbolAllocator().newSymbol(argument, argumentTypes.get(i));
Symbol symbol = context.getSymbolAllocator().newSymbol(argument);
assignments.put(symbol, argument);
rewrittenArguments.add(symbol.toSymbolReference());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ private Assignments buildAssignments(PlanNode source, Map<Symbol, Expression> ne
private Symbol symbolForExpression(Context context, Expression expression)
{
checkArgument(!(expression instanceof Reference), "expression '%s' is a SymbolReference", expression);
return context.getSymbolAllocator().newSymbol(expression, expression.type());
return context.getSymbolAllocator().newSymbol(expression);
}

private class PushFilterExpressionBelowJoinFilterRule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.spi.type.Type;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.PartitioningScheme;
Expand Down Expand Up @@ -136,8 +135,7 @@ public Result apply(ProjectNode project, Captures captures, Context context)
continue;
}
Expression translatedExpression = inlineSymbols(translationMap, projection.getValue());
Type type = projection.getKey().getType();
Symbol symbol = context.getSymbolAllocator().newSymbol(translatedExpression, type);
Symbol symbol = context.getSymbolAllocator().newSymbol(translatedExpression);
projections.put(symbol, translatedExpression);
inputs.add(symbol);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.spi.type.Type;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.Symbol;
Expand Down Expand Up @@ -77,8 +76,7 @@ public Result apply(ProjectNode parent, Captures captures, Context context)
// Translate the assignments in the ProjectNode using symbols of the source of the UnionNode
for (Map.Entry<Symbol, Expression> entry : parent.getAssignments().entrySet()) {
Expression translatedExpression = inlineSymbols(outputToInput, entry.getValue());
Type type = entry.getKey().getType();
Symbol symbol = context.getSymbolAllocator().newSymbol(translatedExpression, type);
Symbol symbol = context.getSymbolAllocator().newSymbol(translatedExpression);
assignments.put(symbol, translatedExpression);
projectSymbolMapping.put(entry.getKey(), symbol);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ public PlanWithProperties visitExchange(ExchangeNode node, HashComputationSet pa
List<HashComputation> hashSymbolOrder = ImmutableList.copyOf(preference.getHashes());
Map<HashComputation, Symbol> newHashSymbols = new HashMap<>();
for (HashComputation preferredHashSymbol : hashSymbolOrder) {
newHashSymbols.put(preferredHashSymbol, symbolAllocator.newHashSymbol());
newHashSymbols.put(preferredHashSymbol, symbolAllocator.newSymbol("$hashValue", BIGINT));
}

// rewrite partition function to include new symbols (and precomputed hash)
Expand Down Expand Up @@ -576,7 +576,7 @@ public PlanWithProperties visitUnion(UnionNode node, HashComputationSet parentPr
// create new hash symbols
Map<HashComputation, Symbol> newHashSymbols = new HashMap<>();
for (HashComputation preferredHashSymbol : preference.getHashes()) {
newHashSymbols.put(preferredHashSymbol, symbolAllocator.newHashSymbol());
newHashSymbols.put(preferredHashSymbol, symbolAllocator.newSymbol("$hashValue", BIGINT));
}

// add hash symbols to sources
Expand Down Expand Up @@ -628,7 +628,7 @@ public PlanWithProperties visitProject(ProjectNode node, HashComputationSet pare
Symbol hashSymbol = child.getHashSymbols().get(hashComputation);
Expression hashExpression;
if (hashSymbol == null) {
hashSymbol = symbolAllocator.newHashSymbol();
hashSymbol = symbolAllocator.newSymbol("$hashValue", BIGINT);
hashExpression = hashComputation.getHashExpression(metadata);
}
else {
Expand Down Expand Up @@ -746,7 +746,7 @@ private PlanWithProperties enforce(PlanWithProperties planWithProperties, HashCo
for (HashComputation hashComputation : requiredHashes.getHashes()) {
if (!planWithProperties.getHashSymbols().containsKey(hashComputation)) {
Expression hashExpression = hashComputation.getHashExpression(metadata);
Symbol hashSymbol = symbolAllocator.newHashSymbol();
Symbol hashSymbol = symbolAllocator.newSymbol("$hashValue", BIGINT);
assignments.put(hashSymbol, hashExpression);
outputHashSymbols.put(hashComputation, hashSymbol);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ private AggregationNode createNonDistinctAggregation(
for (Map.Entry<Symbol, Aggregation> entry : aggregateInfo.getAggregations().entrySet()) {
Aggregation aggregation = entry.getValue();
if (aggregation.getMask().isEmpty()) {
Symbol newSymbol = symbolAllocator.newSymbol(entry.getKey().toSymbolReference(), entry.getKey().getType());
Symbol newSymbol = symbolAllocator.newSymbol(entry.getKey().toSymbolReference());
aggregationOutputSymbolsMapBuilder.put(newSymbol, entry.getKey());
if (!duplicatedDistinctSymbol.equals(distinctSymbol)) {
// Handling for cases when mask symbol appears in non distinct aggregations too
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ private Symbol symbolForExpression(Expression expression)
return Symbol.from(expression);
}

return symbolAllocator.newSymbol(expression, expression.type());
return symbolAllocator.newSymbol(expression);
}

private OuterJoinPushDownResult processLimitedOuterJoin(
Expand Down
Loading
Loading