Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support pushing dereferences within lambdas into table scan #23148

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions core/trino-main/src/main/java/io/trino/FeaturesConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ public class FeaturesConfig

private boolean faultTolerantExecutionExchangeEncryptionEnabled = true;

private boolean pushFieldDereferenceLambdaIntoScanEnabled;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be with other optimizer properties in OptimizerConfig

TBH: I think it should be connector property. We already have io.trino.plugin.iceberg.IcebergConfig#projectionPushdownEnabled so it might not be needed at all


public enum DataIntegrityVerification
{
NONE,
Expand Down Expand Up @@ -517,4 +519,17 @@ public void applyFaultTolerantExecutionDefaults()
{
exchangeCompressionCodec = LZ4;
}

public boolean isPushFieldDereferenceLambdaIntoScanEnabled()
{
return pushFieldDereferenceLambdaIntoScanEnabled;
}

@Config("experimental.enable-push-field-dereference-lambda-into-scan.enabled")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it should really be experimental.

TBH I think it should be connector property

@ConfigDescription("Enables pushing field dereferences in lambda into table scan")
public FeaturesConfig setPushFieldDereferenceLambdaIntoScanEnabled(boolean pushFieldDereferenceLambdaIntoScanEnabled)
{
this.pushFieldDereferenceLambdaIntoScanEnabled = pushFieldDereferenceLambdaIntoScanEnabled;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ public final class SystemSessionProperties
public static final String IDLE_WRITER_MIN_DATA_SIZE_THRESHOLD = "idle_writer_min_data_size_threshold";
public static final String CLOSE_IDLE_WRITERS_TRIGGER_DURATION = "close_idle_writers_trigger_duration";
public static final String COLUMNAR_FILTER_EVALUATION_ENABLED = "columnar_filter_evaluation_enabled";
public static final String ENABLE_PUSH_FIELD_DEREFERENCE_LAMBDA_INTO_SCAN = "enable_push_field_dereference_lambda_into_scan";

private final List<PropertyMetadata<?>> sessionProperties;

Expand Down Expand Up @@ -1103,7 +1104,12 @@ public SystemSessionProperties(
ALLOW_UNSAFE_PUSHDOWN,
"Allow pushing down expressions that may fail for some inputs",
optimizerConfig.isUnsafePushdownAllowed(),
true));
true),
booleanProperty(
ENABLE_PUSH_FIELD_DEREFERENCE_LAMBDA_INTO_SCAN,
"Enable pushing field dereferences in lambda into scan",
featuresConfig.isPushFieldDereferenceLambdaIntoScanEnabled(),
false));
}

@Override
Expand Down Expand Up @@ -1982,4 +1988,9 @@ public static boolean isUnsafePushdownAllowed(Session session)
{
return session.getSystemProperty(ALLOW_UNSAFE_PUSHDOWN, Boolean.class);
}

