Skip to content

Commit

Permalink
Use ResolvedFunction in FunctionCall
Browse files Browse the repository at this point in the history
Replace the legacy use of QualifiedName to encode
a resolved function with a direct reference to
the ResolvedFunction instance.
  • Loading branch information
martint committed Mar 17, 2024
1 parent 5385cec commit 048f88f
Show file tree
Hide file tree
Showing 172 changed files with 982 additions and 1,176 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ private Expression simplifyExpression(Session session, Expression predicate, Typ
{
// TODO reuse io.trino.sql.planner.iterative.rule.SimplifyExpressions.rewrite

Map<NodeRef<Expression>, Type> expressionTypes = typeAnalyzer.getTypes(session, types, predicate);
Map<NodeRef<Expression>, Type> expressionTypes = typeAnalyzer.getTypes(types, predicate);
IrExpressionInterpreter interpreter = new IrExpressionInterpreter(predicate, plannerContext, session, expressionTypes);
Object value = interpreter.optimize(NoOpSymbolResolver.INSTANCE);

Expand Down Expand Up @@ -445,7 +445,7 @@ private Type getType(Expression expression)
return requireNonNull(types.get(symbol), () -> format("No type for symbol %s", symbol));
}

return typeAnalyzer.getType(session, types, expression);
return typeAnalyzer.getType(types, expression);
}

