Skip to content

Commit

Permalink
Always require default value for Switch and Case
Browse files Browse the repository at this point in the history
  • Loading branch information
martint committed Mar 27, 2024
1 parent e96e06e commit 4e9403b
Show file tree
Hide file tree
Showing 26 changed files with 192 additions and 235 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public final class Booleans
{
public static final Constant TRUE = new Constant(BOOLEAN, true);
public static final Constant FALSE = new Constant(BOOLEAN, false);
public static final Constant NULL_BOOLEAN = new Constant(BOOLEAN, null);

private Booleans() {}
}
12 changes: 5 additions & 7 deletions core/trino-main/src/main/java/io/trino/sql/ir/Case.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@
import io.trino.spi.type.Type;

import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.sql.ir.IrUtils.validateType;
import static java.util.Objects.requireNonNull;

@JsonSerialize
public record Case(List<WhenClause> whenClauses, Optional<Expression> defaultValue)
public record Case(List<WhenClause> whenClauses, Expression defaultValue)
implements Expression
{
public Case
Expand All @@ -42,9 +41,7 @@ public record Case(List<WhenClause> whenClauses, Optional<Expression> defaultVal
validateType(whenClauses.getFirst().getResult().type(), whenClauses.get(i).getResult());
}

if (defaultValue.isPresent()) {
validateType(whenClauses.getFirst().getResult().type(), defaultValue.get());
}
validateType(whenClauses.getFirst().getResult().type(), defaultValue);
}

@Override
Expand All @@ -67,7 +64,8 @@ public List<? extends Expression> children()
builder.add(clause.getOperand());
builder.add(clause.getResult());
});
defaultValue.ifPresent(builder::add);

builder.add(defaultValue);

return builder.build();
}
Expand All @@ -79,6 +77,6 @@ public String toString()
whenClauses.stream()
.map(WhenClause::toString)
.collect(Collectors.joining(", ")),
defaultValue.map(Expression::toString).orElse("null"));
defaultValue);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ protected Void visitSwitch(Switch node, C context)
process(clause.getResult(), context);
}

node.defaultValue()
.ifPresent(value -> process(value, context));
process(node.defaultValue(), context);

return null;
}
Expand Down Expand Up @@ -130,8 +129,8 @@ protected Void visitCase(Case node, C context)
process(clause.getOperand(), context);
process(clause.getResult(), context);
}
node.defaultValue()
.ifPresent(value -> process(value, context));

process(node.defaultValue(), context);

return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,7 @@ protected String visitCase(Case node, Void context)
parts.add(format(whenClause, context));
}

node.defaultValue()
.ifPresent(value -> parts.add("ELSE").add(process(value, context)));

parts.add("ELSE").add(process(node.defaultValue(), context));
parts.add("END");

return "(" + Joiner.on(' ').join(parts.build()) + ")";
Expand All @@ -211,9 +209,7 @@ protected String visitSwitch(Switch node, Void context)
parts.add(format(whenClause, context));
}

node.defaultValue()
.ifPresent(value -> parts.add("ELSE").add(process(value, context)));

parts.add("ELSE").add(process(node.defaultValue(), context));
parts.add("END");

return "(" + Joiner.on(' ').join(parts.build()) + ")";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,9 @@ protected Expression visitCase(Case node, Context<C> context)
builder.add(rewriteWhenClause(expression, context));
}

Optional<Expression> defaultValue = node.defaultValue()
.map(value -> rewrite(value, context.get()));
Expression defaultValue = rewrite(node.defaultValue(), context.get());

if (!sameElements(node.defaultValue(), defaultValue) || !sameElements(node.whenClauses(), builder.build())) {
if (node.defaultValue() != defaultValue || !sameElements(node.whenClauses(), builder.build())) {
return new Case(builder.build(), defaultValue);
}

Expand All @@ -275,11 +274,10 @@ protected Expression visitSwitch(Switch node, Context<C> context)
builder.add(rewriteWhenClause(expression, context));
}

Optional<Expression> defaultValue = node.defaultValue()
.map(value -> rewrite(value, context.get()));
Expression defaultValue = rewrite(node.defaultValue(), context.get());

