Skip to content

Commit

Permalink
Clean up GenericLiteral
Browse files Browse the repository at this point in the history
Rename it to Constant and remove the Literal interface,
as there's only a single implementation now.
  • Loading branch information
martint committed Mar 17, 2024
1 parent 3c3e178 commit e21ed01
Show file tree
Hide file tree
Showing 225 changed files with 3,109 additions and 3,184 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
import io.trino.sql.ir.BetweenPredicate;
import io.trino.sql.ir.BooleanLiteral;
import io.trino.sql.ir.ComparisonExpression;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.FunctionCall;
import io.trino.sql.ir.GenericLiteral;
import io.trino.sql.ir.InPredicate;
import io.trino.sql.ir.IrUtils;
import io.trino.sql.ir.IrVisitor;
Expand Down Expand Up @@ -124,7 +124,7 @@ private Expression simplifyExpression(Session session, Expression predicate, Typ
value = false;
}

return GenericLiteral.constant(BOOLEAN, value);
return new Constant(BOOLEAN, value);
}

private class FilterExpressionStatsCalculatingVisitor
Expand Down Expand Up @@ -265,10 +265,10 @@ private PlanNodeStatsEstimate estimateLogicalOr(List<Expression> terms)
}

@Override
protected PlanNodeStatsEstimate visitGenericLiteral(GenericLiteral node, Void context)
protected PlanNodeStatsEstimate visitConstant(Constant node, Void context)
{
if (node.getType().equals(BOOLEAN) && node.getRawValue() != null) {
if ((boolean) node.getRawValue()) {
if (node.getType().equals(BOOLEAN) && node.getValue() != null) {
if ((boolean) node.getValue()) {
return input;
}

Expand All @@ -278,7 +278,7 @@ protected PlanNodeStatsEstimate visitGenericLiteral(GenericLiteral node, Void co
return result.build();
}

return super.visitGenericLiteral(node, context);
return super.visitConstant(node, context);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
import io.trino.sql.ir.ArithmeticUnaryExpression;
import io.trino.sql.ir.Cast;
import io.trino.sql.ir.CoalesceExpression;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.FunctionCall;
import io.trino.sql.ir.GenericLiteral;
import io.trino.sql.ir.IrVisitor;
import io.trino.sql.ir.NodeRef;
import io.trino.sql.ir.SymbolReference;
Expand Down Expand Up @@ -95,10 +95,10 @@ protected SymbolStatsEstimate visitSymbolReference(SymbolReference node, Void co
}

@Override
protected SymbolStatsEstimate visitGenericLiteral(GenericLiteral node, Void context)
protected SymbolStatsEstimate visitConstant(Constant node, Void context)
{
Type type = node.getType();
Object value = node.getRawValue();
Object value = node.getValue();
if (value == null) {
return nullStatsEstimate();
}
Expand Down
20 changes: 10 additions & 10 deletions core/trino-main/src/main/java/io/trino/sql/DynamicFilters.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@
import io.trino.spi.type.VarcharType;
import io.trino.sql.ir.Cast;
import io.trino.sql.ir.ComparisonExpression;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.FunctionCall;
import io.trino.sql.ir.GenericLiteral;
import io.trino.sql.ir.SymbolReference;
import io.trino.sql.planner.BuiltinFunctionCallBuilder;
import io.trino.sql.planner.Symbol;
Expand Down Expand Up @@ -83,8 +83,8 @@ public static Expression createDynamicFilterExpression(
return BuiltinFunctionCallBuilder.resolve(metadata)
.setName(nullAllowed ? NullableFunction.NAME : Function.NAME)
.addArgument(inputType, input)
.addArgument(GenericLiteral.constant(VarcharType.VARCHAR, Slices.utf8Slice(operator.toString())))
.addArgument(GenericLiteral.constant(VarcharType.VARCHAR, Slices.utf8Slice(id.toString())))
.addArgument(new Constant(VarcharType.VARCHAR, Slices.utf8Slice(operator.toString())))
.addArgument(new Constant(VarcharType.VARCHAR, Slices.utf8Slice(id.toString())))
.addArgument(BooleanType.BOOLEAN, nullAllowed ? TRUE_LITERAL : FALSE_LITERAL)
.build();
}
Expand Down Expand Up @@ -145,7 +145,7 @@ public static Expression replaceDynamicFilterId(FunctionCall dynamicFilterFuncti
ImmutableList.of(
dynamicFilterFunctionCall.getArguments().get(0),
dynamicFilterFunctionCall.getArguments().get(1),
GenericLiteral.constant(VarcharType.VARCHAR, Slices.utf8Slice(newId.toString())), // dynamic filter id is the 3rd argument
new Constant(VarcharType.VARCHAR, Slices.utf8Slice(newId.toString())), // dynamic filter id is the 3rd argument
dynamicFilterFunctionCall.getArguments().get(3)));
}

Expand All @@ -170,17 +170,17 @@ public static Optional<Descriptor> getDescriptor(Expression expression)
Expression probeSymbol = arguments.get(0);

Expression operatorExpression = arguments.get(1);
checkArgument(operatorExpression instanceof GenericLiteral literal && literal.getType().equals(VarcharType.VARCHAR), "operatorExpression is expected to be a varchar: %s", operatorExpression.getClass().getSimpleName());
String operatorExpressionString = ((Slice) ((GenericLiteral) operatorExpression).getRawValue()).toStringUtf8();
checkArgument(operatorExpression instanceof Constant literal && literal.getType().equals(VarcharType.VARCHAR), "operatorExpression is expected to be a varchar: %s", operatorExpression.getClass().getSimpleName());
String operatorExpressionString = ((Slice) ((Constant) operatorExpression).getValue()).toStringUtf8();
ComparisonExpression.Operator operator = ComparisonExpression.Operator.valueOf(operatorExpressionString);

Expression idExpression = arguments.get(2);
checkArgument(idExpression instanceof GenericLiteral literal && literal.getType().equals(VarcharType.VARCHAR), "id is expected to be a varchar: %s", idExpression.getClass().getSimpleName());
String id = ((Slice) ((GenericLiteral) idExpression).getRawValue()).toStringUtf8();
checkArgument(idExpression instanceof Constant literal && literal.getType().equals(VarcharType.VARCHAR), "id is expected to be a varchar: %s", idExpression.getClass().getSimpleName());
String id = ((Slice) ((Constant) idExpression).getValue()).toStringUtf8();

Expression nullAllowedExpression = arguments.get(3);
checkArgument(nullAllowedExpression instanceof GenericLiteral literal && literal.getType().equals(BooleanType.BOOLEAN), "nullAllowedExpression is expected to be a boolean constant: %s", nullAllowedExpression.getClass().getSimpleName());
boolean nullAllowed = (boolean) ((GenericLiteral) nullAllowedExpression).getRawValue();
checkArgument(nullAllowedExpression instanceof Constant literal && literal.getType().equals(BooleanType.BOOLEAN), "nullAllowedExpression is expected to be a boolean constant: %s", nullAllowedExpression.getClass().getSimpleName());
boolean nullAllowed = (boolean) ((Constant) nullAllowedExpression).getValue();
return Optional.of(new Descriptor(new DynamicFilterId(id), probeSymbol, operator, nullAllowed));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

public final class BooleanLiteral
{
public static final GenericLiteral TRUE_LITERAL = GenericLiteral.constant(BOOLEAN, true);
public static final GenericLiteral FALSE_LITERAL = GenericLiteral.constant(BOOLEAN, false);
public static final Constant TRUE_LITERAL = new Constant(BOOLEAN, true);
public static final Constant FALSE_LITERAL = new Constant(BOOLEAN, false);

private BooleanLiteral() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,33 +28,28 @@
import static io.trino.spi.type.TypeUtils.readNativeValue;
import static io.trino.spi.type.TypeUtils.writeNativeValue;

public final class GenericLiteral
extends Literal
public final class Constant
extends Expression
{
private final Type type;
private final Object rawValue;

public static GenericLiteral constant(Type type, Object rawValue)
{
return new GenericLiteral(type, rawValue);
}
private final Object value;

@JsonCreator
@DoNotCall // For JSON deserialization only
public static GenericLiteral fromJson(
public static Constant fromJson(
@JsonProperty Type type,
@JsonProperty Block rawValueAsBlock)
@JsonProperty Block valueAsBlock)
{
return new GenericLiteral(type, readNativeValue(type, rawValueAsBlock, 0));
return new Constant(type, readNativeValue(type, valueAsBlock, 0));
}

public GenericLiteral(Type type, Object rawValue)
public Constant(Type type, Object value)
{
if (rawValue != null && !Primitives.wrap(type.getJavaType()).isAssignableFrom(rawValue.getClass())) {
throw new IllegalArgumentException("Improper Java type (%s) for type '%s'".formatted(rawValue.getClass().getName(), type));
if (value != null && !Primitives.wrap(type.getJavaType()).isAssignableFrom(value.getClass())) {
throw new IllegalArgumentException("Improper Java type (%s) for type '%s'".formatted(value.getClass().getName(), type));
}
this.type = type;
this.rawValue = rawValue;
this.value = value;
}

@JsonProperty
Expand All @@ -64,22 +59,22 @@ public Type getType()
}

@JsonProperty
public Block getRawValueAsBlock()
public Block getValueAsBlock()
{
BlockBuilder blockBuilder = type.createBlockBuilder(null, 1);
writeNativeValue(type, blockBuilder, rawValue);
writeNativeValue(type, blockBuilder, value);
return blockBuilder.build();
}

public Object getRawValue()
public Object getValue()
{
return rawValue;
return value;
}

@Override
public <R, C> R accept(IrVisitor<R, C> visitor, C context)
{
return visitor.visitGenericLiteral(this, context);
return visitor.visitConstant(this, context);
}

@Override
Expand All @@ -97,21 +92,21 @@ public boolean equals(Object o)
if (o == null || getClass() != o.getClass()) {
return false;
}
GenericLiteral that = (GenericLiteral) o;
return Objects.equals(type, that.type) && Objects.equals(rawValue, that.rawValue);
Constant that = (Constant) o;
return Objects.equals(type, that.type) && Objects.equals(value, that.value);
}

@Override
public int hashCode()
{
return Objects.hash(type, rawValue);
return Objects.hash(type, value);
}

@Override
public String toString()
{
return "Literal[%s, %s]".formatted(
return "Constant[%s, %s]".formatted(
type,
rawValue == null ? "<null>" : type.getObjectValue(null, getRawValueAsBlock(), 0));
value == null ? "<null>" : type.getObjectValue(null, getValueAsBlock(), 0));
}
}
4 changes: 2 additions & 2 deletions core/trino-main/src/main/java/io/trino/sql/ir/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
@JsonSubTypes.Type(value = CoalesceExpression.class, name = "coalesce"),
@JsonSubTypes.Type(value = ComparisonExpression.class, name = "comparison"),
@JsonSubTypes.Type(value = FunctionCall.class, name = "call"),
@JsonSubTypes.Type(value = GenericLiteral.class, name = "constant"),
@JsonSubTypes.Type(value = Constant.class, name = "constant"),
@JsonSubTypes.Type(value = IfExpression.class, name = "if"),
@JsonSubTypes.Type(value = InPredicate.class, name = "in"),
@JsonSubTypes.Type(value = IsNotNullPredicate.class, name = "isNotNull"),
Expand All @@ -52,7 +52,7 @@ public abstract sealed class Expression
permits ArithmeticBinaryExpression, ArithmeticUnaryExpression, Array, BetweenPredicate,
BindExpression, Cast, CoalesceExpression, ComparisonExpression, FunctionCall,
IfExpression, InPredicate, IsNotNullPredicate, IsNullPredicate,
LambdaExpression, Literal, LogicalExpression,
LambdaExpression, Constant, LogicalExpression,
NotExpression, NullIfExpression, Row, SearchedCaseExpression, SimpleCaseExpression,
SubscriptExpression, SymbolReference, WhenClause
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ public static String formatExpression(Expression expression)
public static class Formatter
extends IrVisitor<String, Void>
{
private final Optional<Function<Literal, String>> literalFormatter;
private final Optional<Function<Constant, String>> literalFormatter;
private final Optional<Function<SymbolReference, String>> symbolReferenceFormatter;

public Formatter(
Optional<Function<Literal, String>> literalFormatter,
Optional<Function<Constant, String>> literalFormatter,
Optional<Function<SymbolReference, String>> symbolReferenceFormatter)
{
this.literalFormatter = requireNonNull(literalFormatter, "literalFormatter is null");
Expand Down Expand Up @@ -76,16 +76,16 @@ protected String visitSubscriptExpression(SubscriptExpression node, Void context
}

@Override
protected String visitGenericLiteral(GenericLiteral node, Void context)
protected String visitConstant(Constant node, Void context)
{
return literalFormatter
.map(formatter -> formatter.apply(node))
.orElseGet(() -> {
if (node.getRawValue() == null) {
if (node.getValue() == null) {
return "null::" + node.getType();
}
else {
return node.getType() + " '" + node.getType().getObjectValue(null, node.getRawValueAsBlock(), 0) + "'";
return node.getType() + " '" + node.getType().getObjectValue(null, node.getValueAsBlock(), 0) + "'";
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ public Expression rewriteInPredicate(InPredicate node, C context, ExpressionTree
return rewriteExpression(node, context, treeRewriter);
}

public Expression rewriteLiteral(Literal node, C context, ExpressionTreeRewriter<C> treeRewriter)
public Expression rewriteConstant(Constant node, C context, ExpressionTreeRewriter<C> treeRewriter)
{
return rewriteExpression(node, context, treeRewriter);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -512,10 +512,10 @@ public Expression visitInPredicate(InPredicate node, Context<C> context)
}

@Override
public Expression visitLiteral(Literal node, Context<C> context)
public Expression visitConstant(Constant node, Context<C> context)
{
if (!context.isDefaultRewrite()) {
Expression result = rewriter.rewriteLiteral(node, context.get(), ExpressionTreeRewriter.this);
Expression result = rewriter.rewriteConstant(node, context.get(), ExpressionTreeRewriter.this);
if (result != null) {
return result;
}
Expand Down
6 changes: 3 additions & 3 deletions core/trino-main/src/main/java/io/trino/sql/ir/IrUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,11 @@ public static Function<Expression, Expression> expressionOrNullSymbols(Predicate
*/
public static boolean isEffectivelyLiteral(PlannerContext plannerContext, Session session, Expression expression)
{
if (expression instanceof Literal) {
if (expression instanceof Constant) {
return true;
}
if (expression instanceof Cast) {
return ((Cast) expression).getExpression() instanceof Literal
if (expression instanceof Cast cast) {
return cast.getExpression() instanceof Constant
// a Cast(Literal(...)) can fail, so this requires verification
&& constantExpressionEvaluatesSuccessfully(plannerContext, session, expression);
}
Expand Down
7 changes: 1 addition & 6 deletions core/trino-main/src/main/java/io/trino/sql/ir/IrVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,11 @@ protected R visitComparisonExpression(ComparisonExpression node, C context)
return visitExpression(node, context);
}

protected R visitLiteral(Literal node, C context)
protected R visitConstant(Constant node, C context)
{
return visitExpression(node, context);
}

protected R visitGenericLiteral(GenericLiteral node, C context)
{
return visitLiteral(node, context);
}

protected R visitWhenClause(WhenClause node, C context)
{
return visitExpression(node, context);
Expand Down
25 changes: 0 additions & 25 deletions core/trino-main/src/main/java/io/trino/sql/ir/Literal.java

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.FunctionCall;
import io.trino.sql.ir.GenericLiteral;

import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -51,7 +51,7 @@ public BuiltinFunctionCallBuilder setName(String name)
return this;
}

public BuiltinFunctionCallBuilder addArgument(GenericLiteral value)
public BuiltinFunctionCallBuilder addArgument(Constant value)
{
requireNonNull(value, "value is null");
return addArgument(value.getType().getTypeSignature(), value);
Expand Down
Loading

0 comments on commit e21ed01

Please sign in to comment.