Skip to content

Commit

Permalink
Use gather over partitioned exchange for small join build side
Browse files Browse the repository at this point in the history
Not partitioned lookup source has better performance
than partitioned one so for small build side the overall
join performance is better even if the lookup source is
created using a single thread.
  • Loading branch information
lukasz-stec authored and sopel39 committed May 18, 2022
1 parent be260e5 commit c92ea3b
Show file tree
Hide file tree
Showing 12 changed files with 625 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ public final class SystemSessionProperties
public static final String ADAPTIVE_PARTIAL_AGGREGATION_ENABLED = "adaptive_partial_aggregation_enabled";
public static final String ADAPTIVE_PARTIAL_AGGREGATION_MIN_ROWS = "adaptive_partial_aggregation_min_rows";
public static final String ADAPTIVE_PARTIAL_AGGREGATION_UNIQUE_ROWS_RATIO_THRESHOLD = "adaptive_partial_aggregation_unique_rows_ratio_threshold";
public static final String JOIN_PARTITIONED_BUILD_MIN_ROW_COUNT = "join_partitioned_build_min_row_count";

private final List<PropertyMetadata<?>> sessionProperties;

Expand Down Expand Up @@ -823,6 +824,12 @@ public SystemSessionProperties(
ADAPTIVE_PARTIAL_AGGREGATION_UNIQUE_ROWS_RATIO_THRESHOLD,
"Ratio between aggregation output and input rows above which partial aggregation might be adaptively turned off",
optimizerConfig.getAdaptivePartialAggregationUniqueRowsRatioThreshold(),
false),
longProperty(
JOIN_PARTITIONED_BUILD_MIN_ROW_COUNT,
"Minimum number of join build side rows required to use partitioned join lookup",
optimizerConfig.getJoinPartitionedBuildMinRowCount(),
value -> validateNonNegativeLongValue(value, JOIN_PARTITIONED_BUILD_MIN_ROW_COUNT),
false));
}

Expand Down Expand Up @@ -1204,6 +1211,13 @@ private static Integer validateIntegerValue(Object value, String property, int l
return intValue;
}

private static void validateNonNegativeLongValue(Long value, String property)
{
if (value < 0) {
throw new TrinoException(INVALID_SESSION_PROPERTY, format("%s must be equal or greater than 0", property));
}
}

