Skip to content

Commit

Permalink
Implement dynamic filtering for semi-joins
Browse files Browse the repository at this point in the history
  • Loading branch information
lxynov authored and sopel39 committed Sep 10, 2020
1 parent 61c72c4 commit af4e856
Show file tree
Hide file tree
Showing 6 changed files with 242 additions and 15 deletions.
29 changes: 26 additions & 3 deletions presto-main/src/main/java/io/prestosql/operator/JoinUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,18 @@
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.planner.plan.SemiJoinNode;
import io.prestosql.util.MorePredicates;

import java.util.List;

import static com.google.common.base.Preconditions.checkArgument;
import static io.prestosql.sql.planner.plan.ExchangeNode.Scope.LOCAL;
import static io.prestosql.sql.planner.plan.ExchangeNode.Scope.REMOTE;
import static io.prestosql.sql.planner.plan.ExchangeNode.Type.GATHER;
import static io.prestosql.sql.planner.plan.ExchangeNode.Type.REPARTITION;
import static io.prestosql.sql.planner.plan.ExchangeNode.Type.REPLICATE;
import static io.prestosql.util.MorePredicates.isInstanceOfAny;

/**
* This class must be public as it is accessed via join compiler reflection.
Expand All @@ -55,12 +59,21 @@ public static List<Page> channelsToPages(List<List<Block>> channels)
return pagesBuilder.build();
}

public static boolean isBuildSideReplicated(JoinNode joinNode)
public static boolean isBuildSideReplicated(PlanNode node)
{
return PlanNodeSearcher.searchFrom(joinNode.getRight())
checkArgument(isInstanceOfAny(JoinNode.class, SemiJoinNode.class).test(node));
if (node instanceof JoinNode) {
return PlanNodeSearcher.searchFrom(((JoinNode) node).getRight())
.recurseOnlyWhen(
MorePredicates.<PlanNode>isInstanceOfAny(ProjectNode.class)
.or(JoinUtils::isLocalRepartitionExchange))
.where(JoinUtils::isRemoteReplicatedExchange)
.matches();
}
return PlanNodeSearcher.searchFrom(((SemiJoinNode) node).getFilteringSource())
.recurseOnlyWhen(
MorePredicates.<PlanNode>isInstanceOfAny(ProjectNode.class)
.or(JoinUtils::isLocalRepartitionExchange))
.or(JoinUtils::isLocalGatherExchange))
.where(JoinUtils::isRemoteReplicatedExchange)
.matches();
}
Expand All @@ -84,4 +97,14 @@ private static boolean isLocalRepartitionExchange(PlanNode node)
ExchangeNode exchangeNode = (ExchangeNode) node;
return exchangeNode.getScope() == LOCAL && exchangeNode.getType() == REPARTITION;
}

private static boolean isLocalGatherExchange(PlanNode node)
{
if (!(node instanceof ExchangeNode)) {
return false;
}

ExchangeNode exchangeNode = (ExchangeNode) node;
return exchangeNode.getScope() == LOCAL && exchangeNode.getType() == GATHER;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Multimap;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
Expand All @@ -42,6 +43,7 @@
import io.prestosql.sql.planner.plan.DynamicFilterId;
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.SemiJoinNode;

import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
Expand Down Expand Up @@ -70,10 +72,12 @@
import static io.airlift.concurrent.MoreFutures.unmodifiableFuture;
import static io.airlift.concurrent.MoreFutures.whenAnyComplete;
import static io.airlift.concurrent.Threads.daemonThreadsNamed;
import static io.prestosql.operator.JoinUtils.isBuildSideReplicated;
import static io.prestosql.spi.connector.DynamicFilter.EMPTY;
import static io.prestosql.spi.predicate.Domain.union;
import static io.prestosql.sql.DynamicFilters.extractDynamicFilters;
import static io.prestosql.sql.planner.ExpressionExtractor.extractExpressions;
import static io.prestosql.util.MorePredicates.isInstanceOfAny;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
Expand Down Expand Up @@ -324,22 +328,33 @@ private static Set<DynamicFilterId> getLazyDynamicFilters(PlanFragment plan)
private static Set<DynamicFilterId> getReplicatedDynamicFilters(PlanNode planNode)
{
return PlanNodeSearcher.searchFrom(planNode)
.where(JoinNode.class::isInstance)
.<JoinNode>findAll().stream()
.filter(JoinUtils::isBuildSideReplicated)
.flatMap(node -> node.getDynamicFilters().keySet().stream())
.where(isInstanceOfAny(JoinNode.class, SemiJoinNode.class))
.findAll().stream()
.filter((JoinUtils::isBuildSideReplicated))
.flatMap(node -> getDynamicFiltersProducedInPlanNode(node).stream())
.collect(toImmutableSet());
}

private static Set<DynamicFilterId> getProducedDynamicFilters(PlanNode planNode)
{
return PlanNodeSearcher.searchFrom(planNode)
.where(JoinNode.class::isInstance)
.<JoinNode>findAll().stream()
.flatMap(node -> node.getDynamicFilters().keySet().stream())
.where(isInstanceOfAny(JoinNode.class, SemiJoinNode.class))
.findAll().stream()
.flatMap(node -> getDynamicFiltersProducedInPlanNode(node).stream())
.collect(toImmutableSet());
}

private static Set<DynamicFilterId> getDynamicFiltersProducedInPlanNode(PlanNode planNode)
{
if (planNode instanceof JoinNode) {
return ((JoinNode) planNode).getDynamicFilters().keySet();
}
if (planNode instanceof SemiJoinNode) {
return ((SemiJoinNode) planNode).getDynamicFilterId().map(ImmutableSet::of).orElse(ImmutableSet.of());
}
throw new IllegalStateException("getDynamicFiltersProducedInPlanNode called with neither JoinNode nor SemiJoinNode");
}

private static Set<DynamicFilterId> getConsumedDynamicFilters(PlanNode planNode)
{
return extractExpressions(planNode).stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2223,18 +2223,45 @@ private Map<Symbol, Integer> createJoinSourcesLayout(Map<Symbol, Integer> lookup
@Override
public PhysicalOperation visitSemiJoin(SemiJoinNode node, LocalExecutionPlanContext context)
{
node.getDynamicFilterId().ifPresent(filterId -> {
// Register locally if the table scan is on the same node (e.g., in case of broadcast semi-joins)
if (getConsumedDynamicFilterIds(node.getSource()).contains(filterId)) {
context.getDynamicFiltersCollector().register(ImmutableSet.of(filterId));
}
});
// Plan probe
PhysicalOperation probeSource = node.getSource().accept(this, context);

// Plan build
LocalExecutionPlanContext buildContext = context.createSubContext();
PhysicalOperation buildSource = node.getFilteringSource().accept(this, buildContext);
checkState(buildSource.getPipelineExecutionStrategy() == probeSource.getPipelineExecutionStrategy(), "build and probe have different pipelineExecutionStrategy");
checkArgument(buildContext.getDriverInstanceCount().orElse(1) == 1, "Expected local execution to not be parallel");
int partitionCount = buildContext.getDriverInstanceCount().orElse(1);
checkArgument(partitionCount == 1, "Expected local execution to not be parallel");

int probeChannel = probeSource.getLayout().get(node.getSourceJoinSymbol());
int buildChannel = buildSource.getLayout().get(node.getFilteringSourceJoinSymbol());

ImmutableList.Builder<OperatorFactory> buildOperatorFactories = new ImmutableList.Builder<>();
buildOperatorFactories.addAll(buildSource.getOperatorFactories());

node.getDynamicFilterId().ifPresent(filterId -> {
// Add a DynamicFilterSourceOperatorFactory to build operator factories
log.debug("[Semi-join] Dynamic filter: %s", filterId);
LocalDynamicFilterConsumer filterConsumer = new LocalDynamicFilterConsumer(
ImmutableMap.of(filterId, buildChannel),
ImmutableMap.of(filterId, buildSource.getTypes().get(buildChannel)),
partitionCount);
addSuccessCallback(filterConsumer.getDynamicFilterDomains(), context::addDynamicFilter);
buildOperatorFactories.add(new DynamicFilterSourceOperatorFactory(
buildContext.getNextOperatorId(),
node.getId(),
filterConsumer.getTupleDomainConsumer(),
ImmutableList.of(new DynamicFilterSourceOperator.Channel(filterId, buildSource.getTypes().get(buildChannel), buildChannel)),
getDynamicFilteringMaxPerDriverRowCount(context.getSession()),
getDynamicFilteringMaxPerDriverSize(context.getSession())));
});

Optional<Integer> buildHashChannel = node.getFilteringSourceHashSymbol().map(channelGetter(buildSource));
Optional<Integer> probeHashChannel = node.getSourceHashSymbol().map(channelGetter(probeSource));

Expand All @@ -2246,14 +2273,12 @@ public PhysicalOperation visitSemiJoin(SemiJoinNode node, LocalExecutionPlanCont
buildHashChannel,
10_000,
joinCompiler);
buildOperatorFactories.add(setBuilderOperatorFactory);
SetSupplier setProvider = setBuilderOperatorFactory.getSetProvider();
context.addDriverFactory(
buildContext.isInputDriver(),
false,
ImmutableList.<OperatorFactory>builder()
.addAll(buildSource.getOperatorFactories())
.add(setBuilderOperatorFactory)
.build(),
buildOperatorFactories.build(),
buildContext.getDriverInstanceCount(),
buildSource.getPipelineExecutionStrategy());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,70 @@ public void testJoinDynamicFilteringBlockProbeSide()
ImmutableSet.of(1, ORDERS_COUNT, PART_COUNT));
}

@Test
public void testSemiJoinDynamicFilteringNone()
{
// Probe-side is not scanned at all, due to dynamic filtering:
assertDynamicFiltering(
"SELECT * FROM lineitem WHERE lineitem.orderkey IN (SELECT orders.orderkey FROM orders WHERE orders.totalprice < 0)",
withBroadcastJoin(),
0,
ImmutableSet.of(0, ORDERS_COUNT));
}

@Test
public void testSemiJoinLargeBuildSideNoDynamicFiltering()
{
// Probe-side is fully scanned because the build-side is too large for dynamic filtering:
assertDynamicFiltering(
"SELECT * FROM lineitem WHERE lineitem.orderkey IN (SELECT orders.orderkey FROM orders)",
withBroadcastJoin(),
toIntExact(LINEITEM_COUNT),
ImmutableSet.of(LINEITEM_COUNT, ORDERS_COUNT));
}

@Test
public void testPartitionedSemiJoinNoDynamicFiltering()
{
// Probe-side is fully scanned, because local dynamic filtering does not work for partitioned joins:
assertDynamicFiltering(
"SELECT * FROM lineitem WHERE lineitem.orderkey IN (SELECT orders.orderkey FROM orders WHERE orders.totalprice < 0)",
withPartitionedJoin(),
0,
ImmutableSet.of(LINEITEM_COUNT, ORDERS_COUNT));
}

@Test
public void testSemiJoinDynamicFilteringSingleValue()
{
// Join lineitem with a single row of orders
assertDynamicFiltering(
"SELECT * FROM lineitem WHERE lineitem.orderkey IN (SELECT orders.orderkey FROM orders WHERE orders.comment = 'nstructions sleep furiously among ')",
withBroadcastJoin(),
6,
ImmutableSet.of(6, ORDERS_COUNT));

// Join lineitem with a single row of part
assertDynamicFiltering(
"SELECT l.comment FROM lineitem l WHERE l.partkey IN (SELECT p.partkey FROM part p WHERE p.comment = 'onic deposits')",
withBroadcastJoin(),
39,
ImmutableSet.of(39, PART_COUNT));
}

@Test
public void testSemiJoinDynamicFilteringBlockProbeSide()
{
// Wait for both build sides to finish before starting the scan of 'lineitem' table (should be very selective given the dynamic filters).
assertDynamicFiltering(
"SELECT t.comment FROM " +
"(SELECT * FROM lineitem l WHERE l.orderkey IN (SELECT o.orderkey FROM orders o WHERE o.comment = 'nstructions sleep furiously among ')) t " +
"WHERE t.partkey IN (SELECT p.partkey FROM part p WHERE p.comment = 'onic deposits')",
withBroadcastJoinNonReordering(),
1,
ImmutableSet.of(1, ORDERS_COUNT, PART_COUNT));
}

private void assertDynamicFiltering(String selectQuery, Session session, int expectedRowCount, Set<Integer> expectedOperatorRowsRead)
{
DistributedQueryRunner runner = (DistributedQueryRunner) getQueryRunner();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,85 @@ public void testJoinWithMultipleDynamicFiltersOnProbe()
singleValue(BIGINT, 2L))));
}

@Test(timeOut = 30_000)
public void testSemiJoinWithEmptyBuildSide()
{
assertQueryDynamicFilters(
"SELECT * FROM lineitem WHERE lineitem.suppkey IN (SELECT supplier.suppkey FROM tpch.tiny.supplier WHERE supplier.name = 'abc')",
TupleDomain.none());
}

@Test(timeOut = 30_000)
public void testBroadcastSemiJoinWithEmptyBuildSide()
{
assertQueryDynamicFilters(
withBroadcastJoin(),
"SELECT * FROM lineitem WHERE lineitem.suppkey IN (SELECT supplier.suppkey FROM tpch.tiny.supplier WHERE supplier.name = 'abc')",
TupleDomain.none());
}

@Test(timeOut = 30_000)
public void testSemiJoinWithLargeBuildSide()
{
assertQueryDynamicFilters(
"SELECT * FROM lineitem WHERE lineitem.orderkey IN (SELECT orders.orderkey FROM tpch.tiny.orders)",
TupleDomain.all());
}

@Test(timeOut = 30_000)
public void testBroadcastSemiJoinWithLargeBuildSide()
{
assertQueryDynamicFilters(
withBroadcastJoin(),
"SELECT * FROM lineitem WHERE lineitem.orderkey IN (SELECT orders.orderkey FROM tpch.tiny.orders)",
TupleDomain.all());
}

@Test(timeOut = 30_000)
public void testSemiJoinWithSelectiveBuildSide()
{
assertQueryDynamicFilters(
"SELECT * FROM lineitem WHERE lineitem.suppkey IN (SELECT supplier.suppkey FROM tpch.tiny.supplier WHERE supplier.name = 'Supplier#000000001')",
TupleDomain.withColumnDomains(ImmutableMap.of(
SUPP_KEY_HANDLE,
singleValue(BIGINT, 1L))));
}

@Test(timeOut = 30_000)
public void testBroadcastSemiJoinWithSelectiveBuildSide()
{
assertQueryDynamicFilters(
withBroadcastJoin(),
"SELECT * FROM lineitem WHERE lineitem.suppkey IN (SELECT supplier.suppkey FROM tpch.tiny.supplier WHERE supplier.name = 'Supplier#000000001')",
TupleDomain.withColumnDomains(ImmutableMap.of(
SUPP_KEY_HANDLE,
singleValue(BIGINT, 1L))));
}

@Test(timeOut = 30_000)
public void testSemiJoinWithNonSelectiveBuildSide()
{
assertQueryDynamicFilters(
"SELECT * FROM lineitem WHERE lineitem.suppkey IN (SELECT supplier.suppkey FROM tpch.tiny.supplier)",
TupleDomain.withColumnDomains(ImmutableMap.of(
SUPP_KEY_HANDLE,
Domain.create(ValueSet.ofRanges(Range.range(BIGINT, 1L, true, 100L, true)), false))));
}

@Test(timeOut = 30_000)
public void testSemiJoinWithMultipleDynamicFiltersOnProbe()
{
// supplier names Supplier#000000001 and Supplier#000000002 match suppkey 1 and 2
assertQueryDynamicFilters(
"SELECT * FROM (" +
"SELECT lineitem.suppkey FROM lineitem WHERE lineitem.suppkey IN " +
"(SELECT supplier.suppkey FROM tpch.tiny.supplier WHERE supplier.name IN ('Supplier#000000001', 'Supplier#000000002'))) t " +
"WHERE t.suppkey IN (SELECT partsupp.suppkey FROM tpch.tiny.partsupp WHERE partsupp.suppkey IN (2, 3))",
TupleDomain.withColumnDomains(ImmutableMap.of(
SUPP_KEY_HANDLE,
singleValue(BIGINT, 2L))));
}

protected TupleDomain<ColumnHandle> getExpectedDynamicFilter(ConnectorSession session)
{
return expectedDynamicFilter.get(session.getSource().get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,27 @@ public void testBroadcastJoinWithLargeBuildSide()
// for broadcast joins lazy dynamic filters are non blocking
}

@Test(enabled = false)
@Override
public void testBroadcastSemiJoinWithSelectiveBuildSide()
{
// for broadcast semi-joins lazy dynamic filters are non blocking
}

@Test(enabled = false)
@Override
public void testBroadcastSemiJoinWithEmptyBuildSide()
{
// for broadcast semi-joins lazy dynamic filters are non blocking
}

@Test(enabled = false)
@Override
public void testBroadcastSemiJoinWithLargeBuildSide()
{
// for broadcast semi-joins lazy dynamic filters are non blocking
}

private class TestPlugin
implements Plugin
{
Expand Down

0 comments on commit af4e856

Please sign in to comment.