public static boolean isPushFieldDereferenceLambdaIntoScanEnabled(Session session)
{
return session.getSystemProperty(ENABLE_PUSH_FIELD_DEREFERENCE_LAMBDA_INTO_SCAN, Boolean.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.trino.plugin.base.expression.ConnectorExpressions;
import io.trino.security.AllowAllAccessControl;
import io.trino.spi.connector.CatalogSchemaName;
import io.trino.spi.expression.ArrayFieldDereference;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.FieldDereference;
import io.trino.spi.expression.FunctionName;
Expand All @@ -46,9 +47,11 @@
import io.trino.sql.ir.In;
import io.trino.sql.ir.IrVisitor;
import io.trino.sql.ir.IsNull;
import io.trino.sql.ir.Lambda;
import io.trino.sql.ir.Logical;
import io.trino.sql.ir.NullIf;
import io.trino.sql.ir.Reference;
import io.trino.sql.ir.Row;
import io.trino.sql.tree.QualifiedName;
import io.trino.type.JoniRegexp;
import io.trino.type.JsonPathType;
Expand All @@ -69,6 +72,7 @@
import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName;
import static io.trino.metadata.GlobalFunctionCatalog.isBuiltinFunctionName;
import static io.trino.metadata.LanguageFunctionManager.isInlineFunction;
import static io.trino.operator.scalar.ArrayTransformFunction.ARRAY_TRANSFORM_NAME;
import static io.trino.operator.scalar.JsonStringToArrayCast.JSON_STRING_TO_ARRAY_NAME;
import static io.trino.operator.scalar.JsonStringToMapCast.JSON_STRING_TO_MAP_NAME;
import static io.trino.operator.scalar.JsonStringToRowCast.JSON_STRING_TO_ROW_NAME;
Expand Down Expand Up @@ -127,15 +131,21 @@ public static Expression translate(Session session, ConnectorExpression expressi

public static Optional<ConnectorExpression> translate(Session session, Expression expression)
{
return new SqlToConnectorExpressionTranslator(session)
return new SqlToConnectorExpressionTranslator(session, false)
.process(expression);
}

public static Optional<ConnectorExpression> translate(Session session, Expression expression, boolean translateArrayFieldReference)
{
return new SqlToConnectorExpressionTranslator(session, translateArrayFieldReference)
.process(expression);
}

public static ConnectorExpressionTranslation translateConjuncts(
Session session,
Expression expression)
{
SqlToConnectorExpressionTranslator translator = new SqlToConnectorExpressionTranslator(session);
SqlToConnectorExpressionTranslator translator = new SqlToConnectorExpressionTranslator(session, false);

List<Expression> conjuncts = extractConjuncts(expression);
List<Expression> remaining = new ArrayList<>();
Expand Down Expand Up @@ -562,10 +572,12 @@ public static class SqlToConnectorExpressionTranslator
extends IrVisitor<Optional<ConnectorExpression>, Void>
{
private final Session session;
private final boolean translateArrayFieldReference;

public SqlToConnectorExpressionTranslator(Session session)
public SqlToConnectorExpressionTranslator(Session session, boolean translateArrayFieldReference)
{
this.session = requireNonNull(session, "session is null");
this.translateArrayFieldReference = translateArrayFieldReference;
}

@Override
Expand Down Expand Up @@ -694,6 +706,37 @@ else if (functionName.equals(builtinFunctionName(MODULUS))) {
new io.trino.spi.expression.Call(node.type(), MODULUS_FUNCTION_NAME, ImmutableList.of(left, right))));
}

// Very narrow case that only tries to extract a particular type of lambda expression
// TODO: Expand the scope
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please create Trino issue which would track the TODO and reference it here.

if (translateArrayFieldReference && functionName.equals(builtinFunctionName(ARRAY_TRANSFORM_NAME))) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that should be else if

List<Expression> allNodeArgument = node.arguments();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: just arguments

// at this point, SubscriptExpression should already been pushed down by PushProjectionIntoTableScan,
// if not, it means its referenced by other expressions. we only care about SymbolReference at this moment
List<Expression> inputExpressions = allNodeArgument.stream().filter(Reference.class::isInstance)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check arguments.getFirst() explicitly, see SpecializeTransformWithJsonParse

.collect(toImmutableList());
List<Lambda> lambdaExpressions = allNodeArgument.stream().filter(e -> e instanceof Lambda lambda
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto. Plase check rguments.getFirst() explicitly. Then you would actually have lambda variable within if, so you wouldn't need to use lambdaExpressions.get(0).

There can't be more than single lambda expression anyway

&& lambda.arguments().size() == 1)
.map(Lambda.class::cast)
.collect(toImmutableList());
if (inputExpressions.size() == 1 && lambdaExpressions.size() == 1) {
Optional<ConnectorExpression> inputVariable = process(inputExpressions.get(0));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can just return Optional.empty if inputVariable is empty

if (lambdaExpressions.get(0).body() instanceof Row row) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

List<Expression> rowFields = row.items();
List<ConnectorExpression> translatedRowFields =
rowFields.stream().map(e -> process(e)).filter(Optional::isPresent).map(Optional::get).collect(toImmutableList());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

List<ConnectorExpression> translatedRowFields = rowFields.stream()
  .map(this::process)
  .filter(Optional::isPresent)
  .collect(toImmutableList());

if (inputVariable.isPresent() && translatedRowFields.size() == rowFields.size()) {
return Optional.of(new ArrayFieldDereference(node.type(), inputVariable.get(), translatedRowFields));
}
}
if (lambdaExpressions.get(0).body() instanceof FieldReference fieldReference) {
Optional<ConnectorExpression> fieldReferenceConnectorExpr = process(fieldReference);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just: fieldReferenceConnectorExpr -> reference

if (inputVariable.isPresent() && fieldReferenceConnectorExpr.isPresent() && fieldReferenceConnectorExpr.get() instanceof FieldDereference expr) {
return Optional.of(new ArrayFieldDereference(node.type(), inputVariable.get(), ImmutableList.of(expr)));
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can return empty if the above condition doesn't match. Same for Row case

}
}
}

ImmutableList.Builder<ConnectorExpression> arguments = ImmutableList.builder();
for (Expression argumentExpression : node.arguments()) {
Optional<ConnectorExpression> argument = process(argumentExpression);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@ private PartialTranslator() {}
*/
public static Map<NodeRef<Expression>, ConnectorExpression> extractPartialTranslations(
Expression inputExpression,
Session session)
Session session,
boolean translateArrayFieldReference)
{
requireNonNull(inputExpression, "inputExpression is null");
requireNonNull(session, "session is null");

Map<NodeRef<Expression>, ConnectorExpression> partialTranslations = new HashMap<>();
new Visitor(session, partialTranslations).process(inputExpression);
new Visitor(session, partialTranslations, translateArrayFieldReference).process(inputExpression);
return ImmutableMap.copyOf(partialTranslations);
}

Expand All @@ -53,10 +54,10 @@ private static class Visitor
private final Map<NodeRef<Expression>, ConnectorExpression> translatedSubExpressions;
private final ConnectorExpressionTranslator.SqlToConnectorExpressionTranslator translator;

Visitor(Session session, Map<NodeRef<Expression>, ConnectorExpression> translatedSubExpressions)
Visitor(Session session, Map<NodeRef<Expression>, ConnectorExpression> translatedSubExpressions, boolean translateArrayFieldReference)
{
this.translatedSubExpressions = requireNonNull(translatedSubExpressions, "translatedSubExpressions is null");
this.translator = new ConnectorExpressionTranslator.SqlToConnectorExpressionTranslator(session);
this.translator = new ConnectorExpressionTranslator.SqlToConnectorExpressionTranslator(session, translateArrayFieldReference);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,11 @@
import io.trino.sql.planner.iterative.rule.PushDownDereferencesThroughTopN;
import io.trino.sql.planner.iterative.rule.PushDownDereferencesThroughTopNRanking;
import io.trino.sql.planner.iterative.rule.PushDownDereferencesThroughWindow;
import io.trino.sql.planner.iterative.rule.PushDownFieldReferenceLambdaThroughFilter;
import io.trino.sql.planner.iterative.rule.PushDownFieldReferenceLambdaThroughProject;
import io.trino.sql.planner.iterative.rule.PushDownProjectionsFromPatternRecognition;
import io.trino.sql.planner.iterative.rule.PushFieldReferenceLambdaIntoTableScan;
import io.trino.sql.planner.iterative.rule.PushFieldReferenceLambdaThroughFilterIntoTableScan;
import io.trino.sql.planner.iterative.rule.PushFilterIntoValues;
import io.trino.sql.planner.iterative.rule.PushFilterThroughBoolOrAggregation;
import io.trino.sql.planner.iterative.rule.PushFilterThroughCountAggregation;
Expand Down Expand Up @@ -346,7 +350,9 @@ public PlanOptimizers(
new PushDownDereferencesThroughWindow(),
new PushDownDereferencesThroughTopN(),
new PushDownDereferencesThroughRowNumber(),
new PushDownDereferencesThroughTopNRanking());
new PushDownDereferencesThroughTopNRanking(),
new PushDownFieldReferenceLambdaThroughProject(),
new PushDownFieldReferenceLambdaThroughFilter());

Set<Rule<?>> limitPushdownRules = ImmutableSet.of(
new PushLimitThroughOffset(),
Expand Down Expand Up @@ -999,6 +1005,15 @@ public PlanOptimizers(
new PushPartialAggregationThroughExchange(plannerContext),
new PruneJoinColumns(),
new PruneJoinChildrenColumns())));
// This rule does not touch query plans, but only add subfields if necessary. Trigger at the near end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there queries that actually fail to get pushdown if we don't have this rule here? Maybe it's not needed?

// Keeping this as iterative as it could be combined with PushProjectionIntoTableScan in the future
builder.add(new IterativeOptimizer(
plannerContext,
ruleStats,
statsCalculator,
costCalculator,
ImmutableSet.of(new PushFieldReferenceLambdaIntoTableScan(plannerContext),
new PushFieldReferenceLambdaThroughFilterIntoTableScan(plannerContext))));
builder.add(new IterativeOptimizer(
plannerContext,
ruleStats,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,32 @@
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import io.trino.spi.function.CatalogSchemaFunctionName;
import io.trino.spi.type.RowType;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.DefaultTraversalVisitor;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.FieldReference;
import io.trino.sql.ir.Lambda;
import io.trino.sql.ir.Reference;
import io.trino.sql.ir.Row;
import io.trino.sql.planner.Symbol;
import io.trino.sql.tree.FunctionCall;

import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName;
import static io.trino.operator.scalar.ArrayTransformFunction.ARRAY_TRANSFORM_NAME;
import static io.trino.sql.planner.SymbolsExtractor.extractAll;

/**
Expand Down Expand Up @@ -132,4 +143,113 @@ private static boolean prefixExists(Expression expression, Set<Expression> expre
verify(current instanceof Reference);
return false;
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: public methods above private methods

// Common methods for subscript lambda pushdown
/**
* Extract the sub-expressions of type subscript lambda {@link FunctionCall} from the {@param expression}
*/
public static Map<Call, Reference> extractSubscriptLambdas(Collection<Expression> expressions)
{
List<Map<Expression, Reference>> referencesAndFieldDereferenceLambdas =
expressions.stream()
.map(expression -> getSymbolReferencesAndSubscriptLambdas(expression))
.collect(toImmutableList());

Set<Reference> symbolReferences =
referencesAndFieldDereferenceLambdas.stream()
.flatMap(m -> m.keySet().stream())
.filter(Reference.class::isInstance)
.map(Reference.class::cast)
.collect(Collectors.toSet());

// Returns the subscript expression and its target input expression
Map<Call, Reference> subscriptLambdas =
referencesAndFieldDereferenceLambdas.stream()
.flatMap(m -> m.entrySet().stream())
.filter(e -> e.getKey() instanceof Call && !symbolReferences.contains(e.getValue()))
.collect(Collectors.toMap(e -> (Call) e.getKey(), e -> e.getValue()));

return subscriptLambdas;
}

/**
* Extract the sub-expressions of type {@link Reference} and subscript lambda {@link FunctionCall} from the {@param expression}
*/
private static Map<Expression, Reference> getSymbolReferencesAndSubscriptLambdas(Expression expression)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it does what it means to. See io.trino.sql.planner.SymbolsExtractor.SymbolBuilderVisitor. Symbols within lambdas need to be handled carefully

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See also #23649

{
Map<Expression, Reference> symbolMappings = new HashMap<>();

new DefaultTraversalVisitor<Map<Expression, Reference>>()
{
@Override
protected Void visitReference(Reference node, Map<Expression, Reference> context)
{
context.put(node, node);
return null;
}

@Override
protected Void visitCall(Call node, Map<Expression, Reference> context)
{
Optional<Reference> inputExpression = getSubscriptLambdaInputExpression(node);
if (inputExpression.isPresent()) {
context.put(node, inputExpression.get());
}

return null;
}
}.process(expression, symbolMappings);

return symbolMappings;
}

/**
* Extract the sub-expressions of type {@link Reference} from the {@param expression}
*/
public static List<Reference> getReferences(Expression expression)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this you should use io.trino.sql.planner.SymbolsExtractor#extractUnique(io.trino.sql.ir.Expression) instead. This is because lambdas nested within an expression can reference lamba arguments, see: io.trino.sql.planner.SymbolsExtractor.SymbolBuilderVisitor

{
ImmutableList.Builder<Reference> builder = ImmutableList.builder();

new DefaultTraversalVisitor<ImmutableList.Builder<Reference>>()
{
@Override
protected Void visitReference(Reference node, ImmutableList.Builder<Reference> context)
{
context.add(node);
return null;
}
}.process(expression, builder);

return builder.build();
}

/**
* Common pattern matching util function to look for subscript lambda function
*/
public static Optional<Reference> getSubscriptLambdaInputExpression(Expression expression)
{
if (expression instanceof Call functionCall) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if you revert it:

if (!(expression instanceOf Call functionCall) ||
  !functionCall.equals(builtinFunctionName(ARRAY_TRANSFORM_NAME))) {
  return Optional.empty();
}

// rest of code

that the happy path does not have extra indents.

CatalogSchemaFunctionName functionName = functionCall.function().name();

if (functionName.equals(builtinFunctionName(ARRAY_TRANSFORM_NAME))) {
List<Expression> allNodeArgument = functionCall.arguments();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar pattern happens in ConnectorExpressionTranslator. Maybe we can extract something like

record ArrayDereferenceExpression(Expression argument, Lambda lambda)

then we could have method Optional<ArrayDereferenceExpression> getArrayDereferenceExpression(Expression expression)
and use it in rules and in ConnectorExpressionTranslator?

// at this point, FieldDereference expression should already been replaced with reference expression,
// if not, it means its referenced by other expressions. we only care about FieldReference at this moment
List<Reference> inputExpressions = allNodeArgument.stream()
.filter(Reference.class::isInstance)
.map(Reference.class::cast)
.collect(toImmutableList());
List<Lambda> lambdaExpressions = allNodeArgument.stream().filter(e -> e instanceof Lambda lambda
&& lambda.arguments().size() == 1)
.map(Lambda.class::cast)
.collect(toImmutableList());
if (inputExpressions.size() == 1 && lambdaExpressions.size() == 1 &&
((lambdaExpressions.get(0).body() instanceof Row row &&
row.items().stream().allMatch(FieldReference.class::isInstance)) || (lambdaExpressions.get(0).body() instanceof FieldReference))) {
return Optional.of(inputExpressions.get(0));
}
}
}
return Optional.empty();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

import static io.trino.matching.Capture.newCapture;
import static io.trino.sql.planner.ExpressionSymbolInliner.inlineSymbols;
import static io.trino.sql.planner.iterative.rule.DereferencePushdown.getSubscriptLambdaInputExpression;
import static io.trino.sql.planner.plan.Patterns.project;
import static io.trino.sql.planner.plan.Patterns.source;
import static java.util.stream.Collectors.toSet;
Expand Down Expand Up @@ -187,6 +188,10 @@ private static Set<Symbol> extractInliningTargets(ProjectNode parent, ProjectNod

return true;
})
.filter(entry -> {
// skip subscript lambdas, otherwise, inlining can cause conflicts with PushdownDereferences
return getSubscriptLambdaInputExpression(child.getAssignments().get(entry.getKey())).isEmpty();
})
.map(Map.Entry::getKey)
.collect(toSet());

Expand Down
Loading
Loading