Skip to content

Commit

Permalink
Clean up IR class names
Browse files Browse the repository at this point in the history
  • Loading branch information
martint committed Mar 22, 2024
1 parent 05d7474 commit 3bf2818
Show file tree
Hide file tree
Showing 410 changed files with 7,429 additions and 7,422 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
*/
package io.trino.cost;

import io.trino.sql.ir.ComparisonExpression;
import io.trino.sql.ir.Comparison;
import io.trino.sql.planner.Symbol;

import java.util.Optional;
Expand Down Expand Up @@ -45,7 +45,7 @@ public static PlanNodeStatsEstimate estimateExpressionToLiteralComparison(
SymbolStatsEstimate expressionStatistics,
Optional<Symbol> expressionSymbol,
OptionalDouble literalValue,
ComparisonExpression.Operator operator)
Comparison.Operator operator)
{
switch (operator) {
case EQUAL:
Expand Down Expand Up @@ -160,7 +160,7 @@ public static PlanNodeStatsEstimate estimateExpressionToExpressionComparison(
Optional<Symbol> leftExpressionSymbol,
SymbolStatsEstimate rightExpressionStatistics,
Optional<Symbol> rightExpressionSymbol,
ComparisonExpression.Operator operator)
Comparison.Operator operator)
{
switch (operator) {
case EQUAL:
Expand Down Expand Up @@ -255,7 +255,7 @@ private static PlanNodeStatsEstimate estimateExpressionNotEqualToExpression(
}

private static PlanNodeStatsEstimate estimateExpressionToExpressionInequality(
ComparisonExpression.Operator operator,
Comparison.Operator operator,
PlanNodeStatsEstimate inputStatistics,
SymbolStatsEstimate leftExpressionStatistics,
Optional<Symbol> leftExpressionSymbol,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@
import io.trino.Session;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.BetweenPredicate;
import io.trino.sql.ir.BooleanLiteral;
import io.trino.sql.ir.ComparisonExpression;
import io.trino.sql.ir.Between;
import io.trino.sql.ir.Booleans;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.FunctionCall;
import io.trino.sql.ir.InPredicate;
import io.trino.sql.ir.In;
import io.trino.sql.ir.IrVisitor;
import io.trino.sql.ir.IsNullPredicate;
import io.trino.sql.ir.LogicalExpression;
import io.trino.sql.ir.NotExpression;
import io.trino.sql.ir.SymbolReference;
import io.trino.sql.ir.IsNull;
import io.trino.sql.ir.Logical;
import io.trino.sql.ir.Not;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.IrExpressionInterpreter;
import io.trino.sql.planner.NoOpSymbolResolver;
import io.trino.sql.planner.Symbol;
Expand All @@ -58,9 +58,9 @@
import static io.trino.spi.statistics.StatsUtil.toStatsRepresentation;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.sql.DynamicFilters.isDynamicFilter;
import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL;
import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL;
import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL;
import static io.trino.sql.ir.Comparison.Operator.EQUAL;
import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL;
import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL;
import static io.trino.sql.ir.IrUtils.and;
import static io.trino.sql.planner.IrExpressionInterpreter.evaluateConstantExpression;
import static io.trino.sql.planner.SymbolsExtractor.extractUnique;
Expand Down Expand Up @@ -148,10 +148,10 @@ protected PlanNodeStatsEstimate visitExpression(Expression node, Void context)
}

@Override
protected PlanNodeStatsEstimate visitNotExpression(NotExpression node, Void context)
protected PlanNodeStatsEstimate visitNot(Not node, Void context)
{
if (node.value() instanceof IsNullPredicate inner) {
if (inner.value() instanceof SymbolReference) {
if (node.value() instanceof IsNull inner) {
if (inner.value() instanceof Reference) {
Symbol symbol = Symbol.from(inner.value());
SymbolStatsEstimate symbolStats = input.getSymbolStatistics(symbol);
PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(input);
Expand All @@ -165,7 +165,7 @@ protected PlanNodeStatsEstimate visitNotExpression(NotExpression node, Void cont
}

@Override
protected PlanNodeStatsEstimate visitLogicalExpression(LogicalExpression node, Void context)
protected PlanNodeStatsEstimate visitLogical(Logical node, Void context)
{
switch (node.operator()) {
case AND:
Expand Down Expand Up @@ -277,9 +277,9 @@ protected PlanNodeStatsEstimate visitConstant(Constant node, Void context)
}

@Override
protected PlanNodeStatsEstimate visitIsNullPredicate(IsNullPredicate node, Void context)
protected PlanNodeStatsEstimate visitIsNull(IsNull node, Void context)
{
if (node.value() instanceof SymbolReference) {
if (node.value() instanceof Reference) {
Symbol symbol = Symbol.from(node.value());
SymbolStatsEstimate symbolStats = input.getSymbolStatistics(symbol);
PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(input);
Expand All @@ -296,7 +296,7 @@ protected PlanNodeStatsEstimate visitIsNullPredicate(IsNullPredicate node, Void
}

@Override
protected PlanNodeStatsEstimate visitBetweenPredicate(BetweenPredicate node, Void context)
protected PlanNodeStatsEstimate visitBetween(Between node, Void context)
{
SymbolStatsEstimate valueStats = getExpressionStats(node.value());
if (valueStats.isUnknown()) {
Expand All @@ -309,8 +309,8 @@ protected PlanNodeStatsEstimate visitBetweenPredicate(BetweenPredicate node, Voi
return PlanNodeStatsEstimate.unknown();
}

Expression lowerBound = new ComparisonExpression(GREATER_THAN_OR_EQUAL, node.value(), node.min());
Expression upperBound = new ComparisonExpression(LESS_THAN_OR_EQUAL, node.value(), node.max());
Expression lowerBound = new Comparison(GREATER_THAN_OR_EQUAL, node.value(), node.min());
Expression upperBound = new Comparison(LESS_THAN_OR_EQUAL, node.value(), node.max());

Expression transformed;
if (isInfinite(valueStats.getLowValue())) {
Expand All @@ -325,10 +325,10 @@ protected PlanNodeStatsEstimate visitBetweenPredicate(BetweenPredicate node, Voi
}

@Override
protected PlanNodeStatsEstimate visitInPredicate(InPredicate node, Void context)
protected PlanNodeStatsEstimate visitIn(In node, Void context)
{
ImmutableList<PlanNodeStatsEstimate> equalityEstimates = node.valueList().stream()
.map(inValue -> process(new ComparisonExpression(EQUAL, node.value(), inValue)))
.map(inValue -> process(new Comparison(EQUAL, node.value(), inValue)))
.collect(toImmutableList());

if (equalityEstimates.stream().anyMatch(PlanNodeStatsEstimate::isOutputRowCountUnknown)) {
Expand All @@ -353,7 +353,7 @@ protected PlanNodeStatsEstimate visitInPredicate(InPredicate node, Void context)
PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(input);
result.setOutputRowCount(min(inEstimate.getOutputRowCount(), notNullValuesBeforeIn));

if (node.value() instanceof SymbolReference) {
if (node.value() instanceof Reference) {
Symbol valueSymbol = Symbol.from(node.value());
SymbolStatsEstimate newSymbolStats = inEstimate.getSymbolStatistics(valueSymbol)
.mapDistinctValuesCount(newDistinctValuesCount -> min(newDistinctValuesCount, valueStats.getDistinctValuesCount()));
Expand All @@ -364,30 +364,30 @@ protected PlanNodeStatsEstimate visitInPredicate(InPredicate node, Void context)

@SuppressWarnings("ArgumentSelectionDefectChecker")
@Override
protected PlanNodeStatsEstimate visitComparisonExpression(ComparisonExpression node, Void context)
protected PlanNodeStatsEstimate visitComparison(Comparison node, Void context)
{
ComparisonExpression.Operator operator = node.operator();
Comparison.Operator operator = node.operator();
Expression left = node.left();
Expression right = node.right();

checkArgument(!(left instanceof Constant && right instanceof Constant), "Literal-to-literal not supported here, should be eliminated earlier");

if (!(left instanceof SymbolReference) && right instanceof SymbolReference) {
if (!(left instanceof Reference) && right instanceof Reference) {
// normalize so that symbol is on the left
return process(new ComparisonExpression(operator.flip(), right, left));
return process(new Comparison(operator.flip(), right, left));
}

if (left instanceof Constant) {
// normalize so that literal is on the right
return process(new ComparisonExpression(operator.flip(), right, left));
return process(new Comparison(operator.flip(), right, left));
}

if (left instanceof SymbolReference && left.equals(right)) {
return process(new NotExpression(new IsNullPredicate(left)));
if (left instanceof Reference && left.equals(right)) {
return process(new Not(new IsNull(left)));
}

SymbolStatsEstimate leftStats = getExpressionStats(left);
Optional<Symbol> leftSymbol = left instanceof SymbolReference ? Optional.of(Symbol.from(left)) : Optional.empty();
Optional<Symbol> leftSymbol = left instanceof Reference ? Optional.of(Symbol.from(left)) : Optional.empty();
if (right instanceof Constant) {
Type type = left.type();
Object literalValue = evaluateConstantExpression(right, plannerContext, session);
Expand All @@ -405,22 +405,22 @@ protected PlanNodeStatsEstimate visitComparisonExpression(ComparisonExpression n
return estimateExpressionToLiteralComparison(input, leftStats, leftSymbol, value, operator);
}

Optional<Symbol> rightSymbol = right instanceof SymbolReference ? Optional.of(Symbol.from(right)) : Optional.empty();
Optional<Symbol> rightSymbol = right instanceof Reference ? Optional.of(Symbol.from(right)) : Optional.empty();
return estimateExpressionToExpressionComparison(input, leftStats, leftSymbol, rightStats, rightSymbol, operator);
}

@Override
protected PlanNodeStatsEstimate visitFunctionCall(FunctionCall node, Void context)
protected PlanNodeStatsEstimate visitCall(Call node, Void context)
{
if (isDynamicFilter(node)) {
return process(BooleanLiteral.TRUE_LITERAL, context);
return process(Booleans.TRUE, context);
}
return PlanNodeStatsEstimate.unknown();
}

private SymbolStatsEstimate getExpressionStats(Expression expression)
{
if (expression instanceof SymbolReference) {
if (expression instanceof Reference) {
Symbol symbol = Symbol.from(expression);
return requireNonNull(input.getSymbolStatistics(symbol), () -> format("No statistics for symbol %s", symbol));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import io.trino.Session;
import io.trino.cost.StatsCalculator.Context;
import io.trino.matching.Pattern;
import io.trino.sql.ir.ComparisonExpression;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Expression;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.plan.JoinNode;
Expand All @@ -36,7 +36,7 @@
import static io.trino.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT;
import static io.trino.cost.PlanNodeStatsEstimateMath.estimateCorrelatedConjunctionRowCount;
import static io.trino.cost.SymbolStatsEstimate.buildFrom;
import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL;
import static io.trino.sql.ir.Comparison.Operator.EQUAL;
import static io.trino.sql.ir.IrUtils.extractConjuncts;
import static io.trino.sql.planner.plan.Patterns.join;
import static io.trino.util.MoreMath.firstNonNaN;
Expand Down Expand Up @@ -183,7 +183,7 @@ private PlanNodeStatsEstimate filterByEquiJoinClauses(
// clause separately because stats estimates would be way off.
List<PlanNodeStatsEstimateWithClause> knownEstimates = clauses.stream()
.map(clause -> {
ComparisonExpression predicate = new ComparisonExpression(EQUAL, clause.getLeft().toSymbolReference(), clause.getRight().toSymbolReference());
Comparison predicate = new Comparison(EQUAL, clause.getLeft().toSymbolReference(), clause.getRight().toSymbolReference());
return new PlanNodeStatsEstimateWithClause(filterStatsCalculator.filterStats(stats, predicate, session), clause);
})
.collect(toImmutableList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@
import io.trino.spi.type.TinyintType;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.ArithmeticBinaryExpression;
import io.trino.sql.ir.ArithmeticNegation;
import io.trino.sql.ir.Arithmetic;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Cast;
import io.trino.sql.ir.CoalesceExpression;
import io.trino.sql.ir.Coalesce;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.FunctionCall;
import io.trino.sql.ir.IrVisitor;
import io.trino.sql.ir.SymbolReference;
import io.trino.sql.ir.Negation;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.IrExpressionInterpreter;
import io.trino.sql.planner.NoOpSymbolResolver;
import io.trino.sql.planner.Symbol;
Expand Down Expand Up @@ -80,7 +80,7 @@ protected SymbolStatsEstimate visitExpression(Expression node, Void context)
}

@Override
protected SymbolStatsEstimate visitSymbolReference(SymbolReference node, Void context)
protected SymbolStatsEstimate visitReference(Reference node, Void context)
{
return input.getSymbolStatistics(Symbol.from(node));
}
Expand All @@ -107,7 +107,7 @@ protected SymbolStatsEstimate visitConstant(Constant node, Void context)
}

@Override
protected SymbolStatsEstimate visitFunctionCall(FunctionCall node, Void context)
protected SymbolStatsEstimate visitCall(Call node, Void context)
{
IrExpressionInterpreter interpreter = new IrExpressionInterpreter(node, plannerContext, session);
Object value = interpreter.optimize(NoOpSymbolResolver.INSTANCE);
Expand Down Expand Up @@ -176,7 +176,7 @@ private boolean isIntegralType(Type type)
}

@Override
protected SymbolStatsEstimate visitArithmeticNegation(ArithmeticNegation node, Void context)
protected SymbolStatsEstimate visitNegation(Negation node, Void context)
{
SymbolStatsEstimate stats = process(node.value());
return SymbolStatsEstimate.buildFrom(stats)
Expand All @@ -186,7 +186,7 @@ protected SymbolStatsEstimate visitArithmeticNegation(ArithmeticNegation node, V
}

@Override
protected SymbolStatsEstimate visitArithmeticBinary(ArithmeticBinaryExpression node, Void context)
protected SymbolStatsEstimate visitArithmetic(Arithmetic node, Void context)
{
requireNonNull(node, "node is null");
SymbolStatsEstimate left = process(node.left());
Expand All @@ -208,11 +208,11 @@ protected SymbolStatsEstimate visitArithmeticBinary(ArithmeticBinaryExpression n
result.setLowValue(NaN)
.setHighValue(NaN);
}
else if (node.operator() == ArithmeticBinaryExpression.Operator.DIVIDE && rightLow < 0 && rightHigh > 0) {
else if (node.operator() == Arithmetic.Operator.DIVIDE && rightLow < 0 && rightHigh > 0) {
result.setLowValue(Double.NEGATIVE_INFINITY)
.setHighValue(Double.POSITIVE_INFINITY);
}
else if (node.operator() == ArithmeticBinaryExpression.Operator.MODULUS) {
else if (node.operator() == Arithmetic.Operator.MODULUS) {
double maxDivisor = max(abs(rightLow), abs(rightHigh));
if (leftHigh <= 0) {
result.setLowValue(max(-maxDivisor, leftLow))
Expand Down Expand Up @@ -242,7 +242,7 @@ else if (leftLow >= 0) {
return result.build();
}

private double operate(ArithmeticBinaryExpression.Operator operator, double left, double right)
private double operate(Arithmetic.Operator operator, double left, double right)
{
switch (operator) {
case ADD:
Expand All @@ -260,7 +260,7 @@ private double operate(ArithmeticBinaryExpression.Operator operator, double left
}

@Override
protected SymbolStatsEstimate visitCoalesceExpression(CoalesceExpression node, Void context)
protected SymbolStatsEstimate visitCoalesce(Coalesce node, Void context)
{
requireNonNull(node, "node is null");
SymbolStatsEstimate result = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import io.trino.cost.StatsCalculator.Context;
import io.trino.matching.Pattern;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.NotExpression;
import io.trino.sql.ir.SymbolReference;
import io.trino.sql.ir.Not;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.PlanNode;
Expand Down Expand Up @@ -134,15 +134,15 @@ private Optional<SemiJoinOutputFilter> extractSemiJoinOutputFilter(Expression pr
Expression remainingPredicate = combineConjuncts(conjuncts.stream()
.filter(conjunct -> conjunct != semiJoinOutputReference)
.collect(toImmutableList()));
boolean negated = semiJoinOutputReference instanceof NotExpression;
boolean negated = semiJoinOutputReference instanceof Not;
return Optional.of(new SemiJoinOutputFilter(negated, remainingPredicate));
}

private static boolean isSemiJoinOutputReference(Expression conjunct, Symbol semiJoinOutput)
{
SymbolReference semiJoinOutputSymbolReference = semiJoinOutput.toSymbolReference();
Reference semiJoinOutputSymbolReference = semiJoinOutput.toSymbolReference();
return conjunct.equals(semiJoinOutputSymbolReference) ||
(conjunct instanceof NotExpression && ((NotExpression) conjunct).value().equals(semiJoinOutputSymbolReference));
(conjunct instanceof Not && ((Not) conjunct).value().equals(semiJoinOutputSymbolReference));
}

private static class SemiJoinOutputFilter
Expand Down
Loading

0 comments on commit 3bf2818

Please sign in to comment.