private static double validateDoubleRange(Object value, String property, double lowerBoundIncluded, double upperBoundIncluded)
{
double doubleValue = (double) value;
Expand Down Expand Up @@ -1479,4 +1493,9 @@ public static double getAdaptivePartialAggregationUniqueRowsRatioThreshold(Sessi
{
return session.getSystemProperty(ADAPTIVE_PARTIAL_AGGREGATION_UNIQUE_ROWS_RATIO_THRESHOLD, Double.class);
}

public static long getJoinPartitionedBuildMinRowCount(Session session)
{
return session.getSystemProperty(JOIN_PARTITIONED_BUILD_MIN_ROW_COUNT, Long.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ public class OptimizerConfig
private boolean adaptivePartialAggregationEnabled = true;
private long adaptivePartialAggregationMinRows = 100_000;
private double adaptivePartialAggregationUniqueRowsRatioThreshold = 0.8;
private long joinPartitionedBuildMinRowCount = 1_000_000L;

public enum JoinReorderingStrategy
{
Expand Down Expand Up @@ -713,4 +714,18 @@ public OptimizerConfig setAdaptivePartialAggregationUniqueRowsRatioThreshold(dou
this.adaptivePartialAggregationUniqueRowsRatioThreshold = adaptivePartialAggregationUniqueRowsRatioThreshold;
return this;
}

@Min(0)
public long getJoinPartitionedBuildMinRowCount()
{
return joinPartitionedBuildMinRowCount;
}

@Config("optimizer.join-partitioned-build-min-row-count")
@ConfigDescription("Minimum number of join build side rows required to use partitioned join lookup")
public OptimizerConfig setJoinPartitionedBuildMinRowCount(long joinPartitionedBuildMinRowCount)
{
this.joinPartitionedBuildMinRowCount = joinPartitionedBuildMinRowCount;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@
import io.trino.sql.planner.iterative.rule.UnwrapRowSubscript;
import io.trino.sql.planner.iterative.rule.UnwrapSingleColumnRowInApply;
import io.trino.sql.planner.iterative.rule.UnwrapTimestampToDateCastInComparison;
import io.trino.sql.planner.iterative.rule.UseNonPartitionedJoinLookupSource;
import io.trino.sql.planner.optimizations.AddExchanges;
import io.trino.sql.planner.optimizations.AddLocalExchanges;
import io.trino.sql.planner.optimizations.BeginTableWrite;
Expand Down Expand Up @@ -919,6 +920,13 @@ public PlanOptimizers(

// Optimizers above this don't understand local exchanges, so be careful moving this.
builder.add(new AddLocalExchanges(plannerContext, typeAnalyzer));
// UseNonPartitionedJoinLookupSource needs to run after AddLocalExchanges since it operates on ExchangeNodes added by this optimizer.
builder.add(new IterativeOptimizer(
plannerContext,
ruleStats,
statsCalculator,
costCalculator,
ImmutableSet.of(new UseNonPartitionedJoinLookupSource())));

// Optimizers above this do not need to care about aggregations with the type other than SINGLE
// This optimizer must be run after all exchange-related optimizers
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.planner.iterative.rule;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.trino.Session;
import io.trino.cost.StatsProvider;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.operator.join.LookupJoinOperator;
import io.trino.sql.planner.Partitioning;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.UnnestNode;
import io.trino.sql.planner.plan.ValuesNode;

import java.util.List;
import java.util.Optional;

import static io.trino.SystemSessionProperties.getJoinPartitionedBuildMinRowCount;
import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION;
import static io.trino.sql.planner.plan.ExchangeNode.Scope.LOCAL;
import static io.trino.sql.planner.plan.ExchangeNode.Type.GATHER;
import static io.trino.sql.planner.plan.Patterns.Join.right;
import static io.trino.sql.planner.plan.Patterns.exchange;
import static io.trino.sql.planner.plan.Patterns.join;

/**
* Rule that transforms
* <pre>
* join
* probe
* build
* localExchange(partitioned)
* </pre>
* into:
* <pre>
* join
* probe
* build
* localExchange(gather)
* </pre>
* for small build sides.
* Avoiding partitioned exchange on the probe side improves {@link LookupJoinOperator} performance.
*/
public class UseNonPartitionedJoinLookupSource
implements Rule<JoinNode>
{
private static final Capture<ExchangeNode> RIGHT_EXCHANGE_NODE = Capture.newCapture();
private static final Pattern<JoinNode> JOIN_PATTERN = join()
.with(right().matching(exchange()
.matching(UseNonPartitionedJoinLookupSource::canBeTranslatedToLocalGather)
.capturedAs(RIGHT_EXCHANGE_NODE)));

@Override
public Pattern<JoinNode> getPattern()
{
return JOIN_PATTERN;
}

@Override
public boolean isEnabled(Session session)
{
return getJoinPartitionedBuildMinRowCount(session) > 0;
}

@Override
public Result apply(JoinNode node, Captures captures, Context context)
{
double buildSideRowCount = getSourceTablesRowCount(node.getRight(), context);
if (Double.isNaN(buildSideRowCount)) {
// buildSideRowCount = NaN means stats are not available or build side contains join
return Result.empty();
}
if (buildSideRowCount >= getJoinPartitionedBuildMinRowCount(context.getSession())) {
// build side has too many rows
return Result.empty();
}

ExchangeNode rightSideExchange = captures.get(RIGHT_EXCHANGE_NODE);
ExchangeNode singleThreadedExchange = toGatheringExchange(rightSideExchange);
return Result.ofPlanNode(node.replaceChildren(ImmutableList.of(node.getLeft(), singleThreadedExchange)));
}

private static ExchangeNode toGatheringExchange(ExchangeNode exchangeNode)
{
return new ExchangeNode(
exchangeNode.getId(),
GATHER,
LOCAL,
new PartitioningScheme(
Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()),
exchangeNode.getPartitioningScheme().getOutputLayout()),
exchangeNode.getSources(),
exchangeNode.getInputs(),
Optional.empty());
}

private static boolean canBeTranslatedToLocalGather(ExchangeNode exchangeNode)
{
return exchangeNode.getScope() == LOCAL
&& !isSingleGather(exchangeNode)
&& exchangeNode.getOrderingScheme().isEmpty()
&& exchangeNode.getPartitioningScheme().getBucketToPartition().isEmpty()
&& !exchangeNode.getPartitioningScheme().isReplicateNullsAndAny();
}

private static boolean isSingleGather(ExchangeNode exchangeNode)
{
return exchangeNode.getType() == GATHER
&& exchangeNode.getPartitioningScheme().getPartitioning().getHandle() == SINGLE_DISTRIBUTION;
}

private static double getSourceTablesRowCount(PlanNode node, Context context)
{
return getSourceTablesRowCount(node, context.getLookup(), context.getStatsProvider());
}

@VisibleForTesting
static double getSourceTablesRowCount(PlanNode node, Lookup lookup, StatsProvider statsProvider)
{
boolean hasExpandingNodes = PlanNodeSearcher.searchFrom(node, lookup)
.whereIsInstanceOfAny(JoinNode.class, UnnestNode.class)
.matches();
if (hasExpandingNodes) {
return Double.NaN;
}

List<PlanNode> sourceNodes = PlanNodeSearcher.searchFrom(node, lookup)
.whereIsInstanceOfAny(TableScanNode.class, ValuesNode.class)
.findAll();

return sourceNodes.stream()
.mapToDouble(sourceNode -> statsProvider.getStats(sourceNode).getOutputRowCount())
.sum();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ public void testDefaults()
.setForceSingleNodeOutput(true)
.setAdaptivePartialAggregationEnabled(true)
.setAdaptivePartialAggregationMinRows(100_000)
.setAdaptivePartialAggregationUniqueRowsRatioThreshold(0.8));
.setAdaptivePartialAggregationUniqueRowsRatioThreshold(0.8)
.setJoinPartitionedBuildMinRowCount(1_000_000));
}

@Test
Expand Down Expand Up @@ -141,6 +142,7 @@ public void testExplicitPropertyMappings()
.put("adaptive-partial-aggregation.enabled", "false")
.put("adaptive-partial-aggregation.min-rows", "1")
.put("adaptive-partial-aggregation.unique-rows-ratio-threshold", "0.99")
.put("optimizer.join-partitioned-build-min-row-count", "1")
.buildOrThrow();

OptimizerConfig expected = new OptimizerConfig()
Expand Down Expand Up @@ -191,7 +193,8 @@ public void testExplicitPropertyMappings()
.setForceSingleNodeOutput(false)
.setAdaptivePartialAggregationEnabled(false)
.setAdaptivePartialAggregationMinRows(1)
.setAdaptivePartialAggregationUniqueRowsRatioThreshold(0.99);
.setAdaptivePartialAggregationUniqueRowsRatioThreshold(0.99)
.setJoinPartitionedBuildMinRowCount(1);
assertFullMapping(properties, expected);
}
}
Loading

0 comments on commit c92ea3b

Please sign in to comment.