Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IR cleanups #21286

Merged
merged 6 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import io.trino.sql.ir.Not;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.IrExpressionInterpreter;
import io.trino.sql.planner.NoOpSymbolResolver;
import io.trino.sql.planner.Symbol;
import io.trino.util.DisjointSet;
import jakarta.annotation.Nullable;
Expand Down Expand Up @@ -99,20 +98,14 @@ public PlanNodeStatsEstimate filterStats(
private Expression simplifyExpression(Session session, Expression predicate)
{
// TODO reuse io.trino.sql.planner.iterative.rule.SimplifyExpressions.rewrite
Expression value = new IrExpressionInterpreter(predicate, plannerContext, session).optimize();

IrExpressionInterpreter interpreter = new IrExpressionInterpreter(predicate, plannerContext, session);
Object value = interpreter.optimize(NoOpSymbolResolver.INSTANCE);

if (value instanceof Expression expression) {
return expression;
}

if (value == null) {
if (value instanceof Constant constant && constant.value() == null) {
// Expression evaluates to SQL null, which in Filter is equivalent to false. This assumes the expression is a top-level expression (eg. not in NOT).
value = false;
value = Booleans.FALSE;
}

return new Constant(BOOLEAN, value);
return value;
}

private class FilterExpressionStatsCalculatingVisitor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import io.trino.sql.ir.IrVisitor;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.IrExpressionInterpreter;
import io.trino.sql.planner.NoOpSymbolResolver;
import io.trino.sql.planner.Symbol;

import java.util.OptionalDouble;
Expand Down Expand Up @@ -130,23 +129,20 @@ else if (node.function().getName().equals(builtinFunctionName(ADD)) ||
return processArithmetic(node);
}

IrExpressionInterpreter interpreter = new IrExpressionInterpreter(node, plannerContext, session);
Object value = interpreter.optimize(NoOpSymbolResolver.INSTANCE);
Expression value = new IrExpressionInterpreter(node, plannerContext, session).optimize();

if (value == null) {
if (value instanceof Constant constant && constant.value() == null) {
return nullStatsEstimate();
}

if (value instanceof Expression) {
// value is not a constant
return SymbolStatsEstimate.unknown();
if (value instanceof Constant) {
return SymbolStatsEstimate.builder()
.setNullsFraction(0)
.setDistinctValuesCount(1)
.build();
}

// value is a constant
return SymbolStatsEstimate.builder()
.setNullsFraction(0)
.setDistinctValuesCount(1)
.build();
return SymbolStatsEstimate.unknown();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3699,7 +3699,7 @@ private void createMergeAnalysis(Table table, TableHandle handle, TableSchema ta
// create the RowType that holds all column values
List<RowType.Field> fields = new ArrayList<>();
for (ColumnSchema schema : dataColumnSchemas) {
fields.add(new RowType.Field(Optional.of(schema.getName()), schema.getType()));
fields.add(RowType.field(schema.getType()));
}
fields.add(new RowType.Field(Optional.empty(), BOOLEAN)); // present
fields.add(new RowType.Field(Optional.empty(), TINYINT)); // operation_number
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public final class Booleans
{
public static final Constant TRUE = new Constant(BOOLEAN, true);
public static final Constant FALSE = new Constant(BOOLEAN, false);
public static final Constant NULL_BOOLEAN = new Constant(BOOLEAN, null);

private Booleans() {}
}
12 changes: 5 additions & 7 deletions core/trino-main/src/main/java/io/trino/sql/ir/Case.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@
import io.trino.spi.type.Type;

import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.sql.ir.IrUtils.validateType;
import static java.util.Objects.requireNonNull;

@JsonSerialize
public record Case(List<WhenClause> whenClauses, Optional<Expression> defaultValue)
public record Case(List<WhenClause> whenClauses, Expression defaultValue)
implements Expression
{
public Case
Expand All @@ -42,9 +41,7 @@ public record Case(List<WhenClause> whenClauses, Optional<Expression> defaultVal
validateType(whenClauses.getFirst().getResult().type(), whenClauses.get(i).getResult());
}

if (defaultValue.isPresent()) {
validateType(whenClauses.getFirst().getResult().type(), defaultValue.get());
}
validateType(whenClauses.getFirst().getResult().type(), defaultValue);
}

@Override
Expand All @@ -67,7 +64,8 @@ public List<? extends Expression> children()
builder.add(clause.getOperand());
builder.add(clause.getResult());
});
defaultValue.ifPresent(builder::add);

builder.add(defaultValue);

return builder.build();
}
Expand All @@ -79,6 +77,6 @@ public String toString()
whenClauses.stream()
.map(WhenClause::toString)
.collect(Collectors.joining(", ")),
defaultValue.map(Expression::toString).orElse("null"));
defaultValue);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ protected Void visitSwitch(Switch node, C context)
process(clause.getResult(), context);
}

node.defaultValue()
.ifPresent(value -> process(value, context));
process(node.defaultValue(), context);

return null;
}
Expand Down Expand Up @@ -130,8 +129,8 @@ protected Void visitCase(Case node, C context)
process(clause.getOperand(), context);
process(clause.getResult(), context);
}
node.defaultValue()
.ifPresent(value -> process(value, context));

process(node.defaultValue(), context);

return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,7 @@ protected String visitCase(Case node, Void context)
parts.add(format(whenClause, context));
}

node.defaultValue()
.ifPresent(value -> parts.add("ELSE").add(process(value, context)));

parts.add("ELSE").add(process(node.defaultValue(), context));
parts.add("END");

return "(" + Joiner.on(' ').join(parts.build()) + ")";
Expand All @@ -211,9 +209,7 @@ protected String visitSwitch(Switch node, Void context)
parts.add(format(whenClause, context));
}