if (operand != node.operand() ||
!sameElements(node.defaultValue(), defaultValue) ||
node.defaultValue() != defaultValue ||
!sameElements(node.whenClauses(), builder.build())) {
return new Switch(operand, builder.build(), defaultValue);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,17 @@

import com.google.common.collect.ImmutableList;

import java.util.Optional;

public class IrExpressions
{
private IrExpressions() {}

public static Expression ifExpression(Expression condition, Expression trueCase)
{
return new Case(ImmutableList.of(new WhenClause(condition, trueCase)), Optional.empty());
return new Case(ImmutableList.of(new WhenClause(condition, trueCase)), new Constant(trueCase.type(), null));
}

public static Expression ifExpression(Expression condition, Expression trueCase, Expression falseCase)
{
return new Case(ImmutableList.of(new WhenClause(condition, trueCase)), Optional.of(falseCase));
return new Case(ImmutableList.of(new WhenClause(condition, trueCase)), falseCase);
}
}
10 changes: 3 additions & 7 deletions core/trino-main/src/main/java/io/trino/sql/ir/Switch.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@
import io.trino.spi.type.Type;

import java.util.List;
import java.util.Optional;

import static io.trino.sql.ir.IrUtils.validateType;
import static java.util.Objects.requireNonNull;

@JsonSerialize
public record Switch(Expression operand, List<WhenClause> whenClauses, Optional<Expression> defaultValue)
public record Switch(Expression operand, List<WhenClause> whenClauses, Expression defaultValue)
implements Expression
{
public Switch
Expand All @@ -40,9 +39,7 @@ public record Switch(Expression operand, List<WhenClause> whenClauses, Optional<
validateType(whenClauses.getFirst().getResult().type(), whenClauses.get(i).getResult());
}

if (defaultValue.isPresent()) {
validateType(whenClauses.getFirst().getResult().type(), defaultValue.get());
}
validateType(whenClauses.getFirst().getResult().type(), defaultValue);
}

@Override
Expand All @@ -68,8 +65,7 @@ public List<? extends Expression> children()
builder.add(clause.getResult());
});

defaultValue.ifPresent(builder::add);

builder.add(defaultValue);
return builder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -230,16 +230,14 @@ else if (Boolean.TRUE.equals(whenOperand)) {
defaultResult = newDefault;
}
else {
defaultResult = processWithExceptionHandling(node.defaultValue().orElse(null), context);
defaultResult = processWithExceptionHandling(node.defaultValue(), context);
}

if (whenClauses.isEmpty()) {
return defaultResult;
}

Expression defaultExpression;
defaultExpression = defaultResult == null ? null : toExpression(defaultResult, ((Expression) node).type());
return new Case(whenClauses, Optional.ofNullable(defaultExpression));
return new Case(whenClauses, toExpression(defaultResult, ((Expression) node).type()));
}

@Override
Expand All @@ -250,7 +248,7 @@ protected Object visitSwitch(Switch node, SymbolResolver context)

// if operand is null, return defaultValue
if (operand == null) {
return processWithExceptionHandling(node.defaultValue().orElse(null), context);
return processWithExceptionHandling(node.defaultValue(), context);
}

Object newDefault = null;
Expand Down Expand Up @@ -281,16 +279,15 @@ protected Object visitSwitch(Switch node, SymbolResolver context)
defaultResult = newDefault;
}
else {
defaultResult = processWithExceptionHandling(node.defaultValue().orElse(null), context);
defaultResult = processWithExceptionHandling(node.defaultValue(), context);
}

if (whenClauses.isEmpty()) {
return defaultResult;
}

Expression defaultExpression;
defaultExpression = defaultResult == null ? null : toExpression(defaultResult, ((Expression) node).type());
return new Switch(toExpression(operand, node.operand().type()), whenClauses, Optional.ofNullable(defaultExpression));
Expression defaultExpression = toExpression(defaultResult, ((Expression) node).type());
return new Switch(toExpression(operand, node.operand().type()), whenClauses, defaultExpression);
}

private boolean isEqual(Object operand1, Type type1, Object operand2, Type type2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,7 @@ public MergeWriterNode plan(Merge merge)
// The case number
rowBuilder.add(new Constant(INTEGER, -1L));

Case caseExpression = new Case(whenClauses.build(), Optional.of(new Row(rowBuilder.build())));
Case caseExpression = new Case(whenClauses.build(), new Row(rowBuilder.build()));

Symbol mergeRowSymbol = symbolAllocator.newSymbol("merge_row", mergeAnalysis.getMergeRowType());
Symbol caseNumberSymbol = symbolAllocator.newSymbol("case_number", INTEGER);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,9 @@ private io.trino.sql.ir.Expression translate(SearchedCaseExpression expression)
translateExpression(clause.getOperand()),
translateExpression(clause.getResult())))
.collect(toImmutableList()),
expression.getDefaultValue().map(this::translateExpression));
expression.getDefaultValue()
.map(this::translateExpression)
.orElse(new Constant(analysis.getType(expression), null)));
}

