Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Desugar "not" during initial planning #21986

Merged
merged 1 commit into from
May 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import io.trino.sql.ir.IrVisitor;
import io.trino.sql.ir.IsNull;
import io.trino.sql.ir.Logical;
import io.trino.sql.ir.Not;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.IrExpressionInterpreter;
import io.trino.sql.planner.Symbol;
Expand All @@ -54,12 +53,14 @@
import static io.trino.cost.PlanNodeStatsEstimateMath.estimateCorrelatedConjunctionRowCount;
import static io.trino.cost.PlanNodeStatsEstimateMath.intersectCorrelatedStats;
import static io.trino.cost.PlanNodeStatsEstimateMath.subtractSubsetStats;
import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName;
import static io.trino.spi.statistics.StatsUtil.toStatsRepresentation;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.sql.DynamicFilters.isDynamicFilter;
import static io.trino.sql.ir.Comparison.Operator.EQUAL;
import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL;
import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL;
import static io.trino.sql.ir.IrExpressions.not;
import static io.trino.sql.ir.IrUtils.and;
import static io.trino.sql.planner.SymbolsExtractor.extractUnique;
import static java.lang.Double.NaN;
Expand Down Expand Up @@ -139,23 +140,6 @@ protected PlanNodeStatsEstimate visitExpression(Expression node, Void context)
return PlanNodeStatsEstimate.unknown();
}

@Override
protected PlanNodeStatsEstimate visitNot(Not node, Void context)
{
if (node.value() instanceof IsNull inner) {
if (inner.value() instanceof Reference) {
Symbol symbol = Symbol.from(inner.value());
SymbolStatsEstimate symbolStats = input.getSymbolStatistics(symbol);
PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(input);
result.setOutputRowCount(input.getOutputRowCount() * (1 - symbolStats.getNullsFraction()));
result.addSymbolStatistics(symbol, symbolStats.mapNullsFraction(x -> 0.0));
return result.build();
}
return PlanNodeStatsEstimate.unknown();
}
return subtractSubsetStats(input, process(node.value()));
}

@Override
protected PlanNodeStatsEstimate visitLogical(Logical node, Void context)
{
Expand Down Expand Up @@ -372,7 +356,7 @@ protected PlanNodeStatsEstimate visitComparison(Comparison node, Void context)
}

if (left instanceof Reference && left.equals(right)) {
return process(new Not(new IsNull(left)));
return process(not(plannerContext.getMetadata(), new IsNull(left)));
}

SymbolStatsEstimate leftStats = getExpressionStats(left);
Expand Down Expand Up @@ -404,6 +388,22 @@ protected PlanNodeStatsEstimate visitCall(Call node, Void context)
if (isDynamicFilter(node)) {
return process(Booleans.TRUE, context);
}
else if (node.function().name().equals(builtinFunctionName("$not"))) {
Expression argument = node.arguments().getFirst();
if (argument instanceof IsNull inner) {
if (inner.value() instanceof Reference) {
Symbol symbol = Symbol.from(inner.value());
SymbolStatsEstimate symbolStats = input.getSymbolStatistics(symbol);
PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(input);
result.setOutputRowCount(input.getOutputRowCount() * (1 - symbolStats.getNullsFraction()));
result.addSymbolStatistics(symbol, symbolStats.mapNullsFraction(x -> 0.0));
return result.build();
}
return PlanNodeStatsEstimate.unknown();
}
return subtractSubsetStats(input, process(argument));
}

return PlanNodeStatsEstimate.unknown();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import io.trino.Session;
import io.trino.cost.StatsCalculator.Context;
import io.trino.matching.Pattern;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Not;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.plan.FilterNode;
Expand All @@ -33,6 +33,7 @@
import static io.trino.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT;
import static io.trino.cost.SemiJoinStatsCalculator.computeAntiJoin;
import static io.trino.cost.SemiJoinStatsCalculator.computeSemiJoin;
import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName;
import static io.trino.sql.ir.IrUtils.combineConjuncts;
import static io.trino.sql.ir.IrUtils.extractConjuncts;
import static io.trino.sql.planner.plan.Patterns.filter;
Expand Down Expand Up @@ -134,14 +135,17 @@ private Optional<SemiJoinOutputFilter> extractSemiJoinOutputFilter(Expression pr
Expression remainingPredicate = combineConjuncts(conjuncts.stream()
.filter(conjunct -> conjunct != semiJoinOutputReference)
.collect(toImmutableList()));
boolean negated = semiJoinOutputReference instanceof Not;
boolean negated = semiJoinOutputReference instanceof Call call && call.function().name().equals(builtinFunctionName("$not"));
return Optional.of(new SemiJoinOutputFilter(negated, remainingPredicate));
}