private SymbolStatsEstimate getExpressionStats(Expression expression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ protected SymbolStatsEstimate visitConstant(Constant node, Void context)
@Override
protected SymbolStatsEstimate visitFunctionCall(FunctionCall node, Void context)
{
Map<NodeRef<Expression>, Type> expressionTypes = typeAnalyzer.getTypes(session, types, node);
Map<NodeRef<Expression>, Type> expressionTypes = typeAnalyzer.getTypes(types, node);
IrExpressionInterpreter interpreter = new IrExpressionInterpreter(node, plannerContext, session, expressionTypes);
Object value = interpreter.optimize(NoOpSymbolResolver.INSTANCE);

Expand Down Expand Up @@ -148,7 +148,7 @@ protected SymbolStatsEstimate visitCast(Cast node, Void context)
double lowValue = sourceStats.getLowValue();
double highValue = sourceStats.getHighValue();

if (isIntegralType(typeAnalyzer.getType(session, types, node))) {
if (isIntegralType(typeAnalyzer.getType(types, node))) {
// todo handle low/high value changes if range gets narrower due to cast (e.g. BIGINT -> SMALLINT)
if (isFinite(lowValue)) {
lowValue = Math.round(lowValue);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import io.trino.Session;
import io.trino.cost.StatsCalculator.Context;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.NotExpression;
import io.trino.sql.ir.SymbolReference;
Expand Down Expand Up @@ -49,13 +48,11 @@ public class SimpleFilterProjectSemiJoinStatsRule
{
private static final Pattern<FilterNode> PATTERN = filter();

private final Metadata metadata;
private final FilterStatsCalculator filterStatsCalculator;

public SimpleFilterProjectSemiJoinStatsRule(Metadata metadata, StatsNormalizer normalizer, FilterStatsCalculator filterStatsCalculator)
public SimpleFilterProjectSemiJoinStatsRule(StatsNormalizer normalizer, FilterStatsCalculator filterStatsCalculator)
{
super(normalizer);
this.metadata = requireNonNull(metadata, "metadata is null");
this.filterStatsCalculator = requireNonNull(filterStatsCalculator, "filterStatsCalculator cannot be null");
}

Expand Down Expand Up @@ -135,7 +132,7 @@ private Optional<SemiJoinOutputFilter> extractSemiJoinOutputFilter(Expression pr
}

Expression semiJoinOutputReference = Iterables.getOnlyElement(semiJoinOutputReferences);
Expression remainingPredicate = combineConjuncts(metadata, conjuncts.stream()
Expression remainingPredicate = combineConjuncts(conjuncts.stream()
.filter(conjunct -> conjunct != semiJoinOutputReference)
.collect(toImmutableList()));
boolean negated = semiJoinOutputReference instanceof NotExpression;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public List<ComposableStatsCalculator.Rule<?>> get()

rules.add(new OutputStatsRule());
rules.add(new TableScanStatsRule(normalizer));
rules.add(new SimpleFilterProjectSemiJoinStatsRule(plannerContext.getMetadata(), normalizer, filterStatsCalculator)); // this must be before FilterStatsRule
rules.add(new SimpleFilterProjectSemiJoinStatsRule(normalizer, filterStatsCalculator)); // this must be before FilterStatsRule
rules.add(new FilterProjectAggregationStatsRule(normalizer, filterStatsCalculator)); // this must be before FilterStatsRule
rules.add(new FilterStatsRule(normalizer, filterStatsCalculator));
rules.add(new ValuesStatsRule(plannerContext));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,23 +145,25 @@ public static boolean isResolved(QualifiedName name)
return SerializedResolvedFunction.isSerializedResolvedFunction(name);
}

public CatalogSchemaFunctionName getName()
{
QualifiedName qualifiedName = toQualifiedName();
return SerializedResolvedFunction.fromSerializedName(qualifiedName).functionName();
}

@Deprecated
public QualifiedName toQualifiedName()
{
CatalogSchemaFunctionName name = toCatalogSchemaFunctionName();
return QualifiedName.of(name.getCatalogName(), name.getSchemaName(), name.getFunctionName());
}

@Deprecated
public CatalogSchemaFunctionName toCatalogSchemaFunctionName()
{
return ResolvedFunctionDecoder.toCatalogSchemaFunctionName(this);
}

public static CatalogSchemaFunctionName extractFunctionName(QualifiedName qualifiedName)
{
checkArgument(isResolved(qualifiedName), "Expected qualifiedName to be a resolved function: %s", qualifiedName);
return SerializedResolvedFunction.fromSerializedName(qualifiedName).functionName();
}

@Override
public boolean equals(Object o)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.metadata.Metadata;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.function.CatalogSchemaFunctionName;
import io.trino.spi.function.IsNull;
import io.trino.spi.function.ScalarFunction;
Expand Down Expand Up @@ -141,7 +140,7 @@ private static Symbol extractSourceSymbol(DynamicFilters.Descriptor descriptor)
public static Expression replaceDynamicFilterId(FunctionCall dynamicFilterFunctionCall, DynamicFilterId newId)
{
return new FunctionCall(
dynamicFilterFunctionCall.getName(),
dynamicFilterFunctionCall.getFunction(),
ImmutableList.of(
dynamicFilterFunctionCall.getArguments().get(0),
dynamicFilterFunctionCall.getArguments().get(1),
Expand Down Expand Up @@ -186,7 +185,7 @@ public static Optional<Descriptor> getDescriptor(Expression expression)

private static boolean isDynamicFilterFunction(FunctionCall functionCall)
{
return isDynamicFilterFunction(ResolvedFunction.extractFunctionName(functionCall.getName()));
return isDynamicFilterFunction(functionCall.getFunction().getName());
}

public static boolean isDynamicFilterFunction(CatalogSchemaFunctionName functionName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public static Object evaluateConstant(
io.trino.sql.ir.Expression rewritten = translationMap.rewrite(expression);

IrTypeAnalyzer analyzer = new IrTypeAnalyzer(plannerContext);
Map<io.trino.sql.ir.NodeRef<io.trino.sql.ir.Expression>, Type> types = analyzer.getTypes(session, TypeProvider.empty(), rewritten);
Map<io.trino.sql.ir.NodeRef<io.trino.sql.ir.Expression>, Type> types = analyzer.getTypes(TypeProvider.empty(), rewritten);

Type actualType = types.get(io.trino.sql.ir.NodeRef.of(rewritten));
if (!new TypeCoercion(plannerContext.getTypeManager()::getType).canCoerce(actualType, expectedType)) {
Expand All @@ -73,7 +73,7 @@ public static Object evaluateConstant(

if (!actualType.equals(expectedType)) {
rewritten = new Cast(rewritten, expectedType, false);
types = analyzer.getTypes(session, TypeProvider.empty(), rewritten);
types = analyzer.getTypes(TypeProvider.empty(), rewritten);
}

return new IrExpressionInterpreter(rewritten, plannerContext, session, types).evaluate();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.util.Optional;
import java.util.function.Function;

import static io.trino.sql.SqlFormatter.formatName;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.joining;

Expand Down Expand Up @@ -93,7 +92,7 @@ protected String visitConstant(Constant node, Void context)
@Override
protected String visitFunctionCall(FunctionCall node, Void context)
{
return formatName(node.getName()) + '(' + joinExpressions(node.getArguments()) + ')';
return node.getFunction().getName().toString() + '(' + joinExpressions(node.getArguments()) + ')';
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ public Expression visitFunctionCall(FunctionCall node, Context<C> context)
List<Expression> arguments = rewrite(node.getArguments(), context);

if (!sameElements(node.getArguments(), arguments)) {
return new FunctionCall(node.getName(), arguments);
return new FunctionCall(node.getFunction(), arguments);
}
return node;
}
Expand Down
34 changes: 11 additions & 23 deletions core/trino-main/src/main/java/io/trino/sql/ir/FunctionCall.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,43 +16,31 @@
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.collect.ImmutableList;
import io.trino.connector.system.GlobalSystemConnector;
import io.trino.sql.tree.QualifiedName;
import io.trino.metadata.ResolvedFunction;

import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

import static java.util.Objects.requireNonNull;

public final class FunctionCall
extends Expression
{
private final QualifiedName name;
private final ResolvedFunction function;
private final List<Expression> arguments;

@JsonCreator
public FunctionCall(String resolvedFunction, List<Expression> arguments)
{
this(
QualifiedName.of(GlobalSystemConnector.NAME, "$resolved", resolvedFunction),
ImmutableList.copyOf(arguments));
}

public FunctionCall(QualifiedName name, List<Expression> arguments)
public FunctionCall(ResolvedFunction function, List<Expression> arguments)
{
this.name = name;
this.function = requireNonNull(function, "function is null");
this.arguments = ImmutableList.copyOf(arguments);
}

@Deprecated
public QualifiedName getName()
{
return name;
}

@JsonProperty
public String getResolvedFunction()
public ResolvedFunction getFunction()
{
return name.getSuffix();
return function;
}

@JsonProperty
Expand Down Expand Up @@ -83,21 +71,21 @@ public boolean equals(Object obj)
return false;
}
FunctionCall o = (FunctionCall) obj;
return Objects.equals(name, o.name) &&
return Objects.equals(function, o.function) &&
Objects.equals(arguments, o.arguments);
}

@Override
public int hashCode()
{
return Objects.hash(name, arguments);
return Objects.hash(function, arguments);
}

@Override
public String toString()
{
return "%s(%s)".formatted(
name.getSuffix(),
function.getName(),
arguments.stream()
.map(Expression::toString)
.collect(Collectors.joining(", ")));
Expand Down
38 changes: 19 additions & 19 deletions core/trino-main/src/main/java/io/trino/sql/ir/IrUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -135,18 +135,18 @@ public static Expression combinePredicates(Metadata metadata, LogicalExpression.
public static Expression combinePredicates(Metadata metadata, LogicalExpression.Operator operator, Collection<Expression> expressions)
{
if (operator == LogicalExpression.Operator.AND) {
return combineConjuncts(metadata, expressions);
return combineConjuncts(expressions);
}

return combineDisjuncts(metadata, expressions);
return combineDisjuncts(expressions);
}

public static Expression combineConjuncts(Metadata metadata, Expression... expressions)
public static Expression combineConjuncts(Expression... expressions)
{
return combineConjuncts(metadata, Arrays.asList(expressions));
return combineConjuncts(Arrays.asList(expressions));
}

public static Expression combineConjuncts(Metadata metadata, Collection<Expression> expressions)
public static Expression combineConjuncts(Collection<Expression> expressions)
{
requireNonNull(expressions, "expressions is null");

Expand All @@ -155,7 +155,7 @@ public static Expression combineConjuncts(Metadata metadata, Collection<Expressi
.filter(e -> !e.equals(TRUE_LITERAL))
.collect(toList());

conjuncts = removeDuplicates(metadata, conjuncts);
conjuncts = removeDuplicates(conjuncts);

if (conjuncts.contains(FALSE_LITERAL)) {
return FALSE_LITERAL;
Expand All @@ -180,17 +180,17 @@ public static Expression combineConjunctsWithDuplicates(Collection<Expression> e
return and(conjuncts);
}

public static Expression combineDisjuncts(Metadata metadata, Expression... expressions)
public static Expression combineDisjuncts(Expression... expressions)
{
return combineDisjuncts(metadata, Arrays.asList(expressions));
return combineDisjuncts(Arrays.asList(expressions));
}

public static Expression combineDisjuncts(Metadata metadata, Collection<Expression> expressions)
public static Expression combineDisjuncts(Collection<Expression> expressions)
{
return combineDisjunctsWithDefault(metadata, expressions, FALSE_LITERAL);
return combineDisjunctsWithDefault(expressions, FALSE_LITERAL);
}

public static Expression combineDisjunctsWithDefault(Metadata metadata, Collection<Expression> expressions, Expression emptyDefault)
public static Expression combineDisjunctsWithDefault(Collection<Expression> expressions, Expression emptyDefault)
{
requireNonNull(expressions, "expressions is null");

Expand All @@ -199,7 +199,7 @@ public static Expression combineDisjunctsWithDefault(Metadata metadata, Collecti
.filter(e -> !e.equals(FALSE_LITERAL))
.collect(toList());

disjuncts = removeDuplicates(metadata, disjuncts);
disjuncts = removeDuplicates(disjuncts);

if (disjuncts.contains(TRUE_LITERAL)) {
return TRUE_LITERAL;
Expand All @@ -210,21 +210,21 @@ public static Expression combineDisjunctsWithDefault(Metadata metadata, Collecti

public static Expression filterDeterministicConjuncts(Metadata metadata, Expression expression)
{
return filterConjuncts(metadata, expression, expression1 -> DeterminismEvaluator.isDeterministic(expression1, metadata));
return filterConjuncts(expression, expression1 -> DeterminismEvaluator.isDeterministic(expression1));
}

public static Expression filterNonDeterministicConjuncts(Metadata metadata, Expression expression)
{
return filterConjuncts(metadata, expression, not(testExpression -> DeterminismEvaluator.isDeterministic(testExpression, metadata)));
return filterConjuncts(expression, not(testExpression -> DeterminismEvaluator.isDeterministic(testExpression)));
}

public static Expression filterConjuncts(Metadata metadata, Expression expression, Predicate<Expression> predicate)
public static Expression filterConjuncts(Expression expression, Predicate<Expression> predicate)
{
List<Expression> conjuncts = extractConjuncts(expression).stream()
.filter(predicate)
.collect(toList());

return combineConjuncts(metadata, conjuncts);
return combineConjuncts(conjuncts);
}

@SafeVarargs
Expand Down Expand Up @@ -276,7 +276,7 @@ public static boolean isEffectivelyLiteral(PlannerContext plannerContext, Sessio

private static boolean constantExpressionEvaluatesSuccessfully(PlannerContext plannerContext, Session session, Expression constantExpression)
{
Map<NodeRef<Expression>, Type> types = new IrTypeAnalyzer(plannerContext).getTypes(session, TypeProvider.empty(), constantExpression);
Map<NodeRef<Expression>, Type> types = new IrTypeAnalyzer(plannerContext).getTypes(TypeProvider.empty(), constantExpression);
IrExpressionInterpreter interpreter = new IrExpressionInterpreter(constantExpression, plannerContext, session, types);
Object literalValue = interpreter.optimize(NoOpSymbolResolver.INSTANCE);
return !(literalValue instanceof Expression);
Expand All @@ -286,13 +286,13 @@ private static boolean constantExpressionEvaluatesSuccessfully(PlannerContext pl
* Removes duplicate deterministic expressions. Preserves the relative order
* of the expressions in the list.
*/
private static List<Expression> removeDuplicates(Metadata metadata, List<Expression> expressions)
private static List<Expression> removeDuplicates(List<Expression> expressions)
{
Set<Expression> seen = new HashSet<>();

ImmutableList.Builder<Expression> result = ImmutableList.builder();
for (Expression expression : expressions) {
if (!DeterminismEvaluator.isDeterministic(expression, metadata)) {
if (!DeterminismEvaluator.isDeterministic(expression)) {
result.add(expression);
}
else if (!seen.contains(expression)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,6 @@ public BuiltinFunctionCallBuilder setArguments(List<Type> types, List<Expression
public FunctionCall build()
{
ResolvedFunction resolvedFunction = metadata.resolveBuiltinFunction(name, TypeSignatureProvider.fromTypeSignatures(argumentTypes));
return new FunctionCall(resolvedFunction.toQualifiedName(), argumentValues);
return new FunctionCall(resolvedFunction, argumentValues);
}
}
Loading

0 comments on commit 048f88f

Please sign in to comment.