private io.trino.sql.ir.Expression translate(SimpleCaseExpression expression)
Expand All @@ -417,7 +419,9 @@ private io.trino.sql.ir.Expression translate(SimpleCaseExpression expression)
translateExpression(clause.getOperand()),
translateExpression(clause.getResult())))
.collect(toImmutableList()),
expression.getDefaultValue().map(this::translateExpression));
expression.getDefaultValue()
.map(this::translateExpression)
.orElse(new Constant(analysis.getType(expression), null)));
}

private io.trino.sql.ir.Expression translate(InPredicate expression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,33 +381,27 @@ private Optional<CaseAggregation> extractCaseAggregation(Symbol aggregationSymbo
return Optional.empty();
}

Optional<Expression> cumulativeAggregationDefaultValue = Optional.empty();
if (caseExpression.defaultValue().isPresent()) {
Expression defaultValue = optimizeExpression(caseExpression.defaultValue().get(), context);
if (defaultValue instanceof Constant(Type type, Object value) && value != null) {
if (!name.equals(SUM)) {
return Optional.empty();
}
Expression defaultValue = optimizeExpression(caseExpression.defaultValue(), context);
if (defaultValue instanceof Constant(Type type, Object value) && value != null) {
if (!name.equals(SUM)) {
return Optional.empty();
}

// sum aggregation is only supported if default value is null or 0, otherwise it wouldn't be cumulative
if (type instanceof BigintType
|| type == INTEGER
|| type == SMALLINT
|| type == TINYINT
|| type == DOUBLE
|| type == REAL
|| type instanceof DecimalType) {
if (!value.equals(0L) && !value.equals(0.0d) && !value.equals(Int128.ZERO)) {
return Optional.empty();
}
}
else {
// sum aggregation is only supported if default value is null or 0, otherwise it wouldn't be cumulative
if (type instanceof BigintType
|| type == INTEGER
|| type == SMALLINT
|| type == TINYINT
|| type == DOUBLE
|| type == REAL
|| type instanceof DecimalType) {
if (!value.equals(0L) && !value.equals(0.0d) && !value.equals(Int128.ZERO)) {
return Optional.empty();
}
}

// cumulative aggregation default value need to be CAST to cumulative aggregation input type
cumulativeAggregationDefaultValue = Optional.of(new Cast(caseExpression.defaultValue().get(), aggregationType));
else {
return Optional.empty();
}
}

return Optional.of(new CaseAggregation(
Expand All @@ -417,7 +411,7 @@ private Optional<CaseAggregation> extractCaseAggregation(Symbol aggregationSymbo
name,
caseExpression.whenClauses().get(0).getOperand(),
caseExpression.whenClauses().get(0).getResult(),
cumulativeAggregationDefaultValue));
new Cast(caseExpression.defaultValue(), aggregationType)));
}

private Type getType(Expression expression)
Expand Down Expand Up @@ -445,7 +439,7 @@ private static class CaseAggregation
// CASE expression only result expression
private final Expression result;
// default value of cumulative aggregation
private final Optional<Expression> cumulativeAggregationDefaultValue;
private final Expression cumulativeAggregationDefaultValue;

public CaseAggregation(
Symbol aggregationSymbol,
Expand All @@ -454,7 +448,7 @@ public CaseAggregation(
CatalogSchemaFunctionName name,
Expression operand,
Expression result,
Optional<Expression> cumulativeAggregationDefaultValue)
Expression cumulativeAggregationDefaultValue)
{
this.aggregationSymbol = requireNonNull(aggregationSymbol, "aggregationSymbol is null");
this.function = requireNonNull(function, "function is null");
Expand Down Expand Up @@ -495,7 +489,7 @@ public Expression getResult()
return result;
}

public Optional<Expression> getCumulativeAggregationDefaultValue()
public Expression getCumulativeAggregationDefaultValue()
{
return cumulativeAggregationDefaultValue;
}
Expand Down
Loading

0 comments on commit 4e9403b

Please sign in to comment.