-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support pushing dereferences within lambdas into table scan #23148
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -122,6 +122,8 @@ public class FeaturesConfig | |
|
||
private boolean faultTolerantExecutionExchangeEncryptionEnabled = true; | ||
|
||
private boolean pushFieldDereferenceLambdaIntoScanEnabled; | ||
|
||
public enum DataIntegrityVerification | ||
{ | ||
NONE, | ||
|
@@ -517,4 +519,17 @@ public void applyFaultTolerantExecutionDefaults() | |
{ | ||
exchangeCompressionCodec = LZ4; | ||
} | ||
|
||
public boolean isPushFieldDereferenceLambdaIntoScanEnabled() | ||
{ | ||
return pushFieldDereferenceLambdaIntoScanEnabled; | ||
} | ||
|
||
@Config("experimental.enable-push-field-dereference-lambda-into-scan.enabled") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think it should really be experimental. TBH I think it should be connector property |
||
@ConfigDescription("Enables pushing field dereferences in lambda into table scan") | ||
public FeaturesConfig setPushFieldDereferenceLambdaIntoScanEnabled(boolean pushFieldDereferenceLambdaIntoScanEnabled) | ||
{ | ||
this.pushFieldDereferenceLambdaIntoScanEnabled = pushFieldDereferenceLambdaIntoScanEnabled; | ||
return this; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ | |
import io.trino.plugin.base.expression.ConnectorExpressions; | ||
import io.trino.security.AllowAllAccessControl; | ||
import io.trino.spi.connector.CatalogSchemaName; | ||
import io.trino.spi.expression.ArrayFieldDereference; | ||
import io.trino.spi.expression.ConnectorExpression; | ||
import io.trino.spi.expression.FieldDereference; | ||
import io.trino.spi.expression.FunctionName; | ||
|
@@ -46,9 +47,11 @@ | |
import io.trino.sql.ir.In; | ||
import io.trino.sql.ir.IrVisitor; | ||
import io.trino.sql.ir.IsNull; | ||
import io.trino.sql.ir.Lambda; | ||
import io.trino.sql.ir.Logical; | ||
import io.trino.sql.ir.NullIf; | ||
import io.trino.sql.ir.Reference; | ||
import io.trino.sql.ir.Row; | ||
import io.trino.sql.tree.QualifiedName; | ||
import io.trino.type.JoniRegexp; | ||
import io.trino.type.JsonPathType; | ||
|
@@ -69,6 +72,7 @@ | |
import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; | ||
import static io.trino.metadata.GlobalFunctionCatalog.isBuiltinFunctionName; | ||
import static io.trino.metadata.LanguageFunctionManager.isInlineFunction; | ||
import static io.trino.operator.scalar.ArrayTransformFunction.ARRAY_TRANSFORM_NAME; | ||
import static io.trino.operator.scalar.JsonStringToArrayCast.JSON_STRING_TO_ARRAY_NAME; | ||
import static io.trino.operator.scalar.JsonStringToMapCast.JSON_STRING_TO_MAP_NAME; | ||
import static io.trino.operator.scalar.JsonStringToRowCast.JSON_STRING_TO_ROW_NAME; | ||
|
@@ -127,15 +131,21 @@ public static Expression translate(Session session, ConnectorExpression expressi | |
|
||
public static Optional<ConnectorExpression> translate(Session session, Expression expression) | ||
{ | ||
return new SqlToConnectorExpressionTranslator(session) | ||
return new SqlToConnectorExpressionTranslator(session, false) | ||
.process(expression); | ||
} | ||
|
||
public static Optional<ConnectorExpression> translate(Session session, Expression expression, boolean translateArrayFieldReference) | ||
{ | ||
return new SqlToConnectorExpressionTranslator(session, translateArrayFieldReference) | ||
.process(expression); | ||
} | ||
|
||
public static ConnectorExpressionTranslation translateConjuncts( | ||
Session session, | ||
Expression expression) | ||
{ | ||
SqlToConnectorExpressionTranslator translator = new SqlToConnectorExpressionTranslator(session); | ||
SqlToConnectorExpressionTranslator translator = new SqlToConnectorExpressionTranslator(session, false); | ||
|
||
List<Expression> conjuncts = extractConjuncts(expression); | ||
List<Expression> remaining = new ArrayList<>(); | ||
|
@@ -562,10 +572,12 @@ public static class SqlToConnectorExpressionTranslator | |
extends IrVisitor<Optional<ConnectorExpression>, Void> | ||
{ | ||
private final Session session; | ||
private final boolean translateArrayFieldReference; | ||
|
||
public SqlToConnectorExpressionTranslator(Session session) | ||
public SqlToConnectorExpressionTranslator(Session session, boolean translateArrayFieldReference) | ||
{ | ||
this.session = requireNonNull(session, "session is null"); | ||
this.translateArrayFieldReference = translateArrayFieldReference; | ||
} | ||
|
||
@Override | ||
|
@@ -694,6 +706,37 @@ else if (functionName.equals(builtinFunctionName(MODULUS))) { | |
new io.trino.spi.expression.Call(node.type(), MODULUS_FUNCTION_NAME, ImmutableList.of(left, right)))); | ||
} | ||
|
||
// Very narrow case that only tries to extract a particular type of lambda expression | ||
// TODO: Expand the scope | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please create Trino issue which would track the TODO and reference it here. |
||
if (translateArrayFieldReference && functionName.equals(builtinFunctionName(ARRAY_TRANSFORM_NAME))) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that should be |
||
List<Expression> allNodeArgument = node.arguments(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: just |
||
// at this point, SubscriptExpression should already been pushed down by PushProjectionIntoTableScan, | ||
// if not, it means its referenced by other expressions. we only care about SymbolReference at this moment | ||
List<Expression> inputExpressions = allNodeArgument.stream().filter(Reference.class::isInstance) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please check |
||
.collect(toImmutableList()); | ||
List<Lambda> lambdaExpressions = allNodeArgument.stream().filter(e -> e instanceof Lambda lambda | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto. Plase check There can't be more than single lambda expression anyway |
||
&& lambda.arguments().size() == 1) | ||
.map(Lambda.class::cast) | ||
.collect(toImmutableList()); | ||
if (inputExpressions.size() == 1 && lambdaExpressions.size() == 1) { | ||
Optional<ConnectorExpression> inputVariable = process(inputExpressions.get(0)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can just return |
||
if (lambdaExpressions.get(0).body() instanceof Row row) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: you can use switch pattern matching instead: https://belief-driven-design.com/looking-at-java-21-switch-pattern-matching-14648/ |
||
List<Expression> rowFields = row.items(); | ||
List<ConnectorExpression> translatedRowFields = | ||
rowFields.stream().map(e -> process(e)).filter(Optional::isPresent).map(Optional::get).collect(toImmutableList()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit:
|
||
if (inputVariable.isPresent() && translatedRowFields.size() == rowFields.size()) { | ||
return Optional.of(new ArrayFieldDereference(node.type(), inputVariable.get(), translatedRowFields)); | ||
} | ||
} | ||
if (lambdaExpressions.get(0).body() instanceof FieldReference fieldReference) { | ||
Optional<ConnectorExpression> fieldReferenceConnectorExpr = process(fieldReference); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just: |
||
if (inputVariable.isPresent() && fieldReferenceConnectorExpr.isPresent() && fieldReferenceConnectorExpr.get() instanceof FieldDereference expr) { | ||
return Optional.of(new ArrayFieldDereference(node.type(), inputVariable.get(), ImmutableList.of(expr))); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you can return |
||
} | ||
} | ||
} | ||
|
||
ImmutableList.Builder<ConnectorExpression> arguments = ImmutableList.builder(); | ||
for (Expression argumentExpression : node.arguments()) { | ||
Optional<ConnectorExpression> argument = process(argumentExpression); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -151,7 +151,11 @@ | |
import io.trino.sql.planner.iterative.rule.PushDownDereferencesThroughTopN; | ||
import io.trino.sql.planner.iterative.rule.PushDownDereferencesThroughTopNRanking; | ||
import io.trino.sql.planner.iterative.rule.PushDownDereferencesThroughWindow; | ||
import io.trino.sql.planner.iterative.rule.PushDownFieldReferenceLambdaThroughFilter; | ||
import io.trino.sql.planner.iterative.rule.PushDownFieldReferenceLambdaThroughProject; | ||
import io.trino.sql.planner.iterative.rule.PushDownProjectionsFromPatternRecognition; | ||
import io.trino.sql.planner.iterative.rule.PushFieldReferenceLambdaIntoTableScan; | ||
import io.trino.sql.planner.iterative.rule.PushFieldReferenceLambdaThroughFilterIntoTableScan; | ||
import io.trino.sql.planner.iterative.rule.PushFilterIntoValues; | ||
import io.trino.sql.planner.iterative.rule.PushFilterThroughBoolOrAggregation; | ||
import io.trino.sql.planner.iterative.rule.PushFilterThroughCountAggregation; | ||
|
@@ -346,7 +350,9 @@ public PlanOptimizers( | |
new PushDownDereferencesThroughWindow(), | ||
new PushDownDereferencesThroughTopN(), | ||
new PushDownDereferencesThroughRowNumber(), | ||
new PushDownDereferencesThroughTopNRanking()); | ||
new PushDownDereferencesThroughTopNRanking(), | ||
new PushDownFieldReferenceLambdaThroughProject(), | ||
new PushDownFieldReferenceLambdaThroughFilter()); | ||
|
||
Set<Rule<?>> limitPushdownRules = ImmutableSet.of( | ||
new PushLimitThroughOffset(), | ||
|
@@ -999,6 +1005,15 @@ public PlanOptimizers( | |
new PushPartialAggregationThroughExchange(plannerContext), | ||
new PruneJoinColumns(), | ||
new PruneJoinChildrenColumns()))); | ||
// This rule does not touch query plans, but only add subfields if necessary. Trigger at the near end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are there queries that actually fail to get pushdown if we don't have this rule here? Maybe it's not needed? |
||
// Keeping this as iterative as it could be combined with PushProjectionIntoTableScan in the future | ||
builder.add(new IterativeOptimizer( | ||
plannerContext, | ||
ruleStats, | ||
statsCalculator, | ||
costCalculator, | ||
ImmutableSet.of(new PushFieldReferenceLambdaIntoTableScan(plannerContext), | ||
new PushFieldReferenceLambdaThroughFilterIntoTableScan(plannerContext)))); | ||
builder.add(new IterativeOptimizer( | ||
plannerContext, | ||
ruleStats, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,21 +14,32 @@ | |
package io.trino.sql.planner.iterative.rule; | ||
|
||
import com.google.common.collect.ImmutableList; | ||
import io.trino.spi.function.CatalogSchemaFunctionName; | ||
import io.trino.spi.type.RowType; | ||
import io.trino.sql.ir.Call; | ||
import io.trino.sql.ir.DefaultTraversalVisitor; | ||
import io.trino.sql.ir.Expression; | ||
import io.trino.sql.ir.FieldReference; | ||
import io.trino.sql.ir.Lambda; | ||
import io.trino.sql.ir.Reference; | ||
import io.trino.sql.ir.Row; | ||
import io.trino.sql.planner.Symbol; | ||
import io.trino.sql.tree.FunctionCall; | ||
|
||
import java.util.Collection; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Optional; | ||
import java.util.Set; | ||
import java.util.stream.Collectors; | ||
|
||
import static com.google.common.base.Verify.verify; | ||
import static com.google.common.collect.ImmutableList.toImmutableList; | ||
import static com.google.common.collect.ImmutableSet.toImmutableSet; | ||
import static com.google.common.collect.Iterables.getOnlyElement; | ||
import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; | ||
import static io.trino.operator.scalar.ArrayTransformFunction.ARRAY_TRANSFORM_NAME; | ||
import static io.trino.sql.planner.SymbolsExtractor.extractAll; | ||
|
||
/** | ||
|
@@ -132,4 +143,113 @@ private static boolean prefixExists(Expression expression, Set<Expression> expre | |
verify(current instanceof Reference); | ||
return false; | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: public methods above private methods |
||
// Common methods for subscript lambda pushdown | ||
/** | ||
* Extract the sub-expressions of type subscript lambda {@link FunctionCall} from the {@param expression} | ||
*/ | ||
public static Map<Call, Reference> extractSubscriptLambdas(Collection<Expression> expressions) | ||
{ | ||
List<Map<Expression, Reference>> referencesAndFieldDereferenceLambdas = | ||
expressions.stream() | ||
.map(expression -> getSymbolReferencesAndSubscriptLambdas(expression)) | ||
.collect(toImmutableList()); | ||
|
||
Set<Reference> symbolReferences = | ||
referencesAndFieldDereferenceLambdas.stream() | ||
.flatMap(m -> m.keySet().stream()) | ||
.filter(Reference.class::isInstance) | ||
.map(Reference.class::cast) | ||
.collect(Collectors.toSet()); | ||
|
||
// Returns the subscript expression and its target input expression | ||
Map<Call, Reference> subscriptLambdas = | ||
referencesAndFieldDereferenceLambdas.stream() | ||
.flatMap(m -> m.entrySet().stream()) | ||
.filter(e -> e.getKey() instanceof Call && !symbolReferences.contains(e.getValue())) | ||
.collect(Collectors.toMap(e -> (Call) e.getKey(), e -> e.getValue())); | ||
|
||
return subscriptLambdas; | ||
} | ||
|
||
/** | ||
* Extract the sub-expressions of type {@link Reference} and subscript lambda {@link FunctionCall} from the {@param expression} | ||
*/ | ||
private static Map<Expression, Reference> getSymbolReferencesAndSubscriptLambdas(Expression expression) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think it does what it means to. See There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See also #23649 |
||
{ | ||
Map<Expression, Reference> symbolMappings = new HashMap<>(); | ||
|
||
new DefaultTraversalVisitor<Map<Expression, Reference>>() | ||
{ | ||
@Override | ||
protected Void visitReference(Reference node, Map<Expression, Reference> context) | ||
{ | ||
context.put(node, node); | ||
return null; | ||
} | ||
|
||
@Override | ||
protected Void visitCall(Call node, Map<Expression, Reference> context) | ||
{ | ||
Optional<Reference> inputExpression = getSubscriptLambdaInputExpression(node); | ||
if (inputExpression.isPresent()) { | ||
context.put(node, inputExpression.get()); | ||
} | ||
|
||
return null; | ||
} | ||
}.process(expression, symbolMappings); | ||
|
||
return symbolMappings; | ||
} | ||
|
||
/** | ||
* Extract the sub-expressions of type {@link Reference} from the {@param expression} | ||
*/ | ||
public static List<Reference> getReferences(Expression expression) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like this you should use |
||
{ | ||
ImmutableList.Builder<Reference> builder = ImmutableList.builder(); | ||
|
||
new DefaultTraversalVisitor<ImmutableList.Builder<Reference>>() | ||
{ | ||
@Override | ||
protected Void visitReference(Reference node, ImmutableList.Builder<Reference> context) | ||
{ | ||
context.add(node); | ||
return null; | ||
} | ||
}.process(expression, builder); | ||
|
||
return builder.build(); | ||
} | ||
|
||
/** | ||
* Common pattern matching util function to look for subscript lambda function | ||
*/ | ||
public static Optional<Reference> getSubscriptLambdaInputExpression(Expression expression) | ||
{ | ||
if (expression instanceof Call functionCall) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: if you revert it:
that the happy path does not have extra indents. |
||
CatalogSchemaFunctionName functionName = functionCall.function().name(); | ||
|
||
if (functionName.equals(builtinFunctionName(ARRAY_TRANSFORM_NAME))) { | ||
List<Expression> allNodeArgument = functionCall.arguments(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. similar pattern happens in
then we could have method |
||
// at this point, FieldDereference expression should already been replaced with reference expression, | ||
// if not, it means its referenced by other expressions. we only care about FieldReference at this moment | ||
List<Reference> inputExpressions = allNodeArgument.stream() | ||
.filter(Reference.class::isInstance) | ||
.map(Reference.class::cast) | ||
.collect(toImmutableList()); | ||
List<Lambda> lambdaExpressions = allNodeArgument.stream().filter(e -> e instanceof Lambda lambda | ||
&& lambda.arguments().size() == 1) | ||
.map(Lambda.class::cast) | ||
.collect(toImmutableList()); | ||
if (inputExpressions.size() == 1 && lambdaExpressions.size() == 1 && | ||
((lambdaExpressions.get(0).body() instanceof Row row && | ||
row.items().stream().allMatch(FieldReference.class::isInstance)) || (lambdaExpressions.get(0).body() instanceof FieldReference))) { | ||
return Optional.of(inputExpressions.get(0)); | ||
} | ||
} | ||
} | ||
return Optional.empty(); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be with other optimizer properties in
OptimizerConfig
TBH: I think it should be connector property. We already have
io.trino.plugin.iceberg.IcebergConfig#projectionPushdownEnabled
so it might not be needed at all