diff --git a/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java b/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java index bd3352c20c93..3bc7080436ee 100644 --- a/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java +++ b/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java @@ -111,7 +111,7 @@ private Expression simplifyExpression(Session session, Expression predicate, Typ { // TODO reuse io.trino.sql.planner.iterative.rule.SimplifyExpressions.rewrite - Map, Type> expressionTypes = typeAnalyzer.getTypes(session, types, predicate); + Map, Type> expressionTypes = typeAnalyzer.getTypes(types, predicate); IrExpressionInterpreter interpreter = new IrExpressionInterpreter(predicate, plannerContext, session, expressionTypes); Object value = interpreter.optimize(NoOpSymbolResolver.INSTANCE); @@ -445,7 +445,7 @@ private Type getType(Expression expression) return requireNonNull(types.get(symbol), () -> format("No type for symbol %s", symbol)); } - return typeAnalyzer.getType(session, types, expression); + return typeAnalyzer.getType(types, expression); } private SymbolStatsEstimate getExpressionStats(Expression expression) diff --git a/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java b/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java index e2aa0eb38953..278643c258e4 100644 --- a/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java +++ b/core/trino-main/src/main/java/io/trino/cost/ScalarStatsCalculator.java @@ -118,7 +118,7 @@ protected SymbolStatsEstimate visitConstant(Constant node, Void context) @Override protected SymbolStatsEstimate visitFunctionCall(FunctionCall node, Void context) { - Map, Type> expressionTypes = typeAnalyzer.getTypes(session, types, node); + Map, Type> expressionTypes = typeAnalyzer.getTypes(types, node); IrExpressionInterpreter interpreter = new IrExpressionInterpreter(node, plannerContext, session, expressionTypes); Object value = interpreter.optimize(NoOpSymbolResolver.INSTANCE); @@ -148,7 +148,7 @@ protected SymbolStatsEstimate visitCast(Cast node, Void context) double lowValue = sourceStats.getLowValue(); double highValue = sourceStats.getHighValue(); - if (isIntegralType(typeAnalyzer.getType(session, types, node))) { + if (isIntegralType(typeAnalyzer.getType(types, node))) { // todo handle low/high value changes if range gets narrower due to cast (e.g. BIGINT -> SMALLINT) if (isFinite(lowValue)) { lowValue = Math.round(lowValue); diff --git a/core/trino-main/src/main/java/io/trino/cost/SimpleFilterProjectSemiJoinStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/SimpleFilterProjectSemiJoinStatsRule.java index d628ed267354..69416bff8bcb 100644 --- a/core/trino-main/src/main/java/io/trino/cost/SimpleFilterProjectSemiJoinStatsRule.java +++ b/core/trino-main/src/main/java/io/trino/cost/SimpleFilterProjectSemiJoinStatsRule.java @@ -17,7 +17,6 @@ import io.trino.Session; import io.trino.cost.StatsCalculator.Context; import io.trino.matching.Pattern; -import io.trino.metadata.Metadata; import io.trino.sql.ir.Expression; import io.trino.sql.ir.NotExpression; import io.trino.sql.ir.SymbolReference; @@ -49,13 +48,11 @@ public class SimpleFilterProjectSemiJoinStatsRule { private static final Pattern PATTERN = filter(); - private final Metadata metadata; private final FilterStatsCalculator filterStatsCalculator; - public SimpleFilterProjectSemiJoinStatsRule(Metadata metadata, StatsNormalizer normalizer, FilterStatsCalculator filterStatsCalculator) + public SimpleFilterProjectSemiJoinStatsRule(StatsNormalizer normalizer, FilterStatsCalculator filterStatsCalculator) { super(normalizer); - this.metadata = requireNonNull(metadata, "metadata is null"); this.filterStatsCalculator = requireNonNull(filterStatsCalculator, "filterStatsCalculator cannot be null"); } @@ -135,7 +132,7 @@ private Optional extractSemiJoinOutputFilter(Expression pr } Expression semiJoinOutputReference = Iterables.getOnlyElement(semiJoinOutputReferences); - Expression remainingPredicate = combineConjuncts(metadata, conjuncts.stream() + Expression remainingPredicate = combineConjuncts(conjuncts.stream() .filter(conjunct -> conjunct != semiJoinOutputReference) .collect(toImmutableList())); boolean negated = semiJoinOutputReference instanceof NotExpression; diff --git a/core/trino-main/src/main/java/io/trino/cost/StatsCalculatorModule.java b/core/trino-main/src/main/java/io/trino/cost/StatsCalculatorModule.java index ad10a6112262..c5fc613b803f 100644 --- a/core/trino-main/src/main/java/io/trino/cost/StatsCalculatorModule.java +++ b/core/trino-main/src/main/java/io/trino/cost/StatsCalculatorModule.java @@ -65,7 +65,7 @@ public List> get() rules.add(new OutputStatsRule()); rules.add(new TableScanStatsRule(normalizer)); - rules.add(new SimpleFilterProjectSemiJoinStatsRule(plannerContext.getMetadata(), normalizer, filterStatsCalculator)); // this must be before FilterStatsRule + rules.add(new SimpleFilterProjectSemiJoinStatsRule(normalizer, filterStatsCalculator)); // this must be before FilterStatsRule rules.add(new FilterProjectAggregationStatsRule(normalizer, filterStatsCalculator)); // this must be before FilterStatsRule rules.add(new FilterStatsRule(normalizer, filterStatsCalculator)); rules.add(new ValuesStatsRule(plannerContext)); diff --git a/core/trino-main/src/main/java/io/trino/metadata/ResolvedFunction.java b/core/trino-main/src/main/java/io/trino/metadata/ResolvedFunction.java index 3672d9eefb86..a8c140db7aff 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/ResolvedFunction.java +++ b/core/trino-main/src/main/java/io/trino/metadata/ResolvedFunction.java @@ -145,23 +145,25 @@ public static boolean isResolved(QualifiedName name) return SerializedResolvedFunction.isSerializedResolvedFunction(name); } + public CatalogSchemaFunctionName getName() + { + QualifiedName qualifiedName = toQualifiedName(); + return SerializedResolvedFunction.fromSerializedName(qualifiedName).functionName(); + } + + @Deprecated public QualifiedName toQualifiedName() { CatalogSchemaFunctionName name = toCatalogSchemaFunctionName(); return QualifiedName.of(name.getCatalogName(), name.getSchemaName(), name.getFunctionName()); } + @Deprecated public CatalogSchemaFunctionName toCatalogSchemaFunctionName() { return ResolvedFunctionDecoder.toCatalogSchemaFunctionName(this); } - public static CatalogSchemaFunctionName extractFunctionName(QualifiedName qualifiedName) - { - checkArgument(isResolved(qualifiedName), "Expected qualifiedName to be a resolved function: %s", qualifiedName); - return SerializedResolvedFunction.fromSerializedName(qualifiedName).functionName(); - } - @Override public boolean equals(Object o) { diff --git a/core/trino-main/src/main/java/io/trino/sql/DynamicFilters.java b/core/trino-main/src/main/java/io/trino/sql/DynamicFilters.java index 4563eba10f34..9ee847f5f841 100644 --- a/core/trino-main/src/main/java/io/trino/sql/DynamicFilters.java +++ b/core/trino-main/src/main/java/io/trino/sql/DynamicFilters.java @@ -19,7 +19,6 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; import io.trino.metadata.Metadata; -import io.trino.metadata.ResolvedFunction; import io.trino.spi.function.CatalogSchemaFunctionName; import io.trino.spi.function.IsNull; import io.trino.spi.function.ScalarFunction; @@ -141,7 +140,7 @@ private static Symbol extractSourceSymbol(DynamicFilters.Descriptor descriptor) public static Expression replaceDynamicFilterId(FunctionCall dynamicFilterFunctionCall, DynamicFilterId newId) { return new FunctionCall( - dynamicFilterFunctionCall.getName(), + dynamicFilterFunctionCall.getFunction(), ImmutableList.of( dynamicFilterFunctionCall.getArguments().get(0), dynamicFilterFunctionCall.getArguments().get(1), @@ -186,7 +185,7 @@ public static Optional getDescriptor(Expression expression) private static boolean isDynamicFilterFunction(FunctionCall functionCall) { - return isDynamicFilterFunction(ResolvedFunction.extractFunctionName(functionCall.getName())); + return isDynamicFilterFunction(functionCall.getFunction().getName()); } public static boolean isDynamicFilterFunction(CatalogSchemaFunctionName functionName) diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/ConstantEvaluator.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/ConstantEvaluator.java index 6c4b8ebf2a60..c20669615abb 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/ConstantEvaluator.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/ConstantEvaluator.java @@ -64,7 +64,7 @@ public static Object evaluateConstant( io.trino.sql.ir.Expression rewritten = translationMap.rewrite(expression); IrTypeAnalyzer analyzer = new IrTypeAnalyzer(plannerContext); - Map, Type> types = analyzer.getTypes(session, TypeProvider.empty(), rewritten); + Map, Type> types = analyzer.getTypes(TypeProvider.empty(), rewritten); Type actualType = types.get(io.trino.sql.ir.NodeRef.of(rewritten)); if (!new TypeCoercion(plannerContext.getTypeManager()::getType).canCoerce(actualType, expectedType)) { @@ -73,7 +73,7 @@ public static Object evaluateConstant( if (!actualType.equals(expectedType)) { rewritten = new Cast(rewritten, expectedType, false); - types = analyzer.getTypes(session, TypeProvider.empty(), rewritten); + types = analyzer.getTypes(TypeProvider.empty(), rewritten); } return new IrExpressionInterpreter(rewritten, plannerContext, session, types).evaluate(); diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionFormatter.java b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionFormatter.java index f915b2c8f763..986b963e09e3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionFormatter.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionFormatter.java @@ -20,7 +20,6 @@ import java.util.Optional; import java.util.function.Function; -import static io.trino.sql.SqlFormatter.formatName; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.joining; @@ -93,7 +92,7 @@ protected String visitConstant(Constant node, Void context) @Override protected String visitFunctionCall(FunctionCall node, Void context) { - return formatName(node.getName()) + '(' + joinExpressions(node.getArguments()) + ')'; + return node.getFunction().getName().toString() + '(' + joinExpressions(node.getArguments()) + ')'; } @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionTreeRewriter.java b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionTreeRewriter.java index f3c545712c99..ad0285c9a322 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionTreeRewriter.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionTreeRewriter.java @@ -439,7 +439,7 @@ public Expression visitFunctionCall(FunctionCall node, Context context) List arguments = rewrite(node.getArguments(), context); if (!sameElements(node.getArguments(), arguments)) { - return new FunctionCall(node.getName(), arguments); + return new FunctionCall(node.getFunction(), arguments); } return node; } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/FunctionCall.java b/core/trino-main/src/main/java/io/trino/sql/ir/FunctionCall.java index 95915d564565..e55cc86f6a73 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/FunctionCall.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/FunctionCall.java @@ -16,43 +16,31 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; -import io.trino.connector.system.GlobalSystemConnector; -import io.trino.sql.tree.QualifiedName; +import io.trino.metadata.ResolvedFunction; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; +import static java.util.Objects.requireNonNull; + public final class FunctionCall extends Expression { - private final QualifiedName name; + private final ResolvedFunction function; private final List arguments; @JsonCreator - public FunctionCall(String resolvedFunction, List arguments) - { - this( - QualifiedName.of(GlobalSystemConnector.NAME, "$resolved", resolvedFunction), - ImmutableList.copyOf(arguments)); - } - - public FunctionCall(QualifiedName name, List arguments) + public FunctionCall(ResolvedFunction function, List arguments) { - this.name = name; + this.function = requireNonNull(function, "function is null"); this.arguments = ImmutableList.copyOf(arguments); } - @Deprecated - public QualifiedName getName() - { - return name; - } - @JsonProperty - public String getResolvedFunction() + public ResolvedFunction getFunction() { - return name.getSuffix(); + return function; } @JsonProperty @@ -83,21 +71,21 @@ public boolean equals(Object obj) return false; } FunctionCall o = (FunctionCall) obj; - return Objects.equals(name, o.name) && + return Objects.equals(function, o.function) && Objects.equals(arguments, o.arguments); } @Override public int hashCode() { - return Objects.hash(name, arguments); + return Objects.hash(function, arguments); } @Override public String toString() { return "%s(%s)".formatted( - name.getSuffix(), + function.getName(), arguments.stream() .map(Expression::toString) .collect(Collectors.joining(", "))); diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/IrUtils.java b/core/trino-main/src/main/java/io/trino/sql/ir/IrUtils.java index 15381115d6df..a87a471111e4 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/IrUtils.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/IrUtils.java @@ -135,18 +135,18 @@ public static Expression combinePredicates(Metadata metadata, LogicalExpression. public static Expression combinePredicates(Metadata metadata, LogicalExpression.Operator operator, Collection expressions) { if (operator == LogicalExpression.Operator.AND) { - return combineConjuncts(metadata, expressions); + return combineConjuncts(expressions); } - return combineDisjuncts(metadata, expressions); + return combineDisjuncts(expressions); } - public static Expression combineConjuncts(Metadata metadata, Expression... expressions) + public static Expression combineConjuncts(Expression... expressions) { - return combineConjuncts(metadata, Arrays.asList(expressions)); + return combineConjuncts(Arrays.asList(expressions)); } - public static Expression combineConjuncts(Metadata metadata, Collection expressions) + public static Expression combineConjuncts(Collection expressions) { requireNonNull(expressions, "expressions is null"); @@ -155,7 +155,7 @@ public static Expression combineConjuncts(Metadata metadata, Collection !e.equals(TRUE_LITERAL)) .collect(toList()); - conjuncts = removeDuplicates(metadata, conjuncts); + conjuncts = removeDuplicates(conjuncts); if (conjuncts.contains(FALSE_LITERAL)) { return FALSE_LITERAL; @@ -180,17 +180,17 @@ public static Expression combineConjunctsWithDuplicates(Collection e return and(conjuncts); } - public static Expression combineDisjuncts(Metadata metadata, Expression... expressions) + public static Expression combineDisjuncts(Expression... expressions) { - return combineDisjuncts(metadata, Arrays.asList(expressions)); + return combineDisjuncts(Arrays.asList(expressions)); } - public static Expression combineDisjuncts(Metadata metadata, Collection expressions) + public static Expression combineDisjuncts(Collection expressions) { - return combineDisjunctsWithDefault(metadata, expressions, FALSE_LITERAL); + return combineDisjunctsWithDefault(expressions, FALSE_LITERAL); } - public static Expression combineDisjunctsWithDefault(Metadata metadata, Collection expressions, Expression emptyDefault) + public static Expression combineDisjunctsWithDefault(Collection expressions, Expression emptyDefault) { requireNonNull(expressions, "expressions is null"); @@ -199,7 +199,7 @@ public static Expression combineDisjunctsWithDefault(Metadata metadata, Collecti .filter(e -> !e.equals(FALSE_LITERAL)) .collect(toList()); - disjuncts = removeDuplicates(metadata, disjuncts); + disjuncts = removeDuplicates(disjuncts); if (disjuncts.contains(TRUE_LITERAL)) { return TRUE_LITERAL; @@ -210,21 +210,21 @@ public static Expression combineDisjunctsWithDefault(Metadata metadata, Collecti public static Expression filterDeterministicConjuncts(Metadata metadata, Expression expression) { - return filterConjuncts(metadata, expression, expression1 -> DeterminismEvaluator.isDeterministic(expression1, metadata)); + return filterConjuncts(expression, expression1 -> DeterminismEvaluator.isDeterministic(expression1)); } public static Expression filterNonDeterministicConjuncts(Metadata metadata, Expression expression) { - return filterConjuncts(metadata, expression, not(testExpression -> DeterminismEvaluator.isDeterministic(testExpression, metadata))); + return filterConjuncts(expression, not(testExpression -> DeterminismEvaluator.isDeterministic(testExpression))); } - public static Expression filterConjuncts(Metadata metadata, Expression expression, Predicate predicate) + public static Expression filterConjuncts(Expression expression, Predicate predicate) { List conjuncts = extractConjuncts(expression).stream() .filter(predicate) .collect(toList()); - return combineConjuncts(metadata, conjuncts); + return combineConjuncts(conjuncts); } @SafeVarargs @@ -276,7 +276,7 @@ public static boolean isEffectivelyLiteral(PlannerContext plannerContext, Sessio private static boolean constantExpressionEvaluatesSuccessfully(PlannerContext plannerContext, Session session, Expression constantExpression) { - Map, Type> types = new IrTypeAnalyzer(plannerContext).getTypes(session, TypeProvider.empty(), constantExpression); + Map, Type> types = new IrTypeAnalyzer(plannerContext).getTypes(TypeProvider.empty(), constantExpression); IrExpressionInterpreter interpreter = new IrExpressionInterpreter(constantExpression, plannerContext, session, types); Object literalValue = interpreter.optimize(NoOpSymbolResolver.INSTANCE); return !(literalValue instanceof Expression); @@ -286,13 +286,13 @@ private static boolean constantExpressionEvaluatesSuccessfully(PlannerContext pl * Removes duplicate deterministic expressions. Preserves the relative order * of the expressions in the list. */ - private static List removeDuplicates(Metadata metadata, List expressions) + private static List removeDuplicates(List expressions) { Set seen = new HashSet<>(); ImmutableList.Builder result = ImmutableList.builder(); for (Expression expression : expressions) { - if (!DeterminismEvaluator.isDeterministic(expression, metadata)) { + if (!DeterminismEvaluator.isDeterministic(expression)) { result.add(expression); } else if (!seen.contains(expression)) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/BuiltinFunctionCallBuilder.java b/core/trino-main/src/main/java/io/trino/sql/planner/BuiltinFunctionCallBuilder.java index 48a4521f2c26..76629948cf9e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/BuiltinFunctionCallBuilder.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/BuiltinFunctionCallBuilder.java @@ -86,6 +86,6 @@ public BuiltinFunctionCallBuilder setArguments(List types, List translate(Session session, Expression expression, TypeProvider types, PlannerContext plannerContext, IrTypeAnalyzer typeAnalyzer) { - return new SqlToConnectorExpressionTranslator(session, typeAnalyzer.getTypes(session, types, expression), plannerContext) + return new SqlToConnectorExpressionTranslator(session, typeAnalyzer.getTypes(types, expression), plannerContext) .process(expression); } @@ -134,7 +134,7 @@ public static ConnectorExpressionTranslation translateConjuncts( PlannerContext plannerContext, IrTypeAnalyzer typeAnalyzer) { - Map, Type> remainingExpressionTypes = typeAnalyzer.getTypes(session, types, expression); + Map, Type> remainingExpressionTypes = typeAnalyzer.getTypes(types, expression); SqlToConnectorExpressionTranslator translator = new SqlToConnectorExpressionTranslator( session, remainingExpressionTypes, @@ -154,7 +154,7 @@ public static ConnectorExpressionTranslation translateConjuncts( } return new ConnectorExpressionTranslation( ConnectorExpressions.and(converted), - combineConjuncts(plannerContext.getMetadata(), remaining)); + combineConjuncts(remaining)); } @VisibleForTesting @@ -664,7 +664,7 @@ protected Optional visitFunctionCall(FunctionCall node, Voi return Optional.of(constantFor(typeOf(node), evaluateConstantExpression(node, plannerContext, session))); } - CatalogSchemaFunctionName functionName = ResolvedFunction.extractFunctionName(node.getName()); + CatalogSchemaFunctionName functionName = node.getFunction().getName(); checkArgument(!isDynamicFilterFunction(functionName), "Dynamic filter has no meaning for a connector, it should not be translated into ConnectorExpression"); if (functionName.equals(builtinFunctionName(LIKE_FUNCTION_NAME))) { @@ -716,7 +716,7 @@ private Optional translateLike(FunctionCall node) arguments.add(new io.trino.spi.expression.Constant(Slices.utf8Slice(matcher.getEscape().get().toString()), createVarcharType(1))); } } - else if (patternArgument instanceof FunctionCall call && ResolvedFunction.extractFunctionName(call.getName()).equals(builtinFunctionName(LIKE_PATTERN_FUNCTION_NAME))) { + else if (patternArgument instanceof FunctionCall call && call.getFunction().getName().equals(builtinFunctionName(LIKE_PATTERN_FUNCTION_NAME))) { Optional translatedPattern = process(call.getArguments().get(0)); if (translatedPattern.isEmpty()) { return Optional.empty(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/DeterminismEvaluator.java b/core/trino-main/src/main/java/io/trino/sql/planner/DeterminismEvaluator.java index 6752f8f74d6f..959f91411270 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/DeterminismEvaluator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/DeterminismEvaluator.java @@ -13,16 +13,11 @@ */ package io.trino.sql.planner; -import io.trino.metadata.Metadata; -import io.trino.metadata.ResolvedFunction; import io.trino.sql.ir.DefaultTraversalVisitor; import io.trino.sql.ir.Expression; import io.trino.sql.ir.FunctionCall; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Function; - -import static java.util.Objects.requireNonNull; /** * Determines whether a given Expression is deterministic @@ -31,35 +26,20 @@ public final class DeterminismEvaluator { private DeterminismEvaluator() {} - public static boolean isDeterministic(Expression expression, Metadata metadata) + public static boolean isDeterministic(Expression expression) { - return isDeterministic(expression, functionCall -> metadata.decodeFunction(functionCall.getName())); - } - - public static boolean isDeterministic(Expression expression, Function resolvedFunctionSupplier) - { - requireNonNull(resolvedFunctionSupplier, "resolvedFunctionSupplier is null"); - requireNonNull(expression, "expression is null"); - AtomicBoolean deterministic = new AtomicBoolean(true); - new Visitor(resolvedFunctionSupplier).process(expression, deterministic); + new Visitor().process(expression, deterministic); return deterministic.get(); } private static class Visitor extends DefaultTraversalVisitor { - private final Function resolvedFunctionSupplier; - - public Visitor(Function resolvedFunctionSupplier) - { - this.resolvedFunctionSupplier = resolvedFunctionSupplier; - } - @Override protected Void visitFunctionCall(FunctionCall node, AtomicBoolean deterministic) { - if (!resolvedFunctionSupplier.apply(node).isDeterministic()) { + if (!node.getFunction().isDeterministic()) { deterministic.set(false); return null; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java index fbc35ffc4580..4ed62a1dfcca 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java @@ -109,13 +109,6 @@ public final class DomainTranslator { - private final PlannerContext plannerContext; - - public DomainTranslator(PlannerContext plannerContext) - { - this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); - } - public Expression toPredicate(TupleDomain tupleDomain) { if (tupleDomain.isNone()) { @@ -125,7 +118,7 @@ public Expression toPredicate(TupleDomain tupleDomain) Map domains = tupleDomain.getDomains().get(); return domains.entrySet().stream() .map(entry -> toPredicate(entry.getValue(), entry.getKey().toSymbolReference())) - .collect(collectingAndThen(toImmutableList(), expressions -> combineConjuncts(plannerContext.getMetadata(), expressions))); + .collect(collectingAndThen(toImmutableList(), expressions -> combineConjuncts(expressions))); } private Expression toPredicate(Domain domain, SymbolReference reference) @@ -152,7 +145,7 @@ private Expression toPredicate(Domain domain, SymbolReference reference) disjuncts.add(new IsNullPredicate(reference)); } - return combineDisjunctsWithDefault(plannerContext.getMetadata(), disjuncts, TRUE_LITERAL); + return combineDisjunctsWithDefault(disjuncts, TRUE_LITERAL); } private Expression processRange(Type type, Range range, SymbolReference reference) @@ -184,7 +177,7 @@ private Expression processRange(Type type, Range range, SymbolReference referenc } // If rangeConjuncts is null, then the range was ALL, which should already have been checked for checkState(!rangeConjuncts.isEmpty()); - return combineConjuncts(plannerContext.getMetadata(), rangeConjuncts); + return combineConjuncts(rangeConjuncts); } private Expression combineRangeWithExcludedPoints(Type type, SymbolReference reference, Range range, List excludedPoints) @@ -198,7 +191,7 @@ private Expression combineRangeWithExcludedPoints(Type type, SymbolReference ref excludedPointsExpression = new ComparisonExpression(NOT_EQUAL, reference, getOnlyElement(excludedPoints)); } - return combineConjuncts(plannerContext.getMetadata(), processRange(type, range, reference), excludedPointsExpression); + return combineConjuncts(processRange(type, range, reference), excludedPointsExpression); } private List extractDisjuncts(Type type, Ranges ranges, SymbolReference reference) @@ -374,7 +367,7 @@ protected ExtractionResult visitLogicalExpression(LogicalExpression node, Boolea case AND: return new ExtractionResult( TupleDomain.intersect(tupleDomains), - combineConjuncts(plannerContext.getMetadata(), residuals)); + combineConjuncts(residuals)); case OR: TupleDomain columnUnionedTupleDomain = TupleDomain.columnWiseUnion(tupleDomains); @@ -387,7 +380,7 @@ protected ExtractionResult visitLogicalExpression(LogicalExpression node, Boolea // some of these cases, we won't have to double check the bounds unnecessarily at execution time. // We can only make inferences if the remaining expressions on all terms are equal and deterministic - if (Set.copyOf(residuals).size() == 1 && DeterminismEvaluator.isDeterministic(residuals.get(0), plannerContext.getMetadata())) { + if (Set.copyOf(residuals).size() == 1 && DeterminismEvaluator.isDeterministic(residuals.get(0))) { // NONE are no-op for the purpose of OR tupleDomains = tupleDomains.stream() .filter(domain -> !domain.isNone()) @@ -573,7 +566,7 @@ private boolean isImplicitCoercion(Map, Type> expressionType private Map, Type> analyzeExpression(Expression expression) { - return typeAnalyzer.getTypes(session, types, expression); + return typeAnalyzer.getTypes(types, expression); } private Optional createVarcharCastToDateComparisonExtractionResult( @@ -1026,7 +1019,7 @@ private Optional tryVisitLikeFunction(FunctionCall node, Boole return Optional.empty(); } - Type type = typeAnalyzer.getType(session, types, value); + Type type = typeAnalyzer.getType(types, value); if (!(type instanceof VarcharType varcharType)) { // TODO support CharType return Optional.empty(); @@ -1034,7 +1027,7 @@ private Optional tryVisitLikeFunction(FunctionCall node, Boole Symbol symbol = Symbol.from(value); - if (!(typeAnalyzer.getType(session, types, patternArgument) instanceof LikePatternType) || + if (!(typeAnalyzer.getType(types, patternArgument) instanceof LikePatternType) || !SymbolsExtractor.extractAll(patternArgument).isEmpty()) { // dynamic pattern or escape return Optional.empty(); @@ -1075,7 +1068,7 @@ private Optional tryVisitLikeFunction(FunctionCall node, Boole @Override protected ExtractionResult visitFunctionCall(FunctionCall node, Boolean complement) { - CatalogSchemaFunctionName name = ResolvedFunction.extractFunctionName(node.getName()); + CatalogSchemaFunctionName name = node.getFunction().getName(); if (name.equals(builtinFunctionName("starts_with"))) { Optional result = tryVisitStartsWithFunction(node, complement); if (result.isPresent()) { @@ -1110,7 +1103,7 @@ private Optional tryVisitStartsWithFunction(FunctionCall node, return Optional.empty(); } - Type type = typeAnalyzer.getType(session, types, target); + Type type = typeAnalyzer.getType(types, target); if (!(type instanceof VarcharType)) { // TODO support CharType return Optional.empty(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/EffectivePredicateExtractor.java b/core/trino-main/src/main/java/io/trino/sql/planner/EffectivePredicateExtractor.java index d0aa5e72dd36..92017ace9ff8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/EffectivePredicateExtractor.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/EffectivePredicateExtractor.java @@ -169,7 +169,7 @@ public Expression visitFilter(FilterNode node, Void context) // Remove non-deterministic conjuncts predicate = filterDeterministicConjuncts(metadata, predicate); - return combineConjuncts(metadata, predicate, underlyingPredicate); + return combineConjuncts(predicate, underlyingPredicate); } @Override @@ -216,11 +216,10 @@ public Expression visitProject(ProjectNode node, Void context) .collect(toImmutableList()); return pullExpressionThroughSymbols(combineConjuncts( - metadata, - ImmutableList.builder() - .addAll(projectionEqualities) - .addAll(validUnderlyingEqualities) - .build()), + ImmutableList.builder() + .addAll(projectionEqualities) + .addAll(validUnderlyingEqualities) + .build()), node.getOutputSymbols()); } @@ -306,23 +305,23 @@ public Expression visitJoin(JoinNode node, Void context) .collect(toImmutableList()); return switch (node.getType()) { - case INNER -> pullExpressionThroughSymbols(combineConjuncts(metadata, ImmutableList.builder() + case INNER -> pullExpressionThroughSymbols(combineConjuncts(ImmutableList.builder() .add(leftPredicate) .add(rightPredicate) - .add(combineConjuncts(metadata, joinConjuncts)) + .add(combineConjuncts(joinConjuncts)) .add(node.getFilter().orElse(TRUE_LITERAL)) .build()), node.getOutputSymbols()); - case LEFT -> combineConjuncts(metadata, ImmutableList.builder() + case LEFT -> combineConjuncts(ImmutableList.builder() .add(pullExpressionThroughSymbols(leftPredicate, node.getOutputSymbols())) .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), node.getOutputSymbols(), node.getRight().getOutputSymbols()::contains)) .addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputSymbols(), node.getRight().getOutputSymbols()::contains)) .build()); - case RIGHT -> combineConjuncts(metadata, ImmutableList.builder() + case RIGHT -> combineConjuncts(ImmutableList.builder() .add(pullExpressionThroughSymbols(rightPredicate, node.getOutputSymbols())) .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(leftPredicate), node.getOutputSymbols(), node.getLeft().getOutputSymbols()::contains)) .addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputSymbols(), node.getLeft().getOutputSymbols()::contains)) .build()); - case FULL -> combineConjuncts(metadata, ImmutableList.builder() + case FULL -> combineConjuncts(ImmutableList.builder() .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(leftPredicate), node.getOutputSymbols(), node.getLeft().getOutputSymbols()::contains)) .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), node.getOutputSymbols(), node.getRight().getOutputSymbols()::contains)) .addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputSymbols(), node.getLeft().getOutputSymbols()::contains, node.getRight().getOutputSymbols()::contains)) @@ -350,7 +349,7 @@ public Expression visitValues(ValuesNode node, Void context) }) .collect(toImmutableList()); - Map, Type> expressionTypes = typeAnalyzer.getTypes(session, types, processedExpressions); + Map, Type> expressionTypes = typeAnalyzer.getTypes(types, processedExpressions); boolean[] hasNull = new boolean[node.getOutputSymbols().size()]; boolean[] hasNaN = new boolean[node.getOutputSymbols().size()]; @@ -365,7 +364,7 @@ public Expression visitValues(ValuesNode node, Void context) if (row instanceof Row) { for (int i = 0; i < node.getOutputSymbols().size(); i++) { Expression value = ((Row) row).getItems().get(i); - if (!DeterminismEvaluator.isDeterministic(value, metadata)) { + if (!DeterminismEvaluator.isDeterministic(value)) { nonDeterministic[i] = true; } else { @@ -396,7 +395,7 @@ public Expression visitValues(ValuesNode node, Void context) } } else { - if (!DeterminismEvaluator.isDeterministic(row, metadata)) { + if (!DeterminismEvaluator.isDeterministic(row)) { return TRUE_LITERAL; } IrExpressionInterpreter interpreter = new IrExpressionInterpreter(row, plannerContext, session, expressionTypes); @@ -532,11 +531,11 @@ public Expression visitSpatialJoin(SpatialJoinNode node, Void context) Expression rightPredicate = node.getRight().accept(this, context); return switch (node.getType()) { - case INNER -> combineConjuncts(metadata, ImmutableList.builder() + case INNER -> combineConjuncts(ImmutableList.builder() .add(pullExpressionThroughSymbols(leftPredicate, node.getOutputSymbols())) .add(pullExpressionThroughSymbols(rightPredicate, node.getOutputSymbols())) .build()); - case LEFT -> combineConjuncts(metadata, ImmutableList.builder() + case LEFT -> combineConjuncts(ImmutableList.builder() .add(pullExpressionThroughSymbols(leftPredicate, node.getOutputSymbols())) .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), node.getOutputSymbols(), node.getRight().getOutputSymbols()::contains)) .build()); @@ -556,11 +555,10 @@ private Expression deriveCommonPredicates(PlanNode node, Functionbuilder() - .addAll(equalities) - .add(underlyingPredicate) - .build()), + ImmutableList.builder() + .addAll(equalities) + .add(underlyingPredicate) + .build()), node.getOutputSymbols())))); } @@ -572,17 +570,17 @@ private Expression deriveCommonPredicates(PlanNode node, Function symbols) { - EqualityInference equalityInference = new EqualityInference(metadata, expression); + EqualityInference equalityInference = new EqualityInference(expression); ImmutableList.Builder effectiveConjuncts = ImmutableList.builder(); Set scope = ImmutableSet.copyOf(symbols); - EqualityInference.nonInferrableConjuncts(metadata, expression).forEach(conjunct -> { - if (DeterminismEvaluator.isDeterministic(conjunct, metadata)) { + EqualityInference.nonInferrableConjuncts(expression).forEach(conjunct -> { + if (DeterminismEvaluator.isDeterministic(conjunct)) { Expression rewritten = equalityInference.rewrite(conjunct, scope); if (rewritten != null) { effectiveConjuncts.add(rewritten); @@ -592,7 +590,7 @@ private Expression pullExpressionThroughSymbols(Expression expression, Collectio effectiveConjuncts.addAll(equalityInference.generateEqualitiesPartitionedBy(scope).getScopeEqualities()); - return combineConjuncts(metadata, effectiveConjuncts.build()); + return combineConjuncts(effectiveConjuncts.build()); } } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/EqualityInference.java b/core/trino-main/src/main/java/io/trino/sql/planner/EqualityInference.java index 8764c7b79e13..c75b20a5449d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/EqualityInference.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/EqualityInference.java @@ -19,7 +19,6 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSetMultimap; import com.google.common.collect.Multimap; -import io.trino.metadata.Metadata; import io.trino.sql.ir.ComparisonExpression; import io.trino.sql.ir.Expression; import io.trino.sql.ir.IrUtils; @@ -62,17 +61,17 @@ public class EqualityInference private final Map> symbolsCache = new HashMap<>(); private final Map> uniqueSymbolsCache = new HashMap<>(); - public EqualityInference(Metadata metadata, Expression... expressions) + public EqualityInference(Expression... expressions) { - this(metadata, Arrays.asList(expressions)); + this(Arrays.asList(expressions)); } - public EqualityInference(Metadata metadata, Collection expressions) + public EqualityInference(Collection expressions) { DisjointSet equalities = new DisjointSet<>(); expressions.stream() .flatMap(expression -> extractConjuncts(expression).stream()) - .filter(expression -> isInferenceCandidate(metadata, expression)) + .filter(expression -> isInferenceCandidate(expression)) .forEach(expression -> { ComparisonExpression comparison = (ComparisonExpression) expression; Expression expression1 = comparison.getLeft(); @@ -255,10 +254,10 @@ else if (complementCanonical != null) { /** * Determines whether an Expression may be successfully applied to the equality inference */ - public static boolean isInferenceCandidate(Metadata metadata, Expression expression) + public static boolean isInferenceCandidate(Expression expression) { if (expression instanceof ComparisonExpression comparison && - isDeterministic(expression, metadata) && + isDeterministic(expression) && !mayReturnNullOnNonNullInput(expression)) { if (comparison.getOperator() == ComparisonExpression.Operator.EQUAL) { // We should only consider equalities that have distinct left and right components @@ -271,10 +270,10 @@ public static boolean isInferenceCandidate(Metadata metadata, Expression express /** * Provides a convenience Stream of Expression conjuncts which have not been added to the inference */ - public static Stream nonInferrableConjuncts(Metadata metadata, Expression expression) + public static Stream nonInferrableConjuncts(Expression expression) { return extractConjuncts(expression).stream() - .filter(e -> !isInferenceCandidate(metadata, e)); + .filter(e -> !isInferenceCandidate(e)); } private Expression rewrite(Expression expression, Predicate symbolScope, boolean allowFullReplacement) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java b/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java index c5f9febdb6c2..529a47ac1130 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/IrExpressionInterpreter.java @@ -139,7 +139,7 @@ public IrExpressionInterpreter(Expression expression, PlannerContext plannerCont public static Object evaluateConstantExpression(Expression expression, PlannerContext plannerContext, Session session) { - Map, Type> types = new IrTypeAnalyzer(plannerContext).getTypes(session, TypeProvider.empty(), expression); + Map, Type> types = new IrTypeAnalyzer(plannerContext).getTypes(TypeProvider.empty(), expression); return new IrExpressionInterpreter(expression, plannerContext, session, types).evaluate(); } @@ -374,7 +374,7 @@ private List processOperands(CoalesceExpression node, Object context) // The nested CoalesceExpression was recursively processed. It does not contain null. for (Expression nestedOperand : ((CoalesceExpression) value).getOperands()) { // Skip duplicates unless they are non-deterministic. - if (!isDeterministic(nestedOperand, metadata) || uniqueNewOperands.add(nestedOperand)) { + if (!isDeterministic(nestedOperand) || uniqueNewOperands.add(nestedOperand)) { newOperands.add(nestedOperand); } // This operand can be evaluated to a non-null value. Remaining operands can be skipped. @@ -385,7 +385,7 @@ private List processOperands(CoalesceExpression node, Object context) } else if (value instanceof Expression expression) { // Skip duplicates unless they are non-deterministic. - if (!isDeterministic(expression, metadata) || uniqueNewOperands.add(expression)) { + if (!isDeterministic(expression) || uniqueNewOperands.add(expression)) { newOperands.add(expression); } } @@ -492,10 +492,10 @@ else if (!found && result) { List expressionValues = toExpressions(values, types); List simplifiedExpressionValues = Stream.concat( expressionValues.stream() - .filter(expression -> isDeterministic(expression, metadata)) + .filter(expression -> isDeterministic(expression)) .distinct(), expressionValues.stream() - .filter(expression -> !isDeterministic(expression, metadata))) + .filter(expression -> !isDeterministic(expression))) .collect(toImmutableList()); if (simplifiedExpressionValues.size() == 1) { @@ -811,7 +811,7 @@ protected Object visitFunctionCall(FunctionCall node, Object context) argumentTypes.add(type); } - ResolvedFunction resolvedFunction = metadata.decodeFunction(node.getName()); + ResolvedFunction resolvedFunction = node.getFunction(); FunctionNullability functionNullability = resolvedFunction.getFunctionNullability(); for (int i = 0; i < argumentValues.size(); i++) { Object value = argumentValues.get(i); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/IrTypeAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/planner/IrTypeAnalyzer.java index fc2c9c325804..55023f8f3b77 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/IrTypeAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/IrTypeAnalyzer.java @@ -16,12 +16,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.inject.Inject; -import io.trino.Session; -import io.trino.execution.warnings.WarningCollector; -import io.trino.metadata.FunctionResolver; -import io.trino.metadata.ResolvedFunction; -import io.trino.security.AccessControl; -import io.trino.security.AllowAllAccessControl; import io.trino.spi.function.BoundSignature; import io.trino.spi.function.OperatorType; import io.trino.spi.type.ArrayType; @@ -85,9 +79,9 @@ public IrTypeAnalyzer(PlannerContext plannerContext) this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); } - public Map, Type> getTypes(Session session, TypeProvider inputTypes, Iterable expressions) + public Map, Type> getTypes(TypeProvider inputTypes, Iterable expressions) { - Visitor visitor = new Visitor(plannerContext, session, inputTypes); + Visitor visitor = new Visitor(plannerContext, inputTypes); for (Expression expression : expressions) { visitor.process(expression, new Context(ImmutableMap.of())); @@ -96,34 +90,28 @@ public Map, Type> getTypes(Session session, TypeProvider inp return visitor.getTypes(); } - public Map, Type> getTypes(Session session, TypeProvider inputTypes, Expression expression) + public Map, Type> getTypes(TypeProvider inputTypes, Expression expression) { - return getTypes(session, inputTypes, ImmutableList.of(expression)); + return getTypes(inputTypes, ImmutableList.of(expression)); } - public Type getType(Session session, TypeProvider inputTypes, Expression expression) + public Type getType(TypeProvider inputTypes, Expression expression) { - return getTypes(session, inputTypes, expression).get(NodeRef.of(expression)); + return getTypes(inputTypes, expression).get(NodeRef.of(expression)); } private static class Visitor extends IrVisitor { - private static final AccessControl ALLOW_ALL_ACCESS_CONTROL = new AllowAllAccessControl(); - private final PlannerContext plannerContext; - private final Session session; private final TypeProvider symbolTypes; - private final FunctionResolver functionResolver; private final Map, Type> expressionTypes = new LinkedHashMap<>(); - public Visitor(PlannerContext plannerContext, Session session, TypeProvider symbolTypes) + public Visitor(PlannerContext plannerContext, TypeProvider symbolTypes) { this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); - this.session = requireNonNull(session, "session is null"); this.symbolTypes = requireNonNull(symbolTypes, "symbolTypes is null"); - this.functionResolver = plannerContext.getFunctionResolver(WarningCollector.NOOP); } public Map, Type> getTypes() @@ -355,10 +343,7 @@ protected Type visitConstant(Constant node, Context context) @Override protected Type visitFunctionCall(FunctionCall node, Context context) { - // Function should already be resolved in IR - ResolvedFunction function = functionResolver.resolveFunction(session, node.getName(), null, ALLOW_ALL_ACCESS_CONTROL); - - BoundSignature signature = function.getSignature(); + BoundSignature signature = node.getFunction().getSignature(); for (int i = 0; i < node.getArguments().size(); i++) { Expression argument = node.getArguments().get(i); Type formalType = signature.getArgumentTypes().get(i); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LayoutConstraintEvaluator.java b/core/trino-main/src/main/java/io/trino/sql/planner/LayoutConstraintEvaluator.java index b244d73a21ae..f100364fd97c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LayoutConstraintEvaluator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LayoutConstraintEvaluator.java @@ -37,7 +37,7 @@ public class LayoutConstraintEvaluator public LayoutConstraintEvaluator(PlannerContext plannerContext, IrTypeAnalyzer typeAnalyzer, Session session, TypeProvider types, Map assignments, Expression expression) { this.assignments = ImmutableMap.copyOf(requireNonNull(assignments, "assignments is null")); - evaluator = new IrExpressionInterpreter(expression, plannerContext, session, typeAnalyzer.getTypes(session, types, expression)); + evaluator = new IrExpressionInterpreter(expression, plannerContext, session, typeAnalyzer.getTypes(types, expression)); arguments = SymbolsExtractor.extractUnique(expression).stream() .map(assignments::get) .collect(toImmutableSet()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index 8aa614598af0..c43cd1987b59 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -1473,7 +1473,7 @@ private Supplier prepareProjection(ExpressionAndValuePointers ex } // compile expression using input layout and input types - RowExpression rowExpression = toRowExpression(rewritten, typeAnalyzer.getTypes(session, TypeProvider.viewOf(inputTypes.buildOrThrow()), rewritten), inputLayout.buildOrThrow()); + RowExpression rowExpression = toRowExpression(rewritten, typeAnalyzer.getTypes(TypeProvider.viewOf(inputTypes.buildOrThrow()), rewritten), inputLayout.buildOrThrow()); return pageFunctionCompiler.compileProjection(rowExpression, Optional.empty()); } @@ -1631,7 +1631,7 @@ else if (matchNumberSymbol.isPresent() && inputSymbols.get(i).equals(matchNumber } // compile expression using input layout and input types - RowExpression rowExpression = toRowExpression(argument, typeAnalyzer.getTypes(session, TypeProvider.viewOf(inputTypes.buildOrThrow()), argument), inputLayout.buildOrThrow()); + RowExpression rowExpression = toRowExpression(argument, typeAnalyzer.getTypes(TypeProvider.viewOf(inputTypes.buildOrThrow()), argument), inputLayout.buildOrThrow()); return pageFunctionCompiler.compileProjection(rowExpression, Optional.empty()); } @@ -2016,7 +2016,6 @@ else if (sourceNode instanceof SampleNode sampleNode) { } Map, Type> expressionTypes = typeAnalyzer.getTypes( - session, context.getTypes(), concat(staticFilters.map(ImmutableList::of).orElse(ImmutableList.of()), assignments.getExpressions())); @@ -2096,7 +2095,7 @@ private PhysicalOperation visitTableScan(PlanNodeId planNodeId, TableScanNode no private Optional getStaticFilter(Expression filterExpression) { DynamicFilters.ExtractResult extractDynamicFilterResult = extractDynamicFilters(filterExpression); - Expression staticFilter = combineConjuncts(metadata, extractDynamicFilterResult.getStaticConjuncts()); + Expression staticFilter = combineConjuncts(extractDynamicFilterResult.getStaticConjuncts()); if (staticFilter.equals(TRUE_LITERAL)) { return Optional.empty(); } @@ -2142,7 +2141,7 @@ public PhysicalOperation visitValues(ValuesNode node, LocalExecutionPlanContext // evaluate values for non-empty rows if (node.getRows().isPresent()) { Expression row = node.getRows().get().get(i); - Map, Type> types = typeAnalyzer.getTypes(session, TypeProvider.empty(), row); + Map, Type> types = typeAnalyzer.getTypes(TypeProvider.empty(), row); checkState(types.get(NodeRef.of(row)) instanceof RowType, "unexpected type of Values row: %s", types); // evaluate the literal value SqlRow result = (SqlRow) new IrExpressionInterpreter(row, plannerContext, session, types).evaluate(); @@ -2551,7 +2550,7 @@ private Optional removeExpressionFromFilter(Expression filter, Expre private SpatialPredicate spatialTest(FunctionCall functionCall, boolean probeFirst, Optional comparisonOperator) { - CatalogSchemaFunctionName functionName = ResolvedFunction.extractFunctionName(functionCall.getName()); + CatalogSchemaFunctionName functionName = functionCall.getFunction().getName(); if (functionName.equals(builtinFunctionName(ST_CONTAINS))) { if (probeFirst) { return (buildGeometry, probeGeometry, radius) -> probeGeometry.contains(buildGeometry); @@ -2730,8 +2729,7 @@ private PagesSpatialIndexFactory createPagesSpatialIndexFactory( filterExpression, probeLayout, buildLayout, - context.getTypes(), - session)); + context.getTypes())); Optional partitionChannel = node.getRightPartitionSymbol().map(buildChannelGetter); @@ -2815,11 +2813,10 @@ private PhysicalOperation createLookupJoin( filterExpression, probeSource.getLayout(), buildLayout, - context.getTypes(), - session)); + context.getTypes())); Optional sortExpressionContext = node.getFilter() - .flatMap(filter -> extractSortExpression(metadata, ImmutableSet.copyOf(node.getRight().getOutputSymbols()), filter)); + .flatMap(filter -> extractSortExpression(ImmutableSet.copyOf(node.getRight().getOutputSymbols()), filter)); Optional sortChannel = sortExpressionContext .map(SortExpressionContext::getSortExpression) @@ -2833,8 +2830,7 @@ private PhysicalOperation createLookupJoin( searchExpression, probeSource.getLayout(), buildLayout, - context.getTypes(), - session)) + context.getTypes())) .collect(toImmutableList())) .orElse(ImmutableList.of()); @@ -3101,12 +3097,11 @@ private JoinFilterFunctionFactory compileJoinFilterFunction( Expression filterExpression, Map probeLayout, Map buildLayout, - TypeProvider types, - Session session) + TypeProvider types) { Map joinSourcesLayout = createJoinSourcesLayout(buildLayout, probeLayout); - RowExpression translatedFilter = toRowExpression(filterExpression, typeAnalyzer.getTypes(session, types, filterExpression), joinSourcesLayout); + RowExpression translatedFilter = toRowExpression(filterExpression, typeAnalyzer.getTypes(types, filterExpression), joinSourcesLayout); return joinFilterFunctionCompiler.compileJoinFilterFunction(translatedFilter, buildLayout.size()); } @@ -3880,7 +3875,7 @@ private List> makeLambdaProviders(List lambda .put(NodeRef.of(lambdaExpression), functionType) // expressions from lambda arguments // expressions from lambda body - .putAll(typeAnalyzer.getTypes(session, TypeProvider.copyOf(lambdaArgumentSymbolTypes), lambdaExpression.getBody())) + .putAll(typeAnalyzer.getTypes(TypeProvider.copyOf(lambdaArgumentSymbolTypes), lambdaExpression.getBody())) .buildOrThrow(); LambdaDefinitionExpression lambda = (LambdaDefinitionExpression) toRowExpression(lambdaExpression, expressionTypes, ImmutableMap.of()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java index 13cf33ab5324..910ebac31cc9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java @@ -825,7 +825,7 @@ private Expression noTruncationCast(Expression expression, Type fromType, Type t new Constant(BIGINT, (long) targetLength), new CoalesceExpression( new FunctionCall( - spaceTrimmedLength.toQualifiedName(), + spaceTrimmedLength, ImmutableList.of(new Cast(expression, VARCHAR))), new Constant(BIGINT, 0L))), new Cast(expression, toType), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PartialTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/PartialTranslator.java index 80ae51791a84..ba6840fb1df3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PartialTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PartialTranslator.java @@ -50,7 +50,7 @@ public static Map, ConnectorExpression> extractPartialTransl requireNonNull(typeProvider, "typeProvider is null"); Map, ConnectorExpression> partialTranslations = new HashMap<>(); - new Visitor(session, typeAnalyzer.getTypes(session, typeProvider, inputExpression), partialTranslations, plannerContext).process(inputExpression); + new Visitor(session, typeAnalyzer.getTypes(typeProvider, inputExpression), partialTranslations, plannerContext).process(inputExpression); return ImmutableMap.copyOf(partialTranslations); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanCopier.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanCopier.java index f90c006c7cec..5b315b12fe37 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanCopier.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanCopier.java @@ -13,7 +13,6 @@ */ package io.trino.sql.planner; -import io.trino.metadata.Metadata; import io.trino.sql.planner.optimizations.UnaliasSymbolReferences; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.ApplyNode; @@ -55,10 +54,10 @@ public final class PlanCopier { private PlanCopier() {} - public static NodeAndMappings copyPlan(PlanNode plan, List fields, Metadata metadata, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) + public static NodeAndMappings copyPlan(PlanNode plan, List fields, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) { PlanNode copy = SimplePlanRewriter.rewriteWith(new Copier(idAllocator), plan, null); - return new UnaliasSymbolReferences(metadata).reallocateSymbols(copy, fields, symbolAllocator); + return new UnaliasSymbolReferences().reallocateSymbols(copy, fields, symbolAllocator); } private static class Copier diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java index e4ec21272be9..8f55231615d8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java @@ -426,7 +426,7 @@ public PlanOptimizers( new RemoveEmptyUnionBranches(), new EvaluateEmptyIntersect(), new RemoveEmptyExceptBranches(), - new MergeFilters(metadata), + new MergeFilters(), new InlineProjections(plannerContext, typeAnalyzer), new RemoveRedundantIdentityProjections(), new RemoveFullSample(), @@ -451,7 +451,7 @@ public PlanOptimizers( new RemoveRedundantEnforceSingleRowNode(), new RemoveRedundantExists(), new RemoveRedundantWindow(), - new ImplementFilteredAggregations(metadata), + new ImplementFilteredAggregations(), new SingleDistinctAggregationToGroupBy(), new MergeLimitWithDistinct(), new PruneCountAggregationOverScalar(metadata), @@ -474,7 +474,7 @@ public PlanOptimizers( .addAll(ImmutableSet.of( new MergeUnion(), new RemoveEmptyUnionBranches(), - new MergeFilters(metadata), + new MergeFilters(), new RemoveTrivialFilters(), new MergeLimits(), new MergeLimitWithSort(), @@ -490,7 +490,7 @@ public PlanOptimizers( costCalculator, ImmutableSet.of(new ImplementOffset())), simplifyOptimizer, - new UnaliasSymbolReferences(metadata), + new UnaliasSymbolReferences(), new IterativeOptimizer( plannerContext, ruleStats, @@ -564,7 +564,7 @@ public PlanOptimizers( new TransformCorrelatedInPredicateToJoin(metadata), // must be run after columnPruningOptimizer new TransformCorrelatedScalarSubquery(metadata), // must be run after TransformCorrelatedAggregation rules new TransformCorrelatedJoinToJoin(plannerContext), - new ImplementFilteredAggregations(metadata))), + new ImplementFilteredAggregations())), new IterativeOptimizer( plannerContext, ruleStats, @@ -575,8 +575,8 @@ public PlanOptimizers( new RemoveRedundantIdentityProjections(), new TransformCorrelatedSingleRowSubqueryToProject(), new RemoveAggregationInSemiJoin(), - new MergeProjectWithValues(metadata), - new ReplaceJoinOverConstantWithProject(metadata))), + new MergeProjectWithValues(), + new ReplaceJoinOverConstantWithProject())), new CheckSubqueryNodesAreRewritten(), simplifyOptimizer, // Should run after MergeProjectWithValues new StatsRecordingPlanOptimizer( @@ -598,8 +598,8 @@ public PlanOptimizers( statsCalculator, costCalculator, ImmutableSet.>builder() - .add(new InlineProjectIntoFilter(metadata)) - .add(new SimplifyFilterPredicate(metadata)) + .add(new InlineProjectIntoFilter()) + .add(new SimplifyFilterPredicate()) .addAll(columnPruningRules) .add(new InlineProjections(plannerContext, typeAnalyzer)) .addAll(new PushFilterThroughCountAggregation(plannerContext).rules()) // must run after PredicatePushDown and after TransformFilteringSemiJoinToInnerJoin @@ -643,7 +643,7 @@ public PlanOptimizers( costCalculator, pushIntoTableScanRulesExceptJoins); builder.add(pushIntoTableScanOptimizer); - builder.add(new UnaliasSymbolReferences(metadata)); + builder.add(new UnaliasSymbolReferences()); builder.add(pushIntoTableScanOptimizer); // TODO (https://github.com/trinodb/trino/issues/811) merge with the above after migrating UnaliasSymbolReferences to rules IterativeOptimizer pushProjectionIntoTableScanOptimizer = new IterativeOptimizer( @@ -695,7 +695,7 @@ public PlanOptimizers( .addAll(simplifyOptimizerRules) // Should be always run after PredicatePushDown .add(new PushPredicateIntoTableScan(plannerContext, typeAnalyzer, false)) .build()), - new UnaliasSymbolReferences(metadata), // Run again because predicate pushdown and projection pushdown might add more projections + new UnaliasSymbolReferences(), // Run again because predicate pushdown and projection pushdown might add more projections columnPruningOptimizer, // Make sure to run this before index join. Filtered projections may not have all the columns. new IndexJoinOptimizer(plannerContext), // Run this after projections and filters have been fully simplified and pushed down new LimitPushDown(), // Run LimitPushDown before WindowFilterPushDown @@ -827,7 +827,7 @@ public PlanOptimizers( new DetermineTableScanNodePartitioning(metadata, nodePartitioningManager, taskCountEstimator), // Must run after join reordering because join reordering creates // new join nodes without JoinNode.maySkipOutputDuplicates flag set - new OptimizeDuplicateInsensitiveJoins(metadata)))); + new OptimizeDuplicateInsensitiveJoins()))); // Previous invocations of PushPredicateIntoTableScan do not prune using predicate expression. The invocation in AddExchanges // does this pruning - and we may end up with empty union branches after that. We invoke PushPredicateIntoTableScan @@ -880,7 +880,7 @@ public PlanOptimizers( ImmutableSet.of(new PushTableWriteThroughUnion()))); // Must run before AddExchanges // unalias symbols before adding exchanges to use same partitioning symbols in joins, aggregations and other // operators that require node partitioning - builder.add(new UnaliasSymbolReferences(metadata)); + builder.add(new UnaliasSymbolReferences()); builder.add(new StatsRecordingPlanOptimizer(optimizerStats, new AddExchanges(plannerContext, typeAnalyzer, statsCalculator, taskCountEstimator))); // It can only run after AddExchanges since it estimates the hash partition count for all remote exchanges builder.add(new StatsRecordingPlanOptimizer(optimizerStats, new DeterminePartitionCount(statsCalculator, taskCountEstimator))); @@ -920,7 +920,7 @@ public PlanOptimizers( ruleStats, statsCalculator, costCalculator, - ImmutableSet.copyOf(new PushInequalityFilterExpressionBelowJoinRuleSet(metadata, typeAnalyzer).rules()))); + ImmutableSet.copyOf(new PushInequalityFilterExpressionBelowJoinRuleSet(typeAnalyzer).rules()))); // Projection pushdown rules may push reducing projections (e.g. dereferences) below filters for potential // pushdown into the connectors. Invoke PredicatePushdown and PushPredicateIntoTableScan after this // to leverage predicate pushdown on projected columns and to pushdown dynamic filters. @@ -939,7 +939,7 @@ public PlanOptimizers( // PushPredicateIntoTableScan and RemoveRedundantPredicateAboveTableScan due to those rules replacing table scans with empty ValuesNode builder.add(new RemoveUnsupportedDynamicFilters(plannerContext)); builder.add(inlineProjections); - builder.add(new UnaliasSymbolReferences(metadata)); // Run unalias after merging projections to simplify projections more efficiently + builder.add(new UnaliasSymbolReferences()); // Run unalias after merging projections to simplify projections more efficiently builder.add(columnPruningOptimizer); builder.add(new IterativeOptimizer( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java index cd465b99d626..a5eed6e99c72 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java @@ -377,7 +377,7 @@ public RelationPlan planExpand(Query query) // order and might be used to identify the original output symbols with their copies. private NodeAndMappings copy(PlanNode plan, List fields) { - return PlanCopier.copyPlan(plan, fields, plannerContext.getMetadata(), symbolAllocator, idAllocator); + return PlanCopier.copyPlan(plan, fields, symbolAllocator, idAllocator); } private PlanNode replace(PlanNode plan, NodeAndMappings replacementSpot, NodeAndMappings replacement) @@ -1568,7 +1568,7 @@ private FrameBoundPlanAndSymbols planFrameBound(PlanBuilder subPlan, PlanAndMapp // Note: if frameOffset needs a coercion, it was added before by a call to coerce() method. ResolvedFunction function = frameBoundCalculationFunction.get(); Expression functionCall = new FunctionCall( - function.toQualifiedName(), + function, ImmutableList.of( sortKeyCoercedForFrameBoundCalculation.toSymbolReference(), offsetSymbol.toSymbolReference())); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index d1e97afd3880..f1e5a5067894 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -1367,7 +1367,7 @@ private RelationPlan planJoinJsonTable(PlanBuilder leftPlan, List leftFi // apply the input function to the input expression Constant failOnError = new Constant(BOOLEAN, jsonTable.getErrorBehavior().orElse(JsonTable.ErrorBehavior.EMPTY) == JsonTable.ErrorBehavior.ERROR); ResolvedFunction inputToJson = analysis.getJsonInputFunction(inputExpression); - Expression inputJson = new FunctionCall(inputToJson.toQualifiedName(), ImmutableList.of(coerced.get(inputExpression).toSymbolReference(), failOnError)); + Expression inputJson = new FunctionCall(inputToJson, ImmutableList.of(coerced.get(inputExpression).toSymbolReference(), failOnError)); // apply the input functions to the JSON path parameters having FORMAT, // and collect all JSON path parameters in a Row @@ -1498,7 +1498,7 @@ else if (jsonTable.getPlan().orElseThrow() instanceof JsonTableDefaultPlan defau Constant errorBehavior = new Constant(TINYINT, (long) queryColumn.getErrorBehavior().orElse(defaultErrorOnError ? ERROR : NULL).ordinal()); Constant omitQuotes = new Constant(BOOLEAN, queryColumn.getQuotesBehavior().orElse(KEEP) == OMIT); ResolvedFunction outputFunction = analysis.getJsonOutputFunction(queryColumn); - Expression result = new FunctionCall(outputFunction.toQualifiedName(), ImmutableList.of(properOutput.toSymbolReference(), errorBehavior, omitQuotes)); + Expression result = new FunctionCall(outputFunction, ImmutableList.of(properOutput.toSymbolReference(), errorBehavior, omitQuotes)); // cast to declared returned type Type expectedType = jsonTableRelationType.getFieldByIndex(i).getType(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/ResolvedFunctionCallBuilder.java b/core/trino-main/src/main/java/io/trino/sql/planner/ResolvedFunctionCallBuilder.java index 75ac3b9cfd67..cdfbbf0586ce 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/ResolvedFunctionCallBuilder.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/ResolvedFunctionCallBuilder.java @@ -53,6 +53,6 @@ public ResolvedFunctionCallBuilder setArguments(List values) public FunctionCall build() { - return new FunctionCall(resolvedFunction.toQualifiedName(), argumentValues); + return new FunctionCall(resolvedFunction, argumentValues); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/SortExpressionExtractor.java b/core/trino-main/src/main/java/io/trino/sql/planner/SortExpressionExtractor.java index bb7b0b5c1093..ba8e94c255f1 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/SortExpressionExtractor.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/SortExpressionExtractor.java @@ -14,7 +14,6 @@ package io.trino.sql.planner; import com.google.common.collect.ImmutableList; -import io.trino.metadata.Metadata; import io.trino.operator.join.SortedPositionLinks; import io.trino.sql.ir.BetweenPredicate; import io.trino.sql.ir.ComparisonExpression; @@ -59,13 +58,13 @@ by sorting position links according to the result of f(...) function. */ private SortExpressionExtractor() {} - public static Optional extractSortExpression(Metadata metadata, Set buildSymbols, Expression filter) + public static Optional extractSortExpression(Set buildSymbols, Expression filter) { List filterConjuncts = IrUtils.extractConjuncts(filter); SortExpressionVisitor visitor = new SortExpressionVisitor(buildSymbols); List sortExpressionCandidates = ImmutableList.copyOf(filterConjuncts.stream() - .filter(expression -> DeterminismEvaluator.isDeterministic(expression, metadata)) + .filter(expression -> DeterminismEvaluator.isDeterministic(expression)) .map(visitor::process) .filter(Optional::isPresent) .map(Optional::get) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/SplitSourceFactory.java b/core/trino-main/src/main/java/io/trino/sql/planner/SplitSourceFactory.java index 26e961462ed0..88e8d9b11c23 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/SplitSourceFactory.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/SplitSourceFactory.java @@ -179,7 +179,7 @@ private SplitSource createSplitSource(TableHandle table, Map filterConjuncts(plannerContext.getMetadata(), predicate, expression -> !DynamicFilters.isDynamicFilter(expression))) + .map(predicate -> filterConjuncts(predicate, expression -> !DynamicFilters.isDynamicFilter(expression))) .map(predicate -> new LayoutConstraintEvaluator(plannerContext, typeAnalyzer, session, typeProvider, assignments, predicate)) .map(evaluator -> new Constraint(TupleDomain.all(), evaluator::isCandidate, evaluator.getArguments())) // we are interested only in functional predicate here, so we set the summary to ALL. .orElse(alwaysTrue()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/SymbolAllocator.java b/core/trino-main/src/main/java/io/trino/sql/planner/SymbolAllocator.java index 8eeac83d9789..7f8003481bb3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/SymbolAllocator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/SymbolAllocator.java @@ -14,7 +14,6 @@ package io.trino.sql.planner; import com.google.common.primitives.Ints; -import io.trino.metadata.ResolvedFunction; import io.trino.spi.type.BigintType; import io.trino.spi.type.Type; import io.trino.sql.analyzer.Field; @@ -109,12 +108,7 @@ public Symbol newSymbol(Expression expression, Type type, String suffix) String nameHint = "expr"; if (expression instanceof FunctionCall functionCall) { // symbol allocation can happen during planning, before function calls are rewritten - if (ResolvedFunction.isResolved(functionCall.getName())) { - nameHint = ResolvedFunction.extractFunctionName(functionCall.getName()).getFunctionName(); - } - else { - nameHint = functionCall.getName().getSuffix(); - } + nameHint = functionCall.getFunction().getName().getFunctionName(); } else if (expression instanceof SymbolReference symbolReference) { nameHint = symbolReference.getName(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java b/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java index d83280e24c85..0d9883260b4c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java @@ -660,7 +660,7 @@ private io.trino.sql.ir.Expression translate(FunctionCall expression) checkArgument(resolvedFunction != null, "Function has not been analyzed: %s", expression); return new io.trino.sql.ir.FunctionCall( - resolvedFunction.toQualifiedName(), + resolvedFunction, expression.getArguments().stream() .map(this::translateExpression) .collect(toImmutableList())); @@ -716,8 +716,7 @@ private io.trino.sql.ir.Expression translate(CurrentCatalog unused) { return new io.trino.sql.ir.FunctionCall( plannerContext.getMetadata() - .resolveBuiltinFunction("$current_catalog", ImmutableList.of()) - .toQualifiedName(), + .resolveBuiltinFunction("$current_catalog", ImmutableList.of()), ImmutableList.of()); } @@ -725,8 +724,7 @@ private io.trino.sql.ir.Expression translate(CurrentSchema unused) { return new io.trino.sql.ir.FunctionCall( plannerContext.getMetadata() - .resolveBuiltinFunction("$current_schema", ImmutableList.of()) - .toQualifiedName(), + .resolveBuiltinFunction("$current_schema", ImmutableList.of()), ImmutableList.of()); } @@ -734,8 +732,7 @@ private io.trino.sql.ir.Expression translate(CurrentPath unused) { return new io.trino.sql.ir.FunctionCall( plannerContext.getMetadata() - .resolveBuiltinFunction("$current_path", ImmutableList.of()) - .toQualifiedName(), + .resolveBuiltinFunction("$current_path", ImmutableList.of()), ImmutableList.of()); } @@ -743,8 +740,7 @@ private io.trino.sql.ir.Expression translate(CurrentUser unused) { return new io.trino.sql.ir.FunctionCall( plannerContext.getMetadata() - .resolveBuiltinFunction("$current_user", ImmutableList.of()) - .toQualifiedName(), + .resolveBuiltinFunction("$current_user", ImmutableList.of()), ImmutableList.of()); } @@ -970,7 +966,7 @@ private io.trino.sql.ir.Expression translate(Trim node) .map(this::translateExpression) .ifPresent(arguments::add); - return new io.trino.sql.ir.FunctionCall(resolvedFunction.toQualifiedName(), arguments.build()); + return new io.trino.sql.ir.FunctionCall(resolvedFunction, arguments.build()); } private io.trino.sql.ir.Expression translate(SubscriptExpression node) @@ -1015,7 +1011,7 @@ private io.trino.sql.ir.Expression translate(JsonExists node) // apply the input function to the input expression Constant failOnError = new Constant(BOOLEAN, node.getErrorBehavior() == JsonExists.ErrorBehavior.ERROR); ResolvedFunction inputToJson = analysis.getJsonInputFunction(node.getJsonPathInvocation().getInputExpression()); - io.trino.sql.ir.Expression input = new io.trino.sql.ir.FunctionCall(inputToJson.toQualifiedName(), ImmutableList.of( + io.trino.sql.ir.Expression input = new io.trino.sql.ir.FunctionCall(inputToJson, ImmutableList.of( translateExpression(node.getJsonPathInvocation().getInputExpression()), failOnError)); @@ -1038,7 +1034,7 @@ private io.trino.sql.ir.Expression translate(JsonExists node) .add(orderedParameters.getParametersRow()) .add(new Constant(TINYINT, (long) node.getErrorBehavior().ordinal())); - return new io.trino.sql.ir.FunctionCall(resolvedFunction.toQualifiedName(), arguments.build()); + return new io.trino.sql.ir.FunctionCall(resolvedFunction, arguments.build()); } private io.trino.sql.ir.Expression translate(JsonValue node) @@ -1049,7 +1045,7 @@ private io.trino.sql.ir.Expression translate(JsonValue node) // apply the input function to the input expression Constant failOnError = new Constant(BOOLEAN, node.getErrorBehavior() == JsonValue.EmptyOrErrorBehavior.ERROR); ResolvedFunction inputToJson = analysis.getJsonInputFunction(node.getJsonPathInvocation().getInputExpression()); - io.trino.sql.ir.Expression input = new io.trino.sql.ir.FunctionCall(inputToJson.toQualifiedName(), ImmutableList.of( + io.trino.sql.ir.Expression input = new io.trino.sql.ir.FunctionCall(inputToJson, ImmutableList.of( translateExpression(node.getJsonPathInvocation().getInputExpression()), failOnError)); @@ -1079,7 +1075,7 @@ private io.trino.sql.ir.Expression translate(JsonValue node) .map(this::translateExpression) .orElseGet(() -> new Constant(resolvedFunction.getSignature().getReturnType(), null))); - return new io.trino.sql.ir.FunctionCall(resolvedFunction.toQualifiedName(), arguments.build()); + return new io.trino.sql.ir.FunctionCall(resolvedFunction, arguments.build()); } private io.trino.sql.ir.Expression translate(JsonQuery node) @@ -1090,7 +1086,7 @@ private io.trino.sql.ir.Expression translate(JsonQuery node) // apply the input function to the input expression Constant failOnError = new Constant(BOOLEAN, node.getErrorBehavior() == JsonQuery.EmptyOrErrorBehavior.ERROR); ResolvedFunction inputToJson = analysis.getJsonInputFunction(node.getJsonPathInvocation().getInputExpression()); - io.trino.sql.ir.Expression input = new io.trino.sql.ir.FunctionCall(inputToJson.toQualifiedName(), ImmutableList.of( + io.trino.sql.ir.Expression input = new io.trino.sql.ir.FunctionCall(inputToJson, ImmutableList.of( translateExpression(node.getJsonPathInvocation().getInputExpression()), failOnError)); @@ -1115,13 +1111,13 @@ private io.trino.sql.ir.Expression translate(JsonQuery node) .add(new Constant(TINYINT, (long) node.getEmptyBehavior().ordinal())) .add(new Constant(TINYINT, (long) node.getErrorBehavior().ordinal())); - io.trino.sql.ir.Expression function = new io.trino.sql.ir.FunctionCall(resolvedFunction.toQualifiedName(), arguments.build()); + io.trino.sql.ir.Expression function = new io.trino.sql.ir.FunctionCall(resolvedFunction, arguments.build()); // apply function to format output Constant errorBehavior = new Constant(TINYINT, (long) node.getErrorBehavior().ordinal()); Constant omitQuotes = new Constant(BOOLEAN, node.getQuotesBehavior().orElse(KEEP) == OMIT); ResolvedFunction outputFunction = analysis.getJsonOutputFunction(node); - io.trino.sql.ir.Expression result = new io.trino.sql.ir.FunctionCall(outputFunction.toQualifiedName(), ImmutableList.of(function, errorBehavior, omitQuotes)); + io.trino.sql.ir.Expression result = new io.trino.sql.ir.FunctionCall(outputFunction, ImmutableList.of(function, errorBehavior, omitQuotes)); // cast to requested returned type Type returnedType = node.getReturnedType() @@ -1164,7 +1160,7 @@ private io.trino.sql.ir.Expression translate(JsonObject node) io.trino.sql.ir.Expression rewrittenValue = translateExpression(value); ResolvedFunction valueToJson = analysis.getJsonInputFunction(value); if (valueToJson != null) { - values.add(new io.trino.sql.ir.FunctionCall(valueToJson.toQualifiedName(), ImmutableList.of(rewrittenValue, TRUE_LITERAL))); + values.add(new io.trino.sql.ir.FunctionCall(valueToJson, ImmutableList.of(rewrittenValue, TRUE_LITERAL))); } else { values.add(rewrittenValue); @@ -1181,11 +1177,11 @@ private io.trino.sql.ir.Expression translate(JsonObject node) .add(node.isUniqueKeys() ? TRUE_LITERAL : FALSE_LITERAL) .build(); - io.trino.sql.ir.Expression function = new io.trino.sql.ir.FunctionCall(resolvedFunction.toQualifiedName(), arguments); + io.trino.sql.ir.Expression function = new io.trino.sql.ir.FunctionCall(resolvedFunction, arguments); // apply function to format output ResolvedFunction outputFunction = analysis.getJsonOutputFunction(node); - io.trino.sql.ir.Expression result = new io.trino.sql.ir.FunctionCall(outputFunction.toQualifiedName(), ImmutableList.of( + io.trino.sql.ir.Expression result = new io.trino.sql.ir.FunctionCall(outputFunction, ImmutableList.of( function, new Constant(TINYINT, (long) ERROR.ordinal()), FALSE_LITERAL)); @@ -1223,7 +1219,7 @@ private io.trino.sql.ir.Expression translate(JsonArray node) io.trino.sql.ir.Expression rewrittenElement = translateExpression(element); ResolvedFunction elementToJson = analysis.getJsonInputFunction(element); if (elementToJson != null) { - elements.add(new io.trino.sql.ir.FunctionCall(elementToJson.toQualifiedName(), ImmutableList.of(rewrittenElement, TRUE_LITERAL))); + elements.add(new io.trino.sql.ir.FunctionCall(elementToJson, ImmutableList.of(rewrittenElement, TRUE_LITERAL))); } else { elements.add(rewrittenElement); @@ -1237,11 +1233,11 @@ private io.trino.sql.ir.Expression translate(JsonArray node) .add(node.isNullOnNull() ? TRUE_LITERAL : FALSE_LITERAL) .build(); - io.trino.sql.ir.Expression function = new io.trino.sql.ir.FunctionCall(resolvedFunction.toQualifiedName(), arguments); + io.trino.sql.ir.Expression function = new io.trino.sql.ir.FunctionCall(resolvedFunction, arguments); // apply function to format output ResolvedFunction outputFunction = analysis.getJsonOutputFunction(node); - io.trino.sql.ir.Expression result = new io.trino.sql.ir.FunctionCall(outputFunction.toQualifiedName(), ImmutableList.of( + io.trino.sql.ir.Expression result = new io.trino.sql.ir.FunctionCall(outputFunction, ImmutableList.of( function, new Constant(TINYINT, (long) ERROR.ordinal()), FALSE_LITERAL)); @@ -1311,7 +1307,7 @@ public ParametersRow getParametersRow( ResolvedFunction parameterToJson = analysis.getJsonInputFunction(pathParameters.get(i).getParameter()); io.trino.sql.ir.Expression rewrittenParameter = rewrittenPathParameters.get(i); if (parameterToJson != null) { - parameters.add(new io.trino.sql.ir.FunctionCall(parameterToJson.toQualifiedName(), ImmutableList.of(rewrittenParameter, failOnError))); + parameters.add(new io.trino.sql.ir.FunctionCall(parameterToJson, ImmutableList.of(rewrittenParameter, failOnError))); } else { parameters.add(rewrittenParameter); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ApplyTableScanRedirection.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ApplyTableScanRedirection.java index 9b6436b1f8f4..c8d214762d07 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ApplyTableScanRedirection.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ApplyTableScanRedirection.java @@ -207,7 +207,7 @@ public Result apply(TableScanNode scanNode, Captures captures, Context context) scanNode.isUpdateTarget(), Optional.empty()); - DomainTranslator domainTranslator = new DomainTranslator(plannerContext); + DomainTranslator domainTranslator = new DomainTranslator(); FilterNode filterNode = new FilterNode( context.getIdAllocator().getNextId(), applyProjection( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ArraySortAfterArrayDistinct.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ArraySortAfterArrayDistinct.java index 5887b39d56f6..e8fa678df477 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ArraySortAfterArrayDistinct.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ArraySortAfterArrayDistinct.java @@ -79,12 +79,12 @@ public Visitor(Metadata metadata) public Expression rewriteFunctionCall(FunctionCall node, Void context, ExpressionTreeRewriter treeRewriter) { FunctionCall rewritten = treeRewriter.defaultRewrite(node, context); - if (metadata.decodeFunction(rewritten.getName()).getSignature().getName().equals(ARRAY_DISTINCT_NAME) && + if (node.getFunction().getName().equals(ARRAY_DISTINCT_NAME) && getOnlyElement(rewritten.getArguments()) instanceof FunctionCall) { Expression expression = getOnlyElement(rewritten.getArguments()); FunctionCall functionCall = (FunctionCall) expression; - ResolvedFunction resolvedFunction = metadata.decodeFunction(functionCall.getName()); - if (resolvedFunction.getSignature().getName().equals(ARRAY_SORT_NAME)) { + ResolvedFunction resolvedFunction = functionCall.getFunction(); + if (resolvedFunction.getName().equals(ARRAY_SORT_NAME)) { List arraySortArguments = functionCall.getArguments(); List arraySortArgumentsTypes = resolvedFunction.getSignature().getArgumentTypes(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/CanonicalizeExpressionRewriter.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/CanonicalizeExpressionRewriter.java index bea59f5af34f..adab135fe31a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/CanonicalizeExpressionRewriter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/CanonicalizeExpressionRewriter.java @@ -41,7 +41,6 @@ import java.util.Optional; import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; -import static io.trino.metadata.ResolvedFunction.extractFunctionName; import static io.trino.spi.type.DateType.DATE; import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MULTIPLY; @@ -134,10 +133,10 @@ public Expression rewriteIfExpression(IfExpression node, Void context, Expressio @Override public Expression rewriteFunctionCall(FunctionCall node, Void context, ExpressionTreeRewriter treeRewriter) { - CatalogSchemaFunctionName functionName = extractFunctionName(node.getName()); + CatalogSchemaFunctionName functionName = node.getFunction().getName(); if (functionName.equals(builtinFunctionName("date")) && node.getArguments().size() == 1) { Expression argument = node.getArguments().get(0); - Type argumentType = typeAnalyzer.getType(session, types, argument); + Type argumentType = typeAnalyzer.getType(types, argument); if (argumentType instanceof TimestampType || argumentType instanceof TimestampWithTimeZoneType || argumentType instanceof VarcharType) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DereferencePushdown.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DereferencePushdown.java index 60ffe303ed51..60d75e01f1a7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DereferencePushdown.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DereferencePushdown.java @@ -14,7 +14,6 @@ package io.trino.sql.planner.iterative.rule; import com.google.common.collect.ImmutableList; -import io.trino.Session; import io.trino.spi.type.RowType; import io.trino.sql.ir.DefaultTraversalVisitor; import io.trino.sql.ir.Expression; @@ -41,10 +40,10 @@ class DereferencePushdown { private DereferencePushdown() {} - public static Set extractRowSubscripts(Collection expressions, boolean allowOverlap, Session session, IrTypeAnalyzer typeAnalyzer, TypeProvider types) + public static Set extractRowSubscripts(Collection expressions, boolean allowOverlap, IrTypeAnalyzer typeAnalyzer, TypeProvider types) { Set symbolReferencesAndRowSubscripts = expressions.stream() - .flatMap(expression -> getSymbolReferencesAndRowSubscripts(expression, session, typeAnalyzer, types).stream()) + .flatMap(expression -> getSymbolReferencesAndRowSubscripts(expression, typeAnalyzer, types).stream()) .collect(toImmutableSet()); // Remove overlap if required @@ -62,12 +61,12 @@ public static Set extractRowSubscripts(Collection projections, Session session, IrTypeAnalyzer typeAnalyzer, TypeProvider types) + public static boolean exclusiveDereferences(Set projections, IrTypeAnalyzer typeAnalyzer, TypeProvider types) { return projections.stream() .allMatch(expression -> expression instanceof SymbolReference || (expression instanceof SubscriptExpression && - isRowSubscriptChain((SubscriptExpression) expression, session, typeAnalyzer, types) && + isRowSubscriptChain((SubscriptExpression) expression, typeAnalyzer, types) && !prefixExists(expression, projections))); } @@ -80,7 +79,7 @@ public static Symbol getBase(SubscriptExpression expression) * Extract the sub-expressions of type {@link SubscriptExpression} or {@link SymbolReference} from the expression * in a top-down manner. The expressions within the base of a valid {@link SubscriptExpression} sequence are not extracted. */ - private static List getSymbolReferencesAndRowSubscripts(Expression expression, Session session, IrTypeAnalyzer typeAnalyzer, TypeProvider types) + private static List getSymbolReferencesAndRowSubscripts(Expression expression, IrTypeAnalyzer typeAnalyzer, TypeProvider types) { ImmutableList.Builder builder = ImmutableList.builder(); @@ -89,7 +88,7 @@ private static List getSymbolReferencesAndRowSubscripts(Expression e @Override protected Void visitSubscriptExpression(SubscriptExpression node, ImmutableList.Builder context) { - if (isRowSubscriptChain(node, session, typeAnalyzer, types)) { + if (isRowSubscriptChain(node, typeAnalyzer, types)) { context.add(node); } return null; @@ -112,14 +111,14 @@ protected Void visitLambdaExpression(LambdaExpression node, ImmutableList.Builde return builder.build(); } - private static boolean isRowSubscriptChain(SubscriptExpression expression, Session session, IrTypeAnalyzer typeAnalyzer, TypeProvider types) + private static boolean isRowSubscriptChain(SubscriptExpression expression, IrTypeAnalyzer typeAnalyzer, TypeProvider types) { - if (!(typeAnalyzer.getType(session, types, expression.getBase()) instanceof RowType)) { + if (!(typeAnalyzer.getType(types, expression.getBase()) instanceof RowType)) { return false; } return (expression.getBase() instanceof SymbolReference) || - ((expression.getBase() instanceof SubscriptExpression) && isRowSubscriptChain((SubscriptExpression) expression.getBase(), session, typeAnalyzer, types)); + ((expression.getBase() instanceof SubscriptExpression) && isRowSubscriptChain((SubscriptExpression) expression.getBase(), typeAnalyzer, types)); } private static boolean prefixExists(Expression expression, Set expressions) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractCommonPredicatesExpressionRewriter.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractCommonPredicatesExpressionRewriter.java index b923b50a9e9a..cc0a952fa45e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractCommonPredicatesExpressionRewriter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractCommonPredicatesExpressionRewriter.java @@ -134,7 +134,7 @@ private static List> getSubPredicates(LogicalExpression express */ private Expression distributeIfPossible(LogicalExpression expression) { - if (!isDeterministic(expression, metadata)) { + if (!isDeterministic(expression)) { // Do not distribute boolean expressions if there are any non-deterministic elements // TODO: This can be optimized further if non-deterministic elements are not repeated return expression; @@ -179,7 +179,7 @@ private Expression distributeIfPossible(LogicalExpression expression) private Set filterDeterministicPredicates(List predicates) { return predicates.stream() - .filter(expression -> isDeterministic(expression, metadata)) + .filter(expression -> isDeterministic(expression)) .collect(toSet()); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractDereferencesFromFilterAboveScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractDereferencesFromFilterAboveScan.java index 71624be6f5cb..04476a094cd3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractDereferencesFromFilterAboveScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractDereferencesFromFilterAboveScan.java @@ -86,12 +86,12 @@ public Pattern getPattern() @Override public Result apply(FilterNode node, Captures captures, Context context) { - Set dereferences = extractRowSubscripts(ImmutableList.of(node.getPredicate()), true, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes()); + Set dereferences = extractRowSubscripts(ImmutableList.of(node.getPredicate()), true, typeAnalyzer, context.getSymbolAllocator().getTypes()); if (dereferences.isEmpty()) { return Result.empty(); } - Assignments assignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer); + Assignments assignments = Assignments.of(dereferences, context.getSymbolAllocator(), typeAnalyzer); Map mappings = HashBiMap.create(assignments.getMap()) .inverse() .entrySet().stream() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractSpatialJoins.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractSpatialJoins.java index 265e2acd6b98..a71ad1f56c36 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractSpatialJoins.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractSpatialJoins.java @@ -381,8 +381,8 @@ private static Result tryCreateSpatialJoin( Expression secondArgument = arguments.get(1); Type sphericalGeographyType = plannerContext.getTypeManager().getType(SPHERICAL_GEOGRAPHY_TYPE_SIGNATURE); - if (typeAnalyzer.getType(context.getSession(), context.getSymbolAllocator().getTypes(), firstArgument).equals(sphericalGeographyType) - || typeAnalyzer.getType(context.getSession(), context.getSymbolAllocator().getTypes(), secondArgument).equals(sphericalGeographyType)) { + if (typeAnalyzer.getType(context.getSymbolAllocator().getTypes(), firstArgument).equals(sphericalGeographyType) + || typeAnalyzer.getType(context.getSymbolAllocator().getTypes(), secondArgument).equals(sphericalGeographyType)) { return Result.empty(); } @@ -435,9 +435,7 @@ else if (alignment < 0) { } } - ResolvedFunction resolvedFunction = plannerContext.getFunctionDecoder() - .fromQualifiedName(spatialFunction.getName()) - .orElseThrow(() -> new IllegalArgumentException("function call not resolved")); + ResolvedFunction resolvedFunction = spatialFunction.getFunction(); Expression newSpatialFunction = ResolvedFunctionCallBuilder.builder(resolvedFunction) .addArgument(newFirstArgument) .addArgument(newSecondArgument) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementExceptAll.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementExceptAll.java index dfe6ebfa320f..c1fd5a7aec42 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementExceptAll.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementExceptAll.java @@ -101,7 +101,7 @@ public Result apply(ExceptNode node, Captures captures, Context context) Expression count = result.getCountSymbols().get(0).toSymbolReference(); for (int i = 1; i < result.getCountSymbols().size(); i++) { count = new FunctionCall( - greatest.toQualifiedName(), + greatest, ImmutableList.of( new ArithmeticBinaryExpression(SUBTRACT, count, result.getCountSymbols().get(i).toSymbolReference()), new Constant(BIGINT, 0L))); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementFilteredAggregations.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementFilteredAggregations.java index 5d291dfadf5b..bc11a1d0c92d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementFilteredAggregations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementFilteredAggregations.java @@ -17,7 +17,6 @@ import com.google.common.collect.ImmutableMap; import io.trino.matching.Captures; import io.trino.matching.Pattern; -import io.trino.metadata.Metadata; import io.trino.sql.ir.Expression; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; @@ -66,13 +65,6 @@ public class ImplementFilteredAggregations private static final Pattern PATTERN = aggregation() .matching(ImplementFilteredAggregations::hasFilters); - private final Metadata metadata; - - public ImplementFilteredAggregations(Metadata metadata) - { - this.metadata = metadata; - } - private static boolean hasFilters(AggregationNode aggregation) { return aggregation.getAggregations() @@ -133,7 +125,7 @@ else if (mask.isPresent()) { Expression predicate = TRUE_LITERAL; if (!aggregationNode.hasNonEmptyGroupingSet() && !aggregateWithoutFilterOrMaskPresent) { - predicate = combineDisjunctsWithDefault(metadata, maskSymbols.build(), TRUE_LITERAL); + predicate = combineDisjunctsWithDefault(maskSymbols.build(), TRUE_LITERAL); } // identity projection for all existing inputs diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementIntersectAll.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementIntersectAll.java index 1664fbe2eb3e..eea2eb025b09 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementIntersectAll.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementIntersectAll.java @@ -97,7 +97,7 @@ public Result apply(IntersectNode node, Captures captures, Context context) Expression minCount = result.getCountSymbols().get(0).toSymbolReference(); for (int i = 1; i < result.getCountSymbols().size(); i++) { - minCount = new FunctionCall(least.toQualifiedName(), ImmutableList.of(minCount, result.getCountSymbols().get(i).toSymbolReference())); + minCount = new FunctionCall(least, ImmutableList.of(minCount, result.getCountSymbols().get(i).toSymbolReference())); } // filter rows so that expected number of rows remains diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/InlineProjectIntoFilter.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/InlineProjectIntoFilter.java index 84c551b26e40..48f2562a0fb9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/InlineProjectIntoFilter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/InlineProjectIntoFilter.java @@ -18,7 +18,6 @@ import io.trino.matching.Capture; import io.trino.matching.Captures; import io.trino.matching.Pattern; -import io.trino.metadata.Metadata; import io.trino.sql.ir.Expression; import io.trino.sql.ir.SymbolReference; import io.trino.sql.planner.Symbol; @@ -95,13 +94,6 @@ public class InlineProjectIntoFilter private static final Pattern PATTERN = filter() .with(source().matching(project().capturedAs(PROJECTION))); - private final Metadata metadata; - - public InlineProjectIntoFilter(Metadata metadata) - { - this.metadata = metadata; - } - @Override public Pattern getPattern() { @@ -184,7 +176,7 @@ public Result apply(FilterNode node, Captures captures, Context context) projectNode.getId(), projectNode.getSource(), newAssignments.build()), - combineConjuncts(metadata, newConjuncts.build())), + combineConjuncts(newConjuncts.build())), Assignments.builder() .putAll(outputAssignments) .build())); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/InlineProjections.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/InlineProjections.java index ba768d285bc0..3018befca162 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/InlineProjections.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/InlineProjections.java @@ -195,7 +195,7 @@ private static Set extractInliningTargets(PlannerContext plannerContext, Expression assignment = child.getAssignments().get(entry.getKey()); if (assignment instanceof SubscriptExpression) { - if (typeAnalyzer.getType(session, types, ((SubscriptExpression) assignment).getBase()) instanceof RowType) { + if (typeAnalyzer.getType(types, ((SubscriptExpression) assignment).getBase()) instanceof RowType) { return false; } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MergeFilters.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MergeFilters.java index 1794e1ba658e..38111b7193c9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MergeFilters.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MergeFilters.java @@ -16,7 +16,6 @@ import io.trino.matching.Capture; import io.trino.matching.Captures; import io.trino.matching.Pattern; -import io.trino.metadata.Metadata; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.FilterNode; @@ -24,7 +23,6 @@ import static io.trino.sql.ir.IrUtils.combineConjuncts; import static io.trino.sql.planner.plan.Patterns.filter; import static io.trino.sql.planner.plan.Patterns.source; -import static java.util.Objects.requireNonNull; public class MergeFilters implements Rule @@ -34,13 +32,6 @@ public class MergeFilters private static final Pattern PATTERN = filter() .with(source().matching(filter().capturedAs(CHILD))); - private final Metadata metadata; - - public MergeFilters(Metadata metadata) - { - this.metadata = requireNonNull(metadata, "metadata is null"); - } - @Override public Pattern getPattern() { @@ -56,6 +47,6 @@ public Result apply(FilterNode parent, Captures captures, Context context) new FilterNode( parent.getId(), child.getSource(), - combineConjuncts(metadata, child.getPredicate(), parent.getPredicate()))); + combineConjuncts(child.getPredicate(), parent.getPredicate()))); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MergeProjectWithValues.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MergeProjectWithValues.java index ae7a74ef67e5..0a4536c98045 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MergeProjectWithValues.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MergeProjectWithValues.java @@ -20,7 +20,6 @@ import io.trino.matching.Capture; import io.trino.matching.Captures; import io.trino.matching.Pattern; -import io.trino.metadata.Metadata; import io.trino.sql.ir.Expression; import io.trino.sql.ir.Row; import io.trino.sql.ir.SymbolReference; @@ -47,7 +46,6 @@ import static io.trino.sql.planner.plan.Patterns.source; import static io.trino.sql.planner.plan.Patterns.values; import static java.util.Collections.nCopies; -import static java.util.Objects.requireNonNull; /** * Transforms: @@ -89,13 +87,6 @@ public class MergeProjectWithValues .matching(MergeProjectWithValues::isSupportedValues) .capturedAs(VALUES))); - private final Metadata metadata; - - public MergeProjectWithValues(Metadata metadata) - { - this.metadata = requireNonNull(metadata, "metadata is null"); - } - @Override public Pattern getPattern() { @@ -140,7 +131,7 @@ public Result apply(ProjectNode node, Captures captures, Context context) for (Expression rowExpression : valuesNode.getRows().get()) { Row row = (Row) rowExpression; for (int i = 0; i < valuesNode.getOutputSymbols().size(); i++) { - if (!isDeterministic(row.getItems().get(i), metadata)) { + if (!isDeterministic(row.getItems().get(i))) { nonDeterministicValuesOutputs.add(valuesNode.getOutputSymbols().get(i)); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/OptimizeDuplicateInsensitiveJoins.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/OptimizeDuplicateInsensitiveJoins.java index ec603190a07f..96da45888ac1 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/OptimizeDuplicateInsensitiveJoins.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/OptimizeDuplicateInsensitiveJoins.java @@ -18,7 +18,6 @@ import io.trino.Session; import io.trino.matching.Captures; import io.trino.matching.Pattern; -import io.trino.metadata.Metadata; import io.trino.sql.planner.iterative.GroupReference; import io.trino.sql.planner.iterative.Lookup; import io.trino.sql.planner.iterative.Rule; @@ -52,13 +51,6 @@ public class OptimizeDuplicateInsensitiveJoins private static final Pattern PATTERN = aggregation() .matching(aggregation -> aggregation.getAggregations().isEmpty()); - private final Metadata metadata; - - public OptimizeDuplicateInsensitiveJoins(Metadata metadata) - { - this.metadata = requireNonNull(metadata, "metadata is null"); - } - @Override public Pattern getPattern() { @@ -74,7 +66,7 @@ public boolean isEnabled(Session session) @Override public Result apply(AggregationNode aggregation, Captures captures, Context context) { - return aggregation.getSource().accept(new Rewriter(metadata, context.getLookup()), null) + return aggregation.getSource().accept(new Rewriter(context.getLookup()), null) .map(rewrittenSource -> Result.ofPlanNode(aggregation.replaceChildren(ImmutableList.of(rewrittenSource)))) .orElse(Result.empty()); } @@ -82,12 +74,10 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont private static class Rewriter extends PlanVisitor, Void> { - private final Metadata metadata; private final Lookup lookup; - private Rewriter(Metadata metadata, Lookup lookup) + private Rewriter(Lookup lookup) { - this.metadata = requireNonNull(metadata, "metadata is null"); this.lookup = requireNonNull(lookup, "lookup is null"); } @@ -100,7 +90,7 @@ protected Optional visitPlan(PlanNode node, Void context) @Override public Optional visitFilter(FilterNode node, Void context) { - if (!isDeterministic(node.getPredicate(), metadata)) { + if (!isDeterministic(node.getPredicate())) { // non-deterministic expressions could filter duplicate rows probabilistically return Optional.empty(); } @@ -113,7 +103,7 @@ public Optional visitFilter(FilterNode node, Void context) public Optional visitProject(ProjectNode node, Void context) { boolean isDeterministic = node.getAssignments().getExpressions().stream() - .allMatch(expression -> isDeterministic(expression, metadata)); + .allMatch(expression -> isDeterministic(expression)); if (!isDeterministic) { // non-deterministic projections could be used in downstream filters which could // filter duplicate rows probabilistically @@ -144,7 +134,7 @@ public Optional visitJoin(JoinNode node, Void context) // LookupJoinOperator will evaluate non-deterministic condition on output rows until one of the // rows matches. Therefore it's safe to set maySkipOutputDuplicates for joins with non-deterministic // filters. - if (!isDeterministic(node.getFilter().orElse(TRUE_LITERAL), metadata)) { + if (!isDeterministic(node.getFilter().orElse(TRUE_LITERAL))) { if (node.isMaySkipOutputDuplicates()) { // join node is already set to skip duplicates, return empty to prevent rule from looping forever return Optional.empty(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PreAggregateCaseAggregations.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PreAggregateCaseAggregations.java index bd2825408327..812de498d250 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PreAggregateCaseAggregations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PreAggregateCaseAggregations.java @@ -427,12 +427,12 @@ private Optional extractCaseAggregation(Symbol aggregationSymbo private Type getType(Context context, Expression expression) { - return typeAnalyzer.getType(context.getSession(), context.getSymbolAllocator().getTypes(), expression); + return typeAnalyzer.getType(context.getSymbolAllocator().getTypes(), expression); } private Object optimizeExpression(Expression expression, Context context) { - Map, Type> expressionTypes = typeAnalyzer.getTypes(context.getSession(), context.getSymbolAllocator().getTypes(), expression); + Map, Type> expressionTypes = typeAnalyzer.getTypes(context.getSymbolAllocator().getTypes(), expression); IrExpressionInterpreter expressionInterpreter = new IrExpressionInterpreter(expression, plannerContext, context.getSession(), expressionTypes); return expressionInterpreter.optimize(Symbol::toSymbolReference); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java index 69b0c78b96a8..bb8c5baa5355 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java @@ -195,7 +195,7 @@ public static Optional pushAggregationIntoTableScan( Expression translated = ConnectorExpressionTranslator.translate(session, expression, plannerContext, variableMappings); // ConnectorExpressionTranslator may or may not preserve optimized form of expressions during round-trip. Avoid potential optimizer loop // by ensuring expression is optimized. - Map, Type> translatedExpressionTypes = typeAnalyzer.getTypes(session, context.getSymbolAllocator().getTypes(), translated); + Map, Type> translatedExpressionTypes = typeAnalyzer.getTypes(context.getSymbolAllocator().getTypes(), translated); Object optimized = new IrExpressionInterpreter(translated, plannerContext, session, translatedExpressionTypes) .optimize(NoOpSymbolResolver.INSTANCE); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughFilter.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughFilter.java index 2bc34ac48884..43c5468c1a78 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughFilter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughFilter.java @@ -88,14 +88,14 @@ public Result apply(ProjectNode node, Captures captures, Rule.Context context) .build(); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(expressions, false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes()); + Set dereferences = extractRowSubscripts(expressions, false, typeAnalyzer, context.getSymbolAllocator().getTypes()); if (dereferences.isEmpty()) { return Result.empty(); } // Create new symbols for dereference expressions - Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer); + Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSymbolAllocator(), typeAnalyzer); // Rewrite project node assignments using new symbols for dereference expressions Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughJoin.java index b164b853386f..59774cc94de6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughJoin.java @@ -100,7 +100,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) ImmutableList.Builder expressionsBuilder = ImmutableList.builder(); expressionsBuilder.addAll(projectNode.getAssignments().getExpressions()); joinNode.getFilter().ifPresent(expressionsBuilder::add); - Set dereferences = extractRowSubscripts(expressionsBuilder.build(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes()); + Set dereferences = extractRowSubscripts(expressionsBuilder.build(), false, typeAnalyzer, context.getSymbolAllocator().getTypes()); // Exclude criteria symbols ImmutableSet.Builder criteriaSymbolsBuilder = ImmutableSet.builder(); @@ -119,7 +119,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) } // Create new symbols for dereference expressions - Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer); + Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSymbolAllocator(), typeAnalyzer); // Rewrite project node assignments using new symbols for dereference expressions Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughProject.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughProject.java index d6ad9cda4a56..3869c7ef7576 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughProject.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughProject.java @@ -74,7 +74,7 @@ public Result apply(ProjectNode node, Captures captures, Context context) ProjectNode child = captures.get(CHILD); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(node.getAssignments().getExpressions(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes()); + Set dereferences = extractRowSubscripts(node.getAssignments().getExpressions(), false, typeAnalyzer, context.getSymbolAllocator().getTypes()); // Exclude dereferences on symbols being synthesized within child dereferences = dereferences.stream() @@ -86,7 +86,7 @@ public Result apply(ProjectNode node, Captures captures, Context context) } // Create new symbols for dereference expressions - Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer); + Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSymbolAllocator(), typeAnalyzer); // Rewrite project node assignments using new symbols for dereference expressions Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughSemiJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughSemiJoin.java index aabe4f52b615..c6dc03e5c965 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughSemiJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughSemiJoin.java @@ -86,7 +86,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) SemiJoinNode semiJoinNode = captures.get(CHILD); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes()); + Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, typeAnalyzer, context.getSymbolAllocator().getTypes()); // All dereferences can be assumed on the symbols coming from source, since filteringSource output is not propagated, // and semiJoinOutput is of type boolean. We exclude pushdown of dereferences on sourceJoinSymbol. @@ -99,7 +99,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) } // Create new symbols for dereference expressions - Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer); + Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSymbolAllocator(), typeAnalyzer); // Rewrite project node assignments using new symbols for dereference expressions Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughUnnest.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughUnnest.java index 726c1c6811af..8b4cc6e002cc 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughUnnest.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughUnnest.java @@ -87,7 +87,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) expressionsBuilder.addAll(projectNode.getAssignments().getExpressions()); // Extract dereferences for pushdown - Set dereferences = extractRowSubscripts(expressionsBuilder.build(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes()); + Set dereferences = extractRowSubscripts(expressionsBuilder.build(), false, typeAnalyzer, context.getSymbolAllocator().getTypes()); // Only retain dereferences on replicate symbols dereferences = dereferences.stream() @@ -99,7 +99,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) } // Create new symbols for dereference expressions - Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer); + Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSymbolAllocator(), typeAnalyzer); // Rewrite project node assignments using new symbols for dereference expressions Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughAssignUniqueId.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughAssignUniqueId.java index cce71c68e7f8..f83dc8819a66 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughAssignUniqueId.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughAssignUniqueId.java @@ -78,7 +78,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) AssignUniqueId assignUniqueId = captures.get(CHILD); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes()); + Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, typeAnalyzer, context.getSymbolAllocator().getTypes()); // We do not need to filter dereferences on idColumn symbol since it is supposed to be of BIGINT type. @@ -87,7 +87,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) } // Create new symbols for dereference expressions - Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer); + Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSymbolAllocator(), typeAnalyzer); // Rewrite project node assignments using new symbols for dereference expressions Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughLimit.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughLimit.java index cc32a67a0942..38b55abe3c51 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughLimit.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughLimit.java @@ -83,7 +83,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) LimitNode limitNode = captures.get(CHILD); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes()); + Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, typeAnalyzer, context.getSymbolAllocator().getTypes()); // Exclude dereferences on symbols being used in tiesResolvingScheme and requiresPreSortedInputs Set excludedSymbols = ImmutableSet.builder() @@ -101,7 +101,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) } // Create new symbols for dereference expressions - Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer); + Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSymbolAllocator(), typeAnalyzer); // Rewrite project node assignments using new symbols for dereference expressions Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughMarkDistinct.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughMarkDistinct.java index 2ce87e969ad5..6a2fb375ecf7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughMarkDistinct.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughMarkDistinct.java @@ -83,7 +83,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) MarkDistinctNode markDistinctNode = captures.get(CHILD); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes()); + Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, typeAnalyzer, context.getSymbolAllocator().getTypes()); // Exclude dereferences on distinct symbols being used in markDistinctNode. We do not need to filter // dereferences on markerSymbol since it is supposed to be of boolean type. @@ -96,7 +96,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) } // Create new symbols for dereference expressions - Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer); + Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSymbolAllocator(), typeAnalyzer); // Rewrite project node assignments using new symbols for dereference expressions Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughRowNumber.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughRowNumber.java index 7280432451fd..cf405c5d0d14 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughRowNumber.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughRowNumber.java @@ -83,7 +83,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) RowNumberNode rowNumberNode = captures.get(CHILD); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes()); + Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, typeAnalyzer, context.getSymbolAllocator().getTypes()); // Exclude dereferences on symbols being used in partitionBy dereferences = dereferences.stream() @@ -95,7 +95,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) } // Create new symbols for dereference expressions - Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer); + Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSymbolAllocator(), typeAnalyzer); // Rewrite project node assignments using new symbols for dereference expressions Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughSort.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughSort.java index bcb24f5b0567..8f629ac0a88d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughSort.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughSort.java @@ -83,7 +83,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) SortNode sortNode = captures.get(CHILD); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes()); + Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, typeAnalyzer, context.getSymbolAllocator().getTypes()); // Exclude dereferences on symbols used in ordering scheme to avoid replication of data dereferences = dereferences.stream() @@ -95,7 +95,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) } // Create new symbols for dereference expressions - Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer); + Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSymbolAllocator(), typeAnalyzer); // Rewrite project node assignments using new symbols for dereference expressions Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughTopN.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughTopN.java index 09152a01c46c..bb6e6e849c77 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughTopN.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughTopN.java @@ -83,7 +83,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) TopNNode topNNode = captures.get(CHILD); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes()); + Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, typeAnalyzer, context.getSymbolAllocator().getTypes()); // Exclude dereferences on symbols being used in orderBy dereferences = dereferences.stream() @@ -95,7 +95,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) } // Create new symbols for dereference expressions - Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer); + Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSymbolAllocator(), typeAnalyzer); // Rewrite project node assignments using new symbols for dereference expressions Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughTopNRanking.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughTopNRanking.java index d3599f63f256..f733371274da 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughTopNRanking.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughTopNRanking.java @@ -86,7 +86,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) TopNRankingNode topNRankingNode = captures.get(CHILD); // Extract dereferences from project node assignments for pushdown - Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes()); + Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, typeAnalyzer, context.getSymbolAllocator().getTypes()); // Exclude dereferences on symbols being used in partitionBy and orderBy DataOrganizationSpecification specification = topNRankingNode.getSpecification(); @@ -103,7 +103,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) } // Create new symbols for dereference expressions - Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer); + Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSymbolAllocator(), typeAnalyzer); // Rewrite project node assignments using new symbols for dereference expressions Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughWindow.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughWindow.java index 23636c08f48b..8dd6b3907994 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughWindow.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughWindow.java @@ -96,7 +96,6 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) .collect(toImmutableList())) .build(), false, - context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes()); @@ -116,7 +115,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) } // Create new symbols for dereference expressions - Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer); + Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSymbolAllocator(), typeAnalyzer); // Rewrite project node assignments using new symbols for dereference expressions Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFilterThroughCountAggregation.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFilterThroughCountAggregation.java index d2d31dbcef7f..6a5e4aabeb87 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFilterThroughCountAggregation.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFilterThroughCountAggregation.java @@ -229,8 +229,7 @@ private static Result pushFilter(FilterNode filterNode, AggregationNode aggregat // After filtering out `0` values, filter predicate's domain contains all remaining countSymbol values. Remove the countSymbol domain. TupleDomain newTupleDomain = tupleDomain.filter((symbol, domain) -> !symbol.equals(countSymbol)); Expression newPredicate = combineConjuncts( - plannerContext.getMetadata(), - new DomainTranslator(plannerContext).toPredicate(newTupleDomain), + new DomainTranslator().toPredicate(newTupleDomain), extractionResult.getRemainingExpression()); if (newPredicate.equals(TRUE_LITERAL)) { return Result.ofPlanNode(filterSource); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushInequalityFilterExpressionBelowJoinRuleSet.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushInequalityFilterExpressionBelowJoinRuleSet.java index 693d34f4345e..4b374415a8ae 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushInequalityFilterExpressionBelowJoinRuleSet.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushInequalityFilterExpressionBelowJoinRuleSet.java @@ -19,7 +19,6 @@ import io.trino.matching.Capture; import io.trino.matching.Captures; import io.trino.matching.Pattern; -import io.trino.metadata.Metadata; import io.trino.sql.ir.ComparisonExpression; import io.trino.sql.ir.Expression; import io.trino.sql.ir.SymbolReference; @@ -90,12 +89,10 @@ public class PushInequalityFilterExpressionBelowJoinRuleSet private static final Pattern FILTER_PATTERN = filter().with(source().matching( join().capturedAs(JOIN_CAPTURE))); - private final Metadata metadata; private final IrTypeAnalyzer typeAnalyzer; - public PushInequalityFilterExpressionBelowJoinRuleSet(Metadata metadata, IrTypeAnalyzer typeAnalyzer) + public PushInequalityFilterExpressionBelowJoinRuleSet(IrTypeAnalyzer typeAnalyzer) { - this.metadata = requireNonNull(metadata, "metadata is null"); this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); } @@ -171,7 +168,7 @@ private Result pushInequalityFilterExpressionBelowJoin(Context context, JoinNode private Optional conjunctsToFilter(List conjuncts) { - return Optional.of(combineConjuncts(metadata, conjuncts)).filter(expression -> !TRUE_LITERAL.equals(expression)); + return Optional.of(combineConjuncts(conjuncts)).filter(expression -> !TRUE_LITERAL.equals(expression)); } Map> extractPushDownCandidates(JoinNodeContext joinNodeContext, Expression filter) @@ -182,7 +179,7 @@ Map> extractPushDownCandidates(JoinNodeContext joinNod private boolean isSupportedExpression(JoinNodeContext joinNodeContext, Expression expression) { - if (!(expression instanceof ComparisonExpression comparison && isDeterministic(expression, metadata))) { + if (!(expression instanceof ComparisonExpression comparison && isDeterministic(expression))) { return false; } if (!SUPPORTED_COMPARISONS.contains(comparison.getOperator())) { @@ -293,7 +290,7 @@ private Assignments buildAssignments(PlanNode source, Map ne private Symbol symbolForExpression(Context context, Expression expression) { checkArgument(!(expression instanceof SymbolReference), "expression '%s' is a SymbolReference", expression); - return context.getSymbolAllocator().newSymbol(expression, typeAnalyzer.getType(context.getSession(), context.getSymbolAllocator().getTypes(), expression)); + return context.getSymbolAllocator().newSymbol(expression, typeAnalyzer.getType(context.getSymbolAllocator().getTypes(), expression)); } private class PushFilterExpressionBelowJoinFilterRule diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushLimitThroughProject.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushLimitThroughProject.java index 83e0a4c488b3..0f01532e72cb 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushLimitThroughProject.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushLimitThroughProject.java @@ -73,8 +73,8 @@ public Result apply(LimitNode parent, Captures captures, Context context) // undoing of PushDownDereferencesThroughLimit. We still push limit in the case of overlapping dereferences since // it enables PushDownDereferencesThroughLimit rule to push optimal dereferences. Set projections = ImmutableSet.copyOf(projectNode.getAssignments().getExpressions()); - if (!extractRowSubscripts(projections, false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes()).isEmpty() - && exclusiveDereferences(projections, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes())) { + if (!extractRowSubscripts(projections, false, typeAnalyzer, context.getSymbolAllocator().getTypes()).isEmpty() + && exclusiveDereferences(projections, typeAnalyzer, context.getSymbolAllocator().getTypes())) { return Result.empty(); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java index 6878b308f4c0..b1fcf4bfce8a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java @@ -126,7 +126,7 @@ public Result apply(FilterNode filterNode, Captures captures, Context context) plannerContext, typeAnalyzer, context.getStatsProvider(), - new DomainTranslator(plannerContext)); + new DomainTranslator()); if (rewritten.isEmpty() || arePlansSame(filterNode, tableScan, rewritten.get())) { return Result.empty(); @@ -168,7 +168,7 @@ public static Optional pushFilterIntoTableScan( return Optional.empty(); } - SplitExpression splitExpression = splitExpression(plannerContext, filterNode.getPredicate()); + SplitExpression splitExpression = splitExpression(filterNode.getPredicate()); DomainTranslator.ExtractionResult decomposedPredicate = DomainTranslator.getExtractionResult( plannerContext, @@ -202,7 +202,6 @@ public static Optional pushFilterIntoTableScan( symbolAllocator.getTypes(), node.getAssignments(), combineConjuncts( - plannerContext.getMetadata(), splitExpression.getDeterministicPredicate(), // Simplify the tuple domain to avoid creating an expression with too many nodes, // which would be expensive to evaluate in the call to isCandidate below. @@ -285,7 +284,7 @@ public static Optional pushFilterIntoTableScan( Expression translatedExpression = ConnectorExpressionTranslator.translate(session, remainingConnectorExpression.get(), plannerContext, variableMappings); // ConnectorExpressionTranslator may or may not preserve optimized form of expressions during round-trip. Avoid potential optimizer loop // by ensuring expression is optimized. - Map, Type> translatedExpressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), translatedExpression); + Map, Type> translatedExpressionTypes = typeAnalyzer.getTypes(symbolAllocator.getTypes(), translatedExpression); Object optimized = new IrExpressionInterpreter(translatedExpression, plannerContext, session, translatedExpressionTypes) .optimize(NoOpSymbolResolver.INSTANCE); @@ -293,7 +292,7 @@ public static Optional pushFilterIntoTableScan( optimizedExpression : new Constant(translatedExpressionTypes.get(NodeRef.of(translatedExpression)), optimized); - remainingDecomposedPredicate = combineConjuncts(plannerContext.getMetadata(), translatedExpression, expressionTranslation.remainingExpression()); + remainingDecomposedPredicate = combineConjuncts(translatedExpression, expressionTranslation.remainingExpression()); } Expression resultingPredicate = createResultingPredicate( @@ -331,10 +330,8 @@ private static void verifyTablePartitioning( verify(newTablePartitioning.equals(oldTablePartitioning), "Partitioning must not change after predicate is pushed down"); } - private static SplitExpression splitExpression(PlannerContext plannerContext, Expression predicate) + private static SplitExpression splitExpression(Expression predicate) { - Metadata metadata = plannerContext.getMetadata(); - List dynamicFilters = new ArrayList<>(); List deterministicPredicates = new ArrayList<>(); List nonDeterministicPredicate = new ArrayList<>(); @@ -344,7 +341,7 @@ private static SplitExpression splitExpression(PlannerContext plannerContext, Ex // dynamic filters have no meaning for connectors, so don't pass them dynamicFilters.add(conjunct); } - else if (isDeterministic(conjunct, metadata)) { + else if (isDeterministic(conjunct)) { deterministicPredicates.add(conjunct); } else { @@ -354,9 +351,9 @@ else if (isDeterministic(conjunct, metadata)) { } return new SplitExpression( - combineConjuncts(metadata, dynamicFilters), - combineConjuncts(metadata, deterministicPredicates), - combineConjuncts(metadata, nonDeterministicPredicate)); + combineConjuncts(dynamicFilters), + combineConjuncts(deterministicPredicates), + combineConjuncts(nonDeterministicPredicate)); } static Expression createResultingPredicate( @@ -378,7 +375,7 @@ static Expression createResultingPredicate( // * Short of implementing the previous bullet point, the current order of non-deterministic expressions // and non-TupleDomain-expressible expressions should be retained. Changing the order can lead // to failures of previously successful queries. - Expression expression = combineConjuncts(plannerContext.getMetadata(), dynamicFilter, unenforcedConstraints, nonDeterministicPredicate, remainingDecomposedPredicate); + Expression expression = combineConjuncts(dynamicFilter, unenforcedConstraints, nonDeterministicPredicate, remainingDecomposedPredicate); // Make sure we produce an expression whose terms are consistent with the canonical form used in other optimizations // Otherwise, we'll end up ping-ponging among rules diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateThroughProjectIntoRowNumber.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateThroughProjectIntoRowNumber.java index 1c72a0ca1a55..3a31e44fffc8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateThroughProjectIntoRowNumber.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateThroughProjectIntoRowNumber.java @@ -139,9 +139,8 @@ public Result apply(FilterNode filter, Captures captures, Context context) // Remove the row number domain because it is absorbed into the node TupleDomain newTupleDomain = tupleDomain.filter((symbol, domain) -> !symbol.equals(rowNumberSymbol)); Expression newPredicate = combineConjuncts( - plannerContext.getMetadata(), extractionResult.getRemainingExpression(), - new DomainTranslator(plannerContext).toPredicate(newTupleDomain)); + new DomainTranslator().toPredicate(newTupleDomain)); if (newPredicate.equals(TRUE_LITERAL)) { return Result.ofPlanNode(project); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateThroughProjectIntoWindow.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateThroughProjectIntoWindow.java index 75c362207b49..4c15188be7c6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateThroughProjectIntoWindow.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateThroughProjectIntoWindow.java @@ -145,9 +145,8 @@ public Result apply(FilterNode filter, Captures captures, Context context) // Remove the ranking domain because it is absorbed into the node TupleDomain newTupleDomain = tupleDomain.filter((symbol, domain) -> !symbol.equals(rankingSymbol)); Expression newPredicate = combineConjuncts( - plannerContext.getMetadata(), extractionResult.getRemainingExpression(), - new DomainTranslator(plannerContext).toPredicate(newTupleDomain)); + new DomainTranslator().toPredicate(newTupleDomain)); if (newPredicate.equals(TRUE_LITERAL)) { return Result.ofPlanNode(project); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionIntoTableScan.java index 31a97186cf9f..408cefbc802a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionIntoTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionIntoTableScan.java @@ -155,7 +155,7 @@ public Result apply(ProjectNode project, Captures captures, Context context) Expression translated = ConnectorExpressionTranslator.translate(session, expression, plannerContext, variableMappings); // ConnectorExpressionTranslator may or may not preserve optimized form of expressions during round-trip. Avoid potential optimizer loop // by ensuring expression is optimized. - Map, Type> translatedExpressionTypes = typeAnalyzer.getTypes(session, context.getSymbolAllocator().getTypes(), translated); + Map, Type> translatedExpressionTypes = typeAnalyzer.getTypes(context.getSymbolAllocator().getTypes(), translated); Object optimized = new IrExpressionInterpreter(translated, plannerContext, session, translatedExpressionTypes) .optimize(NoOpSymbolResolver.INSTANCE); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughJoin.java index 89e9aa6ac536..07fa378fceeb 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughJoin.java @@ -56,7 +56,7 @@ public static Optional pushProjectionThroughJoin( IrTypeAnalyzer typeAnalyzer, TypeProvider types) { - if (!projectNode.getAssignments().getExpressions().stream().allMatch(expression -> isDeterministic(expression, plannerContext.getMetadata()))) { + if (!projectNode.getAssignments().getExpressions().stream().allMatch(expression -> isDeterministic(expression))) { return Optional.empty(); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushTopNThroughProject.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushTopNThroughProject.java index 9212bf93478e..14d445490b80 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushTopNThroughProject.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushTopNThroughProject.java @@ -93,8 +93,8 @@ public Result apply(TopNNode parent, Captures captures, Context context) // undoing of PushDownDereferencesThroughTopN. We still push topN in the case of overlapping dereferences since // it enables PushDownDereferencesThroughTopN rule to push optimal dereferences. Set projections = ImmutableSet.copyOf(projectNode.getAssignments().getExpressions()); - if (!extractRowSubscripts(projections, false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes()).isEmpty() - && exclusiveDereferences(projections, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes())) { + if (!extractRowSubscripts(projections, false, typeAnalyzer, context.getSymbolAllocator().getTypes()).isEmpty() + && exclusiveDereferences(projections, typeAnalyzer, context.getSymbolAllocator().getTypes())) { return Result.empty(); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushdownFilterIntoRowNumber.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushdownFilterIntoRowNumber.java index 73f82509b7a3..5e8b39a1dc7d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushdownFilterIntoRowNumber.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushdownFilterIntoRowNumber.java @@ -102,9 +102,8 @@ public Result apply(FilterNode node, Captures captures, Context context) TupleDomain newTupleDomain = tupleDomain.filter((symbol, domain) -> !symbol.equals(rowNumberSymbol)); Expression newPredicate = combineConjuncts( - plannerContext.getMetadata(), extractionResult.getRemainingExpression(), - new DomainTranslator(plannerContext).toPredicate(newTupleDomain)); + new DomainTranslator().toPredicate(newTupleDomain)); if (newPredicate.equals(BooleanLiteral.TRUE_LITERAL)) { return Result.ofPlanNode(source); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushdownFilterIntoWindow.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushdownFilterIntoWindow.java index 7047c5f72767..5acc313b6ad2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushdownFilterIntoWindow.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushdownFilterIntoWindow.java @@ -122,9 +122,8 @@ public Result apply(FilterNode node, Captures captures, Context context) // Remove the row number domain because it is absorbed into the node TupleDomain newTupleDomain = tupleDomain.filter((symbol, domain) -> !symbol.equals(rankingSymbol)); Expression newPredicate = combineConjuncts( - plannerContext.getMetadata(), extractionResult.getRemainingExpression(), - new DomainTranslator(plannerContext).toPredicate(newTupleDomain)); + new DomainTranslator().toPredicate(newTupleDomain)); if (newPredicate.equals(BooleanLiteral.TRUE_LITERAL)) { return Result.ofPlanNode(newSource); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantDateTrunc.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantDateTrunc.java index 02ced6dcb25e..4897a14023c8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantDateTrunc.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantDateTrunc.java @@ -33,7 +33,6 @@ import java.util.Map; import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; -import static io.trino.metadata.ResolvedFunction.extractFunctionName; import static io.trino.spi.type.DateType.DATE; import static io.trino.sql.ir.IrUtils.isEffectivelyLiteral; import static java.util.Objects.requireNonNull; @@ -76,9 +75,9 @@ public Visitor(Session session, PlannerContext plannerContext, IrTypeAnalyzer ty @Override public Expression rewriteFunctionCall(FunctionCall node, Void context, ExpressionTreeRewriter treeRewriter) { - CatalogSchemaFunctionName functionName = extractFunctionName(node.getName()); + CatalogSchemaFunctionName functionName = node.getFunction().getName(); if (functionName.equals(builtinFunctionName("date_trunc")) && node.getArguments().size() == 2) { - Map, Type> expressionTypes = typeAnalyzer.getTypes(session, types, node); + Map, Type> expressionTypes = typeAnalyzer.getTypes(types, node); Expression unitExpression = node.getArguments().get(0); Expression argument = node.getArguments().get(1); if (expressionTypes.get(NodeRef.of(argument)) == DATE && expressionTypes.get(NodeRef.of(unitExpression)) instanceof VarcharType && isEffectivelyLiteral(plannerContext, session, unitExpression)) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantPredicateAboveTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantPredicateAboveTableScan.java index 9140a8138389..fdec20959f82 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantPredicateAboveTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantPredicateAboveTableScan.java @@ -142,7 +142,7 @@ public Result apply(FilterNode filterNode, Captures captures, Context context) context.getSymbolAllocator(), typeAnalyzer, TRUE_LITERAL, // Dynamic filters are included in decomposedPredicate.getRemainingExpression() - new DomainTranslator(plannerContext).toPredicate(unenforcedDomain.transformKeys(assignments::get)), + new DomainTranslator().toPredicate(unenforcedDomain.transformKeys(assignments::get)), nonDeterministicPredicate, decomposedPredicate.getRemainingExpression()); @@ -163,7 +163,6 @@ private ExtractionResult getFullyExtractedPredicates(Session session, Expression .map(ExtractionResult::getTupleDomain) .collect(toImmutableList())), combineConjuncts( - plannerContext.getMetadata(), extractedPredicates.getOrDefault(FALSE, ImmutableList.of()).stream() .map(ExtractionResult::getRemainingExpression) .collect(toImmutableList()))); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java index dae567e3c863..919cdfcd0786 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java @@ -16,7 +16,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import io.trino.Session; import io.trino.metadata.OperatorNotFoundException; import io.trino.spi.type.Type; import io.trino.sql.DynamicFilters; @@ -87,20 +86,18 @@ public RemoveUnsupportedDynamicFilters(PlannerContext plannerContext) @Override public PlanNode optimize(PlanNode plan, Context context) { - PlanWithConsumedDynamicFilters result = plan.accept(new RemoveUnsupportedDynamicFilters.Rewriter(context.session(), context.types()), ImmutableSet.of()); + PlanWithConsumedDynamicFilters result = plan.accept(new RemoveUnsupportedDynamicFilters.Rewriter(context.types()), ImmutableSet.of()); return result.getNode(); } private class Rewriter extends PlanVisitor> { - private final Session session; private final TypeProvider types; private final TypeCoercion typeCoercion; - public Rewriter(Session session, TypeProvider types) + public Rewriter(TypeProvider types) { - this.session = requireNonNull(session, "session is null"); this.types = requireNonNull(types, "types is null"); this.typeCoercion = new TypeCoercion(plannerContext.getTypeManager()::getType); } @@ -295,7 +292,7 @@ public PlanWithConsumedDynamicFilters visitFilter(FilterNode node, Set allowedDynamicFilterIds, ImmutableSet.Builder consumedDynamicFilterIds) { - return combineConjuncts(plannerContext.getMetadata(), extractConjuncts(expression) + return combineConjuncts(extractConjuncts(expression) .stream() .map(this::removeNestedDynamicFilters) .filter(conjunct -> @@ -322,7 +319,7 @@ private boolean isSupportedDynamicFilterExpression(Expression expression) if (!(castExpression.getExpression() instanceof SymbolReference)) { return false; } - Map, Type> expressionTypes = typeAnalyzer.getTypes(session, types, expression); + Map, Type> expressionTypes = typeAnalyzer.getTypes(types, expression); Type castSourceType = expressionTypes.get(NodeRef.of(castExpression.getExpression())); Type castTargetType = expressionTypes.get(NodeRef.of(castExpression)); // CAST must be an implicit coercion @@ -350,7 +347,7 @@ private Expression removeAllDynamicFilters(Expression expression) if (extractResult.getDynamicConjuncts().isEmpty()) { return rewrittenExpression; } - return combineConjuncts(plannerContext.getMetadata(), extractResult.getStaticConjuncts()); + return combineConjuncts(extractResult.getStaticConjuncts()); } private Expression removeNestedDynamicFilters(Expression expression) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReorderJoins.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReorderJoins.java index 051c47f82099..833d76ff52cc 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReorderJoins.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReorderJoins.java @@ -31,7 +31,6 @@ import io.trino.cost.StatsProvider; import io.trino.matching.Captures; import io.trino.matching.Pattern; -import io.trino.metadata.Metadata; import io.trino.sql.PlannerContext; import io.trino.sql.ir.ComparisonExpression; import io.trino.sql.ir.Expression; @@ -114,7 +113,7 @@ public ReorderJoins(PlannerContext plannerContext, CostComparator costComparator this.pattern = join().matching( joinNode -> joinNode.getDistributionType().isEmpty() && joinNode.getType() == INNER - && isDeterministic(joinNode.getFilter().orElse(TRUE_LITERAL), plannerContext.getMetadata())); + && isDeterministic(joinNode.getFilter().orElse(TRUE_LITERAL))); this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); } @@ -158,7 +157,6 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) private JoinEnumerationResult chooseJoinOrder(MultiJoinNode multiJoinNode, Context context) { JoinEnumerator joinEnumerator = new JoinEnumerator( - plannerContext.getMetadata(), costComparator, multiJoinNode.getFilter(), context); @@ -168,7 +166,6 @@ private JoinEnumerationResult chooseJoinOrder(MultiJoinNode multiJoinNode, Conte @VisibleForTesting static class JoinEnumerator { - private final Metadata metadata; private final Session session; private final StatsProvider statsProvider; private final CostProvider costProvider; @@ -183,9 +180,8 @@ static class JoinEnumerator private final Map, JoinEnumerationResult> memo = new HashMap<>(); @VisibleForTesting - JoinEnumerator(Metadata metadata, CostComparator costComparator, Expression filter, Context context) + JoinEnumerator(CostComparator costComparator, Expression filter, Context context) { - this.metadata = requireNonNull(metadata, "metadata is null"); this.context = requireNonNull(context); this.session = requireNonNull(context.getSession(), "session is null"); this.statsProvider = requireNonNull(context.getStatsProvider(), "statsProvider is null"); @@ -193,7 +189,7 @@ static class JoinEnumerator this.resultComparator = costComparator.forSession(session).onResultOf(result -> result.cost); this.idAllocator = requireNonNull(context.getIdAllocator(), "idAllocator is null"); this.allFilter = requireNonNull(filter, "filter is null"); - this.allFilterInference = new EqualityInference(metadata, filter); + this.allFilterInference = new EqualityInference(filter); this.lookup = requireNonNull(context.getLookup(), "lookup is null"); } @@ -353,7 +349,7 @@ private List getJoinPredicates(Set leftSymbols, Set // This takes all conjuncts that were part of allFilters that // could not be used for equality inference. // If they use both the left and right symbols, we add them to the list of joinPredicates - nonInferrableConjuncts(metadata, allFilter) + nonInferrableConjuncts(allFilter) .map(conjunct -> allFilterInference.rewrite(conjunct, Sets.union(leftSymbols, rightSymbols))) .filter(Objects::nonNull) // filter expressions that contain only left or right symbols @@ -364,7 +360,7 @@ private List getJoinPredicates(Set leftSymbols, Set // create equality inference on available symbols // TODO: make generateEqualitiesPartitionedBy take left and right scope List joinEqualities = allFilterInference.generateEqualitiesPartitionedBy(Sets.union(leftSymbols, rightSymbols)).getScopeEqualities(); - EqualityInference joinInference = new EqualityInference(metadata, joinEqualities); + EqualityInference joinInference = new EqualityInference(joinEqualities); joinPredicatesBuilder.addAll(joinInference.generateEqualitiesPartitionedBy(leftSymbols).getScopeStraddlingEqualities()); return joinPredicatesBuilder.build(); @@ -377,11 +373,11 @@ private JoinEnumerationResult getJoinSource(LinkedHashSet nodes, List< Set scope = ImmutableSet.copyOf(outputSymbols); ImmutableList.Builder predicates = ImmutableList.builder(); predicates.addAll(allFilterInference.generateEqualitiesPartitionedBy(scope).getScopeEqualities()); - nonInferrableConjuncts(metadata, allFilter) + nonInferrableConjuncts(allFilter) .map(conjunct -> allFilterInference.rewrite(conjunct, scope)) .filter(Objects::nonNull) .forEach(predicates::add); - Expression filter = combineConjuncts(metadata, predicates.build()); + Expression filter = combineConjuncts(predicates.build()); if (!TRUE_LITERAL.equals(filter)) { planNode = new FilterNode(idAllocator.getNextId(), planNode, filter); } @@ -645,7 +641,7 @@ private void flattenNode(PlanNode node, int limit) return; } - if (joinNode.getType() != INNER || !isDeterministic(joinNode.getFilter().orElse(TRUE_LITERAL), plannerContext.getMetadata()) || joinNode.getDistributionType().isPresent()) { + if (joinNode.getType() != INNER || !isDeterministic(joinNode.getFilter().orElse(TRUE_LITERAL)) || joinNode.getDistributionType().isPresent()) { sources.add(node); return; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceJoinOverConstantWithProject.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceJoinOverConstantWithProject.java index d7afd9dfc806..82f3985ec3f4 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceJoinOverConstantWithProject.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceJoinOverConstantWithProject.java @@ -15,7 +15,6 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; -import io.trino.metadata.Metadata; import io.trino.sql.ir.Expression; import io.trino.sql.ir.Row; import io.trino.sql.planner.PlanNodeIdAllocator; @@ -37,7 +36,6 @@ import static io.trino.sql.planner.DeterminismEvaluator.isDeterministic; import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.extractCardinality; import static io.trino.sql.planner.plan.Patterns.join; -import static java.util.Objects.requireNonNull; /** * This rule transforms plans with join where one of the sources is @@ -83,13 +81,6 @@ public class ReplaceJoinOverConstantWithProject private static final Pattern PATTERN = join() .matching(ReplaceJoinOverConstantWithProject::isUnconditional); - private final Metadata metadata; - - public ReplaceJoinOverConstantWithProject(Metadata metadata) - { - this.metadata = requireNonNull(metadata, "metadata is null"); - } - @Override public Pattern getPattern() { @@ -181,7 +172,7 @@ private boolean isSingleConstantRow(PlanNode node) Expression row = getOnlyElement(values.getRows().get()); - if (!isDeterministic(row, metadata)) { + if (!isDeterministic(row)) { return false; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java index e13e03c2a1a1..1862069c808e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java @@ -153,8 +153,6 @@ private boolean isStEnvelopeFunctionCall(Expression expression, ResolvedFunction return false; } - return plannerContext.getMetadata().decodeFunction(functionCall.getName()) - .getFunctionId() - .equals(stEnvelopeFunction.getFunctionId()); + return functionCall.getFunction().getFunctionId().equals(stEnvelopeFunction.getFunctionId()); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyExpressions.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyExpressions.java index cc23251850c6..488158970586 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyExpressions.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyExpressions.java @@ -45,11 +45,11 @@ public static Expression rewrite(Expression expression, Session session, SymbolA if (expression instanceof SymbolReference) { return expression; } - Map, Type> expressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), expression); + Map, Type> expressionTypes = typeAnalyzer.getTypes(symbolAllocator.getTypes(), expression); expression = pushDownNegations(plannerContext.getMetadata(), expression, expressionTypes); expression = extractCommonPredicates(plannerContext.getMetadata(), expression); expression = normalizeOrExpression(expression); - expressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), expression); + expressionTypes = typeAnalyzer.getTypes(symbolAllocator.getTypes(), expression); IrExpressionInterpreter interpreter = new IrExpressionInterpreter(expression, plannerContext, session, expressionTypes); Object optimized = interpreter.optimize(NoOpSymbolResolver.INSTANCE); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyFilterPredicate.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyFilterPredicate.java index 3b4278bb2c58..d06cb9ce6508 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyFilterPredicate.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyFilterPredicate.java @@ -16,7 +16,6 @@ import com.google.common.collect.ImmutableList; import io.trino.matching.Captures; import io.trino.matching.Pattern; -import io.trino.metadata.Metadata; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; @@ -56,13 +55,6 @@ public class SimplifyFilterPredicate { private static final Pattern PATTERN = filter(); - private final Metadata metadata; - - public SimplifyFilterPredicate(Metadata metadata) - { - this.metadata = metadata; - } - @Override public Pattern getPattern() { @@ -93,7 +85,7 @@ public Result apply(FilterNode node, Captures captures, Context context) return Result.ofPlanNode(new FilterNode( node.getId(), node.getSource(), - combineConjuncts(metadata, newConjuncts.build()))); + combineConjuncts(newConjuncts.build()))); } private Optional simplifyFilterExpression(Expression expression) @@ -109,7 +101,7 @@ private Optional simplifyFilterExpression(Expression expression) if (isNotTrue(trueValue) && falseValue.isPresent() && falseValue.get().equals(TRUE_LITERAL)) { return Optional.of(isFalseOrNullPredicate(condition)); } - if (falseValue.isPresent() && falseValue.get().equals(trueValue) && isDeterministic(trueValue, metadata)) { + if (falseValue.isPresent() && falseValue.get().equals(trueValue) && isDeterministic(trueValue)) { return Optional.of(trueValue); } if (isNotTrue(trueValue) && (falseValue.isEmpty() || isNotTrue(falseValue.get()))) { @@ -163,7 +155,7 @@ private Optional simplifyFilterExpression(Expression expression) } else { builder.add(operand); - return Optional.of(combineConjuncts(metadata, builder.build())); + return Optional.of(combineConjuncts(builder.build())); } } } @@ -171,7 +163,7 @@ private Optional simplifyFilterExpression(Expression expression) if (notTrueResultsCount == results.size() && defaultValue.isPresent() && defaultValue.get().equals(TRUE_LITERAL)) { ImmutableList.Builder builder = ImmutableList.builder(); operands.forEach(operand -> builder.add(isFalseOrNullPredicate(operand))); - return Optional.of(combineConjuncts(metadata, builder.build())); + return Optional.of(combineConjuncts(builder.build())); } // skip clauses with not true conditions List whenClauses = new ArrayList<>(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedJoinToJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedJoinToJoin.java index 14643895c525..2c91df0c53fc 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedJoinToJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedJoinToJoin.java @@ -75,7 +75,6 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co DecorrelatedNode decorrelatedSubquery = decorrelatedNodeOptional.get(); Expression filter = combineConjuncts( - plannerContext.getMetadata(), decorrelatedSubquery.getCorrelatedPredicates().orElse(TRUE_LITERAL), correlatedJoinNode.getFilter()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapCastInComparison.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapCastInComparison.java index ffa8d1d5065f..15ed55306be5 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapCastInComparison.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapCastInComparison.java @@ -175,7 +175,7 @@ private Expression unwrapCast(ComparisonExpression expression) return expression; } - Object right = new IrExpressionInterpreter(expression.getRight(), plannerContext, session, typeAnalyzer.getTypes(session, types, expression.getRight())) + Object right = new IrExpressionInterpreter(expression.getRight(), plannerContext, session, typeAnalyzer.getTypes(types, expression.getRight())) .optimize(NoOpSymbolResolver.INSTANCE); ComparisonExpression.Operator operator = expression.getOperator(); @@ -191,8 +191,8 @@ private Expression unwrapCast(ComparisonExpression expression) return expression; } - Type sourceType = typeAnalyzer.getType(session, types, cast.getExpression()); - Type targetType = typeAnalyzer.getType(session, types, expression.getRight()); + Type sourceType = typeAnalyzer.getType(types, cast.getExpression()); + Type targetType = typeAnalyzer.getType(types, expression.getRight()); if (sourceType instanceof TimestampType && targetType == DATE) { return unwrapTimestampToDateCast((TimestampType) sourceType, operator, cast.getExpression(), (long) right).orElse(expression); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapDateTruncInComparison.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapDateTruncInComparison.java index af7a421ab4e1..3d4b3c118701 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapDateTruncInComparison.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapDateTruncInComparison.java @@ -53,7 +53,6 @@ import static com.google.common.base.Verify.verify; import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; -import static io.trino.metadata.ResolvedFunction.extractFunctionName; import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; @@ -152,12 +151,12 @@ private Expression unwrapDateTrunc(ComparisonExpression expression) // This is provided by CanonicalizeExpressionRewriter. if (!(expression.getLeft() instanceof FunctionCall call) || - !extractFunctionName(call.getName()).equals(builtinFunctionName("date_trunc")) || + !call.getFunction().getName().equals(builtinFunctionName("date_trunc")) || call.getArguments().size() != 2) { return expression; } - Map, Type> expressionTypes = typeAnalyzer.getTypes(session, types, expression); + Map, Type> expressionTypes = typeAnalyzer.getTypes(types, expression); Expression unitExpression = call.getArguments().get(0); if (!(expressionTypes.get(NodeRef.of(unitExpression)) instanceof VarcharType) || !isEffectivelyLiteral(plannerContext, session, unitExpression)) { return expression; @@ -193,7 +192,7 @@ private Expression unwrapDateTrunc(ComparisonExpression expression) return expression; } - ResolvedFunction resolvedFunction = plannerContext.getMetadata().decodeFunction(call.getName()); + ResolvedFunction resolvedFunction = call.getFunction(); Optional unitIfSupported = Enums.getIfPresent(SupportedUnit.class, unitName.toStringUtf8().toUpperCase(Locale.ENGLISH)).toJavaUtil(); if (unitIfSupported.isEmpty()) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapSingleColumnRowInApply.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapSingleColumnRowInApply.java index 15785b9eafb9..b5b7245746db 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapSingleColumnRowInApply.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapSingleColumnRowInApply.java @@ -148,7 +148,7 @@ else if (expression instanceof ApplyNode.QuantifiedComparison comparison) { private Optional unwrapSingleColumnRow(Context context, Expression value, Expression list, BiFunction function) { - Type type = typeAnalyzer.getType(context.getSession(), context.getSymbolAllocator().getTypes(), value); + Type type = typeAnalyzer.getType(context.getSymbolAllocator().getTypes(), value); if (type instanceof RowType rowType) { if (rowType.getFields().size() == 1) { Type elementType = rowType.getTypeParameters().get(0); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapYearInComparison.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapYearInComparison.java index d77bf7247f5e..647d3da1fc46 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapYearInComparison.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapYearInComparison.java @@ -43,7 +43,6 @@ import static com.google.common.collect.Iterables.getOnlyElement; import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; -import static io.trino.metadata.ResolvedFunction.extractFunctionName; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.DateType.DATE; import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_SECOND; @@ -130,7 +129,7 @@ public Expression rewriteInPredicate(InPredicate node, Void context, ExpressionT Expression value = inPredicate.getValue(); if (!(value instanceof FunctionCall call) || - !extractFunctionName(call.getName()).equals(builtinFunctionName("year")) || + !call.getFunction().getName().equals(builtinFunctionName("year")) || call.getArguments().size() != 1) { return inPredicate; } @@ -156,12 +155,12 @@ private Expression unwrapYear(ComparisonExpression expression) // Expect year on the left side and value on the right side of the comparison. // This is provided by CanonicalizeExpressionRewriter. if (!(expression.getLeft() instanceof FunctionCall call) || - !extractFunctionName(call.getName()).equals(builtinFunctionName("year")) || + !call.getFunction().getName().equals(builtinFunctionName("year")) || call.getArguments().size() != 1) { return expression; } - Map, Type> expressionTypes = typeAnalyzer.getTypes(session, types, expression); + Map, Type> expressionTypes = typeAnalyzer.getTypes(types, expression); Expression argument = getOnlyElement(call.getArguments()); Type argumentType = expressionTypes.get(NodeRef.of(argument)); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java index 7e77ef9233d5..ca12163f1813 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java @@ -184,7 +184,7 @@ public Rewriter(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator this.types = symbolAllocator.getTypes(); this.statsProvider = new CachingStatsProvider(statsCalculator, session, types, tableStatsProvider); this.session = session; - this.domainTranslator = new DomainTranslator(plannerContext); + this.domainTranslator = new DomainTranslator(); this.redistributeWrites = SystemSessionProperties.isRedistributeWrites(session); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/ExpressionEquivalence.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/ExpressionEquivalence.java index bb69094e2fb6..589ee5836600 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/ExpressionEquivalence.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/ExpressionEquivalence.java @@ -96,7 +96,7 @@ private RowExpression toRowExpression(Session session, Expression expression, Ma { return translate( expression, - typeAnalyzer.getTypes(session, types, expression), + typeAnalyzer.getTypes(types, expression), symbolInput, metadata, functionManager, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/IndexJoinOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/IndexJoinOptimizer.java index f6006b048ac0..c9ee1a4128f6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/IndexJoinOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/IndexJoinOptimizer.java @@ -251,7 +251,7 @@ private IndexSourceRewriter( Session session) { this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); - this.domainTranslator = new DomainTranslator(plannerContext); + this.domainTranslator = new DomainTranslator(); this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.session = requireNonNull(session, "session is null"); @@ -325,7 +325,6 @@ private PlanNode planTableScan(TableScanNode node, Expression predicate, Context node.getAssignments()); Expression resultingPredicate = combineConjuncts( - plannerContext.getMetadata(), domainTranslator.toPredicate(resolvedIndex.getUnresolvedTupleDomain().transformKeys(inverseAssignments::get)), decomposedPredicate.getRemainingExpression()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/MetadataQueryOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/MetadataQueryOptimizer.java index 7f7b39895e16..e1e309bff1cf 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/MetadataQueryOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/MetadataQueryOptimizer.java @@ -186,7 +186,7 @@ private Optional findTableScan(PlanNode source) } else if (source instanceof ProjectNode project) { // verify projections are deterministic - if (!Iterables.all(project.getAssignments().getExpressions(), expression -> isDeterministic(expression, plannerContext.getMetadata()))) { + if (!Iterables.all(project.getAssignments().getExpressions(), expression -> isDeterministic(expression))) { return Optional.empty(); } source = project.getSource(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanNodeDecorrelator.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanNodeDecorrelator.java index 10d2f2635033..3b776991832a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanNodeDecorrelator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanNodeDecorrelator.java @@ -72,14 +72,12 @@ public class PlanNodeDecorrelator { - private final PlannerContext plannerContext; private final SymbolAllocator symbolAllocator; private final Lookup lookup; private final TypeCoercion typeCoercion; public PlanNodeDecorrelator(PlannerContext plannerContext, SymbolAllocator symbolAllocator, Lookup lookup) { - this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); this.lookup = requireNonNull(lookup, "lookup is null"); this.typeCoercion = new TypeCoercion(plannerContext.getTypeManager()::getType); @@ -159,7 +157,7 @@ public Optional visitFilter(FilterNode node, Void context) FilterNode newFilterNode = new FilterNode( node.getId(), childDecorrelationResult.node, - combineConjuncts(plannerContext.getMetadata(), uncorrelatedPredicates)); + combineConjuncts(uncorrelatedPredicates)); Set symbolsToPropagate = Sets.difference(SymbolsExtractor.extractUnique(correlatedPredicates), ImmutableSet.copyOf(correlation)); return Optional.of(new DecorrelationResult( @@ -528,7 +526,7 @@ private Set extractConstantSymbols(List correlatedConjuncts) // checks whether the expression is a deterministic combination of correlation symbols private boolean isConstant(Expression expression) { - return isDeterministic(expression, plannerContext.getMetadata()) && + return isDeterministic(expression) && ImmutableSet.copyOf(correlation).containsAll(SymbolsExtractor.extractUnique(expression)); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java index 9d4f5cc12f32..fa702799552a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java @@ -184,7 +184,7 @@ private Rewriter( this.dynamicFiltering = dynamicFiltering; this.effectivePredicateExtractor = new EffectivePredicateExtractor( - new DomainTranslator(plannerContext), + new DomainTranslator(), plannerContext, useTableProperties && isPredicatePushdownUseTableProperties(session)); } @@ -247,15 +247,15 @@ public PlanNode visitWindow(WindowNode node, RewriteContext context) // function is injective, but that's a rare case. The majority of window nodes are expected to be partitioned by // pre-projected symbols. Predicate isSupported = conjunct -> - isDeterministic(conjunct, metadata) && + isDeterministic(conjunct) && partitionSymbols.containsAll(extractUnique(conjunct)); Map> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(isSupported)); - PlanNode rewrittenNode = context.defaultRewrite(node, combineConjuncts(metadata, conjuncts.get(true))); + PlanNode rewrittenNode = context.defaultRewrite(node, combineConjuncts(conjuncts.get(true))); if (!conjuncts.get(false).isEmpty()) { - rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, combineConjuncts(metadata, conjuncts.get(false))); + rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, combineConjuncts(conjuncts.get(false))); } return rewrittenNode; @@ -265,7 +265,7 @@ public PlanNode visitWindow(WindowNode node, RewriteContext context) public PlanNode visitProject(ProjectNode node, RewriteContext context) { Set deterministicSymbols = node.getAssignments().entrySet().stream() - .filter(entry -> isDeterministic(entry.getValue(), metadata)) + .filter(entry -> isDeterministic(entry.getValue())) .map(Map.Entry::getKey) .collect(Collectors.toSet()); @@ -288,7 +288,7 @@ public PlanNode visitProject(ProjectNode node, RewriteContext contex .map(conjunct -> unwrapCasts(session, plannerContext, typeAnalyzer, types, conjunct)) .collect(Collectors.toList()); - PlanNode rewrittenNode = context.defaultRewrite(node, combineConjuncts(metadata, inlinedDeterministicConjuncts)); + PlanNode rewrittenNode = context.defaultRewrite(node, combineConjuncts(inlinedDeterministicConjuncts)); // All deterministic conjuncts that contains non-inlining targets, and non-deterministic conjuncts, // if any, will be in the filter node. @@ -296,7 +296,7 @@ public PlanNode visitProject(ProjectNode node, RewriteContext contex nonInliningConjuncts.addAll(conjuncts.get(false)); if (!nonInliningConjuncts.isEmpty()) { - rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, combineConjuncts(metadata, nonInliningConjuncts)); + rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, combineConjuncts(nonInliningConjuncts)); } return rewrittenNode; @@ -331,11 +331,11 @@ public PlanNode visitGroupId(GroupIdNode node, RewriteContext contex Map> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(pushdownEligiblePredicate)); // Push down conjuncts from the inherited predicate that apply to common grouping symbols - PlanNode rewrittenNode = context.defaultRewrite(node, inlineSymbols(commonGroupingSymbolMapping, combineConjuncts(metadata, conjuncts.get(true)))); + PlanNode rewrittenNode = context.defaultRewrite(node, inlineSymbols(commonGroupingSymbolMapping, combineConjuncts(conjuncts.get(true)))); // All other conjuncts, if any, will be in the filter node. if (!conjuncts.get(false).isEmpty()) { - rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, combineConjuncts(metadata, conjuncts.get(false))); + rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, combineConjuncts(conjuncts.get(false))); } return rewrittenNode; @@ -348,10 +348,10 @@ public PlanNode visitMarkDistinct(MarkDistinctNode node, RewriteContext> conjuncts = extractConjuncts(context.get()).stream() .collect(Collectors.partitioningBy(conjunct -> pushDownableSymbols.containsAll(extractUnique(conjunct)))); - PlanNode rewrittenNode = context.defaultRewrite(node, combineConjuncts(metadata, conjuncts.get(true))); + PlanNode rewrittenNode = context.defaultRewrite(node, combineConjuncts(conjuncts.get(true))); if (!conjuncts.get(false).isEmpty()) { - rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, combineConjuncts(metadata, conjuncts.get(false))); + rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, combineConjuncts(conjuncts.get(false))); } return rewrittenNode; } @@ -388,7 +388,7 @@ public PlanNode visitUnion(UnionNode node, RewriteContext context) @Override public PlanNode visitFilter(FilterNode node, RewriteContext context) { - PlanNode rewrittenPlan = context.rewrite(node.getSource(), combineConjuncts(metadata, node.getPredicate(), context.get())); + PlanNode rewrittenPlan = context.rewrite(node.getSource(), combineConjuncts(node.getPredicate(), context.get())); if (!(rewrittenPlan instanceof FilterNode rewrittenFilterNode)) { return rewrittenPlan; } @@ -511,7 +511,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) List joinFilter = joinFilterBuilder.build(); DynamicFiltersResult dynamicFiltersResult = createDynamicFilters(node, equiJoinClauses, joinFilter, session, idAllocator); Map dynamicFilters = dynamicFiltersResult.getDynamicFilters(); - leftPredicate = combineConjuncts(metadata, leftPredicate, combineConjuncts(metadata, dynamicFiltersResult.getPredicates())); + leftPredicate = combineConjuncts(leftPredicate, combineConjuncts(dynamicFiltersResult.getPredicates())); PlanNode leftSource; PlanNode rightSource; @@ -525,7 +525,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) rightSource = context.rewrite(node.getRight(), rightPredicate); } - Optional newJoinFilter = Optional.of(combineConjuncts(metadata, joinFilter)); + Optional newJoinFilter = Optional.of(combineConjuncts(joinFilter)); if (newJoinFilter.get().equals(TRUE_LITERAL)) { newJoinFilter = Optional.empty(); } @@ -535,7 +535,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) // inner join, so we plan execution as nested-loops-join followed by filter instead // hash join. // todo: remove the code when we have support for filter function in nested loop join - postJoinPredicate = combineConjuncts(metadata, postJoinPredicate, newJoinFilter.get()); + postJoinPredicate = combineConjuncts(postJoinPredicate, newJoinFilter.get()); newJoinFilter = Optional.empty(); } @@ -817,7 +817,7 @@ private Symbol symbolForExpression(Expression expression) return Symbol.from(expression); } - return symbolAllocator.newSymbol(expression, typeAnalyzer.getType(session, symbolAllocator.getTypes(), expression)); + return symbolAllocator.newSymbol(expression, typeAnalyzer.getType(symbolAllocator.getTypes(), expression)); } private OuterJoinPushDownResult processLimitedOuterJoin( @@ -838,36 +838,36 @@ private OuterJoinPushDownResult processLimitedOuterJoin( // Strip out non-deterministic conjuncts extractConjuncts(inheritedPredicate).stream() - .filter(expression -> !isDeterministic(expression, metadata)) + .filter(expression -> !isDeterministic(expression)) .forEach(postJoinConjuncts::add); inheritedPredicate = filterDeterministicConjuncts(metadata, inheritedPredicate); outerEffectivePredicate = filterDeterministicConjuncts(metadata, outerEffectivePredicate); innerEffectivePredicate = filterDeterministicConjuncts(metadata, innerEffectivePredicate); extractConjuncts(joinPredicate).stream() - .filter(expression -> !isDeterministic(expression, metadata)) + .filter(expression -> !isDeterministic(expression)) .forEach(joinConjuncts::add); joinPredicate = filterDeterministicConjuncts(metadata, joinPredicate); // Generate equality inferences - EqualityInference inheritedInference = new EqualityInference(metadata, inheritedPredicate); - EqualityInference outerInference = new EqualityInference(metadata, inheritedPredicate, outerEffectivePredicate); + EqualityInference inheritedInference = new EqualityInference(inheritedPredicate); + EqualityInference outerInference = new EqualityInference(inheritedPredicate, outerEffectivePredicate); Set innerScope = ImmutableSet.copyOf(innerSymbols); Set outerScope = ImmutableSet.copyOf(outerSymbols); EqualityInference.EqualityPartition equalityPartition = inheritedInference.generateEqualitiesPartitionedBy(outerScope); - Expression outerOnlyInheritedEqualities = combineConjuncts(metadata, equalityPartition.getScopeEqualities()); - EqualityInference potentialNullSymbolInference = new EqualityInference(metadata, outerOnlyInheritedEqualities, outerEffectivePredicate, innerEffectivePredicate, joinPredicate); + Expression outerOnlyInheritedEqualities = combineConjuncts(equalityPartition.getScopeEqualities()); + EqualityInference potentialNullSymbolInference = new EqualityInference(outerOnlyInheritedEqualities, outerEffectivePredicate, innerEffectivePredicate, joinPredicate); // Push outer and join equalities into the inner side. For example: // SELECT * FROM nation LEFT OUTER JOIN region ON nation.regionkey = region.regionkey and nation.name = region.name WHERE nation.name = 'blah' - EqualityInference potentialNullSymbolInferenceWithoutInnerInferred = new EqualityInference(metadata, outerOnlyInheritedEqualities, outerEffectivePredicate, joinPredicate); + EqualityInference potentialNullSymbolInferenceWithoutInnerInferred = new EqualityInference(outerOnlyInheritedEqualities, outerEffectivePredicate, joinPredicate); innerPushdownConjuncts.addAll(potentialNullSymbolInferenceWithoutInnerInferred.generateEqualitiesPartitionedBy(innerScope).getScopeEqualities()); // TODO: we can further improve simplifying the equalities by considering other relationships from the outer side - EqualityInference.EqualityPartition joinEqualityPartition = new EqualityInference(metadata, joinPredicate).generateEqualitiesPartitionedBy(innerScope); + EqualityInference.EqualityPartition joinEqualityPartition = new EqualityInference(joinPredicate).generateEqualitiesPartitionedBy(innerScope); innerPushdownConjuncts.addAll(joinEqualityPartition.getScopeEqualities()); joinConjuncts.addAll(joinEqualityPartition.getScopeComplementEqualities()) .addAll(joinEqualityPartition.getScopeStraddlingEqualities()); @@ -878,7 +878,7 @@ private OuterJoinPushDownResult processLimitedOuterJoin( postJoinConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities()); // See if we can push inherited predicates down - EqualityInference.nonInferrableConjuncts(metadata, inheritedPredicate).forEach(conjunct -> { + EqualityInference.nonInferrableConjuncts(inheritedPredicate).forEach(conjunct -> { Expression outerRewritten = outerInference.rewrite(conjunct, outerScope); if (outerRewritten != null) { outerPushdownConjuncts.add(outerRewritten); @@ -895,13 +895,13 @@ private OuterJoinPushDownResult processLimitedOuterJoin( }); // See if we can push down any outer effective predicates to the inner side - EqualityInference.nonInferrableConjuncts(metadata, outerEffectivePredicate) + EqualityInference.nonInferrableConjuncts(outerEffectivePredicate) .map(conjunct -> potentialNullSymbolInference.rewrite(conjunct, innerScope)) .filter(Objects::nonNull) .forEach(innerPushdownConjuncts::add); // See if we can push down join predicates to the inner side - EqualityInference.nonInferrableConjuncts(metadata, joinPredicate).forEach(conjunct -> { + EqualityInference.nonInferrableConjuncts(joinPredicate).forEach(conjunct -> { Expression innerRewritten = potentialNullSymbolInference.rewrite(conjunct, innerScope); if (innerRewritten != null) { innerPushdownConjuncts.add(innerRewritten); @@ -911,10 +911,10 @@ private OuterJoinPushDownResult processLimitedOuterJoin( } }); - return new OuterJoinPushDownResult(combineConjuncts(metadata, outerPushdownConjuncts.build()), - combineConjuncts(metadata, innerPushdownConjuncts.build()), - combineConjuncts(metadata, joinConjuncts.build()), - combineConjuncts(metadata, postJoinConjuncts.build())); + return new OuterJoinPushDownResult(combineConjuncts(outerPushdownConjuncts.build()), + combineConjuncts(innerPushdownConjuncts.build()), + combineConjuncts(joinConjuncts.build()), + combineConjuncts(postJoinConjuncts.build())); } private static class OuterJoinPushDownResult @@ -970,12 +970,12 @@ private InnerJoinPushDownResult processInnerJoin( // Strip out non-deterministic conjuncts extractConjuncts(inheritedPredicate).stream() - .filter(deterministic -> !isDeterministic(deterministic, metadata)) + .filter(deterministic -> !isDeterministic(deterministic)) .forEach(joinConjuncts::add); inheritedPredicate = filterDeterministicConjuncts(metadata, inheritedPredicate); extractConjuncts(joinPredicate).stream() - .filter(expression -> !isDeterministic(expression, metadata)) + .filter(expression -> !isDeterministic(expression)) .forEach(joinConjuncts::add); joinPredicate = filterDeterministicConjuncts(metadata, joinPredicate); @@ -986,9 +986,9 @@ private InnerJoinPushDownResult processInnerJoin( ImmutableSet rightScope = ImmutableSet.copyOf(rightSymbols); // Generate equality inferences - EqualityInference allInference = new EqualityInference(metadata, inheritedPredicate, leftEffectivePredicate, rightEffectivePredicate, joinPredicate); - EqualityInference allInferenceWithoutLeftInferred = new EqualityInference(metadata, inheritedPredicate, rightEffectivePredicate, joinPredicate); - EqualityInference allInferenceWithoutRightInferred = new EqualityInference(metadata, inheritedPredicate, leftEffectivePredicate, joinPredicate); + EqualityInference allInference = new EqualityInference(inheritedPredicate, leftEffectivePredicate, rightEffectivePredicate, joinPredicate); + EqualityInference allInferenceWithoutLeftInferred = new EqualityInference(inheritedPredicate, rightEffectivePredicate, joinPredicate); + EqualityInference allInferenceWithoutRightInferred = new EqualityInference(inheritedPredicate, leftEffectivePredicate, joinPredicate); // Add equalities from the inference back in leftPushDownConjuncts.addAll(allInferenceWithoutLeftInferred.generateEqualitiesPartitionedBy(leftScope).getScopeEqualities()); @@ -996,7 +996,7 @@ private InnerJoinPushDownResult processInnerJoin( joinConjuncts.addAll(allInference.generateEqualitiesPartitionedBy(leftScope).getScopeStraddlingEqualities()); // scope straddling equalities get dropped in as part of the join predicate // Sort through conjuncts in inheritedPredicate that were not used for inference - EqualityInference.nonInferrableConjuncts(metadata, inheritedPredicate).forEach(conjunct -> { + EqualityInference.nonInferrableConjuncts(inheritedPredicate).forEach(conjunct -> { Expression leftRewrittenConjunct = allInference.rewrite(conjunct, leftScope); if (leftRewrittenConjunct != null) { leftPushDownConjuncts.add(leftRewrittenConjunct); @@ -1014,19 +1014,19 @@ private InnerJoinPushDownResult processInnerJoin( }); // See if we can push the right effective predicate to the left side - EqualityInference.nonInferrableConjuncts(metadata, rightEffectivePredicate) + EqualityInference.nonInferrableConjuncts(rightEffectivePredicate) .map(conjunct -> allInference.rewrite(conjunct, leftScope)) .filter(Objects::nonNull) .forEach(leftPushDownConjuncts::add); // See if we can push the left effective predicate to the right side - EqualityInference.nonInferrableConjuncts(metadata, leftEffectivePredicate) + EqualityInference.nonInferrableConjuncts(leftEffectivePredicate) .map(conjunct -> allInference.rewrite(conjunct, rightScope)) .filter(Objects::nonNull) .forEach(rightPushDownConjuncts::add); // See if we can push any parts of the join predicates to either side - EqualityInference.nonInferrableConjuncts(metadata, joinPredicate).forEach(conjunct -> { + EqualityInference.nonInferrableConjuncts(joinPredicate).forEach(conjunct -> { Expression leftRewritten = allInference.rewrite(conjunct, leftScope); if (leftRewritten != null) { leftPushDownConjuncts.add(leftRewritten); @@ -1043,9 +1043,9 @@ private InnerJoinPushDownResult processInnerJoin( }); return new InnerJoinPushDownResult( - combineConjuncts(metadata, leftPushDownConjuncts.build()), - combineConjuncts(metadata, rightPushDownConjuncts.build()), - combineConjuncts(metadata, joinConjuncts.build()), + combineConjuncts(leftPushDownConjuncts.build()), + combineConjuncts(rightPushDownConjuncts.build()), + combineConjuncts(joinConjuncts.build()), TRUE_LITERAL); } @@ -1092,7 +1092,7 @@ private Expression extractJoinPredicate(JoinNode joinNode) builder.add(equiJoinClause.toExpression()); } joinNode.getFilter().ifPresent(builder::add); - return combineConjuncts(metadata, builder.build()); + return combineConjuncts(builder.build()); } private JoinNode tryNormalizeToOuterToInnerJoin(JoinNode node, Expression inheritedPredicate) @@ -1171,7 +1171,7 @@ private boolean canConvertOuterToInner(List innerSymbolsForOuterJoin, Ex { Set innerSymbols = ImmutableSet.copyOf(innerSymbolsForOuterJoin); for (Expression conjunct : extractConjuncts(inheritedPredicate)) { - if (isDeterministic(conjunct, metadata)) { + if (isDeterministic(conjunct)) { // Ignore a conjunct for this test if we cannot deterministically get responses from it Object response = nullInputEvaluator(innerSymbols, conjunct); if (response == null || Boolean.FALSE.equals(response)) { @@ -1188,7 +1188,7 @@ private boolean canConvertOuterToInner(List innerSymbolsForOuterJoin, Ex // Temporary implementation for joins because the SimplifyExpressions optimizers cannot run properly on join clauses private Expression simplifyExpression(Expression expression) { - Map, Type> expressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), expression); + Map, Type> expressionTypes = typeAnalyzer.getTypes(symbolAllocator.getTypes(), expression); IrExpressionInterpreter optimizer = new IrExpressionInterpreter(expression, plannerContext, session, expressionTypes); Object object = optimizer.optimize(NoOpSymbolResolver.INSTANCE); @@ -1207,7 +1207,7 @@ private boolean areExpressionsEquivalent(Expression leftExpression, Expression r */ private Object nullInputEvaluator(Collection nullSymbols, Expression expression) { - Map, Type> expressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), expression); + Map, Type> expressionTypes = typeAnalyzer.getTypes(symbolAllocator.getTypes(), expression); return new IrExpressionInterpreter(expression, plannerContext, session, expressionTypes) .optimize(symbol -> nullSymbols.contains(symbol) ? null : symbol.toSymbolReference()); } @@ -1227,8 +1227,8 @@ private boolean joinDynamicFilteringExpression(Expression expression, Collection } comparison = (ComparisonExpression) notExpression.getValue(); Set expressionTypes = ImmutableSet.of( - typeAnalyzer.getType(session, types, comparison.getLeft()), - typeAnalyzer.getType(session, types, comparison.getRight())); + typeAnalyzer.getType(types, comparison.getLeft()), + typeAnalyzer.getType(types, comparison.getRight())); // Dynamic filtering is not supported with IS NOT DISTINCT FROM clause on REAL or DOUBLE types to avoid dealing with NaN values if (expressionTypes.contains(REAL) || expressionTypes.contains(DOUBLE)) { return false; @@ -1249,7 +1249,7 @@ private boolean joinDynamicFilteringExpression(Expression expression, Collection private boolean joinComparisonExpression(Expression expression, Collection leftSymbols, Collection rightSymbols, Set operators) { // At this point in time, our join predicates need to be deterministic - if (expression instanceof ComparisonExpression comparison && isDeterministic(expression, metadata)) { + if (expression instanceof ComparisonExpression comparison && isDeterministic(expression)) { if (operators.contains(comparison.getOperator())) { Set symbols1 = extractUnique(comparison.getLeft()); Set symbols2 = extractUnique(comparison.getRight()); @@ -1285,8 +1285,8 @@ private PlanNode visitNonFilteringSemiJoin(SemiJoinNode node, RewriteContext sourceScope = ImmutableSet.copyOf(node.getSource().getOutputSymbols()); - EqualityInference inheritedInference = new EqualityInference(metadata, inheritedPredicate); - EqualityInference.nonInferrableConjuncts(metadata, inheritedPredicate).forEach(conjunct -> { + EqualityInference inheritedInference = new EqualityInference(inheritedPredicate); + EqualityInference.nonInferrableConjuncts(inheritedPredicate).forEach(conjunct -> { Expression rewrittenConjunct = inheritedInference.rewrite(conjunct, sourceScope); // Since each source row is reflected exactly once in the output, ok to push non-deterministic predicates down if (rewrittenConjunct != null) { @@ -1303,7 +1303,7 @@ private PlanNode visitNonFilteringSemiJoin(SemiJoinNode node, RewriteContext postJoinConjuncts = new ArrayList<>(); // Generate equality inferences - EqualityInference allInference = new EqualityInference(metadata, deterministicInheritedPredicate, sourceEffectivePredicate, filteringSourceEffectivePredicate, joinExpression); - EqualityInference allInferenceWithoutSourceInferred = new EqualityInference(metadata, deterministicInheritedPredicate, filteringSourceEffectivePredicate, joinExpression); - EqualityInference allInferenceWithoutFilteringSourceInferred = new EqualityInference(metadata, deterministicInheritedPredicate, sourceEffectivePredicate, joinExpression); + EqualityInference allInference = new EqualityInference(deterministicInheritedPredicate, sourceEffectivePredicate, filteringSourceEffectivePredicate, joinExpression); + EqualityInference allInferenceWithoutSourceInferred = new EqualityInference(deterministicInheritedPredicate, filteringSourceEffectivePredicate, joinExpression); + EqualityInference allInferenceWithoutFilteringSourceInferred = new EqualityInference(deterministicInheritedPredicate, sourceEffectivePredicate, joinExpression); // Push inheritedPredicates down to the source if they don't involve the semi join output Set sourceScope = ImmutableSet.copyOf(sourceSymbols); - EqualityInference.nonInferrableConjuncts(metadata, inheritedPredicate).forEach(conjunct -> { + EqualityInference.nonInferrableConjuncts(inheritedPredicate).forEach(conjunct -> { Expression rewrittenConjunct = allInference.rewrite(conjunct, sourceScope); // Since each source row is reflected exactly once in the output, ok to push non-deterministic predicates down if (rewrittenConjunct != null) { @@ -1363,7 +1363,7 @@ private PlanNode visitFilteringSemiJoin(SemiJoinNode node, RewriteContext filterScope = ImmutableSet.copyOf(filteringSourceSymbols); - EqualityInference.nonInferrableConjuncts(metadata, deterministicInheritedPredicate).forEach(conjunct -> { + EqualityInference.nonInferrableConjuncts(deterministicInheritedPredicate).forEach(conjunct -> { Expression rewrittenConjunct = allInference.rewrite(conjunct, filterScope); // We cannot push non-deterministic predicates to filtering side. Each filtering side row have to be // logically reevaluated for each source row. @@ -1374,13 +1374,13 @@ private PlanNode visitFilteringSemiJoin(SemiJoinNode node, RewriteContext filter // See if we can push the filtering source effective predicate to the source side - EqualityInference.nonInferrableConjuncts(metadata, filteringSourceEffectivePredicate) + EqualityInference.nonInferrableConjuncts(filteringSourceEffectivePredicate) .map(conjunct -> allInference.rewrite(conjunct, sourceScope)) .filter(Objects::nonNull) .forEach(sourceConjuncts::add); // See if we can push the source effective predicate to the filtering source side - EqualityInference.nonInferrableConjuncts(metadata, sourceEffectivePredicate) + EqualityInference.nonInferrableConjuncts(sourceEffectivePredicate) .map(conjunct -> allInference.rewrite(conjunct, filterScope)) .filter(Objects::nonNull) .forEach(filteringSourceConjuncts::add); @@ -1402,8 +1402,8 @@ private PlanNode visitFilteringSemiJoin(SemiJoinNode node, RewriteContext pushdownConjuncts = new ArrayList<>(); List postAggregationConjuncts = new ArrayList<>(); // Strip out non-deterministic conjuncts extractConjuncts(inheritedPredicate).stream() - .filter(expression -> !isDeterministic(expression, metadata)) + .filter(expression -> !isDeterministic(expression)) .forEach(postAggregationConjuncts::add); inheritedPredicate = filterDeterministicConjuncts(metadata, inheritedPredicate); // Sort non-equality predicates by those that can be pushed down and those that cannot Set groupingKeys = ImmutableSet.copyOf(node.getGroupingKeys()); - EqualityInference.nonInferrableConjuncts(metadata, inheritedPredicate).forEach(conjunct -> { + EqualityInference.nonInferrableConjuncts(inheritedPredicate).forEach(conjunct -> { if (node.getGroupIdSymbol().isPresent() && extractUnique(conjunct).contains(node.getGroupIdSymbol().get())) { // aggregation operator synthesizes outputs for group ids corresponding to the global grouping set (i.e., ()), so we // need to preserve any predicates that evaluate the group id to run after the aggregation @@ -1474,7 +1474,7 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext context) } //TODO for LEFT or INNER join type, push down UnnestNode's filter on replicate symbols - EqualityInference equalityInference = new EqualityInference(metadata, inheritedPredicate); + EqualityInference equalityInference = new EqualityInference(inheritedPredicate); List pushdownConjuncts = new ArrayList<>(); List postUnnestConjuncts = new ArrayList<>(); // Strip out non-deterministic conjuncts extractConjuncts(inheritedPredicate).stream() - .filter(expression -> !isDeterministic(expression, metadata)) + .filter(expression -> !isDeterministic(expression)) .forEach(postUnnestConjuncts::add); inheritedPredicate = filterDeterministicConjuncts(metadata, inheritedPredicate); // Sort non-equality predicates by those that can be pushed down and those that cannot Set replicatedSymbols = ImmutableSet.copyOf(node.getReplicateSymbols()); - EqualityInference.nonInferrableConjuncts(metadata, inheritedPredicate).forEach(conjunct -> { + EqualityInference.nonInferrableConjuncts(inheritedPredicate).forEach(conjunct -> { Expression rewrittenConjunct = equalityInference.rewrite(conjunct, replicatedSymbols); if (rewrittenConjunct != null) { pushdownConjuncts.add(rewrittenConjunct); @@ -1527,14 +1527,14 @@ public PlanNode visitUnnest(UnnestNode node, RewriteContext context) postUnnestConjuncts.addAll(equalityPartition.getScopeComplementEqualities()); postUnnestConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities()); - PlanNode rewrittenSource = context.rewrite(node.getSource(), combineConjuncts(metadata, pushdownConjuncts)); + PlanNode rewrittenSource = context.rewrite(node.getSource(), combineConjuncts(pushdownConjuncts)); PlanNode output = node; if (rewrittenSource != node.getSource()) { output = new UnnestNode(node.getId(), rewrittenSource, node.getReplicateSymbols(), node.getMappings(), node.getOrdinalitySymbol(), node.getJoinType()); } if (!postUnnestConjuncts.isEmpty()) { - output = new FilterNode(idAllocator.getNextId(), output, combineConjuncts(metadata, postUnnestConjuncts)); + output = new FilterNode(idAllocator.getNextId(), output, combineConjuncts(postUnnestConjuncts)); } return output; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java index a32e31b44b74..d57c52a58a09 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java @@ -782,7 +782,7 @@ public ActualProperties visitProject(ProjectNode node, List in for (Map.Entry assignment : node.getAssignments().entrySet()) { Expression expression = assignment.getValue(); - Map, Type> expressionTypes = typeAnalyzer.getTypes(session, types, expression); + Map, Type> expressionTypes = typeAnalyzer.getTypes(types, expression); Type type = requireNonNull(expressionTypes.get(NodeRef.of(expression))); IrExpressionInterpreter optimizer = new IrExpressionInterpreter(expression, plannerContext, session, expressionTypes); // TODO: diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java index 955f6a9512a2..d6ad9d0b5bb6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java @@ -182,11 +182,11 @@ public Expression rewriteUsingBounds(ApplyNode.QuantifiedComparison quantifiedCo Function, Expression> quantifier; if (quantifiedComparison.quantifier() == ALL) { emptySetResult = TRUE_LITERAL; - quantifier = expressions -> combineConjuncts(metadata, expressions); + quantifier = expressions -> combineConjuncts(expressions); } else { emptySetResult = FALSE_LITERAL; - quantifier = expressions -> combineDisjuncts(metadata, expressions); + quantifier = expressions -> combineDisjuncts(expressions); } Expression comparisonWithExtremeValue = getBoundComparisons(quantifiedComparison, minValue, maxValue); @@ -210,7 +210,6 @@ private Expression getBoundComparisons(ApplyNode.QuantifiedComparison quantified if (mapOperator(quantifiedComparison) == EQUAL && quantifiedComparison.quantifier() == ALL) { // A = ALL B <=> min B = max B && A = min B return combineConjuncts( - metadata, new ComparisonExpression(EQUAL, minValue.toSymbolReference(), maxValue.toSymbolReference()), new ComparisonExpression(EQUAL, quantifiedComparison.value().toSymbolReference(), maxValue.toSymbolReference())); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java index 6fe4e31531dc..4ee1735745a5 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -20,7 +20,6 @@ import com.google.common.collect.ListMultimap; import com.google.common.collect.Sets; import io.trino.cost.PlanNodeStatsEstimate; -import io.trino.metadata.Metadata; import io.trino.spi.connector.ColumnHandle; import io.trino.sql.DynamicFilters; import io.trino.sql.ir.Expression; @@ -128,19 +127,12 @@ public class UnaliasSymbolReferences implements PlanOptimizer { - private final Metadata metadata; - - public UnaliasSymbolReferences(Metadata metadata) - { - this.metadata = requireNonNull(metadata, "metadata is null"); - } - @Override public PlanNode optimize(PlanNode plan, Context context) { requireNonNull(plan, "plan is null"); - Visitor visitor = new Visitor(metadata, SymbolMapper::symbolMapper); + Visitor visitor = new Visitor(SymbolMapper::symbolMapper); PlanAndMappings result = plan.accept(visitor, UnaliasContext.empty()); return updateDynamicFilterIds(result.getRoot(), visitor.getDynamicFilterIdMap()); } @@ -157,7 +149,7 @@ public NodeAndMappings reallocateSymbols(PlanNode plan, List fields, Sym requireNonNull(fields, "fields is null"); requireNonNull(symbolAllocator, "symbolAllocator is null"); - Visitor visitor = new Visitor(metadata, mapping -> symbolReallocator(mapping, symbolAllocator)); + Visitor visitor = new Visitor(mapping -> symbolReallocator(mapping, symbolAllocator)); PlanAndMappings result = plan.accept(visitor, UnaliasContext.empty()); return new NodeAndMappings(updateDynamicFilterIds(result.getRoot(), visitor.getDynamicFilterIdMap()), symbolMapper(result.getMappings()).map(fields)); } @@ -165,7 +157,7 @@ public NodeAndMappings reallocateSymbols(PlanNode plan, List fields, Sym private PlanNode updateDynamicFilterIds(PlanNode resultNode, Map dynamicFilterIdMap) { if (!dynamicFilterIdMap.isEmpty()) { - resultNode = rewriteWith(new DynamicFilterVisitor(metadata, dynamicFilterIdMap), resultNode); + resultNode = rewriteWith(new DynamicFilterVisitor(dynamicFilterIdMap), resultNode); } return resultNode; } @@ -173,13 +165,11 @@ private PlanNode updateDynamicFilterIds(PlanNode resultNode, Map { - private final Metadata metadata; private final Function, SymbolMapper> mapperProvider; private final Map dynamicFilterIdMap = new HashMap<>(); - public Visitor(Metadata metadata, Function, SymbolMapper> mapperProvider) + public Visitor(Function, SymbolMapper> mapperProvider) { - this.metadata = requireNonNull(metadata, "metadata is null"); this.mapperProvider = requireNonNull(mapperProvider, "mapperProvider is null"); } @@ -926,7 +916,7 @@ private Map mappingFromAssignments(Map assig } // 2. map same deterministic expressions within a projection into the same symbol // omit NullLiterals since those have ambiguous types - else if (DeterminismEvaluator.isDeterministic(expression, metadata)) { + else if (DeterminismEvaluator.isDeterministic(expression)) { Symbol previous = inputsToOutputs.get(expression); if (previous == null) { inputsToOutputs.put(expression, assignment.getKey()); @@ -1417,12 +1407,10 @@ public Map getMappings() private static class DynamicFilterVisitor extends SimplePlanRewriter { - private final Metadata metadata; private final Map dynamicFilterIdMap; - private DynamicFilterVisitor(Metadata metadata, Map dynamicFilterIdMap) + private DynamicFilterVisitor(Map dynamicFilterIdMap) { - this.metadata = requireNonNull(metadata, "metadata is null"); this.dynamicFilterIdMap = requireNonNull(dynamicFilterIdMap, "dynamicFilterIdMap is null"); } @@ -1463,7 +1451,7 @@ private Expression updateDynamicFilterIds(Map newConjuncts.add(newConjunct); } if (updated) { - return combineConjuncts(metadata, newConjuncts.build()); + return combineConjuncts(newConjuncts.build()); } return predicate; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/WindowFilterPushDown.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/WindowFilterPushDown.java index adba057e15df..2819aec80e9f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/WindowFilterPushDown.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/WindowFilterPushDown.java @@ -93,7 +93,7 @@ private Rewriter( this.types = requireNonNull(types, "types is null"); rowNumberFunctionId = plannerContext.getMetadata().resolveBuiltinFunction("row_number", ImmutableList.of()).getFunctionId(); rankFunctionId = plannerContext.getMetadata().resolveBuiltinFunction("rank", ImmutableList.of()).getFunctionId(); - domainTranslator = new DomainTranslator(plannerContext); + domainTranslator = new DomainTranslator(); } @Override @@ -201,7 +201,6 @@ private PlanNode rewriteFilterSource(FilterNode filterNode, PlanNode source, Sym // Remove the ranking domain because it is absorbed into the node TupleDomain newTupleDomain = tupleDomain.filter((symbol, domain) -> !symbol.equals(rankingSymbol)); Expression newPredicate = combineConjuncts( - plannerContext.getMetadata(), extractionResult.getRemainingExpression(), domainTranslator.toPredicate(newTupleDomain)); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/Assignments.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/Assignments.java index 074ee20c02e9..6aeef013a2bb 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/Assignments.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/Assignments.java @@ -18,7 +18,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; -import io.trino.Session; import io.trino.spi.type.Type; import io.trino.sql.ir.Expression; import io.trino.sql.ir.SymbolReference; @@ -80,12 +79,12 @@ public static Assignments of(Symbol symbol1, Expression expression1, Symbol symb return builder().put(symbol1, expression1).put(symbol2, expression2).build(); } - public static Assignments of(Collection expressions, Session session, SymbolAllocator symbolAllocator, IrTypeAnalyzer typeAnalyzer) + public static Assignments of(Collection expressions, SymbolAllocator symbolAllocator, IrTypeAnalyzer typeAnalyzer) { Assignments.Builder assignments = Assignments.builder(); for (Expression expression : expressions) { - Type type = typeAnalyzer.getType(session, symbolAllocator.getTypes(), expression); + Type type = typeAnalyzer.getType(symbolAllocator.getTypes(), expression); assignments.put(symbolAllocator.newSymbol(expression, type), expression); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java index 5c31c8cc1096..e2f547b3b162 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java @@ -55,9 +55,6 @@ import io.trino.sql.DynamicFilters; import io.trino.sql.ir.ComparisonExpression; import io.trino.sql.ir.Expression; -import io.trino.sql.ir.ExpressionRewriter; -import io.trino.sql.ir.ExpressionTreeRewriter; -import io.trino.sql.ir.FunctionCall; import io.trino.sql.ir.Row; import io.trino.sql.ir.SymbolReference; import io.trino.sql.planner.OrderingScheme; @@ -138,7 +135,6 @@ import io.trino.sql.planner.rowpattern.MatchNumberValuePointer; import io.trino.sql.planner.rowpattern.ScalarValuePointer; import io.trino.sql.planner.rowpattern.ir.IrLabel; -import io.trino.sql.tree.QualifiedName; import java.util.ArrayList; import java.util.Collection; @@ -164,7 +160,6 @@ import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; import static io.trino.metadata.GlobalFunctionCatalog.isBuiltinFunctionName; import static io.trino.metadata.LanguageFunctionManager.isInlineFunction; -import static io.trino.metadata.ResolvedFunction.extractFunctionName; import static io.trino.server.DynamicFilterService.DynamicFilterDomainStats; import static io.trino.spi.function.table.DescriptorArgument.NULL_DESCRIPTOR; import static io.trino.sql.DynamicFilters.extractDynamicFilters; @@ -687,7 +682,7 @@ public Void visitExplainAnalyze(ExplainAnalyzeNode node, Context context) public Void visitJoin(JoinNode node, Context context) { List criteriaExpressions = node.getCriteria().stream() - .map(clause -> unresolveFunctions(clause.toExpression())) + .map(JoinNode.EquiJoinClause::toExpression) .collect(toImmutableList()); NodeRepresentation nodeOutput; @@ -699,7 +694,7 @@ public Void visitJoin(JoinNode node, Context context) else { ImmutableMap.Builder descriptor = ImmutableMap.builder() .put("criteria", Joiner.on(" AND ").join(anonymizeExpressions(criteriaExpressions))); - node.getFilter().ifPresent(filter -> descriptor.put("filter", formatFilter(unresolveFunctions(filter)))); + node.getFilter().ifPresent(filter -> descriptor.put("filter", formatFilter(filter))); descriptor.put("hash", formatHash(node.getLeftHashSymbol(), node.getRightHashSymbol())); node.getDistributionType().ifPresent(distribution -> descriptor.put("distribution", distribution.name())); nodeOutput = addNode(node, node.getType().getJoinLabel(), descriptor.buildOrThrow(), node.getReorderJoinStatsAndCost(), context); @@ -1006,7 +1001,7 @@ public Void visitPatternRecognition(PatternRecognitionNode node, Context context nodeOutput.appendDetails( "%s := %s", anonymizer.anonymize(entry.getKey()), - anonymizer.anonymize(unresolveFunctions(entry.getValue().getExpressionAndValuePointers().getExpression()))); + anonymizer.anonymize(entry.getValue().getExpressionAndValuePointers().getExpression())); appendValuePointers(nodeOutput, entry.getValue().getExpressionAndValuePointers()); } if (node.getRowsPerMatch() != WINDOW) { @@ -1016,7 +1011,7 @@ public Void visitPatternRecognition(PatternRecognitionNode node, Context context nodeOutput.appendDetails("pattern[%s] (%s)", node.getPattern(), node.isInitial() ? "INITIAL" : "SEEK"); for (Entry entry : node.getVariableDefinitions().entrySet()) { - nodeOutput.appendDetails("%s := %s", entry.getKey().getName(), anonymizer.anonymize(unresolveFunctions(entry.getValue().getExpression()))); + nodeOutput.appendDetails("%s := %s", entry.getKey().getName(), anonymizer.anonymize(entry.getValue().getExpression())); appendValuePointers(nodeOutput, entry.getValue()); } @@ -1193,11 +1188,10 @@ public Void visitValues(ValuesNode node, Context context) .map(row -> { if (row instanceof Row) { return ((Row) row).getItems().stream() - .map(PlanPrinter::unresolveFunctions) .map(anonymizer::anonymize) .collect(joining(", ", "(", ")")); } - return anonymizer.anonymize(unresolveFunctions(row)); + return anonymizer.anonymize(row); }) .collect(toImmutableList()); for (String row : rows) { @@ -1259,7 +1253,7 @@ private Void visitScanFilterAndProjectInfo( operatorName += "Filter"; Expression predicate = filterNode.get().getPredicate(); DynamicFilters.ExtractResult extractResult = extractDynamicFilters(predicate); - descriptor.put("filterPredicate", formatFilter(unresolveFunctions(combineConjunctsWithDuplicates(extractResult.getStaticConjuncts())))); + descriptor.put("filterPredicate", formatFilter(combineConjunctsWithDuplicates(extractResult.getStaticConjuncts()))); if (!extractResult.getDynamicConjuncts().isEmpty()) { dynamicFilters = extractResult.getDynamicConjuncts(); descriptor.put("dynamicFilters", printDynamicFilters(dynamicFilters)); @@ -1973,7 +1967,7 @@ private void printAssignments(NodeRepresentation nodeOutput, Assignments assignm // skip identity assignments continue; } - nodeOutput.appendDetails("%s := %s", anonymizer.anonymize(entry.getKey()), anonymizer.anonymize(unresolveFunctions(entry.getValue()))); + nodeOutput.appendDetails("%s := %s", anonymizer.anonymize(entry.getKey()), anonymizer.anonymize(entry.getValue())); } } @@ -2252,27 +2246,6 @@ private static String formatFunctionName(ResolvedFunction function) return name.toString(); } - private static Expression unresolveFunctions(Expression expression) - { - return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<>() - { - @Override - public Expression rewriteFunctionCall(FunctionCall node, Void context, ExpressionTreeRewriter treeRewriter) - { - FunctionCall rewritten = treeRewriter.defaultRewrite(node, context); - CatalogSchemaFunctionName name = extractFunctionName(node.getName()); - QualifiedName qualifiedName; - if (isInlineFunction(name) || isBuiltinFunctionName(name)) { - qualifiedName = QualifiedName.of(name.getFunctionName()); - } - else { - qualifiedName = QualifiedName.of(name.getCatalogName(), name.getSchemaName(), name.getFunctionName()); - } - return new FunctionCall(qualifiedName, rewritten.getArguments()); - } - }, expression); - } - private record Context(Optional tag, Optional types, boolean isInitialPlan) { public Context(Optional types, boolean isInitialPlan) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/AllFunctionsResolved.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/AllFunctionsResolved.java deleted file mode 100644 index b2a08dd015e3..000000000000 --- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/AllFunctionsResolved.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.sql.planner.sanity; - -import com.google.common.collect.ImmutableList.Builder; -import io.trino.Session; -import io.trino.execution.warnings.WarningCollector; -import io.trino.metadata.ResolvedFunction; -import io.trino.sql.PlannerContext; -import io.trino.sql.ir.DefaultTraversalVisitor; -import io.trino.sql.ir.Expression; -import io.trino.sql.ir.FunctionCall; -import io.trino.sql.planner.ExpressionExtractor; -import io.trino.sql.planner.IrTypeAnalyzer; -import io.trino.sql.planner.Symbol; -import io.trino.sql.planner.TypeProvider; -import io.trino.sql.planner.plan.PlanNode; - -import static com.google.common.base.Preconditions.checkArgument; - -public final class AllFunctionsResolved - implements PlanSanityChecker.Checker -{ - private static final Visitor VISITOR = new Visitor(); - - @Override - public void validate( - PlanNode planNode, - Session session, - PlannerContext plannerContext, - IrTypeAnalyzer typeAnalyzer, - TypeProvider types, - WarningCollector warningCollector) - { - ExpressionExtractor.forEachExpression(planNode, AllFunctionsResolved::validate); - } - - private static void validate(Expression expression) - { - VISITOR.process(expression, null); - } - - private static class Visitor - extends DefaultTraversalVisitor> - { - @Override - protected Void visitFunctionCall(FunctionCall node, Builder context) - { - checkArgument(ResolvedFunction.isResolved(node.getName()), "Function call has not been resolved: %s", node); - return super.visitFunctionCall(node, context); - } - } -} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/PlanSanityChecker.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/PlanSanityChecker.java index ffcc88ecf6fe..99f1a175aab2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/PlanSanityChecker.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/PlanSanityChecker.java @@ -41,14 +41,12 @@ public PlanSanityChecker(boolean forceSingleNode) Stage.INTERMEDIATE, new ValidateDependenciesChecker(), new NoDuplicatePlanNodeIdsChecker(), - new AllFunctionsResolved(), new TypeValidator(), new VerifyOnlyOneOutputNode()) .putAll( Stage.FINAL, new ValidateDependenciesChecker(), new NoDuplicatePlanNodeIdsChecker(), - new AllFunctionsResolved(), new TypeValidator(), new VerifyOnlyOneOutputNode(), new VerifyNoFilteredAggregations(), @@ -63,7 +61,6 @@ public PlanSanityChecker(boolean forceSingleNode) Stage.AFTER_ADAPTIVE_PLANNING, new ValidateDependenciesChecker(), new NoDuplicatePlanNodeIdsChecker(), - new AllFunctionsResolved(), new TypeValidator(), new VerifyOnlyOneOutputNode(), new VerifyNoFilteredAggregations(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/TypeValidator.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/TypeValidator.java index 58f8c6b71f17..ea32fb8d05c8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/TypeValidator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/TypeValidator.java @@ -55,19 +55,17 @@ public void validate(PlanNode plan, TypeProvider types, WarningCollector warningCollector) { - plan.accept(new Visitor(session, typeAnalyzer, types), null); + plan.accept(new Visitor(typeAnalyzer, types), null); } private static class Visitor extends SimplePlanVisitor { - private final Session session; private final IrTypeAnalyzer typeAnalyzer; private final TypeProvider types; - public Visitor(Session session, IrTypeAnalyzer typeAnalyzer, TypeProvider types) + public Visitor(IrTypeAnalyzer typeAnalyzer, TypeProvider types) { - this.session = requireNonNull(session, "session is null"); this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.types = requireNonNull(types, "types is null"); } @@ -118,7 +116,7 @@ public Void visitProject(ProjectNode node, Void context) verifyTypeSignature(entry.getKey(), expectedType, types.get(Symbol.from(symbolReference))); continue; } - Type actualType = typeAnalyzer.getType(session, types, entry.getValue()); + Type actualType = typeAnalyzer.getType(types, entry.getValue()); verifyTypeSignature(entry.getKey(), expectedType, actualType); } @@ -173,7 +171,7 @@ private void checkCall(Symbol symbol, BoundSignature signature, List if (expectedTypeSignature instanceof FunctionType) { continue; } - Type actualTypeSignature = typeAnalyzer.getType(session, types, arguments.get(i)); + Type actualTypeSignature = typeAnalyzer.getType(types, arguments.get(i)); verifyTypeSignature(symbol, expectedTypeSignature, actualTypeSignature); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java b/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java index d195baf7a98e..99f2b0201daa 100644 --- a/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/relational/SqlToRowExpressionTranslator.java @@ -189,7 +189,7 @@ protected RowExpression visitFunctionCall(FunctionCall node, Void context) .map(value -> process(value, context)) .collect(toImmutableList()); - return new CallExpression(metadata.decodeFunction(node.getName()), arguments); + return new CallExpression(node.getFunction(), arguments); } @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutinePlanner.java b/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutinePlanner.java index e10fdd334977..a8ee520f2df5 100644 --- a/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutinePlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutinePlanner.java @@ -338,7 +338,7 @@ private RowExpression toRowExpression(Context context, Expression expression) // The expression tree has been rewritten which breaks all the identity maps, so redo the analysis // to re-analyze coercions that might be necessary IrTypeAnalyzer analyzer = new IrTypeAnalyzer(plannerContext); - Map, Type> types = analyzer.getTypes(session, typeProvider, lambdaCaptureDesugared); + Map, Type> types = analyzer.getTypes(typeProvider, lambdaCaptureDesugared); // optimize the expression IrExpressionInterpreter interpreter = new IrExpressionInterpreter(lambdaCaptureDesugared, plannerContext, session, types); @@ -349,7 +349,7 @@ private RowExpression toRowExpression(Context context, Expression expression) new Constant(types.get(io.trino.sql.ir.NodeRef.of(lambdaCaptureDesugared)), value); // Analyze again after optimization - types = analyzer.getTypes(session, typeProvider, optimized); + types = analyzer.getTypes(typeProvider, optimized); // translate to RowExpression TranslationVisitor translator = new TranslationVisitor(plannerContext.getMetadata(), plannerContext.getTypeManager(), types, ImmutableMap.of(), context.variables()); diff --git a/core/trino-main/src/main/java/io/trino/util/SpatialJoinUtils.java b/core/trino-main/src/main/java/io/trino/util/SpatialJoinUtils.java index 8f63354ae6c1..606fc0f77c03 100644 --- a/core/trino-main/src/main/java/io/trino/util/SpatialJoinUtils.java +++ b/core/trino-main/src/main/java/io/trino/util/SpatialJoinUtils.java @@ -22,7 +22,6 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; -import static io.trino.metadata.ResolvedFunction.extractFunctionName; import static io.trino.sql.ir.IrUtils.extractConjuncts; public final class SpatialJoinUtils @@ -53,7 +52,7 @@ public static List extractSupportedSpatialFunctions(Expression fil private static boolean isSupportedSpatialFunction(FunctionCall functionCall) { - CatalogSchemaFunctionName functionName = extractFunctionName(functionCall.getName()); + CatalogSchemaFunctionName functionName = functionCall.getFunction().getName(); return functionName.equals(builtinFunctionName(ST_CONTAINS)) || functionName.equals(builtinFunctionName(ST_WITHIN)) || functionName.equals(builtinFunctionName(ST_INTERSECTS)); @@ -94,8 +93,8 @@ private static boolean isSupportedSpatialComparison(ComparisonExpression express private static boolean isSTDistance(Expression expression) { - if (expression instanceof FunctionCall) { - return extractFunctionName(((FunctionCall) expression).getName()).equals(builtinFunctionName(ST_DISTANCE)); + if (expression instanceof FunctionCall call) { + return call.getFunction().getName().equals(builtinFunctionName(ST_DISTANCE)); } return false; diff --git a/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsCalculator.java b/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsCalculator.java index 52d01e833470..e1fd781bb013 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsCalculator.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsCalculator.java @@ -347,9 +347,9 @@ public void testOrStats() @Test public void testUnsupportedExpression() { - assertExpression(new FunctionCall(SIN.toQualifiedName(), ImmutableList.of(new SymbolReference("x")))) + assertExpression(new FunctionCall(SIN, ImmutableList.of(new SymbolReference("x")))) .outputRowsCountUnknown(); - assertExpression(new ComparisonExpression(EQUAL, new SymbolReference("x"), new FunctionCall(SIN.toQualifiedName(), ImmutableList.of(new SymbolReference("x"))))) + assertExpression(new ComparisonExpression(EQUAL, new SymbolReference("x"), new FunctionCall(SIN, ImmutableList.of(new SymbolReference("x"))))) .outputRowsCountUnknown(); } @@ -381,7 +381,7 @@ public void testAndStats() .symbolStats(new Symbol("y"), SymbolStatsAssertion::emptyRange); // first argument unknown - assertExpression(new LogicalExpression(AND, ImmutableList.of(new FunctionCall(JSON_ARRAY_CONTAINS.toQualifiedName(), ImmutableList.of(new Constant(JSON, JsonTypeUtil.jsonParse(Slices.utf8Slice("[]"))), new SymbolReference("x"))), new ComparisonExpression(LESS_THAN, new SymbolReference("x"), new Constant(DOUBLE, 0.0))))) + assertExpression(new LogicalExpression(AND, ImmutableList.of(new FunctionCall(JSON_ARRAY_CONTAINS, ImmutableList.of(new Constant(JSON, JsonTypeUtil.jsonParse(Slices.utf8Slice("[]"))), new SymbolReference("x"))), new ComparisonExpression(LESS_THAN, new SymbolReference("x"), new Constant(DOUBLE, 0.0))))) .outputRowsCount(337.5) .symbolStats(new Symbol("x"), symbolAssert -> symbolAssert.lowValue(-10) @@ -390,7 +390,7 @@ public void testAndStats() .nullsFraction(0)); // second argument unknown - assertExpression(new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference("x"), new Constant(DOUBLE, 0.0)), new FunctionCall(JSON_ARRAY_CONTAINS.toQualifiedName(), ImmutableList.of(new Constant(JSON, JsonTypeUtil.jsonParse(Slices.utf8Slice("[]"))), new SymbolReference("x")))))) + assertExpression(new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN, new SymbolReference("x"), new Constant(DOUBLE, 0.0)), new FunctionCall(JSON_ARRAY_CONTAINS, ImmutableList.of(new Constant(JSON, JsonTypeUtil.jsonParse(Slices.utf8Slice("[]"))), new SymbolReference("x")))))) .outputRowsCount(337.5) .symbolStats(new Symbol("x"), symbolAssert -> symbolAssert.lowValue(-10) @@ -400,8 +400,8 @@ public void testAndStats() // both arguments unknown assertExpression(new LogicalExpression(AND, ImmutableList.of( - new FunctionCall(JSON_ARRAY_CONTAINS.toQualifiedName(), ImmutableList.of(new Constant(JSON, JsonTypeUtil.jsonParse(Slices.utf8Slice("[11]"))), new SymbolReference("x"))), - new FunctionCall(JSON_ARRAY_CONTAINS.toQualifiedName(), ImmutableList.of(new Constant(JSON, JsonTypeUtil.jsonParse(Slices.utf8Slice("[13]"))), new SymbolReference("x")))))) + new FunctionCall(JSON_ARRAY_CONTAINS, ImmutableList.of(new Constant(JSON, JsonTypeUtil.jsonParse(Slices.utf8Slice("[11]"))), new SymbolReference("x"))), + new FunctionCall(JSON_ARRAY_CONTAINS, ImmutableList.of(new Constant(JSON, JsonTypeUtil.jsonParse(Slices.utf8Slice("[13]"))), new SymbolReference("x")))))) .outputRowsCountUnknown(); assertExpression(new LogicalExpression(AND, ImmutableList.of(new InPredicate(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("a")), ImmutableList.of(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("b")), new Constant(VarcharType.VARCHAR, Slices.utf8Slice("c")))), new ComparisonExpression(EQUAL, new SymbolReference("unknownRange"), new Constant(DOUBLE, 3.0))))) @@ -592,7 +592,7 @@ public void testNotStats() .nullsFraction(0)) .symbolStats(new Symbol("y"), symbolAssert -> symbolAssert.isEqualTo(yStats)); - assertExpression(new NotExpression(new FunctionCall(JSON_ARRAY_CONTAINS.toQualifiedName(), ImmutableList.of(new Constant(JSON, JsonTypeUtil.jsonParse(Slices.utf8Slice("[]"))), new SymbolReference("x"))))) + assertExpression(new NotExpression(new FunctionCall(JSON_ARRAY_CONTAINS, ImmutableList.of(new Constant(JSON, JsonTypeUtil.jsonParse(Slices.utf8Slice("[]"))), new SymbolReference("x"))))) .outputRowsCountUnknown(); } diff --git a/core/trino-main/src/test/java/io/trino/metadata/TestingFunctionResolution.java b/core/trino-main/src/test/java/io/trino/metadata/TestingFunctionResolution.java index 8212d6c3c3c6..5a78beda338d 100644 --- a/core/trino-main/src/test/java/io/trino/metadata/TestingFunctionResolution.java +++ b/core/trino-main/src/test/java/io/trino/metadata/TestingFunctionResolution.java @@ -216,7 +216,7 @@ public TestingFunctionCallBuilder setArguments(List types, List, Type> expressionTypes = new IrTypeAnalyzer(PLANNER_CONTEXT).getTypes(TEST_SESSION, SYMBOL_TYPES, parsedExpression); + Map, Type> expressionTypes = new IrTypeAnalyzer(PLANNER_CONTEXT).getTypes(SYMBOL_TYPES, parsedExpression); IrExpressionInterpreter interpreter = new IrExpressionInterpreter(parsedExpression, PLANNER_CONTEXT, TEST_SESSION, expressionTypes); return interpreter.optimize(INPUTS); } @@ -916,7 +916,7 @@ private static void assertEvaluatedEquals(Expression actual, Expression expected private static Object evaluate(Expression expression) { - Map, Type> expressionTypes = new IrTypeAnalyzer(PLANNER_CONTEXT).getTypes(TEST_SESSION, SYMBOL_TYPES, expression); + Map, Type> expressionTypes = new IrTypeAnalyzer(PLANNER_CONTEXT).getTypes(SYMBOL_TYPES, expression); IrExpressionInterpreter interpreter = new IrExpressionInterpreter(expression, PLANNER_CONTEXT, TEST_SESSION, expressionTypes); return interpreter.evaluate(INPUTS); diff --git a/core/trino-main/src/test/java/io/trino/sql/TestExpressionUtils.java b/core/trino-main/src/test/java/io/trino/sql/TestExpressionUtils.java index df5bad11a365..95cab57a9e57 100644 --- a/core/trino-main/src/test/java/io/trino/sql/TestExpressionUtils.java +++ b/core/trino-main/src/test/java/io/trino/sql/TestExpressionUtils.java @@ -14,13 +14,11 @@ package io.trino.sql; import com.google.common.collect.ImmutableList; -import io.trino.metadata.Metadata; import io.trino.sql.ir.Expression; import io.trino.sql.ir.LogicalExpression; import io.trino.sql.ir.SymbolReference; import org.junit.jupiter.api.Test; -import static io.trino.metadata.MetadataManager.createTestMetadataManager; import static io.trino.sql.ir.IrUtils.and; import static io.trino.sql.ir.IrUtils.combineConjuncts; import static io.trino.sql.ir.LogicalExpression.Operator.AND; @@ -28,8 +26,6 @@ public class TestExpressionUtils { - private final Metadata metadata = createTestMetadataManager(); - @Test public void testAnd() { @@ -41,6 +37,6 @@ public void testAnd() assertThat(and(a, b, c, d, e)).isEqualTo(new LogicalExpression(AND, ImmutableList.of(a, b, c, d, e))); - assertThat(combineConjuncts(metadata, a, b, a, c, d, c, e)).isEqualTo(new LogicalExpression(AND, ImmutableList.of(a, b, c, d, e))); + assertThat(combineConjuncts(a, b, a, c, d, c, e)).isEqualTo(new LogicalExpression(AND, ImmutableList.of(a, b, c, d, e))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/TestSqlToRowExpressionTranslator.java b/core/trino-main/src/test/java/io/trino/sql/TestSqlToRowExpressionTranslator.java index 53538bac2074..44fb26d28bf0 100644 --- a/core/trino-main/src/test/java/io/trino/sql/TestSqlToRowExpressionTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/TestSqlToRowExpressionTranslator.java @@ -114,6 +114,6 @@ private Expression simplifyExpression(Expression expression) private Map, Type> getExpressionTypes(Expression expression) { - return new IrTypeAnalyzer(PLANNER_CONTEXT).getTypes(TEST_SESSION, TypeProvider.empty(), expression); + return new IrTypeAnalyzer(PLANNER_CONTEXT).getTypes(TypeProvider.empty(), expression); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor2.java b/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor2.java index 6b712bc9a019..56d659e5b705 100644 --- a/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor2.java +++ b/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor2.java @@ -185,7 +185,7 @@ else if (type == VARCHAR) { for (int i = 0; i < columnCount; i++) { // alternatively use identity expression rowExpression("varchar" + i, type) or // rowExpression("substr(varchar" + i + ", 1, 1)", type) - builder.add(rowExpression(new FunctionCall(CONCAT.toQualifiedName(), ImmutableList.of(new SymbolReference("varchar" + i), new Constant(VARCHAR, Slices.utf8Slice("foo")))))); + builder.add(rowExpression(new FunctionCall(CONCAT, ImmutableList.of(new SymbolReference("varchar" + i), new Constant(VARCHAR, Slices.utf8Slice("foo")))))); } } return builder.build(); @@ -193,7 +193,7 @@ else if (type == VARCHAR) { private RowExpression rowExpression(Expression expression) { - Map, Type> expressionTypes = TYPE_ANALYZER.getTypes(TEST_SESSION, TypeProvider.copyOf(symbolTypes), expression); + Map, Type> expressionTypes = TYPE_ANALYZER.getTypes(TypeProvider.copyOf(symbolTypes), expression); return SqlToRowExpressionTranslator.translate( expression, expressionTypes, diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/AbstractPredicatePushdownTest.java b/core/trino-main/src/test/java/io/trino/sql/planner/AbstractPredicatePushdownTest.java index 3eb61c8b8106..defc220fcc11 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/AbstractPredicatePushdownTest.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/AbstractPredicatePushdownTest.java @@ -17,6 +17,8 @@ import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slices; import io.trino.Session; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TestingFunctionResolution; import io.trino.sql.ir.ArithmeticBinaryExpression; import io.trino.sql.ir.Cast; import io.trino.sql.ir.ComparisonExpression; @@ -32,7 +34,6 @@ import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.ProjectNode; import io.trino.sql.planner.plan.WindowNode; -import io.trino.sql.tree.QualifiedName; import org.junit.jupiter.api.Test; import java.util.List; @@ -43,6 +44,7 @@ import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.createVarcharType; +import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.DIVIDE; import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MULTIPLY; @@ -71,6 +73,11 @@ public abstract class AbstractPredicatePushdownTest extends BasePlanTest { + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction RANDOM = FUNCTIONS.resolveFunction("random", fromTypes()); + private static final ResolvedFunction ROUND = FUNCTIONS.resolveFunction("round", fromTypes(DOUBLE)); + private static final ResolvedFunction LENGTH = FUNCTIONS.resolveFunction("length", fromTypes(createVarcharType(1))); + private final boolean enableDynamicFiltering; protected AbstractPredicatePushdownTest(boolean enableDynamicFiltering) @@ -108,7 +115,7 @@ public void testNonDeterministicPredicatePropagatesOnlyToSourceSideOfSemiJoin() anyTree( semiJoin("LINE_ORDER_KEY", "ORDERS_ORDER_KEY", "SEMI_JOIN_RESULT", enableDynamicFiltering, filter( - new ComparisonExpression(EQUAL, new SymbolReference("LINE_ORDER_KEY"), new Cast(new FunctionCall(QualifiedName.of("random"), ImmutableList.of(new Constant(INTEGER, 5L))), BIGINT)), + new ComparisonExpression(EQUAL, new SymbolReference("LINE_ORDER_KEY"), new Cast(new FunctionCall(RANDOM, ImmutableList.of(new Constant(INTEGER, 5L))), BIGINT)), tableScan("lineitem", ImmutableMap.of( "LINE_ORDER_KEY", "orderkey"))), node(ExchangeNode.class, // NO filter here @@ -118,7 +125,7 @@ public void testNonDeterministicPredicatePropagatesOnlyToSourceSideOfSemiJoin() anyTree( semiJoin("LINE_ORDER_KEY", "ORDERS_ORDER_KEY", "SEMI_JOIN_RESULT", filter( - new ComparisonExpression(EQUAL, new SymbolReference("LINE_ORDER_KEY"), new Cast(new FunctionCall(QualifiedName.of("random"), ImmutableList.of(new Constant(INTEGER, 5L))), BIGINT)), + new ComparisonExpression(EQUAL, new SymbolReference("LINE_ORDER_KEY"), new Cast(new FunctionCall(RANDOM, ImmutableList.of(new Constant(INTEGER, 5L))), BIGINT)), tableScan("lineitem", ImmutableMap.of( "LINE_ORDER_KEY", "orderkey"))), anyTree( @@ -348,7 +355,7 @@ public void testPredicatePushDownOverProjection() anyTree( filter( new ComparisonExpression(GREATER_THAN, new SymbolReference("expr"), new Constant(DOUBLE, 5000.0)), - project(ImmutableMap.of("expr", expression(new ArithmeticBinaryExpression(MULTIPLY, new FunctionCall(QualifiedName.of("random"), ImmutableList.of()), new Cast(new SymbolReference("orderkey"), DOUBLE)))), + project(ImmutableMap.of("expr", expression(new ArithmeticBinaryExpression(MULTIPLY, new FunctionCall(RANDOM, ImmutableList.of()), new Cast(new SymbolReference("orderkey"), DOUBLE)))), tableScan("orders", ImmutableMap.of( "orderkey", "orderkey")))))); } @@ -427,7 +434,7 @@ public void testPredicateOnNonDeterministicSymbolsPushedDown() anyTree( filter( new ComparisonExpression(GREATER_THAN, new SymbolReference("ROUND"), new Constant(DOUBLE, 100.0)), - project(ImmutableMap.of("ROUND", expression(new FunctionCall(QualifiedName.of("round"), ImmutableList.of(new ArithmeticBinaryExpression(MULTIPLY, new Cast(new SymbolReference("CUST_KEY"), DOUBLE), new FunctionCall(QualifiedName.of("random"), ImmutableList.of())))))), + project(ImmutableMap.of("ROUND", expression(new FunctionCall(ROUND, ImmutableList.of(new ArithmeticBinaryExpression(MULTIPLY, new Cast(new SymbolReference("CUST_KEY"), DOUBLE), new FunctionCall(RANDOM, ImmutableList.of())))))), tableScan( "orders", ImmutableMap.of("CUST_KEY", "custkey")))))))); @@ -443,7 +450,7 @@ public void testNonDeterministicPredicateNotPushedDown() ") WHERE custkey > 100*rand()", anyTree( filter( - new ComparisonExpression(GREATER_THAN, new Cast(new SymbolReference("CUST_KEY"), DOUBLE), new ArithmeticBinaryExpression(MULTIPLY, new FunctionCall(QualifiedName.of("random"), ImmutableList.of()), new Constant(DOUBLE, 100.0))), + new ComparisonExpression(GREATER_THAN, new Cast(new SymbolReference("CUST_KEY"), DOUBLE), new ArithmeticBinaryExpression(MULTIPLY, new FunctionCall(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 100.0))), anyTree( node(WindowNode.class, anyTree( @@ -464,7 +471,7 @@ public void testRemovesRedundantTableScanPredicate() JoinNode.class, node(ProjectNode.class, filter( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(EQUAL, new SymbolReference("ORDERKEY"), new Constant(BIGINT, 123L)), new ComparisonExpression(EQUAL, new FunctionCall(QualifiedName.of("random"), ImmutableList.of()), new Cast(new SymbolReference("ORDERKEY"), DOUBLE)), new ComparisonExpression(LESS_THAN, new FunctionCall(QualifiedName.of("length"), ImmutableList.of(new SymbolReference("ORDERSTATUS"))), new Constant(BIGINT, 42L)))), + new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(EQUAL, new SymbolReference("ORDERKEY"), new Constant(BIGINT, 123L)), new ComparisonExpression(EQUAL, new FunctionCall(RANDOM, ImmutableList.of()), new Cast(new SymbolReference("ORDERKEY"), DOUBLE)), new ComparisonExpression(LESS_THAN, new FunctionCall(LENGTH, ImmutableList.of(new SymbolReference("ORDERSTATUS"))), new Constant(BIGINT, 42L)))), tableScan( "orders", ImmutableMap.of( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestCanonicalize.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestCanonicalize.java index a0bff496248c..dd21e2c55eef 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestCanonicalize.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestCanonicalize.java @@ -70,7 +70,7 @@ public void testDuplicatesInWindowOrderBy() .addFunction(windowFunction("row_number", ImmutableList.of(), DEFAULT_FRAME)), values("A"))), ImmutableList.of( - new UnaliasSymbolReferences(getPlanTester().getPlannerContext().getMetadata()), + new UnaliasSymbolReferences(), new IterativeOptimizer( getPlanTester().getPlannerContext(), new RuleStatsRecorder(), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestDereferencePushDown.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestDereferencePushDown.java index d1f9f579a279..aa778f6a8ed9 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestDereferencePushDown.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestDereferencePushDown.java @@ -16,6 +16,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.Session; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TestingFunctionResolution; import io.trino.sql.ir.ArithmeticBinaryExpression; import io.trino.sql.ir.ComparisonExpression; import io.trino.sql.ir.Constant; @@ -24,7 +26,6 @@ import io.trino.sql.ir.SubscriptExpression; import io.trino.sql.ir.SymbolReference; import io.trino.sql.planner.assertions.BasePlanTest; -import io.trino.sql.tree.QualifiedName; import org.junit.jupiter.api.Test; import static io.trino.SystemSessionProperties.FILTERING_SEMI_JOIN_TO_INNER; @@ -32,6 +33,7 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MULTIPLY; import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; import static io.trino.sql.ir.LogicalExpression.Operator.OR; @@ -52,6 +54,9 @@ public class TestDereferencePushDown extends BasePlanTest { + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction IS_FINITE = FUNCTIONS.resolveFunction("is_finite", fromTypes(DOUBLE)); + @Test public void testDereferencePushdownMultiLevel() { @@ -128,7 +133,7 @@ public void testDereferencePushdownFilter() filter( new LogicalExpression(OR, ImmutableList.of( new ComparisonExpression(EQUAL, new SymbolReference("a_x"), new Constant(BIGINT, 7L)), - new FunctionCall(QualifiedName.of("is_finite"), ImmutableList.of(new SymbolReference("b_y"))))), + new FunctionCall(IS_FINITE, ImmutableList.of(new SymbolReference("b_y"))))), values( ImmutableList.of("b_x", "b_y", "a_y", "a_x"), ImmutableList.of(ImmutableList.of( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestDeterminismEvaluator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestDeterminismEvaluator.java index 0a657be0e7be..8363f8734d69 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestDeterminismEvaluator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestDeterminismEvaluator.java @@ -14,7 +14,6 @@ package io.trino.sql.planner; import com.google.common.collect.ImmutableList; -import io.trino.metadata.Metadata; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; @@ -44,32 +43,31 @@ public class TestDeterminismEvaluator @Test public void testSanity() { - Metadata metadata = functionResolution.getMetadata(); - assertThat(DeterminismEvaluator.isDeterministic(function("rand"), metadata)).isFalse(); - assertThat(DeterminismEvaluator.isDeterministic(function("random"), metadata)).isFalse(); - assertThat(DeterminismEvaluator.isDeterministic(function("shuffle", ImmutableList.of(new ArrayType(VARCHAR)), ImmutableList.of(new Constant(UnknownType.UNKNOWN, null))), - metadata)).isFalse(); - assertThat(DeterminismEvaluator.isDeterministic(function("uuid"), metadata)).isFalse(); - assertThat(DeterminismEvaluator.isDeterministic(function("abs", ImmutableList.of(DOUBLE), ImmutableList.of(input("symbol"))), metadata)).isTrue(); - assertThat(DeterminismEvaluator.isDeterministic(function("abs", ImmutableList.of(DOUBLE), ImmutableList.of(function("rand"))), metadata)).isFalse(); + assertThat(DeterminismEvaluator.isDeterministic(function("rand"))).isFalse(); + assertThat(DeterminismEvaluator.isDeterministic(function("random"))).isFalse(); + assertThat(DeterminismEvaluator.isDeterministic(function("shuffle", ImmutableList.of(new ArrayType(VARCHAR)), ImmutableList.of(new Constant(UnknownType.UNKNOWN, null))) + )).isFalse(); + assertThat(DeterminismEvaluator.isDeterministic(function("uuid"))).isFalse(); + assertThat(DeterminismEvaluator.isDeterministic(function("abs", ImmutableList.of(DOUBLE), ImmutableList.of(input("symbol"))))).isTrue(); + assertThat(DeterminismEvaluator.isDeterministic(function("abs", ImmutableList.of(DOUBLE), ImmutableList.of(function("rand"))))).isFalse(); assertThat(DeterminismEvaluator.isDeterministic( function( "abs", ImmutableList.of(DOUBLE), - ImmutableList.of(function("abs", ImmutableList.of(DOUBLE), ImmutableList.of(input("symbol"))))), - metadata)).isTrue(); + ImmutableList.of(function("abs", ImmutableList.of(DOUBLE), ImmutableList.of(input("symbol"))))) + )).isTrue(); assertThat(DeterminismEvaluator.isDeterministic( function( "filter", ImmutableList.of(new ArrayType(INTEGER), new FunctionType(ImmutableList.of(INTEGER), BOOLEAN)), - ImmutableList.of(lambda("a", comparison(GREATER_THAN, input("a"), new Constant(INTEGER, 0L))))), - metadata)).isTrue(); + ImmutableList.of(lambda("a", comparison(GREATER_THAN, input("a"), new Constant(INTEGER, 0L))))) + )).isTrue(); assertThat(DeterminismEvaluator.isDeterministic( function( "filter", ImmutableList.of(new ArrayType(INTEGER), new FunctionType(ImmutableList.of(INTEGER), BOOLEAN)), - ImmutableList.of(lambda("a", comparison(GREATER_THAN, function("rand", ImmutableList.of(INTEGER), ImmutableList.of(input("a"))), new Constant(INTEGER, 0L))))), - metadata)).isFalse(); + ImmutableList.of(lambda("a", comparison(GREATER_THAN, function("rand", ImmutableList.of(INTEGER), ImmutableList.of(input("a"))), new Constant(INTEGER, 0L))))) + )).isFalse(); } private FunctionCall function(String name) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestDomainTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestDomainTranslator.java index 62ac34f387dc..1bf90647c391 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestDomainTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestDomainTranslator.java @@ -180,7 +180,7 @@ public class TestDomainTranslator public void setup() { functionResolution = new TestingFunctionResolution(); - domainTranslator = new DomainTranslator(functionResolution.getPlannerContext()); + domainTranslator = new DomainTranslator(); } @AfterAll @@ -1992,10 +1992,10 @@ public void testStartsWithFunction() public void testUnsupportedFunctions() { assertUnsupportedPredicate(new FunctionCall( - functionResolution.resolveFunction("length", fromTypes(VARCHAR)).toQualifiedName(), + functionResolution.resolveFunction("length", fromTypes(VARCHAR)), ImmutableList.of(C_VARCHAR.toSymbolReference()))); assertUnsupportedPredicate(new FunctionCall( - functionResolution.resolveFunction("replace", fromTypes(VARCHAR, VARCHAR)).toQualifiedName(), + functionResolution.resolveFunction("replace", fromTypes(VARCHAR, VARCHAR)), ImmutableList.of(C_VARCHAR.toSymbolReference(), stringLiteral("abc")))); } @@ -2116,31 +2116,31 @@ private static ComparisonExpression isDistinctFrom(Symbol symbol, Expression exp private FunctionCall like(Symbol symbol, String pattern) { return new FunctionCall( - functionResolution.resolveFunction(LIKE_FUNCTION_NAME, fromTypes(VARCHAR, LikePatternType.LIKE_PATTERN)).toQualifiedName(), + functionResolution.resolveFunction(LIKE_FUNCTION_NAME, fromTypes(VARCHAR, LikePatternType.LIKE_PATTERN)), ImmutableList.of(symbol.toSymbolReference(), new Constant(LikePatternType.LIKE_PATTERN, LikePattern.compile(pattern, Optional.empty())))); } private FunctionCall like(Symbol symbol, Expression pattern, Expression escape) { FunctionCall likePattern = new FunctionCall( - functionResolution.resolveFunction(LIKE_PATTERN_FUNCTION_NAME, fromTypes(VARCHAR, VARCHAR)).toQualifiedName(), + functionResolution.resolveFunction(LIKE_PATTERN_FUNCTION_NAME, fromTypes(VARCHAR, VARCHAR)), ImmutableList.of(symbol.toSymbolReference(), pattern, escape)); return new FunctionCall( - functionResolution.resolveFunction(LIKE_FUNCTION_NAME, fromTypes(VARCHAR, LikePatternType.LIKE_PATTERN)).toQualifiedName(), + functionResolution.resolveFunction(LIKE_FUNCTION_NAME, fromTypes(VARCHAR, LikePatternType.LIKE_PATTERN)), ImmutableList.of(symbol.toSymbolReference(), pattern, likePattern)); } private FunctionCall like(Symbol symbol, String pattern, Character escape) { return new FunctionCall( - functionResolution.resolveFunction(LIKE_FUNCTION_NAME, fromTypes(VARCHAR, LikePatternType.LIKE_PATTERN)).toQualifiedName(), + functionResolution.resolveFunction(LIKE_FUNCTION_NAME, fromTypes(VARCHAR, LikePatternType.LIKE_PATTERN)), ImmutableList.of(symbol.toSymbolReference(), new Constant(LikePatternType.LIKE_PATTERN, LikePattern.compile(pattern, Optional.of(escape))))); } private FunctionCall startsWith(Symbol symbol, Expression expression) { return new FunctionCall( - functionResolution.resolveFunction("starts_with", fromTypes(VARCHAR, VARCHAR)).toQualifiedName(), + functionResolution.resolveFunction("starts_with", fromTypes(VARCHAR, VARCHAR)), ImmutableList.of(symbol.toSymbolReference(), expression)); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestDynamicFilter.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestDynamicFilter.java index c7afdd1137e2..00702b6a8104 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestDynamicFilter.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestDynamicFilter.java @@ -18,6 +18,8 @@ import io.trino.Session; import io.trino.cost.StatsProvider; import io.trino.metadata.Metadata; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.type.Type; import io.trino.sql.DynamicFilters; import io.trino.sql.ir.ArithmeticBinaryExpression; @@ -39,7 +41,6 @@ import io.trino.sql.planner.plan.EnforceSingleRowNode; import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.PlanNode; -import io.trino.sql.tree.QualifiedName; import org.junit.jupiter.api.Test; import static io.trino.SystemSessionProperties.ENABLE_DYNAMIC_FILTERING; @@ -49,6 +50,7 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.createVarcharType; +import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.SUBTRACT; import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; @@ -80,6 +82,9 @@ public class TestDynamicFilter extends BasePlanTest { + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction MOD = FUNCTIONS.resolveFunction("mod", fromTypes(INTEGER, INTEGER)); + public TestDynamicFilter() { super(ImmutableMap.of( @@ -849,7 +854,7 @@ public void testDynamicFilterAliasDeDuplicated() .right( anyTree( project( - ImmutableMap.of("mod", expression(new FunctionCall(QualifiedName.of("mod"), ImmutableList.of(new SymbolReference("n_nationkey"), new Constant(BIGINT, 2L))))), + ImmutableMap.of("mod", expression(new FunctionCall(MOD, ImmutableList.of(new SymbolReference("n_nationkey"), new Constant(BIGINT, 2L))))), tableScan("nation", ImmutableMap.of("n_nationkey", "nationkey"))))))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java index 7b4f2e7f96a0..b2e2a9bbd262 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java @@ -170,8 +170,8 @@ public TableProperties getTableProperties(Session session, TableHandle handle) private final PlannerContext plannerContext = plannerContextBuilder().withMetadata(metadata).build(); private final IrTypeAnalyzer typeAnalyzer = new IrTypeAnalyzer(plannerContext); - private final EffectivePredicateExtractor effectivePredicateExtractor = new EffectivePredicateExtractor(new DomainTranslator(plannerContext), plannerContext, true); - private final EffectivePredicateExtractor effectivePredicateExtractorWithoutTableProperties = new EffectivePredicateExtractor(new DomainTranslator(plannerContext), plannerContext, false); + private final EffectivePredicateExtractor effectivePredicateExtractor = new EffectivePredicateExtractor(new DomainTranslator(), plannerContext, true); + private final EffectivePredicateExtractor effectivePredicateExtractorWithoutTableProperties = new EffectivePredicateExtractor(new DomainTranslator(), plannerContext, false); private Map scanAssignments; private TableScanNode baseTableScan; @@ -690,7 +690,7 @@ public void testValues() ValuesNode node = new ValuesNode( newId(), ImmutableList.of(A, B), - ImmutableList.of(new Row(ImmutableList.of(bigintLiteral(1), new FunctionCall(rand.toQualifiedName(), ImmutableList.of()))))); + ImmutableList.of(new Row(ImmutableList.of(bigintLiteral(1), new FunctionCall(rand, ImmutableList.of()))))); assertThat(extract(types, node)).isEqualTo(new ComparisonExpression(EQUAL, AE, bigintLiteral(1))); // non-constant @@ -1160,7 +1160,7 @@ private Set normalizeConjuncts(Expression... conjuncts) private Set normalizeConjuncts(Collection conjuncts) { - return normalizeConjuncts(combineConjuncts(metadata, conjuncts)); + return normalizeConjuncts(combineConjuncts(conjuncts)); } private Set normalizeConjuncts(Expression predicate) @@ -1170,10 +1170,10 @@ private Set normalizeConjuncts(Expression predicate) predicate = expressionNormalizer.normalize(predicate); // Equality inference rewrites and equality generation will always be stable across multiple runs in the same JVM - EqualityInference inference = new EqualityInference(metadata, predicate); + EqualityInference inference = new EqualityInference(predicate); Set scope = SymbolsExtractor.extractUnique(predicate); - Set rewrittenSet = EqualityInference.nonInferrableConjuncts(metadata, predicate) + Set rewrittenSet = EqualityInference.nonInferrableConjuncts(predicate) .map(expression -> inference.rewrite(expression, scope)) .peek(rewritten -> checkState(rewritten != null, "Rewrite with full symbol scope should always be possible")) .collect(Collectors.toSet()); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestEqualityInference.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestEqualityInference.java index c3dcef8a40ef..835c33ef295a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestEqualityInference.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestEqualityInference.java @@ -18,7 +18,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; -import io.trino.metadata.Metadata; import io.trino.metadata.TestingFunctionResolution; import io.trino.operator.scalar.TryFunction; import io.trino.sql.ir.ArithmeticBinaryExpression; @@ -62,13 +61,11 @@ public class TestEqualityInference { private final TestingFunctionResolution functionResolution = new TestingFunctionResolution(); - private final Metadata metadata = functionResolution.getMetadata(); @Test public void testDoesNotInferRedundantStraddlingPredicates() { EqualityInference inference = new EqualityInference( - metadata, equals("a1", "b1"), equals(add(nameReference("a1"), number(1)), number(0)), equals(nameReference("a2"), add(nameReference("a1"), number(2))), @@ -91,7 +88,6 @@ public void testDoesNotInferRedundantStraddlingPredicates() public void testTransitivity() { EqualityInference inference = new EqualityInference( - metadata, equals("a1", "b1"), equals("b1", "c1"), equals("d1", "c1"), @@ -118,7 +114,7 @@ public void testTransitivity() @Test public void testTriviallyRewritable() { - Expression expression = new EqualityInference(metadata) + Expression expression = new EqualityInference() .rewrite(someExpression("a1", "a2"), symbols("a1", "a2")); assertThat(expression).isEqualTo(someExpression("a1", "a2")); @@ -128,7 +124,6 @@ public void testTriviallyRewritable() public void testUnrewritable() { EqualityInference inference = new EqualityInference( - metadata, equals("a1", "b1"), equals("a2", "b2")); @@ -140,7 +135,6 @@ public void testUnrewritable() public void testParseEqualityExpression() { EqualityInference inference = new EqualityInference( - metadata, equals("a1", "b1"), equals("a1", "c1"), equals("c1", "a1")); @@ -153,7 +147,6 @@ public void testParseEqualityExpression() public void testExtractInferrableEqualities() { EqualityInference inference = new EqualityInference( - metadata, and(equals("a1", "b1"), equals("b1", "c1"), someExpression("c1", "d1"))); // Able to rewrite to c1 due to equalities @@ -167,7 +160,6 @@ public void testExtractInferrableEqualities() public void testEqualityPartitionGeneration() { EqualityInference inference = new EqualityInference( - metadata, equals(nameReference("a1"), nameReference("b1")), equals(add("a1", "a1"), multiply(nameReference("a1"), number(2))), equals(nameReference("b1"), nameReference("c1")), @@ -187,22 +179,21 @@ public void testEqualityPartitionGeneration() // There should be equalities in the scope, that only use c1 and are all inferrable equalities assertThat(equalityPartition.getScopeEqualities().isEmpty()).isFalse(); assertThat(Iterables.all(equalityPartition.getScopeEqualities(), matchesSymbolScope(matchesSymbols("c1")))).isTrue(); - assertThat(Iterables.all(equalityPartition.getScopeEqualities(), expression -> isInferenceCandidate(metadata, expression))).isTrue(); + assertThat(Iterables.all(equalityPartition.getScopeEqualities(), expression -> isInferenceCandidate(expression))).isTrue(); // There should be equalities in the inverse scope, that never use c1 and are all inferrable equalities assertThat(equalityPartition.getScopeComplementEqualities().isEmpty()).isFalse(); assertThat(Iterables.all(equalityPartition.getScopeComplementEqualities(), matchesSymbolScope(not(matchesSymbols("c1"))))).isTrue(); - assertThat(Iterables.all(equalityPartition.getScopeComplementEqualities(), expression -> isInferenceCandidate(metadata, expression))).isTrue(); + assertThat(Iterables.all(equalityPartition.getScopeComplementEqualities(), expression -> isInferenceCandidate(expression))).isTrue(); // There should be equalities in the straddling scope, that should use both c1 and not c1 symbols assertThat(equalityPartition.getScopeStraddlingEqualities().isEmpty()).isFalse(); assertThat(Iterables.any(equalityPartition.getScopeStraddlingEqualities(), matchesStraddlingScope(matchesSymbols("c1")))).isTrue(); - assertThat(Iterables.all(equalityPartition.getScopeStraddlingEqualities(), expression -> isInferenceCandidate(metadata, expression))).isTrue(); + assertThat(Iterables.all(equalityPartition.getScopeStraddlingEqualities(), expression -> isInferenceCandidate(expression))).isTrue(); // There should be a "full cover" of all of the equalities used // THUS, we should be able to plug the generated equalities back in and get an equivalent set of equalities back the next time around EqualityInference newInference = new EqualityInference( - metadata, ImmutableList.builder() .addAll(equalityPartition.getScopeEqualities()) .addAll(equalityPartition.getScopeComplementEqualities()) @@ -220,7 +211,6 @@ public void testEqualityPartitionGeneration() public void testMultipleEqualitySetsPredicateGeneration() { EqualityInference inference = new EqualityInference( - metadata, equals("a1", "b1"), equals("b1", "c1"), equals("c1", "d1"), @@ -234,22 +224,21 @@ public void testMultipleEqualitySetsPredicateGeneration() // There should be equalities in the scope, that only use a* and b* symbols and are all inferrable equalities assertThat(equalityPartition.getScopeEqualities().isEmpty()).isFalse(); assertThat(Iterables.all(equalityPartition.getScopeEqualities(), matchesSymbolScope(symbolBeginsWith("a", "b")))).isTrue(); - assertThat(Iterables.all(equalityPartition.getScopeEqualities(), expression -> isInferenceCandidate(metadata, expression))).isTrue(); + assertThat(Iterables.all(equalityPartition.getScopeEqualities(), expression -> isInferenceCandidate(expression))).isTrue(); // There should be equalities in the inverse scope, that never use a* and b* symbols and are all inferrable equalities assertThat(equalityPartition.getScopeComplementEqualities().isEmpty()).isFalse(); assertThat(Iterables.all(equalityPartition.getScopeComplementEqualities(), matchesSymbolScope(not(symbolBeginsWith("a", "b"))))).isTrue(); - assertThat(Iterables.all(equalityPartition.getScopeComplementEqualities(), expression -> isInferenceCandidate(metadata, expression))).isTrue(); + assertThat(Iterables.all(equalityPartition.getScopeComplementEqualities(), expression -> isInferenceCandidate(expression))).isTrue(); // There should be equalities in the straddling scope, that should use both c1 and not c1 symbols assertThat(equalityPartition.getScopeStraddlingEqualities().isEmpty()).isFalse(); assertThat(Iterables.any(equalityPartition.getScopeStraddlingEqualities(), matchesStraddlingScope(symbolBeginsWith("a", "b")))).isTrue(); - assertThat(Iterables.all(equalityPartition.getScopeStraddlingEqualities(), expression -> isInferenceCandidate(metadata, expression))).isTrue(); + assertThat(Iterables.all(equalityPartition.getScopeStraddlingEqualities(), expression -> isInferenceCandidate(expression))).isTrue(); // Again, there should be a "full cover" of all of the equalities used // THUS, we should be able to plug the generated equalities back in and get an equivalent set of equalities back the next time around EqualityInference newInference = new EqualityInference( - metadata, ImmutableList.builder() .addAll(equalityPartition.getScopeEqualities()) .addAll(equalityPartition.getScopeComplementEqualities()) @@ -267,7 +256,6 @@ public void testMultipleEqualitySetsPredicateGeneration() public void testSubExpressionRewrites() { EqualityInference inference = new EqualityInference( - metadata, equals(nameReference("a1"), add("b", "c")), // a1 = b + c equals(nameReference("a2"), multiply(nameReference("b"), add("b", "c"))), // a2 = b * (b + c) equals(nameReference("a3"), multiply(nameReference("a1"), add("b", "c")))); // a3 = a1 * (b + c) @@ -286,7 +274,6 @@ public void testSubExpressionRewrites() public void testConstantEqualities() { EqualityInference inference = new EqualityInference( - metadata, equals("a1", "b1"), equals("b1", "c1"), equals(nameReference("c1"), number(1))); @@ -307,7 +294,6 @@ public void testConstantEqualities() public void testEqualityGeneration() { EqualityInference inference = new EqualityInference( - metadata, equals(nameReference("a1"), add("b", "c")), // a1 = b + c equals(nameReference("e1"), add("b", "d")), // e1 = b + d equals("c", "d")); @@ -334,7 +320,6 @@ public void testExpressionsThatMayReturnNullOnNonNullInput() for (Expression candidate : candidates) { EqualityInference inference = new EqualityInference( - metadata, equals(nameReference("b"), nameReference("x")), equals(nameReference("a"), candidate)); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestJsonTable.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestJsonTable.java index 4d48318df452..2b8c6149e81b 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestJsonTable.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestJsonTable.java @@ -39,7 +39,6 @@ import io.trino.sql.planner.plan.TableFunctionNode; import io.trino.sql.tree.JsonQuery; import io.trino.sql.tree.JsonValue; -import io.trino.sql.tree.QualifiedName; import org.intellij.lang.annotations.Language; import org.junit.jupiter.api.Test; @@ -77,14 +76,19 @@ public class TestJsonTable extends BasePlanTest { - private static final ResolvedFunction JSON_VALUE_FUNCTION = new TestingFunctionResolution().resolveFunction( + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + + private static final ResolvedFunction JSON_VALUE_FUNCTION = FUNCTIONS.resolveFunction( JSON_VALUE_FUNCTION_NAME, fromTypes(JSON_2016, JSON_PATH_2016, JSON_NO_PARAMETERS_ROW_TYPE, TINYINT, BIGINT, TINYINT, BIGINT)); - private static final ResolvedFunction JSON_QUERY_FUNCTION = new TestingFunctionResolution().resolveFunction( + private static final ResolvedFunction JSON_QUERY_FUNCTION = FUNCTIONS.resolveFunction( JSON_QUERY_FUNCTION_NAME, fromTypes(JSON_2016, JSON_PATH_2016, JSON_NO_PARAMETERS_ROW_TYPE, TINYINT, TINYINT, TINYINT)); + private static final ResolvedFunction JSON_TO_VARCHAR = FUNCTIONS.resolveFunction("$json_to_varchar", fromTypes(JSON_2016, TINYINT, BOOLEAN)); + private static final ResolvedFunction VARCHAR_TO_JSON = FUNCTIONS.resolveFunction("$varchar_to_json", fromTypes(VARCHAR, BOOLEAN)); + @Test public void testJsonTableInitialPlan() { @@ -104,7 +108,7 @@ public void testJsonTableInitialPlan() ImmutableList.of("json_col", "int_col", "bigint_col", "formatted_varchar_col"), anyTree( project( - ImmutableMap.of("formatted_varchar_col", expression(new FunctionCall(QualifiedName.of("$json_to_varchar"), ImmutableList.of(new SymbolReference("varchar_col"), new Constant(TINYINT, 1L), FALSE_LITERAL)))), + ImmutableMap.of("formatted_varchar_col", expression(new FunctionCall(JSON_TO_VARCHAR, ImmutableList.of(new SymbolReference("varchar_col"), new Constant(TINYINT, 1L), FALSE_LITERAL)))), tableFunction(builder -> builder .name("$json_table") .addTableArgument( @@ -116,8 +120,8 @@ public void testJsonTableInitialPlan() .properOutputs(ImmutableList.of("bigint_col", "varchar_col")), project( ImmutableMap.of( - "context_item", expression(new FunctionCall(QualifiedName.of("$varchar_to_json"), ImmutableList.of(new SymbolReference("json_col_coerced"), FALSE_LITERAL))), // apply input function to context item - "parameters_row", expression(new Cast(new Row(ImmutableList.of(new SymbolReference("int_col"), new FunctionCall(QualifiedName.of("$varchar_to_json"), ImmutableList.of(new SymbolReference("name_coerced"), FALSE_LITERAL)))), rowType(field("id", INTEGER), field("name", JSON_2016))))), // apply input function to formatted path parameter and gather path parameters in a row + "context_item", expression(new FunctionCall(VARCHAR_TO_JSON, ImmutableList.of(new SymbolReference("json_col_coerced"), FALSE_LITERAL))), // apply input function to context item + "parameters_row", expression(new Cast(new Row(ImmutableList.of(new SymbolReference("int_col"), new FunctionCall(VARCHAR_TO_JSON, ImmutableList.of(new SymbolReference("name_coerced"), FALSE_LITERAL)))), rowType(field("id", INTEGER), field("name", JSON_2016))))), // apply input function to formatted path parameter and gather path parameters in a row project(// coerce context item, path parameters and default expressions ImmutableMap.of( "name_coerced", expression(new Cast(new SymbolReference("name"), VARCHAR)), // cast formatted path parameter to VARCHAR for the input function diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java index f272ea6b5955..8f9fba5a7fd3 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java @@ -18,10 +18,13 @@ import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slices; import io.trino.Session; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TestingFunctionResolution; import io.trino.plugin.tpch.TpchColumnHandle; import io.trino.plugin.tpch.TpchTableHandle; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.SortOrder; +import io.trino.spi.function.OperatorType; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.Range; import io.trino.spi.predicate.TupleDomain; @@ -79,7 +82,6 @@ import io.trino.sql.planner.rowpattern.ScalarValuePointer; import io.trino.sql.planner.rowpattern.ir.IrLabel; import io.trino.sql.planner.rowpattern.ir.IrQuantified; -import io.trino.sql.tree.QualifiedName; import io.trino.tests.QueryTemplate; import io.trino.type.Reals; import org.junit.jupiter.api.Test; @@ -100,6 +102,7 @@ import static io.trino.SystemSessionProperties.JOIN_REORDERING_STRATEGY; import static io.trino.SystemSessionProperties.OPTIMIZE_HASH_GENERATION; import static io.trino.SystemSessionProperties.TASK_CONCURRENCY; +import static io.trino.metadata.MetadataManager.createTestMetadataManager; import static io.trino.spi.StandardErrorCode.SUBQUERY_MULTIPLE_ROWS; import static io.trino.spi.connector.SortOrder.ASC_NULLS_LAST; import static io.trino.spi.predicate.Domain.multipleValues; @@ -112,6 +115,7 @@ import static io.trino.spi.type.TinyintType.TINYINT; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; +import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MULTIPLY; import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.SUBTRACT; @@ -191,6 +195,12 @@ public class TestLogicalPlanner extends BasePlanTest { + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction FAIL = FUNCTIONS.resolveFunction("fail", fromTypes(INTEGER, VARCHAR)); + private static final ResolvedFunction LOWER = FUNCTIONS.resolveFunction("lower", fromTypes(VARCHAR)); + private static final ResolvedFunction COMBINE_HASH = FUNCTIONS.resolveFunction("combine_hash", fromTypes(BIGINT, BIGINT)); + private static final ResolvedFunction HASH_CODE = createTestMetadataManager().resolveOperator(OperatorType.HASH_CODE, ImmutableList.of(BIGINT)); + private static final WindowNode.Frame ROWS_FROM_CURRENT = new WindowNode.Frame( ROWS, CURRENT_ROW, @@ -305,7 +315,7 @@ public void testAllFieldsDereferenceOnSubquery() public void testAllFieldsDereferenceFromNonDeterministic() { FunctionCall randomFunction = new FunctionCall( - getPlanTester().getPlannerContext().getMetadata().resolveBuiltinFunction("rand", ImmutableList.of()).toQualifiedName(), + getPlanTester().getPlannerContext().getMetadata().resolveBuiltinFunction("rand", ImmutableList.of()), ImmutableList.of()); assertPlan("SELECT (x, x).* FROM (SELECT rand()) T(x)", @@ -867,7 +877,7 @@ public void testCorrelatedScalarSubqueryInSelect() new SimpleCaseExpression( new SymbolReference("is_distinct"), ImmutableList.of(new WhenClause(TRUE_LITERAL, TRUE_LITERAL)), - Optional.of(new Cast(new FunctionCall(QualifiedName.of("fail"), ImmutableList.of(new Constant(INTEGER, (long) SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode()), new Constant(VARCHAR, utf8Slice("Scalar sub-query has returned multiple rows")))), BOOLEAN))), + Optional.of(new Cast(new FunctionCall(FAIL, ImmutableList.of(new Constant(INTEGER, (long) SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode()), new Constant(VARCHAR, utf8Slice("Scalar sub-query has returned multiple rows")))), BOOLEAN))), project( markDistinct("is_distinct", ImmutableList.of("unique"), join(LEFT, builder -> builder @@ -885,7 +895,7 @@ public void testCorrelatedScalarSubqueryInSelect() new SimpleCaseExpression( new SymbolReference("is_distinct"), ImmutableList.of(new WhenClause(TRUE_LITERAL, TRUE_LITERAL)), - Optional.of(new Cast(new FunctionCall(QualifiedName.of("fail"), ImmutableList.of(new Constant(INTEGER, (long) SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode()), new Constant(VARCHAR, utf8Slice("Scalar sub-query has returned multiple rows")))), BOOLEAN))), + Optional.of(new Cast(new FunctionCall(FAIL, ImmutableList.of(new Constant(INTEGER, (long) SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode()), new Constant(VARCHAR, utf8Slice("Scalar sub-query has returned multiple rows")))), BOOLEAN))), project( markDistinct("is_distinct", ImmutableList.of("unique"), join(LEFT, builder -> builder @@ -1143,7 +1153,7 @@ public void testCorrelatedDistinctGroupedAggregationRewriteToLeftOuterJoin() new SimpleCaseExpression( new SymbolReference("is_distinct"), ImmutableList.of(new WhenClause(TRUE_LITERAL, TRUE_LITERAL)), - Optional.of(new Cast(new FunctionCall(QualifiedName.of("fail"), ImmutableList.of(new Constant(INTEGER, (long) SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode()), new Constant(VARCHAR, utf8Slice("Scalar sub-query has returned multiple rows")))), BOOLEAN))), + Optional.of(new Cast(new FunctionCall(FAIL, ImmutableList.of(new Constant(INTEGER, (long) SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode()), new Constant(VARCHAR, utf8Slice("Scalar sub-query has returned multiple rows")))), BOOLEAN))), project(markDistinct( "is_distinct", ImmutableList.of("unique"), @@ -1859,10 +1869,10 @@ public void testRedundantHashRemovalForUnionAll() project( node(AggregationNode.class, exchange(LOCAL, REPARTITION, - project(ImmutableMap.of("hash", expression(new FunctionCall(QualifiedName.of("combine_hash"), ImmutableList.of(new Constant(BIGINT, 0L), new CoalesceExpression(new FunctionCall(QualifiedName.of("$operator$hash_code"), ImmutableList.of(new SymbolReference("nationkey"))), new Constant(BIGINT, 0L)))))), + project(ImmutableMap.of("hash", expression(new FunctionCall(COMBINE_HASH, ImmutableList.of(new Constant(BIGINT, 0L), new CoalesceExpression(new FunctionCall(HASH_CODE, ImmutableList.of(new SymbolReference("nationkey"))), new Constant(BIGINT, 0L)))))), node(AggregationNode.class, tableScan("customer", ImmutableMap.of("nationkey", "nationkey")))), - project(ImmutableMap.of("hash_1", expression(new FunctionCall(QualifiedName.of("combine_hash"), ImmutableList.of(new Constant(BIGINT, 0L), new CoalesceExpression(new FunctionCall(QualifiedName.of("$operator$hash_code"), ImmutableList.of(new SymbolReference("nationkey_6"))), new Constant(BIGINT, 0L)))))), + project(ImmutableMap.of("hash_1", expression(new FunctionCall(COMBINE_HASH, ImmutableList.of(new Constant(BIGINT, 0L), new CoalesceExpression(new FunctionCall(HASH_CODE, ImmutableList.of(new SymbolReference("nationkey_6"))), new Constant(BIGINT, 0L)))))), node(AggregationNode.class, tableScan("customer", ImmutableMap.of("nationkey_6", "nationkey"))))))))); } @@ -1882,8 +1892,8 @@ public void testRedundantHashRemovalForMarkDistinct() node(MarkDistinctNode.class, anyTree( project(ImmutableMap.of( - "hash_1", expression(new FunctionCall(QualifiedName.of("combine_hash"), ImmutableList.of(new Constant(BIGINT, 0L), new CoalesceExpression(new FunctionCall(QualifiedName.of("$operator$hash_code"), ImmutableList.of(new SymbolReference("suppkey"))), new Constant(BIGINT, 0L))))), - "hash_2", expression(new FunctionCall(QualifiedName.of("combine_hash"), ImmutableList.of(new Constant(BIGINT, 0L), new CoalesceExpression(new FunctionCall(QualifiedName.of("$operator$hash_code"), ImmutableList.of(new SymbolReference("partkey"))), new Constant(BIGINT, 0L)))))), + "hash_1", expression(new FunctionCall(COMBINE_HASH, ImmutableList.of(new Constant(BIGINT, 0L), new CoalesceExpression(new FunctionCall(HASH_CODE, ImmutableList.of(new SymbolReference("suppkey"))), new Constant(BIGINT, 0L))))), + "hash_2", expression(new FunctionCall(COMBINE_HASH, ImmutableList.of(new Constant(BIGINT, 0L), new CoalesceExpression(new FunctionCall(HASH_CODE, ImmutableList.of(new SymbolReference("partkey"))), new Constant(BIGINT, 0L)))))), node(MarkDistinctNode.class, tableScan("lineitem", ImmutableMap.of("suppkey", "suppkey", "partkey", "partkey")))))))))); } @@ -1904,8 +1914,8 @@ public void testRedundantHashRemovalForUnionAllAndMarkDistinct() exchange(LOCAL, REPARTITION, exchange(REMOTE, REPARTITION, project(ImmutableMap.of( - "hash_custkey", expression(new FunctionCall(QualifiedName.of("combine_hash"), ImmutableList.of(new Constant(BIGINT, 0L), new CoalesceExpression(new FunctionCall(QualifiedName.of("$operator$hash_code"), ImmutableList.of(new SymbolReference("custkey"))), new Constant(BIGINT, 0L))))), - "hash_nationkey", expression(new FunctionCall(QualifiedName.of("combine_hash"), ImmutableList.of(new Constant(BIGINT, 0L), new CoalesceExpression(new FunctionCall(QualifiedName.of("$operator$hash_code"), ImmutableList.of(new SymbolReference("nationkey"))), new Constant(BIGINT, 0L)))))), + "hash_custkey", expression(new FunctionCall(COMBINE_HASH, ImmutableList.of(new Constant(BIGINT, 0L), new CoalesceExpression(new FunctionCall(HASH_CODE, ImmutableList.of(new SymbolReference("custkey"))), new Constant(BIGINT, 0L))))), + "hash_nationkey", expression(new FunctionCall(COMBINE_HASH, ImmutableList.of(new Constant(BIGINT, 0L), new CoalesceExpression(new FunctionCall(HASH_CODE, ImmutableList.of(new SymbolReference("nationkey"))), new Constant(BIGINT, 0L)))))), tableScan("customer", ImmutableMap.of("custkey", "custkey", "nationkey", "nationkey")))), exchange(REMOTE, REPARTITION, node(ProjectNode.class, @@ -2401,7 +2411,7 @@ public void testMergePatternRecognitionNodesWithProjections() ImmutableMap.of( "output1", expression(new SymbolReference("id")), "output2", expression(new ArithmeticBinaryExpression(MULTIPLY, new SymbolReference("value"), new Constant(INTEGER, 2L))), - "output3", expression(new FunctionCall(QualifiedName.of("lower"), ImmutableList.of(new SymbolReference("label")))), + "output3", expression(new FunctionCall(LOWER, ImmutableList.of(new SymbolReference("label")))), "output4", expression(new ArithmeticBinaryExpression(ADD, new SymbolReference("min"), new Constant(INTEGER, 1L)))), project( ImmutableMap.of( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java index 9b6ecc10fde9..37a708edfb97 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java @@ -82,7 +82,7 @@ public void testPartialTranslator() assertFullTranslation(new ArithmeticBinaryExpression(ADD, symbolReference1, dereferenceExpression1)); Expression functionCallExpression = new FunctionCall( - PLANNER_CONTEXT.getMetadata().resolveBuiltinFunction("concat", fromTypes(VARCHAR, VARCHAR)).toQualifiedName(), + PLANNER_CONTEXT.getMetadata().resolveBuiltinFunction("concat", fromTypes(VARCHAR, VARCHAR)), ImmutableList.of(stringLiteral, dereferenceExpression2)); assertFullTranslation(functionCallExpression); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestPredicatePushdown.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestPredicatePushdown.java index b326eebd5e80..6ba89d377ede 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestPredicatePushdown.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestPredicatePushdown.java @@ -17,6 +17,8 @@ import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slices; import io.trino.Session; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TestingFunctionResolution; import io.trino.sql.ir.Cast; import io.trino.sql.ir.ComparisonExpression; import io.trino.sql.ir.Constant; @@ -25,7 +27,6 @@ import io.trino.sql.ir.NotExpression; import io.trino.sql.ir.SymbolReference; import io.trino.sql.planner.plan.ExchangeNode; -import io.trino.sql.tree.QualifiedName; import org.junit.jupiter.api.Test; import static io.trino.SystemSessionProperties.JOIN_REORDERING_STRATEGY; @@ -33,6 +34,7 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; +import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; @@ -46,6 +48,9 @@ public class TestPredicatePushdown extends AbstractPredicatePushdownTest { + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction RANDOM = FUNCTIONS.resolveFunction("random", fromTypes(INTEGER)); + public TestPredicatePushdown() { super(true); @@ -173,7 +178,7 @@ public void testNonDeterministicPredicateDoesNotPropagateFromFilteringSideToSour "LINE_ORDER_KEY", "orderkey"))), node(ExchangeNode.class, filter( - new ComparisonExpression(EQUAL, new SymbolReference("ORDERS_ORDER_KEY"), new Cast(new FunctionCall(QualifiedName.of("random"), ImmutableList.of(new Constant(INTEGER, 5L))), BIGINT)), + new ComparisonExpression(EQUAL, new SymbolReference("ORDERS_ORDER_KEY"), new Cast(new FunctionCall(RANDOM, ImmutableList.of(new Constant(INTEGER, 5L))), BIGINT)), tableScan("orders", ImmutableMap.of("ORDERS_ORDER_KEY", "orderkey"))))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestPredicatePushdownWithoutDynamicFilter.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestPredicatePushdownWithoutDynamicFilter.java index 4db359ed6bad..ebe52b968ec7 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestPredicatePushdownWithoutDynamicFilter.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestPredicatePushdownWithoutDynamicFilter.java @@ -17,6 +17,8 @@ import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slices; import io.trino.Session; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TestingFunctionResolution; import io.trino.sql.ir.Cast; import io.trino.sql.ir.ComparisonExpression; import io.trino.sql.ir.Constant; @@ -25,7 +27,6 @@ import io.trino.sql.ir.NotExpression; import io.trino.sql.ir.SymbolReference; import io.trino.sql.planner.plan.ExchangeNode; -import io.trino.sql.tree.QualifiedName; import org.junit.jupiter.api.Test; import static io.trino.SystemSessionProperties.JOIN_REORDERING_STRATEGY; @@ -33,6 +34,7 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; +import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; @@ -47,6 +49,9 @@ public class TestPredicatePushdownWithoutDynamicFilter extends AbstractPredicatePushdownTest { + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction RANDOM = FUNCTIONS.resolveFunction("random", fromTypes()); + public TestPredicatePushdownWithoutDynamicFilter() { super(false); @@ -169,7 +174,7 @@ public void testNonDeterministicPredicateDoesNotPropagateFromFilteringSideToSour "LINE_ORDER_KEY", "orderkey")), node(ExchangeNode.class, filter( - new ComparisonExpression(EQUAL, new SymbolReference("ORDERS_ORDER_KEY"), new Cast(new FunctionCall(QualifiedName.of("random"), ImmutableList.of(new Constant(INTEGER, 5L))), BIGINT)), + new ComparisonExpression(EQUAL, new SymbolReference("ORDERS_ORDER_KEY"), new Cast(new FunctionCall(RANDOM, ImmutableList.of(new Constant(INTEGER, 5L))), BIGINT)), tableScan("orders", ImmutableMap.of("ORDERS_ORDER_KEY", "orderkey"))))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestRecursiveCte.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestRecursiveCte.java index 061d9f752758..5874362c52b3 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestRecursiveCte.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestRecursiveCte.java @@ -17,6 +17,8 @@ import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slices; import io.trino.Session; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TestingFunctionResolution; import io.trino.sql.ir.ArithmeticBinaryExpression; import io.trino.sql.ir.Cast; import io.trino.sql.ir.ComparisonExpression; @@ -26,7 +28,6 @@ import io.trino.sql.ir.SymbolReference; import io.trino.sql.planner.assertions.BasePlanTest; import io.trino.sql.planner.assertions.PlanMatchPattern; -import io.trino.sql.tree.QualifiedName; import io.trino.testing.PlanTester; import org.intellij.lang.annotations.Language; import org.junit.jupiter.api.Test; @@ -36,6 +37,7 @@ import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; @@ -55,6 +57,9 @@ public class TestRecursiveCte extends BasePlanTest { + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction FAIL = FUNCTIONS.resolveFunction("fail", fromTypes(INTEGER, VARCHAR)); + @Override protected PlanTester createPlanTester() { @@ -93,7 +98,7 @@ public void testRecursiveQuery() filter( new IfExpression( new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference("count"), new Constant(BIGINT, 0L)), - new Cast(new FunctionCall(QualifiedName.of("fail"), ImmutableList.of(new Constant(INTEGER, (long) NOT_SUPPORTED.toErrorCode().getCode()), new Constant(VARCHAR, Slices.utf8Slice("Recursion depth limit exceeded (1). Use 'max_recursion_depth' session property to modify the limit.")))), BOOLEAN), + new Cast(new FunctionCall(FAIL, ImmutableList.of(new Constant(INTEGER, (long) NOT_SUPPORTED.toErrorCode().getCode()), new Constant(VARCHAR, Slices.utf8Slice("Recursion depth limit exceeded (1). Use 'max_recursion_depth' session property to modify the limit.")))), BOOLEAN), TRUE_LITERAL), window(windowBuilder -> windowBuilder .addFunction( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestSortExpressionExtractor.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestSortExpressionExtractor.java index f38f3ca4634c..07b95856514f 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestSortExpressionExtractor.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestSortExpressionExtractor.java @@ -17,7 +17,6 @@ import com.google.common.collect.ImmutableSet; import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; -import io.trino.sql.PlannerContext; import io.trino.sql.ir.ArithmeticBinaryExpression; import io.trino.sql.ir.BetweenPredicate; import io.trino.sql.ir.ComparisonExpression; @@ -26,7 +25,6 @@ import io.trino.sql.ir.FunctionCall; import io.trino.sql.ir.LogicalExpression; import io.trino.sql.ir.SymbolReference; -import io.trino.transaction.TestingTransactionManager; import org.junit.jupiter.api.Test; import java.util.Arrays; @@ -46,15 +44,10 @@ import static io.trino.sql.ir.IrUtils.extractConjuncts; import static io.trino.sql.ir.LogicalExpression.Operator.AND; import static io.trino.sql.ir.LogicalExpression.Operator.OR; -import static io.trino.sql.planner.TestingPlannerContext.plannerContextBuilder; import static org.assertj.core.api.Assertions.assertThat; public class TestSortExpressionExtractor { - private static final TestingTransactionManager TRANSACTION_MANAGER = new TestingTransactionManager(); - private static final PlannerContext PLANNER_CONTEXT = plannerContextBuilder() - .withTransactionManager(TRANSACTION_MANAGER) - .build(); private static final Set BUILD_SYMBOLS = ImmutableSet.of(new Symbol("b1"), new Symbol("b2")); private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); @@ -77,31 +70,31 @@ public void testGetSortExpression() "b2"); assertGetSortExpression( - new ComparisonExpression(GREATER_THAN, new SymbolReference("b2"), new FunctionCall(SIN.toQualifiedName(), ImmutableList.of(new SymbolReference("p1")))), + new ComparisonExpression(GREATER_THAN, new SymbolReference("b2"), new FunctionCall(SIN, ImmutableList.of(new SymbolReference("p1")))), "b2"); - assertNoSortExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference("b2"), new FunctionCall(RANDOM.toQualifiedName(), ImmutableList.of(new SymbolReference("p1"))))); + assertNoSortExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference("b2"), new FunctionCall(RANDOM, ImmutableList.of(new SymbolReference("p1"))))); assertGetSortExpression( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference("b2"), new FunctionCall(RANDOM.toQualifiedName(), ImmutableList.of(new SymbolReference("p1")))), new ComparisonExpression(GREATER_THAN, new SymbolReference("b2"), new SymbolReference("p1")))), + new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference("b2"), new FunctionCall(RANDOM, ImmutableList.of(new SymbolReference("p1")))), new ComparisonExpression(GREATER_THAN, new SymbolReference("b2"), new SymbolReference("p1")))), "b2", new ComparisonExpression(GREATER_THAN, new SymbolReference("b2"), new SymbolReference("p1"))); assertGetSortExpression( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference("b2"), new FunctionCall(RANDOM.toQualifiedName(), ImmutableList.of(new SymbolReference("p1")))), new ComparisonExpression(GREATER_THAN, new SymbolReference("b1"), new SymbolReference("p1")))), + new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new SymbolReference("b2"), new FunctionCall(RANDOM, ImmutableList.of(new SymbolReference("p1")))), new ComparisonExpression(GREATER_THAN, new SymbolReference("b1"), new SymbolReference("p1")))), "b1", new ComparisonExpression(GREATER_THAN, new SymbolReference("b1"), new SymbolReference("p1"))); assertNoSortExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference("b1"), new ArithmeticBinaryExpression(ADD, new SymbolReference("p1"), new SymbolReference("b2")))); - assertNoSortExpression(new ComparisonExpression(GREATER_THAN, new FunctionCall(SIN.toQualifiedName(), ImmutableList.of(new SymbolReference("b1"))), new SymbolReference("p1"))); + assertNoSortExpression(new ComparisonExpression(GREATER_THAN, new FunctionCall(SIN, ImmutableList.of(new SymbolReference("b1"))), new SymbolReference("p1"))); assertNoSortExpression(new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference("b1"), new SymbolReference("p1")), new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference("b2"), new SymbolReference("p1"))))); - assertNoSortExpression(new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new FunctionCall(SIN.toQualifiedName(), ImmutableList.of(new SymbolReference("b2"))), new SymbolReference("p1")), new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference("b2"), new SymbolReference("p1")), new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference("b2"), new ArithmeticBinaryExpression(ADD, new SymbolReference("p1"), new Constant(INTEGER, 10L)))))))); + assertNoSortExpression(new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new FunctionCall(SIN, ImmutableList.of(new SymbolReference("b2"))), new SymbolReference("p1")), new LogicalExpression(OR, ImmutableList.of(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference("b2"), new SymbolReference("p1")), new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference("b2"), new ArithmeticBinaryExpression(ADD, new SymbolReference("p1"), new Constant(INTEGER, 10L)))))))); assertGetSortExpression( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new FunctionCall(SIN.toQualifiedName(), ImmutableList.of(new SymbolReference("b2"))), new SymbolReference("p1")), new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference("b2"), new SymbolReference("p1")), new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference("b2"), new ArithmeticBinaryExpression(ADD, new SymbolReference("p1"), new Constant(INTEGER, 10L))))))), + new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(GREATER_THAN, new FunctionCall(SIN, ImmutableList.of(new SymbolReference("b2"))), new SymbolReference("p1")), new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference("b2"), new SymbolReference("p1")), new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference("b2"), new ArithmeticBinaryExpression(ADD, new SymbolReference("p1"), new Constant(INTEGER, 10L))))))), "b2", new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference("b2"), new SymbolReference("p1")), new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference("b2"), new ArithmeticBinaryExpression(ADD, new SymbolReference("p1"), new Constant(INTEGER, 10L)))); @@ -147,7 +140,7 @@ public void testGetSortExpression() private void assertNoSortExpression(Expression expression) { - Optional actual = SortExpressionExtractor.extractSortExpression(PLANNER_CONTEXT.getMetadata(), BUILD_SYMBOLS, expression); + Optional actual = SortExpressionExtractor.extractSortExpression(BUILD_SYMBOLS, expression); assertThat(actual).isEqualTo(Optional.empty()); } @@ -165,7 +158,7 @@ private void assertGetSortExpression(Expression expression, String expectedSymbo private void assertGetSortExpression(Expression expression, String expectedSymbol, List searchExpressions) { Optional expected = Optional.of(new SortExpressionContext(new SymbolReference(expectedSymbol), searchExpressions)); - Optional actual = SortExpressionExtractor.extractSortExpression(PLANNER_CONTEXT.getMetadata(), BUILD_SYMBOLS, expression); + Optional actual = SortExpressionExtractor.extractSortExpression(BUILD_SYMBOLS, expression); assertThat(actual).isEqualTo(expected); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestUnwrapCastInComparison.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestUnwrapCastInComparison.java index 5f6e4c1d37d2..0984b58a5c38 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestUnwrapCastInComparison.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestUnwrapCastInComparison.java @@ -16,6 +16,8 @@ import com.google.common.collect.ImmutableList; import io.airlift.slice.Slices; import io.trino.Session; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.type.TimeZoneKey; import io.trino.sql.ir.Cast; import io.trino.sql.ir.ComparisonExpression; @@ -27,7 +29,6 @@ import io.trino.sql.ir.NotExpression; import io.trino.sql.ir.SymbolReference; import io.trino.sql.planner.assertions.BasePlanTest; -import io.trino.sql.tree.QualifiedName; import io.trino.type.DateTimes; import io.trino.type.Reals; import io.trino.util.DateTimeUtils; @@ -45,6 +46,7 @@ import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS; import static io.trino.spi.type.TinyintType.TINYINT; import static io.trino.spi.type.VarcharType.createVarcharType; +import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; @@ -64,6 +66,9 @@ public class TestUnwrapCastInComparison extends BasePlanTest { + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction RANDOM = FUNCTIONS.resolveFunction("random", fromTypes()); + @Test public void testEquals() { @@ -807,7 +812,7 @@ private void testRemoveFilter(String inputType, String inputPredicate) { assertPlan(format("SELECT * FROM (VALUES CAST(NULL AS %s)) t(a) WHERE %s AND rand() = 42", inputType, inputPredicate), output( - filter(new ComparisonExpression(EQUAL, new FunctionCall(QualifiedName.of("random"), ImmutableList.of()), new Constant(DOUBLE, 42.0)), + filter(new ComparisonExpression(EQUAL, new FunctionCall(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 42.0)), values("a")))); } @@ -818,7 +823,7 @@ private void testUnwrap(String inputType, String inputPredicate, Expression expe private void testUnwrap(Session session, String inputType, String inputPredicate, Expression expected) { - Expression antiOptimization = new ComparisonExpression(EQUAL, new FunctionCall(QualifiedName.of("random"), ImmutableList.of()), new Constant(DOUBLE, 42.0)); + Expression antiOptimization = new ComparisonExpression(EQUAL, new FunctionCall(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 42.0)); if (expected instanceof LogicalExpression logical && logical.getOperator() == OR) { expected = new LogicalExpression(OR, ImmutableList.builder() .addAll(logical.getTerms()) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestUnwrapYearInComparison.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestUnwrapYearInComparison.java index baaf9b897b4c..823e9dbd61c3 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestUnwrapYearInComparison.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestUnwrapYearInComparison.java @@ -14,6 +14,8 @@ package io.trino.sql.planner; import com.google.common.collect.ImmutableList; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.type.LongTimestamp; import io.trino.sql.ir.BetweenPredicate; import io.trino.sql.ir.ComparisonExpression; @@ -25,7 +27,6 @@ import io.trino.sql.ir.NotExpression; import io.trino.sql.ir.SymbolReference; import io.trino.sql.planner.assertions.BasePlanTest; -import io.trino.sql.tree.QualifiedName; import io.trino.type.DateTimes; import io.trino.util.DateTimeUtils; import org.junit.jupiter.api.Test; @@ -44,6 +45,7 @@ import static io.trino.spi.type.TimestampType.createTimestampType; import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_SECOND; import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_MICROSECOND; +import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; @@ -64,6 +66,11 @@ public class TestUnwrapYearInComparison extends BasePlanTest { + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction RANDOM = FUNCTIONS.resolveFunction("random", fromTypes()); + private static final ResolvedFunction YEAR_DATE = FUNCTIONS.resolveFunction("year", fromTypes(DATE)); + private static final ResolvedFunction YEAR_TIMESTAMP_3 = FUNCTIONS.resolveFunction("year", fromTypes(createTimestampType(3))); + @Test public void testEquals() { @@ -277,8 +284,8 @@ public void testNull() @Test public void testNaN() { - testUnwrap("date", "year(a) = nan()", new LogicalExpression(AND, ImmutableList.of(new IsNullPredicate(new FunctionCall(QualifiedName.of("year"), ImmutableList.of(new SymbolReference("a")))), new Constant(BOOLEAN, null)))); - testUnwrap("timestamp", "year(a) = nan()", new LogicalExpression(AND, ImmutableList.of(new IsNullPredicate(new FunctionCall(QualifiedName.of("year"), ImmutableList.of(new SymbolReference("a")))), new Constant(BOOLEAN, null)))); + testUnwrap("date", "year(a) = nan()", new LogicalExpression(AND, ImmutableList.of(new IsNullPredicate(new FunctionCall(YEAR_DATE, ImmutableList.of(new SymbolReference("a")))), new Constant(BOOLEAN, null)))); + testUnwrap("timestamp", "year(a) = nan()", new LogicalExpression(AND, ImmutableList.of(new IsNullPredicate(new FunctionCall(YEAR_TIMESTAMP_3, ImmutableList.of(new SymbolReference("a")))), new Constant(BOOLEAN, null)))); } @Test @@ -335,7 +342,7 @@ private static long toEpochMicros(LocalDateTime localDateTime) private void testUnwrap(String inputType, String inputPredicate, Expression expected) { - Expression antiOptimization = new ComparisonExpression(EQUAL, new FunctionCall(QualifiedName.of("random"), ImmutableList.of()), new Constant(DOUBLE, 42.0)); + Expression antiOptimization = new ComparisonExpression(EQUAL, new FunctionCall(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 42.0)); if (expected instanceof LogicalExpression logical && logical.getOperator() == OR) { expected = new LogicalExpression(OR, ImmutableList.builder() .addAll(logical.getTerms()) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestWindowClause.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestWindowClause.java index d689baf194f3..9f68cdecafe1 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestWindowClause.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestWindowClause.java @@ -15,7 +15,9 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.trino.metadata.ResolvedFunction; import io.trino.spi.connector.SortOrder; +import io.trino.spi.function.OperatorType; import io.trino.sql.ir.ArithmeticBinaryExpression; import io.trino.sql.ir.ArithmeticUnaryExpression; import io.trino.sql.ir.Cast; @@ -26,12 +28,12 @@ import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.WindowNode; -import io.trino.sql.tree.QualifiedName; import org.intellij.lang.annotations.Language; import org.junit.jupiter.api.Test; import java.util.Optional; +import static io.trino.metadata.MetadataManager.createTestMetadataManager; import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; @@ -58,6 +60,9 @@ public class TestWindowClause extends BasePlanTest { + private static final ResolvedFunction SUBTRACT_INTEGER = createTestMetadataManager().resolveOperator(OperatorType.SUBTRACT, ImmutableList.of(INTEGER, INTEGER)); + private static final ResolvedFunction ADD_DOUBLE = createTestMetadataManager().resolveOperator(OperatorType.ADD, ImmutableList.of(DOUBLE, DOUBLE)); + @Test public void testPreprojectExpression() { @@ -106,7 +111,7 @@ public void testPreprojectExpressions() Optional.empty(), Optional.empty()))), project( - ImmutableMap.of("frame_start", expression(new FunctionCall(QualifiedName.of("$operator$subtract"), ImmutableList.of(new SymbolReference("expr_b"), new SymbolReference("expr_c"))))), + ImmutableMap.of("frame_start", expression(new FunctionCall(SUBTRACT_INTEGER, ImmutableList.of(new SymbolReference("expr_b"), new SymbolReference("expr_c"))))), anyTree(project( ImmutableMap.of( "expr_a", expression(new ArithmeticBinaryExpression(ADD, new SymbolReference("a"), new Constant(INTEGER, 1L))), @@ -180,7 +185,7 @@ public void testWindowWithFrameCoercions() Optional.of(new Symbol("frame_bound")), Optional.of(new Symbol("coerced_sortkey"))))), project(// frame bound value computation - ImmutableMap.of("frame_bound", expression(new FunctionCall(QualifiedName.of("$operator$add"), ImmutableList.of(new SymbolReference("coerced_sortkey"), new SymbolReference("frame_offset"))))), + ImmutableMap.of("frame_bound", expression(new FunctionCall(ADD_DOUBLE, ImmutableList.of(new SymbolReference("coerced_sortkey"), new SymbolReference("frame_offset"))))), project(// sort key coercion to frame bound type ImmutableMap.of("coerced_sortkey", expression(new Cast(new SymbolReference("sortkey"), DOUBLE))), node(FilterNode.class, diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestWindowFrameRange.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestWindowFrameRange.java index f6479cdad32f..9ac5059d9159 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestWindowFrameRange.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestWindowFrameRange.java @@ -16,7 +16,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slices; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.connector.SortOrder; +import io.trino.spi.function.OperatorType; import io.trino.spi.type.Decimals; import io.trino.sql.ir.Cast; import io.trino.sql.ir.ComparisonExpression; @@ -27,18 +30,19 @@ import io.trino.sql.planner.assertions.BasePlanTest; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.plan.WindowNode; -import io.trino.sql.tree.QualifiedName; import org.intellij.lang.annotations.Language; import org.junit.jupiter.api.Test; import java.math.BigDecimal; import java.util.Optional; +import static io.trino.metadata.MetadataManager.createTestMetadataManager; import static io.trino.spi.StandardErrorCode.INVALID_WINDOW_FRAME; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.DecimalType.createDecimalType; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; import static io.trino.sql.planner.LogicalPlanner.Stage.CREATED; @@ -58,6 +62,13 @@ public class TestWindowFrameRange extends BasePlanTest { + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction FAIL = FUNCTIONS.resolveFunction("fail", fromTypes(INTEGER, VARCHAR)); + private static final ResolvedFunction ADD_DECIMAL_10_0 = createTestMetadataManager().resolveOperator(OperatorType.ADD, ImmutableList.of(createDecimalType(10, 0), createDecimalType(10, 0))); + private static final ResolvedFunction SUBTRACT_DECIMAL_10_0 = createTestMetadataManager().resolveOperator(OperatorType.SUBTRACT, ImmutableList.of(createDecimalType(10, 0), createDecimalType(10, 0))); + private static final ResolvedFunction ADD_INTEGER = createTestMetadataManager().resolveOperator(OperatorType.ADD, ImmutableList.of(INTEGER, INTEGER)); + private static final ResolvedFunction SUBTRACT_INTEGER = createTestMetadataManager().resolveOperator(OperatorType.SUBTRACT, ImmutableList.of(INTEGER, INTEGER)); + @Test public void testFramePrecedingWithSortKeyCoercions() { @@ -88,14 +99,14 @@ public void testFramePrecedingWithSortKeyCoercions() project(// coerce sort key to compare sort key values with frame start values ImmutableMap.of("key_for_frame_start_comparison", expression(new Cast(new SymbolReference("key"), createDecimalType(12, 1)))), project(// calculate frame start value (sort key - frame offset) - ImmutableMap.of("frame_start_value", expression(new FunctionCall(QualifiedName.of("$operator$subtract"), ImmutableList.of(new SymbolReference("key_for_frame_start_calculation"), new SymbolReference("x"))))), + ImmutableMap.of("frame_start_value", expression(new FunctionCall(SUBTRACT_DECIMAL_10_0, ImmutableList.of(new SymbolReference("key_for_frame_start_calculation"), new SymbolReference("x"))))), project(// coerce sort key to calculate frame start values ImmutableMap.of("key_for_frame_start_calculation", expression(new Cast(new SymbolReference("key"), createDecimalType(10, 0)))), filter(// validate offset values new IfExpression( new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference("x"), new Constant(createDecimalType(2, 1), 0L)), TRUE_LITERAL, - new Cast(new FunctionCall(QualifiedName.of("fail"), ImmutableList.of(new Constant(INTEGER, (long) INVALID_WINDOW_FRAME.toErrorCode().getCode()), new Constant(VARCHAR, Slices.utf8Slice("Window frame offset value must not be negative or null")))), BOOLEAN)), + new Cast(new FunctionCall(FAIL, ImmutableList.of(new Constant(INTEGER, (long) INVALID_WINDOW_FRAME.toErrorCode().getCode()), new Constant(VARCHAR, Slices.utf8Slice("Window frame offset value must not be negative or null")))), BOOLEAN)), anyTree( values( ImmutableList.of("key", "x"), @@ -136,12 +147,12 @@ public void testFrameFollowingWithOffsetCoercion() project(// coerce sort key to compare sort key values with frame end values ImmutableMap.of("key_for_frame_end_comparison", expression(new Cast(new SymbolReference("key"), createDecimalType(12, 1)))), project(// calculate frame end value (sort key + frame offset) - ImmutableMap.of("frame_end_value", expression(new FunctionCall(QualifiedName.of("$operator$add"), ImmutableList.of(new SymbolReference("key"), new SymbolReference("offset"))))), + ImmutableMap.of("frame_end_value", expression(new FunctionCall(ADD_DECIMAL_10_0, ImmutableList.of(new SymbolReference("key"), new SymbolReference("offset"))))), filter(// validate offset values new IfExpression( new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference("offset"), new Constant(createDecimalType(10, 0), 0L)), TRUE_LITERAL, - new Cast(new FunctionCall(QualifiedName.of("fail"), ImmutableList.of(new Constant(INTEGER, (long) INVALID_WINDOW_FRAME.toErrorCode().getCode()), new Constant(VARCHAR, Slices.utf8Slice("Window frame offset value must not be negative or null")))), BOOLEAN)), + new Cast(new FunctionCall(FAIL, ImmutableList.of(new Constant(INTEGER, (long) INVALID_WINDOW_FRAME.toErrorCode().getCode()), new Constant(VARCHAR, Slices.utf8Slice("Window frame offset value must not be negative or null")))), BOOLEAN)), project(// coerce offset value to calculate frame end values ImmutableMap.of("offset", expression(new Cast(new SymbolReference("x"), createDecimalType(10, 0)))), anyTree( @@ -182,19 +193,19 @@ public void testFramePrecedingFollowingNoCoercions() Optional.of(new Symbol("frame_end_value")), Optional.of(new Symbol("key"))))), project(// calculate frame end value (sort key + frame end offset) - ImmutableMap.of("frame_end_value", expression(new FunctionCall(QualifiedName.of("$operator$add"), ImmutableList.of(new SymbolReference("key"), new SymbolReference("y"))))), + ImmutableMap.of("frame_end_value", expression(new FunctionCall(ADD_INTEGER, ImmutableList.of(new SymbolReference("key"), new SymbolReference("y"))))), filter(// validate frame end offset values new IfExpression( new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference("y"), new Constant(INTEGER, 0L)), TRUE_LITERAL, - new Cast(new FunctionCall(QualifiedName.of("fail"), ImmutableList.of(new Constant(INTEGER, (long) INVALID_WINDOW_FRAME.toErrorCode().getCode()), new Constant(VARCHAR, Slices.utf8Slice("Window frame offset value must not be negative or null")))), BOOLEAN)), + new Cast(new FunctionCall(FAIL, ImmutableList.of(new Constant(INTEGER, (long) INVALID_WINDOW_FRAME.toErrorCode().getCode()), new Constant(VARCHAR, Slices.utf8Slice("Window frame offset value must not be negative or null")))), BOOLEAN)), project(// calculate frame start value (sort key - frame start offset) - ImmutableMap.of("frame_start_value", expression(new FunctionCall(QualifiedName.of("$operator$subtract"), ImmutableList.of(new SymbolReference("key"), new SymbolReference("x"))))), + ImmutableMap.of("frame_start_value", expression(new FunctionCall(SUBTRACT_INTEGER, ImmutableList.of(new SymbolReference("key"), new SymbolReference("x"))))), filter(// validate frame start offset values new IfExpression( new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference("x"), new Constant(INTEGER, 0L)), TRUE_LITERAL, - new Cast(new FunctionCall(QualifiedName.of("fail"), ImmutableList.of(new Constant(INTEGER, (long) INVALID_WINDOW_FRAME.toErrorCode().getCode()), new Constant(VARCHAR, Slices.utf8Slice("Window frame offset value must not be negative or null")))), BOOLEAN)), + new Cast(new FunctionCall(FAIL, ImmutableList.of(new Constant(INTEGER, (long) INVALID_WINDOW_FRAME.toErrorCode().getCode()), new Constant(VARCHAR, Slices.utf8Slice("Window frame offset value must not be negative or null")))), BOOLEAN)), anyTree( values( ImmutableList.of("key", "x", "y"), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/BasePlanTest.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/BasePlanTest.java index e668b7233ef7..c9d1fb1ad71a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/BasePlanTest.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/BasePlanTest.java @@ -191,7 +191,7 @@ protected void assertMinimallyOptimizedPlan(@Language("SQL") String sql, PlanMat { Metadata metadata = getPlanTester().getPlannerContext().getMetadata(); List optimizers = ImmutableList.of( - new UnaliasSymbolReferences(metadata), + new UnaliasSymbolReferences(), new IterativeOptimizer( planTester.getPlannerContext(), new RuleStatsRecorder(), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/CorrelatedJoinMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/CorrelatedJoinMatcher.java index f11387b9e8d3..1842ea99557d 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/CorrelatedJoinMatcher.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/CorrelatedJoinMatcher.java @@ -52,7 +52,7 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses Expression filter = correlatedJoinNode.getFilter(); ExpressionVerifier verifier = new ExpressionVerifier(symbolAliases); DynamicFilters.ExtractResult extractResult = extractDynamicFilters(filter); - return new MatchResult(verifier.process(combineConjuncts(metadata, extractResult.getStaticConjuncts()), filter)); + return new MatchResult(verifier.process(combineConjuncts(extractResult.getStaticConjuncts()), filter)); } @Override diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionVerifier.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionVerifier.java index b50600133372..fe02323f1181 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionVerifier.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/ExpressionVerifier.java @@ -13,7 +13,6 @@ */ package io.trino.sql.planner.assertions; -import io.trino.spi.function.CatalogSchemaFunctionName; import io.trino.sql.ir.ArithmeticBinaryExpression; import io.trino.sql.ir.ArithmeticUnaryExpression; import io.trino.sql.ir.BetweenPredicate; @@ -42,10 +41,6 @@ import java.util.Objects; import java.util.Optional; -import static com.google.common.base.Preconditions.checkArgument; -import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; -import static io.trino.metadata.ResolvedFunction.extractFunctionName; -import static io.trino.metadata.ResolvedFunction.isResolved; import static java.util.Objects.requireNonNull; /** @@ -314,16 +309,7 @@ protected Boolean visitFunctionCall(FunctionCall actual, Expression expectedExpr return false; } - CatalogSchemaFunctionName expectedFunctionName; - if (isResolved(expected.getName())) { - expectedFunctionName = extractFunctionName(expected.getName()); - } - else { - checkArgument(expected.getName().getParts().size() == 1, "Unresolved function call name must not be qualified: %s", expected.getName()); - expectedFunctionName = builtinFunctionName(expected.getName().getSuffix()); - } - - return extractFunctionName(actual.getName()).equals(expectedFunctionName) && + return actual.getFunction().getName().equals(expected.getFunction().getName()) && process(actual.getArguments(), expected.getArguments()); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/FilterMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/FilterMatcher.java index 7564beeffbf7..75476a1aaeee 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/FilterMatcher.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/FilterMatcher.java @@ -57,11 +57,11 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses ExpressionVerifier verifier = new ExpressionVerifier(symbolAliases); if (dynamicFilter.isPresent()) { - return new MatchResult(verifier.process(filterPredicate, combineConjuncts(metadata, predicate, dynamicFilter.get()))); + return new MatchResult(verifier.process(filterPredicate, combineConjuncts(predicate, dynamicFilter.get()))); } DynamicFilters.ExtractResult extractResult = extractDynamicFilters(filterPredicate); - return new MatchResult(verifier.process(combineConjuncts(metadata, extractResult.getStaticConjuncts()), predicate)); + return new MatchResult(verifier.process(combineConjuncts(extractResult.getStaticConjuncts()), predicate)); } @Override diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestArraySortAfterArrayDistinct.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestArraySortAfterArrayDistinct.java index 98eaf7cce11e..143dba97c441 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestArraySortAfterArrayDistinct.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestArraySortAfterArrayDistinct.java @@ -49,21 +49,21 @@ public class TestArraySortAfterArrayDistinct public void testArrayDistinctAfterArraySort() { test( - new FunctionCall(DISTINCT.toQualifiedName(), ImmutableList.of(new FunctionCall(SORT.toQualifiedName(), ImmutableList.of(new FunctionCall(ARRAY.toQualifiedName(), ImmutableList.of(new Constant(VARCHAR, Slices.utf8Slice("a")))))))), - new FunctionCall(SORT.toQualifiedName(), ImmutableList.of(new FunctionCall(DISTINCT.toQualifiedName(), ImmutableList.of(new FunctionCall(ARRAY.toQualifiedName(), ImmutableList.of(new Constant(VARCHAR, Slices.utf8Slice("a"))))))))); + new FunctionCall(DISTINCT, ImmutableList.of(new FunctionCall(SORT, ImmutableList.of(new FunctionCall(ARRAY, ImmutableList.of(new Constant(VARCHAR, Slices.utf8Slice("a")))))))), + new FunctionCall(SORT, ImmutableList.of(new FunctionCall(DISTINCT, ImmutableList.of(new FunctionCall(ARRAY, ImmutableList.of(new Constant(VARCHAR, Slices.utf8Slice("a"))))))))); } @Test public void testArrayDistinctAfterArraySortWithLambda() { test( - new FunctionCall(DISTINCT.toQualifiedName(), ImmutableList.of( - new FunctionCall(SORT_WITH_LAMBDA.toQualifiedName(), ImmutableList.of( - new FunctionCall(ARRAY.toQualifiedName(), ImmutableList.of(new Constant(VARCHAR, Slices.utf8Slice("a")))), + new FunctionCall(DISTINCT, ImmutableList.of( + new FunctionCall(SORT_WITH_LAMBDA, ImmutableList.of( + new FunctionCall(ARRAY, ImmutableList.of(new Constant(VARCHAR, Slices.utf8Slice("a")))), new LambdaExpression(ImmutableList.of("a", "b"), new Constant(INTEGER, 1L)))))), - new FunctionCall(SORT_WITH_LAMBDA.toQualifiedName(), ImmutableList.of( - new FunctionCall(DISTINCT.toQualifiedName(), ImmutableList.of( - new FunctionCall(ARRAY.toQualifiedName(), ImmutableList.of(new Constant(VARCHAR, Slices.utf8Slice("a")))))), + new FunctionCall(SORT_WITH_LAMBDA, ImmutableList.of( + new FunctionCall(DISTINCT, ImmutableList.of( + new FunctionCall(ARRAY, ImmutableList.of(new Constant(VARCHAR, Slices.utf8Slice("a")))))), new LambdaExpression(ImmutableList.of("a", "b"), new Constant(INTEGER, 1L))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestCanonicalizeExpressionRewriter.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestCanonicalizeExpressionRewriter.java index db033b72d0df..ecd448a5f822 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestCanonicalizeExpressionRewriter.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestCanonicalizeExpressionRewriter.java @@ -202,7 +202,7 @@ public void testCanonicalizeRewriteDateFunctionToCast() private static void assertCanonicalizedDate(Type type, String symbolName) { FunctionCall date = new FunctionCall( - PLANNER_CONTEXT.getMetadata().resolveBuiltinFunction("date", fromTypes(type)).toQualifiedName(), + PLANNER_CONTEXT.getMetadata().resolveBuiltinFunction("date", fromTypes(type)), ImmutableList.of(new SymbolReference(symbolName))); assertRewritten(date, new Cast(new SymbolReference(symbolName), DATE)); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateInnerUnnestWithGlobalAggregation.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateInnerUnnestWithGlobalAggregation.java index b8e8c3cd7e04..5d85aa1bfa09 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateInnerUnnestWithGlobalAggregation.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateInnerUnnestWithGlobalAggregation.java @@ -16,6 +16,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slices; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TestingFunctionResolution; import io.trino.sql.ir.ArithmeticBinaryExpression; import io.trino.sql.ir.ArithmeticUnaryExpression; import io.trino.sql.ir.Constant; @@ -28,7 +30,6 @@ import io.trino.sql.planner.iterative.rule.test.PlanBuilder; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.UnnestNode; -import io.trino.sql.tree.QualifiedName; import org.junit.jupiter.api.Test; import java.util.Optional; @@ -58,6 +59,9 @@ public class TestDecorrelateInnerUnnestWithGlobalAggregation extends BaseRuleTest { + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction REGEXP_EXTRACT_ALL = FUNCTIONS.resolveFunction("regexp_extract_all", fromTypes(VARCHAR, VARCHAR)); + @Test public void doesNotFireWithoutGlobalAggregation() { @@ -313,7 +317,7 @@ public void testPreprojectUnnestSymbol() .on(p -> { Symbol corr = p.symbol("corr", VARCHAR); FunctionCall regexpExtractAll = new FunctionCall( - tester().getMetadata().resolveBuiltinFunction("regexp_extract_all", fromTypes(VARCHAR, VARCHAR)).toQualifiedName(), + tester().getMetadata().resolveBuiltinFunction("regexp_extract_all", fromTypes(VARCHAR, VARCHAR)), ImmutableList.of(corr.toSymbolReference(), new Constant(VARCHAR, Slices.utf8Slice(".")))); return p.correlatedJoin( @@ -348,7 +352,7 @@ public void testPreprojectUnnestSymbol() Optional.of("ordinality"), LEFT, project( - ImmutableMap.of("char_array", expression(new FunctionCall(QualifiedName.of("regexp_extract_all"), ImmutableList.of(new SymbolReference("corr"), new Constant(VARCHAR, Slices.utf8Slice(".")))))), + ImmutableMap.of("char_array", expression(new FunctionCall(REGEXP_EXTRACT_ALL, ImmutableList.of(new SymbolReference("corr"), new Constant(VARCHAR, Slices.utf8Slice(".")))))), assignUniqueId("unique", values("corr")))))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateLeftUnnestWithGlobalAggregation.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateLeftUnnestWithGlobalAggregation.java index bf4bcb265cd9..9c31ad89c4c9 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateLeftUnnestWithGlobalAggregation.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateLeftUnnestWithGlobalAggregation.java @@ -16,6 +16,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slices; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TestingFunctionResolution; import io.trino.sql.ir.ArithmeticBinaryExpression; import io.trino.sql.ir.ArithmeticUnaryExpression; import io.trino.sql.ir.Constant; @@ -26,7 +28,6 @@ import io.trino.sql.planner.iterative.rule.test.PlanBuilder; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.UnnestNode; -import io.trino.sql.tree.QualifiedName; import org.junit.jupiter.api.Test; import java.util.Optional; @@ -54,6 +55,9 @@ public class TestDecorrelateLeftUnnestWithGlobalAggregation extends BaseRuleTest { + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction REGEXP_EXTRACT_ALL = FUNCTIONS.resolveFunction("regexp_extract_all", fromTypes(VARCHAR, VARCHAR)); + @Test public void doesNotFireWithoutGlobalAggregation() { @@ -294,7 +298,7 @@ public void testPreprojectUnnestSymbol() .on(p -> { Symbol corr = p.symbol("corr", VARCHAR); FunctionCall regexpExtractAll = new FunctionCall( - tester().getMetadata().resolveBuiltinFunction("regexp_extract_all", fromTypes(VARCHAR, VARCHAR)).toQualifiedName(), + tester().getMetadata().resolveBuiltinFunction("regexp_extract_all", fromTypes(VARCHAR, VARCHAR)), ImmutableList.of(corr.toSymbolReference(), new Constant(VARCHAR, Slices.utf8Slice(".")))); return p.correlatedJoin( @@ -326,7 +330,7 @@ public void testPreprojectUnnestSymbol() Optional.empty(), LEFT, project( - ImmutableMap.of("char_array", expression(new FunctionCall(QualifiedName.of("regexp_extract_all"), ImmutableList.of(new SymbolReference("corr"), new Constant(VARCHAR, Slices.utf8Slice(".")))))), + ImmutableMap.of("char_array", expression(new FunctionCall(REGEXP_EXTRACT_ALL, ImmutableList.of(new SymbolReference("corr"), new Constant(VARCHAR, Slices.utf8Slice(".")))))), assignUniqueId("unique", values("corr"))))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateUnnest.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateUnnest.java index aaf02bcec845..22c7af59291f 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateUnnest.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDecorrelateUnnest.java @@ -16,6 +16,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slices; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TestingFunctionResolution; import io.trino.sql.ir.Cast; import io.trino.sql.ir.ComparisonExpression; import io.trino.sql.ir.Constant; @@ -29,7 +31,6 @@ import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.JoinType; import io.trino.sql.planner.plan.UnnestNode; -import io.trino.sql.tree.QualifiedName; import org.junit.jupiter.api.Test; import java.util.Optional; @@ -61,6 +62,9 @@ public class TestDecorrelateUnnest extends BaseRuleTest { + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction FAIL = FUNCTIONS.resolveFunction("fail", fromTypes(INTEGER, VARCHAR)); + @Test public void doesNotFireWithoutUnnest() { @@ -213,7 +217,7 @@ public void testEnforceSingleRow() "corr", expression(new SymbolReference("corr")), "unnested_corr", expression(new IfExpression(new IsNullPredicate(new SymbolReference("ordinality")), new Constant(BIGINT, null), new SymbolReference("unnested_corr")))), filter( - new IfExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference("row_number"), new Constant(BIGINT, 1L)), new Cast(new FunctionCall(QualifiedName.of("fail"), ImmutableList.of(new Constant(INTEGER, 28L), new Constant(VARCHAR, Slices.utf8Slice("Scalar sub-query has returned multiple rows")))), BOOLEAN), TRUE_LITERAL), + new IfExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference("row_number"), new Constant(BIGINT, 1L)), new Cast(new FunctionCall(FAIL, ImmutableList.of(new Constant(INTEGER, 28L), new Constant(VARCHAR, Slices.utf8Slice("Scalar sub-query has returned multiple rows")))), BOOLEAN), TRUE_LITERAL), rowNumber( builder -> builder .partitionBy(ImmutableList.of("unique")) @@ -426,7 +430,7 @@ public void testDifferentNodesInSubquery() .matches( project( filter(// enforce single row - new IfExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference("row_number"), new Constant(BIGINT, 1L)), new Cast(new FunctionCall(QualifiedName.of("fail"), ImmutableList.of(new Constant(INTEGER, 28L), new Constant(VARCHAR, Slices.utf8Slice("Scalar sub-query has returned multiple rows")))), BOOLEAN), TRUE_LITERAL), + new IfExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference("row_number"), new Constant(BIGINT, 1L)), new Cast(new FunctionCall(FAIL, ImmutableList.of(new Constant(INTEGER, 28L), new Constant(VARCHAR, Slices.utf8Slice("Scalar sub-query has returned multiple rows")))), BOOLEAN), TRUE_LITERAL), project(// second projection ImmutableMap.of( "corr", expression(new SymbolReference("corr")), @@ -494,7 +498,7 @@ public void testPreprojectUnnestSymbol() .on(p -> { Symbol corr = p.symbol("corr", VARCHAR); FunctionCall regexpExtractAll = new FunctionCall( - tester().getMetadata().resolveBuiltinFunction("regexp_extract_all", fromTypes(VARCHAR, VARCHAR)).toQualifiedName(), + tester().getMetadata().resolveBuiltinFunction("regexp_extract_all", fromTypes(VARCHAR, VARCHAR)), ImmutableList.of(corr.toSymbolReference(), new Constant(VARCHAR, Slices.utf8Slice(".")))); return p.correlatedJoin( @@ -519,7 +523,7 @@ public void testPreprojectUnnestSymbol() Optional.of("ordinality"), LEFT, project( - ImmutableMap.of("char_array", expression(new FunctionCall(QualifiedName.of("regexp_extract_all"), ImmutableList.of(new SymbolReference("corr"), new Constant(VARCHAR, Slices.utf8Slice(".")))))), + ImmutableMap.of("char_array", expression(new FunctionCall(tester().getMetadata().resolveBuiltinFunction("regexp_extract_all", fromTypes(VARCHAR, VARCHAR)), ImmutableList.of(new SymbolReference("corr"), new Constant(VARCHAR, Slices.utf8Slice(".")))))), assignUniqueId("unique", values("corr")))))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementExceptAll.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementExceptAll.java index 89f77f002eca..2c42f328a5b5 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementExceptAll.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementExceptAll.java @@ -16,6 +16,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TestingFunctionResolution; import io.trino.sql.ir.ArithmeticBinaryExpression; import io.trino.sql.ir.ComparisonExpression; import io.trino.sql.ir.Constant; @@ -25,13 +27,13 @@ import io.trino.sql.planner.assertions.SetOperationOutputMatcher; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.WindowNode; -import io.trino.sql.tree.QualifiedName; import org.junit.jupiter.api.Test; import java.util.Optional; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.SUBTRACT; import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; @@ -51,6 +53,9 @@ public class TestImplementExceptAll extends BaseRuleTest { + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction GREATEST = FUNCTIONS.resolveFunction("GREATEST", fromTypes(BIGINT, BIGINT)); + @Test public void test() { @@ -89,7 +94,7 @@ public void test() "a", expression(new SymbolReference("a")), "b", expression(new SymbolReference("b"))), filter( - new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference("row_number"), new FunctionCall(QualifiedName.of("greatest"), ImmutableList.of(new ArithmeticBinaryExpression(SUBTRACT, new SymbolReference("count_1"), new SymbolReference("count_2")), new Constant(BIGINT, 0L)))), + new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference("row_number"), new FunctionCall(GREATEST, ImmutableList.of(new ArithmeticBinaryExpression(SUBTRACT, new SymbolReference("count_1"), new SymbolReference("count_2")), new Constant(BIGINT, 0L)))), strictProject( ImmutableMap.of( "a", expression(new SymbolReference("a")), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementFilteredAggregations.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementFilteredAggregations.java index a90287eaaf84..e6431399c0a9 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementFilteredAggregations.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementFilteredAggregations.java @@ -44,7 +44,7 @@ public class TestImplementFilteredAggregations @Test public void testFilterToMask() { - tester().assertThat(new ImplementFilteredAggregations(tester().getMetadata())) + tester().assertThat(new ImplementFilteredAggregations()) .on(p -> { Symbol a = p.symbol("a"); Symbol g = p.symbol("g"); @@ -75,7 +75,7 @@ public void testFilterToMask() @Test public void testCombineMaskAndFilter() { - tester().assertThat(new ImplementFilteredAggregations(tester().getMetadata())) + tester().assertThat(new ImplementFilteredAggregations()) .on(p -> { Symbol a = p.symbol("a"); Symbol g = p.symbol("g"); @@ -113,7 +113,7 @@ public void testCombineMaskAndFilter() @Test public void testWithFilterPushdown() { - tester().assertThat(new ImplementFilteredAggregations(tester().getMetadata())) + tester().assertThat(new ImplementFilteredAggregations()) .on(p -> { Symbol a = p.symbol("a"); Symbol g = p.symbol("g"); @@ -144,7 +144,7 @@ public void testWithFilterPushdown() @Test public void testWithMultipleAggregations() { - tester().assertThat(new ImplementFilteredAggregations(tester().getMetadata())) + tester().assertThat(new ImplementFilteredAggregations()) .on(p -> { Symbol a = p.symbol("a"); Symbol g = p.symbol("g"); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementIntersectAll.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementIntersectAll.java index 2528e95301b6..3b5189e51e9f 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementIntersectAll.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementIntersectAll.java @@ -16,6 +16,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TestingFunctionResolution; import io.trino.sql.ir.ComparisonExpression; import io.trino.sql.ir.Constant; import io.trino.sql.ir.FunctionCall; @@ -24,12 +26,13 @@ import io.trino.sql.planner.assertions.SetOperationOutputMatcher; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.WindowNode; -import io.trino.sql.tree.QualifiedName; import org.junit.jupiter.api.Test; import java.util.Optional; +import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; @@ -48,6 +51,9 @@ public class TestImplementIntersectAll extends BaseRuleTest { + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction LEAST = FUNCTIONS.resolveFunction("least", fromTypes(BIGINT, BIGINT)); + @Test public void test() { @@ -86,7 +92,7 @@ public void test() "a", expression(new SymbolReference("a")), "b", expression(new SymbolReference("b"))), filter( - new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference("row_number"), new FunctionCall(QualifiedName.of("least"), ImmutableList.of(new SymbolReference("count_1"), new SymbolReference("count_2")))), + new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference("row_number"), new FunctionCall(LEAST, ImmutableList.of(new SymbolReference("count_1"), new SymbolReference("count_2")))), strictProject( ImmutableMap.of( "a", expression(new SymbolReference("a")), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjectIntoFilter.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjectIntoFilter.java index d20f2f245187..9e8478237398 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjectIntoFilter.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestInlineProjectIntoFilter.java @@ -43,7 +43,7 @@ public class TestInlineProjectIntoFilter @Test public void testInlineProjection() { - tester().assertThat(new InlineProjectIntoFilter(tester().getMetadata())) + tester().assertThat(new InlineProjectIntoFilter()) .on(p -> p.filter( new SymbolReference("a"), p.project( @@ -58,7 +58,7 @@ public void testInlineProjection() ImmutableMap.of("b", expression(new SymbolReference("b"))), values("b"))))); - tester().assertThat(new InlineProjectIntoFilter(tester().getMetadata())) + tester().assertThat(new InlineProjectIntoFilter()) .on(p -> { Symbol a = p.symbol("a"); Symbol b = p.symbol("b"); @@ -87,7 +87,7 @@ public void testInlineProjection() @Test public void testNoSimpleConjuncts() { - tester().assertThat(new InlineProjectIntoFilter(tester().getMetadata())) + tester().assertThat(new InlineProjectIntoFilter()) .on(p -> p.filter( new LogicalExpression(OR, ImmutableList.of(new SymbolReference("a"), FALSE_LITERAL)), p.project( @@ -99,7 +99,7 @@ public void testNoSimpleConjuncts() @Test public void testMultipleReferencesToConjunct() { - tester().assertThat(new InlineProjectIntoFilter(tester().getMetadata())) + tester().assertThat(new InlineProjectIntoFilter()) .on(p -> p.filter( new LogicalExpression(AND, ImmutableList.of(new SymbolReference("a"), new SymbolReference("a"))), p.project( @@ -107,7 +107,7 @@ public void testMultipleReferencesToConjunct() p.values(p.symbol("b"))))) .doesNotFire(); - tester().assertThat(new InlineProjectIntoFilter(tester().getMetadata())) + tester().assertThat(new InlineProjectIntoFilter()) .on(p -> p.filter( new LogicalExpression(AND, ImmutableList.of(new SymbolReference("a"), new LogicalExpression(OR, ImmutableList.of(new SymbolReference("a"), FALSE_LITERAL)))), p.project( @@ -119,7 +119,7 @@ public void testMultipleReferencesToConjunct() @Test public void testInlineMultiple() { - tester().assertThat(new InlineProjectIntoFilter(tester().getMetadata())) + tester().assertThat(new InlineProjectIntoFilter()) .on(p -> p.filter( new LogicalExpression(AND, ImmutableList.of(new SymbolReference("a"), new SymbolReference("b"))), p.project( @@ -139,7 +139,7 @@ public void testInlineMultiple() @Test public void testInlinePartially() { - tester().assertThat(new InlineProjectIntoFilter(tester().getMetadata())) + tester().assertThat(new InlineProjectIntoFilter()) .on(p -> p.filter( new LogicalExpression(AND, ImmutableList.of(new SymbolReference("a"), new SymbolReference("a"), new SymbolReference("b"))), p.project( @@ -163,7 +163,7 @@ public void testInlinePartially() public void testTrivialProjection() { // identity projection - tester().assertThat(new InlineProjectIntoFilter(tester().getMetadata())) + tester().assertThat(new InlineProjectIntoFilter()) .on(p -> p.filter( new SymbolReference("a"), p.project( @@ -172,7 +172,7 @@ public void testTrivialProjection() .doesNotFire(); // renaming projection - tester().assertThat(new InlineProjectIntoFilter(tester().getMetadata())) + tester().assertThat(new InlineProjectIntoFilter()) .on(p -> p.filter( new SymbolReference("a"), p.project( @@ -184,7 +184,7 @@ public void testTrivialProjection() @Test public void testCorrelationSymbol() { - tester().assertThat(new InlineProjectIntoFilter(tester().getMetadata())) + tester().assertThat(new InlineProjectIntoFilter()) .on(p -> p.filter( new SymbolReference("corr"), p.project( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinEnumerator.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinEnumerator.java index 94125c4f7988..b74bcce8551b 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinEnumerator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinEnumerator.java @@ -104,7 +104,6 @@ public void testDoesNotCreateJoinWhenPartitionedOnCrossJoin() ImmutableList.of(a1, b1), false); JoinEnumerator joinEnumerator = new JoinEnumerator( - planTester.getPlannerContext().getMetadata(), new CostComparator(1, 1, 1), multiJoinNode.getFilter(), createContext()); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeFilters.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeFilters.java index 20908542a40c..ba49f60ca788 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeFilters.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeFilters.java @@ -15,7 +15,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.metadata.Metadata; import io.trino.sql.ir.ComparisonExpression; import io.trino.sql.ir.Constant; import io.trino.sql.ir.LogicalExpression; @@ -23,7 +22,6 @@ import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import org.junit.jupiter.api.Test; -import static io.trino.metadata.MetadataManager.createTestMetadataManager; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; @@ -34,12 +32,10 @@ public class TestMergeFilters extends BaseRuleTest { - private final Metadata metadata = createTestMetadataManager(); - @Test public void test() { - tester().assertThat(new MergeFilters(metadata)) + tester().assertThat(new MergeFilters()) .on(p -> p.filter( new ComparisonExpression(GREATER_THAN, new SymbolReference("b"), new Constant(INTEGER, 44L)), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergePatternRecognitionNodes.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergePatternRecognitionNodes.java index cdf061e5832d..dc86b4052519 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergePatternRecognitionNodes.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergePatternRecognitionNodes.java @@ -37,7 +37,6 @@ import io.trino.sql.planner.rowpattern.MatchNumberValuePointer; import io.trino.sql.planner.rowpattern.ScalarValuePointer; import io.trino.sql.planner.rowpattern.ir.IrLabel; -import io.trino.sql.tree.QualifiedName; import org.junit.jupiter.api.Test; import java.util.Optional; @@ -105,7 +104,7 @@ public void testSpecificationsDoNotMatch() .doesNotFire(); // aggregations in variable definitions do not match - QualifiedName count = tester().getMetadata().resolveBuiltinFunction("count", fromTypes(BIGINT)).toQualifiedName(); + ResolvedFunction count = tester().getMetadata().resolveBuiltinFunction("count", fromTypes(BIGINT)); tester().assertThat(new MergePatternRecognitionNodesWithoutProject()) .on(p -> p.patternRecognition(parentBuilder -> parentBuilder .pattern(new IrLabel("X")) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeProjectWithValues.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeProjectWithValues.java index 7aa23fbfbaa5..16ae0bf27b44 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeProjectWithValues.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeProjectWithValues.java @@ -48,7 +48,7 @@ public class TestMergeProjectWithValues @Test public void testDoesNotFireOnNonRowType() { - tester().assertThat(new MergeProjectWithValues(tester().getMetadata())) + tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( Assignments.of(), @@ -64,7 +64,7 @@ public void testDoesNotFireOnNonRowType() public void testProjectWithoutOutputSymbols() { // ValuesNode has two output symbols and two rows - tester().assertThat(new MergeProjectWithValues(tester().getMetadata())) + tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( Assignments.of(), @@ -76,7 +76,7 @@ public void testProjectWithoutOutputSymbols() .matches(values(2)); // ValuesNode has no output symbols and two rows - tester().assertThat(new MergeProjectWithValues(tester().getMetadata())) + tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( Assignments.of(), @@ -86,7 +86,7 @@ public void testProjectWithoutOutputSymbols() .matches(values(2)); // ValuesNode has two output symbols and no rows - tester().assertThat(new MergeProjectWithValues(tester().getMetadata())) + tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( Assignments.of(), @@ -96,7 +96,7 @@ public void testProjectWithoutOutputSymbols() .matches(values()); // ValuesNode has no output symbols and no rows - tester().assertThat(new MergeProjectWithValues(tester().getMetadata())) + tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( Assignments.of(), @@ -110,7 +110,7 @@ public void testProjectWithoutOutputSymbols() public void testValuesWithoutOutputSymbols() { // ValuesNode has two rows. Projected expressions are reproduced for every row of ValuesNode. - tester().assertThat(new MergeProjectWithValues(tester().getMetadata())) + tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( Assignments.of(p.symbol("a"), new Constant(createCharType(1), Slices.utf8Slice("x")), p.symbol("b"), TRUE_LITERAL), @@ -124,7 +124,7 @@ public void testValuesWithoutOutputSymbols() ImmutableList.of(new Constant(createCharType(1), Slices.utf8Slice("x")), TRUE_LITERAL)))); // ValuesNode has no rows - tester().assertThat(new MergeProjectWithValues(tester().getMetadata())) + tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( Assignments.of(p.symbol("a"), new Constant(createCharType(1), Slices.utf8Slice("x")), p.symbol("b"), TRUE_LITERAL), @@ -138,10 +138,10 @@ public void testValuesWithoutOutputSymbols() public void testNonDeterministicValues() { FunctionCall randomFunction = new FunctionCall( - tester().getMetadata().resolveBuiltinFunction("random", ImmutableList.of()).toQualifiedName(), + tester().getMetadata().resolveBuiltinFunction("random", ImmutableList.of()), ImmutableList.of()); - tester().assertThat(new MergeProjectWithValues(tester().getMetadata())) + tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( Assignments.of(p.symbol("rand"), new SymbolReference("rand")), p.valuesOfExpressions( @@ -153,7 +153,7 @@ public void testNonDeterministicValues() ImmutableList.of(ImmutableList.of(randomFunction)))); // ValuesNode has multiple rows - tester().assertThat(new MergeProjectWithValues(tester().getMetadata())) + tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( Assignments.of(p.symbol("output"), new SymbolReference("value")), p.valuesOfExpressions( @@ -171,7 +171,7 @@ public void testNonDeterministicValues() ImmutableList.of(new ArithmeticUnaryExpression(MINUS, randomFunction))))); // ValuesNode has multiple non-deterministic outputs - tester().assertThat(new MergeProjectWithValues(tester().getMetadata())) + tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( Assignments.of( p.symbol("x"), new ArithmeticUnaryExpression(MINUS, new SymbolReference("a")), @@ -195,10 +195,10 @@ public void testNonDeterministicValues() public void testDoNotFireOnNonDeterministicValues() { FunctionCall randomFunction = new FunctionCall( - tester().getMetadata().resolveBuiltinFunction("random", ImmutableList.of()).toQualifiedName(), + tester().getMetadata().resolveBuiltinFunction("random", ImmutableList.of()), ImmutableList.of()); - tester().assertThat(new MergeProjectWithValues(tester().getMetadata())) + tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( Assignments.of( p.symbol("x"), new SymbolReference("rand"), @@ -208,7 +208,7 @@ public void testDoNotFireOnNonDeterministicValues() ImmutableList.of(new Row(ImmutableList.of(randomFunction)))))) .doesNotFire(); - tester().assertThat(new MergeProjectWithValues(tester().getMetadata())) + tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( Assignments.of(p.symbol("x"), new ArithmeticBinaryExpression(ADD, new SymbolReference("rand"), new SymbolReference("rand"))), p.valuesOfExpressions( @@ -221,7 +221,7 @@ public void testDoNotFireOnNonDeterministicValues() public void testCorrelation() { // correlation symbol in projection (note: the resulting plan is not yet supported in execution) - tester().assertThat(new MergeProjectWithValues(tester().getMetadata())) + tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( Assignments.of(p.symbol("x"), new ArithmeticBinaryExpression(ADD, new SymbolReference("a"), new SymbolReference("corr"))), p.valuesOfExpressions( @@ -230,7 +230,7 @@ public void testCorrelation() .matches(values(ImmutableList.of("x"), ImmutableList.of(ImmutableList.of(new ArithmeticBinaryExpression(ADD, new Constant(INTEGER, 1L), new SymbolReference("corr")))))); // correlation symbol in values (note: the resulting plan is not yet supported in execution) - tester().assertThat(new MergeProjectWithValues(tester().getMetadata())) + tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( Assignments.of(p.symbol("x"), new SymbolReference("a")), p.valuesOfExpressions( @@ -239,7 +239,7 @@ public void testCorrelation() .matches(values(ImmutableList.of("x"), ImmutableList.of(ImmutableList.of(new SymbolReference("corr"))))); // correlation symbol is not present in the resulting expression - tester().assertThat(new MergeProjectWithValues(tester().getMetadata())) + tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( Assignments.of(p.symbol("x"), new Constant(INTEGER, 1L)), p.valuesOfExpressions( @@ -253,7 +253,7 @@ public void testFailingExpression() { FunctionCall failFunction = failFunction(tester().getMetadata(), GENERIC_USER_ERROR, "message"); - tester().assertThat(new MergeProjectWithValues(tester().getMetadata())) + tester().assertThat(new MergeProjectWithValues()) .on(p -> p.project( Assignments.of(p.symbol("x"), failFunction), p.valuesOfExpressions( @@ -265,7 +265,7 @@ public void testFailingExpression() @Test public void testMergeProjectWithValues() { - tester().assertThat(new MergeProjectWithValues(tester().getMetadata())) + tester().assertThat(new MergeProjectWithValues()) .on(p -> { Symbol a = p.symbol("a"); Symbol b = p.symbol("b"); @@ -295,7 +295,7 @@ public void testMergeProjectWithValues() ImmutableList.of(new Constant(createCharType(1), Slices.utf8Slice("z")), TRUE_LITERAL, new IsNullPredicate(new Constant(createCharType(1), Slices.utf8Slice("z"))), new Constant(INTEGER, 1L))))); // ValuesNode has no rows - tester().assertThat(new MergeProjectWithValues(tester().getMetadata())) + tester().assertThat(new MergeProjectWithValues()) .on(p -> { Symbol a = p.symbol("a"); Symbol b = p.symbol("b"); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestOptimizeDuplicateInsensitiveJoins.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestOptimizeDuplicateInsensitiveJoins.java index c2a5ebc196aa..5d51710bb52f 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestOptimizeDuplicateInsensitiveJoins.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestOptimizeDuplicateInsensitiveJoins.java @@ -16,6 +16,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TestingFunctionResolution; import io.trino.sql.ir.ComparisonExpression; import io.trino.sql.ir.Constant; import io.trino.sql.ir.FunctionCall; @@ -25,10 +27,10 @@ import io.trino.sql.planner.iterative.rule.test.PlanBuilder; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.JoinNode; -import io.trino.sql.tree.QualifiedName; import org.junit.jupiter.api.Test; import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; @@ -43,10 +45,13 @@ public class TestOptimizeDuplicateInsensitiveJoins extends BaseRuleTest { + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction RANDOM = FUNCTIONS.resolveFunction("random", fromTypes()); + @Test public void testNoAggregation() { - tester().assertThat(new OptimizeDuplicateInsensitiveJoins(tester().getMetadata())) + tester().assertThat(new OptimizeDuplicateInsensitiveJoins()) .on(p -> p.join( INNER, p.values(p.symbol("a")), @@ -57,7 +62,7 @@ public void testNoAggregation() @Test public void testAggregation() { - tester().assertThat(new OptimizeDuplicateInsensitiveJoins(tester().getMetadata())) + tester().assertThat(new OptimizeDuplicateInsensitiveJoins()) .on(p -> { Symbol symbolA = p.symbol("a"); Symbol symbolB = p.symbol("b"); @@ -76,7 +81,7 @@ public void testAggregation() @Test public void testEmptyAggregation() { - tester().assertThat(new OptimizeDuplicateInsensitiveJoins(tester().getMetadata())) + tester().assertThat(new OptimizeDuplicateInsensitiveJoins()) .on(p -> { Symbol symbolA = p.symbol("a"); Symbol symbolB = p.symbol("b"); @@ -98,7 +103,7 @@ public void testEmptyAggregation() @Test public void testNestedJoins() { - tester().assertThat(new OptimizeDuplicateInsensitiveJoins(tester().getMetadata())) + tester().assertThat(new OptimizeDuplicateInsensitiveJoins()) .on(p -> { Symbol symbolA = p.symbol("a"); Symbol symbolB = p.symbol("b"); @@ -134,10 +139,10 @@ public void testNestedJoins() public void testNondeterministicJoins() { FunctionCall randomFunction = new FunctionCall( - tester().getMetadata().resolveBuiltinFunction("random", ImmutableList.of()).toQualifiedName(), + tester().getMetadata().resolveBuiltinFunction("random", ImmutableList.of()), ImmutableList.of()); - tester().assertThat(new OptimizeDuplicateInsensitiveJoins(tester().getMetadata())) + tester().assertThat(new OptimizeDuplicateInsensitiveJoins()) .on(p -> { Symbol symbolA = p.symbol("a"); Symbol symbolB = p.symbol("b"); @@ -156,7 +161,7 @@ public void testNondeterministicJoins() .matches( aggregation(ImmutableMap.of(), join(INNER, builder -> builder - .filter(new ComparisonExpression(GREATER_THAN, new SymbolReference("B"), new FunctionCall(QualifiedName.of("random"), ImmutableList.of()))) + .filter(new ComparisonExpression(GREATER_THAN, new SymbolReference("B"), new FunctionCall(RANDOM, ImmutableList.of()))) .left(values("A")) .right( join(INNER, rightJoinBuilder -> rightJoinBuilder @@ -170,10 +175,10 @@ public void testNondeterministicJoins() public void testNondeterministicFilter() { FunctionCall randomFunction = new FunctionCall( - tester().getMetadata().resolveBuiltinFunction("random", ImmutableList.of()).toQualifiedName(), + tester().getMetadata().resolveBuiltinFunction("random", ImmutableList.of()), ImmutableList.of()); - tester().assertThat(new OptimizeDuplicateInsensitiveJoins(tester().getMetadata())) + tester().assertThat(new OptimizeDuplicateInsensitiveJoins()) .on(p -> { Symbol symbolA = p.symbol("a"); Symbol symbolB = p.symbol("b"); @@ -192,10 +197,10 @@ public void testNondeterministicFilter() public void testNondeterministicProjection() { FunctionCall randomFunction = new FunctionCall( - tester().getMetadata().resolveBuiltinFunction("random", ImmutableList.of()).toQualifiedName(), + tester().getMetadata().resolveBuiltinFunction("random", ImmutableList.of()), ImmutableList.of()); - tester().assertThat(new OptimizeDuplicateInsensitiveJoins(tester().getMetadata())) + tester().assertThat(new OptimizeDuplicateInsensitiveJoins()) .on(p -> { Symbol symbolA = p.symbol("a"); Symbol symbolB = p.symbol("b"); @@ -218,7 +223,7 @@ public void testNondeterministicProjection() @Test public void testUnion() { - tester().assertThat(new OptimizeDuplicateInsensitiveJoins(tester().getMetadata())) + tester().assertThat(new OptimizeDuplicateInsensitiveJoins()) .on(p -> { Symbol symbolA = p.symbol("a"); Symbol symbolB = p.symbol("b"); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPreAggregateCaseAggregations.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPreAggregateCaseAggregations.java index c697b2e8a158..4237ed5b6f22 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPreAggregateCaseAggregations.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPreAggregateCaseAggregations.java @@ -19,6 +19,8 @@ import io.trino.Session; import io.trino.connector.MockConnectorFactory; import io.trino.connector.MockConnectorTableHandle; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.connector.ColumnMetadata; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.type.Decimals; @@ -38,7 +40,6 @@ import io.trino.sql.planner.assertions.ExpressionMatcher; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.PlanNode; -import io.trino.sql.tree.QualifiedName; import io.trino.testing.PlanTester; import org.intellij.lang.annotations.Language; import org.junit.jupiter.api.Test; @@ -57,6 +58,7 @@ import static io.trino.spi.type.TinyintType.TINYINT; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; +import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MODULUS; import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MULTIPLY; import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; @@ -78,6 +80,9 @@ public class TestPreAggregateCaseAggregations extends BasePlanTest { + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction CONCAT = FUNCTIONS.resolveFunction("concat", fromTypes(VARCHAR, VARCHAR)); + private static final SchemaTableName TABLE = new SchemaTableName("default", "t"); @Override @@ -161,7 +166,7 @@ public void testPreAggregatesCaseAggregations() SINGLE, exchange( project(ImmutableMap.of( - "KEY", expression(new FunctionCall(QualifiedName.of("concat"), ImmutableList.of(new SymbolReference("COL_VARCHAR"), new Constant(VARCHAR, Slices.utf8Slice("a"))))), + "KEY", expression(new FunctionCall(CONCAT, ImmutableList.of(new SymbolReference("COL_VARCHAR"), new Constant(VARCHAR, Slices.utf8Slice("a"))))), "VALUE_BIGINT", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new InPredicate(new SymbolReference("COL_BIGINT"), ImmutableList.of(new Constant(BIGINT, 1L), new Constant(BIGINT, 2L))), new ArithmeticBinaryExpression(MULTIPLY, new SymbolReference("COL_BIGINT"), new Constant(BIGINT, 2L)))), Optional.empty())), "VALUE_INT_CAST", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(EQUAL, new SymbolReference("COL_BIGINT"), new Constant(BIGINT, 1L)), new Cast(new Cast(new ArithmeticBinaryExpression(MULTIPLY, new SymbolReference("COL_BIGINT"), new Constant(BIGINT, 2L)), INTEGER), BIGINT))), Optional.empty())), "VALUE_2_BIGINT", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(GREATER_THAN, new ArithmeticBinaryExpression(MODULUS, new SymbolReference("COL_BIGINT"), new Constant(BIGINT, 2L)), new Constant(BIGINT, 1L)), new ArithmeticBinaryExpression(MULTIPLY, new SymbolReference("COL_BIGINT"), new Constant(BIGINT, 2L)))), Optional.empty())), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownProjectionsFromPatternRecognition.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownProjectionsFromPatternRecognition.java index 230c3aed78ba..e7c607a74fb1 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownProjectionsFromPatternRecognition.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownProjectionsFromPatternRecognition.java @@ -18,7 +18,7 @@ import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slices; import io.trino.metadata.ResolvedFunction; -import io.trino.spi.type.VarcharType; +import io.trino.metadata.TestingFunctionResolution; import io.trino.sql.ir.ArithmeticBinaryExpression; import io.trino.sql.ir.ComparisonExpression; import io.trino.sql.ir.Constant; @@ -28,8 +28,10 @@ import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.rowpattern.AggregatedSetDescriptor; import io.trino.sql.planner.rowpattern.AggregationValuePointer; +import io.trino.sql.planner.rowpattern.ClassifierValuePointer; +import io.trino.sql.planner.rowpattern.LogicalIndexPointer; +import io.trino.sql.planner.rowpattern.MatchNumberValuePointer; import io.trino.sql.planner.rowpattern.ir.IrLabel; -import io.trino.sql.tree.QualifiedName; import org.junit.jupiter.api.Test; import java.util.Optional; @@ -37,6 +39,7 @@ import static io.trino.metadata.MetadataManager.createTestMetadataManager; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.ADD; import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MULTIPLY; @@ -50,7 +53,9 @@ public class TestPushDownProjectionsFromPatternRecognition extends BaseRuleTest { - private static final QualifiedName MAX_BY = createTestMetadataManager().resolveBuiltinFunction("max_by", fromTypes(BIGINT, BIGINT)).toQualifiedName(); + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction CONCAT = FUNCTIONS.resolveFunction("concat", fromTypes(VARCHAR, VARCHAR)); + private static final ResolvedFunction MAX_BY = createTestMetadataManager().resolveBuiltinFunction("max_by", fromTypes(BIGINT, BIGINT)); @Test public void testNoAggregations() @@ -72,9 +77,12 @@ public void testDoNotPushRuntimeEvaluatedArguments() .addVariableDefinition( new IrLabel("X"), new ComparisonExpression(GREATER_THAN, new FunctionCall(MAX_BY, ImmutableList.of( - new ArithmeticBinaryExpression(ADD, new Constant(INTEGER, 1L), new FunctionCall(QualifiedName.of("match_number"), ImmutableList.of())), - new FunctionCall(QualifiedName.of("concat"), ImmutableList.of(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x")), new FunctionCall(QualifiedName.of("classifier"), ImmutableList.of()))))), - new Constant(INTEGER, 5L))) + new ArithmeticBinaryExpression(ADD, new Constant(INTEGER, 1L), new SymbolReference("match")), + new FunctionCall(CONCAT, ImmutableList.of(new Constant(VARCHAR, Slices.utf8Slice("x")), new SymbolReference("classifier"))))), + new Constant(INTEGER, 5L)), + ImmutableMap.of( + "classifier", new ClassifierValuePointer(new LogicalIndexPointer(ImmutableSet.of(), true, true, 0, 0)), + "match", new MatchNumberValuePointer())) .source(p.values(p.symbol("a"))))) .doesNotFire(); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushInequalityFilterExpressionBelowJoinRuleSet.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushInequalityFilterExpressionBelowJoinRuleSet.java index 86d721749430..8b08a6ae973f 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushInequalityFilterExpressionBelowJoinRuleSet.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushInequalityFilterExpressionBelowJoinRuleSet.java @@ -49,7 +49,7 @@ public class TestPushInequalityFilterExpressionBelowJoinRuleSet @BeforeAll public void setUpBeforeClass() { - ruleSet = new PushInequalityFilterExpressionBelowJoinRuleSet(tester().getMetadata(), tester().getTypeAnalyzer()); + ruleSet = new PushInequalityFilterExpressionBelowJoinRuleSet(tester().getTypeAnalyzer()); } @Test diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushMergeWriterUpdateIntoConnector.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushMergeWriterUpdateIntoConnector.java index 2b9e19caf074..68170d3d98ea 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushMergeWriterUpdateIntoConnector.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushMergeWriterUpdateIntoConnector.java @@ -144,7 +144,7 @@ public void testPushUpdateIntoConnectorUpdateAll() Symbol rowCount = p.symbol("row_count"); // set function call, which represents update all columns statement Expression updateMergeRowExpression = new Row(ImmutableList.of(new FunctionCall( - ruleTester.getMetadata().resolveBuiltinFunction("from_base64", fromTypes(VARCHAR)).toQualifiedName(), + ruleTester.getMetadata().resolveBuiltinFunction("from_base64", fromTypes(VARCHAR)), ImmutableList.of(new Constant(VARCHAR, Slices.utf8Slice("")))))); return p.tableFinish( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java index 978dfc3b1f8d..f518c039936a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java @@ -148,7 +148,7 @@ public void testPushProjection() .put(dereference, new SubscriptExpression(baseColumn.toSymbolReference(), new Constant(INTEGER, 1L))) .put(constant, new Constant(INTEGER, 5L)) .put(call, new FunctionCall( - ruleTester.getMetadata().resolveBuiltinFunction("starts_with", fromTypes(VARCHAR, VARCHAR)).toQualifiedName(), + ruleTester.getMetadata().resolveBuiltinFunction("starts_with", fromTypes(VARCHAR, VARCHAR)), ImmutableList.of(new Constant(VARCHAR, Slices.utf8Slice("abc")), new Constant(VARCHAR, Slices.utf8Slice("ab"))))) .buildOrThrow(); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceJoinOverConstantWithProject.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceJoinOverConstantWithProject.java index c79d7026eea2..d44d5e627d41 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceJoinOverConstantWithProject.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceJoinOverConstantWithProject.java @@ -49,7 +49,7 @@ public class TestReplaceJoinOverConstantWithProject @Test public void testDoesNotFireOnJoinWithEmptySource() { - tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) + tester().assertThat(new ReplaceJoinOverConstantWithProject()) .on(p -> p.join( INNER, @@ -57,7 +57,7 @@ public void testDoesNotFireOnJoinWithEmptySource() p.values(0, p.symbol("b")))) .doesNotFire(); - tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) + tester().assertThat(new ReplaceJoinOverConstantWithProject()) .on(p -> p.join( INNER, @@ -69,7 +69,7 @@ public void testDoesNotFireOnJoinWithEmptySource() @Test public void testDoesNotFireOnJoinWithCondition() { - tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) + tester().assertThat(new ReplaceJoinOverConstantWithProject()) .on(p -> p.join( INNER, @@ -78,7 +78,7 @@ public void testDoesNotFireOnJoinWithCondition() new EquiJoinClause(p.symbol("a"), p.symbol("b")))) .doesNotFire(); - tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) + tester().assertThat(new ReplaceJoinOverConstantWithProject()) .on(p -> p.join( INNER, @@ -91,7 +91,7 @@ public void testDoesNotFireOnJoinWithCondition() @Test public void testDoesNotFireOnValuesWithMultipleRows() { - tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) + tester().assertThat(new ReplaceJoinOverConstantWithProject()) .on(p -> p.join( INNER, @@ -103,7 +103,7 @@ public void testDoesNotFireOnValuesWithMultipleRows() @Test public void testDoesNotFireOnValuesWithNoOutputs() { - tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) + tester().assertThat(new ReplaceJoinOverConstantWithProject()) .on(p -> p.join( INNER, @@ -115,7 +115,7 @@ public void testDoesNotFireOnValuesWithNoOutputs() @Test public void testDoesNotFireOnValuesWithNonRowExpression() { - tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) + tester().assertThat(new ReplaceJoinOverConstantWithProject()) .on(p -> p.join( INNER, @@ -127,7 +127,7 @@ public void testDoesNotFireOnValuesWithNonRowExpression() @Test public void testDoesNotFireOnOuterJoinWhenSourcePossiblyEmpty() { - tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) + tester().assertThat(new ReplaceJoinOverConstantWithProject()) .on(p -> p.join( LEFT, @@ -137,7 +137,7 @@ public void testDoesNotFireOnOuterJoinWhenSourcePossiblyEmpty() p.values(10, p.symbol("b"))))) .doesNotFire(); - tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) + tester().assertThat(new ReplaceJoinOverConstantWithProject()) .on(p -> p.join( RIGHT, @@ -147,7 +147,7 @@ public void testDoesNotFireOnOuterJoinWhenSourcePossiblyEmpty() p.values(1, p.symbol("b")))) .doesNotFire(); - tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) + tester().assertThat(new ReplaceJoinOverConstantWithProject()) .on(p -> p.join( FULL, @@ -157,7 +157,7 @@ public void testDoesNotFireOnOuterJoinWhenSourcePossiblyEmpty() p.values(10, p.symbol("b"))))) .doesNotFire(); - tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) + tester().assertThat(new ReplaceJoinOverConstantWithProject()) .on(p -> p.join( FULL, @@ -171,7 +171,7 @@ public void testDoesNotFireOnOuterJoinWhenSourcePossiblyEmpty() @Test public void testReplaceInnerJoinWithProject() { - tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) + tester().assertThat(new ReplaceJoinOverConstantWithProject()) .on(p -> p.join( INNER, @@ -185,7 +185,7 @@ public void testReplaceInnerJoinWithProject() "c", PlanMatchPattern.expression(new SymbolReference("c"))), values("c"))); - tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) + tester().assertThat(new ReplaceJoinOverConstantWithProject()) .on(p -> p.join( INNER, @@ -203,7 +203,7 @@ public void testReplaceInnerJoinWithProject() @Test public void testReplaceLeftJoinWithProject() { - tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) + tester().assertThat(new ReplaceJoinOverConstantWithProject()) .on(p -> p.join( LEFT, @@ -217,7 +217,7 @@ public void testReplaceLeftJoinWithProject() "c", PlanMatchPattern.expression(new SymbolReference("c"))), values("c"))); - tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) + tester().assertThat(new ReplaceJoinOverConstantWithProject()) .on(p -> p.join( LEFT, @@ -235,7 +235,7 @@ public void testReplaceLeftJoinWithProject() @Test public void testReplaceRightJoinWithProject() { - tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) + tester().assertThat(new ReplaceJoinOverConstantWithProject()) .on(p -> p.join( RIGHT, @@ -249,7 +249,7 @@ public void testReplaceRightJoinWithProject() "c", PlanMatchPattern.expression(new SymbolReference("c"))), values("c"))); - tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) + tester().assertThat(new ReplaceJoinOverConstantWithProject()) .on(p -> p.join( RIGHT, @@ -267,7 +267,7 @@ public void testReplaceRightJoinWithProject() @Test public void testReplaceFullJoinWithProject() { - tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) + tester().assertThat(new ReplaceJoinOverConstantWithProject()) .on(p -> p.join( FULL, @@ -281,7 +281,7 @@ public void testReplaceFullJoinWithProject() "c", PlanMatchPattern.expression(new SymbolReference("c"))), values("c"))); - tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) + tester().assertThat(new ReplaceJoinOverConstantWithProject()) .on(p -> p.join( FULL, @@ -299,7 +299,7 @@ public void testReplaceFullJoinWithProject() @Test public void testRemoveOutputDuplicates() { - tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) + tester().assertThat(new ReplaceJoinOverConstantWithProject()) .on(p -> p.join( INNER, @@ -322,10 +322,10 @@ public void testRemoveOutputDuplicates() public void testNonDeterministicValues() { FunctionCall randomFunction = new FunctionCall( - tester().getMetadata().resolveBuiltinFunction("random", ImmutableList.of()).toQualifiedName(), + tester().getMetadata().resolveBuiltinFunction("random", ImmutableList.of()), ImmutableList.of()); - tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) + tester().assertThat(new ReplaceJoinOverConstantWithProject()) .on(p -> p.join( INNER, @@ -334,10 +334,10 @@ public void testNonDeterministicValues() .doesNotFire(); FunctionCall uuidFunction = new FunctionCall( - tester().getMetadata().resolveBuiltinFunction("uuid", ImmutableList.of()).toQualifiedName(), + tester().getMetadata().resolveBuiltinFunction("uuid", ImmutableList.of()), ImmutableList.of()); - tester().assertThat(new ReplaceJoinOverConstantWithProject(tester().getMetadata())) + tester().assertThat(new ReplaceJoinOverConstantWithProject()) .on(p -> p.join( INNER, diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyFilterPredicate.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyFilterPredicate.java index 0670563c460b..2e6996191396 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyFilterPredicate.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyFilterPredicate.java @@ -55,7 +55,7 @@ public class TestSimplifyFilterPredicate public void testSimplifyIfExpression() { // true result iff the condition is true - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new IfExpression(new SymbolReference("a"), TRUE_LITERAL, FALSE_LITERAL), p.values(p.symbol("a")))) @@ -65,7 +65,7 @@ public void testSimplifyIfExpression() values("a"))); // true result iff the condition is true - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new IfExpression(new SymbolReference("a"), TRUE_LITERAL, new Constant(UnknownType.UNKNOWN, null)), p.values(p.symbol("a")))) @@ -75,7 +75,7 @@ public void testSimplifyIfExpression() values("a"))); // true result iff the condition is null or false - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new IfExpression(new SymbolReference("a"), FALSE_LITERAL, TRUE_LITERAL), p.values(p.symbol("a")))) @@ -85,7 +85,7 @@ public void testSimplifyIfExpression() values("a"))); // true result iff the condition is null or false - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new IfExpression(new SymbolReference("a"), new Constant(UnknownType.UNKNOWN, null), TRUE_LITERAL), p.values(p.symbol("a")))) @@ -95,7 +95,7 @@ public void testSimplifyIfExpression() values("a"))); // always true - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new IfExpression(new SymbolReference("a"), TRUE_LITERAL, TRUE_LITERAL), p.values(p.symbol("a")))) @@ -105,7 +105,7 @@ public void testSimplifyIfExpression() values("a"))); // always false - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new IfExpression(new SymbolReference("a"), FALSE_LITERAL, FALSE_LITERAL), p.values(p.symbol("a")))) @@ -115,7 +115,7 @@ public void testSimplifyIfExpression() values("a"))); // both results equal - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new IfExpression(new SymbolReference("a"), new ComparisonExpression(GREATER_THAN, new SymbolReference("b"), new Constant(INTEGER, 0L)), new ComparisonExpression(GREATER_THAN, new SymbolReference("b"), new Constant(INTEGER, 0L))), p.values(p.symbol("a"), p.symbol("b")))) @@ -126,9 +126,9 @@ public void testSimplifyIfExpression() // both results are equal non-deterministic expressions FunctionCall randomFunction = new FunctionCall( - tester().getMetadata().resolveBuiltinFunction("random", ImmutableList.of()).toQualifiedName(), + tester().getMetadata().resolveBuiltinFunction("random", ImmutableList.of()), ImmutableList.of()); - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new IfExpression( new SymbolReference("a"), @@ -138,7 +138,7 @@ public void testSimplifyIfExpression() .doesNotFire(); // always null (including the default) -> simplified to FALSE - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new IfExpression(new SymbolReference("a"), new Constant(UnknownType.UNKNOWN, null)), p.values(p.symbol("a")))) @@ -148,7 +148,7 @@ public void testSimplifyIfExpression() values("a"))); // condition is true -> first branch - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new IfExpression(TRUE_LITERAL, new SymbolReference("a"), new NotExpression(new SymbolReference("a"))), p.values(p.symbol("a")))) @@ -158,7 +158,7 @@ public void testSimplifyIfExpression() values("a"))); // condition is true -> second branch - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new IfExpression(FALSE_LITERAL, new SymbolReference("a"), new NotExpression(new SymbolReference("a"))), p.values(p.symbol("a")))) @@ -168,7 +168,7 @@ public void testSimplifyIfExpression() values("a"))); // condition is true, no second branch -> the result is null, simplified to FALSE - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new IfExpression(FALSE_LITERAL, new SymbolReference("a")), p.values(p.symbol("a")))) @@ -178,7 +178,7 @@ public void testSimplifyIfExpression() values("a"))); // not known result (`b`) - cannot optimize - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new IfExpression(new SymbolReference("a"), TRUE_LITERAL, new SymbolReference("b")), p.values(p.symbol("a"), p.symbol("b")))) @@ -189,7 +189,7 @@ public void testSimplifyIfExpression() public void testSimplifyNullIfExpression() { // NULLIF(x, y) returns true if and only if: x != y AND x = true - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new NullIfExpression(new SymbolReference("a"), new SymbolReference("b")), p.values(p.symbol("a"), p.symbol("b")))) @@ -206,7 +206,7 @@ public void testSimplifyNullIfExpression() @Test public void testSimplifySearchedCaseExpression() { - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new SearchedCaseExpression(ImmutableList.of( new WhenClause(new ComparisonExpression(LESS_THAN, new SymbolReference("a"), new Constant(INTEGER, 0L)), TRUE_LITERAL), @@ -217,7 +217,7 @@ public void testSimplifySearchedCaseExpression() .doesNotFire(); // all results true - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new SearchedCaseExpression(ImmutableList.of( new WhenClause(new ComparisonExpression(LESS_THAN, new SymbolReference("a"), new Constant(INTEGER, 0L)), TRUE_LITERAL), @@ -231,7 +231,7 @@ public void testSimplifySearchedCaseExpression() values("a"))); // all results not true - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new SearchedCaseExpression(ImmutableList.of( new WhenClause(new ComparisonExpression(LESS_THAN, new SymbolReference("a"), new Constant(INTEGER, 0L)), FALSE_LITERAL), @@ -245,7 +245,7 @@ public void testSimplifySearchedCaseExpression() values("a"))); // all results not true (including default null result) - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new SearchedCaseExpression(ImmutableList.of( new WhenClause(new ComparisonExpression(LESS_THAN, new SymbolReference("a"), new Constant(INTEGER, 0L)), FALSE_LITERAL), @@ -259,7 +259,7 @@ public void testSimplifySearchedCaseExpression() values("a"))); // one result true, and remaining results not true - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new SearchedCaseExpression(ImmutableList.of( new WhenClause(new ComparisonExpression(LESS_THAN, new SymbolReference("a"), new Constant(INTEGER, 0L)), FALSE_LITERAL), @@ -273,7 +273,7 @@ public void testSimplifySearchedCaseExpression() values("a"))); // first result true, and remaining results not true - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new SearchedCaseExpression(ImmutableList.of( new WhenClause(new ComparisonExpression(LESS_THAN, new SymbolReference("a"), new Constant(INTEGER, 0L)), TRUE_LITERAL), @@ -287,7 +287,7 @@ public void testSimplifySearchedCaseExpression() values("a"))); // all results not true, and default true - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new SearchedCaseExpression(ImmutableList.of( new WhenClause(new ComparisonExpression(LESS_THAN, new SymbolReference("a"), new Constant(INTEGER, 0L)), FALSE_LITERAL), @@ -310,7 +310,7 @@ public void testSimplifySearchedCaseExpression() values("a"))); // all conditions not true - return the default - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new SearchedCaseExpression(ImmutableList.of( new WhenClause(FALSE_LITERAL, new SymbolReference("a")), @@ -324,7 +324,7 @@ public void testSimplifySearchedCaseExpression() values("a", "b"))); // all conditions not true, no default specified - return false - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new SearchedCaseExpression(ImmutableList.of( new WhenClause(FALSE_LITERAL, new SymbolReference("a")), @@ -338,7 +338,7 @@ public void testSimplifySearchedCaseExpression() values("a"))); // not true conditions preceding true condition - return the result associated with the true condition - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new SearchedCaseExpression(ImmutableList.of( new WhenClause(FALSE_LITERAL, new SymbolReference("a")), @@ -352,7 +352,7 @@ public void testSimplifySearchedCaseExpression() values("a", "b"))); // remove not true condition and move the result associated with the first true condition to default - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new SearchedCaseExpression(ImmutableList.of( new WhenClause(FALSE_LITERAL, new SymbolReference("a")), @@ -366,7 +366,7 @@ public void testSimplifySearchedCaseExpression() values("a", "b"))); // move the result associated with the first true condition to default - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new SearchedCaseExpression(ImmutableList.of( new WhenClause(new ComparisonExpression(LESS_THAN, new SymbolReference("b"), new Constant(INTEGER, 0L)), new SymbolReference("a")), @@ -384,7 +384,7 @@ public void testSimplifySearchedCaseExpression() values("a", "b"))); // cannot remove any clause - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new SearchedCaseExpression(ImmutableList.of( new WhenClause(new ComparisonExpression(LESS_THAN, new SymbolReference("b"), new Constant(INTEGER, 0L)), new SymbolReference("a")), @@ -397,7 +397,7 @@ public void testSimplifySearchedCaseExpression() @Test public void testSimplifySimpleCaseExpression() { - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new SimpleCaseExpression( new SymbolReference("a"), @@ -409,7 +409,7 @@ public void testSimplifySimpleCaseExpression() .doesNotFire(); // comparison with null returns null - no WHEN branch matches, return default value - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new SimpleCaseExpression( new Constant(UnknownType.UNKNOWN, null), @@ -424,7 +424,7 @@ public void testSimplifySimpleCaseExpression() values("a", "b"))); // comparison with null returns null - no WHEN branch matches, the result is default null, simplified to FALSE - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new SimpleCaseExpression( new Constant(UnknownType.UNKNOWN, null), @@ -439,7 +439,7 @@ public void testSimplifySimpleCaseExpression() values("a"))); // all results true - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new SimpleCaseExpression( new SymbolReference("a"), @@ -454,7 +454,7 @@ public void testSimplifySimpleCaseExpression() values("a", "b"))); // all results not true - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new SimpleCaseExpression( new SymbolReference("a"), @@ -469,7 +469,7 @@ public void testSimplifySimpleCaseExpression() values("a", "b"))); // all results not true (including default null result) - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new SimpleCaseExpression( new SymbolReference("a"), @@ -487,7 +487,7 @@ public void testSimplifySimpleCaseExpression() @Test public void testCastNull() { - tester().assertThat(new SimplifyFilterPredicate(tester().getMetadata())) + tester().assertThat(new SimplifyFilterPredicate()) .on(p -> p.filter( new IfExpression(new SymbolReference("a"), new Cast(new Cast(new Constant(BOOLEAN, null), BIGINT), BOOLEAN), FALSE_LITERAL), p.values(p.symbol("a", BOOLEAN)))) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestExpressionEquivalence.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestExpressionEquivalence.java index 7bae7598d8d8..5d4345dfab85 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestExpressionEquivalence.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestExpressionEquivalence.java @@ -138,8 +138,8 @@ public void testEquivalent() new ComparisonExpression(EQUAL, new Constant(createTimestampWithTimeZoneType(9), DateTimes.parseTimestampWithTimeZone(9, "2021-05-10 12:34:56.123456789 +8")), new Constant(createTimestampWithTimeZoneType(9), DateTimes.parseTimestampWithTimeZone(9, "2020-05-10 12:34:56.123456789 +8")))); assertEquivalent( - new FunctionCall(MOD.toQualifiedName(), ImmutableList.of(new Constant(INTEGER, 4L), new Constant(INTEGER, 5L))), - new FunctionCall(MOD.toQualifiedName(), ImmutableList.of(new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)))); + new FunctionCall(MOD, ImmutableList.of(new Constant(INTEGER, 4L), new Constant(INTEGER, 5L))), + new FunctionCall(MOD, ImmutableList.of(new Constant(INTEGER, 4L), new Constant(INTEGER, 5L)))); assertEquivalent( new SymbolReference("a_bigint"), @@ -263,8 +263,8 @@ public void testNotEquivalent() new ComparisonExpression(GREATER_THAN_OR_EQUAL, new Constant(INTEGER, 5L), new Constant(INTEGER, 6L))); assertNotEquivalent( - new FunctionCall(MOD.toQualifiedName(), ImmutableList.of(new Constant(INTEGER, 4L), new Constant(INTEGER, 5L))), - new FunctionCall(MOD.toQualifiedName(), ImmutableList.of(new Constant(INTEGER, 5L), new Constant(INTEGER, 4L)))); + new FunctionCall(MOD, ImmutableList.of(new Constant(INTEGER, 4L), new Constant(INTEGER, 5L))), + new FunctionCall(MOD, ImmutableList.of(new Constant(INTEGER, 5L), new Constant(INTEGER, 4L)))); assertNotEquivalent( new SymbolReference("a_bigint"), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestMergeWindows.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestMergeWindows.java index 44a94206142f..39c43f6808cb 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestMergeWindows.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestMergeWindows.java @@ -583,7 +583,7 @@ public void testNotMergeDifferentNullOrdering() private void assertUnitPlan(@Language("SQL") String sql, PlanMatchPattern pattern) { List optimizers = ImmutableList.of( - new UnaliasSymbolReferences(getPlanTester().getPlannerContext().getMetadata()), + new UnaliasSymbolReferences(), new IterativeOptimizer( getPlanTester().getPlannerContext(), new RuleStatsRecorder(), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestRemoveUnsupportedDynamicFilters.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestRemoveUnsupportedDynamicFilters.java index c1b74572b270..609aae1344d6 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestRemoveUnsupportedDynamicFilters.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestRemoveUnsupportedDynamicFilters.java @@ -184,7 +184,6 @@ public void testUnmatchedDynamicFilter() ordersTableScanNode, builder.filter( combineConjuncts( - metadata, new ComparisonExpression(GREATER_THAN, new SymbolReference("LINEITEM_OK"), new Constant(INTEGER, 0L)), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, lineitemOrderKeySymbol.toSymbolReference())), lineitemTableScanNode), @@ -218,7 +217,6 @@ public void testRemoveDynamicFilterNotAboveTableScan() INNER, builder.filter( combineConjuncts( - metadata, new ComparisonExpression(GREATER_THAN, new SymbolReference("LINEITEM_OK"), new Constant(INTEGER, 0L)), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, ordersOrderKeySymbol.toSymbolReference())), builder.values(lineitemOrderKeySymbol)), @@ -254,13 +252,10 @@ public void testNestedDynamicFilterDisjunctionRewrite() ordersTableScanNode, builder.filter( combineConjuncts( - metadata, combineDisjuncts( - metadata, new IsNullPredicate(new SymbolReference("LINEITEM_OK")), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, lineitemOrderKeySymbol.toSymbolReference())), combineDisjuncts( - metadata, new IsNotNullPredicate(new SymbolReference("LINEITEM_OK")), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, lineitemOrderKeySymbol.toSymbolReference()))), lineitemTableScanNode), @@ -291,13 +286,10 @@ public void testNestedDynamicFilterConjunctionRewrite() ordersTableScanNode, builder.filter( combineDisjuncts( - metadata, combineConjuncts( - metadata, new IsNullPredicate(new SymbolReference("LINEITEM_OK")), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, lineitemOrderKeySymbol.toSymbolReference())), combineConjuncts( - metadata, new IsNotNullPredicate(new SymbolReference("LINEITEM_OK")), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, lineitemOrderKeySymbol.toSymbolReference()))), lineitemTableScanNode), @@ -318,7 +310,6 @@ public void testNestedDynamicFilterConjunctionRewrite() .right( filter( combineDisjuncts( - metadata, new IsNullPredicate(new SymbolReference("LINEITEM_OK")), new IsNotNullPredicate(new SymbolReference("LINEITEM_OK"))), tableScan("lineitem", ImmutableMap.of("LINEITEM_OK", "orderkey"))))))); @@ -412,7 +403,6 @@ public void testDynamicFilterConsumedOnFilteringSourceSideInSemiJoin() ordersTableScanNode, builder.filter( combineConjuncts( - metadata, new ComparisonExpression(GREATER_THAN, new SymbolReference("LINEITEM_OK"), new Constant(INTEGER, 0L)), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, lineitemOrderKeySymbol.toSymbolReference())), lineitemTableScanNode), @@ -438,7 +428,6 @@ public void testUnmatchedDynamicFilterInSemiJoin() PlanNode root = builder.semiJoin( builder.filter( combineConjuncts( - metadata, new ComparisonExpression(GREATER_THAN, new SymbolReference("ORDERS_OK"), new Constant(INTEGER, 0L)), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, ordersOrderKeySymbol.toSymbolReference())), ordersTableScanNode), @@ -465,7 +454,6 @@ public void testRemoveDynamicFilterNotAboveTableScanWithSemiJoin() PlanNode root = builder.semiJoin( builder.filter( combineConjuncts( - metadata, new ComparisonExpression(GREATER_THAN, new SymbolReference("ORDERS_OK"), new Constant(INTEGER, 0L)), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, ordersOrderKeySymbol.toSymbolReference())), builder.values(ordersOrderKeySymbol)), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestReorderWindows.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestReorderWindows.java index f5be3a460a63..fd2f07b0141e 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestReorderWindows.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestReorderWindows.java @@ -341,7 +341,7 @@ private void assertUnitPlan(@Language("SQL") String sql, PlanMatchPattern patter { PlannerContext plannerContext = getPlanTester().getPlannerContext(); List optimizers = ImmutableList.of( - new UnaliasSymbolReferences(getPlanTester().getPlannerContext().getMetadata()), + new UnaliasSymbolReferences(), new PredicatePushDown( getPlanTester().getPlannerContext(), new IrTypeAnalyzer(plannerContext), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestSetFlattening.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestSetFlattening.java index 65a690859e2a..a0b0dfe029d8 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestSetFlattening.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestSetFlattening.java @@ -133,7 +133,7 @@ public void testDoesNotFlattenDifferentSetOperations() protected void assertPlan(String sql, PlanMatchPattern pattern) { List optimizers = ImmutableList.of( - new UnaliasSymbolReferences(getPlanTester().getPlannerContext().getMetadata()), + new UnaliasSymbolReferences(), new IterativeOptimizer( getPlanTester().getPlannerContext(), new RuleStatsRecorder(), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestUnaliasSymbolReferences.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestUnaliasSymbolReferences.java index 5ac50a333659..8242712ee886 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestUnaliasSymbolReferences.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestUnaliasSymbolReferences.java @@ -67,7 +67,7 @@ public void testDynamicFilterIdUnAliased() String probeTable = "supplier"; String buildTable = "nation"; assertOptimizedPlan( - new UnaliasSymbolReferences(getPlanTester().getPlannerContext().getMetadata()), + new UnaliasSymbolReferences(), (p, session, metadata) -> { ColumnHandle column = new TpchColumnHandle("nationkey", BIGINT); Symbol buildColumnSymbol = p.symbol("nationkey"); @@ -123,7 +123,7 @@ probeColumn2, new TpchColumnHandle("suppkey", BIGINT))))), public void testGroupIdGroupingSetsDeduplicated() { assertOptimizedPlan( - new UnaliasSymbolReferences(getPlanTester().getPlannerContext().getMetadata()), + new UnaliasSymbolReferences(), (p, session, metadata) -> { Symbol symbol = p.symbol("symbol"); Symbol alias1 = p.symbol("alias1"); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestPatternRecognitionNodeSerialization.java b/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestPatternRecognitionNodeSerialization.java index 78e85f6e6f24..5f9fe55884d8 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestPatternRecognitionNodeSerialization.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestPatternRecognitionNodeSerialization.java @@ -21,6 +21,7 @@ import io.airlift.json.ObjectMapperProvider; import io.trino.block.BlockJsonSerde; import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.block.Block; import io.trino.spi.block.TestingBlockEncodingSerde; import io.trino.spi.type.TestingTypeManager; @@ -67,6 +68,9 @@ public class TestPatternRecognitionNodeSerialization { + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction RANDOM = FUNCTIONS.resolveFunction("random", fromTypes()); + private static final JsonCodec VALUE_POINTER_CODEC; private static final JsonCodec EXPRESSION_AND_VALUE_POINTERS_CODEC; private static final JsonCodec MEASURE_CODEC; @@ -130,7 +134,7 @@ public void testExpressionAndValuePointersRoundtrip() assertJsonRoundTrip(EXPRESSION_AND_VALUE_POINTERS_CODEC, new ExpressionAndValuePointers( new IfExpression( new ComparisonExpression(GREATER_THAN, new SymbolReference("classifier"), new SymbolReference("x")), - new FunctionCall("rand", ImmutableList.of()), + new FunctionCall(RANDOM, ImmutableList.of()), new ArithmeticUnaryExpression(MINUS, new SymbolReference("match_number"))), ImmutableList.of( new ExpressionAndValuePointers.Assignment( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/sanity/TestDynamicFiltersChecker.java b/core/trino-main/src/test/java/io/trino/sql/planner/sanity/TestDynamicFiltersChecker.java index 983689ec5dca..16c8894ae129 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/sanity/TestDynamicFiltersChecker.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/sanity/TestDynamicFiltersChecker.java @@ -143,7 +143,6 @@ public void testUnmatchedDynamicFilter() ordersTableScanNode, builder.filter( combineConjuncts( - metadata, new ComparisonExpression(GREATER_THAN, new SymbolReference("LINEITEM_OK"), new Constant(INTEGER, 0L)), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, lineitemOrderKeySymbol.toSymbolReference())), lineitemTableScanNode), @@ -169,7 +168,6 @@ public void testDynamicFilterNotAboveTableScan() INNER, builder.filter( combineConjuncts( - metadata, new ComparisonExpression(GREATER_THAN, new SymbolReference("LINEITEM_OK"), new Constant(INTEGER, 0L)), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, ordersOrderKeySymbol.toSymbolReference())), builder.values(lineitemOrderKeySymbol)), @@ -197,13 +195,10 @@ public void testUnmatchedNestedDynamicFilter() ordersTableScanNode, builder.filter( combineConjuncts( - metadata, combineDisjuncts( - metadata, new IsNullPredicate(new SymbolReference("LINEITEM_OK")), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, lineitemOrderKeySymbol.toSymbolReference())), combineDisjuncts( - metadata, new IsNotNullPredicate(new SymbolReference("LINEITEM_OK")), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, lineitemOrderKeySymbol.toSymbolReference()))), lineitemTableScanNode), @@ -285,13 +280,11 @@ public void testDynamicFilterConsumedOnFilteringSourceSideInSemiJoin() PlanNode root = builder.semiJoin( builder.filter( combineConjuncts( - metadata, new ComparisonExpression(GREATER_THAN, new SymbolReference("ORDERS_OK"), new Constant(INTEGER, 0L)), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, lineitemOrderKeySymbol.toSymbolReference())), ordersTableScanNode), builder.filter( combineConjuncts( - metadata, new ComparisonExpression(GREATER_THAN, new SymbolReference("LINEITEM_OK"), new Constant(INTEGER, 0L)), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, lineitemOrderKeySymbol.toSymbolReference())), lineitemTableScanNode), @@ -316,7 +309,6 @@ public void testUnmatchedDynamicFilterInSemiJoin() builder.semiJoin( builder.filter( combineConjuncts( - metadata, new ComparisonExpression(GREATER_THAN, new SymbolReference("ORDERS_OK"), new Constant(INTEGER, 0L)), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, ordersOrderKeySymbol.toSymbolReference())), ordersTableScanNode), @@ -339,7 +331,6 @@ public void testDynamicFilterNotAboveTableScanWithSemiJoin() PlanNode root = builder.semiJoin( builder.filter( combineConjuncts( - metadata, new ComparisonExpression(GREATER_THAN, new SymbolReference("ORDERS_OK"), new Constant(INTEGER, 0L)), createDynamicFilterExpression(metadata, new DynamicFilterId("DF"), BIGINT, ordersOrderKeySymbol.toSymbolReference())), builder.values(ordersOrderKeySymbol)), diff --git a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/AbstractTestExtractSpatial.java b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/AbstractTestExtractSpatial.java index 6347347429c7..adcf6a8d88b3 100644 --- a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/AbstractTestExtractSpatial.java +++ b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/AbstractTestExtractSpatial.java @@ -75,6 +75,6 @@ protected FunctionCall toPointCall(Expression x, Expression y) private FunctionCall functionCall(String name, List types, List arguments) { - return new FunctionCall(tester().getMetadata().resolveBuiltinFunction(name, fromTypes(types)).toQualifiedName(), arguments); + return new FunctionCall(tester().getMetadata().resolveBuiltinFunction(name, fromTypes(types)), arguments); } } diff --git a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestExtractSpatialInnerJoin.java b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestExtractSpatialInnerJoin.java index ea2ea470e738..0de93ce3c453 100644 --- a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestExtractSpatialInnerJoin.java +++ b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestExtractSpatialInnerJoin.java @@ -15,6 +15,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TestingFunctionResolution; import io.trino.sql.ir.ComparisonExpression; import io.trino.sql.ir.Constant; import io.trino.sql.ir.FunctionCall; @@ -25,13 +27,14 @@ import io.trino.sql.planner.iterative.rule.ExtractSpatialJoins.ExtractSpatialInnerJoin; import io.trino.sql.planner.iterative.rule.test.RuleBuilder; import io.trino.sql.planner.iterative.rule.test.RuleTester; -import io.trino.sql.tree.QualifiedName; import org.junit.jupiter.api.Test; import static io.trino.plugin.geospatial.GeometryType.GEOMETRY; import static io.trino.plugin.geospatial.SphericalGeographyType.SPHERICAL_GEOGRAPHY; +import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN; import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN; import static io.trino.sql.ir.ComparisonExpression.Operator.NOT_EQUAL; @@ -45,6 +48,11 @@ public class TestExtractSpatialInnerJoin extends AbstractTestExtractSpatial { + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(new GeoPlugin()); + private static final ResolvedFunction ST_CONTAINS = FUNCTIONS.resolveFunction("st_contains", fromTypes(GEOMETRY, GEOMETRY)); + private static final ResolvedFunction ST_POINT = FUNCTIONS.resolveFunction("st_point", fromTypes(DOUBLE, DOUBLE)); + private static final ResolvedFunction ST_GEOMETRY_FROM_TEXT = FUNCTIONS.resolveFunction("st_geometryfromtext", fromTypes(VARCHAR)); + @Test public void testDoesNotFire() { @@ -150,7 +158,7 @@ public void testConvertToSpatialJoin() }) .matches( spatialJoin( - new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("a"), new SymbolReference("b"))), + new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("a"), new SymbolReference("b"))), values(ImmutableMap.of("a", 0)), values(ImmutableMap.of("b", 0)))); @@ -172,7 +180,7 @@ public void testConvertToSpatialJoin() }) .matches( spatialJoin( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(NOT_EQUAL, new SymbolReference("name_1"), new SymbolReference("name_2")), new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("a"), new SymbolReference("b"))))), + new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(NOT_EQUAL, new SymbolReference("name_1"), new SymbolReference("name_2")), new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("a"), new SymbolReference("b"))))), values(ImmutableMap.of("a", 0, "name_1", 1)), values(ImmutableMap.of("b", 0, "name_2", 1)))); @@ -194,7 +202,7 @@ public void testConvertToSpatialJoin() }) .matches( spatialJoin( - new LogicalExpression(AND, ImmutableList.of(new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("a1"), new SymbolReference("b1"))), new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("a2"), new SymbolReference("b2"))))), + new LogicalExpression(AND, ImmutableList.of(new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("a1"), new SymbolReference("b1"))), new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("a2"), new SymbolReference("b2"))))), values(ImmutableMap.of("a1", 0, "a2", 1)), values(ImmutableMap.of("b1", 0, "b2", 1)))); } @@ -215,8 +223,8 @@ public void testPushDownFirstArgument() }) .matches( spatialJoin( - new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("point"))), - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new SymbolReference("wkt"))))), + new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("point"))), + project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new SymbolReference("wkt"))))), values(ImmutableMap.of("wkt", 0))), values(ImmutableMap.of("point", 0)))); @@ -250,9 +258,9 @@ public void testPushDownSecondArgument() }) .matches( spatialJoin( - new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("polygon"), new SymbolReference("st_point"))), + new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("polygon"), new SymbolReference("st_point"))), values(ImmutableMap.of("polygon", 0)), - project(ImmutableMap.of("st_point", expression(new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), + project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), values(ImmutableMap.of("lat", 0, "lng", 1))))); assertRuleApplication() @@ -286,10 +294,10 @@ public void testPushDownBothArguments() }) .matches( spatialJoin( - new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new SymbolReference("wkt"))))), + new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), + project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new SymbolReference("wkt"))))), values(ImmutableMap.of("wkt", 0))), - project(ImmutableMap.of("st_point", expression(new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), + project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), values(ImmutableMap.of("lat", 0, "lng", 1))))); } @@ -309,10 +317,10 @@ public void testPushDownOppositeOrder() p.values(wkt))); }) .matches( - spatialJoin(new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), - project(ImmutableMap.of("st_point", expression(new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), + spatialJoin(new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), + project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), values(ImmutableMap.of("lat", 0, "lng", 1))), - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new SymbolReference("wkt"))))), + project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new SymbolReference("wkt"))))), values(ImmutableMap.of("wkt", 0))))); } @@ -337,10 +345,10 @@ public void testPushDownAnd() }) .matches( spatialJoin( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(NOT_EQUAL, new SymbolReference("name_1"), new SymbolReference("name_2")), new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))))), - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new SymbolReference("wkt"))))), + new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(NOT_EQUAL, new SymbolReference("name_1"), new SymbolReference("name_2")), new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))))), + project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new SymbolReference("wkt"))))), values(ImmutableMap.of("wkt", 0, "name_1", 1))), - project(ImmutableMap.of("st_point", expression(new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), + project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), values(ImmutableMap.of("lat", 0, "lng", 1, "name_2", 2))))); // Multiple spatial functions - only the first one is being processed @@ -361,8 +369,8 @@ public void testPushDownAnd() }) .matches( spatialJoin( - new LogicalExpression(AND, ImmutableList.of(new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("geometry1"))), new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new SymbolReference("wkt2"))), new SymbolReference("geometry2"))))), - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new SymbolReference("wkt1"))))), + new LogicalExpression(AND, ImmutableList.of(new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("geometry1"))), new FunctionCall(ST_CONTAINS, ImmutableList.of(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new SymbolReference("wkt2"))), new SymbolReference("geometry2"))))), + project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new SymbolReference("wkt1"))))), values(ImmutableMap.of("wkt1", 0, "wkt2", 1))), values(ImmutableMap.of("geometry1", 0, "geometry2", 1)))); } diff --git a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestExtractSpatialLeftJoin.java b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestExtractSpatialLeftJoin.java index 18115fc26c52..9aae96c0f853 100644 --- a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestExtractSpatialLeftJoin.java +++ b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestExtractSpatialLeftJoin.java @@ -15,6 +15,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TestingFunctionResolution; import io.trino.sql.ir.ComparisonExpression; import io.trino.sql.ir.Constant; import io.trino.sql.ir.FunctionCall; @@ -25,13 +27,14 @@ import io.trino.sql.planner.iterative.rule.ExtractSpatialJoins; import io.trino.sql.planner.iterative.rule.test.RuleBuilder; import io.trino.sql.planner.iterative.rule.test.RuleTester; -import io.trino.sql.tree.QualifiedName; import org.junit.jupiter.api.Test; import static io.trino.plugin.geospatial.GeometryType.GEOMETRY; import static io.trino.plugin.geospatial.SphericalGeographyType.SPHERICAL_GEOGRAPHY; +import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.ir.ComparisonExpression.Operator.NOT_EQUAL; import static io.trino.sql.ir.LogicalExpression.Operator.AND; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; @@ -43,6 +46,11 @@ public class TestExtractSpatialLeftJoin extends AbstractTestExtractSpatial { + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(new GeoPlugin()); + private static final ResolvedFunction ST_CONTAINS = FUNCTIONS.resolveFunction("st_contains", fromTypes(GEOMETRY, GEOMETRY)); + private static final ResolvedFunction ST_GEOMETRY_FROM_TEXT = FUNCTIONS.resolveFunction("st_geometryfromtext", fromTypes(VARCHAR)); + private static final ResolvedFunction ST_POINT = FUNCTIONS.resolveFunction("st_point", fromTypes(DOUBLE, DOUBLE)); + @Test public void testDoesNotFire() { @@ -152,7 +160,7 @@ public void testConvertToSpatialJoin() }) .matches( spatialLeftJoin( - new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("a"), new SymbolReference("b"))), + new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("a"), new SymbolReference("b"))), values(ImmutableMap.of("a", 0)), values(ImmutableMap.of("b", 0)))); @@ -173,7 +181,7 @@ public void testConvertToSpatialJoin() }) .matches( spatialLeftJoin( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(NOT_EQUAL, new SymbolReference("name_1"), new SymbolReference("name_2")), new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("a"), new SymbolReference("b"))))), + new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(NOT_EQUAL, new SymbolReference("name_1"), new SymbolReference("name_2")), new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("a"), new SymbolReference("b"))))), values(ImmutableMap.of("a", 0, "name_1", 1)), values(ImmutableMap.of("b", 0, "name_2", 1)))); @@ -194,7 +202,7 @@ public void testConvertToSpatialJoin() }) .matches( spatialLeftJoin( - new LogicalExpression(AND, ImmutableList.of(new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("a1"), new SymbolReference("b1"))), new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("a2"), new SymbolReference("b2"))))), + new LogicalExpression(AND, ImmutableList.of(new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("a1"), new SymbolReference("b1"))), new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("a2"), new SymbolReference("b2"))))), values(ImmutableMap.of("a1", 0, "a2", 1)), values(ImmutableMap.of("b1", 0, "b2", 1)))); } @@ -214,8 +222,8 @@ public void testPushDownFirstArgument() }) .matches( spatialLeftJoin( - new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("point"))), - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new SymbolReference("wkt"))))), + new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("point"))), + project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new SymbolReference("wkt"))))), values(ImmutableMap.of("wkt", 0))), values(ImmutableMap.of("point", 0)))); @@ -247,9 +255,9 @@ public void testPushDownSecondArgument() }) .matches( spatialLeftJoin( - new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("polygon"), new SymbolReference("st_point"))), + new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("polygon"), new SymbolReference("st_point"))), values(ImmutableMap.of("polygon", 0)), - project(ImmutableMap.of("st_point", expression(new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), + project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), values(ImmutableMap.of("lat", 0, "lng", 1))))); assertRuleApplication() @@ -281,10 +289,10 @@ public void testPushDownBothArguments() }) .matches( spatialLeftJoin( - new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new SymbolReference("wkt"))))), + new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), + project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new SymbolReference("wkt"))))), values(ImmutableMap.of("wkt", 0))), - project(ImmutableMap.of("st_point", expression(new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), + project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), values(ImmutableMap.of("lat", 0, "lng", 1))))); } @@ -304,9 +312,9 @@ public void testPushDownOppositeOrder() }) .matches( spatialLeftJoin( - new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), - project(ImmutableMap.of("st_point", expression(new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), values(ImmutableMap.of("lat", 0, "lng", 1))), - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new SymbolReference("wkt"))))), + new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), + project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), values(ImmutableMap.of("lat", 0, "lng", 1))), + project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new SymbolReference("wkt"))))), values(ImmutableMap.of("wkt", 0))))); } @@ -330,10 +338,10 @@ public void testPushDownAnd() }) .matches( spatialLeftJoin( - new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(NOT_EQUAL, new SymbolReference("name_1"), new SymbolReference("name_2")), new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))))), - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new SymbolReference("wkt"))))), + new LogicalExpression(AND, ImmutableList.of(new ComparisonExpression(NOT_EQUAL, new SymbolReference("name_1"), new SymbolReference("name_2")), new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))))), + project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new SymbolReference("wkt"))))), values(ImmutableMap.of("wkt", 0, "name_1", 1))), - project(ImmutableMap.of("st_point", expression(new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), + project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), values(ImmutableMap.of("lat", 0, "lng", 1, "name_2", 2))))); // Multiple spatial functions - only the first one is being processed @@ -353,8 +361,8 @@ public void testPushDownAnd() }) .matches( spatialLeftJoin( - new LogicalExpression(AND, ImmutableList.of(new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("geometry1"))), new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new SymbolReference("wkt2"))), new SymbolReference("geometry2"))))), - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new SymbolReference("wkt1"))))), + new LogicalExpression(AND, ImmutableList.of(new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("geometry1"))), new FunctionCall(ST_CONTAINS, ImmutableList.of(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new SymbolReference("wkt2"))), new SymbolReference("geometry2"))))), + project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new SymbolReference("wkt1"))))), values(ImmutableMap.of("wkt1", 0, "wkt2", 1))), values(ImmutableMap.of("geometry1", 0, "geometry2", 1)))); } diff --git a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestPruneSpatialJoinChildrenColumns.java b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestPruneSpatialJoinChildrenColumns.java index 4c5ab2f0cb15..b5baa6fb01f1 100644 --- a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestPruneSpatialJoinChildrenColumns.java +++ b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestPruneSpatialJoinChildrenColumns.java @@ -25,7 +25,6 @@ import io.trino.sql.planner.iterative.rule.PruneSpatialJoinChildrenColumns; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.SpatialJoinNode; -import io.trino.sql.tree.QualifiedName; import org.junit.jupiter.api.Test; import java.util.Optional; @@ -59,12 +58,12 @@ public void testPruneOneChild() ImmutableList.of(a, b, r), new ComparisonExpression( LESS_THAN_OR_EQUAL, - new FunctionCall(TEST_ST_DISTANCE_FUNCTION.toQualifiedName(), ImmutableList.of(a.toSymbolReference(), b.toSymbolReference())), + new FunctionCall(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(a.toSymbolReference(), b.toSymbolReference())), r.toSymbolReference())); }) .matches( spatialJoin( - new ComparisonExpression(LESS_THAN_OR_EQUAL, new FunctionCall(QualifiedName.of("st_distance"), ImmutableList.of(new SymbolReference("a"), new SymbolReference("b"))), new SymbolReference("r")), + new ComparisonExpression(LESS_THAN_OR_EQUAL, new FunctionCall(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(new SymbolReference("a"), new SymbolReference("b"))), new SymbolReference("r")), Optional.empty(), Optional.of(ImmutableList.of("a", "b", "r")), strictProject( @@ -90,12 +89,12 @@ public void testPruneBothChildren() ImmutableList.of(a, b, r), new ComparisonExpression( LESS_THAN_OR_EQUAL, - new FunctionCall(TEST_ST_DISTANCE_FUNCTION.toQualifiedName(), ImmutableList.of(a.toSymbolReference(), b.toSymbolReference())), + new FunctionCall(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(a.toSymbolReference(), b.toSymbolReference())), r.toSymbolReference())); }) .matches( spatialJoin( - new ComparisonExpression(LESS_THAN_OR_EQUAL, new FunctionCall(QualifiedName.of("st_distance"), ImmutableList.of(new SymbolReference("a"), new SymbolReference("b"))), new SymbolReference("r")), + new ComparisonExpression(LESS_THAN_OR_EQUAL, new FunctionCall(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(new SymbolReference("a"), new SymbolReference("b"))), new SymbolReference("r")), Optional.empty(), Optional.of(ImmutableList.of("a", "b", "r")), strictProject( @@ -120,7 +119,7 @@ public void testDoNotPruneOneOutputOrFilterSymbols() p.values(a), p.values(b, r, output), ImmutableList.of(output), - new ComparisonExpression(LESS_THAN_OR_EQUAL, new FunctionCall(QualifiedName.of("st_distance"), ImmutableList.of(new SymbolReference("a"), new SymbolReference("b"))), new SymbolReference("r"))); + new ComparisonExpression(LESS_THAN_OR_EQUAL, new FunctionCall(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(new SymbolReference("a"), new SymbolReference("b"))), new SymbolReference("r"))); }) .doesNotFire(); } @@ -140,7 +139,7 @@ public void testDoNotPrunePartitionSymbols() p.values(a, leftPartitionSymbol), p.values(b, r, rightPartitionSymbol), ImmutableList.of(a, b, r), - new ComparisonExpression(LESS_THAN_OR_EQUAL, new FunctionCall(QualifiedName.of("st_distance"), ImmutableList.of(new SymbolReference("a"), new SymbolReference("b"))), new SymbolReference("r")), + new ComparisonExpression(LESS_THAN_OR_EQUAL, new FunctionCall(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(new SymbolReference("a"), new SymbolReference("b"))), new SymbolReference("r")), Optional.of(leftPartitionSymbol), Optional.of(rightPartitionSymbol), Optional.of("some nice kdb tree")); diff --git a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestPruneSpatialJoinColumns.java b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestPruneSpatialJoinColumns.java index e59cf66efc00..fbb3702084ad 100644 --- a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestPruneSpatialJoinColumns.java +++ b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestPruneSpatialJoinColumns.java @@ -26,7 +26,6 @@ import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.SpatialJoinNode; -import io.trino.sql.tree.QualifiedName; import org.junit.jupiter.api.Test; import java.util.Optional; @@ -61,14 +60,14 @@ public void notAllOutputsReferenced() ImmutableList.of(a, b, r), new ComparisonExpression( LESS_THAN_OR_EQUAL, - new FunctionCall(TEST_ST_DISTANCE_FUNCTION.toQualifiedName(), ImmutableList.of(a.toSymbolReference(), b.toSymbolReference())), + new FunctionCall(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(a.toSymbolReference(), b.toSymbolReference())), r.toSymbolReference()))); }) .matches( strictProject( ImmutableMap.of("a", PlanMatchPattern.expression(new SymbolReference("a"))), spatialJoin( - new ComparisonExpression(LESS_THAN_OR_EQUAL, new FunctionCall(QualifiedName.of("st_distance"), ImmutableList.of(new SymbolReference("a"), new SymbolReference("b"))), new SymbolReference("r")), + new ComparisonExpression(LESS_THAN_OR_EQUAL, new FunctionCall(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(new SymbolReference("a"), new SymbolReference("b"))), new SymbolReference("r")), Optional.empty(), Optional.of(ImmutableList.of("a")), values("a"), @@ -90,7 +89,7 @@ public void allOutputsReferenced() p.values(a), p.values(b, r), ImmutableList.of(a, b, r), - new ComparisonExpression(LESS_THAN_OR_EQUAL, new FunctionCall(QualifiedName.of("st_distance"), ImmutableList.of(new SymbolReference("a"), new SymbolReference("b"))), new SymbolReference("r")))); + new ComparisonExpression(LESS_THAN_OR_EQUAL, new FunctionCall(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(new SymbolReference("a"), new SymbolReference("b"))), new SymbolReference("r")))); }) .doesNotFire(); } diff --git a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestRewriteSpatialPartitioningAggregation.java b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestRewriteSpatialPartitioningAggregation.java index 6d61dedbf4a3..7043bc5cb91a 100644 --- a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestRewriteSpatialPartitioningAggregation.java +++ b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestRewriteSpatialPartitioningAggregation.java @@ -15,6 +15,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TestingFunctionResolution; import io.trino.sql.ir.Constant; import io.trino.sql.ir.FunctionCall; import io.trino.sql.ir.SymbolReference; @@ -23,11 +25,11 @@ import io.trino.sql.planner.iterative.rule.test.PlanBuilder; import io.trino.sql.planner.iterative.rule.test.RuleBuilder; import io.trino.sql.planner.plan.AggregationNode; -import io.trino.sql.tree.QualifiedName; import org.junit.jupiter.api.Test; import static io.trino.plugin.geospatial.GeometryType.GEOMETRY; import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregationFunction; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; @@ -37,6 +39,9 @@ public class TestRewriteSpatialPartitioningAggregation extends BaseRuleTest { + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(new GeoPlugin()); + private static final ResolvedFunction ST_ENVELOPE = FUNCTIONS.resolveFunction("st_envelope", fromTypes(GEOMETRY)); + public TestRewriteSpatialPartitioningAggregation() { super(new GeoPlugin()); @@ -68,7 +73,7 @@ public void test() ImmutableMap.of("sp", aggregationFunction("spatial_partitioning", ImmutableList.of("envelope", "partition_count"))), project( ImmutableMap.of("partition_count", expression(new Constant(INTEGER, 100L)), - "envelope", expression(new FunctionCall(QualifiedName.of("st_envelope"), ImmutableList.of(new SymbolReference("geometry"))))), + "envelope", expression(new FunctionCall(ST_ENVELOPE, ImmutableList.of(new SymbolReference("geometry"))))), values("geometry")))); assertRuleApplication() @@ -82,7 +87,7 @@ public void test() ImmutableMap.of("sp", aggregationFunction("spatial_partitioning", ImmutableList.of("envelope", "partition_count"))), project( ImmutableMap.of("partition_count", expression(new Constant(INTEGER, 100L)), - "envelope", expression(new FunctionCall(QualifiedName.of("st_envelope"), ImmutableList.of(new SymbolReference("geometry"))))), + "envelope", expression(new FunctionCall(ST_ENVELOPE, ImmutableList.of(new SymbolReference("geometry"))))), values("geometry")))); } diff --git a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestSpatialJoinPlanning.java b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestSpatialJoinPlanning.java index c5ce7eb92ff2..51c7d7d44307 100644 --- a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestSpatialJoinPlanning.java +++ b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestSpatialJoinPlanning.java @@ -20,6 +20,8 @@ import io.trino.geospatial.KdbTree; import io.trino.geospatial.KdbTreeUtils; import io.trino.geospatial.Rectangle; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TestingFunctionResolution; import io.trino.plugin.memory.MemoryConnectorFactory; import io.trino.plugin.tpch.TpchConnectorFactory; import io.trino.spi.TrinoException; @@ -38,7 +40,6 @@ import io.trino.sql.planner.assertions.BasePlanTest; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.plan.ExchangeNode; -import io.trino.sql.tree.QualifiedName; import io.trino.testing.PlanTester; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; @@ -51,6 +52,7 @@ import static io.trino.SystemSessionProperties.SPATIAL_PARTITIONING_TABLE_NAME; import static io.trino.geospatial.KdbTree.Node.newLeaf; import static io.trino.plugin.geospatial.GeometryType.GEOMETRY; +import static io.trino.plugin.geospatial.KdbTreeType.KDB_TREE; import static io.trino.spi.StandardErrorCode.INVALID_SPATIAL_PARTITIONING; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DoubleType.DOUBLE; @@ -84,7 +86,18 @@ public class TestSpatialJoinPlanning extends BasePlanTest { private static final String KDB_TREE_JSON = KdbTreeUtils.toJson(new KdbTree(newLeaf(new Rectangle(0, 0, 10, 10), 0))); - private static final Expression KDB_TREE_LITERAL = new Constant(KdbTreeType.KDB_TREE, KdbTreeUtils.fromJson(KDB_TREE_JSON)); + private static final Expression KDB_TREE_LITERAL = new Constant(KDB_TREE, KdbTreeUtils.fromJson(KDB_TREE_JSON)); + + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(new GeoPlugin()); + private static final ResolvedFunction RANDOM = FUNCTIONS.resolveFunction("random", fromTypes()); + private static final ResolvedFunction SPATIAL_PARTITIONS = FUNCTIONS.resolveFunction("spatial_partitions", fromTypes(KDB_TREE, GEOMETRY)); + private static final ResolvedFunction ST_CONTAINS = FUNCTIONS.resolveFunction("st_contains", fromTypes(GEOMETRY, GEOMETRY)); + private static final ResolvedFunction ST_INTERSECTS = FUNCTIONS.resolveFunction("st_intersects", fromTypes(GEOMETRY, GEOMETRY)); + private static final ResolvedFunction ST_WITHIN = FUNCTIONS.resolveFunction("st_within", fromTypes(GEOMETRY, GEOMETRY)); + private static final ResolvedFunction ST_POINT = FUNCTIONS.resolveFunction("st_point", fromTypes(DOUBLE, DOUBLE)); + private static final ResolvedFunction ST_GEOMETRY_FROM_TEXT = FUNCTIONS.resolveFunction("st_geometryfromtext", fromTypes(VARCHAR)); + private static final ResolvedFunction LENGTH = FUNCTIONS.resolveFunction("length", fromTypes(VARCHAR)); + private static final ResolvedFunction CONCAT = FUNCTIONS.resolveFunction("concat", fromTypes(VARCHAR, VARCHAR)); @Override protected PlanTester createPlanTester() @@ -111,11 +124,11 @@ public void testSpatialJoinContains() "WHERE ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))", anyTree( spatialJoin( - new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), - project(ImmutableMap.of("st_point", expression(new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), + new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), + project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name_a", "name"))), anyTree( - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt"), VARCHAR))))), + project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference("wkt"), VARCHAR))))), tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_b", "name"))))))); // Verify that projections generated by the ExtractSpatialJoins rule @@ -124,15 +137,15 @@ public void testSpatialJoinContains() "FROM (SELECT length(name), * FROM points), (SELECT length(name), * FROM polygons) " + "WHERE ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))", anyTree( - spatialJoin(new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), + spatialJoin(new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), project(ImmutableMap.of( - "st_point", expression(new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat")))), - "length", expression(new FunctionCall(QualifiedName.of("length"), ImmutableList.of(new SymbolReference("name"))))), + "st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat")))), + "length", expression(new FunctionCall(LENGTH, ImmutableList.of(new SymbolReference("name"))))), tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name", "name"))), anyTree( project(ImmutableMap.of( - "st_geometryfromtext", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt"), VARCHAR)))), - "length_2", expression(new FunctionCall(QualifiedName.of("length"), ImmutableList.of(new SymbolReference("name_2"))))), + "st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference("wkt"), VARCHAR)))), + "length_2", expression(new FunctionCall(LENGTH, ImmutableList.of(new SymbolReference("name_2"))))), tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_2", "name"))))))); // distributed @@ -142,18 +155,18 @@ public void testSpatialJoinContains() withSpatialPartitioning("kdb_tree"), anyTree( spatialJoin( - new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), + new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), Optional.of(KDB_TREE_JSON), Optional.empty(), anyTree( unnest( - project(ImmutableMap.of("partitions_a", expression(new FunctionCall(QualifiedName.of("spatial_partitions"), ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference("st_point"))))), - project(ImmutableMap.of("st_point", expression(new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), + project(ImmutableMap.of("partitions_a", expression(new FunctionCall(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference("st_point"))))), + project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name", "name")))))), anyTree( unnest( - project(ImmutableMap.of("partitions_b", expression(new FunctionCall(QualifiedName.of("spatial_partitions"), ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference("st_geometryfromtext"))))), - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt"), VARCHAR))))), + project(ImmutableMap.of("partitions_b", expression(new FunctionCall(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference("st_geometryfromtext"))))), + project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference("wkt"), VARCHAR))))), tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_2", "name"))))))))); } @@ -166,11 +179,11 @@ public void testSpatialJoinWithin() "WHERE ST_Within(ST_GeometryFromText(wkt), ST_Point(lng, lat))", anyTree( spatialJoin( - new FunctionCall(QualifiedName.of("st_within"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), - project(ImmutableMap.of("st_point", expression(new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), + new FunctionCall(ST_WITHIN, ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), + project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name_a", "name"))), anyTree( - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt"), VARCHAR))))), + project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference("wkt"), VARCHAR))))), tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_b", "name"))))))); // Verify that projections generated by the ExtractSpatialJoins rule @@ -179,15 +192,15 @@ public void testSpatialJoinWithin() "FROM (SELECT length(name), * FROM points), (SELECT length(name), * FROM polygons) " + "WHERE ST_Within(ST_GeometryFromText(wkt), ST_Point(lng, lat))", anyTree( - spatialJoin(new FunctionCall(QualifiedName.of("st_within"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), + spatialJoin(new FunctionCall(ST_WITHIN, ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), project(ImmutableMap.of( - "st_point", expression(new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat")))), - "length", expression(new FunctionCall(QualifiedName.of("length"), ImmutableList.of(new SymbolReference("name"))))), + "st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat")))), + "length", expression(new FunctionCall(LENGTH, ImmutableList.of(new SymbolReference("name"))))), tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name", "name"))), anyTree( project(ImmutableMap.of( - "st_geometryfromtext", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt"), VARCHAR)))), - "length_2", expression(new FunctionCall(QualifiedName.of("length"), ImmutableList.of(new SymbolReference("name_2"))))), + "st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference("wkt"), VARCHAR)))), + "length_2", expression(new FunctionCall(LENGTH, ImmutableList.of(new SymbolReference("name_2"))))), tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_2", "name"))))))); // distributed @@ -197,18 +210,18 @@ public void testSpatialJoinWithin() withSpatialPartitioning("kdb_tree"), anyTree( spatialJoin( - new FunctionCall(QualifiedName.of("st_within"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), + new FunctionCall(ST_WITHIN, ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), Optional.of(KDB_TREE_JSON), Optional.empty(), anyTree( unnest( - project(ImmutableMap.of("partitions_a", expression(new FunctionCall(QualifiedName.of("spatial_partitions"), ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference("st_point"))))), - project(ImmutableMap.of("st_point", expression(new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), + project(ImmutableMap.of("partitions_a", expression(new FunctionCall(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference("st_point"))))), + project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name_a", "name")))))), anyTree( unnest( - project(ImmutableMap.of("partitions_b", expression(new FunctionCall(QualifiedName.of("spatial_partitions"), ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference("st_geometryfromtext"))))), - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt"), VARCHAR))))), + project(ImmutableMap.of("partitions_b", expression(new FunctionCall(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference("st_geometryfromtext"))))), + project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference("wkt"), VARCHAR))))), tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_b", "name"))))))))); } @@ -288,11 +301,11 @@ public void testSpatialJoinIntersects() "WHERE ST_Intersects(ST_GeometryFromText(a.wkt), ST_GeometryFromText(b.wkt))", anyTree( spatialJoin( - new FunctionCall(QualifiedName.of("st_intersects"), ImmutableList.of(new SymbolReference("geometry_a"), new SymbolReference("geometry_b"))), - project(ImmutableMap.of("geometry_a", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt_a"), VARCHAR))))), + new FunctionCall(ST_INTERSECTS, ImmutableList.of(new SymbolReference("geometry_a"), new SymbolReference("geometry_b"))), + project(ImmutableMap.of("geometry_a", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference("wkt_a"), VARCHAR))))), tableScan("polygons", ImmutableMap.of("wkt_a", "wkt", "name_a", "name"))), anyTree( - project(ImmutableMap.of("geometry_b", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt_b"), VARCHAR))))), + project(ImmutableMap.of("geometry_b", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference("wkt_b"), VARCHAR))))), tableScan("polygons", ImmutableMap.of("wkt_b", "wkt", "name_b", "name"))))))); // distributed @@ -302,15 +315,15 @@ public void testSpatialJoinIntersects() withSpatialPartitioning("default.kdb_tree"), anyTree( spatialJoin( - new FunctionCall(QualifiedName.of("st_intersects"), ImmutableList.of(new SymbolReference("geometry_a"), new SymbolReference("geometry_b"))), Optional.of(KDB_TREE_JSON), Optional.empty(), + new FunctionCall(ST_INTERSECTS, ImmutableList.of(new SymbolReference("geometry_a"), new SymbolReference("geometry_b"))), Optional.of(KDB_TREE_JSON), Optional.empty(), anyTree( unnest( - project(ImmutableMap.of("partitions_a", expression(new FunctionCall(QualifiedName.of("spatial_partitions"), ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference("geometry_a"))))), - project(ImmutableMap.of("geometry_a", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt_a"), VARCHAR))))), + project(ImmutableMap.of("partitions_a", expression(new FunctionCall(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference("geometry_a"))))), + project(ImmutableMap.of("geometry_a", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference("wkt_a"), VARCHAR))))), tableScan("polygons", ImmutableMap.of("wkt_a", "wkt", "name_a", "name")))))), anyTree( - project(ImmutableMap.of("partitions_b", expression(new FunctionCall(QualifiedName.of("spatial_partitions"), ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference("geometry_b"))))), - project(ImmutableMap.of("geometry_b", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt_b"), VARCHAR))))), + project(ImmutableMap.of("partitions_b", expression(new FunctionCall(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference("geometry_b"))))), + project(ImmutableMap.of("geometry_b", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference("wkt_b"), VARCHAR))))), tableScan("polygons", ImmutableMap.of("wkt_b", "wkt", "name_b", "name")))))))); } @@ -322,7 +335,7 @@ public void testNotContains() "WHERE NOT ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))", anyTree( filter( - new NotExpression(new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt"), VARCHAR))), new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat")))))), + new NotExpression(new FunctionCall(ST_CONTAINS, ImmutableList.of(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference("wkt"), VARCHAR))), new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat")))))), join(INNER, builder -> builder .left(tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name_a", "name"))) .right( @@ -348,13 +361,13 @@ public void testNotIntersects() .left( project( ImmutableMap.of( - "wkt_a", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new FunctionCall(QualifiedName.of("random"), ImmutableList.of()), new Constant(DOUBLE, 0.0)), new Constant(createVarcharType(45), Slices.utf8Slice("POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))")))), Optional.empty())), + "wkt_a", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new FunctionCall(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 0.0)), new Constant(createVarcharType(45), Slices.utf8Slice("POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))")))), Optional.empty())), "name_a", expression(new Constant(createVarcharType(1), Slices.utf8Slice("a")))), singleRow())) .right( any(project( ImmutableMap.of( - "wkt_b", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new FunctionCall(QualifiedName.of("random"), ImmutableList.of()), new Constant(DOUBLE, 0.0)), new Constant(createVarcharType(45), Slices.utf8Slice("POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))")))), Optional.empty())), + "wkt_b", expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new FunctionCall(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 0.0)), new Constant(createVarcharType(45), Slices.utf8Slice("POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))")))), Optional.empty())), "name_b", expression(new Constant(createVarcharType(1), Slices.utf8Slice("a")))), singleRow()))))))); } @@ -368,7 +381,7 @@ public void testContainsWithEquiClause() anyTree( join(INNER, builder -> builder .equiCriteria("name_a", "name_b") - .filter(new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt"), VARCHAR))), new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat")))))) + .filter(new FunctionCall(ST_CONTAINS, ImmutableList.of(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference("wkt"), VARCHAR))), new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat")))))) .left( anyTree( tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name_a", "name")))) @@ -386,7 +399,7 @@ public void testIntersectsWithEquiClause() anyTree( join(INNER, builder -> builder .equiCriteria("name_a", "name_b") - .filter(new FunctionCall(QualifiedName.of("st_intersects"), ImmutableList.of(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt_a"), VARCHAR))), new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt_b"), VARCHAR)))))) + .filter(new FunctionCall(ST_INTERSECTS, ImmutableList.of(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference("wkt_a"), VARCHAR))), new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference("wkt_b"), VARCHAR)))))) .left( anyTree( tableScan("polygons", ImmutableMap.of("wkt_a", "wkt", "name_a", "name")))) @@ -403,11 +416,11 @@ public void testSpatialLeftJoins() "ON ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))", anyTree( spatialLeftJoin( - new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), - project(ImmutableMap.of("st_point", expression(new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), + new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), + project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name_a", "name"))), anyTree( - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt"), VARCHAR))))), + project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference("wkt"), VARCHAR))))), tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_b", "name"))))))); // deterministic extra join predicate @@ -416,11 +429,11 @@ public void testSpatialLeftJoins() "ON ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat)) AND a.name <> b.name", anyTree( spatialLeftJoin( - new LogicalExpression(AND, ImmutableList.of(new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), new ComparisonExpression(NOT_EQUAL, new SymbolReference("name_a"), new SymbolReference("name_b")))), - project(ImmutableMap.of("st_point", expression(new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), + new LogicalExpression(AND, ImmutableList.of(new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), new ComparisonExpression(NOT_EQUAL, new SymbolReference("name_a"), new SymbolReference("name_b")))), + project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name_a", "name"))), anyTree( - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt"), VARCHAR))))), + project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference("wkt"), VARCHAR))))), tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_b", "name"))))))); // non-deterministic extra join predicate @@ -429,11 +442,11 @@ public void testSpatialLeftJoins() "ON ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat)) AND rand() < 0.5", anyTree( spatialLeftJoin( - new LogicalExpression(AND, ImmutableList.of(new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), new ComparisonExpression(LESS_THAN, new FunctionCall(QualifiedName.of("random"), ImmutableList.of()), new Constant(DOUBLE, 0.5)))), - project(ImmutableMap.of("st_point", expression(new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), + new LogicalExpression(AND, ImmutableList.of(new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), new ComparisonExpression(LESS_THAN, new FunctionCall(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 0.5)))), + project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name_a", "name"))), anyTree( - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt"), VARCHAR))))), + project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference("wkt"), VARCHAR))))), tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_b", "name"))))))); // filter over join @@ -443,13 +456,13 @@ public void testSpatialLeftJoins() "WHERE concat(a.name, b.name) is null", anyTree( filter( - new IsNullPredicate(new FunctionCall(QualifiedName.of("concat"), ImmutableList.of(new Cast(new SymbolReference("name_a"), VARCHAR), new Cast(new SymbolReference("name_b"), VARCHAR)))), + new IsNullPredicate(new FunctionCall(CONCAT, ImmutableList.of(new Cast(new SymbolReference("name_a"), VARCHAR), new Cast(new SymbolReference("name_b"), VARCHAR)))), spatialLeftJoin( - new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), - project(ImmutableMap.of("st_point", expression(new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), + new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), + project(ImmutableMap.of("st_point", expression(new FunctionCall(ST_POINT, ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name_a", "name"))), anyTree( - project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt"), VARCHAR))))), + project(ImmutableMap.of("st_geometryfromtext", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference("wkt"), VARCHAR))))), tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_b", "name")))))))); } @@ -462,19 +475,19 @@ public void testDistributedSpatialJoinOverUnion() "WHERE ST_Contains(ST_GeometryFromText(a.name), ST_GeometryFromText(b.name))", withSpatialPartitioning("kdb_tree"), anyTree( - spatialJoin(new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("g1"), new SymbolReference("g3"))), Optional.of(KDB_TREE_JSON), Optional.empty(), + spatialJoin(new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("g1"), new SymbolReference("g3"))), Optional.of(KDB_TREE_JSON), Optional.empty(), anyTree( unnest(exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.REPARTITION, - project(ImmutableMap.of("p1", expression(new FunctionCall(QualifiedName.of("spatial_partitions"), ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference("g1"))))), - project(ImmutableMap.of("g1", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("name_a1"), VARCHAR))))), + project(ImmutableMap.of("p1", expression(new FunctionCall(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference("g1"))))), + project(ImmutableMap.of("g1", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference("name_a1"), VARCHAR))))), tableScan("region", ImmutableMap.of("name_a1", "name")))), - project(ImmutableMap.of("p2", expression(new FunctionCall(QualifiedName.of("spatial_partitions"), ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference("g2"))))), - project(ImmutableMap.of("g2", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("name_a2"), VARCHAR))))), + project(ImmutableMap.of("p2", expression(new FunctionCall(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference("g2"))))), + project(ImmutableMap.of("g2", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference("name_a2"), VARCHAR))))), tableScan("nation", ImmutableMap.of("name_a2", "name"))))))), anyTree( unnest( - project(ImmutableMap.of("p3", expression(new FunctionCall(QualifiedName.of("spatial_partitions"), ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference("g3"))))), - project(ImmutableMap.of("g3", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("name_b"), VARCHAR))))), + project(ImmutableMap.of("p3", expression(new FunctionCall(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference("g3"))))), + project(ImmutableMap.of("g3", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference("name_b"), VARCHAR))))), tableScan("customer", ImmutableMap.of("name_b", "name"))))))))); // union on the right side @@ -484,21 +497,21 @@ public void testDistributedSpatialJoinOverUnion() withSpatialPartitioning("kdb_tree"), anyTree( spatialJoin( - new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("g1"), new SymbolReference("g2"))), + new FunctionCall(ST_CONTAINS, ImmutableList.of(new SymbolReference("g1"), new SymbolReference("g2"))), Optional.of(KDB_TREE_JSON), Optional.empty(), anyTree( unnest( - project(ImmutableMap.of("p1", expression(new FunctionCall(QualifiedName.of("spatial_partitions"), ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference("g1"))))), - project(ImmutableMap.of("g1", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("name_a"), VARCHAR))))), + project(ImmutableMap.of("p1", expression(new FunctionCall(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference("g1"))))), + project(ImmutableMap.of("g1", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference("name_a"), VARCHAR))))), tableScan("customer", ImmutableMap.of("name_a", "name")))))), anyTree( unnest(exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.REPARTITION, - project(ImmutableMap.of("p2", expression(new FunctionCall(QualifiedName.of("spatial_partitions"), ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference("g2"))))), - project(ImmutableMap.of("g2", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("name_b1"), VARCHAR))))), + project(ImmutableMap.of("p2", expression(new FunctionCall(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference("g2"))))), + project(ImmutableMap.of("g2", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference("name_b1"), VARCHAR))))), tableScan("region", ImmutableMap.of("name_b1", "name")))), - project(ImmutableMap.of("p3", expression(new FunctionCall(QualifiedName.of("spatial_partitions"), ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference("g3"))))), - project(ImmutableMap.of("g3", expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("name_b2"), VARCHAR))))), + project(ImmutableMap.of("p3", expression(new FunctionCall(SPATIAL_PARTITIONS, ImmutableList.of(KDB_TREE_LITERAL, new SymbolReference("g3"))))), + project(ImmutableMap.of("g3", expression(new FunctionCall(ST_GEOMETRY_FROM_TEXT, ImmutableList.of(new Cast(new SymbolReference("name_b2"), VARCHAR))))), tableScan("nation", ImmutableMap.of("name_b2", "name")))))))))); } @@ -537,6 +550,6 @@ private static String doubleLiteral(double value) private FunctionCall functionCall(String name, List types, List arguments) { - return new FunctionCall(getPlanTester().getPlannerContext().getMetadata().resolveBuiltinFunction(name, fromTypes(types)).toQualifiedName(), arguments); + return new FunctionCall(getPlanTester().getPlannerContext().getMetadata().resolveBuiltinFunction(name, fromTypes(types)), arguments); } } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestHivePlans.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestHivePlans.java index 0cb938e9f6fa..4cfdc07ea2a7 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestHivePlans.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestHivePlans.java @@ -16,6 +16,8 @@ import com.google.common.collect.ImmutableList; import io.airlift.slice.Slices; import io.trino.Session; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TestingFunctionResolution; import io.trino.plugin.hive.TestingHiveConnectorFactory; import io.trino.plugin.hive.metastore.Database; import io.trino.plugin.hive.metastore.HiveMetastore; @@ -32,7 +34,6 @@ import io.trino.sql.planner.OptimizerConfig.JoinDistributionType; import io.trino.sql.planner.OptimizerConfig.JoinReorderingStrategy; import io.trino.sql.planner.assertions.BasePlanTest; -import io.trino.sql.tree.QualifiedName; import io.trino.testing.PlanTester; import io.trino.type.LikePattern; import org.junit.jupiter.api.AfterAll; @@ -53,7 +54,9 @@ import static io.trino.plugin.hive.TestingHiveUtils.getConnectorService; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; +import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.ir.ArithmeticBinaryExpression.Operator.MODULUS; import static io.trino.sql.ir.BooleanLiteral.TRUE_LITERAL; import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL; @@ -83,6 +86,10 @@ public class TestHivePlans .setSchema(SCHEMA_NAME) .build(); + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction LIKE = FUNCTIONS.resolveFunction("$like", fromTypes(VARCHAR, LIKE_PATTERN)); + private static final ResolvedFunction SUBSTRING = FUNCTIONS.resolveFunction("substring", fromTypes(VARCHAR, BIGINT)); + private File baseDir; @Override @@ -147,7 +154,7 @@ public void testPruneSimplePartitionLikeFilter() "SELECT * FROM table_str_partitioned WHERE str_part LIKE 't%'", output( filter( - new FunctionCall(QualifiedName.of("$like"), ImmutableList.of(new SymbolReference("STR_PART"), new Constant(LIKE_PATTERN, LikePattern.compile("t%", Optional.empty())))), + new FunctionCall(LIKE, ImmutableList.of(new SymbolReference("STR_PART"), new Constant(LIKE_PATTERN, LikePattern.compile("t%", Optional.empty())))), tableScan("table_str_partitioned", Map.of("INT_COL", "int_col", "STR_PART", "str_part"))))); } @@ -169,12 +176,12 @@ public void testPrunePartitionLikeFilter() .left( exchange(REMOTE, REPARTITION, filter( - new FunctionCall(QualifiedName.of("$like"), ImmutableList.of(new SymbolReference("L_STR_PART"), new Constant(LIKE_PATTERN, LikePattern.compile("t%", Optional.empty())))), + new FunctionCall(LIKE, ImmutableList.of(new SymbolReference("L_STR_PART"), new Constant(LIKE_PATTERN, LikePattern.compile("t%", Optional.empty())))), tableScan("table_str_partitioned", Map.of("L_INT_COL", "int_col", "L_STR_PART", "str_part"))))) .right(exchange(LOCAL, exchange(REMOTE, REPARTITION, filter( - new LogicalExpression(AND, ImmutableList.of(new InPredicate(new SymbolReference("R_STR_COL"), ImmutableList.of(new Constant(createVarcharType(5), Slices.utf8Slice("three")), new Constant(createVarcharType(5), Slices.utf8Slice("two")))), new FunctionCall(QualifiedName.of("$like"), ImmutableList.of(new SymbolReference("R_STR_COL"), new Constant(LIKE_PATTERN, LikePattern.compile("t%", Optional.empty())))))), + new LogicalExpression(AND, ImmutableList.of(new InPredicate(new SymbolReference("R_STR_COL"), ImmutableList.of(new Constant(createVarcharType(5), Slices.utf8Slice("three")), new Constant(createVarcharType(5), Slices.utf8Slice("two")))), new FunctionCall(LIKE, ImmutableList.of(new SymbolReference("R_STR_COL"), new Constant(LIKE_PATTERN, LikePattern.compile("t%", Optional.empty())))))), tableScan("table_unpartitioned", Map.of("R_STR_COL", "str_col", "R_INT_COL", "int_col"))))))))); } @@ -245,7 +252,7 @@ public void testSubsumePartitionPartWhenOtherFilterNotConvertibleToTupleDomain() .left( exchange(REMOTE, REPARTITION, filter( - new ComparisonExpression(NOT_EQUAL, new FunctionCall(QualifiedName.of("substring"), ImmutableList.of(new SymbolReference("L_STR_COL"), new Constant(BIGINT, 2L))), new Constant(createVarcharType(5), Slices.utf8Slice("hree"))), + new ComparisonExpression(NOT_EQUAL, new FunctionCall(SUBSTRING, ImmutableList.of(new SymbolReference("L_STR_COL"), new Constant(BIGINT, 2L))), new Constant(createVarcharType(5), Slices.utf8Slice("hree"))), tableScan("table_int_partitioned", Map.of("L_INT_PART", "int_part", "L_STR_COL", "str_col"))))) .right( exchange(LOCAL,