Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace lambda with method reference #21296

Merged
merged 2 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ PlanNodeStatsEstimate calculateJoinComplementStats(
.map(drivingClause -> calculateJoinComplementStats(leftStats, rightStats, drivingClause, criteria.size() - 1 + numberOfFilterClauses))
.filter(estimate -> !estimate.isOutputRowCountUnknown())
.max(comparingDouble(PlanNodeStatsEstimate::getOutputRowCount))
.map(estimate -> normalizer.normalize(estimate))
.map(normalizer::normalize)
.orElse(PlanNodeStatsEstimate.unknown());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ private void waitForMinimumWorkers()
ListenableFuture<Void> minimumWorkerFuture = clusterSizeMonitor.waitForMinimumWorkers(executionMinCount, getRequiredWorkersMaxWait(session));
// when worker requirement is met, start the execution
addSuccessCallback(minimumWorkerFuture, () -> startExecution(queryExecution), queryExecutor);
addExceptionCallback(minimumWorkerFuture, throwable -> stateMachine.transitionToFailed(throwable), queryExecutor);
addExceptionCallback(minimumWorkerFuture, stateMachine::transitionToFailed, queryExecutor);

// cancel minimumWorkerFuture if query fails for some reason or is cancelled by user
stateMachine.addStateChangeListener(state -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ public void writeValuesSorted(BlockBuilder valueBlockBuilder)
for (int i = 0; i < indexes.length; i++) {
indexes[i] = i;
}
IntArrays.quickSort(indexes, (a, b) -> compare(a, b));
IntArrays.quickSort(indexes, this::compare);

for (int index : indexes) {
write(index, null, valueBlockBuilder);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public Operator createOperator(DriverContext driverContext)
ListenableFuture<OuterPositionIterator> outerPositionsFuture = joinBridgeManager.getOuterPositionsFuture();
OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, LookupOuterOperator.class.getSimpleName());
joinBridgeManager.outerOperatorCreated();
return new LookupOuterOperator(operatorContext, outerPositionsFuture, probeOutputTypes, buildOutputTypes, () -> joinBridgeManager.outerOperatorClosed());
return new LookupOuterOperator(operatorContext, outerPositionsFuture, probeOutputTypes, buildOutputTypes, joinBridgeManager::outerOperatorClosed);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public Operator createOperator(DriverContext driverContext)
nestedLoopJoinBridge,
probeChannels,
buildChannels,
() -> joinBridgeManager.probeOperatorClosed());
joinBridgeManager::probeOperatorClosed);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ public WorkProcessorOperator create(ProcessorContext processorContext, WorkProce
waitForBuild,
lookupSourceFactory,
joinProbeFactory,
() -> joinBridgeManager.probeOperatorClosed(),
joinBridgeManager::probeOperatorClosed,
processorContext,
Optional.of(sourcePages));
}
Expand All @@ -212,7 +212,7 @@ public AdapterWorkProcessorOperator createAdapterOperator(ProcessorContext proce
waitForBuild,
lookupSourceFactory,
joinProbeFactory,
() -> joinBridgeManager.probeOperatorClosed(),
joinBridgeManager::probeOperatorClosed,
processorContext,
Optional.empty());
}
Expand Down
9 changes: 4 additions & 5 deletions core/trino-main/src/main/java/io/trino/sql/ir/IrUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import com.google.common.collect.Iterables;
import com.google.common.graph.SuccessorsFunction;
import com.google.common.graph.Traverser;
import io.trino.metadata.Metadata;
import io.trino.spi.type.Type;
import io.trino.sql.planner.DeterminismEvaluator;
import io.trino.sql.planner.Symbol;
Expand Down Expand Up @@ -202,14 +201,14 @@ public static Expression combineDisjunctsWithDefault(Collection<Expression> expr
return disjuncts.isEmpty() ? emptyDefault : or(disjuncts);
}

public static Expression filterDeterministicConjuncts(Metadata metadata, Expression expression)
public static Expression filterDeterministicConjuncts(Expression expression)
{
return filterConjuncts(expression, expression1 -> DeterminismEvaluator.isDeterministic(expression1));
return filterConjuncts(expression, DeterminismEvaluator::isDeterministic);
}

public static Expression filterNonDeterministicConjuncts(Metadata metadata, Expression expression)
public static Expression filterNonDeterministicConjuncts(Expression expression)
{
return filterConjuncts(expression, not(testExpression -> DeterminismEvaluator.isDeterministic(testExpression)));
return filterConjuncts(expression, not(DeterminismEvaluator::isDeterministic));
}

