From 57e442d97a75034baa6f2251de4a00cba031ffb5 Mon Sep 17 00:00:00 2001 From: Roman Zeyde Date: Tue, 8 Oct 2019 11:40:10 +0300 Subject: [PATCH] Add unit-tests for LocalDynamicFilter and LocalDynamicFilterCollector Following https://github.com/prestosql/presto/pull/1686 --- .../prestosql/testing/LocalQueryRunner.java | 7 +- .../sql/planner/TestLocalDynamicFilter.java | 306 ++++++++++++++++++ .../TestLocalDynamicFiltersCollector.java | 53 +++ .../sql/planner/assertions/BasePlanTest.java | 19 ++ 4 files changed, 384 insertions(+), 1 deletion(-) create mode 100644 presto-main/src/test/java/io/prestosql/sql/planner/TestLocalDynamicFilter.java create mode 100644 presto-main/src/test/java/io/prestosql/sql/planner/TestLocalDynamicFiltersCollector.java diff --git a/presto-main/src/main/java/io/prestosql/testing/LocalQueryRunner.java b/presto-main/src/main/java/io/prestosql/testing/LocalQueryRunner.java index 142ae78dae81..e12ba08fc3a3 100644 --- a/presto-main/src/main/java/io/prestosql/testing/LocalQueryRunner.java +++ b/presto-main/src/main/java/io/prestosql/testing/LocalQueryRunner.java @@ -677,13 +677,18 @@ public List createDrivers(Session session, @Language("SQL") String sql, return createDrivers(session, plan, outputFactory, taskContext); } + public SubPlan createSubPlans(Session session, Plan plan, boolean forceSingleNode) + { + return planFragmenter.createSubPlans(session, plan, forceSingleNode, WarningCollector.NOOP); + } + private List createDrivers(Session session, Plan plan, OutputFactory outputFactory, TaskContext taskContext) { if (printPlan) { System.out.println(PlanPrinter.textLogicalPlan(plan.getRoot(), plan.getTypes(), metadata, plan.getStatsAndCosts(), session, 0, false)); } - SubPlan subplan = planFragmenter.createSubPlans(session, plan, true, WarningCollector.NOOP); + SubPlan subplan = createSubPlans(session, plan, true); if (!subplan.getChildren().isEmpty()) { throw new AssertionError("Expected subplan to have no children"); } diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestLocalDynamicFilter.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestLocalDynamicFilter.java new file mode 100644 index 000000000000..01b111783d66 --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/planner/TestLocalDynamicFilter.java @@ -0,0 +1,306 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.sql.planner; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableMultimap; +import com.google.common.collect.Iterables; +import com.google.common.util.concurrent.ListenableFuture; +import io.prestosql.Session; +import io.prestosql.spi.predicate.Domain; +import io.prestosql.spi.predicate.TupleDomain; +import io.prestosql.sql.planner.assertions.BasePlanTest; +import io.prestosql.sql.planner.optimizations.PlanNodeSearcher; +import io.prestosql.sql.planner.plan.JoinNode; +import org.testng.annotations.Test; + +import java.util.Comparator; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.ExecutionException; +import java.util.function.Consumer; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.prestosql.SystemSessionProperties.ENABLE_DYNAMIC_FILTERING; +import static io.prestosql.SystemSessionProperties.FORCE_SINGLE_NODE_OUTPUT; +import static io.prestosql.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.IntegerType.INTEGER; +import static io.prestosql.testing.assertions.Assert.assertEquals; +import static org.testng.Assert.assertFalse; + +public class TestLocalDynamicFilter + extends BasePlanTest +{ + public TestLocalDynamicFilter() + { + super(ImmutableMap.of( + FORCE_SINGLE_NODE_OUTPUT, "false", + JOIN_DISTRIBUTION_TYPE, "BROADCAST", + ENABLE_DYNAMIC_FILTERING, "true")); + } + + @Test + public void testSimple() + throws ExecutionException, InterruptedException + { + LocalDynamicFilter filter = new LocalDynamicFilter( + ImmutableMultimap.of("123", new Symbol("a")), + ImmutableMap.of("123", 0), + 1); + assertEquals(filter.getBuildChannels(), ImmutableMap.of("123", 0)); + Consumer> consumer = filter.getTupleDomainConsumer(); + ListenableFuture> result = filter.getResultFuture(); + assertFalse(result.isDone()); + + consumer.accept(TupleDomain.withColumnDomains(ImmutableMap.of( + "123", Domain.singleValue(INTEGER, 7L)))); + assertEquals(result.get(), TupleDomain.withColumnDomains(ImmutableMap.of( + new Symbol("a"), Domain.singleValue(INTEGER, 7L)))); + } + + @Test + public void testMultipleProbeSymbols() + throws ExecutionException, InterruptedException + { + LocalDynamicFilter filter = new LocalDynamicFilter( + ImmutableMultimap.of("123", new Symbol("a1"), "123", new Symbol("a2")), + ImmutableMap.of("123", 0), + 1); + assertEquals(filter.getBuildChannels(), ImmutableMap.of("123", 0)); + Consumer> consumer = filter.getTupleDomainConsumer(); + ListenableFuture> result = filter.getResultFuture(); + assertFalse(result.isDone()); + + consumer.accept(TupleDomain.withColumnDomains(ImmutableMap.of( + "123", Domain.singleValue(INTEGER, 7L)))); + assertEquals(result.get(), TupleDomain.withColumnDomains(ImmutableMap.of( + new Symbol("a1"), Domain.singleValue(INTEGER, 7L), + new Symbol("a2"), Domain.singleValue(INTEGER, 7L)))); + } + + @Test + public void testMultiplePartitions() + throws ExecutionException, InterruptedException + { + LocalDynamicFilter filter = new LocalDynamicFilter( + ImmutableMultimap.of("123", new Symbol("a")), + ImmutableMap.of("123", 0), + 2); + assertEquals(filter.getBuildChannels(), ImmutableMap.of("123", 0)); + Consumer> consumer = filter.getTupleDomainConsumer(); + ListenableFuture> result = filter.getResultFuture(); + + assertFalse(result.isDone()); + consumer.accept(TupleDomain.withColumnDomains(ImmutableMap.of( + "123", Domain.singleValue(INTEGER, 10L)))); + + assertFalse(result.isDone()); + consumer.accept(TupleDomain.withColumnDomains(ImmutableMap.of( + "123", Domain.singleValue(INTEGER, 20L)))); + + assertEquals(result.get(), TupleDomain.withColumnDomains(ImmutableMap.of( + new Symbol("a"), Domain.multipleValues(INTEGER, ImmutableList.of(10L, 20L))))); + } + + @Test + public void testNone() + throws ExecutionException, InterruptedException + { + LocalDynamicFilter filter = new LocalDynamicFilter( + ImmutableMultimap.of("123", new Symbol("a")), + ImmutableMap.of("123", 0), + 1); + assertEquals(filter.getBuildChannels(), ImmutableMap.of("123", 0)); + Consumer> consumer = filter.getTupleDomainConsumer(); + ListenableFuture> result = filter.getResultFuture(); + + assertFalse(result.isDone()); + consumer.accept(TupleDomain.withColumnDomains(ImmutableMap.of( + "123", Domain.none(INTEGER)))); + + assertEquals(result.get(), TupleDomain.withColumnDomains(ImmutableMap.of( + new Symbol("a"), Domain.none(INTEGER)))); + } + + @Test + public void testMultipleColumns() + throws ExecutionException, InterruptedException + { + LocalDynamicFilter filter = new LocalDynamicFilter( + ImmutableMultimap.of("123", new Symbol("a"), "456", new Symbol("b")), + ImmutableMap.of("123", 0, "456", 1), + 1); + assertEquals(filter.getBuildChannels(), ImmutableMap.of("123", 0, "456", 1)); + Consumer> consumer = filter.getTupleDomainConsumer(); + ListenableFuture> result = filter.getResultFuture(); + assertFalse(result.isDone()); + + consumer.accept(TupleDomain.withColumnDomains(ImmutableMap.of( + "123", Domain.singleValue(INTEGER, 10L), + "456", Domain.singleValue(INTEGER, 20L)))); + assertEquals(result.get(), TupleDomain.withColumnDomains(ImmutableMap.of( + new Symbol("a"), Domain.singleValue(INTEGER, 10L), + new Symbol("b"), Domain.singleValue(INTEGER, 20L)))); + } + + @Test + public void testMultiplePartitionsAndColumns() + throws ExecutionException, InterruptedException + { + LocalDynamicFilter filter = new LocalDynamicFilter( + ImmutableMultimap.of("123", new Symbol("a"), "456", new Symbol("b")), + ImmutableMap.of("123", 0, "456", 1), + 2); + assertEquals(filter.getBuildChannels(), ImmutableMap.of("123", 0, "456", 1)); + Consumer> consumer = filter.getTupleDomainConsumer(); + ListenableFuture> result = filter.getResultFuture(); + + assertFalse(result.isDone()); + consumer.accept(TupleDomain.withColumnDomains(ImmutableMap.of( + "123", Domain.singleValue(INTEGER, 10L), + "456", Domain.singleValue(BIGINT, 100L)))); + + assertFalse(result.isDone()); + consumer.accept(TupleDomain.withColumnDomains(ImmutableMap.of( + "123", Domain.singleValue(INTEGER, 20L), + "456", Domain.singleValue(BIGINT, 200L)))); + + assertEquals(result.get(), TupleDomain.withColumnDomains(ImmutableMap.of( + new Symbol("a"), Domain.multipleValues(INTEGER, ImmutableList.of(10L, 20L)), + new Symbol("b"), Domain.multipleValues(BIGINT, ImmutableList.of(100L, 200L))))); + } + + @Test + public void testCreateSingleColumn() + throws ExecutionException, InterruptedException + { + SubPlan subplan = subplan( + "SELECT count() FROM lineitem, orders WHERE lineitem.orderkey = orders.orderkey " + + "AND orders.custkey < 10", + LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, + false); + JoinNode joinNode = searchJoins(subplan.getChildren().get(0).getFragment()).findOnlyElement(); + LocalDynamicFilter filter = LocalDynamicFilter.create(joinNode, 1).get(); + String filterId = Iterables.getOnlyElement(filter.getBuildChannels().keySet()); + Symbol probeSymbol = Iterables.getOnlyElement(joinNode.getCriteria()).getLeft(); + + filter.getTupleDomainConsumer().accept(TupleDomain.withColumnDomains(ImmutableMap.of( + filterId, Domain.singleValue(BIGINT, 3L)))); + assertEquals(filter.getResultFuture().get(), TupleDomain.withColumnDomains(ImmutableMap.of( + probeSymbol, Domain.singleValue(BIGINT, 3L)))); + } + + @Test + public void testCreateDistributedJoin() + { + Session session = Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(JOIN_DISTRIBUTION_TYPE, "PARTITIONED") + .build(); + SubPlan subplan = subplan( + "SELECT count() FROM nation, region WHERE nation.regionkey = region.regionkey " + + "AND region.comment = 'abc'", + LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, + false, + session); + JoinNode joinNode = searchJoins(subplan.getChildren().get(0).getFragment()).findOnlyElement(); + assertEquals(joinNode.getDynamicFilters().isEmpty(), false); + assertEquals(LocalDynamicFilter.create(joinNode, 1), Optional.empty()); + } + + @Test + public void testCreateMultipleCriteria() + throws ExecutionException, InterruptedException + { + SubPlan subplan = subplan( + "SELECT count() FROM lineitem, partsupp " + + "WHERE lineitem.partkey = partsupp.partkey AND lineitem.suppkey = partsupp.suppkey " + + "AND partsupp.availqty < 10", + LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, + false); + JoinNode joinNode = searchJoins(subplan.getChildren().get(0).getFragment()).findOnlyElement(); + LocalDynamicFilter filter = LocalDynamicFilter.create(joinNode, 1).get(); + List filterIds = filter + .getBuildChannels() + .entrySet() + .stream() + .sorted(Comparator.comparing(e -> e.getValue())) + .map(e -> e.getKey()) + .collect(toImmutableList()); + filter.getTupleDomainConsumer().accept(TupleDomain.withColumnDomains(ImmutableMap.of( + filterIds.get(0), Domain.singleValue(BIGINT, 4L), + filterIds.get(1), Domain.singleValue(BIGINT, 5L)))); + + TupleDomain expected = TupleDomain.withColumnDomains(ImmutableMap.of( + new Symbol("partkey"), Domain.singleValue(BIGINT, 4L), + new Symbol("suppkey"), Domain.singleValue(BIGINT, 5L))); + assertEquals(filter.getResultFuture().get(), expected); + } + + @Test + public void testCreateMultipleJoins() + throws ExecutionException, InterruptedException + { + SubPlan subplan = subplan( + "SELECT count() FROM lineitem, orders, part " + + "WHERE lineitem.orderkey = orders.orderkey AND lineitem.partkey = part.partkey " + + "AND orders.custkey < 10 AND part.name = 'abc'", + LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, + false); + List joinNodes = searchJoins(subplan.getChildren().get(0).getFragment()).findAll(); + assertEquals(joinNodes.size(), 2); + for (JoinNode joinNode : joinNodes) { + LocalDynamicFilter filter = LocalDynamicFilter.create(joinNode, 1).get(); + String filterId = Iterables.getOnlyElement(filter.getBuildChannels().keySet()); + Symbol probeSymbol = Iterables.getOnlyElement(joinNode.getCriteria()).getLeft(); + + filter.getTupleDomainConsumer().accept(TupleDomain.withColumnDomains(ImmutableMap.of( + filterId, Domain.singleValue(BIGINT, 6L)))); + assertEquals(filter.getResultFuture().get(), TupleDomain.withColumnDomains(ImmutableMap.of( + probeSymbol, Domain.singleValue(BIGINT, 6L)))); + } + } + + @Test + public void testCreateProbeSideUnion() + throws ExecutionException, InterruptedException + { + SubPlan subplan = subplan( + "WITH union_table(key) AS " + + "((SELECT partkey FROM part) UNION (SELECT suppkey FROM supplier)) " + + "SELECT count() FROM union_table, nation WHERE union_table.key = nation.nationkey " + + "AND nation.comment = 'abc'", + LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, + true); + JoinNode joinNode = searchJoins(subplan.getFragment()).findOnlyElement(); + LocalDynamicFilter filter = LocalDynamicFilter.create(joinNode, 1).get(); + String filterId = Iterables.getOnlyElement(filter.getBuildChannels().keySet()); + + filter.getTupleDomainConsumer().accept(TupleDomain.withColumnDomains(ImmutableMap.of( + filterId, Domain.singleValue(BIGINT, 7L)))); + TupleDomain expected = TupleDomain.withColumnDomains(ImmutableMap.of( + new Symbol("partkey"), Domain.singleValue(BIGINT, 7L), + new Symbol("suppkey"), Domain.singleValue(BIGINT, 7L))); + assertEquals(filter.getResultFuture().get(), expected); + } + + private PlanNodeSearcher searchJoins(PlanFragment fragment) + { + return PlanNodeSearcher + .searchFrom(fragment.getRoot()) + .where(node -> node instanceof JoinNode); + } +} diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestLocalDynamicFiltersCollector.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestLocalDynamicFiltersCollector.java new file mode 100644 index 000000000000..77202cc54970 --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/planner/TestLocalDynamicFiltersCollector.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.prestosql.sql.planner; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.prestosql.spi.predicate.Domain; +import io.prestosql.spi.predicate.TupleDomain; +import org.testng.annotations.Test; + +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.testing.assertions.Assert.assertEquals; + +public class TestLocalDynamicFiltersCollector +{ + @Test + public void testCollector() + { + Symbol symbol = new Symbol("symbol"); + + LocalDynamicFiltersCollector collector = new LocalDynamicFiltersCollector(); + assertEquals(collector.getPredicate(), TupleDomain.all()); + + collector.intersect(TupleDomain.all()); + assertEquals(collector.getPredicate(), TupleDomain.all()); + + collector.intersect(tupleDomain(symbol, 1L, 2L)); + assertEquals(collector.getPredicate(), tupleDomain(symbol, 1L, 2L)); + + collector.intersect(tupleDomain(symbol, 2L, 3L)); + assertEquals(collector.getPredicate(), tupleDomain(symbol, 2L)); + + collector.intersect(tupleDomain(symbol, 0L)); + assertEquals(collector.getPredicate(), TupleDomain.none()); + } + + private TupleDomain tupleDomain(Symbol symbol, Long... values) + { + return TupleDomain.withColumnDomains(ImmutableMap.of(symbol, Domain.multipleValues(BIGINT, ImmutableList.copyOf(values)))); + } +} diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/BasePlanTest.java b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/BasePlanTest.java index 731edc877890..d0392d1a12d7 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/BasePlanTest.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/BasePlanTest.java @@ -23,6 +23,7 @@ import io.prestosql.sql.planner.LogicalPlanner; import io.prestosql.sql.planner.Plan; import io.prestosql.sql.planner.RuleStatsRecorder; +import io.prestosql.sql.planner.SubPlan; import io.prestosql.sql.planner.iterative.IterativeOptimizer; import io.prestosql.sql.planner.iterative.rule.RemoveRedundantIdentityProjections; import io.prestosql.sql.planner.optimizations.PlanOptimizer; @@ -206,6 +207,24 @@ protected Plan plan(String sql, LogicalPlanner.Stage stage, boolean forceSingleN } } + protected SubPlan subplan(String sql, LogicalPlanner.Stage stage, boolean forceSingleNode) + { + return subplan(sql, stage, forceSingleNode, getQueryRunner().getDefaultSession()); + } + + protected SubPlan subplan(String sql, LogicalPlanner.Stage stage, boolean forceSingleNode, Session session) + { + try { + return queryRunner.inTransaction(session, transactionSession -> { + Plan plan = queryRunner.createPlan(transactionSession, sql, stage, forceSingleNode, WarningCollector.NOOP); + return queryRunner.createSubPlans(transactionSession, plan, forceSingleNode); + }); + } + catch (RuntimeException e) { + throw new AssertionError("Planning failed for SQL: " + sql, e); + } + } + public interface LocalQueryRunnerSupplier { LocalQueryRunner get();