From 671fc975746b079e1ba13629e3b50aea0c5483f8 Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Sun, 24 Mar 2024 17:12:19 -0700 Subject: [PATCH 1/6] Encapsulate no-bindings optimize call --- .../src/main/java/io/trino/cost/FilterStatsCalculator.java | 3 +-- .../src/main/java/io/trino/cost/ScalarStatsCalculator.java | 3 +-- .../io/trino/sql/planner/EffectivePredicateExtractor.java | 4 ++-- .../java/io/trino/sql/planner/IrExpressionInterpreter.java | 5 +++++ .../iterative/rule/PushAggregationIntoTableScan.java | 4 +--- .../planner/iterative/rule/PushPredicateIntoTableScan.java | 4 +--- .../iterative/rule/PushProjectionIntoTableScan.java | 4 +--- .../planner/iterative/rule/RemoveRedundantDateTrunc.java | 4 +--- .../sql/planner/iterative/rule/SimplifyExpressions.java | 3 +-- .../sql/planner/iterative/rule/UnwrapCastInComparison.java | 4 +--- .../iterative/rule/UnwrapDateTruncInComparison.java | 7 ++----- .../sql/planner/iterative/rule/UnwrapYearInComparison.java | 4 +--- .../trino/sql/planner/optimizations/PredicatePushDown.java | 3 +-- .../sql/planner/optimizations/PropertyDerivations.java | 3 +-- .../main/java/io/trino/sql/routine/SqlRoutinePlanner.java | 3 +-- .../io/trino/sql/TestSqlToRowExpressionTranslator.java | 3 +-- 16 files changed, 22 insertions(+), 39 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java b/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java index cb0d457870a1..97257f53eddf 100644 --- a/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java +++ b/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java @@ -33,7 +33,6 @@ import io.trino.sql.ir.Not; import io.trino.sql.ir.Reference; import io.trino.sql.planner.IrExpressionInterpreter; -import io.trino.sql.planner.NoOpSymbolResolver; import io.trino.sql.planner.Symbol; import io.trino.util.DisjointSet; import jakarta.annotation.Nullable; @@ -101,7 +100,7 @@ private Expression simplifyExpression(Session session, Expression predicate) // TODO reuse io.trino.sql.planner.iterative.rule.SimplifyExpressions.rewrite IrExpressionInterpreter interpreter = new IrExpressionInterpreter(predicate, plannerContext, session); - Object value = interpreter.optimize(NoOpSymbolResolver.INSTANCE); + Object value = interpreter.optimize(); if (value instanceof Expression expression) { return expression; diff --git a/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java b/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java index 7d726dcbeccd..9990c026bdca 100644 --- a/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java +++ b/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java @@ -31,7 +31,6 @@ import io.trino.sql.ir.IrVisitor; import io.trino.sql.ir.Reference; import io.trino.sql.planner.IrExpressionInterpreter; -import io.trino.sql.planner.NoOpSymbolResolver; import io.trino.sql.planner.Symbol; import java.util.OptionalDouble; @@ -131,7 +130,7 @@ else if (node.function().getName().equals(builtinFunctionName(ADD)) || } IrExpressionInterpreter interpreter = new IrExpressionInterpreter(node, plannerContext, session); - Object value = interpreter.optimize(NoOpSymbolResolver.INSTANCE); + Object value = interpreter.optimize(); if (value == null) { return nullStatsEstimate(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/EffectivePredicateExtractor.java b/core/trino-main/src/main/java/io/trino/sql/planner/EffectivePredicateExtractor.java index 634ad05bc940..853c809a0204 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/EffectivePredicateExtractor.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/EffectivePredicateExtractor.java @@ -353,7 +353,7 @@ public Expression visitValues(ValuesNode node, Void context) } else { IrExpressionInterpreter interpreter = new IrExpressionInterpreter(value, plannerContext, session); - Object item = interpreter.optimize(NoOpSymbolResolver.INSTANCE); + Object item = interpreter.optimize(); if (item instanceof Expression) { return TRUE; } @@ -383,7 +383,7 @@ public Expression visitValues(ValuesNode node, Void context) return TRUE; } IrExpressionInterpreter interpreter = new IrExpressionInterpreter(expression, plannerContext, session); - Object evaluated = interpreter.optimize(NoOpSymbolResolver.INSTANCE); + Object evaluated = interpreter.optimize(); if (evaluated instanceof Expression) { return TRUE; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java b/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java index e0578e21d320..fb0c64dec95a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java @@ -129,6 +129,11 @@ public Object optimize(SymbolResolver inputs) return new Visitor(true).processWithExceptionHandling(expression, inputs); } + public Object optimize() + { + return new Visitor(true).processWithExceptionHandling(expression, NoOpSymbolResolver.INSTANCE); + } + private class Visitor extends IrVisitor { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java index 55d2a129671f..16e43bfcf0c4 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java @@ -35,7 +35,6 @@ import io.trino.sql.ir.Reference; import io.trino.sql.planner.ConnectorExpressionTranslator; import io.trino.sql.planner.IrExpressionInterpreter; -import io.trino.sql.planner.NoOpSymbolResolver; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; @@ -189,8 +188,7 @@ public static Optional pushAggregationIntoTableScan( Expression translated = ConnectorExpressionTranslator.translate(session, expression, plannerContext, variableMappings); // ConnectorExpressionTranslator may or may not preserve optimized form of expressions during round-trip. Avoid potential optimizer loop // by ensuring expression is optimized. - Object optimized = new IrExpressionInterpreter(translated, plannerContext, session) - .optimize(NoOpSymbolResolver.INSTANCE); + Object optimized = new IrExpressionInterpreter(translated, plannerContext, session).optimize(); return optimized instanceof Expression optimizedExpression ? optimizedExpression : diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java index 201d7c04cec6..931936e13c5f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java @@ -40,7 +40,6 @@ import io.trino.sql.planner.DomainTranslator; import io.trino.sql.planner.IrExpressionInterpreter; import io.trino.sql.planner.LayoutConstraintEvaluator; -import io.trino.sql.planner.NoOpSymbolResolver; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.FilterNode; @@ -265,8 +264,7 @@ public static Optional pushFilterIntoTableScan( Expression translatedExpression = ConnectorExpressionTranslator.translate(session, remainingConnectorExpression.get(), plannerContext, variableMappings); // ConnectorExpressionTranslator may or may not preserve optimized form of expressions during round-trip. Avoid potential optimizer loop // by ensuring expression is optimized. - Object optimized = new IrExpressionInterpreter(translatedExpression, plannerContext, session) - .optimize(NoOpSymbolResolver.INSTANCE); + Object optimized = new IrExpressionInterpreter(translatedExpression, plannerContext, session).optimize(); translatedExpression = optimized instanceof Expression optimizedExpression ? optimizedExpression : diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionIntoTableScan.java index a4695b80541c..a8cbd38e6af4 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionIntoTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionIntoTableScan.java @@ -36,7 +36,6 @@ import io.trino.sql.ir.NodeRef; import io.trino.sql.planner.ConnectorExpressionTranslator; import io.trino.sql.planner.IrExpressionInterpreter; -import io.trino.sql.planner.NoOpSymbolResolver; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; @@ -149,8 +148,7 @@ public Result apply(ProjectNode project, Captures captures, Context context) Expression translated = ConnectorExpressionTranslator.translate(session, expression, plannerContext, variableMappings); // ConnectorExpressionTranslator may or may not preserve optimized form of expressions during round-trip. Avoid potential optimizer loop // by ensuring expression is optimized. - Object optimized = new IrExpressionInterpreter(translated, plannerContext, session) - .optimize(NoOpSymbolResolver.INSTANCE); + Object optimized = new IrExpressionInterpreter(translated, plannerContext, session).optimize(); return optimized instanceof Expression optimizedExpression ? optimizedExpression : diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantDateTrunc.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantDateTrunc.java index c0914662c180..17b2d9c4483c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantDateTrunc.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantDateTrunc.java @@ -24,7 +24,6 @@ import io.trino.sql.ir.ExpressionTreeRewriter; import io.trino.sql.ir.Reference; import io.trino.sql.planner.IrExpressionInterpreter; -import io.trino.sql.planner.NoOpSymbolResolver; import java.util.Locale; @@ -70,8 +69,7 @@ public Expression rewriteCall(Call node, Void context, ExpressionTreeRewriter innerSymbolsForOuterJoin, Ex private Expression simplifyExpression(Expression expression) { IrExpressionInterpreter optimizer = new IrExpressionInterpreter(expression, plannerContext, session); - Object object = optimizer.optimize(NoOpSymbolResolver.INSTANCE); + Object object = optimizer.optimize(); return object instanceof Expression optimized ? optimized : diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java index afb19c0bdb69..f2a4d4b9102c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java @@ -35,7 +35,6 @@ import io.trino.sql.ir.Reference; import io.trino.sql.planner.DomainTranslator; import io.trino.sql.planner.IrExpressionInterpreter; -import io.trino.sql.planner.NoOpSymbolResolver; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.optimizations.ActualProperties.Global; @@ -774,7 +773,7 @@ public ActualProperties visitProject(ProjectNode node, List in // to take advantage of constant-folding for complex expressions // However, that currently causes errors when those expressions operate on arrays or row types // ("ROW comparison not supported for fields with null elements", etc) - Object value = optimizer.optimize(NoOpSymbolResolver.INSTANCE); + Object value = optimizer.optimize(); if (value instanceof Reference) { Symbol symbol = Symbol.from((Reference) value); diff --git a/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutinePlanner.java b/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutinePlanner.java index d1d0a254f614..db74e1c5a0ca 100644 --- a/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutinePlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutinePlanner.java @@ -29,7 +29,6 @@ import io.trino.sql.ir.Constant; import io.trino.sql.ir.Reference; import io.trino.sql.planner.IrExpressionInterpreter; -import io.trino.sql.planner.NoOpSymbolResolver; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolAllocator; import io.trino.sql.planner.TranslationMap; @@ -329,7 +328,7 @@ private RowExpression toRowExpression(Context context, Expression expression) // optimize the expression IrExpressionInterpreter interpreter = new IrExpressionInterpreter(lambdaCaptureDesugared, plannerContext, session); - Object value = interpreter.optimize(NoOpSymbolResolver.INSTANCE); + Object value = interpreter.optimize(); io.trino.sql.ir.Expression optimized = value instanceof io.trino.sql.ir.Expression optimizedExpression ? optimizedExpression : diff --git a/core/trino-main/src/test/java/io/trino/sql/TestSqlToRowExpressionTranslator.java b/core/trino-main/src/test/java/io/trino/sql/TestSqlToRowExpressionTranslator.java index db77a79acc9d..34923a43b2b0 100644 --- a/core/trino-main/src/test/java/io/trino/sql/TestSqlToRowExpressionTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/TestSqlToRowExpressionTranslator.java @@ -20,7 +20,6 @@ import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.planner.IrExpressionInterpreter; -import io.trino.sql.planner.NoOpSymbolResolver; import io.trino.sql.relational.RowExpression; import io.trino.sql.relational.SqlToRowExpressionTranslator; import org.junit.jupiter.api.Test; @@ -89,7 +88,7 @@ private Expression simplifyExpression(Expression expression) { // Testing simplified expressions is important, since simplification may create CASTs or function calls that cannot be simplified by the ExpressionOptimizer IrExpressionInterpreter interpreter = new IrExpressionInterpreter(expression, PLANNER_CONTEXT, TEST_SESSION); - Object value = interpreter.optimize(NoOpSymbolResolver.INSTANCE); + Object value = interpreter.optimize(); return value instanceof Expression optimized ? optimized : From f42a06cf6ea6467cda87991644ee39e2bb42f292 Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Sun, 24 Mar 2024 17:31:58 -0700 Subject: [PATCH 2/6] Improve clarity of SymbolResolver API Previously, the result value could simultaneously encode a missing binding, a null value and a non-null value. It was up to the caller to check whether the result was an Expression, a Constant, or another non-IR value. The updated version encodes the presence or absence via an Optional, and the null vs non-null via the IR Constant encoding. --- .../io/trino/sql/planner/IrExpressionInterpreter.java | 11 ++++++++--- .../io/trino/sql/planner/LookupSymbolResolver.java | 8 +++++--- .../java/io/trino/sql/planner/NoOpSymbolResolver.java | 8 ++++++-- .../java/io/trino/sql/planner/SymbolResolver.java | 6 +++++- .../iterative/rule/PreAggregateCaseAggregations.java | 2 +- .../sql/planner/optimizations/PredicatePushDown.java | 2 +- .../java/io/trino/sql/TestExpressionInterpreter.java | 4 ++-- 7 files changed, 28 insertions(+), 13 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java b/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java index fb0c64dec95a..e50bd08e04b8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java @@ -169,7 +169,12 @@ private Object processWithExceptionHandling(Expression expression, Object contex @Override protected Object visitReference(Reference node, Object context) { - return ((SymbolResolver) context).getValue(Symbol.from(node)); + Optional binding = ((SymbolResolver) context).getValue(Symbol.from(node)); + if (binding.isPresent()) { + return binding.get().value(); + } + + return node; } @Override @@ -921,10 +926,10 @@ public LambdaSymbolResolver(Map values) } @Override - public Object getValue(Symbol symbol) + public Optional getValue(Symbol symbol) { checkState(values.containsKey(symbol.getName()), "values does not contain %s", symbol); - return values.get(symbol.getName()); + return Optional.of(new Constant(symbol.getType(), values.get(symbol.getName()))); } } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LookupSymbolResolver.java b/core/trino-main/src/main/java/io/trino/sql/planner/LookupSymbolResolver.java index e2f0b1c65953..223e41813faa 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LookupSymbolResolver.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LookupSymbolResolver.java @@ -16,8 +16,10 @@ import com.google.common.collect.ImmutableMap; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.predicate.NullableValue; +import io.trino.sql.ir.Constant; import java.util.Map; +import java.util.Optional; import static java.util.Objects.requireNonNull; @@ -37,14 +39,14 @@ public LookupSymbolResolver(Map assignments, Map getValue(Symbol symbol) { ColumnHandle column = assignments.get(symbol); if (column == null || !bindings.containsKey(column)) { - return symbol.toSymbolReference(); + return Optional.empty(); } - return bindings.get(column).getValue(); + return Optional.of(new Constant(symbol.getType(), bindings.get(column).getValue())); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/NoOpSymbolResolver.java b/core/trino-main/src/main/java/io/trino/sql/planner/NoOpSymbolResolver.java index 32e4268c424f..8cf5173af216 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/NoOpSymbolResolver.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/NoOpSymbolResolver.java @@ -13,14 +13,18 @@ */ package io.trino.sql.planner; +import io.trino.sql.ir.Constant; + +import java.util.Optional; + public class NoOpSymbolResolver implements SymbolResolver { public static final NoOpSymbolResolver INSTANCE = new NoOpSymbolResolver(); @Override - public Object getValue(Symbol symbol) + public Optional getValue(Symbol symbol) { - return symbol.toSymbolReference(); + return Optional.empty(); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/SymbolResolver.java b/core/trino-main/src/main/java/io/trino/sql/planner/SymbolResolver.java index 89afc0e39550..f3dc5f09877f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/SymbolResolver.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/SymbolResolver.java @@ -13,7 +13,11 @@ */ package io.trino.sql.planner; +import io.trino.sql.ir.Constant; + +import java.util.Optional; + public interface SymbolResolver { - Object getValue(Symbol symbol); + Optional getValue(Symbol symbol); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PreAggregateCaseAggregations.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PreAggregateCaseAggregations.java index 0ac04f1c4b8b..eccbf70549af 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PreAggregateCaseAggregations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PreAggregateCaseAggregations.java @@ -429,7 +429,7 @@ private Type getType(Expression expression) private Object optimizeExpression(Expression expression, Context context) { IrExpressionInterpreter expressionInterpreter = new IrExpressionInterpreter(expression, plannerContext, context.getSession()); - return expressionInterpreter.optimize(Symbol::toSymbolReference); + return expressionInterpreter.optimize(); } private static class CaseAggregation diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java index e947aff75e1b..992959f36cf9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java @@ -1193,7 +1193,7 @@ private boolean areExpressionsEquivalent(Expression leftExpression, Expression r private Object nullInputEvaluator(Collection nullSymbols, Expression expression) { return new IrExpressionInterpreter(expression, plannerContext, session) - .optimize(symbol -> nullSymbols.contains(symbol) ? null : symbol.toSymbolReference()); + .optimize(symbol -> nullSymbols.contains(symbol) ? Optional.of(new Constant(symbol.getType(), null)) : Optional.empty()); } private boolean joinEqualityExpression(Expression expression, Collection leftSymbols, Collection rightSymbols) diff --git a/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java b/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java index 5a6eece76c48..f2df8a2ee9d9 100644 --- a/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java +++ b/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java @@ -78,10 +78,10 @@ public class TestExpressionInterpreter private static final SymbolResolver INPUTS = symbol -> { if (symbol.getName().toLowerCase(ENGLISH).equals("bound_value")) { - return 1234L; + return Optional.of(new Constant(INTEGER, 1234L)); } - return symbol.toSymbolReference(); + return Optional.empty(); }; private static final TestingTransactionManager TRANSACTION_MANAGER = new TestingTransactionManager(); From c45646e3f6d392b7bb530d5f157cb7610e315f99 Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Sun, 24 Mar 2024 18:04:18 -0700 Subject: [PATCH 3/6] Use precise type for context object --- .../sql/planner/IrExpressionInterpreter.java | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java b/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java index e50bd08e04b8..0bc837f15318 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java @@ -135,7 +135,7 @@ public Object optimize() } private class Visitor - extends IrVisitor + extends IrVisitor { private final boolean optimize; @@ -144,7 +144,7 @@ private Visitor(boolean optimize) this.optimize = optimize; } - private Object processWithExceptionHandling(Expression expression, Object context) + private Object processWithExceptionHandling(Expression expression, SymbolResolver context) { if (expression == null) { return null; @@ -167,9 +167,9 @@ private Object processWithExceptionHandling(Expression expression, Object contex } @Override - protected Object visitReference(Reference node, Object context) + protected Object visitReference(Reference node, SymbolResolver context) { - Optional binding = ((SymbolResolver) context).getValue(Symbol.from(node)); + Optional binding = context.getValue(Symbol.from(node)); if (binding.isPresent()) { return binding.get().value(); } @@ -178,13 +178,13 @@ protected Object visitReference(Reference node, Object context) } @Override - protected Object visitConstant(Constant node, Object context) + protected Object visitConstant(Constant node, SymbolResolver context) { return node.value(); } @Override - protected Object visitIsNull(IsNull node, Object context) + protected Object visitIsNull(IsNull node, SymbolResolver context) { Object value = processWithExceptionHandling(node.value(), context); @@ -196,7 +196,7 @@ protected Object visitIsNull(IsNull node, Object context) } @Override - protected Object visitCase(Case node, Object context) + protected Object visitCase(Case node, SymbolResolver context) { Object newDefault = null; boolean foundNewDefault = false; @@ -237,7 +237,7 @@ else if (Boolean.TRUE.equals(whenOperand)) { } @Override - protected Object visitSwitch(Switch node, Object context) + protected Object visitSwitch(Switch node, SymbolResolver context) { Object operand = processWithExceptionHandling(node.operand(), context); Type operandType = node.operand().type(); @@ -293,7 +293,7 @@ private boolean isEqual(Object operand1, Type type1, Object operand2, Type type2 } @Override - protected Object visitCoalesce(Coalesce node, Object context) + protected Object visitCoalesce(Coalesce node, SymbolResolver context) { List newOperands = processOperands(node, context); if (newOperands.isEmpty()) { @@ -307,7 +307,7 @@ protected Object visitCoalesce(Coalesce node, Object context) .collect(toImmutableList())); } - private List processOperands(Coalesce node, Object context) + private List processOperands(Coalesce node, SymbolResolver context) { List newOperands = new ArrayList<>(); Set uniqueNewOperands = new HashSet<>(); @@ -342,7 +342,7 @@ else if (value != null) { } @Override - protected Object visitIn(In node, Object context) + protected Object visitIn(In node, SymbolResolver context) { Object value = processWithExceptionHandling(node.value(), context); @@ -454,7 +454,7 @@ else if (!found && result) { } @Override - protected Object visitComparison(Comparison node, Object context) + protected Object visitComparison(Comparison node, SymbolResolver context) { Operator operator = node.operator(); Expression left = node.left(); @@ -487,7 +487,7 @@ protected Object visitComparison(Comparison node, Object context) return processComparisonExpression(context, operator, left, right); } - private Object processIsDistinctFrom(Object context, Expression leftExpression, Expression rightExpression) + private Object processIsDistinctFrom(SymbolResolver context, Expression leftExpression, Expression rightExpression) { Object left = processWithExceptionHandling(leftExpression, context); Object right = processWithExceptionHandling(rightExpression, context); @@ -507,7 +507,7 @@ private Object processIsDistinctFrom(Object context, Expression leftExpression, return invokeOperator(OperatorType.valueOf(Operator.IS_DISTINCT_FROM.name()), types(leftExpression, rightExpression), Arrays.asList(left, right)); } - private Object processComparisonExpression(Object context, Operator operator, Expression leftExpression, Expression rightExpression) + private Object processComparisonExpression(SymbolResolver context, Operator operator, Expression leftExpression, Expression rightExpression) { Object left = processWithExceptionHandling(leftExpression, context); if (left == null) { @@ -541,7 +541,7 @@ private Comparison flipComparison(Comparison comparison) } @Override - protected Object visitBetween(Between node, Object context) + protected Object visitBetween(Between node, SymbolResolver context) { Object value = processWithExceptionHandling(node.value(), context); if (value == null) { @@ -576,7 +576,7 @@ protected Object visitBetween(Between node, Object context) } @Override - protected Object visitNullIf(NullIf node, Object context) + protected Object visitNullIf(NullIf node, SymbolResolver context) { Object first = processWithExceptionHandling(node.first(), context); if (first == null) { @@ -614,7 +614,7 @@ protected Object visitNullIf(NullIf node, Object context) } @Override - protected Object visitNot(Not node, Object context) + protected Object visitNot(Not node, SymbolResolver context) { Object value = processWithExceptionHandling(node.value(), context); if (value == null) { @@ -629,7 +629,7 @@ protected Object visitNot(Not node, Object context) } @Override - protected Object visitLogical(Logical node, Object context) + protected Object visitLogical(Logical node, SymbolResolver context) { List terms = new ArrayList<>(); List types = new ArrayList<>(); @@ -682,7 +682,7 @@ protected Object visitLogical(Logical node, Object context) } @Override - protected Object visitCall(Call node, Object context) + protected Object visitCall(Call node, SymbolResolver context) { if (node.function().getName().getFunctionName().equals(mangleOperatorName(NEGATION))) { return processNegation(node, context); @@ -718,7 +718,7 @@ protected Object visitCall(Call node, Object context) return functionInvoker.invoke(resolvedFunction, connectorSession, argumentValues); } - private Object processNegation(Call negation, Object context) + private Object processNegation(Call negation, SymbolResolver context) { Object value = processWithExceptionHandling(negation.arguments().getFirst(), context); @@ -731,7 +731,7 @@ private Object processNegation(Call negation, Object context) } @Override - protected Object visitLambda(Lambda node, Object context) + protected Object visitLambda(Lambda node, SymbolResolver context) { if (optimize) { // TODO: enable optimization related to lambda expression @@ -768,7 +768,7 @@ protected Object visitLambda(Lambda node, Object context) } @Override - protected Object visitBind(Bind node, Object context) + protected Object visitBind(Bind node, SymbolResolver context) { List values = node.values().stream() .map(value -> processWithExceptionHandling(value, context)) @@ -788,7 +788,7 @@ protected Object visitBind(Bind node, Object context) } @Override - public Object visitCast(Cast node, Object context) + public Object visitCast(Cast node, SymbolResolver context) { Object value = processWithExceptionHandling(node.expression(), context); Type targetType = node.type(); @@ -819,7 +819,7 @@ public Object visitCast(Cast node, Object context) } @Override - protected Object visitRow(Row node, Object context) + protected Object visitRow(Row node, SymbolResolver context) { RowType rowType = (RowType) ((Expression) node).type(); List parameterTypes = rowType.getTypeParameters(); @@ -841,7 +841,7 @@ protected Object visitRow(Row node, Object context) } @Override - protected Object visitFieldReference(FieldReference node, Object context) + protected Object visitFieldReference(FieldReference node, SymbolResolver context) { Object base = processWithExceptionHandling(node.base(), context); if (base == null) { @@ -858,7 +858,7 @@ protected Object visitFieldReference(FieldReference node, Object context) } @Override - protected Object visitExpression(Expression node, Object context) + protected Object visitExpression(Expression node, SymbolResolver context) { throw new TrinoException(NOT_SUPPORTED, "not yet implemented: " + node.getClass().getName()); } From 3a13ee01385d7e6a545d508d8da49c4642970c04 Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Mon, 25 Mar 2024 19:29:23 -0700 Subject: [PATCH 4/6] Make expression optimizer always return Expression --- .../io/trino/cost/FilterStatsCalculator.java | 14 +++----- .../io/trino/cost/ScalarStatsCalculator.java | 19 +++++------ .../planner/EffectivePredicateExtractor.java | 21 ++++++------ .../sql/planner/IrExpressionInterpreter.java | 14 +++++--- .../planner/LayoutConstraintEvaluator.java | 6 ++-- .../rule/PreAggregateCaseAggregations.java | 26 +++++++-------- .../rule/PushAggregationIntoTableScan.java | 7 +--- .../rule/PushPredicateIntoTableScan.java | 8 +---- .../rule/PushProjectionIntoTableScan.java | 7 +--- .../rule/RemoveRedundantDateTrunc.java | 2 +- .../iterative/rule/SimplifyExpressions.java | 8 +---- .../rule/UnwrapCastInComparison.java | 22 ++++++------- .../rule/UnwrapDateTruncInComparison.java | 17 ++++------ .../rule/UnwrapYearInComparison.java | 8 ++--- .../optimizations/PredicatePushDown.java | 13 +++----- .../optimizations/PropertyDerivations.java | 10 +++--- .../trino/sql/routine/SqlRoutinePlanner.java | 8 +---- .../trino/sql/TestExpressionInterpreter.java | 32 +++++++++---------- .../sql/TestSqlToRowExpressionTranslator.java | 7 +--- 19 files changed, 102 insertions(+), 147 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java b/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java index 97257f53eddf..42bc358df0f7 100644 --- a/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java +++ b/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java @@ -98,20 +98,14 @@ public PlanNodeStatsEstimate filterStats( private Expression simplifyExpression(Session session, Expression predicate) { // TODO reuse io.trino.sql.planner.iterative.rule.SimplifyExpressions.rewrite + Expression value = new IrExpressionInterpreter(predicate, plannerContext, session).optimize(); - IrExpressionInterpreter interpreter = new IrExpressionInterpreter(predicate, plannerContext, session); - Object value = interpreter.optimize(); - - if (value instanceof Expression expression) { - return expression; - } - - if (value == null) { + if (value instanceof Constant constant && constant.value() == null) { // Expression evaluates to SQL null, which in Filter is equivalent to false. This assumes the expression is a top-level expression (eg. not in NOT). - value = false; + value = Booleans.FALSE; } - return new Constant(BOOLEAN, value); + return value; } private class FilterExpressionStatsCalculatingVisitor diff --git a/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java b/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java index 9990c026bdca..8e117f810067 100644 --- a/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java +++ b/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java @@ -129,23 +129,20 @@ else if (node.function().getName().equals(builtinFunctionName(ADD)) || return processArithmetic(node); } - IrExpressionInterpreter interpreter = new IrExpressionInterpreter(node, plannerContext, session); - Object value = interpreter.optimize(); + Expression value = new IrExpressionInterpreter(node, plannerContext, session).optimize(); - if (value == null) { + if (value instanceof Constant constant && constant.value() == null) { return nullStatsEstimate(); } - if (value instanceof Expression) { - // value is not a constant - return SymbolStatsEstimate.unknown(); + if (value instanceof Constant) { + return SymbolStatsEstimate.builder() + .setNullsFraction(0) + .setDistinctValuesCount(1) + .build(); } - // value is a constant - return SymbolStatsEstimate.builder() - .setNullsFraction(0) - .setDistinctValuesCount(1) - .build(); + return SymbolStatsEstimate.unknown(); } @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/EffectivePredicateExtractor.java b/core/trino-main/src/main/java/io/trino/sql/planner/EffectivePredicateExtractor.java index 853c809a0204..bb3a74074b43 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/EffectivePredicateExtractor.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/EffectivePredicateExtractor.java @@ -30,6 +30,7 @@ import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; import io.trino.sql.ir.Comparison; +import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.ir.Reference; import io.trino.sql.ir.Row; @@ -352,12 +353,11 @@ public Expression visitValues(ValuesNode node, Void context) nonDeterministic[i] = true; } else { - IrExpressionInterpreter interpreter = new IrExpressionInterpreter(value, plannerContext, session); - Object item = interpreter.optimize(); - if (item instanceof Expression) { + Expression item = new IrExpressionInterpreter(value, plannerContext, session).optimize(); + if (!(item instanceof Constant constant)) { return TRUE; } - if (item == null) { + if (constant.value() == null) { hasNull[i] = true; } else { @@ -365,15 +365,15 @@ public Expression visitValues(ValuesNode node, Void context) if (!type.isComparable() && !type.isOrderable()) { return TRUE; } - if (hasNestedNulls(type, item)) { + if (hasNestedNulls(type, ((Constant) item).value())) { // Workaround solution to deal with array and row comparisons don't support null elements currently. // TODO: remove when comparisons are fixed return TRUE; } - if (isFloatingPointNaN(type, item)) { + if (isFloatingPointNaN(type, ((Constant) item).value())) { hasNaN[i] = true; } - valuesBuilders.get(i).add(item); + valuesBuilders.get(i).add(((Constant) item).value()); } } } @@ -382,12 +382,11 @@ public Expression visitValues(ValuesNode node, Void context) if (!DeterminismEvaluator.isDeterministic(expression)) { return TRUE; } - IrExpressionInterpreter interpreter = new IrExpressionInterpreter(expression, plannerContext, session); - Object evaluated = interpreter.optimize(); - if (evaluated instanceof Expression) { + Expression evaluated = new IrExpressionInterpreter(expression, plannerContext, session).optimize(); + if (!(evaluated instanceof Constant constant)) { return TRUE; } - SqlRow sqlRow = (SqlRow) evaluated; + SqlRow sqlRow = (SqlRow) constant.value(); int rawIndex = sqlRow.getRawIndex(); for (int i = 0; i < node.getOutputSymbols().size(); i++) { Type type = node.getOutputSymbols().get(i).getType(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java b/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java index 0bc837f15318..55c0f1dbb523 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java @@ -124,14 +124,20 @@ public Object evaluate() return result; } - public Object optimize(SymbolResolver inputs) + public Expression optimize(SymbolResolver inputs) { - return new Visitor(true).processWithExceptionHandling(expression, inputs); + Object result = new Visitor(true).processWithExceptionHandling(expression, inputs); + + if (result instanceof Expression expression) { + return expression; + } + + return new Constant(expression.type(), result); } - public Object optimize() + public Expression optimize() { - return new Visitor(true).processWithExceptionHandling(expression, NoOpSymbolResolver.INSTANCE); + return optimize(NoOpSymbolResolver.INSTANCE); } private class Visitor diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LayoutConstraintEvaluator.java b/core/trino-main/src/main/java/io/trino/sql/planner/LayoutConstraintEvaluator.java index 8b31d32a09dc..bc8497d1b65a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LayoutConstraintEvaluator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LayoutConstraintEvaluator.java @@ -19,6 +19,8 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.predicate.NullableValue; import io.trino.sql.PlannerContext; +import io.trino.sql.ir.Booleans; +import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import java.util.Map; @@ -57,9 +59,9 @@ public boolean isCandidate(Map bindings) // Skip pruning if evaluation fails in a recoverable way. Failing here can cause // spurious query failures for partitions that would otherwise be filtered out. - Object optimized = TryFunction.evaluate(() -> evaluator.optimize(inputs), true); + Expression optimized = TryFunction.evaluate(() -> evaluator.optimize(inputs), Booleans.TRUE); // If any conjuncts evaluate to FALSE or null, then the whole predicate will never be true and so the partition should be pruned - return !(Boolean.FALSE.equals(optimized) || optimized == null); + return !(optimized instanceof Constant constant) || Boolean.TRUE.equals(constant.value()); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PreAggregateCaseAggregations.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PreAggregateCaseAggregations.java index eccbf70549af..e48588202218 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PreAggregateCaseAggregations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PreAggregateCaseAggregations.java @@ -383,22 +383,21 @@ private Optional extractCaseAggregation(Symbol aggregationSymbo Optional cumulativeAggregationDefaultValue = Optional.empty(); if (caseExpression.defaultValue().isPresent()) { - Type defaultType = getType(caseExpression.defaultValue().get()); - Object defaultValue = optimizeExpression(caseExpression.defaultValue().get(), context); - if (defaultValue != null) { + Expression defaultValue = optimizeExpression(caseExpression.defaultValue().get(), context); + if (defaultValue instanceof Constant(Type type, Object value) && value != null) { if (!name.equals(SUM)) { return Optional.empty(); } // sum aggregation is only supported if default value is null or 0, otherwise it wouldn't be cumulative - if (defaultType instanceof BigintType - || defaultType == INTEGER - || defaultType == SMALLINT - || defaultType == TINYINT - || defaultType == DOUBLE - || defaultType == REAL - || defaultType instanceof DecimalType) { - if (!defaultValue.equals(0L) && !defaultValue.equals(0.0d) && !defaultValue.equals(Int128.ZERO)) { + if (type instanceof BigintType + || type == INTEGER + || type == SMALLINT + || type == TINYINT + || type == DOUBLE + || type == REAL + || type instanceof DecimalType) { + if (!value.equals(0L) && !value.equals(0.0d) && !value.equals(Int128.ZERO)) { return Optional.empty(); } } @@ -426,10 +425,9 @@ private Type getType(Expression expression) return expression.type(); } - private Object optimizeExpression(Expression expression, Context context) + private Expression optimizeExpression(Expression expression, Context context) { - IrExpressionInterpreter expressionInterpreter = new IrExpressionInterpreter(expression, plannerContext, context.getSession()); - return expressionInterpreter.optimize(); + return new IrExpressionInterpreter(expression, plannerContext, context.getSession()).optimize(); } private static class CaseAggregation diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java index 16e43bfcf0c4..467d0602dc1c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java @@ -30,7 +30,6 @@ import io.trino.spi.function.BoundSignature; import io.trino.spi.predicate.TupleDomain; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.ir.Reference; import io.trino.sql.planner.ConnectorExpressionTranslator; @@ -188,11 +187,7 @@ public static Optional pushAggregationIntoTableScan( Expression translated = ConnectorExpressionTranslator.translate(session, expression, plannerContext, variableMappings); // ConnectorExpressionTranslator may or may not preserve optimized form of expressions during round-trip. Avoid potential optimizer loop // by ensuring expression is optimized. - Object optimized = new IrExpressionInterpreter(translated, plannerContext, session).optimize(); - - return optimized instanceof Expression optimizedExpression ? - optimizedExpression : - new Constant(translated.type(), optimized); + return new IrExpressionInterpreter(translated, plannerContext, session).optimize(); }) .collect(toImmutableList()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java index 931936e13c5f..ec477e979224 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java @@ -33,7 +33,6 @@ import io.trino.spi.predicate.TupleDomain; import io.trino.sql.PlannerContext; import io.trino.sql.ir.Booleans; -import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.planner.ConnectorExpressionTranslator; import io.trino.sql.planner.ConnectorExpressionTranslator.ConnectorExpressionTranslation; @@ -264,12 +263,7 @@ public static Optional pushFilterIntoTableScan( Expression translatedExpression = ConnectorExpressionTranslator.translate(session, remainingConnectorExpression.get(), plannerContext, variableMappings); // ConnectorExpressionTranslator may or may not preserve optimized form of expressions during round-trip. Avoid potential optimizer loop // by ensuring expression is optimized. - Object optimized = new IrExpressionInterpreter(translatedExpression, plannerContext, session).optimize(); - - translatedExpression = optimized instanceof Expression optimizedExpression ? - optimizedExpression : - new Constant(translatedExpression.type(), optimized); - + translatedExpression = new IrExpressionInterpreter(translatedExpression, plannerContext, session).optimize(); remainingDecomposedPredicate = combineConjuncts(translatedExpression, expressionTranslation.remainingExpression()); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionIntoTableScan.java index a8cbd38e6af4..4592acbf3de7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionIntoTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionIntoTableScan.java @@ -31,7 +31,6 @@ import io.trino.spi.expression.Variable; import io.trino.spi.predicate.TupleDomain; import io.trino.sql.PlannerContext; -import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.ir.NodeRef; import io.trino.sql.planner.ConnectorExpressionTranslator; @@ -148,11 +147,7 @@ public Result apply(ProjectNode project, Captures captures, Context context) Expression translated = ConnectorExpressionTranslator.translate(session, expression, plannerContext, variableMappings); // ConnectorExpressionTranslator may or may not preserve optimized form of expressions during round-trip. Avoid potential optimizer loop // by ensuring expression is optimized. - Object optimized = new IrExpressionInterpreter(translated, plannerContext, session).optimize(); - - return optimized instanceof Expression optimizedExpression ? - optimizedExpression : - new Constant(translated.type(), optimized); + return new IrExpressionInterpreter(translated, plannerContext, session).optimize(); }) .collect(toImmutableList()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantDateTrunc.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantDateTrunc.java index 17b2d9c4483c..27c123c25192 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantDateTrunc.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantDateTrunc.java @@ -69,7 +69,7 @@ public Expression rewriteCall(Call node, Void context, ExpressionTreeRewriter new Constant(BOOLEAN, null); case IS_DISTINCT_FROM -> new Not(new IsNull(cast)); }; } - if (right instanceof Expression) { + if (!(right instanceof Constant(Type type, Object rightValue))) { return expression; } @@ -184,21 +184,21 @@ private Expression unwrapCast(Comparison expression) Type targetType = expression.right().type(); if (sourceType instanceof TimestampType && targetType == DATE) { - return unwrapTimestampToDateCast((TimestampType) sourceType, operator, cast.expression(), (long) right).orElse(expression); + return unwrapTimestampToDateCast((TimestampType) sourceType, operator, cast.expression(), (long) rightValue).orElse(expression); } if (targetType instanceof TimestampWithTimeZoneType) { // Note: two TIMESTAMP WITH TIME ZONE values differing in zone only (same instant) are considered equal. - right = withTimeZone(((TimestampWithTimeZoneType) targetType), right, session.getTimeZoneKey()); + rightValue = withTimeZone(((TimestampWithTimeZoneType) targetType), rightValue, session.getTimeZoneKey()); } - if (!hasInjectiveImplicitCoercion(sourceType, targetType, right)) { + if (!hasInjectiveImplicitCoercion(sourceType, targetType, rightValue)) { return expression; } // Handle comparison against NaN. // It must be done before source type range bounds are compared to target value. - if (isFloatingPointNaN(targetType, right)) { + if (isFloatingPointNaN(targetType, rightValue)) { switch (operator) { case EQUAL: case GREATER_THAN: @@ -234,7 +234,7 @@ private Expression unwrapCast(Comparison expression) if (maxInTargetType != null) { // NaN values of `right` are excluded at this point. Otherwise, NaN would be recognized as // greater than source type upper bound, and incorrect expression might be derived. - int upperBoundComparison = compare(targetType, right, maxInTargetType); + int upperBoundComparison = compare(targetType, rightValue, maxInTargetType); if (upperBoundComparison > 0) { // larger than maximum representable value return switch (operator) { @@ -259,7 +259,7 @@ private Expression unwrapCast(Comparison expression) Object min = sourceRange.get().getMin(); Object minInTargetType = coerce(min, sourceToTarget); - int lowerBoundComparison = compare(targetType, right, minInTargetType); + int lowerBoundComparison = compare(targetType, rightValue, minInTargetType); if (lowerBoundComparison < 0) { // smaller than minimum representable value return switch (operator) { @@ -294,7 +294,7 @@ private Expression unwrapCast(Comparison expression) Object literalInSourceType; try { - literalInSourceType = coerce(right, targetToSource); + literalInSourceType = coerce(rightValue, targetToSource); } catch (TrinoException e) { // A failure to cast from target -> source type could be because: @@ -309,7 +309,7 @@ private Expression unwrapCast(Comparison expression) if (targetType.isOrderable()) { Object roundtripLiteral = coerce(literalInSourceType, sourceToTarget); - int literalVsRoundtripped = compare(targetType, right, roundtripLiteral); + int literalVsRoundtripped = compare(targetType, rightValue, roundtripLiteral); if (literalVsRoundtripped > 0) { // cast rounded down diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapDateTruncInComparison.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapDateTruncInComparison.java index 90dbfa608265..2bb0bb94fe65 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapDateTruncInComparison.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapDateTruncInComparison.java @@ -146,7 +146,7 @@ private Expression unwrapDateTrunc(Comparison expression) if (!(unitExpression.type() instanceof VarcharType) || !(unitExpression instanceof Constant)) { return expression; } - Slice unitName = (Slice) new IrExpressionInterpreter(unitExpression, plannerContext, session).optimize(); + Slice unitName = (Slice) new IrExpressionInterpreter(unitExpression, plannerContext, session).evaluate(); if (unitName == null) { return expression; } @@ -154,19 +154,16 @@ private Expression unwrapDateTrunc(Comparison expression) Expression argument = call.arguments().get(1); Type argumentType = argument.type(); - Type rightType = expression.right().type(); - verify(argumentType.equals(rightType), "Mismatched types: %s and %s", argumentType, rightType); + Expression right = new IrExpressionInterpreter(expression.right(), plannerContext, session).optimize(); - Object right = new IrExpressionInterpreter(expression.right(), plannerContext, session).optimize(); - - if (right == null) { + if (right instanceof Constant constant && constant.value() == null) { return switch (expression.operator()) { case EQUAL, NOT_EQUAL, LESS_THAN, LESS_THAN_OR_EQUAL, GREATER_THAN, GREATER_THAN_OR_EQUAL -> new Constant(BOOLEAN, null); case IS_DISTINCT_FROM -> new Not(new IsNull(argument)); }; } - if (right instanceof Expression) { + if (!(right instanceof Constant(Type rightType, Object rightValue))) { return expression; } if (rightType instanceof TimestampWithTimeZoneType) { @@ -187,9 +184,9 @@ private Expression unwrapDateTrunc(Comparison expression) return expression; } - Object rangeLow = functionInvoker.invoke(resolvedFunction, session.toConnectorSession(), ImmutableList.of(unitName, right)); - int compare = compare(rightType, rangeLow, right); - verify(compare <= 0, "Truncation of %s value %s resulted in a bigger value %s", rightType, right, rangeLow); + Object rangeLow = functionInvoker.invoke(resolvedFunction, session.toConnectorSession(), ImmutableList.of(unitName, rightValue)); + int compare = compare(rightType, rangeLow, rightValue); + verify(compare <= 0, "Truncation of %s value %s resulted in a bigger value %s", rightType, rightValue, rangeLow); boolean rightValueAtRangeLow = compare == 0; return switch (expression.operator()) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapYearInComparison.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapYearInComparison.java index 13d6955f6814..e64424cf173c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapYearInComparison.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapYearInComparison.java @@ -150,16 +150,16 @@ private Expression unwrapYear(Comparison expression) Expression argument = getOnlyElement(call.arguments()); Type argumentType = argument.type(); - Object right = new IrExpressionInterpreter(expression.right(), plannerContext, session).optimize(); + Expression right = new IrExpressionInterpreter(expression.right(), plannerContext, session).optimize(); - if (right == null) { + if (right instanceof Constant constant && constant.value() == null) { return switch (expression.operator()) { case EQUAL, NOT_EQUAL, LESS_THAN, LESS_THAN_OR_EQUAL, GREATER_THAN, GREATER_THAN_OR_EQUAL -> new Constant(BOOLEAN, null); case IS_DISTINCT_FROM -> new Not(new IsNull(argument)); }; } - if (right instanceof Expression) { + if (!(right instanceof Constant(Type rightType, Object rightValue))) { return expression; } if (argumentType instanceof TimestampWithTimeZoneType) { @@ -172,7 +172,7 @@ private Expression unwrapYear(Comparison expression) return expression; } - int year = toIntExact((Long) right); + int year = toIntExact((Long) rightValue); return switch (expression.operator()) { case EQUAL -> between(argument, argumentType, calculateRangeStartInclusive(year, argumentType), calculateRangeEndInclusive(year, argumentType)); case NOT_EQUAL -> new Not(between(argument, argumentType, calculateRangeStartInclusive(year, argumentType), calculateRangeEndInclusive(year, argumentType))); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java index 992959f36cf9..a00e46276859 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java @@ -1159,8 +1159,8 @@ private boolean canConvertOuterToInner(List innerSymbolsForOuterJoin, Ex for (Expression conjunct : extractConjuncts(inheritedPredicate)) { if (isDeterministic(conjunct)) { // Ignore a conjunct for this test if we cannot deterministically get responses from it - Object response = nullInputEvaluator(innerSymbols, conjunct); - if (response == null || Boolean.FALSE.equals(response)) { + Expression response = nullInputEvaluator(innerSymbols, conjunct); + if (response instanceof Constant constant && (constant.value() == null || Boolean.FALSE.equals(constant.value()))) { // If there is a single conjunct that returns FALSE or NULL given all NULL inputs for the inner side symbols of an outer join // then this conjunct removes all effects of the outer join, and effectively turns this into an equivalent of an inner join. // So, let's just rewrite this join as an INNER join @@ -1174,12 +1174,7 @@ private boolean canConvertOuterToInner(List innerSymbolsForOuterJoin, Ex // Temporary implementation for joins because the SimplifyExpressions optimizers cannot run properly on join clauses private Expression simplifyExpression(Expression expression) { - IrExpressionInterpreter optimizer = new IrExpressionInterpreter(expression, plannerContext, session); - Object object = optimizer.optimize(); - - return object instanceof Expression optimized ? - optimized : - new Constant(expression.type(), object); + return new IrExpressionInterpreter(expression, plannerContext, session).optimize(); } private boolean areExpressionsEquivalent(Expression leftExpression, Expression rightExpression) @@ -1190,7 +1185,7 @@ private boolean areExpressionsEquivalent(Expression leftExpression, Expression r /** * Evaluates an expression's response to binding the specified input symbols to NULL */ - private Object nullInputEvaluator(Collection nullSymbols, Expression expression) + private Expression nullInputEvaluator(Collection nullSymbols, Expression expression) { return new IrExpressionInterpreter(expression, plannerContext, session) .optimize(symbol -> nullSymbols.contains(symbol) ? Optional.of(new Constant(symbol.getType(), null)) : Optional.empty()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java index f2a4d4b9102c..c10e3bbd3882 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java @@ -31,6 +31,7 @@ import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; import io.trino.sql.ir.Coalesce; +import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.ir.Reference; import io.trino.sql.planner.DomainTranslator; @@ -767,23 +768,22 @@ public ActualProperties visitProject(ProjectNode node, List in Expression expression = assignment.getValue(); Type type = expression.type(); - IrExpressionInterpreter optimizer = new IrExpressionInterpreter(expression, plannerContext, session); // TODO: // We want to use a symbol resolver that looks up in the constants from the input subplan // to take advantage of constant-folding for complex expressions // However, that currently causes errors when those expressions operate on arrays or row types // ("ROW comparison not supported for fields with null elements", etc) - Object value = optimizer.optimize(); + Expression value = new IrExpressionInterpreter(expression, plannerContext, session).optimize(); if (value instanceof Reference) { - Symbol symbol = Symbol.from((Reference) value); + Symbol symbol = Symbol.from(value); NullableValue existingConstantValue = constants.get(symbol); if (existingConstantValue != null) { constants.put(assignment.getKey(), new NullableValue(type, value)); } } - else if (!(value instanceof Expression)) { - constants.put(assignment.getKey(), new NullableValue(type, value)); + else if (value instanceof Constant constant) { + constants.put(assignment.getKey(), new NullableValue(type, constant.value())); } } constants.putAll(translatedProperties.getConstants()); diff --git a/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutinePlanner.java b/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutinePlanner.java index db74e1c5a0ca..3006f4a1ba96 100644 --- a/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutinePlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutinePlanner.java @@ -26,7 +26,6 @@ import io.trino.sql.analyzer.RelationId; import io.trino.sql.analyzer.RelationType; import io.trino.sql.analyzer.Scope; -import io.trino.sql.ir.Constant; import io.trino.sql.ir.Reference; import io.trino.sql.planner.IrExpressionInterpreter; import io.trino.sql.planner.Symbol; @@ -327,12 +326,7 @@ private RowExpression toRowExpression(Context context, Expression expression) io.trino.sql.ir.Expression lambdaCaptureDesugared = LambdaCaptureDesugaringRewriter.rewrite(translated, symbolAllocator); // optimize the expression - IrExpressionInterpreter interpreter = new IrExpressionInterpreter(lambdaCaptureDesugared, plannerContext, session); - Object value = interpreter.optimize(); - - io.trino.sql.ir.Expression optimized = value instanceof io.trino.sql.ir.Expression optimizedExpression ? - optimizedExpression : - new Constant(lambdaCaptureDesugared.type(), value); + io.trino.sql.ir.Expression optimized = new IrExpressionInterpreter(lambdaCaptureDesugared, plannerContext, session).optimize(); // translate to RowExpression TranslationVisitor translator = new TranslationVisitor(plannerContext.getMetadata(), plannerContext.getTypeManager(), ImmutableMap.of(), context.variables()); diff --git a/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java b/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java index f2df8a2ee9d9..314243d7861d 100644 --- a/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java +++ b/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java @@ -167,7 +167,7 @@ public void testComparison() { assertOptimizedEquals( new Comparison(EQUAL, new Constant(UNKNOWN, null), new Constant(UNKNOWN, null)), - new Constant(UNKNOWN, null)); + new Constant(BOOLEAN, null)); assertOptimizedEquals( new Comparison(EQUAL, new Constant(VARCHAR, Slices.utf8Slice("a")), new Constant(VARCHAR, Slices.utf8Slice("b"))), @@ -253,15 +253,15 @@ public void testNullIf() { assertOptimizedEquals( new NullIf(new Constant(VARCHAR, Slices.utf8Slice("a")), new Constant(VARCHAR, Slices.utf8Slice("a"))), - new Constant(UNKNOWN, null)); + new Constant(VARCHAR, null)); assertOptimizedEquals( new NullIf(new Constant(VARCHAR, Slices.utf8Slice("a")), new Constant(VARCHAR, Slices.utf8Slice("b"))), new Constant(VARCHAR, Slices.utf8Slice("a"))); assertOptimizedEquals( - new NullIf(new Constant(UNKNOWN, null), new Constant(VARCHAR, Slices.utf8Slice("b"))), - new Constant(UNKNOWN, null)); + new NullIf(new Constant(VARCHAR, null), new Constant(VARCHAR, Slices.utf8Slice("b"))), + new Constant(VARCHAR, null)); assertOptimizedEquals( - new NullIf(new Constant(VARCHAR, Slices.utf8Slice("a")), new Constant(UNKNOWN, null)), + new NullIf(new Constant(VARCHAR, Slices.utf8Slice("a")), new Constant(VARCHAR, null)), new Constant(VARCHAR, Slices.utf8Slice("a"))); assertOptimizedEquals( new NullIf(new Reference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L)), @@ -363,10 +363,10 @@ public void testIn() assertOptimizedEquals( new In(new Constant(INTEGER, null), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, null), new Constant(INTEGER, 3L), new Constant(INTEGER, 5L))), - new Constant(UNKNOWN, null)); + new Constant(BOOLEAN, null)); assertOptimizedEquals( new In(new Constant(INTEGER, 3L), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, null))), - new Constant(UNKNOWN, null)); + new Constant(BOOLEAN, null)); assertOptimizedEquals( new In(new Reference(INTEGER, "bound_value"), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 1234L), new Constant(INTEGER, 3L), new Constant(INTEGER, 5L))), @@ -433,13 +433,13 @@ public void testTryCast() { assertOptimizedEquals( new Cast(new Constant(UNKNOWN, null), BIGINT, true), - new Constant(UNKNOWN, null)); + new Constant(BIGINT, null)); assertOptimizedEquals( new Cast(new Constant(INTEGER, 123L), BIGINT, true), - new Constant(INTEGER, 123L)); + new Constant(BIGINT, 123L)); assertOptimizedEquals( new Cast(new Constant(UNKNOWN, null), INTEGER, true), - new Constant(UNKNOWN, null)); + new Constant(INTEGER, null)); assertOptimizedEquals( new Cast(new Constant(INTEGER, 123L), INTEGER, true), new Constant(INTEGER, 123L)); @@ -558,7 +558,7 @@ public void testSimpleCase() ImmutableList.of( new WhenClause(TRUE, new Constant(INTEGER, 33L))), Optional.empty()), - new Constant(UNKNOWN, null)); + new Constant(INTEGER, null)); for (Switch aSwitch : Arrays.asList(new Switch( new Constant(BOOLEAN, null), ImmutableList.of( @@ -693,7 +693,7 @@ public void testSimpleCase() new WhenClause(new Constant(INTEGER, 2L), new Constant(INTEGER, 2L)), new WhenClause(new Constant(INTEGER, 3L), new Constant(INTEGER, 3L))), Optional.empty()), - new Constant(UNKNOWN, null)); + new Constant(INTEGER, null)); assertEvaluatedEquals( new Switch( @@ -799,19 +799,19 @@ public void testIf() new Constant(INTEGER, 3L)); assertOptimizedEquals( ifExpression(FALSE, new Constant(INTEGER, 3L), new Constant(INTEGER, null)), - new Constant(UNKNOWN, null)); + new Constant(INTEGER, null)); assertOptimizedEquals( ifExpression(TRUE, new Constant(INTEGER, null), new Constant(INTEGER, 4L)), - new Constant(UNKNOWN, null)); + new Constant(INTEGER, null)); assertOptimizedEquals( ifExpression(FALSE, new Constant(INTEGER, null), new Constant(INTEGER, 4L)), new Constant(INTEGER, 4L)); assertOptimizedEquals( ifExpression(TRUE, new Constant(INTEGER, null), new Constant(INTEGER, null)), - new Constant(UNKNOWN, null)); + new Constant(INTEGER, null)); assertOptimizedEquals( ifExpression(FALSE, new Constant(INTEGER, null), new Constant(INTEGER, null)), - new Constant(UNKNOWN, null)); + new Constant(INTEGER, null)); assertOptimizedEquals( ifExpression(TRUE, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), diff --git a/core/trino-main/src/test/java/io/trino/sql/TestSqlToRowExpressionTranslator.java b/core/trino-main/src/test/java/io/trino/sql/TestSqlToRowExpressionTranslator.java index 34923a43b2b0..5ed8b14fb2d0 100644 --- a/core/trino-main/src/test/java/io/trino/sql/TestSqlToRowExpressionTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/TestSqlToRowExpressionTranslator.java @@ -87,11 +87,6 @@ private RowExpression translateAndOptimize(Expression expression) private Expression simplifyExpression(Expression expression) { // Testing simplified expressions is important, since simplification may create CASTs or function calls that cannot be simplified by the ExpressionOptimizer - IrExpressionInterpreter interpreter = new IrExpressionInterpreter(expression, PLANNER_CONTEXT, TEST_SESSION); - Object value = interpreter.optimize(); - - return value instanceof Expression optimized ? - optimized : - new Constant(expression.type(), value); + return new IrExpressionInterpreter(expression, PLANNER_CONTEXT, TEST_SESSION).optimize(); } } From e96e06e6adfeab6799d7758d5ec6a121724598c3 Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Tue, 26 Mar 2024 11:52:14 -0700 Subject: [PATCH 5/6] Verify types in assignments --- .../trino/sql/analyzer/StatementAnalyzer.java | 2 +- .../rule/UnwrapDateTruncInComparison.java | 2 - .../trino/sql/planner/plan/Assignments.java | 3 ++ ...TestFilterProjectAggregationStatsRule.java | 2 +- .../trino/sql/planner/TestTypeValidator.java | 19 -------- .../rule/TestArraySortAfterArrayDistinct.java | 2 +- ...elateInnerUnnestWithGlobalAggregation.java | 24 +++++----- ...relateLeftUnnestWithGlobalAggregation.java | 18 +++---- .../iterative/rule/TestDecorrelateUnnest.java | 11 +++-- .../rule/TestEliminateCrossJoins.java | 4 +- .../rule/TestExpressionRewriteRuleSet.java | 17 +------ .../TestImplementTableFunctionSource.java | 2 +- .../rule/TestInlineProjectIntoFilter.java | 18 +++---- .../iterative/rule/TestInlineProjections.java | 36 +++++++------- .../rule/TestMergeAdjacentWindows.java | 5 +- .../rule/TestMergeProjectWithValues.java | 47 ++++++++++--------- ...ipleDistinctAggregationToMarkDistinct.java | 9 ++-- ...TestOptimizeDuplicateInsensitiveJoins.java | 2 +- .../rule/TestPruneTableScanColumns.java | 6 +-- .../rule/TestPruneValuesColumns.java | 6 +-- .../iterative/rule/TestPushCastIntoRow.java | 2 +- .../rule/TestPushDownDereferencesRules.java | 12 ++--- .../rule/TestPushLimitThroughProject.java | 3 +- .../rule/TestPushOffsetThroughProject.java | 3 +- .../rule/TestPushProjectionIntoTableScan.java | 10 ++-- .../TestPushProjectionThroughExchange.java | 6 +-- .../rule/TestPushProjectionThroughUnion.java | 2 +- .../rule/TestRemoveRedundantExists.java | 9 ++-- ...estReplaceJoinOverConstantWithProject.java | 42 ++++++++--------- ...estSingleDistinctAggregationToGroupBy.java | 3 +- ...atedDistinctAggregationWithProjection.java | 2 +- ...elatedGlobalAggregationWithProjection.java | 16 +++---- ...tedGlobalAggregationWithoutProjection.java | 6 +-- ...latedGroupedAggregationWithProjection.java | 12 ++--- ...TestTransformCorrelatedScalarSubquery.java | 6 +-- ...mCorrelatedSingleRowSubqueryToProject.java | 2 +- .../rule/TestUnwrapRowSubscript.java | 2 +- .../iterative/rule/test/TestRuleTester.java | 2 +- .../TestConnectorPushdownRulesWithHive.java | 2 +- ...TestConnectorPushdownRulesWithIceberg.java | 2 +- 40 files changed, 177 insertions(+), 202 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index 6559424346fc..82b72ca125b8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -3699,7 +3699,7 @@ private void createMergeAnalysis(Table table, TableHandle handle, TableSchema ta // create the RowType that holds all column values List fields = new ArrayList<>(); for (ColumnSchema schema : dataColumnSchemas) { - fields.add(new RowType.Field(Optional.of(schema.getName()), schema.getType())); + fields.add(RowType.field(schema.getType())); } fields.add(new RowType.Field(Optional.empty(), BOOLEAN)); // present fields.add(new RowType.Field(Optional.empty(), TINYINT)); // operation_number diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapDateTruncInComparison.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapDateTruncInComparison.java index 2bb0bb94fe65..7f0ab9ae364e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapDateTruncInComparison.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapDateTruncInComparison.java @@ -152,8 +152,6 @@ private Expression unwrapDateTrunc(Comparison expression) } Expression argument = call.arguments().get(1); - Type argumentType = argument.type(); - Expression right = new IrExpressionInterpreter(expression.right(), plannerContext, session).optimize(); if (right instanceof Constant constant && constant.value() == null) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/Assignments.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/Assignments.java index a20716fa107a..56fc4b2a6802 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/Assignments.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/Assignments.java @@ -34,6 +34,7 @@ import java.util.function.Predicate; import java.util.stream.Collector; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static java.util.Arrays.asList; import static java.util.Objects.requireNonNull; @@ -254,6 +255,8 @@ public Builder putAll(Map assignments) public Builder put(Symbol symbol, Expression expression) { + checkArgument(symbol.getType().equals(expression.type()), "Types don't match: %s vs %s, for %s and %s", symbol.getType(), expression.type(), symbol, expression); + if (assignments.containsKey(symbol)) { Expression assignment = assignments.get(symbol); checkState( diff --git a/core/trino-main/src/test/java/io/trino/cost/TestFilterProjectAggregationStatsRule.java b/core/trino-main/src/test/java/io/trino/cost/TestFilterProjectAggregationStatsRule.java index 7bb9f3a6b752..66bb664137bf 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestFilterProjectAggregationStatsRule.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestFilterProjectAggregationStatsRule.java @@ -145,7 +145,7 @@ public void testFilterAndProjectOverAggregationStats() return pb.filter( new Comparison(GREATER_THAN, new Reference(INTEGER, "count_on_x"), new Constant(INTEGER, 0L)), // Non-narrowing projection - pb.project(Assignments.of(pb.symbol("x_1"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "x"), new Constant(INTEGER, 1L))), aggregatedOutput, aggregatedOutput.toSymbolReference()), + pb.project(Assignments.of(pb.symbol("x_1", INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "x"), new Constant(INTEGER, 1L))), aggregatedOutput, aggregatedOutput.toSymbolReference()), pb.aggregation(ab -> ab .addAggregation(aggregatedOutput, aggregation("count", ImmutableList.of(new Reference(BIGINT, "x"))), ImmutableList.of(BIGINT)) .singleGroupingSet(pb.symbol("y", BIGINT)) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestTypeValidator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestTypeValidator.java index 39d4557366ef..abf40c00f72d 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestTypeValidator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestTypeValidator.java @@ -179,25 +179,6 @@ public void testValidAggregation() assertTypesValid(node); } - @Test - public void testInvalidProject() - { - Expression expression1 = new Cast(columnB.toSymbolReference(), BIGINT); - Expression expression2 = new Cast(columnA.toSymbolReference(), INTEGER); - Assignments assignments = Assignments.builder() - .put(symbolAllocator.newSymbol(expression1), expression1) // should be INTEGER - .put(symbolAllocator.newSymbol(expression1), expression2) - .build(); - PlanNode node = new ProjectNode( - newId(), - baseTableScan, - assignments); - - assertThatThrownBy(() -> assertTypesValid(node)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageMatching("type of symbol 'expr(_[0-9]+)?' is expected to be bigint, but the actual type is integer"); - } - @Test public void testInvalidAggregationFunctionCall() { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestArraySortAfterArrayDistinct.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestArraySortAfterArrayDistinct.java index 4d0a32a89f04..86bb40e3625a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestArraySortAfterArrayDistinct.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestArraySortAfterArrayDistinct.java @@ -73,7 +73,7 @@ private void test(Expression original, Expression rewritten) tester().assertThat(new ArraySortAfterArrayDistinct(tester().getPlannerContext()).projectExpressionRewrite()) .on(p -> p.project( Assignments.builder() - .put(p.symbol("output"), original) + .put(p.symbol("output", original.type()), original) .build(), p.values())) .matches( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateInnerUnnestWithGlobalAggregation.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateInnerUnnestWithGlobalAggregation.java index 4bd7f0dabe56..4594dbff11f3 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateInnerUnnestWithGlobalAggregation.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateInnerUnnestWithGlobalAggregation.java @@ -19,6 +19,7 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; +import io.trino.spi.type.ArrayType; import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Constant; @@ -62,7 +63,8 @@ public class TestDecorrelateInnerUnnestWithGlobalAggregation private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); private static final ResolvedFunction REGEXP_EXTRACT_ALL = FUNCTIONS.resolveFunction("regexp_extract_all", fromTypes(VARCHAR, VARCHAR)); private static final ResolvedFunction ADD_INTEGER = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(INTEGER, INTEGER)); - private static final ResolvedFunction MODULUS_INTEGER = FUNCTIONS.resolveOperator(OperatorType.MODULUS, ImmutableList.of(INTEGER, INTEGER)); + private static final ResolvedFunction ADD_BIGINT = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(BIGINT, BIGINT)); + private static final ResolvedFunction MODULUS_BIGINT = FUNCTIONS.resolveOperator(OperatorType.MODULUS, ImmutableList.of(BIGINT, BIGINT)); private static final ResolvedFunction NEGATION_BIGINT = FUNCTIONS.resolveOperator(OperatorType.NEGATION, ImmutableList.of(BIGINT)); @Test @@ -282,7 +284,7 @@ public void testProjectOverGlobalAggregation() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), p.project( - Assignments.of(p.symbol("sum_1"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), + Assignments.of(p.symbol("sum_1", INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), p.aggregation(innerBuilder -> innerBuilder .globalGrouping() .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "unnested_corr"))), ImmutableList.of(BIGINT)) @@ -331,11 +333,11 @@ public void testPreprojectUnnestSymbol() .addAggregation(p.symbol("max"), PlanBuilder.aggregation("max", ImmutableList.of(new Reference(VARCHAR, "unnested_corr"))), ImmutableList.of(BIGINT)) .source(p.unnest( ImmutableList.of(), - ImmutableList.of(new UnnestNode.Mapping(p.symbol("char_array"), ImmutableList.of(p.symbol("unnested_corr", VARCHAR)))), + ImmutableList.of(new UnnestNode.Mapping(p.symbol("char_array", new ArrayType(VARCHAR)), ImmutableList.of(p.symbol("unnested_corr", VARCHAR)))), Optional.empty(), INNER, p.project( - Assignments.of(p.symbol("char_array"), regexpExtractAll), + Assignments.of(p.symbol("char_array", new ArrayType(VARCHAR)), regexpExtractAll), p.values(ImmutableList.of(), ImmutableList.of(ImmutableList.of()))))))); }) .matches( @@ -369,22 +371,22 @@ public void testMultipleNodesOverUnnestInSubquery() ImmutableList.of(p.symbol("groups"), p.symbol("numbers")), p.values(p.symbol("groups"), p.symbol("numbers")), p.project( - Assignments.of(p.symbol("sum_1"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), + Assignments.of(p.symbol("sum_1", BIGINT), new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "sum"), new Constant(BIGINT, 1L)))), p.aggregation(globalBuilder -> globalBuilder .globalGrouping() - .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "negate"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("sum", BIGINT), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "negate"))), ImmutableList.of(BIGINT)) .source(p.project( Assignments.builder() - .put(p.symbol("negate"), new Call(NEGATION_BIGINT, ImmutableList.of(new Reference(BIGINT, "max")))) + .put(p.symbol("negate", BIGINT), new Call(NEGATION_BIGINT, ImmutableList.of(new Reference(BIGINT, "max")))) .build(), p.aggregation(groupedBuilder -> groupedBuilder .singleGroupingSet(p.symbol("group")) - .addAggregation(p.symbol("max"), PlanBuilder.aggregation("max", ImmutableList.of(new Reference(BIGINT, "modulo"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("max", BIGINT), PlanBuilder.aggregation("max", ImmutableList.of(new Reference(BIGINT, "modulo"))), ImmutableList.of(BIGINT)) .source( p.project( Assignments.builder() .putIdentities(ImmutableList.of(p.symbol("group"), p.symbol("number"))) - .put(p.symbol("modulo"), new Call(MODULUS_INTEGER, ImmutableList.of(new Reference(INTEGER, "number"), new Constant(INTEGER, 10L)))) + .put(p.symbol("modulo", BIGINT), new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "number"), new Constant(BIGINT, 10L)))) .build(), p.unnest( ImmutableList.of(), @@ -397,7 +399,7 @@ public void testMultipleNodesOverUnnestInSubquery() .matches( project( project( - ImmutableMap.of("sum_1", expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))))), + ImmutableMap.of("sum_1", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "sum"), new Constant(BIGINT, 1L))))), aggregation( singleGroupingSet("groups", "numbers", "unique"), ImmutableMap.of(Optional.of("sum"), aggregationFunction("sum", ImmutableList.of("negated"))), @@ -415,7 +417,7 @@ public void testMultipleNodesOverUnnestInSubquery() Optional.empty(), SINGLE, project( - ImmutableMap.of("modulo", expression(new Call(MODULUS_INTEGER, ImmutableList.of(new Reference(INTEGER, "number"), new Constant(INTEGER, 10L))))), + ImmutableMap.of("modulo", expression(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "number"), new Constant(BIGINT, 10L))))), project( ImmutableMap.of("mask", expression(new Not(new IsNull(new Reference(BIGINT, "ordinality"))))), unnest( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateLeftUnnestWithGlobalAggregation.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateLeftUnnestWithGlobalAggregation.java index e8d0f2b57d5b..ba6515b0676f 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateLeftUnnestWithGlobalAggregation.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateLeftUnnestWithGlobalAggregation.java @@ -19,6 +19,7 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.function.OperatorType; +import io.trino.spi.type.ArrayType; import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Constant; @@ -56,7 +57,8 @@ public class TestDecorrelateLeftUnnestWithGlobalAggregation private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); private static final ResolvedFunction REGEXP_EXTRACT_ALL = FUNCTIONS.resolveFunction("regexp_extract_all", fromTypes(VARCHAR, VARCHAR)); private static final ResolvedFunction ADD_INTEGER = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(INTEGER, INTEGER)); - private static final ResolvedFunction MODULUS_INTEGER = FUNCTIONS.resolveOperator(OperatorType.MODULUS, ImmutableList.of(INTEGER, INTEGER)); + private static final ResolvedFunction ADD_BIGINT = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(BIGINT, BIGINT)); + private static final ResolvedFunction MODULUS_BIGINT = FUNCTIONS.resolveOperator(OperatorType.MODULUS, ImmutableList.of(BIGINT, BIGINT)); private static final ResolvedFunction NEGATION_BIGINT = FUNCTIONS.resolveOperator(OperatorType.NEGATION, ImmutableList.of(BIGINT)); @Test @@ -261,7 +263,7 @@ public void testProjectOverGlobalAggregation() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), p.project( - Assignments.of(p.symbol("sum_1"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), + Assignments.of(p.symbol("sum_1", INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), p.aggregation(innerBuilder -> innerBuilder .globalGrouping() .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "unnested_corr"))), ImmutableList.of(BIGINT)) @@ -310,11 +312,11 @@ public void testPreprojectUnnestSymbol() .addAggregation(p.symbol("max"), PlanBuilder.aggregation("max", ImmutableList.of(new Reference(BIGINT, "unnested_char"))), ImmutableList.of(BIGINT)) .source(p.unnest( ImmutableList.of(), - ImmutableList.of(new UnnestNode.Mapping(p.symbol("char_array"), ImmutableList.of(p.symbol("unnested_char")))), + ImmutableList.of(new UnnestNode.Mapping(p.symbol("char_array", new ArrayType(VARCHAR)), ImmutableList.of(p.symbol("unnested_char")))), Optional.empty(), LEFT, p.project( - Assignments.of(p.symbol("char_array"), regexpExtractAll), + Assignments.of(p.symbol("char_array", new ArrayType(VARCHAR)), regexpExtractAll), p.values(ImmutableList.of(), ImmutableList.of(ImmutableList.of()))))))); }) .matches( @@ -345,7 +347,7 @@ public void testMultipleNodesOverUnnestInSubquery() ImmutableList.of(p.symbol("groups"), p.symbol("numbers")), p.values(p.symbol("groups"), p.symbol("numbers")), p.project( - Assignments.of(p.symbol("sum_1"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), + Assignments.of(p.symbol("sum_1", BIGINT), new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "sum"), new Constant(BIGINT, 1L)))), p.aggregation(globalBuilder -> globalBuilder .globalGrouping() .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "negate"))), ImmutableList.of(BIGINT)) @@ -360,7 +362,7 @@ public void testMultipleNodesOverUnnestInSubquery() p.project( Assignments.builder() .putIdentities(ImmutableList.of(p.symbol("group"), p.symbol("number"))) - .put(p.symbol("modulo"), new Call(MODULUS_INTEGER, ImmutableList.of(new Reference(INTEGER, "number"), new Constant(INTEGER, 10L)))) + .put(p.symbol("modulo", BIGINT), new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "number"), new Constant(BIGINT, 10L)))) .build(), p.unnest( ImmutableList.of(), @@ -373,7 +375,7 @@ public void testMultipleNodesOverUnnestInSubquery() .matches( project( project( - ImmutableMap.of("sum_1", expression(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))))), + ImmutableMap.of("sum_1", expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "sum"), new Constant(BIGINT, 1L))))), aggregation( singleGroupingSet("groups", "numbers", "unique"), ImmutableMap.of(Optional.of("sum"), aggregationFunction("sum", ImmutableList.of("negated"))), @@ -390,7 +392,7 @@ public void testMultipleNodesOverUnnestInSubquery() Optional.empty(), SINGLE, project( - ImmutableMap.of("modulo", expression(new Call(MODULUS_INTEGER, ImmutableList.of(new Reference(INTEGER, "number"), new Constant(INTEGER, 10L))))), + ImmutableMap.of("modulo", expression(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "number"), new Constant(BIGINT, 10L))))), unnest( ImmutableList.of("groups", "numbers", "unique"), ImmutableList.of( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateUnnest.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateUnnest.java index 1316234e5f2f..9cf1c307b87c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateUnnest.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateUnnest.java @@ -18,6 +18,7 @@ import io.airlift.slice.Slices; import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; +import io.trino.spi.type.ArrayType; import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Comparison; @@ -348,7 +349,7 @@ public void testProject() JoinType.LEFT, TRUE, p.project( - Assignments.of(p.symbol("boolean_result"), new IsNull(new Reference(BIGINT, "unnested_corr"))), + Assignments.of(p.symbol("boolean_result", BOOLEAN), new IsNull(new Reference(BIGINT, "unnested_corr"))), p.unnest( ImmutableList.of(), ImmutableList.of(new UnnestNode.Mapping(p.symbol("corr"), ImmutableList.of(p.symbol("unnested_corr")))), @@ -414,11 +415,11 @@ public void testDifferentNodesInSubquery() TRUE, p.enforceSingleRow( p.project( - Assignments.of(p.symbol("integer_result"), ifExpression(new Reference(BOOLEAN, "boolean_result"), new Constant(INTEGER, 1L), new Constant(INTEGER, 1L))), + Assignments.of(p.symbol("integer_result", INTEGER), ifExpression(new Reference(BOOLEAN, "boolean_result"), new Constant(INTEGER, 1L), new Constant(INTEGER, 1L))), p.limit( 5, p.project( - Assignments.of(p.symbol("boolean_result"), new IsNull(new Reference(BIGINT, "unnested_corr"))), + Assignments.of(p.symbol("boolean_result", BOOLEAN), new IsNull(new Reference(BIGINT, "unnested_corr"))), p.topN( 10, ImmutableList.of(p.symbol("unnested_corr")), @@ -509,11 +510,11 @@ public void testPreprojectUnnestSymbol() TRUE, p.unnest( ImmutableList.of(), - ImmutableList.of(new UnnestNode.Mapping(p.symbol("char_array"), ImmutableList.of(p.symbol("unnested_char")))), + ImmutableList.of(new UnnestNode.Mapping(p.symbol("char_array", new ArrayType(VARCHAR)), ImmutableList.of(p.symbol("unnested_char")))), Optional.empty(), LEFT, p.project( - Assignments.of(p.symbol("char_array"), regexpExtractAll), + Assignments.of(p.symbol("char_array", new ArrayType(VARCHAR)), regexpExtractAll), p.values(ImmutableList.of(), ImmutableList.of(ImmutableList.of()))))); }) .matches( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestEliminateCrossJoins.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestEliminateCrossJoins.java index 0aa7ac47dc8c..104d80835a7e 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestEliminateCrossJoins.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestEliminateCrossJoins.java @@ -318,8 +318,8 @@ private PlanNode projectNode(PlanNode source, String symbol1, Expression express idAllocator.getNextId(), source, Assignments.of( - new Symbol(UNKNOWN, symbol1), expression1, - new Symbol(UNKNOWN, symbol2), expression2)); + new Symbol(INTEGER, symbol1), expression1, + new Symbol(INTEGER, symbol2), expression2)); } private JoinNode joinNode(PlanNode left, PlanNode right, String... symbols) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestExpressionRewriteRuleSet.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestExpressionRewriteRuleSet.java index 254b6d9df633..a9d36657f5b9 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestExpressionRewriteRuleSet.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestExpressionRewriteRuleSet.java @@ -19,8 +19,6 @@ import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; import io.trino.sql.ir.ExpressionTreeRewriter; -import io.trino.sql.ir.IsNull; -import io.trino.sql.ir.Not; import io.trino.sql.ir.Reference; import io.trino.sql.ir.Row; import io.trino.sql.planner.Symbol; @@ -34,10 +32,8 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.sql.ir.Booleans.TRUE; -import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.patternRecognition; -import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; import static io.trino.sql.planner.rowpattern.Patterns.label; @@ -61,23 +57,12 @@ public Expression rewriteRow(Row node, Void context, ExpressionTreeRewriter p.project( - Assignments.of(p.symbol("y"), new Not(new IsNull(new Reference(BIGINT, "x")))), - p.values(p.symbol("x")))) - .matches( - project(ImmutableMap.of("y", expression(new Constant(INTEGER, 0L))), values("x"))); - } - @Test public void testProjectionExpressionNotRewritten() { tester().assertThat(zeroRewriter.projectExpressionRewrite()) .on(p -> p.project( - Assignments.of(p.symbol("y"), new Constant(INTEGER, 0L)), + Assignments.of(p.symbol("y", INTEGER), new Constant(INTEGER, 0L)), p.values(p.symbol("x")))) .doesNotFire(); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementTableFunctionSource.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementTableFunctionSource.java index 4a183b1067b0..a1d4961b3fc0 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementTableFunctionSource.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementTableFunctionSource.java @@ -1217,7 +1217,7 @@ public void testCoerceForCopartitioning() // coerce column c for co-partitioning p.project( Assignments.builder() - .put(c, new Reference(BIGINT, "c")) + .put(c, new Reference(TINYINT, "c")) .put(d, new Reference(BIGINT, "d")) .put(cCoerced, new Cast(new Reference(BIGINT, "c"), INTEGER)) .build(), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjectIntoFilter.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjectIntoFilter.java index 5942ee689918..2b82e1de8016 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjectIntoFilter.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjectIntoFilter.java @@ -46,9 +46,9 @@ public void testInlineProjection() { tester().assertThat(new InlineProjectIntoFilter()) .on(p -> p.filter( - new Reference(INTEGER, "a"), + new Reference(BOOLEAN, "a"), p.project( - Assignments.of(p.symbol("a"), new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 0L))), + Assignments.of(p.symbol("a", BOOLEAN), new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 0L))), p.values(p.symbol("b", INTEGER))))) .matches( project( @@ -92,7 +92,7 @@ public void testNoSimpleConjuncts() .on(p -> p.filter( new Logical(OR, ImmutableList.of(new Reference(BOOLEAN, "a"), FALSE)), p.project( - Assignments.of(p.symbol("a"), new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 0L))), + Assignments.of(p.symbol("a", BOOLEAN), new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 0L))), p.values(p.symbol("b", INTEGER))))) .doesNotFire(); } @@ -168,8 +168,8 @@ public void testTrivialProjection() .on(p -> p.filter( new Reference(INTEGER, "a"), p.project( - Assignments.of(p.symbol("a"), new Reference(INTEGER, "a")), - p.values(p.symbol("a"))))) + Assignments.of(p.symbol("a", INTEGER), new Reference(INTEGER, "a")), + p.values(p.symbol("a", INTEGER))))) .doesNotFire(); // renaming projection @@ -177,8 +177,8 @@ public void testTrivialProjection() .on(p -> p.filter( new Reference(INTEGER, "a"), p.project( - Assignments.of(p.symbol("a"), new Reference(INTEGER, "b")), - p.values(p.symbol("b"))))) + Assignments.of(p.symbol("a", INTEGER), new Reference(INTEGER, "b")), + p.values(p.symbol("b", INTEGER))))) .doesNotFire(); } @@ -189,8 +189,8 @@ public void testCorrelationSymbol() .on(p -> p.filter( new Reference(INTEGER, "corr"), p.project( - Assignments.of(p.symbol("a"), new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 0L))), - p.values(p.symbol("b"))))) + Assignments.of(p.symbol("a", BOOLEAN), new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 0L))), + p.values(p.symbol("b", INTEGER))))) .doesNotFire(); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjections.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjections.java index 4b188f9abe5f..8e235aac0761 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjections.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjections.java @@ -58,22 +58,22 @@ public void test() .on(p -> p.project( Assignments.builder() - .put(p.symbol("identity"), new Reference(BIGINT, "symbol")) // identity - .put(p.symbol("multi_complex_1"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "complex"), new Constant(INTEGER, 1L)))) // complex expression referenced multiple times - .put(p.symbol("multi_complex_2"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "complex"), new Constant(INTEGER, 2L)))) // complex expression referenced multiple times - .put(p.symbol("multi_literal_1"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "literal"), new Constant(INTEGER, 1L)))) // literal referenced multiple times - .put(p.symbol("multi_literal_2"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "literal"), new Constant(INTEGER, 2L)))) // literal referenced multiple times - .put(p.symbol("single_complex"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "complex_2"), new Constant(INTEGER, 2L)))) // complex expression reference only once - .put(p.symbol("msg_xx"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "z"), new Constant(INTEGER, 1L)))) - .put(p.symbol("multi_symbol_reference"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "v"), new Reference(INTEGER, "v")))) + .put(p.symbol("identity", INTEGER), new Reference(INTEGER, "symbol")) // identity + .put(p.symbol("multi_complex_1", INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "complex"), new Constant(INTEGER, 1L)))) // complex expression referenced multiple times + .put(p.symbol("multi_complex_2", INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "complex"), new Constant(INTEGER, 2L)))) // complex expression referenced multiple times + .put(p.symbol("multi_literal_1", INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "literal"), new Constant(INTEGER, 1L)))) // literal referenced multiple times + .put(p.symbol("multi_literal_2", INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "literal"), new Constant(INTEGER, 2L)))) // literal referenced multiple times + .put(p.symbol("single_complex", INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "complex_2"), new Constant(INTEGER, 2L)))) // complex expression reference only once + .put(p.symbol("msg_xx", INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "z"), new Constant(INTEGER, 1L)))) + .put(p.symbol("multi_symbol_reference", INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "v"), new Reference(INTEGER, "v")))) .build(), p.project(Assignments.builder() - .put(p.symbol("symbol"), new Reference(INTEGER, "x")) - .put(p.symbol("complex"), new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "x"), new Constant(INTEGER, 2L)))) - .put(p.symbol("literal"), new Constant(INTEGER, 1L)) - .put(p.symbol("complex_2"), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "x"), new Constant(INTEGER, 1L)))) - .put(p.symbol("z"), new FieldReference(new Reference(MSG_TYPE, "msg"), 0)) - .put(p.symbol("v"), new Reference(INTEGER, "x")) + .put(p.symbol("symbol", INTEGER), new Reference(INTEGER, "x")) + .put(p.symbol("complex", INTEGER), new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "x"), new Constant(INTEGER, 2L)))) + .put(p.symbol("literal", INTEGER), new Constant(INTEGER, 1L)) + .put(p.symbol("complex_2", INTEGER), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "x"), new Constant(INTEGER, 1L)))) + .put(p.symbol("z", MSG_TYPE.getFields().get(0).getType()), new FieldReference(new Reference(MSG_TYPE, "msg"), 0)) + .put(p.symbol("v", INTEGER), new Reference(INTEGER, "x")) .build(), p.values(p.symbol("x", INTEGER), p.symbol("msg", MSG_TYPE))))) .matches( @@ -109,8 +109,8 @@ public void testInlineEffectivelyLiteral() p.project( Assignments.builder() // Use the literal-like expression multiple times. Single-use expression may be inlined regardless of whether it's a literal - .put(p.symbol("decimal_multiplication"), new Call(MULTIPLY_DECIMAL_8_4, ImmutableList.of(new Reference(createDecimalType(8, 4), "decimal_literal"), new Reference(createDecimalType(8, 4), "decimal_literal")))) - .put(p.symbol("decimal_addition"), new Call(ADD_DECIMAL_8_4, ImmutableList.of(new Reference(createDecimalType(8, 4), "decimal_literal"), new Reference(createDecimalType(8, 4), "decimal_literal")))) + .put(p.symbol("decimal_multiplication", createDecimalType(16, 8)), new Call(MULTIPLY_DECIMAL_8_4, ImmutableList.of(new Reference(createDecimalType(8, 4), "decimal_literal"), new Reference(createDecimalType(8, 4), "decimal_literal")))) + .put(p.symbol("decimal_addition", createDecimalType(9, 4)), new Call(ADD_DECIMAL_8_4, ImmutableList.of(new Reference(createDecimalType(8, 4), "decimal_literal"), new Reference(createDecimalType(8, 4), "decimal_literal")))) .build(), p.project(Assignments.builder() .put(p.symbol("decimal_literal", createDecimalType(8, 4)), new Constant(createDecimalType(8, 4), Decimals.valueOfShort(new BigDecimal("12.5")))) @@ -189,9 +189,9 @@ public void testSubqueryProjections() tester().assertThat(new InlineProjections()) .on(p -> p.project( - Assignments.identity(p.symbol("fromOuterScope"), p.symbol("value_1")), + Assignments.identity(p.symbol("fromOuterScope"), p.symbol("value_1", INTEGER)), p.project( - Assignments.of(p.symbol("value_1"), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "value"), new Constant(INTEGER, 1L)))), + Assignments.of(p.symbol("value_1", INTEGER), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "value"), new Constant(INTEGER, 1L)))), p.values(p.symbol("value"))))) .matches( project( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeAdjacentWindows.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeAdjacentWindows.java index 029cfa8a335d..7beec2843274 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeAdjacentWindows.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeAdjacentWindows.java @@ -17,7 +17,6 @@ import com.google.common.collect.ImmutableMap; import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; -import io.trino.sql.ir.Cast; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Reference; @@ -176,7 +175,7 @@ public void testIntermediateProjectNodes() ImmutableMap.of(p.symbol("lagOutput"), newWindowNodeFunction(LAG, new Symbol(DOUBLE, "a"), new Symbol(INTEGER, "one"))), p.project( Assignments.builder() - .put(p.symbol("one", INTEGER), new Cast(new Constant(INTEGER, 1L), BIGINT)) + .put(p.symbol("one", INTEGER), new Constant(INTEGER, 1L)) .putIdentities(ImmutableList.of(new Symbol(DOUBLE, "a"), p.symbol("avgOutput", DOUBLE))) .build(), p.project( @@ -198,7 +197,7 @@ public void testIntermediateProjectNodes() .addFunction(avgOutputAlias, windowFunction(AVG.getSignature().getName().getFunctionName(), ImmutableList.of(columnAAlias), DEFAULT_FRAME)), strictProject( ImmutableMap.of( - oneAlias, PlanMatchPattern.expression(new Cast(new Constant(INTEGER, 1L), BIGINT)), + oneAlias, PlanMatchPattern.expression(new Constant(INTEGER, 1L)), columnAAlias, PlanMatchPattern.expression(new Reference(BIGINT, columnAAlias)), unusedAlias, PlanMatchPattern.expression(new Reference(BIGINT, unusedAlias))), strictProject( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeProjectWithValues.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeProjectWithValues.java index 258951431ac6..e3130ce1a0f4 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeProjectWithValues.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeProjectWithValues.java @@ -28,11 +28,11 @@ import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; -import io.trino.type.UnknownType; import org.junit.jupiter.api.Test; import static io.trino.spi.StandardErrorCode.GENERIC_USER_ERROR; import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.CharType.createCharType; import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; @@ -40,6 +40,7 @@ import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.LogicalPlanner.failFunction; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; +import static io.trino.type.UnknownType.UNKNOWN; public class TestMergeProjectWithValues extends BaseRuleTest @@ -59,7 +60,7 @@ public void testDoesNotFireOnNonRowType() p.valuesOfExpressions( ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of(new Cast( - new Row(ImmutableList.of(new Constant(UnknownType.UNKNOWN, null), new Constant(UnknownType.UNKNOWN, null))), + new Row(ImmutableList.of(new Constant(UNKNOWN, null), new Constant(UNKNOWN, null))), RowType.anonymous(ImmutableList.of(BIGINT, BIGINT))))))) .doesNotFire(); } @@ -117,7 +118,7 @@ public void testValuesWithoutOutputSymbols() tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( - Assignments.of(p.symbol("a"), new Constant(createCharType(1), Slices.utf8Slice("x")), p.symbol("b"), TRUE), + Assignments.of(p.symbol("a", createCharType(1)), new Constant(createCharType(1), Slices.utf8Slice("x")), p.symbol("b", BOOLEAN), TRUE), p.values( ImmutableList.of(), ImmutableList.of(ImmutableList.of(), ImmutableList.of())))) @@ -131,7 +132,7 @@ public void testValuesWithoutOutputSymbols() tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( - Assignments.of(p.symbol("a"), new Constant(createCharType(1), Slices.utf8Slice("x")), p.symbol("b"), TRUE), + Assignments.of(p.symbol("a", createCharType(1)), new Constant(createCharType(1), Slices.utf8Slice("x")), p.symbol("b", BOOLEAN), TRUE), p.values( ImmutableList.of(), ImmutableList.of()))) @@ -205,18 +206,18 @@ public void testDoNotFireOnNonDeterministicValues() tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( Assignments.of( - p.symbol("x"), new Reference(DOUBLE, "rand"), - p.symbol("y"), new Reference(DOUBLE, "rand")), + p.symbol("x", DOUBLE), new Reference(DOUBLE, "rand"), + p.symbol("y", DOUBLE), new Reference(DOUBLE, "rand")), p.valuesOfExpressions( - ImmutableList.of(p.symbol("rand")), + ImmutableList.of(p.symbol("rand", DOUBLE)), ImmutableList.of(new Row(ImmutableList.of(randomFunction)))))) .doesNotFire(); tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( - Assignments.of(p.symbol("x"), new Call(ADD_DOUBLE, ImmutableList.of(new Reference(DOUBLE, "rand"), new Reference(DOUBLE, "rand")))), + Assignments.of(p.symbol("x", DOUBLE), new Call(ADD_DOUBLE, ImmutableList.of(new Reference(DOUBLE, "rand"), new Reference(DOUBLE, "rand")))), p.valuesOfExpressions( - ImmutableList.of(p.symbol("rand")), + ImmutableList.of(p.symbol("rand", DOUBLE)), ImmutableList.of(new Row(ImmutableList.of(randomFunction)))))) .doesNotFire(); } @@ -245,7 +246,7 @@ public void testCorrelation() // correlation symbol is not present in the resulting expression tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( - Assignments.of(p.symbol("x"), new Constant(INTEGER, 1L)), + Assignments.of(p.symbol("x", INTEGER), new Constant(INTEGER, 1L)), p.valuesOfExpressions( ImmutableList.of(p.symbol("a")), ImmutableList.of(new Row(ImmutableList.of(new Reference(INTEGER, "corr"))))))) @@ -259,7 +260,7 @@ public void testFailingExpression() tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( - Assignments.of(p.symbol("x"), failFunction), + Assignments.of(p.symbol("x", UNKNOWN), failFunction), p.valuesOfExpressions( ImmutableList.of(p.symbol("a")), ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L))))))) @@ -271,12 +272,12 @@ public void testMergeProjectWithValues() { tester().assertThat(new MergeProjectWithValues()) .on(p -> { - Symbol a = p.symbol("a"); - Symbol b = p.symbol("b"); - Symbol c = p.symbol("c"); - Symbol d = p.symbol("d"); - Symbol e = p.symbol("e"); - Symbol f = p.symbol("f"); + Symbol a = p.symbol("a", BOOLEAN); + Symbol b = p.symbol("b", BOOLEAN); + Symbol c = p.symbol("c", BOOLEAN); + Symbol d = p.symbol("d", BOOLEAN); + Symbol e = p.symbol("e", BOOLEAN); + Symbol f = p.symbol("f", INTEGER); Assignments.Builder assignments = Assignments.builder(); assignments.putIdentity(a); // identity assignment assignments.put(d, b.toSymbolReference()); // renaming assignment @@ -301,12 +302,12 @@ public void testMergeProjectWithValues() // ValuesNode has no rows tester().assertThat(new MergeProjectWithValues()) .on(p -> { - Symbol a = p.symbol("a"); - Symbol b = p.symbol("b"); - Symbol c = p.symbol("c"); - Symbol d = p.symbol("d"); - Symbol e = p.symbol("e"); - Symbol f = p.symbol("f"); + Symbol a = p.symbol("a", BOOLEAN); + Symbol b = p.symbol("b", BOOLEAN); + Symbol c = p.symbol("c", BOOLEAN); + Symbol d = p.symbol("d", BOOLEAN); + Symbol e = p.symbol("e", BOOLEAN); + Symbol f = p.symbol("f", INTEGER); Assignments.Builder assignments = Assignments.builder(); assignments.putIdentity(a); // identity assignment assignments.put(d, b.toSymbolReference()); // renaming assignment diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMultipleDistinctAggregationToMarkDistinct.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMultipleDistinctAggregationToMarkDistinct.java index 603c5d2f8da5..74c8e0a855ee 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMultipleDistinctAggregationToMarkDistinct.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMultipleDistinctAggregationToMarkDistinct.java @@ -36,6 +36,7 @@ import static io.trino.SystemSessionProperties.OPTIMIZE_DISTINCT_AGGREGATIONS; import static io.trino.SystemSessionProperties.TASK_CONCURRENCY; import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; @@ -112,8 +113,8 @@ public void testDistinctWithFilter() Assignments.builder() .putIdentity(p.symbol("input1")) .putIdentity(p.symbol("input2")) - .put(p.symbol("filter1"), new Comparison(GREATER_THAN, new Reference(INTEGER, "input2"), new Constant(INTEGER, 0L))) - .put(p.symbol("filter2"), new Comparison(GREATER_THAN, new Reference(INTEGER, "input1"), new Constant(INTEGER, 0L))) + .put(p.symbol("filter1", BOOLEAN), new Comparison(GREATER_THAN, new Reference(INTEGER, "input2"), new Constant(INTEGER, 0L))) + .put(p.symbol("filter2", BOOLEAN), new Comparison(GREATER_THAN, new Reference(INTEGER, "input1"), new Constant(INTEGER, 0L))) .build(), p.values( p.symbol("input1"), @@ -130,8 +131,8 @@ public void testDistinctWithFilter() Assignments.builder() .putIdentity(p.symbol("input1")) .putIdentity(p.symbol("input2")) - .put(p.symbol("filter1"), new Comparison(GREATER_THAN, new Reference(INTEGER, "input2"), new Constant(INTEGER, 0L))) - .put(p.symbol("filter2"), new Comparison(GREATER_THAN, new Reference(INTEGER, "input1"), new Constant(INTEGER, 0L))) + .put(p.symbol("filter1", BOOLEAN), new Comparison(GREATER_THAN, new Reference(INTEGER, "input2"), new Constant(INTEGER, 0L))) + .put(p.symbol("filter2", BOOLEAN), new Comparison(GREATER_THAN, new Reference(INTEGER, "input1"), new Constant(INTEGER, 0L))) .build(), p.values( p.symbol("input1"), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestOptimizeDuplicateInsensitiveJoins.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestOptimizeDuplicateInsensitiveJoins.java index 4704ae655c4e..cf376dc91eea 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestOptimizeDuplicateInsensitiveJoins.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestOptimizeDuplicateInsensitiveJoins.java @@ -205,7 +205,7 @@ public void testNondeterministicProjection() .on(p -> { Symbol symbolA = p.symbol("a"); Symbol symbolB = p.symbol("b"); - Symbol symbolC = p.symbol("c"); + Symbol symbolC = p.symbol("c", DOUBLE); return p.aggregation(a -> a .singleGroupingSet(symbolA) .source(p.project( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTableScanColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTableScanColumns.java index fe85b75174e6..3b2d2772036c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTableScanColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTableScanColumns.java @@ -66,7 +66,7 @@ public void testNotAllOutputsReferenced() Symbol orderdate = p.symbol("orderdate", DATE); Symbol totalprice = p.symbol("totalprice", DOUBLE); return p.project( - Assignments.of(p.symbol("x"), totalprice.toSymbolReference()), + Assignments.of(p.symbol("x", DOUBLE), totalprice.toSymbolReference()), p.tableScan( tester().getCurrentCatalogTableHandle(TINY_SCHEMA_NAME, "orders"), ImmutableList.of(orderdate, totalprice), @@ -90,7 +90,7 @@ public void testPruneEnforcedConstraint() TpchColumnHandle orderdateHandle = new TpchColumnHandle(orderdate.getName(), DATE); TpchColumnHandle totalpriceHandle = new TpchColumnHandle(totalprice.getName(), DOUBLE); return p.project( - Assignments.of(p.symbol("x"), totalprice.toSymbolReference()), + Assignments.of(p.symbol("x", DOUBLE), totalprice.toSymbolReference()), p.tableScan( tester().getCurrentCatalogTableHandle(TINY_SCHEMA_NAME, "orders"), List.of(orderdate, totalprice), @@ -152,7 +152,7 @@ public void testPushColumnPruningProjection() Symbol symbolA = p.symbol("cola", DATE); Symbol symbolB = p.symbol("colb", DOUBLE); return p.project( - Assignments.of(p.symbol("x"), symbolB.toSymbolReference()), + Assignments.of(p.symbol("x", DOUBLE), symbolB.toSymbolReference()), p.tableScan( ruleTester.getCurrentCatalogTableHandle(testSchema, testTable), ImmutableList.of(symbolA, symbolB), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneValuesColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneValuesColumns.java index 830026d3cbe7..cfbd182ef68f 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneValuesColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneValuesColumns.java @@ -42,7 +42,7 @@ public void testNotAllOutputsReferenced() tester().assertThat(new PruneValuesColumns()) .on(p -> p.project( - Assignments.of(p.symbol("y"), new Reference(INTEGER, "x")), + Assignments.of(p.symbol("y", INTEGER), new Reference(INTEGER, "x")), p.values( ImmutableList.of(p.symbol("unused"), p.symbol("x")), ImmutableList.of( @@ -105,9 +105,9 @@ public void testDoNotPruneWhenValuesExpressionIsNotRow() tester().assertThat(new PruneValuesColumns()) .on(p -> p.project( - Assignments.of(p.symbol("x"), new Reference(INTEGER, "x")), + Assignments.of(p.symbol("x", INTEGER), new Reference(INTEGER, "x")), p.valuesOfExpressions( - ImmutableList.of(p.symbol("x"), p.symbol("y")), + ImmutableList.of(p.symbol("x", INTEGER), p.symbol("y")), ImmutableList.of(new Cast(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VarcharType.VARCHAR, Slices.utf8Slice("a")))), anonymousRow(BIGINT, createCharType(2))))))) .doesNotFire(); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushCastIntoRow.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushCastIntoRow.java index ef64e5e33b14..059fa2d316d2 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushCastIntoRow.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushCastIntoRow.java @@ -79,7 +79,7 @@ private void test(Expression original, Expression unwrapped) tester().assertThat(new PushCastIntoRow().projectExpressionRewrite()) .on(p -> p.project( Assignments.builder() - .put(p.symbol("output"), original) + .put(p.symbol("output", original.type()), original) .build(), p.values())) .matches( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownDereferencesRules.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownDereferencesRules.java index 907c0c0e4e5f..27ce5ad88f7f 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownDereferencesRules.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownDereferencesRules.java @@ -112,7 +112,7 @@ public void testDoesNotFire() .on(p -> p.project( Assignments.of( - p.symbol("expr_1"), new FieldReference(new Cast(new Row(ImmutableList.of(new Reference(ROW_TYPE, "a"), new Reference(BIGINT, "b"))), rowType(field("f1", rowType(field("x", BIGINT), field("y", BIGINT))), field("f2", BIGINT))), 0), + p.symbol("expr_1", rowType(field("x", BIGINT), field("y", BIGINT))), new FieldReference(new Cast(new Row(ImmutableList.of(new Reference(ROW_TYPE, "a"), new Reference(BIGINT, "b"))), rowType(field("f1", rowType(field("x", BIGINT), field("y", BIGINT))), field("f2", BIGINT))), 0), p.symbol("expr_2"), new FieldReference(new FieldReference(new Cast(new Row(ImmutableList.of(new Reference(ROW_TYPE, "a"), new Reference(BIGINT, "b"))), rowType(field("f1", rowType(field("x", BIGINT), field("y", BIGINT))), field("f2", BIGINT))), 0), 1)), p.project( Assignments.of( @@ -200,8 +200,8 @@ public void testPushDownDereferenceThroughJoin() .on(p -> p.project( Assignments.of( - p.symbol("expr"), new FieldReference(new Reference(ROW_TYPE, "msg1"), 0), - p.symbol("expr_2"), new Reference(ROW_TYPE, "msg2")), + p.symbol("expr", ROW_TYPE.getFields().get(0).getType()), new FieldReference(new Reference(ROW_TYPE, "msg1"), 0), + p.symbol("expr_2", ROW_TYPE), new Reference(ROW_TYPE, "msg2")), p.join(INNER, p.values(p.symbol("msg1", ROW_TYPE)), p.values(p.symbol("msg2", ROW_TYPE)), @@ -293,8 +293,8 @@ public void testPushdownDereferencesThroughUnnest() .on(p -> p.project( Assignments.of( - p.symbol("deref_replicate", BIGINT), new FieldReference(new Reference(rowType, "replicate"), 1), - p.symbol("deref_unnest", BIGINT), new Call(subscript, ImmutableList.of(new Reference(nestedColumnType, "unnested_row"), new Constant(BIGINT, 2L)))), + p.symbol("deref_replicate", rowType.getFields().get(1).getType()), new FieldReference(new Reference(rowType, "replicate"), 1), + p.symbol("deref_unnest", nestedColumnType.getElementType()), new Call(subscript, ImmutableList.of(new Reference(nestedColumnType, "unnested_row"), new Constant(BIGINT, 2L)))), p.unnest( ImmutableList.of(p.symbol("replicate", rowType)), ImmutableList.of( @@ -731,7 +731,7 @@ public void testMultiLevelPushdown() .on(p -> p.project( Assignments.of( - p.symbol("expr_1"), new FieldReference(new Reference(complexType, "a"), 0), + p.symbol("expr_1", complexType.getFields().get(0).getType()), new FieldReference(new Reference(complexType, "a"), 0), p.symbol("expr_2"), new Call( ADD_BIGINT, ImmutableList.of( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushLimitThroughProject.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushLimitThroughProject.java index eb32a25409f3..0d2411fe270a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushLimitThroughProject.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushLimitThroughProject.java @@ -30,6 +30,7 @@ import java.util.Optional; import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.limit; @@ -51,7 +52,7 @@ public void testPushdownLimitNonIdentityProjection() { tester().assertThat(new PushLimitThroughProject()) .on(p -> { - Symbol a = p.symbol("a"); + Symbol a = p.symbol("a", BOOLEAN); return p.limit(1, p.project( Assignments.of(a, TRUE), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushOffsetThroughProject.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushOffsetThroughProject.java index 882fe2936539..0e7369ce7b75 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushOffsetThroughProject.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushOffsetThroughProject.java @@ -19,6 +19,7 @@ import io.trino.sql.planner.plan.Assignments; import org.junit.jupiter.api.Test; +import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.offset; @@ -33,7 +34,7 @@ public void testPushdownOffsetNonIdentityProjection() { tester().assertThat(new PushOffsetThroughProject()) .on(p -> { - Symbol a = p.symbol("a"); + Symbol a = p.symbol("a", BOOLEAN); return p.offset( 5, p.project( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java index d75c2d72e804..6411df6d049d 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java @@ -60,7 +60,6 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.RowType.field; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.planner.ConnectorExpressionTranslator.translate; @@ -70,7 +69,6 @@ import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan; import static io.trino.testing.TestingHandles.TEST_CATALOG_NAME; import static io.trino.testing.TestingSession.testSessionBuilder; -import static io.trino.type.UnknownType.UNKNOWN; import static java.util.Arrays.asList; import static java.util.Objects.requireNonNull; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -124,9 +122,9 @@ public void testPushProjection() MockConnectorFactory factory = createMockFactory(ImmutableMap.of(columnName, columnHandle), Optional.of(this::mockApplyProjection)); try (RuleTester ruleTester = RuleTester.builder().withDefaultCatalogConnectorFactory(factory).build()) { // Prepare project node symbols and types - Symbol identity = new Symbol(UNKNOWN, "symbol_identity"); - Symbol dereference = new Symbol(UNKNOWN, "symbol_dereference"); - Symbol constant = new Symbol(UNKNOWN, "symbol_constant"); + Symbol identity = new Symbol(ROW_TYPE, "symbol_identity"); + Symbol dereference = new Symbol(BIGINT, "symbol_dereference"); + Symbol constant = new Symbol(BIGINT, "symbol_constant"); ImmutableMap types = ImmutableMap.of( baseColumn, ROW_TYPE, identity, ROW_TYPE, @@ -137,7 +135,7 @@ public void testPushProjection() Assignments inputProjections = Assignments.builder() .put(identity, baseColumn.toSymbolReference()) .put(dereference, new FieldReference(baseColumn.toSymbolReference(), 0)) - .put(constant, new Constant(INTEGER, 5L)) + .put(constant, new Constant(BIGINT, 5L)) .build(); // Compute expected symbols after applyProjection diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java index c566150bbba4..a6ea06b9a40b 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java @@ -54,7 +54,7 @@ public void testDoesNotFireNoExchange() tester().assertThat(new PushProjectionThroughExchange()) .on(p -> p.project( - Assignments.of(p.symbol("x"), new Constant(INTEGER, 3L)), + Assignments.of(p.symbol("x", INTEGER), new Constant(INTEGER, 3L)), p.values(p.symbol("a")))) .doesNotFire(); } @@ -90,7 +90,7 @@ public void testSimpleMultipleInputs() Symbol b = p.symbol("b"); Symbol c = p.symbol("c"); Symbol c2 = p.symbol("c2"); - Symbol x = p.symbol("x"); + Symbol x = p.symbol("x", INTEGER); return p.project( Assignments.of( x, new Constant(INTEGER, 3L), @@ -126,7 +126,7 @@ public void testHashMapping() Symbol h1 = p.symbol("h_1"); Symbol c = p.symbol("c", INTEGER); Symbol h = p.symbol("h"); - Symbol cTimes5 = p.symbol("c_times_5"); + Symbol cTimes5 = p.symbol("c_times_5", INTEGER); return p.project( Assignments.of( cTimes5, new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "c"), new Constant(INTEGER, 5L)))), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java index 5b9b358470a3..a88d2aa56413 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java @@ -51,7 +51,7 @@ public void testDoesNotFire() tester().assertThat(new PushProjectionThroughUnion()) .on(p -> p.project( - Assignments.of(p.symbol("x"), new Constant(INTEGER, 3L)), + Assignments.of(p.symbol("x", INTEGER), new Constant(INTEGER, 3L)), p.values(p.symbol("a")))) .doesNotFire(); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantExists.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantExists.java index 75212b470eb8..8cad4c88e6fe 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantExists.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantExists.java @@ -21,6 +21,7 @@ import io.trino.testing.TestingMetadata; import org.junit.jupiter.api.Test; +import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.sql.ir.Booleans.FALSE; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; @@ -34,7 +35,7 @@ public class TestRemoveRedundantExists public void testExistsFalse() { tester().assertThat(new RemoveRedundantExists()) - .on(p -> p.apply(ImmutableMap.of(p.symbol("exists"), new ApplyNode.Exists()), + .on(p -> p.apply(ImmutableMap.of(p.symbol("exists", BOOLEAN), new ApplyNode.Exists()), ImmutableList.of(), p.values(1), p.values(0))) @@ -48,7 +49,7 @@ public void testExistsFalse() public void testExistsTrue() { tester().assertThat(new RemoveRedundantExists()) - .on(p -> p.apply(ImmutableMap.of(p.symbol("exists"), new ApplyNode.Exists()), + .on(p -> p.apply(ImmutableMap.of(p.symbol("exists", BOOLEAN), new ApplyNode.Exists()), ImmutableList.of(), p.values(1), p.values(1))) @@ -62,7 +63,7 @@ public void testExistsTrue() public void testDoesNotFire() { tester().assertThat(new RemoveRedundantExists()) - .on(p -> p.apply(ImmutableMap.of(p.symbol("exists"), new ApplyNode.Exists()), + .on(p -> p.apply(ImmutableMap.of(p.symbol("exists", BOOLEAN), new ApplyNode.Exists()), ImmutableList.of(), p.values(1), p.tableScan(ImmutableList.of(), ImmutableMap.of()))) @@ -71,7 +72,7 @@ public void testDoesNotFire() tester().assertThat(new RemoveRedundantExists()) .on(p -> p.apply( ImmutableMap.builder() - .put(p.symbol("exists"), new ApplyNode.Exists()) + .put(p.symbol("exists", BOOLEAN), new ApplyNode.Exists()) .put(p.symbol("other"), new ApplyNode.In(p.symbol("value"), p.symbol("list"))) .buildOrThrow(), ImmutableList.of(), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceJoinOverConstantWithProject.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceJoinOverConstantWithProject.java index 53386549b28c..59a568f43893 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceJoinOverConstantWithProject.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceJoinOverConstantWithProject.java @@ -16,7 +16,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slices; -import io.trino.spi.type.VarcharType; import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Comparison; @@ -35,6 +34,7 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.RowType.field; import static io.trino.spi.type.RowType.rowType; +import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.strictProject; @@ -120,7 +120,7 @@ public void testDoesNotFireOnValuesWithNonRowExpression() .on(p -> p.join( INNER, - p.valuesOfExpressions(ImmutableList.of(p.symbol("a")), ImmutableList.of(new Cast(new Row(ImmutableList.of(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("true")))), rowType(field("b", BOOLEAN))))), + p.valuesOfExpressions(ImmutableList.of(p.symbol("a")), ImmutableList.of(new Cast(new Row(ImmutableList.of(new Constant(VARCHAR, Slices.utf8Slice("true")))), rowType(field("b", BOOLEAN))))), p.values(5, p.symbol("b")))) .doesNotFire(); } @@ -176,13 +176,13 @@ public void testReplaceInnerJoinWithProject() .on(p -> p.join( INNER, - p.valuesOfExpressions(ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x")))))), + p.valuesOfExpressions(ImmutableList.of(p.symbol("a", INTEGER), p.symbol("b", VARCHAR)), ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VARCHAR, Slices.utf8Slice("x")))))), p.values(5, p.symbol("c")))) .matches( project( ImmutableMap.of( "a", PlanMatchPattern.expression(new Constant(INTEGER, 1L)), - "b", PlanMatchPattern.expression(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x"))), + "b", PlanMatchPattern.expression(new Constant(VARCHAR, Slices.utf8Slice("x"))), "c", PlanMatchPattern.expression(new Reference(BIGINT, "c"))), values("c"))); @@ -191,12 +191,12 @@ public void testReplaceInnerJoinWithProject() p.join( INNER, p.values(5, p.symbol("c")), - p.valuesOfExpressions(ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x")))))))) + p.valuesOfExpressions(ImmutableList.of(p.symbol("a", INTEGER), p.symbol("b", VARCHAR)), ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VARCHAR, Slices.utf8Slice("x")))))))) .matches( project( ImmutableMap.of( "a", PlanMatchPattern.expression(new Constant(INTEGER, 1L)), - "b", PlanMatchPattern.expression(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x"))), + "b", PlanMatchPattern.expression(new Constant(VARCHAR, Slices.utf8Slice("x"))), "c", PlanMatchPattern.expression(new Reference(BIGINT, "c"))), values("c"))); } @@ -208,13 +208,13 @@ public void testReplaceLeftJoinWithProject() .on(p -> p.join( LEFT, - p.valuesOfExpressions(ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x")))))), + p.valuesOfExpressions(ImmutableList.of(p.symbol("a", INTEGER), p.symbol("b", VARCHAR)), ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VARCHAR, Slices.utf8Slice("x")))))), p.values(5, p.symbol("c")))) .matches( project( ImmutableMap.of( "a", PlanMatchPattern.expression(new Constant(INTEGER, 1L)), - "b", PlanMatchPattern.expression(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x"))), + "b", PlanMatchPattern.expression(new Constant(VARCHAR, Slices.utf8Slice("x"))), "c", PlanMatchPattern.expression(new Reference(BIGINT, "c"))), values("c"))); @@ -223,12 +223,12 @@ public void testReplaceLeftJoinWithProject() p.join( LEFT, p.values(5, p.symbol("c")), - p.valuesOfExpressions(ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x")))))))) + p.valuesOfExpressions(ImmutableList.of(p.symbol("a", INTEGER), p.symbol("b", VARCHAR)), ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VARCHAR, Slices.utf8Slice("x")))))))) .matches( project( ImmutableMap.of( "a", PlanMatchPattern.expression(new Constant(INTEGER, 1L)), - "b", PlanMatchPattern.expression(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x"))), + "b", PlanMatchPattern.expression(new Constant(VARCHAR, Slices.utf8Slice("x"))), "c", PlanMatchPattern.expression(new Reference(BIGINT, "c"))), values("c"))); } @@ -240,13 +240,13 @@ public void testReplaceRightJoinWithProject() .on(p -> p.join( RIGHT, - p.valuesOfExpressions(ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x")))))), + p.valuesOfExpressions(ImmutableList.of(p.symbol("a", INTEGER), p.symbol("b", VARCHAR)), ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VARCHAR, Slices.utf8Slice("x")))))), p.values(5, p.symbol("c")))) .matches( project( ImmutableMap.of( "a", PlanMatchPattern.expression(new Constant(INTEGER, 1L)), - "b", PlanMatchPattern.expression(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x"))), + "b", PlanMatchPattern.expression(new Constant(VARCHAR, Slices.utf8Slice("x"))), "c", PlanMatchPattern.expression(new Reference(BIGINT, "c"))), values("c"))); @@ -255,12 +255,12 @@ public void testReplaceRightJoinWithProject() p.join( RIGHT, p.values(5, p.symbol("c")), - p.valuesOfExpressions(ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x")))))))) + p.valuesOfExpressions(ImmutableList.of(p.symbol("a", INTEGER), p.symbol("b", VARCHAR)), ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VARCHAR, Slices.utf8Slice("x")))))))) .matches( project( ImmutableMap.of( "a", PlanMatchPattern.expression(new Constant(INTEGER, 1L)), - "b", PlanMatchPattern.expression(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x"))), + "b", PlanMatchPattern.expression(new Constant(VARCHAR, Slices.utf8Slice("x"))), "c", PlanMatchPattern.expression(new Reference(BIGINT, "c"))), values("c"))); } @@ -272,13 +272,13 @@ public void testReplaceFullJoinWithProject() .on(p -> p.join( FULL, - p.valuesOfExpressions(ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x")))))), + p.valuesOfExpressions(ImmutableList.of(p.symbol("a", INTEGER), p.symbol("b", VARCHAR)), ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VARCHAR, Slices.utf8Slice("x")))))), p.values(5, p.symbol("c")))) .matches( project( ImmutableMap.of( "a", PlanMatchPattern.expression(new Constant(INTEGER, 1L)), - "b", PlanMatchPattern.expression(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x"))), + "b", PlanMatchPattern.expression(new Constant(VARCHAR, Slices.utf8Slice("x"))), "c", PlanMatchPattern.expression(new Reference(BIGINT, "c"))), values("c"))); @@ -287,12 +287,12 @@ public void testReplaceFullJoinWithProject() p.join( FULL, p.values(5, p.symbol("c")), - p.valuesOfExpressions(ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x")))))))) + p.valuesOfExpressions(ImmutableList.of(p.symbol("a", INTEGER), p.symbol("b", VARCHAR)), ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VARCHAR, Slices.utf8Slice("x")))))))) .matches( project( ImmutableMap.of( "a", PlanMatchPattern.expression(new Constant(INTEGER, 1L)), - "b", PlanMatchPattern.expression(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x"))), + "b", PlanMatchPattern.expression(new Constant(VARCHAR, Slices.utf8Slice("x"))), "c", PlanMatchPattern.expression(new Reference(BIGINT, "c"))), values("c"))); } @@ -304,17 +304,17 @@ public void testRemoveOutputDuplicates() .on(p -> p.join( INNER, - p.valuesOfExpressions(ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x")))))), + p.valuesOfExpressions(ImmutableList.of(p.symbol("a", INTEGER), p.symbol("b", VARCHAR)), ImmutableList.of(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VARCHAR, Slices.utf8Slice("x")))))), p.values(5, p.symbol("c")), ImmutableList.of(), - ImmutableList.of(p.symbol("a"), p.symbol("b"), p.symbol("a"), p.symbol("b")), + ImmutableList.of(p.symbol("a", INTEGER), p.symbol("b", VARCHAR), p.symbol("a", INTEGER), p.symbol("b", VARCHAR)), ImmutableList.of(p.symbol("c"), p.symbol("c")), Optional.empty())) .matches( strictProject( ImmutableMap.of( "a", PlanMatchPattern.expression(new Constant(INTEGER, 1L)), - "b", PlanMatchPattern.expression(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x"))), + "b", PlanMatchPattern.expression(new Constant(VARCHAR, Slices.utf8Slice("x"))), "c", PlanMatchPattern.expression(new Reference(BIGINT, "c"))), values("c"))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSingleDistinctAggregationToGroupBy.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSingleDistinctAggregationToGroupBy.java index 351131c16d66..5ba1fc5e7cd2 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSingleDistinctAggregationToGroupBy.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSingleDistinctAggregationToGroupBy.java @@ -29,6 +29,7 @@ import java.util.Optional; import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; @@ -98,7 +99,7 @@ public void testDistinctWithFilter() Assignments.builder() .putIdentity(p.symbol("input1")) .putIdentity(p.symbol("input2")) - .put(p.symbol("filter1"), new Comparison(GREATER_THAN, new Reference(INTEGER, "input2"), new Constant(INTEGER, 0L))) + .put(p.symbol("filter1", BOOLEAN), new Comparison(GREATER_THAN, new Reference(INTEGER, "input2"), new Constant(INTEGER, 0L))) .build(), p.values( p.symbol("input1"), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedDistinctAggregationWithProjection.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedDistinctAggregationWithProjection.java index 9321b192058f..b6b4e17da4bc 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedDistinctAggregationWithProjection.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedDistinctAggregationWithProjection.java @@ -86,7 +86,7 @@ public void rewritesOnSubqueryWithDistinct() JoinType.LEFT, TRUE, p.project( - Assignments.of(p.symbol("x"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 100L)))), + Assignments.of(p.symbol("x", INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 100L)))), p.aggregation(innerBuilder -> innerBuilder .singleGroupingSet(p.symbol("a")) .source(p.filter( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithProjection.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithProjection.java index bb8ab0dc05aa..30201a1b6b8c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithProjection.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithProjection.java @@ -108,9 +108,9 @@ public void doesNotFireOnMultipleProjections() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), p.project( - Assignments.of(p.symbol("expr_2"), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "expr"), new Constant(INTEGER, 1L)))), + Assignments.of(p.symbol("expr_2", INTEGER), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "expr"), new Constant(INTEGER, 1L)))), p.project( - Assignments.of(p.symbol("expr"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), + Assignments.of(p.symbol("expr", INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("b"))) .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) @@ -139,7 +139,7 @@ public void rewritesOnSubqueryWithProjection() .on(p -> p.correlatedJoin( ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), - p.project(Assignments.of(p.symbol("expr"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), + p.project(Assignments.of(p.symbol("expr", INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("b"))) .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) @@ -163,8 +163,8 @@ public void rewritesOnSubqueryWithDistinct() p.values(p.symbol("corr")), p.project( Assignments.of( - p.symbol("expr_sum"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), - p.symbol("expr_count"), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "count"), new Constant(INTEGER, 1L)))), + p.symbol("expr_sum", INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), + p.symbol("expr_count", INTEGER), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "count"), new Constant(INTEGER, 1L)))), p.aggregation(outerBuilder -> outerBuilder .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("count"), PlanBuilder.aggregation("count", ImmutableList.of()), ImmutableList.of()) @@ -216,8 +216,8 @@ public void rewritesOnSubqueryWithDecorrelatableDistinct() p.values(p.symbol("corr")), p.project( Assignments.of( - p.symbol("expr_sum"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), - p.symbol("expr_count"), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "count"), new Constant(INTEGER, 1L)))), + p.symbol("expr_sum", INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), + p.symbol("expr_count", INTEGER), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "count"), new Constant(INTEGER, 1L)))), p.aggregation(outerBuilder -> outerBuilder .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("count"), PlanBuilder.aggregation("count", ImmutableList.of()), ImmutableList.of()) @@ -265,7 +265,7 @@ public void testWithPreexistingMask() .on(p -> p.correlatedJoin( ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), - p.project(Assignments.of(p.symbol("expr"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), + p.project(Assignments.of(p.symbol("expr", INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("mask", BOOLEAN))) .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT), p.symbol("mask", BOOLEAN)) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithoutProjection.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithoutProjection.java index 5aebf76ded0c..a78c52b1e0fe 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithoutProjection.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithoutProjection.java @@ -108,9 +108,9 @@ public void doesNotFireOnMultipleProjections() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), p.project( - Assignments.of(p.symbol("expr_2"), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "expr"), new Constant(INTEGER, 1L)))), + Assignments.of(p.symbol("expr_2", INTEGER), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "expr"), new Constant(INTEGER, 1L)))), p.project( - Assignments.of(p.symbol("expr"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), + Assignments.of(p.symbol("expr", INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("b"))) .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) @@ -146,7 +146,7 @@ public void rewritesOnSubqueryWithProjection() .on(p -> p.correlatedJoin( ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), - p.project(Assignments.of(p.symbol("expr"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), + p.project(Assignments.of(p.symbol("expr", INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L)))), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("b"))) .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGroupedAggregationWithProjection.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGroupedAggregationWithProjection.java index fa0389e8a41d..3092bcb6320d 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGroupedAggregationWithProjection.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGroupedAggregationWithProjection.java @@ -92,8 +92,8 @@ public void rewritesOnSubqueryWithoutDistinct() TRUE, p.project( Assignments.of( - p.symbol("expr_sum"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), - p.symbol("expr_count"), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "count"), new Constant(INTEGER, 1L)))), + p.symbol("expr_sum", INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), + p.symbol("expr_count", INTEGER), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "count"), new Constant(INTEGER, 1L)))), p.aggregation(outerBuilder -> outerBuilder .singleGroupingSet(p.symbol("a")) .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) @@ -134,8 +134,8 @@ public void rewritesOnSubqueryWithDistinct() TRUE, p.project( Assignments.of( - p.symbol("expr_sum"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), - p.symbol("expr_count"), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "count"), new Constant(INTEGER, 1L)))), + p.symbol("expr_sum", INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), + p.symbol("expr_count", INTEGER), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "count"), new Constant(INTEGER, 1L)))), p.aggregation(outerBuilder -> outerBuilder .singleGroupingSet(p.symbol("a")) .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) @@ -185,8 +185,8 @@ public void rewritesOnSubqueryWithDecorrelatableDistinct() TRUE, p.project( Assignments.of( - p.symbol("expr_sum"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), - p.symbol("expr_count"), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "count"), new Constant(INTEGER, 1L)))), + p.symbol("expr_sum", INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "sum"), new Constant(INTEGER, 1L))), + p.symbol("expr_count", INTEGER), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "count"), new Constant(INTEGER, 1L)))), p.aggregation(outerBuilder -> outerBuilder .singleGroupingSet(p.symbol("a")) .addAggregation(p.symbol("sum"), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "a"))), ImmutableList.of(BIGINT)) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedScalarSubquery.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedScalarSubquery.java index 90e6e6c1bece..19cbf828dd20 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedScalarSubquery.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedScalarSubquery.java @@ -133,7 +133,7 @@ public void rewritesOnSubqueryWithProjection() p.values(p.symbol("corr")), p.enforceSingleRow( p.project( - Assignments.of(p.symbol("a2"), new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 2L)))), + Assignments.of(p.symbol("a2", INTEGER), new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 2L)))), p.filter( new Comparison(EQUAL, new Constant(INTEGER, 1L), new Reference(INTEGER, "a")), // TODO use correlated predicate, it requires support for correlated subqueries in plan matchers p.values(ImmutableList.of(p.symbol("a")), TWO_ROWS)))))) @@ -163,10 +163,10 @@ public void rewritesOnSubqueryWithProjectionOnTopEnforceSingleNode() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), p.project( - Assignments.of(p.symbol("a3"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "a2"), new Constant(INTEGER, 1L)))), + Assignments.of(p.symbol("a3", INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "a2"), new Constant(INTEGER, 1L)))), p.enforceSingleRow( p.project( - Assignments.of(p.symbol("a2"), new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 2L)))), + Assignments.of(p.symbol("a2", INTEGER), new Call(MULTIPLY_INTEGER, ImmutableList.of(new Reference(INTEGER, "a"), new Constant(INTEGER, 2L)))), p.filter( new Comparison(EQUAL, new Constant(INTEGER, 1L), new Reference(INTEGER, "a")), // TODO use correlated predicate, it requires support for correlated subqueries in plan matchers p.values(ImmutableList.of(p.symbol("a")), TWO_ROWS))))))) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java index 5dd040b1f04f..b0569415b0ed 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java @@ -69,7 +69,7 @@ public void testRewrite() ImmutableMap.of(p.symbol("l_nationkey"), new TpchColumnHandle("nationkey", BIGINT))), p.project( - Assignments.of(p.symbol("l_expr2"), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "l_nationkey"), new Constant(INTEGER, 1L)))), + Assignments.of(p.symbol("l_expr2", INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "l_nationkey"), new Constant(INTEGER, 1L)))), p.values( ImmutableList.of(), ImmutableList.of( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestUnwrapRowSubscript.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestUnwrapRowSubscript.java index a4ff34663243..e255aaf8a874 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestUnwrapRowSubscript.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestUnwrapRowSubscript.java @@ -90,7 +90,7 @@ private void test(Expression original, Expression unwrapped) tester().assertThat(new UnwrapRowSubscript().projectExpressionRewrite()) .on(p -> p.project( Assignments.builder() - .put(p.symbol("output"), original) + .put(p.symbol("output", original.type()), original) .build(), p.values())) .matches( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/TestRuleTester.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/TestRuleTester.java index 0cf91988072c..02f66c0c211a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/TestRuleTester.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/TestRuleTester.java @@ -52,7 +52,7 @@ public void testReportWrongMatch() (node, captures, context) -> Result.ofPlanNode(node.replaceChildren(node.getSources())))) .on(p -> p.project( - Assignments.of(p.symbol("y"), new Reference(INTEGER, "x")), + Assignments.of(p.symbol("y", INTEGER), new Reference(INTEGER, "x")), p.values( ImmutableList.of(p.symbol("x")), ImmutableList.of(ImmutableList.of(new Constant(INTEGER, 1L)))))); diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestConnectorPushdownRulesWithHive.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestConnectorPushdownRulesWithHive.java index 11854529a8c6..df6532e4d676 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestConnectorPushdownRulesWithHive.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestConnectorPushdownRulesWithHive.java @@ -260,7 +260,7 @@ public void testColumnPruningProjectionPushdown() Symbol symbolA = p.symbol("a", INTEGER); Symbol symbolB = p.symbol("b", INTEGER); return p.project( - Assignments.of(p.symbol("x"), symbolA.toSymbolReference()), + Assignments.of(p.symbol("x", INTEGER), symbolA.toSymbolReference()), p.tableScan( table, ImmutableList.of(symbolA, symbolB), diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/optimizer/TestConnectorPushdownRulesWithIceberg.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/optimizer/TestConnectorPushdownRulesWithIceberg.java index 29516fff4b57..1f8748b91356 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/optimizer/TestConnectorPushdownRulesWithIceberg.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/optimizer/TestConnectorPushdownRulesWithIceberg.java @@ -324,7 +324,7 @@ public void testColumnPruningProjectionPushdown() Symbol symbolA = p.symbol("a", INTEGER); Symbol symbolB = p.symbol("b", INTEGER); return p.project( - Assignments.of(p.symbol("x"), symbolA.toSymbolReference()), + Assignments.of(p.symbol("x", INTEGER), symbolA.toSymbolReference()), p.tableScan( table, ImmutableList.of(symbolA, symbolB), From 4e9403b4374ee4c3e7ada3466adfd794d8145708 Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Tue, 26 Mar 2024 14:03:49 -0700 Subject: [PATCH 6/6] Always require default value for Switch and Case --- .../main/java/io/trino/sql/ir/Booleans.java | 1 + .../src/main/java/io/trino/sql/ir/Case.java | 12 +-- .../trino/sql/ir/DefaultTraversalVisitor.java | 7 +- .../io/trino/sql/ir/ExpressionFormatter.java | 8 +- .../trino/sql/ir/ExpressionTreeRewriter.java | 10 +- .../java/io/trino/sql/ir/IrExpressions.java | 6 +- .../src/main/java/io/trino/sql/ir/Switch.java | 10 +- .../sql/planner/IrExpressionInterpreter.java | 15 ++- .../io/trino/sql/planner/QueryPlanner.java | 2 +- .../io/trino/sql/planner/TranslationMap.java | 8 +- .../rule/PreAggregateCaseAggregations.java | 48 +++++----- .../rule/SimplifyFilterPredicate.java | 44 ++++----- .../TransformCorrelatedInPredicateToJoin.java | 3 +- .../TransformCorrelatedScalarSubquery.java | 4 +- ...tifiedComparisonApplyToCorrelatedJoin.java | 4 +- .../SqlToRowExpressionTranslator.java | 9 +- ...stPageFieldsToInputParametersRewriter.java | 9 +- .../trino/sql/TestExpressionInterpreter.java | 92 +++++++++---------- .../sql/planner/TestEqualityInference.java | 5 +- .../trino/sql/planner/TestLogicalPlanner.java | 6 +- .../assertions/ExpressionVerifier.java | 16 ---- .../TestCanonicalizeExpressionRewriter.java | 4 +- .../TestPreAggregateCaseAggregations.java | 50 +++++----- .../rule/TestSimplifyFilterPredicate.java | 45 +++++---- ...TestTransformCorrelatedScalarSubquery.java | 5 +- .../geospatial/TestSpatialJoinPlanning.java | 4 +- 26 files changed, 192 insertions(+), 235 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/Booleans.java b/core/trino-main/src/main/java/io/trino/sql/ir/Booleans.java index 80e90aa0707f..eb9f34867d62 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/Booleans.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/Booleans.java @@ -19,6 +19,7 @@ public final class Booleans { public static final Constant TRUE = new Constant(BOOLEAN, true); public static final Constant FALSE = new Constant(BOOLEAN, false); + public static final Constant NULL_BOOLEAN = new Constant(BOOLEAN, null); private Booleans() {} } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/Case.java b/core/trino-main/src/main/java/io/trino/sql/ir/Case.java index 7fcf1b9ec1c9..3afadf9426e8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/Case.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/Case.java @@ -18,7 +18,6 @@ import io.trino.spi.type.Type; import java.util.List; -import java.util.Optional; import java.util.stream.Collectors; import static io.trino.spi.type.BooleanType.BOOLEAN; @@ -26,7 +25,7 @@ import static java.util.Objects.requireNonNull; @JsonSerialize -public record Case(List whenClauses, Optional defaultValue) +public record Case(List whenClauses, Expression defaultValue) implements Expression { public Case @@ -42,9 +41,7 @@ public record Case(List whenClauses, Optional defaultVal validateType(whenClauses.getFirst().getResult().type(), whenClauses.get(i).getResult()); } - if (defaultValue.isPresent()) { - validateType(whenClauses.getFirst().getResult().type(), defaultValue.get()); - } + validateType(whenClauses.getFirst().getResult().type(), defaultValue); } @Override @@ -67,7 +64,8 @@ public List children() builder.add(clause.getOperand()); builder.add(clause.getResult()); }); - defaultValue.ifPresent(builder::add); + + builder.add(defaultValue); return builder.build(); } @@ -79,6 +77,6 @@ public String toString() whenClauses.stream() .map(WhenClause::toString) .collect(Collectors.joining(", ")), - defaultValue.map(Expression::toString).orElse("null")); + defaultValue); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/DefaultTraversalVisitor.java b/core/trino-main/src/main/java/io/trino/sql/ir/DefaultTraversalVisitor.java index 25509d2dfad1..dfcd822b2d04 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/DefaultTraversalVisitor.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/DefaultTraversalVisitor.java @@ -90,8 +90,7 @@ protected Void visitSwitch(Switch node, C context) process(clause.getResult(), context); } - node.defaultValue() - .ifPresent(value -> process(value, context)); + process(node.defaultValue(), context); return null; } @@ -130,8 +129,8 @@ protected Void visitCase(Case node, C context) process(clause.getOperand(), context); process(clause.getResult(), context); } - node.defaultValue() - .ifPresent(value -> process(value, context)); + + process(node.defaultValue(), context); return null; } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionFormatter.java b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionFormatter.java index 1b31dab969be..7cd2cb900ce2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionFormatter.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionFormatter.java @@ -191,9 +191,7 @@ protected String visitCase(Case node, Void context) parts.add(format(whenClause, context)); } - node.defaultValue() - .ifPresent(value -> parts.add("ELSE").add(process(value, context))); - + parts.add("ELSE").add(process(node.defaultValue(), context)); parts.add("END"); return "(" + Joiner.on(' ').join(parts.build()) + ")"; @@ -211,9 +209,7 @@ protected String visitSwitch(Switch node, Void context) parts.add(format(whenClause, context)); } - node.defaultValue() - .ifPresent(value -> parts.add("ELSE").add(process(value, context))); - + parts.add("ELSE").add(process(node.defaultValue(), context)); parts.add("END"); return "(" + Joiner.on(' ').join(parts.build()) + ")"; diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionTreeRewriter.java b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionTreeRewriter.java index f7e28a45158c..268a4810d36e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionTreeRewriter.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionTreeRewriter.java @@ -248,10 +248,9 @@ protected Expression visitCase(Case node, Context context) builder.add(rewriteWhenClause(expression, context)); } - Optional defaultValue = node.defaultValue() - .map(value -> rewrite(value, context.get())); + Expression defaultValue = rewrite(node.defaultValue(), context.get()); - if (!sameElements(node.defaultValue(), defaultValue) || !sameElements(node.whenClauses(), builder.build())) { + if (node.defaultValue() != defaultValue || !sameElements(node.whenClauses(), builder.build())) { return new Case(builder.build(), defaultValue); } @@ -275,11 +274,10 @@ protected Expression visitSwitch(Switch node, Context context) builder.add(rewriteWhenClause(expression, context)); } - Optional defaultValue = node.defaultValue() - .map(value -> rewrite(value, context.get())); + Expression defaultValue = rewrite(node.defaultValue(), context.get()); if (operand != node.operand() || - !sameElements(node.defaultValue(), defaultValue) || + node.defaultValue() != defaultValue || !sameElements(node.whenClauses(), builder.build())) { return new Switch(operand, builder.build(), defaultValue); } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/IrExpressions.java b/core/trino-main/src/main/java/io/trino/sql/ir/IrExpressions.java index 76409b1c356f..ecc48e98e80b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/IrExpressions.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/IrExpressions.java @@ -15,19 +15,17 @@ import com.google.common.collect.ImmutableList; -import java.util.Optional; - public class IrExpressions { private IrExpressions() {} public static Expression ifExpression(Expression condition, Expression trueCase) { - return new Case(ImmutableList.of(new WhenClause(condition, trueCase)), Optional.empty()); + return new Case(ImmutableList.of(new WhenClause(condition, trueCase)), new Constant(trueCase.type(), null)); } public static Expression ifExpression(Expression condition, Expression trueCase, Expression falseCase) { - return new Case(ImmutableList.of(new WhenClause(condition, trueCase)), Optional.of(falseCase)); + return new Case(ImmutableList.of(new WhenClause(condition, trueCase)), falseCase); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/Switch.java b/core/trino-main/src/main/java/io/trino/sql/ir/Switch.java index 75512c8bf591..b7fa0890c7ad 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/Switch.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/Switch.java @@ -18,13 +18,12 @@ import io.trino.spi.type.Type; import java.util.List; -import java.util.Optional; import static io.trino.sql.ir.IrUtils.validateType; import static java.util.Objects.requireNonNull; @JsonSerialize -public record Switch(Expression operand, List whenClauses, Optional defaultValue) +public record Switch(Expression operand, List whenClauses, Expression defaultValue) implements Expression { public Switch @@ -40,9 +39,7 @@ public record Switch(Expression operand, List whenClauses, Optional< validateType(whenClauses.getFirst().getResult().type(), whenClauses.get(i).getResult()); } - if (defaultValue.isPresent()) { - validateType(whenClauses.getFirst().getResult().type(), defaultValue.get()); - } + validateType(whenClauses.getFirst().getResult().type(), defaultValue); } @Override @@ -68,8 +65,7 @@ public List children() builder.add(clause.getResult()); }); - defaultValue.ifPresent(builder::add); - + builder.add(defaultValue); return builder.build(); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java b/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java index 55c0f1dbb523..f9d89a901038 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java @@ -230,16 +230,14 @@ else if (Boolean.TRUE.equals(whenOperand)) { defaultResult = newDefault; } else { - defaultResult = processWithExceptionHandling(node.defaultValue().orElse(null), context); + defaultResult = processWithExceptionHandling(node.defaultValue(), context); } if (whenClauses.isEmpty()) { return defaultResult; } - Expression defaultExpression; - defaultExpression = defaultResult == null ? null : toExpression(defaultResult, ((Expression) node).type()); - return new Case(whenClauses, Optional.ofNullable(defaultExpression)); + return new Case(whenClauses, toExpression(defaultResult, ((Expression) node).type())); } @Override @@ -250,7 +248,7 @@ protected Object visitSwitch(Switch node, SymbolResolver context) // if operand is null, return defaultValue if (operand == null) { - return processWithExceptionHandling(node.defaultValue().orElse(null), context); + return processWithExceptionHandling(node.defaultValue(), context); } Object newDefault = null; @@ -281,16 +279,15 @@ protected Object visitSwitch(Switch node, SymbolResolver context) defaultResult = newDefault; } else { - defaultResult = processWithExceptionHandling(node.defaultValue().orElse(null), context); + defaultResult = processWithExceptionHandling(node.defaultValue(), context); } if (whenClauses.isEmpty()) { return defaultResult; } - Expression defaultExpression; - defaultExpression = defaultResult == null ? null : toExpression(defaultResult, ((Expression) node).type()); - return new Switch(toExpression(operand, node.operand().type()), whenClauses, Optional.ofNullable(defaultExpression)); + Expression defaultExpression = toExpression(defaultResult, ((Expression) node).type()); + return new Switch(toExpression(operand, node.operand().type()), whenClauses, defaultExpression); } private boolean isEqual(Object operand1, Type type1, Object operand2, Type type2) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java index a901f03ce3f9..fc957d7ad6e6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java @@ -861,7 +861,7 @@ public MergeWriterNode plan(Merge merge) // The case number rowBuilder.add(new Constant(INTEGER, -1L)); - Case caseExpression = new Case(whenClauses.build(), Optional.of(new Row(rowBuilder.build()))); + Case caseExpression = new Case(whenClauses.build(), new Row(rowBuilder.build())); Symbol mergeRowSymbol = symbolAllocator.newSymbol("merge_row", mergeAnalysis.getMergeRowType()); Symbol caseNumberSymbol = symbolAllocator.newSymbol("case_number", INTEGER); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java b/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java index ac211d2f500d..efebef927db8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java @@ -405,7 +405,9 @@ private io.trino.sql.ir.Expression translate(SearchedCaseExpression expression) translateExpression(clause.getOperand()), translateExpression(clause.getResult()))) .collect(toImmutableList()), - expression.getDefaultValue().map(this::translateExpression)); + expression.getDefaultValue() + .map(this::translateExpression) + .orElse(new Constant(analysis.getType(expression), null))); } private io.trino.sql.ir.Expression translate(SimpleCaseExpression expression) @@ -417,7 +419,9 @@ private io.trino.sql.ir.Expression translate(SimpleCaseExpression expression) translateExpression(clause.getOperand()), translateExpression(clause.getResult()))) .collect(toImmutableList()), - expression.getDefaultValue().map(this::translateExpression)); + expression.getDefaultValue() + .map(this::translateExpression) + .orElse(new Constant(analysis.getType(expression), null))); } private io.trino.sql.ir.Expression translate(InPredicate expression) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PreAggregateCaseAggregations.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PreAggregateCaseAggregations.java index e48588202218..052752cb43d3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PreAggregateCaseAggregations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PreAggregateCaseAggregations.java @@ -381,33 +381,27 @@ private Optional extractCaseAggregation(Symbol aggregationSymbo return Optional.empty(); } - Optional cumulativeAggregationDefaultValue = Optional.empty(); - if (caseExpression.defaultValue().isPresent()) { - Expression defaultValue = optimizeExpression(caseExpression.defaultValue().get(), context); - if (defaultValue instanceof Constant(Type type, Object value) && value != null) { - if (!name.equals(SUM)) { - return Optional.empty(); - } + Expression defaultValue = optimizeExpression(caseExpression.defaultValue(), context); + if (defaultValue instanceof Constant(Type type, Object value) && value != null) { + if (!name.equals(SUM)) { + return Optional.empty(); + } - // sum aggregation is only supported if default value is null or 0, otherwise it wouldn't be cumulative - if (type instanceof BigintType - || type == INTEGER - || type == SMALLINT - || type == TINYINT - || type == DOUBLE - || type == REAL - || type instanceof DecimalType) { - if (!value.equals(0L) && !value.equals(0.0d) && !value.equals(Int128.ZERO)) { - return Optional.empty(); - } - } - else { + // sum aggregation is only supported if default value is null or 0, otherwise it wouldn't be cumulative + if (type instanceof BigintType + || type == INTEGER + || type == SMALLINT + || type == TINYINT + || type == DOUBLE + || type == REAL + || type instanceof DecimalType) { + if (!value.equals(0L) && !value.equals(0.0d) && !value.equals(Int128.ZERO)) { return Optional.empty(); } } - - // cumulative aggregation default value need to be CAST to cumulative aggregation input type - cumulativeAggregationDefaultValue = Optional.of(new Cast(caseExpression.defaultValue().get(), aggregationType)); + else { + return Optional.empty(); + } } return Optional.of(new CaseAggregation( @@ -417,7 +411,7 @@ private Optional extractCaseAggregation(Symbol aggregationSymbo name, caseExpression.whenClauses().get(0).getOperand(), caseExpression.whenClauses().get(0).getResult(), - cumulativeAggregationDefaultValue)); + new Cast(caseExpression.defaultValue(), aggregationType))); } private Type getType(Expression expression) @@ -445,7 +439,7 @@ private static class CaseAggregation // CASE expression only result expression private final Expression result; // default value of cumulative aggregation - private final Optional cumulativeAggregationDefaultValue; + private final Expression cumulativeAggregationDefaultValue; public CaseAggregation( Symbol aggregationSymbol, @@ -454,7 +448,7 @@ public CaseAggregation( CatalogSchemaFunctionName name, Expression operand, Expression result, - Optional cumulativeAggregationDefaultValue) + Expression cumulativeAggregationDefaultValue) { this.aggregationSymbol = requireNonNull(aggregationSymbol, "aggregationSymbol is null"); this.function = requireNonNull(function, "function is null"); @@ -495,7 +489,7 @@ public Expression getResult() return result; } - public Optional getCumulativeAggregationDefaultValue() + public Expression getCumulativeAggregationDefaultValue() { return cumulativeAggregationDefaultValue; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyFilterPredicate.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyFilterPredicate.java index ba3860dc3257..873688b68cdb 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyFilterPredicate.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyFilterPredicate.java @@ -86,45 +86,47 @@ public Result apply(FilterNode node, Captures captures, Context context) return Result.empty(); } + Expression predicate = combineConjuncts(newConjuncts.build()); + if (predicate instanceof Constant constant && constant.value() == null) { + predicate = FALSE; + } return Result.ofPlanNode(new FilterNode( node.getId(), node.getSource(), - combineConjuncts(newConjuncts.build()))); + predicate)); } - private static Optional simplify(Expression condition, Expression trueValue, Optional falseValue) + private static Optional simplify(Expression condition, Expression trueValue, Expression falseValue) { - if (trueValue.equals(TRUE) && (falseValue.isEmpty() || isNotTrue(falseValue.get()))) { + if (trueValue.equals(TRUE) && isNotTrue(falseValue)) { return Optional.of(condition); } - if (isNotTrue(trueValue) && falseValue.isPresent() && falseValue.get().equals(TRUE)) { + if (isNotTrue(trueValue) && falseValue.equals(TRUE)) { return Optional.of(isFalseOrNullPredicate(condition)); } - if (falseValue.isPresent() && falseValue.get().equals(trueValue) && isDeterministic(trueValue)) { + if (falseValue.equals(trueValue) && isDeterministic(trueValue)) { return Optional.of(trueValue); } - if (isNotTrue(trueValue) && (falseValue.isEmpty() || isNotTrue(falseValue.get()))) { + if (isNotTrue(trueValue) && isNotTrue(falseValue)) { return Optional.of(FALSE); } if (condition.equals(TRUE)) { return Optional.of(trueValue); } if (isNotTrue(condition)) { - return Optional.of(falseValue.orElse(FALSE)); + return Optional.of(falseValue); } return Optional.empty(); } private static Optional simplify(Case caseExpression) { - Optional defaultValue = caseExpression.defaultValue(); - if (caseExpression.whenClauses().size() == 1) { // if-like expression return simplify( caseExpression.whenClauses().getFirst().getOperand(), caseExpression.whenClauses().getFirst().getResult(), - defaultValue); + caseExpression.defaultValue()); } List operands = caseExpression.whenClauses().stream() @@ -141,15 +143,15 @@ private static Optional simplify(Case caseExpression) .filter(SimplifyFilterPredicate::isNotTrue) .count(); // all results true - if (trueResultsCount == results.size() && defaultValue.isPresent() && defaultValue.get().equals(TRUE)) { + if (trueResultsCount == results.size() && caseExpression.defaultValue().equals(TRUE)) { return Optional.of(TRUE); } // all results not true - if (notTrueResultsCount == results.size() && (defaultValue.isEmpty() || isNotTrue(defaultValue.get()))) { + if (notTrueResultsCount == results.size() && isNotTrue(caseExpression.defaultValue())) { return Optional.of(FALSE); } // one result true, and remaining results not true - if (trueResultsCount == 1 && notTrueResultsCount == results.size() - 1 && (defaultValue.isEmpty() || isNotTrue(defaultValue.get()))) { + if (trueResultsCount == 1 && notTrueResultsCount == results.size() - 1 && isNotTrue(caseExpression.defaultValue())) { ImmutableList.Builder builder = ImmutableList.builder(); for (WhenClause whenClause : caseExpression.whenClauses()) { Expression operand = whenClause.getOperand(); @@ -164,7 +166,7 @@ private static Optional simplify(Case caseExpression) } } // all results not true, and default true - if (notTrueResultsCount == results.size() && defaultValue.isPresent() && defaultValue.get().equals(TRUE)) { + if (notTrueResultsCount == results.size() && caseExpression.defaultValue().equals(TRUE)) { ImmutableList.Builder builder = ImmutableList.builder(); operands.forEach(operand -> builder.add(isFalseOrNullPredicate(operand))); return Optional.of(combineConjuncts(builder.build())); @@ -177,36 +179,36 @@ private static Optional simplify(Case caseExpression) if (whenClauses.isEmpty()) { return Optional.of(whenClause.getResult()); } - return Optional.of(new Case(whenClauses, Optional.of(whenClause.getResult()))); + return Optional.of(new Case(whenClauses, whenClause.getResult())); } if (!isNotTrue(operand)) { whenClauses.add(whenClause); } } if (whenClauses.isEmpty()) { - return Optional.of(defaultValue.orElse(FALSE)); + return Optional.of(caseExpression.defaultValue()); } if (whenClauses.size() < caseExpression.whenClauses().size()) { - return Optional.of(new Case(whenClauses, defaultValue)); + return Optional.of(new Case(whenClauses, caseExpression.defaultValue())); } return Optional.empty(); } private static Optional simplify(Switch caseExpression) { - Optional defaultValue = caseExpression.defaultValue(); + Optional defaultValue = Optional.of(caseExpression.defaultValue()); if (caseExpression.operand() instanceof Constant literal && literal.value() == null) { - return Optional.of(defaultValue.orElse(FALSE)); + return defaultValue; } List results = caseExpression.whenClauses().stream() .map(WhenClause::getResult) .collect(toImmutableList()); - if (results.stream().allMatch(result -> result.equals(TRUE)) && defaultValue.isPresent() && defaultValue.get().equals(TRUE)) { + if (results.stream().allMatch(result -> result.equals(TRUE)) && defaultValue.get().equals(TRUE)) { return Optional.of(TRUE); } - if (results.stream().allMatch(SimplifyFilterPredicate::isNotTrue) && (defaultValue.isEmpty() || isNotTrue(defaultValue.get()))) { + if (results.stream().allMatch(SimplifyFilterPredicate::isNotTrue) && isNotTrue(defaultValue.get())) { return Optional.of(FALSE); } return Optional.empty(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java index 9a61811f8c94..770b456aa2de 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java @@ -55,6 +55,7 @@ import static io.trino.matching.Pattern.nonEmpty; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.sql.ir.Booleans.FALSE; import static io.trino.sql.ir.IrUtils.and; import static io.trino.sql.ir.IrUtils.or; import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; @@ -220,7 +221,7 @@ private PlanNode buildInPredicateEquivalent( ImmutableList.of( new WhenClause(isGreaterThan(countMatchesSymbol, 0), booleanConstant(true)), new WhenClause(isGreaterThan(countNullMatchesSymbol, 0), booleanConstant(null))), - Optional.of(booleanConstant(false))); + FALSE); return new ProjectNode( idAllocator.getNextId(), aggregation, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedScalarSubquery.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedScalarSubquery.java index 170eadb83e20..b71047ba811f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedScalarSubquery.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedScalarSubquery.java @@ -163,9 +163,9 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co isDistinct.toSymbolReference(), ImmutableList.of( new WhenClause(TRUE, TRUE)), - Optional.of(new Cast( + new Cast( failFunction(metadata, SUBQUERY_MULTIPLE_ROWS, "Scalar sub-query has returned multiple rows"), - BOOLEAN)))); + BOOLEAN))); return Result.ofPlanNode(new ProjectNode( context.getIdAllocator().getNextId(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java index 951ba36480c0..6412b812f636 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java @@ -192,14 +192,14 @@ public Expression rewriteUsingBounds(ApplyNode.QuantifiedComparison quantifiedCo ImmutableList.of(new WhenClause( new Constant(BIGINT, 0L), emptySetResult)), - Optional.of(quantifier.apply(ImmutableList.of( + quantifier.apply(ImmutableList.of( comparisonWithExtremeValue, new Case( ImmutableList.of( new WhenClause( new Comparison(NOT_EQUAL, countAllValue.toSymbolReference(), countNonNullValue.toSymbolReference()), new Constant(BOOLEAN, null))), - Optional.of(emptySetResult)))))); + emptySetResult)))); } private Expression getBoundComparisons(ApplyNode.QuantifiedComparison quantifiedComparison, Symbol minValue, Symbol maxValue) diff --git a/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java b/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java index e0e6ec615a93..55ca50db667e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java @@ -61,7 +61,6 @@ import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.relational.Expressions.call; import static io.trino.sql.relational.Expressions.constant; -import static io.trino.sql.relational.Expressions.constantNull; import static io.trino.sql.relational.Expressions.field; import static io.trino.sql.relational.SpecialForm.Form.AND; import static io.trino.sql.relational.SpecialForm.Form.BETWEEN; @@ -339,9 +338,7 @@ protected RowExpression visitSwitch(Switch node, Void context) Type returnType = ((Expression) node).type(); - arguments.add(node.defaultValue() - .map(defaultValue -> process(defaultValue, context)) - .orElse(constantNull(returnType))); + arguments.add(process(node.defaultValue(), context)); return new SpecialForm(SWITCH, returnType, arguments.build(), functionDependencies.build()); } @@ -369,9 +366,7 @@ protected RowExpression visitCase(Case node, Void context) value4))) */ - RowExpression expression = node.defaultValue() - .map(value -> process(value, context)) - .orElse(constantNull(((Expression) node).type())); + RowExpression expression = process(node.defaultValue(), context); for (WhenClause clause : node.whenClauses().reversed()) { expression = new SpecialForm( diff --git a/core/trino-main/src/test/java/io/trino/operator/project/TestPageFieldsToInputParametersRewriter.java b/core/trino-main/src/test/java/io/trino/operator/project/TestPageFieldsToInputParametersRewriter.java index 37117fa514f2..e8b4d7f0a2f5 100644 --- a/core/trino-main/src/test/java/io/trino/operator/project/TestPageFieldsToInputParametersRewriter.java +++ b/core/trino-main/src/test/java/io/trino/operator/project/TestPageFieldsToInputParametersRewriter.java @@ -50,7 +50,6 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.stream.IntStream; @@ -102,16 +101,16 @@ public void testEagerLoading() verifyEagerlyLoadedColumns(builder.buildExpression(new Comparison(GREATER_THAN, new Reference(INTEGER, "bigint0"), new Constant(INTEGER, 0L))), 1); verifyEagerlyLoadedColumns(builder.buildExpression(new Comparison(EQUAL, new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "bigint0"), new Constant(BIGINT, 1L))), new Constant(BIGINT, 0L))), 1); verifyEagerlyLoadedColumns(builder.buildExpression(new Between(new Reference(BIGINT, "bigint0"), new Constant(BIGINT, 1L), new Constant(BIGINT, 10L))), 1); - verifyEagerlyLoadedColumns(builder.buildExpression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Reference(INTEGER, "bigint0"), new Constant(INTEGER, 0L)), new Reference(BIGINT, "bigint0"))), Optional.of(new Constant(BIGINT, null)))), 1); - verifyEagerlyLoadedColumns(builder.buildExpression(new Switch(new Reference(BIGINT, "bigint0"), ImmutableList.of(new WhenClause(new Constant(BIGINT, 1L), new Constant(BIGINT, 1L))), Optional.of(new Call(NEGATION_BIGINT, ImmutableList.of(new Reference(BIGINT, "bigint0")))))), 1); + verifyEagerlyLoadedColumns(builder.buildExpression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Reference(INTEGER, "bigint0"), new Constant(INTEGER, 0L)), new Reference(BIGINT, "bigint0"))), new Constant(BIGINT, null))), 1); + verifyEagerlyLoadedColumns(builder.buildExpression(new Switch(new Reference(BIGINT, "bigint0"), ImmutableList.of(new WhenClause(new Constant(BIGINT, 1L), new Constant(BIGINT, 1L))), new Call(NEGATION_BIGINT, ImmutableList.of(new Reference(BIGINT, "bigint0"))))), 1); verifyEagerlyLoadedColumns(builder.buildExpression(new Call(ADD_BIGINT, ImmutableList.of(new Coalesce(new Constant(BIGINT, 0L), new Reference(BIGINT, "bigint0")), new Reference(BIGINT, "bigint0")))), 1); verifyEagerlyLoadedColumns(builder.buildExpression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "bigint0"), new Call(MULTIPLY_BIGINT, ImmutableList.of(new Constant(BIGINT, 2L), new Reference(BIGINT, "bigint1")))))), 2); verifyEagerlyLoadedColumns(builder.buildExpression(new NullIf(new Reference(BIGINT, "bigint0"), new Reference(BIGINT, "bigint1"))), 2); verifyEagerlyLoadedColumns(builder.buildExpression(new Coalesce(new Call(CEIL, ImmutableList.of(new Call(DIVIDE_BIGINT, ImmutableList.of(new Reference(BIGINT, "bigint0"), new Reference(BIGINT, "bigint1"))))), new Constant(BIGINT, 0L))), 2); - verifyEagerlyLoadedColumns(builder.buildExpression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Reference(BIGINT, "bigint0"), new Reference(BIGINT, "bigint1")), new Constant(INTEGER, 1L))), Optional.of(new Constant(INTEGER, 0L)))), 2); + verifyEagerlyLoadedColumns(builder.buildExpression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Reference(BIGINT, "bigint0"), new Reference(BIGINT, "bigint1")), new Constant(INTEGER, 1L))), new Constant(INTEGER, 0L))), 2); verifyEagerlyLoadedColumns( - builder.buildExpression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Reference(BIGINT, "bigint0"), new Constant(BIGINT, 0L)), new Reference(BIGINT, "bigint1"))), Optional.of(new Constant(BIGINT, 0L)))), 2, ImmutableSet.of(0)); + builder.buildExpression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Reference(BIGINT, "bigint0"), new Constant(BIGINT, 0L)), new Reference(BIGINT, "bigint1"))), new Constant(BIGINT, 0L))), 2, ImmutableSet.of(0)); verifyEagerlyLoadedColumns(builder.buildExpression(new Coalesce(new Call(ROUND, ImmutableList.of(new Reference(BIGINT, "bigint0"))), new Reference(BIGINT, "bigint1"))), 2, ImmutableSet.of(0)); verifyEagerlyLoadedColumns(builder.buildExpression(new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "bigint0"), new Constant(BIGINT, 0L)), new Comparison(GREATER_THAN, new Reference(BIGINT, "bigint1"), new Constant(BIGINT, 0L))))), 2, ImmutableSet.of(0)); verifyEagerlyLoadedColumns(builder.buildExpression(new Logical(OR, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "bigint0"), new Constant(BIGINT, 0L)), new Comparison(GREATER_THAN, new Reference(BIGINT, "bigint1"), new Constant(BIGINT, 0L))))), 2, ImmutableSet.of(0)); diff --git a/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java b/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java index 314243d7861d..c886f5e3b17b 100644 --- a/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java +++ b/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java @@ -451,92 +451,92 @@ public void testSearchCase() assertOptimizedEquals( new Case(ImmutableList.of( new WhenClause(TRUE, new Constant(INTEGER, 33L))), - Optional.empty()), + new Constant(INTEGER, null)), new Constant(INTEGER, 33L)); assertOptimizedEquals( new Case(ImmutableList.of( new WhenClause(FALSE, new Constant(INTEGER, 1L))), - Optional.of(new Constant(INTEGER, 33L))), + new Constant(INTEGER, 33L)), new Constant(INTEGER, 33L)); assertOptimizedEquals( new Case(ImmutableList.of( new WhenClause(new Comparison(EQUAL, new Reference(INTEGER, "bound_value"), new Constant(INTEGER, 1234L)), new Constant(INTEGER, 33L))), - Optional.empty()), + new Constant(INTEGER, null)), new Constant(INTEGER, 33L)); assertOptimizedEquals( new Case(ImmutableList.of( new WhenClause(TRUE, new Reference(INTEGER, "bound_value"))), - Optional.empty()), + new Constant(INTEGER, null)), new Constant(INTEGER, 1234L)); assertOptimizedEquals( new Case(ImmutableList.of( new WhenClause(FALSE, new Constant(INTEGER, 1L))), - Optional.of(new Reference(INTEGER, "bound_value"))), + new Reference(INTEGER, "bound_value")), new Constant(INTEGER, 1234L)); assertOptimizedEquals( new Case(ImmutableList.of( new WhenClause(new Comparison(EQUAL, new Reference(INTEGER, "bound_value"), new Constant(INTEGER, 1234L)), new Constant(INTEGER, 33L))), - Optional.of(new Reference(INTEGER, "unbound_value"))), + new Reference(INTEGER, "unbound_value")), new Constant(INTEGER, 33L)); assertOptimizedMatches( new Case(ImmutableList.of( new WhenClause(new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), - Optional.empty()), + new Constant(INTEGER, null)), new Case(ImmutableList.of( new WhenClause(new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 0L)), new Constant(INTEGER, 1L))), - Optional.empty())); + new Constant(INTEGER, null))); assertOptimizedEquals( new Case(ImmutableList.of( new WhenClause(TRUE, new Constant(VARCHAR, Slices.utf8Slice("a"))), new WhenClause(new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("b")))), - Optional.of(new Constant(VARCHAR, Slices.utf8Slice("c")))), + new Constant(VARCHAR, Slices.utf8Slice("c"))), new Constant(VARCHAR, Slices.utf8Slice("a"))); assertOptimizedEquals( new Case(ImmutableList.of( new WhenClause(new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a"))), new WhenClause(TRUE, new Constant(VARCHAR, Slices.utf8Slice("b")))), - Optional.of(new Constant(VARCHAR, Slices.utf8Slice("c")))), + new Constant(VARCHAR, Slices.utf8Slice("c"))), new Case(ImmutableList.of( new WhenClause(new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a")))), - Optional.of(new Constant(VARCHAR, Slices.utf8Slice("b"))))); + new Constant(VARCHAR, Slices.utf8Slice("b")))); assertOptimizedEquals( new Case(ImmutableList.of( new WhenClause(new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a"))), new WhenClause(FALSE, new Constant(VARCHAR, Slices.utf8Slice("b")))), - Optional.of(new Constant(VARCHAR, Slices.utf8Slice("c")))), + new Constant(VARCHAR, Slices.utf8Slice("c"))), new Case(ImmutableList.of( new WhenClause(new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a")))), - Optional.of(new Constant(VARCHAR, Slices.utf8Slice("c"))))); + new Constant(VARCHAR, Slices.utf8Slice("c")))); assertOptimizedEquals( new Case(ImmutableList.of( new WhenClause(new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a"))), new WhenClause(FALSE, new Constant(VARCHAR, Slices.utf8Slice("b")))), - Optional.empty()), + new Constant(VARCHAR, null)), new Case(ImmutableList.of( new WhenClause(new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 0L)), new Constant(VARCHAR, Slices.utf8Slice("a")))), - Optional.empty())); + new Constant(VARCHAR, null))); assertOptimizedEquals( new Case(ImmutableList.of( new WhenClause(TRUE, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), new WhenClause(FALSE, new Constant(INTEGER, 1L))), - Optional.empty()), + new Constant(INTEGER, null)), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))); assertOptimizedEquals( new Case(ImmutableList.of( new WhenClause(FALSE, new Constant(INTEGER, 1L)), new WhenClause(FALSE, new Constant(INTEGER, 2L))), - Optional.of(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))))), + new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))); assertEvaluatedEquals( new Case(ImmutableList.of( new WhenClause(FALSE, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), new WhenClause(TRUE, new Constant(INTEGER, 1L))), - Optional.of(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))))), + new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), new Constant(INTEGER, 1L)); assertEvaluatedEquals( new Case(ImmutableList.of( new WhenClause(TRUE, new Constant(INTEGER, 1L)), new WhenClause(new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), - Optional.of(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))))), + new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), new Constant(INTEGER, 1L)); } @@ -549,7 +549,7 @@ public void testSimpleCase() ImmutableList.of( new WhenClause(new Constant(INTEGER, 1L), new Constant(INTEGER, 33L)), new WhenClause(new Constant(INTEGER, 1L), new Constant(INTEGER, 34L))), - Optional.empty()), + new Constant(INTEGER, null)), new Constant(INTEGER, 33L)); assertOptimizedEquals( @@ -557,28 +557,28 @@ public void testSimpleCase() new Constant(BOOLEAN, null), ImmutableList.of( new WhenClause(TRUE, new Constant(INTEGER, 33L))), - Optional.empty()), + new Constant(INTEGER, null)), new Constant(INTEGER, null)); for (Switch aSwitch : Arrays.asList(new Switch( new Constant(BOOLEAN, null), ImmutableList.of( new WhenClause(TRUE, new Constant(INTEGER, 33L))), - Optional.of(new Constant(INTEGER, 33L))), + new Constant(INTEGER, 33L)), new Switch( new Constant(INTEGER, 33L), ImmutableList.of( new WhenClause(new Constant(INTEGER, null), new Constant(INTEGER, 1L))), - Optional.of(new Constant(INTEGER, 33L))), + new Constant(INTEGER, 33L)), new Switch( new Reference(INTEGER, "bound_value"), ImmutableList.of( new WhenClause(new Constant(INTEGER, 1234L), new Constant(INTEGER, 33L))), - Optional.empty()), + new Constant(INTEGER, null)), new Switch( new Constant(INTEGER, 1234L), ImmutableList.of( new WhenClause(new Reference(INTEGER, "bound_value"), new Constant(INTEGER, 33L))), - Optional.empty()))) { + new Constant(INTEGER, null)))) { assertOptimizedEquals( aSwitch, new Constant(INTEGER, 33L)); @@ -589,14 +589,14 @@ public void testSimpleCase() TRUE, ImmutableList.of( new WhenClause(TRUE, new Reference(INTEGER, "bound_value"))), - Optional.empty()), + new Constant(INTEGER, null)), new Constant(INTEGER, 1234L)); assertOptimizedEquals( new Switch( TRUE, ImmutableList.of( new WhenClause(FALSE, new Constant(INTEGER, 1L))), - Optional.of(new Reference(INTEGER, "bound_value"))), + new Reference(INTEGER, "bound_value")), new Constant(INTEGER, 1234L)); assertOptimizedEquals( @@ -605,13 +605,13 @@ public void testSimpleCase() ImmutableList.of( new WhenClause(new Comparison(EQUAL, new Reference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L)), new WhenClause(new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), - Optional.of(new Constant(INTEGER, 33L))), + new Constant(INTEGER, 33L)), new Switch( TRUE, ImmutableList.of( new WhenClause(new Comparison(EQUAL, new Reference(INTEGER, "unbound_value"), new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L)), new WhenClause(new Comparison(EQUAL, new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 0L)), new Constant(INTEGER, 2L))), - Optional.of(new Constant(INTEGER, 33L)))); + new Constant(INTEGER, 33L))); assertOptimizedMatches( new Switch( @@ -619,80 +619,80 @@ public void testSimpleCase() ImmutableList.of( new WhenClause(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 1L)), new WhenClause(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 2L))), - Optional.of(new Constant(INTEGER, 1L))), + new Constant(INTEGER, 1L)), new Switch( new Constant(INTEGER, 1L), ImmutableList.of( new WhenClause(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 1L)), new WhenClause(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 2L))), - Optional.of(new Constant(INTEGER, 1L)))); + new Constant(INTEGER, 1L))); assertOptimizedEquals( new Switch( new Constant(INTEGER, null), ImmutableList.of( new WhenClause(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))))), - Optional.of(new Constant(INTEGER, 1L))), + new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L)); assertOptimizedEquals( new Switch( new Constant(INTEGER, null), ImmutableList.of( new WhenClause(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), - Optional.of(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))))), + new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))); assertOptimizedEquals( new Switch( new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), ImmutableList.of( new WhenClause(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), - Optional.of(new Constant(INTEGER, 3L))), + new Constant(INTEGER, 3L)), new Switch( new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), ImmutableList.of( new WhenClause(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), - Optional.of(new Constant(INTEGER, 3L)))); + new Constant(INTEGER, 3L))); assertOptimizedEquals( new Switch( new Constant(INTEGER, 1L), ImmutableList.of( new WhenClause(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 2L))), - Optional.of(new Constant(INTEGER, 3L))), + new Constant(INTEGER, 3L)), new Switch( new Constant(INTEGER, 1L), ImmutableList.of( new WhenClause(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 2L))), - Optional.of(new Constant(INTEGER, 3L)))); + new Constant(INTEGER, 3L))); assertOptimizedEquals( new Switch( new Constant(INTEGER, 1L), ImmutableList.of( new WhenClause(new Constant(INTEGER, 2L), new Constant(INTEGER, 2L)), new WhenClause(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 3L))), - Optional.of(new Constant(INTEGER, 4L))), + new Constant(INTEGER, 4L)), new Switch( new Constant(INTEGER, 1L), ImmutableList.of( new WhenClause(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 3L))), - Optional.of(new Constant(INTEGER, 4L)))); + new Constant(INTEGER, 4L))); assertOptimizedEquals( new Switch( new Constant(INTEGER, 1L), ImmutableList.of( new WhenClause(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 2L))), - Optional.empty()), + new Constant(INTEGER, null)), new Switch( new Constant(INTEGER, 1L), ImmutableList.of( new WhenClause(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Constant(INTEGER, 2L))), - Optional.empty())); + new Constant(INTEGER, null))); assertOptimizedEquals( new Switch( new Constant(INTEGER, 1L), ImmutableList.of( new WhenClause(new Constant(INTEGER, 2L), new Constant(INTEGER, 2L)), new WhenClause(new Constant(INTEGER, 3L), new Constant(INTEGER, 3L))), - Optional.empty()), + new Constant(INTEGER, null)), new Constant(INTEGER, null)); assertEvaluatedEquals( @@ -700,14 +700,14 @@ public void testSimpleCase() new Constant(INTEGER, null), ImmutableList.of( new WhenClause(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))))), - Optional.of(new Constant(INTEGER, 1L))), + new Constant(INTEGER, 1L)), new Constant(INTEGER, 1L)); assertEvaluatedEquals( new Switch( new Constant(INTEGER, 1L), ImmutableList.of( new WhenClause(new Constant(INTEGER, 2L), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))))), - Optional.of(new Constant(INTEGER, 3L))), + new Constant(INTEGER, 3L)), new Constant(INTEGER, 3L)); assertEvaluatedEquals( new Switch( @@ -715,14 +715,14 @@ public void testSimpleCase() ImmutableList.of( new WhenClause(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L)), new WhenClause(new Constant(INTEGER, 1L), new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))))), - Optional.empty()), + new Constant(INTEGER, null)), new Constant(INTEGER, 2L)); assertEvaluatedEquals( new Switch( new Constant(INTEGER, 1L), ImmutableList.of( new WhenClause(new Constant(INTEGER, 1L), new Constant(INTEGER, 2L))), - Optional.of(new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L))))), + new Call(DIVIDE_INTEGER, ImmutableList.of(new Constant(INTEGER, 0L), new Constant(INTEGER, 0L)))), new Constant(INTEGER, 2L)); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestEqualityInference.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestEqualityInference.java index 86614ce2e218..b1d8edd8dea1 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestEqualityInference.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestEqualityInference.java @@ -43,7 +43,6 @@ import java.util.Arrays; import java.util.Collection; import java.util.List; -import java.util.Optional; import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; @@ -317,8 +316,8 @@ public void testExpressionsThatMayReturnNullOnNonNullInput() .build(), new NullIf(new Reference(BIGINT, "b"), number(1)), new In(new Reference(BIGINT, "b"), ImmutableList.of(new Constant(BIGINT, null))), - new Case(ImmutableList.of(new WhenClause(new Not(new IsNull(new Reference(BIGINT, "b"))), new Constant(UnknownType.UNKNOWN, null))), Optional.empty()), - new Switch(new Reference(INTEGER, "b"), ImmutableList.of(new WhenClause(number(1), new Constant(INTEGER, null))), Optional.empty())); + new Case(ImmutableList.of(new WhenClause(new Not(new IsNull(new Reference(BIGINT, "b"))), new Constant(UnknownType.UNKNOWN, null))), new Constant(UnknownType.UNKNOWN, null)), + new Switch(new Reference(INTEGER, "b"), ImmutableList.of(new WhenClause(number(1), new Constant(INTEGER, null))), new Constant(INTEGER, null))); for (Expression candidate : candidates) { EqualityInference inference = new EqualityInference( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java index a018fce1f91f..ff8dc013a4d0 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java @@ -879,7 +879,7 @@ public void testCorrelatedScalarSubqueryInSelect() new Switch( new Reference(BOOLEAN, "is_distinct"), ImmutableList.of(new WhenClause(TRUE, TRUE)), - Optional.of(new Cast(new Call(FAIL, ImmutableList.of(new Constant(INTEGER, (long) SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode()), new Constant(VARCHAR, utf8Slice("Scalar sub-query has returned multiple rows")))), BOOLEAN))), + new Cast(new Call(FAIL, ImmutableList.of(new Constant(INTEGER, (long) SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode()), new Constant(VARCHAR, utf8Slice("Scalar sub-query has returned multiple rows")))), BOOLEAN)), project( markDistinct("is_distinct", ImmutableList.of("unique"), join(LEFT, builder -> builder @@ -897,7 +897,7 @@ public void testCorrelatedScalarSubqueryInSelect() new Switch( new Reference(BOOLEAN, "is_distinct"), ImmutableList.of(new WhenClause(TRUE, TRUE)), - Optional.of(new Cast(new Call(FAIL, ImmutableList.of(new Constant(INTEGER, (long) SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode()), new Constant(VARCHAR, utf8Slice("Scalar sub-query has returned multiple rows")))), BOOLEAN))), + new Cast(new Call(FAIL, ImmutableList.of(new Constant(INTEGER, (long) SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode()), new Constant(VARCHAR, utf8Slice("Scalar sub-query has returned multiple rows")))), BOOLEAN)), project( markDistinct("is_distinct", ImmutableList.of("unique"), join(LEFT, builder -> builder @@ -1155,7 +1155,7 @@ public void testCorrelatedDistinctGroupedAggregationRewriteToLeftOuterJoin() new Switch( new Reference(BOOLEAN, "is_distinct"), ImmutableList.of(new WhenClause(TRUE, TRUE)), - Optional.of(new Cast(new Call(FAIL, ImmutableList.of(new Constant(INTEGER, (long) SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode()), new Constant(VARCHAR, utf8Slice("Scalar sub-query has returned multiple rows")))), BOOLEAN))), + new Cast(new Call(FAIL, ImmutableList.of(new Constant(INTEGER, (long) SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode()), new Constant(VARCHAR, utf8Slice("Scalar sub-query has returned multiple rows")))), BOOLEAN)), project(markDistinct( "is_distinct", ImmutableList.of("unique"), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionVerifier.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionVerifier.java index e3106ef081ca..80fb22d590e9 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionVerifier.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionVerifier.java @@ -35,7 +35,6 @@ import java.util.List; import java.util.Objects; -import java.util.Optional; import static java.util.Objects.requireNonNull; @@ -238,10 +237,6 @@ protected Boolean visitCase(Case actual, Expression expected) return false; } - if (actual.defaultValue().isPresent() != expectedCase.defaultValue().isPresent()) { - return false; - } - return process(actual.defaultValue(), expectedCase.defaultValue()); } @@ -322,15 +317,4 @@ private boolean process(List actuals, List expected } return true; } - - private boolean process(Optional actual, Optional expected) - { - if (actual.isPresent() != expected.isPresent()) { - return false; - } - if (actual.isPresent()) { - return process(actual.get(), expected.get()); - } - return true; - } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestCanonicalizeExpressionRewriter.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestCanonicalizeExpressionRewriter.java index e1491d6fb1e5..a60624454117 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestCanonicalizeExpressionRewriter.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestCanonicalizeExpressionRewriter.java @@ -34,8 +34,6 @@ import io.trino.transaction.TransactionManager; import org.junit.jupiter.api.Test; -import java.util.Optional; - import static io.trino.SessionTestUtils.TEST_SESSION; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DateType.DATE; @@ -84,7 +82,7 @@ public void testRewriteIfExpression() { assertRewritten( ifExpression(new Comparison(EQUAL, new Reference(INTEGER, "x"), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L), new Constant(INTEGER, 1L)), - new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(INTEGER, "x"), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L))), Optional.of(new Constant(INTEGER, 1L)))); + new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(INTEGER, "x"), new Constant(INTEGER, 0L)), new Constant(INTEGER, 0L))), new Constant(INTEGER, 1L))); } @Test diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPreAggregateCaseAggregations.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPreAggregateCaseAggregations.java index 16f356dcd755..7039088ccc5a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPreAggregateCaseAggregations.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPreAggregateCaseAggregations.java @@ -148,12 +148,12 @@ public void testPreAggregatesCaseAggregations() Optional.empty(), SINGLE, project(ImmutableMap.builder() - .put("SUM_1_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Reference(BIGINT, "SUM_BIGINT"))), Optional.of(new Constant(BIGINT, 0L))))) - .put("SUM_2_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Reference(BIGINT, "SUM_INT_CAST"))), Optional.of(new Constant(BIGINT, 0L))))) - .put("SUM_3_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Reference(BIGINT, "SUM_BIGINT"))), Optional.empty()))) - .put("MIN_1_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))), new Constant(BIGINT, 1L)), new Reference(BIGINT, "MIN_BIGINT"))), Optional.empty()))) - .put("SUM_4_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 3L)), new Reference(BIGINT, "SUM_DECIMAL"))), Optional.empty()))) - .put("SUM_5_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 4L)), new Reference(BIGINT, "SUM_DECIMAL_CAST"))), Optional.empty()))) + .put("SUM_1_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Reference(BIGINT, "SUM_BIGINT"))), new Constant(BIGINT, 0L)))) + .put("SUM_2_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Reference(BIGINT, "SUM_INT_CAST"))), new Constant(BIGINT, 0L)))) + .put("SUM_3_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Reference(BIGINT, "SUM_BIGINT"))), new Constant(BIGINT, null)))) + .put("MIN_1_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))), new Constant(BIGINT, 1L)), new Reference(BIGINT, "MIN_BIGINT"))), new Constant(BIGINT, null)))) + .put("SUM_4_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 3L)), new Reference(createDecimalType(38, 1), "SUM_DECIMAL"))), new Constant(createDecimalType(38, 1), null)))) + .put("SUM_5_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 4L)), new Reference(BIGINT, "SUM_DECIMAL_CAST"))), new Constant(BIGINT, null)))) .buildOrThrow(), aggregation( singleGroupingSet("KEY", "COL_BIGINT"), @@ -168,10 +168,10 @@ public void testPreAggregatesCaseAggregations() exchange( project(ImmutableMap.of( "KEY", expression(new Call(CONCAT, ImmutableList.of(new Reference(VARCHAR, "COL_VARCHAR"), new Constant(VARCHAR, Slices.utf8Slice("a"))))), - "VALUE_BIGINT", expression(new Case(ImmutableList.of(new WhenClause(new In(new Reference(BIGINT, "COL_BIGINT"), ImmutableList.of(new Constant(BIGINT, 1L), new Constant(BIGINT, 2L))), new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))))), Optional.empty())), - "VALUE_INT_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Cast(new Cast(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))), INTEGER), BIGINT))), Optional.empty())), - "VALUE_2_BIGINT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))), new Constant(BIGINT, 1L)), new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))))), Optional.empty())), - "VALUE_DECIMAL_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 4L)), new Cast(new Call(MULTIPLY_DECIMAL_10_0, ImmutableList.of(new Reference(createDecimalType(10, 0), "COL_DECIMAL"), new Constant(createDecimalType(10, 0), Decimals.valueOfShort(new BigDecimal("2"))))), BIGINT))), Optional.empty()))), + "VALUE_BIGINT", expression(new Case(ImmutableList.of(new WhenClause(new In(new Reference(BIGINT, "COL_BIGINT"), ImmutableList.of(new Constant(BIGINT, 1L), new Constant(BIGINT, 2L))), new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))))), new Constant(BIGINT, null))), + "VALUE_INT_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Cast(new Cast(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))), INTEGER), BIGINT))), new Constant(BIGINT, null))), + "VALUE_2_BIGINT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))), new Constant(BIGINT, 1L)), new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))))), new Constant(BIGINT, null))), + "VALUE_DECIMAL_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 4L)), new Cast(new Call(MULTIPLY_DECIMAL_10_0, ImmutableList.of(new Reference(createDecimalType(10, 0), "COL_DECIMAL"), new Constant(createDecimalType(10, 0), Decimals.valueOfShort(new BigDecimal("2"))))), BIGINT))), new Constant(BIGINT, null)))), tableScan( "t", ImmutableMap.of( @@ -208,12 +208,12 @@ public void testGlobalPreAggregatesCaseAggregations() Optional.empty(), SINGLE, project(ImmutableMap.builder() - .put("SUM_1_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Reference(BIGINT, "SUM_BIGINT"))), Optional.of(new Constant(BIGINT, 0L))))) - .put("SUM_2_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Reference(BIGINT, "SUM_INT_CAST"))), Optional.of(new Constant(BIGINT, 0L))))) - .put("SUM_3_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Reference(BIGINT, "SUM_BIGINT"))), Optional.empty()))) - .put("MIN_1_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))), new Constant(BIGINT, 1L)), new Reference(BIGINT, "MIN_BIGINT"))), Optional.empty()))) - .put("SUM_4_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 3L)), new Reference(BIGINT, "SUM_DECIMAL"))), Optional.empty()))) - .put("SUM_5_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 4L)), new Reference(BIGINT, "SUM_DECIMAL_CAST"))), Optional.empty()))) + .put("SUM_1_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Reference(BIGINT, "SUM_BIGINT"))), new Constant(BIGINT, 0L)))) + .put("SUM_2_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Reference(BIGINT, "SUM_INT_CAST"))), new Constant(BIGINT, 0L)))) + .put("SUM_3_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Reference(BIGINT, "SUM_BIGINT"))), new Constant(BIGINT, null)))) + .put("MIN_1_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))), new Constant(BIGINT, 1L)), new Reference(BIGINT, "MIN_BIGINT"))), new Constant(BIGINT, null)))) + .put("SUM_4_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 3L)), new Reference(createDecimalType(38, 1), "SUM_DECIMAL"))), new Constant(createDecimalType(38, 1), null)))) + .put("SUM_5_INPUT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 4L)), new Reference(BIGINT, "SUM_DECIMAL_CAST"))), new Constant(BIGINT, null)))) .buildOrThrow(), aggregation( singleGroupingSet("COL_BIGINT"), @@ -227,10 +227,10 @@ public void testGlobalPreAggregatesCaseAggregations() SINGLE, exchange( project(ImmutableMap.of( - "VALUE_BIGINT", expression(new Case(ImmutableList.of(new WhenClause(new In(new Reference(BIGINT, "COL_BIGINT"), ImmutableList.of(new Constant(BIGINT, 1L), new Constant(BIGINT, 2L))), new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))))), Optional.empty())), - "VALUE_INT_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Cast(new Cast(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))), INTEGER), BIGINT))), Optional.empty())), - "VALUE_2_INT_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))), new Constant(BIGINT, 1L)), new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))))), Optional.empty())), - "VALUE_DECIMAL_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 4L)), new Cast(new Call(MULTIPLY_DECIMAL_10_0, ImmutableList.of(new Reference(createDecimalType(10, 0), "COL_DECIMAL"), new Constant(createDecimalType(10, 0), Decimals.valueOfShort(new BigDecimal("2"))))), BIGINT))), Optional.empty()))), + "VALUE_BIGINT", expression(new Case(ImmutableList.of(new WhenClause(new In(new Reference(BIGINT, "COL_BIGINT"), ImmutableList.of(new Constant(BIGINT, 1L), new Constant(BIGINT, 2L))), new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))))), new Constant(BIGINT, null))), + "VALUE_INT_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Cast(new Cast(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))), INTEGER), BIGINT))), new Constant(BIGINT, null))), + "VALUE_2_INT_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN, new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))), new Constant(BIGINT, 1L)), new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L))))), new Constant(BIGINT, null))), + "VALUE_DECIMAL_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 4L)), new Cast(new Call(MULTIPLY_DECIMAL_10_0, ImmutableList.of(new Reference(createDecimalType(10, 0), "COL_DECIMAL"), new Constant(createDecimalType(10, 0), Decimals.valueOfShort(new BigDecimal("2"))))), BIGINT))), new Constant(BIGINT, null)))), tableScan( "t", ImmutableMap.of( @@ -260,10 +260,10 @@ public void testPreAggregatesWithDefaultValues() Optional.empty(), SINGLE, project(ImmutableMap.builder() - .put("SUM_BIGINT_FINAL", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Reference(BIGINT, "SUM_BIGINT"))), Optional.empty()))) - .put("SUM_BIGINT_FINAL_DEFAULT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Reference(BIGINT, "SUM_BIGINT"))), Optional.of(new Constant(BIGINT, 0L))))) - .put("SUM_INT_CAST_FINAL", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Reference(BIGINT, "SUM_INT_CAST"))), Optional.empty()))) - .put("SUM_INT_CAST_FINAL_DEFAULT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Reference(BIGINT, "SUM_INT_CAST"))), Optional.of(new Constant(BIGINT, 0L))))) + .put("SUM_BIGINT_FINAL", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Reference(BIGINT, "SUM_BIGINT"))), new Constant(BIGINT, null)))) + .put("SUM_BIGINT_FINAL_DEFAULT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 1L)), new Reference(BIGINT, "SUM_BIGINT"))), new Constant(BIGINT, 0L)))) + .put("SUM_INT_CAST_FINAL", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Reference(BIGINT, "SUM_INT_CAST"))), new Constant(BIGINT, null)))) + .put("SUM_INT_CAST_FINAL_DEFAULT", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Reference(BIGINT, "SUM_INT_CAST"))), new Constant(BIGINT, 0L)))) .buildOrThrow(), aggregation( singleGroupingSet("COL_BIGINT"), @@ -274,7 +274,7 @@ public void testPreAggregatesWithDefaultValues() SINGLE, exchange( project(ImmutableMap.of( - "VALUE_INT_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Cast(new Cast(new Reference(BIGINT, "COL_BIGINT"), INTEGER), BIGINT))), Optional.empty()))), + "VALUE_INT_CAST", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "COL_BIGINT"), new Constant(BIGINT, 2L)), new Cast(new Cast(new Reference(BIGINT, "COL_BIGINT"), INTEGER), BIGINT))), new Constant(BIGINT, null)))), tableScan( "t", ImmutableMap.of( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyFilterPredicate.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyFilterPredicate.java index febbe2fbaded..bf115d8f52da 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyFilterPredicate.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyFilterPredicate.java @@ -31,12 +31,11 @@ import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import org.junit.jupiter.api.Test; -import java.util.Optional; - import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.sql.ir.Booleans.FALSE; +import static io.trino.sql.ir.Booleans.NULL_BOOLEAN; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; @@ -214,7 +213,7 @@ public void testSimplifySearchedCaseExpression() new WhenClause(new Comparison(LESS_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), TRUE), new WhenClause(new Comparison(EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), FALSE), new WhenClause(new Comparison(GREATER_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), TRUE)), - Optional.of(FALSE)), + FALSE), p.values(p.symbol("a")))) .doesNotFire(); @@ -225,7 +224,7 @@ public void testSimplifySearchedCaseExpression() new WhenClause(new Comparison(LESS_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), TRUE), new WhenClause(new Comparison(EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), TRUE), new WhenClause(new Comparison(GREATER_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), TRUE)), - Optional.of(TRUE)), + TRUE), p.values(p.symbol("a")))) .matches( filter( @@ -239,7 +238,7 @@ public void testSimplifySearchedCaseExpression() new WhenClause(new Comparison(LESS_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), FALSE), new WhenClause(new Comparison(EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), new Constant(BOOLEAN, null)), new WhenClause(new Comparison(GREATER_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), FALSE)), - Optional.of(FALSE)), + FALSE), p.values(p.symbol("a")))) .matches( filter( @@ -253,7 +252,7 @@ public void testSimplifySearchedCaseExpression() new WhenClause(new Comparison(LESS_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), FALSE), new WhenClause(new Comparison(EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), new Constant(BOOLEAN, null)), new WhenClause(new Comparison(GREATER_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), FALSE)), - Optional.empty()), + NULL_BOOLEAN), p.values(p.symbol("a")))) .matches( filter( @@ -267,7 +266,7 @@ public void testSimplifySearchedCaseExpression() new WhenClause(new Comparison(LESS_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), FALSE), new WhenClause(new Comparison(EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), new Constant(BOOLEAN, null)), new WhenClause(new Comparison(GREATER_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), TRUE)), - Optional.of(FALSE)), + FALSE), p.values(p.symbol("a")))) .matches( filter( @@ -281,7 +280,7 @@ public void testSimplifySearchedCaseExpression() new WhenClause(new Comparison(LESS_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), TRUE), new WhenClause(new Comparison(EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), new Constant(BOOLEAN, null)), new WhenClause(new Comparison(GREATER_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), FALSE)), - Optional.of(FALSE)), + FALSE), p.values(p.symbol("a")))) .matches( filter( @@ -295,7 +294,7 @@ public void testSimplifySearchedCaseExpression() new WhenClause(new Comparison(LESS_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), FALSE), new WhenClause(new Comparison(EQUAL, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), new Constant(BOOLEAN, null)), new WhenClause(new Comparison(GREATER_THAN, new Reference(INTEGER, "a"), new Constant(INTEGER, 0L)), FALSE)), - Optional.of(TRUE)), + TRUE), p.values(p.symbol("a")))) .matches( filter( @@ -318,7 +317,7 @@ public void testSimplifySearchedCaseExpression() new WhenClause(FALSE, new Reference(BOOLEAN, "a")), new WhenClause(FALSE, new Reference(BOOLEAN, "a")), new WhenClause(new Constant(BOOLEAN, null), new Reference(BOOLEAN, "a"))), - Optional.of(new Reference(BOOLEAN, "b"))), + new Reference(BOOLEAN, "b")), p.values(p.symbol("a"), p.symbol("b")))) .matches( filter( @@ -332,7 +331,7 @@ public void testSimplifySearchedCaseExpression() new WhenClause(FALSE, new Reference(BOOLEAN, "a")), new WhenClause(FALSE, new Not(new Reference(BOOLEAN, "a"))), new WhenClause(new Constant(BOOLEAN, null), new Reference(BOOLEAN, "a"))), - Optional.empty()), + NULL_BOOLEAN), p.values(p.symbol("a")))) .matches( filter( @@ -346,7 +345,7 @@ public void testSimplifySearchedCaseExpression() new WhenClause(FALSE, new Reference(BOOLEAN, "a")), new WhenClause(new Constant(BOOLEAN, null), new Not(new Reference(BOOLEAN, "a"))), new WhenClause(TRUE, new Reference(BOOLEAN, "b"))), - Optional.empty()), + NULL_BOOLEAN), p.values(p.symbol("a"), p.symbol("b")))) .matches( filter( @@ -360,11 +359,11 @@ public void testSimplifySearchedCaseExpression() new WhenClause(FALSE, new Reference(BOOLEAN, "a")), new WhenClause(new Reference(BOOLEAN, "b"), new Not(new Reference(BOOLEAN, "a"))), new WhenClause(TRUE, new Reference(BOOLEAN, "b"))), - Optional.empty()), + NULL_BOOLEAN), p.values(p.symbol("a"), p.symbol("b")))) .matches( filter( - new Case(ImmutableList.of(new WhenClause(new Reference(BOOLEAN, "b"), new Not(new Reference(BOOLEAN, "a")))), Optional.of(new Reference(BOOLEAN, "b"))), + new Case(ImmutableList.of(new WhenClause(new Reference(BOOLEAN, "b"), new Not(new Reference(BOOLEAN, "a")))), new Reference(BOOLEAN, "b")), values("a", "b"))); // move the result associated with the first true condition to default @@ -375,14 +374,14 @@ public void testSimplifySearchedCaseExpression() new WhenClause(new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 0L)), new Not(new Reference(BOOLEAN, "a"))), new WhenClause(TRUE, new Reference(BOOLEAN, "b")), new WhenClause(TRUE, new Not(new Reference(BOOLEAN, "b")))), - Optional.empty()), + NULL_BOOLEAN), p.values(p.symbol("a"), p.symbol("b")))) .matches( filter( new Case(ImmutableList.of( new WhenClause(new Comparison(LESS_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 0L)), new Reference(BOOLEAN, "a")), new WhenClause(new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 0L)), new Not(new Reference(BOOLEAN, "a")))), - Optional.of(new Reference(BOOLEAN, "b"))), + new Reference(BOOLEAN, "b")), values("a", "b"))); // cannot remove any clause @@ -391,7 +390,7 @@ public void testSimplifySearchedCaseExpression() new Case(ImmutableList.of( new WhenClause(new Comparison(LESS_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 0L)), new Reference(BOOLEAN, "a")), new WhenClause(new Comparison(GREATER_THAN, new Reference(INTEGER, "b"), new Constant(INTEGER, 0L)), new Not(new Reference(BOOLEAN, "a")))), - Optional.of(new Reference(BOOLEAN, "b"))), + new Reference(BOOLEAN, "b")), p.values(p.symbol("a"), p.symbol("b")))) .doesNotFire(); } @@ -406,7 +405,7 @@ public void testSimplifySimpleCaseExpression() ImmutableList.of( new WhenClause(new Reference(BOOLEAN, "b"), TRUE), new WhenClause(new Comparison(EQUAL, new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "b"), new Constant(INTEGER, 1L))), new Constant(INTEGER, 0L)), FALSE)), - Optional.of(TRUE)), + TRUE), p.values(p.symbol("a"), p.symbol("b")))) .doesNotFire(); @@ -418,7 +417,7 @@ public void testSimplifySimpleCaseExpression() ImmutableList.of( new WhenClause(new Constant(BOOLEAN, null), TRUE), new WhenClause(new Reference(BOOLEAN, "a"), FALSE)), - Optional.of(new Reference(BOOLEAN, "b"))), + new Reference(BOOLEAN, "b")), p.values(p.symbol("a"), p.symbol("b")))) .matches( filter( @@ -433,7 +432,7 @@ public void testSimplifySimpleCaseExpression() ImmutableList.of( new WhenClause(new Constant(BOOLEAN, null), TRUE), new WhenClause(new Reference(BOOLEAN, "a"), FALSE)), - Optional.empty()), + NULL_BOOLEAN), p.values(p.symbol("a")))) .matches( filter( @@ -448,7 +447,7 @@ public void testSimplifySimpleCaseExpression() ImmutableList.of( new WhenClause(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "b"), new Constant(INTEGER, 1L))), TRUE), new WhenClause(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "b"), new Constant(INTEGER, 2L))), TRUE)), - Optional.of(TRUE)), + TRUE), p.values(p.symbol("a"), p.symbol("b")))) .matches( filter( @@ -463,7 +462,7 @@ public void testSimplifySimpleCaseExpression() ImmutableList.of( new WhenClause(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "b"), new Constant(INTEGER, 1L))), FALSE), new WhenClause(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "b"), new Constant(INTEGER, 2L))), new Constant(BOOLEAN, null))), - Optional.of(FALSE)), + FALSE), p.values(p.symbol("a"), p.symbol("b")))) .matches( filter( @@ -478,7 +477,7 @@ public void testSimplifySimpleCaseExpression() ImmutableList.of( new WhenClause(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "b"), new Constant(INTEGER, 1L))), FALSE), new WhenClause(new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "b"), new Constant(INTEGER, 2L))), new Constant(BOOLEAN, null))), - Optional.empty()), + NULL_BOOLEAN), p.values(p.symbol("a"), p.symbol("b")))) .matches( filter( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedScalarSubquery.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedScalarSubquery.java index 19cbf828dd20..8ada30849043 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedScalarSubquery.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedScalarSubquery.java @@ -35,7 +35,6 @@ import org.junit.jupiter.api.Test; import java.util.List; -import java.util.Optional; import static io.trino.metadata.MetadataManager.createTestMetadataManager; import static io.trino.spi.StandardErrorCode.SUBQUERY_MULTIPLE_ROWS; @@ -220,8 +219,8 @@ private Expression ensureScalarSubquery() return new Switch( new Reference(BOOLEAN, "is_distinct"), ImmutableList.of(new WhenClause(TRUE, TRUE)), - Optional.of(new Cast( + new Cast( failFunction(tester().getMetadata(), SUBQUERY_MULTIPLE_ROWS, "Scalar sub-query has returned multiple rows"), - BOOLEAN))); + BOOLEAN)); } } diff --git a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestSpatialJoinPlanning.java b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestSpatialJoinPlanning.java index 551221a0d211..612ed0bfedcf 100644 --- a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestSpatialJoinPlanning.java +++ b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestSpatialJoinPlanning.java @@ -361,13 +361,13 @@ public void testNotIntersects() .left( project( ImmutableMap.of( - "wkt_a", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN_OR_EQUAL, new Call(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 0.0)), new Constant(createVarcharType(45), Slices.utf8Slice("POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))")))), Optional.empty())), + "wkt_a", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN_OR_EQUAL, new Call(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 0.0)), new Constant(createVarcharType(45), Slices.utf8Slice("POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))")))), new Constant(createVarcharType(45), null))), "name_a", expression(new Constant(createVarcharType(1), Slices.utf8Slice("a")))), singleRow())) .right( any(project( ImmutableMap.of( - "wkt_b", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN_OR_EQUAL, new Call(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 0.0)), new Constant(createVarcharType(45), Slices.utf8Slice("POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))")))), Optional.empty())), + "wkt_b", expression(new Case(ImmutableList.of(new WhenClause(new Comparison(GREATER_THAN_OR_EQUAL, new Call(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 0.0)), new Constant(createVarcharType(45), Slices.utf8Slice("POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))")))), new Constant(createVarcharType(45), null))), "name_b", expression(new Constant(createVarcharType(1), Slices.utf8Slice("a")))), singleRow()))))))); }