node.defaultValue()
.ifPresent(value -> parts.add("ELSE").add(process(value, context)));

parts.add("ELSE").add(process(node.defaultValue(), context));
parts.add("END");

return "(" + Joiner.on(' ').join(parts.build()) + ")";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,9 @@ protected Expression visitCase(Case node, Context<C> context)
builder.add(rewriteWhenClause(expression, context));
}

Optional<Expression> defaultValue = node.defaultValue()
.map(value -> rewrite(value, context.get()));
Expression defaultValue = rewrite(node.defaultValue(), context.get());

if (!sameElements(node.defaultValue(), defaultValue) || !sameElements(node.whenClauses(), builder.build())) {
if (node.defaultValue() != defaultValue || !sameElements(node.whenClauses(), builder.build())) {
return new Case(builder.build(), defaultValue);
}

Expand All @@ -275,11 +274,10 @@ protected Expression visitSwitch(Switch node, Context<C> context)
builder.add(rewriteWhenClause(expression, context));
}

Optional<Expression> defaultValue = node.defaultValue()
.map(value -> rewrite(value, context.get()));
Expression defaultValue = rewrite(node.defaultValue(), context.get());

if (operand != node.operand() ||
!sameElements(node.defaultValue(), defaultValue) ||
node.defaultValue() != defaultValue ||
!sameElements(node.whenClauses(), builder.build())) {
return new Switch(operand, builder.build(), defaultValue);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,17 @@

import com.google.common.collect.ImmutableList;

import java.util.Optional;

public class IrExpressions
{
private IrExpressions() {}

public static Expression ifExpression(Expression condition, Expression trueCase)
{
return new Case(ImmutableList.of(new WhenClause(condition, trueCase)), Optional.empty());
return new Case(ImmutableList.of(new WhenClause(condition, trueCase)), new Constant(trueCase.type(), null));
}

public static Expression ifExpression(Expression condition, Expression trueCase, Expression falseCase)
{
return new Case(ImmutableList.of(new WhenClause(condition, trueCase)), Optional.of(falseCase));
return new Case(ImmutableList.of(new WhenClause(condition, trueCase)), falseCase);
}
}
10 changes: 3 additions & 7 deletions core/trino-main/src/main/java/io/trino/sql/ir/Switch.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@
import io.trino.spi.type.Type;

import java.util.List;
import java.util.Optional;

import static io.trino.sql.ir.IrUtils.validateType;
import static java.util.Objects.requireNonNull;

@JsonSerialize
public record Switch(Expression operand, List<WhenClause> whenClauses, Optional<Expression> defaultValue)
public record Switch(Expression operand, List<WhenClause> whenClauses, Expression defaultValue)
implements Expression
{
public Switch
Expand All @@ -40,9 +39,7 @@ public record Switch(Expression operand, List<WhenClause> whenClauses, Optional<
validateType(whenClauses.getFirst().getResult().type(), whenClauses.get(i).getResult());
}

if (defaultValue.isPresent()) {
validateType(whenClauses.getFirst().getResult().type(), defaultValue.get());
}
validateType(whenClauses.getFirst().getResult().type(), defaultValue);
}

@Override
Expand All @@ -68,8 +65,7 @@ public List<? extends Expression> children()
builder.add(clause.getResult());
});

defaultValue.ifPresent(builder::add);

builder.add(defaultValue);
return builder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Reference;
import io.trino.sql.ir.Row;
Expand Down Expand Up @@ -352,28 +353,27 @@ public Expression visitValues(ValuesNode node, Void context)
nonDeterministic[i] = true;
}
else {
IrExpressionInterpreter interpreter = new IrExpressionInterpreter(value, plannerContext, session);
Object item = interpreter.optimize(NoOpSymbolResolver.INSTANCE);
if (item instanceof Expression) {
Expression item = new IrExpressionInterpreter(value, plannerContext, session).optimize();
if (!(item instanceof Constant constant)) {
return TRUE;
}
if (item == null) {
if (constant.value() == null) {
hasNull[i] = true;
}
else {
Type type = node.getOutputSymbols().get(i).getType();
if (!type.isComparable() && !type.isOrderable()) {
return TRUE;
}
if (hasNestedNulls(type, item)) {
if (hasNestedNulls(type, ((Constant) item).value())) {
// Workaround solution to deal with array and row comparisons don't support null elements currently.
// TODO: remove when comparisons are fixed
return TRUE;
}
if (isFloatingPointNaN(type, item)) {
if (isFloatingPointNaN(type, ((Constant) item).value())) {
hasNaN[i] = true;
}
valuesBuilders.get(i).add(item);
valuesBuilders.get(i).add(((Constant) item).value());
}
}
}
Expand All @@ -382,12 +382,11 @@ public Expression visitValues(ValuesNode node, Void context)
if (!DeterminismEvaluator.isDeterministic(expression)) {
return TRUE;
}
IrExpressionInterpreter interpreter = new IrExpressionInterpreter(expression, plannerContext, session);
Object evaluated = interpreter.optimize(NoOpSymbolResolver.INSTANCE);
if (evaluated instanceof Expression) {
Expression evaluated = new IrExpressionInterpreter(expression, plannerContext, session).optimize();
if (!(evaluated instanceof Constant constant)) {
return TRUE;
}
SqlRow sqlRow = (SqlRow) evaluated;
SqlRow sqlRow = (SqlRow) constant.value();
int rawIndex = sqlRow.getRawIndex();
for (int i = 0; i < node.getOutputSymbols().size(); i++) {
Type type = node.getOutputSymbols().get(i).getType();
Expand Down
Loading
Loading