public static Expression filterConjuncts(Expression expression, Predicate<Expression> predicate)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.In;
import io.trino.sql.ir.IrUtils;
import io.trino.sql.ir.IrVisitor;
import io.trino.sql.ir.IsNull;
import io.trino.sql.ir.Logical;
Expand Down Expand Up @@ -114,7 +115,7 @@ public Expression toPredicate(TupleDomain<Symbol> tupleDomain)
Map<Symbol, Domain> domains = tupleDomain.getDomains().get();
return domains.entrySet().stream()
.map(entry -> toPredicate(entry.getValue(), entry.getKey().toSymbolReference()))
.collect(collectingAndThen(toImmutableList(), expressions -> combineConjuncts(expressions)));
.collect(collectingAndThen(toImmutableList(), IrUtils::combineConjuncts));
}

private Expression toPredicate(Domain domain, Reference reference)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ public Expression visitFilter(FilterNode node, Void context)
Expression predicate = node.getPredicate();

// Remove non-deterministic conjuncts
predicate = filterDeterministicConjuncts(metadata, predicate);
predicate = filterDeterministicConjuncts(predicate);

return combineConjuncts(predicate, underlyingPredicate);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public EqualityInference(Collection<Expression> expressions)
DisjointSet<Expression> equalities = new DisjointSet<>();
expressions.stream()
.flatMap(expression -> extractConjuncts(expression).stream())
.filter(expression -> isInferenceCandidate(expression))
.filter(EqualityInference::isInferenceCandidate)
.forEach(expression -> {
Comparison comparison = (Comparison) expression;
Expression expression1 = comparison.left();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ public LocalExecutionPlan plan(
Function<Page, Page> pagePreprocessor = enforceLoadedLayoutProcessor(outputLayout, physicalOperation.getLayout());

List<Type> outputTypes = outputLayout.stream()
.map(symbol -> symbol.getType())
.map(Symbol::getType)
.collect(toImmutableList());

context.addDriverFactory(
Expand Down Expand Up @@ -3698,7 +3698,7 @@ private List<Type> getSourceOperatorTypes(PlanNode node)
private List<Type> getSymbolTypes(List<Symbol> symbols)
{
return symbols.stream()
.map(symbol -> symbol.getType())
.map(Symbol::getType)
.collect(toImmutableList());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public static Optional<SortExpressionContext> extractSortExpression(Set<Symbol>
SortExpressionVisitor visitor = new SortExpressionVisitor(buildSymbols);

List<SortExpressionContext> sortExpressionCandidates = ImmutableList.copyOf(filterConjuncts.stream()
.filter(expression -> DeterminismEvaluator.isDeterministic(expression))
.filter(DeterminismEvaluator::isDeterministic)
.map(visitor::process)
.filter(Optional::isPresent)
.map(Optional::get)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.trino.sql.ir.ExpressionRewriter;
import io.trino.sql.ir.ExpressionTreeRewriter;
import io.trino.sql.ir.Logical;
import io.trino.sql.planner.DeterminismEvaluator;

import java.util.Collection;
import java.util.List;
Expand Down Expand Up @@ -168,7 +169,7 @@ private Expression distributeIfPossible(Logical expression)
private Set<Expression> filterDeterministicPredicates(List<Expression> predicates)
{
return predicates.stream()
.filter(expression -> isDeterministic(expression))
.filter(DeterminismEvaluator::isDeterministic)
.collect(toSet());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.trino.Session;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.planner.DeterminismEvaluator;
import io.trino.sql.planner.iterative.GroupReference;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
Expand Down Expand Up @@ -103,7 +104,7 @@ public Optional<PlanNode> visitFilter(FilterNode node, Void context)
public Optional<PlanNode> visitProject(ProjectNode node, Void context)
{
boolean isDeterministic = node.getAssignments().getExpressions().stream()
.allMatch(expression -> isDeterministic(expression));
.allMatch(DeterminismEvaluator::isDeterministic);
if (!isDeterministic) {
// non-deterministic projections could be used in downstream filters which could
// filter duplicate rows probabilistically
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ public static Optional<PlanNode> pushAggregationIntoTableScan(

List<AggregateFunction> aggregateFunctions = aggregationsList.stream()
.map(Entry::getValue)
.map(aggregation -> toAggregateFunction(aggregation))
.map(PushAggregationIntoTableScan::toAggregateFunction)
.collect(toImmutableList());

List<Symbol> aggregationOutputSymbols = aggregationsList.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Streams;
import io.trino.sql.ir.Expression;
import io.trino.sql.planner.DeterminismEvaluator;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
Expand Down Expand Up @@ -48,7 +49,7 @@ public static Optional<PlanNode> pushProjectionThroughJoin(
Lookup lookup,
PlanNodeIdAllocator planNodeIdAllocator)
{
if (!projectNode.getAssignments().getExpressions().stream().allMatch(expression -> isDeterministic(expression))) {
if (!projectNode.getAssignments().getExpressions().stream().allMatch(DeterminismEvaluator::isDeterministic)) {
return Optional.empty();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ public Result apply(FilterNode filterNode, Captures captures, Context context)
TableScanNode node = captures.get(TABLE_SCAN);
Expression predicate = filterNode.getPredicate();

Expression deterministicPredicate = filterDeterministicConjuncts(plannerContext.getMetadata(), predicate);
Expression nonDeterministicPredicate = filterNonDeterministicConjuncts(plannerContext.getMetadata(), predicate);
Expression deterministicPredicate = filterDeterministicConjuncts(predicate);
Expression nonDeterministicPredicate = filterNonDeterministicConjuncts(predicate);

ExtractionResult decomposedPredicate = getFullyExtractedPredicates(
session,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Row;
import io.trino.sql.planner.DeterminismEvaluator;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.plan.AggregationNode;
Expand Down Expand Up @@ -186,7 +187,7 @@ private Optional<TableScanNode> findTableScan(PlanNode source)
}
else if (source instanceof ProjectNode project) {
// verify projections are deterministic
if (!Iterables.all(project.getAssignments().getExpressions(), expression -> isDeterministic(expression))) {
if (!Iterables.all(project.getAssignments().getExpressions(), DeterminismEvaluator::isDeterministic)) {
return Optional.empty();
}
source = project.getSource();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -824,14 +824,14 @@ private OuterJoinPushDownResult processLimitedOuterJoin(
extractConjuncts(inheritedPredicate).stream()
.filter(expression -> !isDeterministic(expression))
.forEach(postJoinConjuncts::add);
inheritedPredicate = filterDeterministicConjuncts(metadata, inheritedPredicate);
inheritedPredicate = filterDeterministicConjuncts(inheritedPredicate);

outerEffectivePredicate = filterDeterministicConjuncts(metadata, outerEffectivePredicate);
innerEffectivePredicate = filterDeterministicConjuncts(metadata, innerEffectivePredicate);
outerEffectivePredicate = filterDeterministicConjuncts(outerEffectivePredicate);
innerEffectivePredicate = filterDeterministicConjuncts(innerEffectivePredicate);
extractConjuncts(joinPredicate).stream()
.filter(expression -> !isDeterministic(expression))
.forEach(joinConjuncts::add);
joinPredicate = filterDeterministicConjuncts(metadata, joinPredicate);
joinPredicate = filterDeterministicConjuncts(joinPredicate);

// Generate equality inferences
EqualityInference inheritedInference = new EqualityInference(inheritedPredicate);
Expand Down Expand Up @@ -956,15 +956,15 @@ private InnerJoinPushDownResult processInnerJoin(
extractConjuncts(inheritedPredicate).stream()
.filter(deterministic -> !isDeterministic(deterministic))
.forEach(joinConjuncts::add);
inheritedPredicate = filterDeterministicConjuncts(metadata, inheritedPredicate);
inheritedPredicate = filterDeterministicConjuncts(inheritedPredicate);

extractConjuncts(joinPredicate).stream()
.filter(expression -> !isDeterministic(expression))
.forEach(joinConjuncts::add);
joinPredicate = filterDeterministicConjuncts(metadata, joinPredicate);
joinPredicate = filterDeterministicConjuncts(joinPredicate);

leftEffectivePredicate = filterDeterministicConjuncts(metadata, leftEffectivePredicate);
rightEffectivePredicate = filterDeterministicConjuncts(metadata, rightEffectivePredicate);
leftEffectivePredicate = filterDeterministicConjuncts(leftEffectivePredicate);
rightEffectivePredicate = filterDeterministicConjuncts(rightEffectivePredicate);

ImmutableSet<Symbol> leftScope = ImmutableSet.copyOf(leftSymbols);
ImmutableSet<Symbol> rightScope = ImmutableSet.copyOf(rightSymbols);
Expand Down Expand Up @@ -1300,9 +1300,9 @@ private PlanNode visitNonFilteringSemiJoin(SemiJoinNode node, RewriteContext<Exp
private PlanNode visitFilteringSemiJoin(SemiJoinNode node, RewriteContext<Expression> context)
{
Expression inheritedPredicate = context.get();
Expression deterministicInheritedPredicate = filterDeterministicConjuncts(metadata, inheritedPredicate);
Expression sourceEffectivePredicate = filterDeterministicConjuncts(metadata, effectivePredicateExtractor.extract(session, node.getSource()));
Expression filteringSourceEffectivePredicate = filterDeterministicConjuncts(metadata, effectivePredicateExtractor.extract(session, node.getFilteringSource()));
Expression deterministicInheritedPredicate = filterDeterministicConjuncts(inheritedPredicate);
Expression sourceEffectivePredicate = filterDeterministicConjuncts(effectivePredicateExtractor.extract(session, node.getSource()));
Expression filteringSourceEffectivePredicate = filterDeterministicConjuncts(effectivePredicateExtractor.extract(session, node.getFilteringSource()));
Expression joinExpression = new Comparison(
EQUAL,
node.getSourceJoinSymbol().toSymbolReference(),
Expand Down Expand Up @@ -1417,7 +1417,7 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext<Expression
extractConjuncts(inheritedPredicate).stream()
.filter(expression -> !isDeterministic(expression))
.forEach(postAggregationConjuncts::add);
inheritedPredicate = filterDeterministicConjuncts(metadata, inheritedPredicate);
inheritedPredicate = filterDeterministicConjuncts(inheritedPredicate);

// Sort non-equality predicates by those that can be pushed down and those that cannot
Set<Symbol> groupingKeys = ImmutableSet.copyOf(node.getGroupingKeys());
Expand Down Expand Up @@ -1479,7 +1479,7 @@ public PlanNode visitUnnest(UnnestNode node, RewriteContext<Expression> context)
extractConjuncts(inheritedPredicate).stream()
.filter(expression -> !isDeterministic(expression))
.forEach(postUnnestConjuncts::add);
inheritedPredicate = filterDeterministicConjuncts(metadata, inheritedPredicate);
inheritedPredicate = filterDeterministicConjuncts(inheritedPredicate);

// Sort non-equality predicates by those that can be pushed down and those that cannot
Set<Symbol> replicatedSymbols = ImmutableSet.copyOf(node.getReplicateSymbols());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.IrUtils;
import io.trino.sql.ir.Switch;
import io.trino.sql.ir.WhenClause;
import io.trino.sql.planner.PlanNodeIdAllocator;
Expand Down Expand Up @@ -179,11 +180,11 @@ public Expression rewriteUsingBounds(ApplyNode.QuantifiedComparison quantifiedCo
Function<List<Expression>, Expression> quantifier;
if (quantifiedComparison.quantifier() == ALL) {
emptySetResult = TRUE;
quantifier = expressions -> combineConjuncts(expressions);
quantifier = IrUtils::combineConjuncts;
}
else {
emptySetResult = FALSE;
quantifier = expressions -> combineDisjuncts(expressions);
quantifier = IrUtils::combineDisjuncts;
}
Expression comparisonWithExtremeValue = getBoundComparisons(quantifiedComparison, minValue, maxValue);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,6 @@ public void testAggregationWithMoreGroupingSets()
.addSymbolStatistics(new Symbol(UNKNOWN, "y"), SymbolStatsEstimate.builder().setDistinctValuesCount(50).build())
.addSymbolStatistics(new Symbol(UNKNOWN, "z"), SymbolStatsEstimate.builder().setDistinctValuesCount(50).build())
.build())
.check(check -> check.outputRowsCountUnknown());
.check(PlanNodeStatsAssertion::outputRowsCountUnknown);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ public QueryStateMachine build()
if (setAuthorizationUser != null) {
stateMachine.setSetAuthorizationUser(setAuthorizationUser);
}
addPreparedStatements.forEach((key, value) -> stateMachine.addPreparedStatement(key, value));
addPreparedStatements.forEach(stateMachine::addPreparedStatement);
if (transactionId != null) {
stateMachine.setStartedTransactionId(transactionId);
}
Expand Down
Loading
Loading