Skip to content

Commit

Permalink
Rely on noMoreOperators for dynamic filters collection
Browse files Browse the repository at this point in the history
Relying on noMoreOperators from DynamicFilterSourceOperatorFactory
for detecting completion of dynamic filter collection will simplify
the implementation for fault tolerant execution where the collection
may take place in a source stage.
  • Loading branch information
raunaqmorarka authored and losipiuk committed May 31, 2022
1 parent 954125e commit c49be84
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.predicate.ValueSet;
import io.trino.spi.type.Type;
import io.trino.sql.planner.DynamicFilterSourceConsumer;
import io.trino.sql.planner.plan.DynamicFilterId;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.type.BlockTypeOperators;
Expand All @@ -32,7 +33,6 @@
import javax.annotation.Nullable;

import java.util.List;
import java.util.function.Consumer;

import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
Expand Down Expand Up @@ -77,19 +77,20 @@ public static class DynamicFilterSourceOperatorFactory
{
private final int operatorId;
private final PlanNodeId planNodeId;
private final Consumer<TupleDomain<DynamicFilterId>> dynamicPredicateConsumer;
private final DynamicFilterSourceConsumer dynamicPredicateConsumer;
private final List<Channel> channels;
private final int maxDisinctValues;
private final DataSize maxFilterSize;
private final int minMaxCollectionLimit;
private final BlockTypeOperators blockTypeOperators;

private boolean closed;
private int createdOperatorsCount;

public DynamicFilterSourceOperatorFactory(
int operatorId,
PlanNodeId planNodeId,
Consumer<TupleDomain<DynamicFilterId>> dynamicPredicateConsumer,
DynamicFilterSourceConsumer dynamicPredicateConsumer,
List<Channel> channels,
int maxDisinctValues,
DataSize maxFilterSize,
Expand All @@ -114,6 +115,7 @@ public DynamicFilterSourceOperatorFactory(
public Operator createOperator(DriverContext driverContext)
{
checkState(!closed, "Factory is already closed");
createdOperatorsCount++;
return new DynamicFilterSourceOperator(
driverContext.addOperatorContext(operatorId, planNodeId, DynamicFilterSourceOperator.class.getSimpleName()),
dynamicPredicateConsumer,
Expand All @@ -130,6 +132,7 @@ public void noMoreOperators()
{
checkState(!closed, "Factory is already closed");
closed = true;
dynamicPredicateConsumer.setPartitionCount(createdOperatorsCount);
}

@Override
Expand All @@ -142,7 +145,7 @@ public OperatorFactory duplicate()
private final OperatorContext context;
private boolean finished;
private Page current;
private final Consumer<TupleDomain<DynamicFilterId>> dynamicPredicateConsumer;
private final DynamicFilterSourceConsumer dynamicPredicateConsumer;
private final int maxDistinctValues;
private final long maxFilterSizeInBytes;

Expand All @@ -164,7 +167,7 @@ public OperatorFactory duplicate()

private DynamicFilterSourceOperator(
OperatorContext context,
Consumer<TupleDomain<DynamicFilterId>> dynamicPredicateConsumer,
DynamicFilterSourceConsumer dynamicPredicateConsumer,
List<Channel> channels,
PlanNodeId planNodeId,
int maxDistinctValues,
Expand Down Expand Up @@ -270,7 +273,7 @@ private void handleTooLargePredicate()
// The resulting predicate is too large
if (minMaxChannels.isEmpty()) {
// allow all probe-side values to be read.
dynamicPredicateConsumer.accept(TupleDomain.all());
dynamicPredicateConsumer.addPartition(TupleDomain.all());
}
else {
if (minMaxCollectionLimit < 0) {
Expand All @@ -294,7 +297,7 @@ private void handleTooLargePredicate()
private void handleMinMaxCollectionLimitExceeded()
{
// allow all probe-side values to be read.
dynamicPredicateConsumer.accept(TupleDomain.all());
dynamicPredicateConsumer.addPartition(TupleDomain.all());
// Drop references to collected values.
minValues = null;
maxValues = null;
Expand Down Expand Up @@ -387,7 +390,7 @@ public void finish()
}
minValues = null;
maxValues = null;
dynamicPredicateConsumer.accept(TupleDomain.withColumnDomains(domainsBuilder.buildOrThrow()));
dynamicPredicateConsumer.addPartition(TupleDomain.withColumnDomains(domainsBuilder.buildOrThrow()));
return;
}
for (int channelIndex = 0; channelIndex < channels.size(); ++channelIndex) {
Expand All @@ -397,7 +400,7 @@ public void finish()
}
valueSets = null;
blockBuilders = null;
dynamicPredicateConsumer.accept(TupleDomain.withColumnDomains(domainsBuilder.buildOrThrow()));
dynamicPredicateConsumer.addPartition(TupleDomain.withColumnDomains(domainsBuilder.buildOrThrow()));
}

private Domain convertToDomain(Type type, Block block)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.planner;

import io.trino.spi.predicate.TupleDomain;
import io.trino.sql.planner.plan.DynamicFilterId;

public interface DynamicFilterSourceConsumer
{
void addPartition(TupleDomain<DynamicFilterId> tupleDomain);

void setPartitionCount(int partitionCount);
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,27 @@
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanNode;

import javax.annotation.concurrent.GuardedBy;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;

import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;

public class LocalDynamicFilterConsumer
implements DynamicFilterSourceConsumer
{
private static final int PARTITION_COUNT_INITIAL_VALUE = -1;
// Mapping from dynamic filter ID to its build channel indices.
private final Map<DynamicFilterId, Integer> buildChannels;

Expand All @@ -49,39 +53,61 @@ public class LocalDynamicFilterConsumer

private final SettableFuture<TupleDomain<DynamicFilterId>> resultFuture;

// Number of build-side partitions to be collected.
private final int partitionCount;
// Number of build-side partitions to be collected, must be provided by setPartitionCount
@GuardedBy("this")
private int expectedPartitionCount = PARTITION_COUNT_INITIAL_VALUE;

// The resulting predicates from each build-side partition.
@GuardedBy("this")
private final List<TupleDomain<DynamicFilterId>> partitions;

public LocalDynamicFilterConsumer(Map<DynamicFilterId, Integer> buildChannels, Map<DynamicFilterId, Type> filterBuildTypes, int partitionCount)
public LocalDynamicFilterConsumer(Map<DynamicFilterId, Integer> buildChannels, Map<DynamicFilterId, Type> filterBuildTypes)
{
this.buildChannels = requireNonNull(buildChannels, "buildChannels is null");
this.filterBuildTypes = requireNonNull(filterBuildTypes, "filterBuildTypes is null");
verify(buildChannels.keySet().equals(filterBuildTypes.keySet()), "filterBuildTypes and buildChannels must have same keys");

this.resultFuture = SettableFuture.create();

this.partitionCount = partitionCount;
this.partitions = new ArrayList<>(partitionCount);
this.partitions = new ArrayList<>();
}

public ListenableFuture<Map<DynamicFilterId, Domain>> getDynamicFilterDomains()
{
return Futures.transform(resultFuture, this::convertTupleDomain, directExecutor());
}

private void addPartition(TupleDomain<DynamicFilterId> tupleDomain)
@Override
public void addPartition(TupleDomain<DynamicFilterId> tupleDomain)
{
if (resultFuture.isDone()) {
return;
}
TupleDomain<DynamicFilterId> result = null;
synchronized (this) {
// Called concurrently by each DynamicFilterSourceOperator instance (when collection is over).
verify(partitions.size() < partitionCount);
verify(expectedPartitionCount == PARTITION_COUNT_INITIAL_VALUE || partitions.size() < expectedPartitionCount);
// NOTE: may result in a bit more relaxed constraint if there are multiple columns and multiple rows.
// See the comment at TupleDomain::columnWiseUnion() for more details.
partitions.add(tupleDomain);
if (partitions.size() == partitionCount || tupleDomain.isAll()) {
if (partitions.size() == expectedPartitionCount || tupleDomain.isAll()) {
// No more partitions are left to be processed.
result = TupleDomain.columnWiseUnion(partitions);
}
}

if (result != null) {
resultFuture.set(result);
}
}

@Override
public void setPartitionCount(int partitionCount)
{
TupleDomain<DynamicFilterId> result = null;
synchronized (this) {
checkState(expectedPartitionCount == PARTITION_COUNT_INITIAL_VALUE, "setPartitionCount should be called only once");
expectedPartitionCount = partitionCount;
if (partitions.size() == expectedPartitionCount) {
// No more partitions are left to be processed.
result = TupleDomain.columnWiseUnion(partitions);
}
Expand Down Expand Up @@ -109,7 +135,6 @@ private Map<DynamicFilterId, Domain> convertTupleDomain(TupleDomain<DynamicFilte
public static LocalDynamicFilterConsumer create(
JoinNode planNode,
List<Type> buildSourceTypes,
int partitionCount,
Set<DynamicFilterId> collectedFilters)
{
checkArgument(!planNode.getDynamicFilters().isEmpty(), "Join node dynamicFilters is empty.");
Expand All @@ -134,25 +159,20 @@ public static LocalDynamicFilterConsumer create(
.collect(toImmutableMap(
Map.Entry::getKey,
entry -> buildSourceTypes.get(entry.getValue())));
return new LocalDynamicFilterConsumer(buildChannels, filterBuildTypes, partitionCount);
return new LocalDynamicFilterConsumer(buildChannels, filterBuildTypes);
}

public Map<DynamicFilterId, Integer> getBuildChannels()
{
return buildChannels;
}

public Consumer<TupleDomain<DynamicFilterId>> getTupleDomainConsumer()
{
return this::addPartition;
}

@Override
public String toString()
{
return toStringHelper(this)
.add("buildChannels", buildChannels)
.add("partitionCount", partitionCount)
.add("expectedPartitionCount", expectedPartitionCount)
.add("partitions", partitions)
.toString();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2553,7 +2553,7 @@ private PhysicalOperation createNestedLoopJoin(JoinNode node, Set<DynamicFilterI
checkArgument(partitionCount == 1, "Expected local execution to not be parallel");

int operatorId = buildContext.getNextOperatorId();
Optional<LocalDynamicFilterConsumer> localDynamicFilter = createDynamicFilter(buildSource, node, context, partitionCount, localDynamicFilters);
Optional<LocalDynamicFilterConsumer> localDynamicFilter = createDynamicFilter(buildSource, node, context, localDynamicFilters);
if (localDynamicFilter.isPresent()) {
buildSource = createDynamicFilterSourceOperatorFactory(operatorId, localDynamicFilter.get(), node, buildSource, buildContext);
}
Expand Down Expand Up @@ -2817,7 +2817,7 @@ private JoinBridgeManager<PartitionedLookupSourceFactory> createLookupSourceFact
buildOutputTypes);

int operatorId = buildContext.getNextOperatorId();
Optional<LocalDynamicFilterConsumer> localDynamicFilter = createDynamicFilter(buildSource, node, context, partitionCount, localDynamicFilters);
Optional<LocalDynamicFilterConsumer> localDynamicFilter = createDynamicFilter(buildSource, node, context, localDynamicFilters);
if (localDynamicFilter.isPresent()) {
buildSource = createDynamicFilterSourceOperatorFactory(operatorId, localDynamicFilter.get(), node, buildSource, buildContext);
}
Expand Down Expand Up @@ -2874,7 +2874,7 @@ private PhysicalOperation createDynamicFilterSourceOperatorFactory(
new DynamicFilterSourceOperatorFactory(
operatorId,
node.getId(),
dynamicFilter.getTupleDomainConsumer(),
dynamicFilter,
filterBuildChannels,
multipleIf(getDynamicFilteringMaxDistinctValuesPerDriver(session, isReplicatedJoin), taskConcurrency, isBuildSideSingle),
multipleIf(getDynamicFilteringMaxSizePerDriver(session, isReplicatedJoin), taskConcurrency, isBuildSideSingle),
Expand All @@ -2899,7 +2899,6 @@ private Optional<LocalDynamicFilterConsumer> createDynamicFilter(
PhysicalOperation buildSource,
JoinNode node,
LocalExecutionPlanContext context,
int partitionCount,
Set<DynamicFilterId> localDynamicFilters)
{
Set<DynamicFilterId> coordinatorDynamicFilters = getCoordinatorDynamicFilters(node.getDynamicFilters().keySet(), node, context.getTaskId());
Expand All @@ -2914,7 +2913,7 @@ private Optional<LocalDynamicFilterConsumer> createDynamicFilter(
buildSource.getPipelineExecutionStrategy() != GROUPED_EXECUTION,
"Dynamic filtering cannot be used with grouped execution");
log.debug("[Join] Dynamic filters: %s", node.getDynamicFilters());
LocalDynamicFilterConsumer filterConsumer = LocalDynamicFilterConsumer.create(node, buildSource.getTypes(), partitionCount, collectedDynamicFilters);
LocalDynamicFilterConsumer filterConsumer = LocalDynamicFilterConsumer.create(node, buildSource.getTypes(), collectedDynamicFilters);
ListenableFuture<Map<DynamicFilterId, Domain>> domainsFuture = filterConsumer.getDynamicFilterDomains();
if (!localDynamicFilters.isEmpty()) {
addSuccessCallback(domainsFuture, context::addLocalDynamicFilters);
Expand Down Expand Up @@ -3080,8 +3079,7 @@ public PhysicalOperation visitSemiJoin(SemiJoinNode node, LocalExecutionPlanCont
log.debug("[Semi-join] Dynamic filter: %s", filterId);
LocalDynamicFilterConsumer filterConsumer = new LocalDynamicFilterConsumer(
ImmutableMap.of(filterId, buildChannel),
ImmutableMap.of(filterId, buildSource.getTypes().get(buildChannel)),
partitionCount);
ImmutableMap.of(filterId, buildSource.getTypes().get(buildChannel)));
ListenableFuture<Map<DynamicFilterId, Domain>> domainsFuture = filterConsumer.getDynamicFilterDomains();
if (isLocalDynamicFilter) {
addSuccessCallback(domainsFuture, context::addLocalDynamicFilters);
Expand All @@ -3094,7 +3092,7 @@ public PhysicalOperation visitSemiJoin(SemiJoinNode node, LocalExecutionPlanCont
new DynamicFilterSourceOperatorFactory(
operatorId,
node.getId(),
filterConsumer.getTupleDomainConsumer(),
filterConsumer,
ImmutableList.of(new DynamicFilterSourceOperator.Channel(filterId, buildSource.getTypes().get(buildChannel), buildChannel)),
getDynamicFilteringMaxDistinctValuesPerDriver(session, isReplicatedJoin),
getDynamicFilteringMaxSizePerDriver(session, isReplicatedJoin),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
import io.airlift.units.DataSize;
import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.type.TypeOperators;
import io.trino.sql.planner.DynamicFilterSourceConsumer;
import io.trino.sql.planner.plan.DynamicFilterId;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.testing.TestingTaskContext;
Expand Down Expand Up @@ -93,7 +95,13 @@ public void setup()
operatorFactory = new DynamicFilterSourceOperator.DynamicFilterSourceOperatorFactory(
1,
new PlanNodeId("joinNodeId"),
(tupleDomain -> {}),
new DynamicFilterSourceConsumer() {
@Override
public void addPartition(TupleDomain<DynamicFilterId> tupleDomain) {}

@Override
public void setPartitionCount(int partitionCount) {}
},
ImmutableList.of(new DynamicFilterSourceOperator.Channel(new DynamicFilterId("0"), BIGINT, 0)),
maxDistinctValuesCount,
DataSize.ofBytes(Long.MAX_VALUE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.trino.spi.predicate.ValueSet;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeOperators;
import io.trino.sql.planner.DynamicFilterSourceConsumer;
import io.trino.sql.planner.plan.DynamicFilterId;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.testing.MaterializedResult;
Expand Down Expand Up @@ -132,19 +133,23 @@ private OperatorFactory createOperatorFactory(
return new DynamicFilterSourceOperator.DynamicFilterSourceOperatorFactory(
0,
new PlanNodeId("PLAN_NODE_ID"),
this::consumePredicate,
new DynamicFilterSourceConsumer() {
@Override
public void addPartition(TupleDomain<DynamicFilterId> tupleDomain)
{
partitions.add(tupleDomain);
}

@Override
public void setPartitionCount(int partitionCount) {}
},
ImmutableList.copyOf(buildChannels),
maxFilterDistinctValues,
maxFilterSize,
minMaxCollectionLimit,
blockTypeOperators);
}

private void consumePredicate(TupleDomain<DynamicFilterId> partitionPredicate)
{
partitions.add(partitionPredicate);
}

private Operator createOperator(OperatorFactory operatorFactory)
{
return operatorFactory.createOperator(pipelineContext.addDriverContext());
Expand Down
Loading

0 comments on commit c49be84

Please sign in to comment.