private static boolean isSemiJoinOutputReference(Expression conjunct, Symbol semiJoinOutput)
{
Reference semiJoinOutputSymbolReference = semiJoinOutput.toSymbolReference();
return conjunct.equals(semiJoinOutputSymbolReference) || (conjunct instanceof Not not && not.value().equals(semiJoinOutputSymbolReference));
return conjunct.equals(semiJoinOutputSymbolReference) || (
conjunct instanceof Call call &&
call.function().name().equals(builtinFunctionName("$not")) &&
call.arguments().getFirst().equals(semiJoinOutputSymbolReference));
}

private static class SemiJoinOutputFilter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,6 @@ protected Void visitBind(Bind node, C context)
return null;
}

@Override
protected Void visitNot(Not node, C context)
{
process(node.value(), context);
return null;
}

@Override
protected Void visitCase(Case node, C context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
@JsonSubTypes.Type(value = IsNull.class, name = "isnull"),
@JsonSubTypes.Type(value = Lambda.class, name = "lambda"),
@JsonSubTypes.Type(value = Logical.class, name = "logical"),
@JsonSubTypes.Type(value = Not.class, name = "not"),
@JsonSubTypes.Type(value = NullIf.class, name = "nullif"),
@JsonSubTypes.Type(value = Row.class, name = "row"),
@JsonSubTypes.Type(value = Case.class, name = "case"),
Expand All @@ -47,7 +46,7 @@
public sealed interface Expression
permits Array, Between, Bind, Call, Case, Cast, Coalesce,
Comparison, Constant, FieldReference, In, IsNull, Lambda, Logical,
Not, NullIf, Reference, Row, Switch
NullIf, Reference, Row, Switch
{
Type type();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,6 @@ protected String visitLogical(Logical node, Void context)
")";
}

@Override
protected String visitNot(Not node, Void context)
{
return "(NOT " + process(node.value(), context) + ")";
}

@Override
protected String visitComparison(Comparison node, Void context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,6 @@ public Expression rewriteLogical(Logical node, C context, ExpressionTreeRewriter
return rewriteExpression(node, context, treeRewriter);
}

public Expression rewriteNot(Not node, C context, ExpressionTreeRewriter<C> treeRewriter)
{
return rewriteExpression(node, context, treeRewriter);
}

public Expression rewriteIsNull(IsNull node, C context, ExpressionTreeRewriter<C> treeRewriter)
{
return rewriteExpression(node, context, treeRewriter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,25 +194,6 @@ public Expression visitLogical(Logical node, Context<C> context)
return node;
}

@Override
public Expression visitNot(Not node, Context<C> context)
{
if (!context.isDefaultRewrite()) {
Expression result = rewriter.rewriteNot(node, context.get(), ExpressionTreeRewriter.this);
if (result != null) {
return result;
}
}

Expression value = rewrite(node.value(), context.get());

if (value != node.value()) {
return new Not(value);
}

return node;
}

@Override
protected Expression visitIsNull(IsNull node, Context<C> context)
{
Expand Down
12 changes: 11 additions & 1 deletion core/trino-main/src/main/java/io/trino/sql/ir/IrExpressions.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@
package io.trino.sql.ir;

import com.google.common.collect.ImmutableList;
import io.trino.metadata.Metadata;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.function.CatalogSchemaFunctionName;
import io.trino.sql.PlannerContext;
import io.trino.type.TypeCoercion;

import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static io.trino.sql.DynamicFilters.isDynamicFilterFunction;
import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes;
import static io.trino.type.LikeFunctions.LIKE_FUNCTION_NAME;

public class IrExpressions
Expand Down Expand Up @@ -56,7 +59,6 @@ public static boolean mayFail(PlannerContext plannerContext, Expression expressi
case IsNull e -> mayFail(plannerContext, e.value());
case Lambda e -> false;
case Logical e -> e.terms().stream().anyMatch(argument -> mayFail(plannerContext, argument));
case Not e -> mayFail(plannerContext, e.value());
case NullIf e -> mayFail(plannerContext, e.first()) || mayFail(plannerContext, e.second());
case Reference e -> false;
case Row e -> e.items().stream().anyMatch(argument -> mayFail(plannerContext, argument));
Expand Down Expand Up @@ -86,8 +88,16 @@ private static boolean mayFail(ResolvedFunction function)
CatalogSchemaFunctionName name = function.name();
return !name.equals(builtinFunctionName("length")) &&
!name.equals(builtinFunctionName("try_cast")) &&
!name.equals(builtinFunctionName("$not")) &&
!name.equals(builtinFunctionName("substring")) &&
!name.equals(builtinFunctionName(LIKE_FUNCTION_NAME)) &&
!isDynamicFilterFunction(function.name());
}

public static Expression not(Metadata metadata, Expression expression)
{
return new Call(
metadata.resolveBuiltinFunction("$not", fromTypes(BOOLEAN)),
ImmutableList.of(expression));
}
}
5 changes: 0 additions & 5 deletions core/trino-main/src/main/java/io/trino/sql/ir/IrVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,6 @@ protected R visitNullIf(NullIf node, C context)
return visitExpression(node, context);
}

protected R visitNot(Not node, C context)
{
return visitExpression(node, context);
}

protected R visitCase(Case node, C context)
{
return visitExpression(node, context);
Expand Down
57 changes: 0 additions & 57 deletions core/trino-main/src/main/java/io/trino/sql/ir/Not.java

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
import io.trino.sql.ir.IrVisitor;
import io.trino.sql.ir.IsNull;
import io.trino.sql.ir.Logical;
import io.trino.sql.ir.Not;
import io.trino.sql.ir.NullIf;
import io.trino.sql.ir.Reference;
import io.trino.sql.tree.QualifiedName;
Expand Down Expand Up @@ -103,6 +102,7 @@
import static io.trino.spi.type.VarcharType.createVarcharType;
import static io.trino.sql.DynamicFilters.isDynamicFilterFunction;
import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes;
import static io.trino.sql.ir.IrExpressions.not;
import static io.trino.sql.ir.IrUtils.combineConjuncts;
import static io.trino.sql.ir.IrUtils.extractConjuncts;
import static io.trino.type.JoniRegexpType.JONI_REGEXP;
Expand Down Expand Up @@ -344,7 +344,7 @@ private Optional<Expression> translateIsNotNull(ConnectorExpression argument)
{
Optional<Expression> translatedArgument = translate(argument);
if (translatedArgument.isPresent()) {
return Optional.of(new Not(new IsNull(translatedArgument.get())));
return Optional.of(not(plannerContext.getMetadata(), new IsNull(translatedArgument.get())));
}

return Optional.empty();
Expand All @@ -364,7 +364,7 @@ private Optional<Expression> translateNot(ConnectorExpression argument)
{
Optional<Expression> translatedArgument = translate(argument);
if (argument.getType().equals(BOOLEAN) && translatedArgument.isPresent()) {
return Optional.of(new Not(translatedArgument.get()));
return Optional.of(not(plannerContext.getMetadata(), translatedArgument.get()));
}
return Optional.empty();
}
Expand Down Expand Up @@ -758,16 +758,6 @@ protected Optional<ConnectorExpression> visitIsNull(IsNull node, Void context)
return Optional.empty();
}

@Override
protected Optional<ConnectorExpression> visitNot(Not node, Void context)
{
Optional<ConnectorExpression> translatedValue = process(node.value());
if (translatedValue.isPresent()) {
return Optional.of(new io.trino.spi.expression.Call(BOOLEAN, NOT_FUNCTION_NAME, List.of(translatedValue.get())));
}
return Optional.empty();
}

private boolean isSpecialType(Type type)
{
return type.equals(JONI_REGEXP) ||
Expand Down
Loading
Loading