diff --git a/integration_tests/src/main/python/join_test.py b/integration_tests/src/main/python/join_test.py index cb4d539c750..8f9bffe82d7 100644 --- a/integration_tests/src/main/python/join_test.py +++ b/integration_tests/src/main/python/join_test.py @@ -14,7 +14,7 @@ import pytest from _pytest.mark.structures import ParameterSet -from pyspark.sql.functions import broadcast, col, lit +from pyspark.sql.functions import broadcast, col from pyspark.sql.types import * from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect, assert_cpu_and_gpu_are_equal_collect_with_capture from conftest import is_databricks_runtime, is_emr_runtime @@ -930,25 +930,86 @@ def do_join(spark): }) assert_gpu_and_cpu_are_equal_collect(do_join, conf=_all_conf) -bloom_filter_confs = { "spark.sql.autoBroadcastJoinThreshold": "1", - "spark.sql.optimizer.runtime.bloomFilter.applicationSideScanSizeThreshold": 1, - "spark.sql.optimizer.runtime.bloomFilter.creationSideThreshold": "100GB", - "spark.sql.optimizer.runtime.bloomFilter.enabled": "true"} +bloom_filter_confs = { + "spark.sql.autoBroadcastJoinThreshold": "1", + "spark.sql.optimizer.runtime.bloomFilter.applicationSideScanSizeThreshold": 1, + "spark.sql.optimizer.runtime.bloomFilter.creationSideThreshold": "100GB", + "spark.sql.optimizer.runtime.bloomFilter.enabled": "true" +} -def check_bloom_filter_join_multi_column(confs, expected_classes): +bloom_filter_exprs_enabled = { + "spark.rapids.sql.expression.BloomFilterMightContain": "true", + "spark.rapids.sql.expression.BloomFilterAggregate": "true" +} + +def check_bloom_filter_join(confs, expected_classes, is_multi_column): def do_join(spark): - left = spark.range(100000).withColumn("second_id", col("id") % 5) - right = spark.range(10).withColumn("id2", col("id").cast("string")).withColumn("second_id", col("id") % 5) - return right.filter("cast(id2 as bigint) % 3 = 0").join(left, (left.id == right.id) & (left.second_id == right.second_id), "inner") - all_confs = copy_and_update(confs, bloom_filter_confs) + if is_multi_column: + left = spark.range(100000).withColumn("second_id", col("id") % 5) + right = spark.range(10).withColumn("id2", col("id").cast("string")).withColumn("second_id", col("id") % 5) + return right.filter("cast(id2 as bigint) % 3 = 0").join(left, (left.id == right.id) & (left.second_id == right.second_id), "inner") + else: + left = spark.range(100000) + right = spark.range(10).withColumn("id2", col("id").cast("string")) + return right.filter("cast(id2 as bigint) % 3 = 0").join(left, left.id == right.id, "inner") + all_confs = copy_and_update(bloom_filter_confs, confs) assert_cpu_and_gpu_are_equal_collect_with_capture(do_join, expected_classes, conf=all_confs) +@allow_non_gpu("FilterExec", "ObjectHashAggregateExec", "ShuffleExchangeExec") +@ignore_order(local=True) +@pytest.mark.parametrize("is_multi_column", [False, True], ids=idfn) +@pytest.mark.skipif(is_databricks_runtime(), reason="https://github.com/NVIDIA/spark-rapids/issues/8921") +@pytest.mark.skipif(is_before_spark_330(), reason="Bloom filter joins added in Spark 3.3.0") +def test_bloom_filter_disabled_by_default(is_multi_column): + check_bloom_filter_join(confs={}, + expected_classes="BloomFilterMightContain,BloomFilterAggregate", + is_multi_column=is_multi_column) + +@ignore_order(local=True) +@pytest.mark.parametrize("batch_size", ['1g', '1000'], ids=idfn) +@pytest.mark.parametrize("is_multi_column", [False, True], ids=idfn) +@pytest.mark.skipif(is_databricks_runtime(), reason="https://github.com/NVIDIA/spark-rapids/issues/8921") +@pytest.mark.skipif(is_before_spark_330(), reason="Bloom filter joins added in Spark 3.3.0") +def test_bloom_filter_join(batch_size, is_multi_column): + conf = copy_and_update(bloom_filter_exprs_enabled, + {"spark.rapids.sql.batchSizeBytes": batch_size}) + check_bloom_filter_join(confs=conf, + expected_classes="GpuBloomFilterMightContain,GpuBloomFilterAggregate", + is_multi_column=is_multi_column) + +@allow_non_gpu("FilterExec", "ShuffleExchangeExec") +@ignore_order(local=True) +@pytest.mark.parametrize("is_multi_column", [False, True], ids=idfn) +@pytest.mark.skipif(is_databricks_runtime(), reason="https://github.com/NVIDIA/spark-rapids/issues/8921") +@pytest.mark.skipif(is_before_spark_330(), reason="Bloom filter joins added in Spark 3.3.0") +def test_bloom_filter_join_cpu_probe(is_multi_column): + conf = copy_and_update(bloom_filter_exprs_enabled, + {"spark.rapids.sql.expression.BloomFilterMightContain": "false"}) + check_bloom_filter_join(confs=conf, + expected_classes="BloomFilterMightContain,GpuBloomFilterAggregate", + is_multi_column=is_multi_column) @allow_non_gpu("ObjectHashAggregateExec", "ShuffleExchangeExec") @ignore_order(local=True) +@pytest.mark.parametrize("is_multi_column", [False, True], ids=idfn) @pytest.mark.skipif(is_databricks_runtime(), reason="https://github.com/NVIDIA/spark-rapids/issues/8921") @pytest.mark.skipif(is_before_spark_330(), reason="Bloom filter joins added in Spark 3.3.0") -def test_multi_column_bloom_filter_join(): - check_bloom_filter_join_multi_column(confs={}, - expected_classes="GpuBloomFilterMightContain") +def test_bloom_filter_join_cpu_build(is_multi_column): + conf = copy_and_update(bloom_filter_exprs_enabled, + {"spark.rapids.sql.expression.BloomFilterAggregate": "false"}) + check_bloom_filter_join(confs=conf, + expected_classes="GpuBloomFilterMightContain,BloomFilterAggregate", + is_multi_column=is_multi_column) +@allow_non_gpu("ObjectHashAggregateExec", "ProjectExec", "ShuffleExchangeExec") +@ignore_order(local=True) +@pytest.mark.parametrize("agg_replace_mode", ["partial", "final"]) +@pytest.mark.parametrize("is_multi_column", [False, True], ids=idfn) +@pytest.mark.skipif(is_databricks_runtime(), reason="https://github.com/NVIDIA/spark-rapids/issues/8921") +@pytest.mark.skipif(is_before_spark_330(), reason="Bloom filter joins added in Spark 3.3.0") +def test_bloom_filter_join_split_cpu_build(agg_replace_mode, is_multi_column): + conf = copy_and_update(bloom_filter_exprs_enabled, + {"spark.rapids.sql.hashAgg.replaceMode": agg_replace_mode}) + check_bloom_filter_join(confs=conf, + expected_classes="GpuBloomFilterMightContain,BloomFilterAggregate,GpuBloomFilterAggregate", + is_multi_column=is_multi_column) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExecutionPlanCaptureCallback.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExecutionPlanCaptureCallback.scala index 4b6813d7469..596858ca5e3 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExecutionPlanCaptureCallback.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExecutionPlanCaptureCallback.scala @@ -174,13 +174,13 @@ object ExecutionPlanCaptureCallback extends AdaptiveSparkPlanHelper { "Plan does not contain an ansi cast") } - private def didFallBack(exp: Expression, fallbackCpuClass: String): Boolean = { + def didFallBack(exp: Expression, fallbackCpuClass: String): Boolean = { !exp.getClass.getCanonicalName.equals("com.nvidia.spark.rapids.GpuExpression") && PlanUtils.getBaseNameFromClass(exp.getClass.getName) == fallbackCpuClass || exp.children.exists(didFallBack(_, fallbackCpuClass)) } - private def didFallBack(plan: SparkPlan, fallbackCpuClass: String): Boolean = { + def didFallBack(plan: SparkPlan, fallbackCpuClass: String): Boolean = { val executedPlan = ExecutionPlanCaptureCallback.extractExecutedPlan(plan) !executedPlan.getClass.getCanonicalName.equals("com.nvidia.spark.rapids.GpuExec") && PlanUtils.sameClass(executedPlan, fallbackCpuClass) || diff --git a/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/GpuBloomFilterAggregate.scala b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/GpuBloomFilterAggregate.scala new file mode 100644 index 00000000000..f791dca620a --- /dev/null +++ b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/GpuBloomFilterAggregate.scala @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "330"} +{"spark": "330cdh"} +{"spark": "330db"} +{"spark": "331"} +{"spark": "332"} +{"spark": "332db"} +{"spark": "333"} +{"spark": "340"} +{"spark": "341"} +spark-rapids-shim-json-lines ***/ +package com.nvidia.spark.rapids + +import ai.rapids.cudf.{ColumnVector, GroupByAggregation, Scalar} +import com.nvidia.spark.rapids.Arm.closeOnExcept +import com.nvidia.spark.rapids.GpuBloomFilterAggregate.optimalNumOfHashFunctions +import com.nvidia.spark.rapids.jni.BloomFilter + +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} +import org.apache.spark.sql.internal.SQLConf.{RUNTIME_BLOOM_FILTER_MAX_NUM_BITS, RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.rapids.{CudfAggregate, GpuAggregateFunction} +import org.apache.spark.sql.types.{BinaryType, DataType} + +case class GpuBloomFilterAggregate( + child: Expression, + estimatedNumItemsRequested: Long, + numBitsRequested: Long) extends GpuAggregateFunction { + + override def nullable: Boolean = true + + override def dataType: DataType = BinaryType + + override def prettyName: String = "bloom_filter_agg" + + private val estimatedNumItems: Long = + Math.min(estimatedNumItemsRequested, SQLConf.get.getConf(RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS)) + + private val numBits: Long = Math.min(numBitsRequested, + SQLConf.get.getConf(RUNTIME_BLOOM_FILTER_MAX_NUM_BITS)) + + private lazy val numHashes: Int = optimalNumOfHashFunctions(estimatedNumItems, numBits) + + override def children: Seq[Expression] = Seq(child) + + override lazy val initialValues: Seq[Expression] = Seq(GpuLiteral(null, BinaryType)) + + override val inputProjection: Seq[Expression] = Seq(child) + + override val updateAggregates: Seq[CudfAggregate] = Seq(GpuBloomFilterUpdate(numHashes, numBits)) + + override val mergeAggregates: Seq[CudfAggregate] = Seq(GpuBloomFilterMerge()) + + private lazy val bloomAttr: AttributeReference = AttributeReference("bloomFilter", dataType)() + + override def aggBufferAttributes: Seq[AttributeReference] = Seq(bloomAttr) + + override val evaluateExpression: Expression = bloomAttr +} + +object GpuBloomFilterAggregate { + /** + * From Spark's BloomFilter.optimalNumOfHashFunctions + * + * Computes the optimal k (number of hashes per item inserted in Bloom filter), given the + * expected insertions and total number of bits in the Bloom filter. + * + * See http://en.wikipedia.org/wiki/File:Bloom_filter_fp_probability.svg for the formula. + * + * @param n expected insertions (must be positive) + * @param m total number of bits in Bloom filter (must be positive) + */ + private def optimalNumOfHashFunctions(n: Long, m: Long): Int = { + // (m / n) * log(2), but avoid truncation due to division! + Math.max(1, Math.round(m.toDouble / n * Math.log(2)).toInt) + } +} + +case class GpuBloomFilterUpdate(numHashes: Int, numBits: Long) extends CudfAggregate { + override val reductionAggregate: ColumnVector => Scalar = (col: ColumnVector) => { + closeOnExcept(BloomFilter.create(numHashes, numBits)) { bloomFilter => + BloomFilter.put(bloomFilter, col) + bloomFilter + } + } + + override lazy val groupByAggregate: GroupByAggregation = + throw new UnsupportedOperationException("group by aggregations are not supported") + + override def dataType: DataType = BinaryType + + override val name: String = "gpu_bloom_filter_update" +} + +case class GpuBloomFilterMerge() extends CudfAggregate { + override val reductionAggregate: ColumnVector => Scalar = (col: ColumnVector) => { + BloomFilter.merge(col) + } + + override lazy val groupByAggregate: GroupByAggregation = + throw new UnsupportedOperationException("group by aggregations are not supported") + + override def dataType: DataType = BinaryType + + override val name: String = "gpu_bloom_filter_merge" +} diff --git a/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/BloomFilterShims.scala b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/BloomFilterShims.scala index bc69ab738d2..2e3573546ed 100644 --- a/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/BloomFilterShims.scala +++ b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/BloomFilterShims.scala @@ -30,6 +30,7 @@ package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate object BloomFilterShims { @@ -45,7 +46,25 @@ object BloomFilterShims { (a, conf, p, r) => new BinaryExprMeta[BloomFilterMightContain](a, conf, p, r) { override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = GpuBloomFilterMightContain(lhs, rhs) - }) + }).disabledByDefault("Bloom filter join acceleration is experimental"), + GpuOverrides.expr[BloomFilterAggregate]( + "Bloom filter build", + ExprChecksImpl(Map( + (ReductionAggExprContext, + ContextChecks(TypeSig.BINARY, TypeSig.BINARY, + Seq(ParamCheck("child", TypeSig.LONG, TypeSig.LONG), + ParamCheck("estimatedItems", + TypeSig.lit(TypeEnum.LONG), TypeSig.lit(TypeEnum.LONG)), + ParamCheck("numBits", + TypeSig.lit(TypeEnum.LONG), TypeSig.lit(TypeEnum.LONG))))))), + (a, conf, p, r) => new ExprMeta[BloomFilterAggregate](a, conf, p, r) { + override def convertToGpu(): GpuExpression = { + GpuBloomFilterAggregate( + childExprs.head.convertToGpu(), + a.estimatedNumItemsExpression.eval().asInstanceOf[Number].longValue, + a.numBitsExpression.eval().asInstanceOf[Number].longValue) + } + }).disabledByDefault("Bloom filter join acceleration is experimental") ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala index c283bd7afd1..c9a58051365 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.command.{CreateViewCommand, ExecutedCommandExec} import org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback import org.apache.spark.sql.rapids.execution.TrampolineUtil import org.apache.spark.sql.types._ @@ -283,6 +284,7 @@ trait SparkQueryCompareTestSuite extends AnyFunSuite with BeforeAndAfterAll { } finally { cpuPlans = ExecutionPlanCaptureCallback.getResultsWithTimeout() } + cpuPlans = filterCapturedPlans(cpuPlans) assert(cpuPlans.nonEmpty, "Did not capture CPU plan") assert(cpuPlans.length == 1, s"Captured more than one CPU plan: ${cpuPlans.mkString("\n")}") @@ -301,12 +303,21 @@ trait SparkQueryCompareTestSuite extends AnyFunSuite with BeforeAndAfterAll { } finally { gpuPlans = ExecutionPlanCaptureCallback.getResultsWithTimeout() } + gpuPlans = filterCapturedPlans(gpuPlans) assert(gpuPlans.nonEmpty, "Did not capture GPU plan") assert(gpuPlans.length == 1, s"Captured more than one GPU plan: ${gpuPlans.mkString("\n")}") (cpuPlans.head, gpuPlans.head) } + // filter out "uninteresting" plans like view creation, etc. + protected def filterCapturedPlans(plans: Array[SparkPlan]): Array[SparkPlan] = { + plans.filter { + case ExecutedCommandExec(_: CreateViewCommand) => false + case _ => true + } + } + def runOnCpuAndGpuWithCapture(df: SparkSession => DataFrame, fun: DataFrame => DataFrame, conf: SparkConf = new SparkConf(), @@ -332,6 +343,7 @@ trait SparkQueryCompareTestSuite extends AnyFunSuite with BeforeAndAfterAll { } finally { cpuPlans = ExecutionPlanCaptureCallback.getResultsWithTimeout() } + cpuPlans = filterCapturedPlans(cpuPlans) assert(cpuPlans.nonEmpty, "Did not capture CPU plan") assert(cpuPlans.length == 1, s"Captured more than one CPU plan: ${cpuPlans.mkString("\n")}") @@ -351,6 +363,7 @@ trait SparkQueryCompareTestSuite extends AnyFunSuite with BeforeAndAfterAll { } finally { gpuPlans = ExecutionPlanCaptureCallback.getResultsWithTimeout() } + gpuPlans = filterCapturedPlans(gpuPlans) assert(gpuPlans.nonEmpty, "Did not capture GPU plan") assert(gpuPlans.length == 1, s"Captured more than one GPU plan: ${gpuPlans.mkString("\n")}") diff --git a/tests/src/test/spark330/scala/com/nvidia/spark/rapids/BloomFilterAggregateQuerySuite.scala b/tests/src/test/spark330/scala/com/nvidia/spark/rapids/BloomFilterAggregateQuerySuite.scala index cac6a9a55dc..01dc4eafac1 100644 --- a/tests/src/test/spark330/scala/com/nvidia/spark/rapids/BloomFilterAggregateQuerySuite.scala +++ b/tests/src/test/spark330/scala/com/nvidia/spark/rapids/BloomFilterAggregateQuerySuite.scala @@ -27,13 +27,19 @@ spark-rapids-shim-json-lines ***/ package com.nvidia.spark.rapids +import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.expressions.{BloomFilterMightContain, Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback class BloomFilterAggregateQuerySuite extends SparkQueryCompareTestSuite { + val bloomFilterEnabledConf = new SparkConf() + .set("spark.rapids.sql.expression.BloomFilterMightContain", "true") + .set("spark.rapids.sql.expression.BloomFilterAggregate", "true") val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg") val funcId_might_contain = new FunctionIdentifier("might_contain") @@ -65,74 +71,145 @@ class BloomFilterAggregateQuerySuite extends SparkQueryCompareTestSuite { (1L to 100L).map(_ => None)).toDF("col") } + private def withExposedSqlFuncs[T](spark: SparkSession)(func: SparkSession => T): T = { + try { + installSqlFuncs(spark) + func(spark) + } finally { + uninstallSqlFuncs(spark) + } + } + + private def doBloomFilterTest(numEstimated: Long, numBits: Long): DataFrame => DataFrame = { + df => + val table = "bloom_filter_test" + val sqlString = + s""" + |SELECT might_contain( + | (SELECT bloom_filter_agg(col, + | cast($numEstimated as long), + | cast($numBits as long)) + | FROM $table), + | col) positive_membership_test, + | might_contain( + | (SELECT bloom_filter_agg(col, + | cast($numEstimated as long), + | cast($numBits as long)) + | FROM values (-1L), (100001L), (20000L) as t(col)), + | col) negative_membership_test + |FROM $table + """.stripMargin + df.createOrReplaceTempView(table) + withExposedSqlFuncs(df.sparkSession) { spark => + spark.sql(sqlString) + } + } + + private def getPlanValidator(exec: String): (SparkPlan, SparkPlan) => Unit = { + def searchPlan(p: SparkPlan): Boolean = { + ExecutionPlanCaptureCallback.didFallBack(p, exec) || + p.children.exists(searchPlan) || + p.subqueries.exists(searchPlan) + } + (_, gpuPlan) => { + val executedPlan = ExecutionPlanCaptureCallback.extractExecutedPlan(gpuPlan) + assert(searchPlan(executedPlan), s"Could not find $exec in the GPU plan:\n$executedPlan") + } + } + + // test with GPU bloom build, GPU bloom probe + for (numEstimated <- Seq(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS.defaultValue.get)) { + for (numBits <- Seq(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS.defaultValue.get)) { + testSparkResultsAreEqual( + s"might_contain GPU build GPU probe estimated=$numEstimated numBits=$numBits", + buildData, + conf = bloomFilterEnabledConf.clone() + )(doBloomFilterTest(numEstimated, numBits)) + } + } + + // test with CPU bloom build, GPU bloom probe + for (numEstimated <- Seq(4096L, 4194304L, Long.MaxValue, + SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS.defaultValue.get)) { + for (numBits <- Seq(4096L, 4194304L, + SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS.defaultValue.get)) { + ALLOW_NON_GPU_testSparkResultsAreEqualWithCapture( + s"might_contain CPU build GPU probe estimated=$numEstimated numBits=$numBits", + buildData, + Seq("ObjectHashAggregateExec", "ShuffleExchangeExec"), + conf = bloomFilterEnabledConf.clone() + .set("spark.rapids.sql.expression.BloomFilterAggregate", "false") + )(doBloomFilterTest(numEstimated, numBits))(getPlanValidator("ObjectHashAggregateExec")) + } + } + + // test with GPU bloom build, CPU bloom probe for (numEstimated <- Seq(4096L, 4194304L, Long.MaxValue, SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS.defaultValue.get)) { - for (numBits <- Seq(4096L, 4194304L, Long.MaxValue, + for (numBits <- Seq(4096L, 4194304L, SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS.defaultValue.get)) { - ALLOW_NON_GPU_testSparkResultsAreEqual( - s"might_contain estimated=$numEstimated numBits=$numBits", + ALLOW_NON_GPU_testSparkResultsAreEqualWithCapture( + s"might_contain GPU build CPU probe estimated=$numEstimated numBits=$numBits", buildData, - Seq("ObjectHashAggregateExec", "ShuffleExchangeExec"))(df => - { - val table = "bloom_filter_test" - val sqlString = - s""" - |SELECT might_contain( - | (SELECT bloom_filter_agg(col, - | cast($numEstimated as long), - | cast($numBits as long)) - | FROM $table), - | col) positive_membership_test, - | might_contain( - | (SELECT bloom_filter_agg(col, - | cast($numEstimated as long), - | cast($numBits as long)) - | FROM values (-1L), (100001L), (20000L) as t(col)), - | col) negative_membership_test - |FROM $table - """.stripMargin - df.createOrReplaceTempView(table) - try { - installSqlFuncs(df.sparkSession) - df.sparkSession.sql(sqlString) - } finally { - uninstallSqlFuncs(df.sparkSession) - } - }) + Seq("LocalTableScanExec", "ProjectExec", "ShuffleExchangeExec"), + conf = bloomFilterEnabledConf.clone() + .set("spark.rapids.sql.expression.BloomFilterMightContain", "false") + )(doBloomFilterTest(numEstimated, numBits))(getPlanValidator("ProjectExec")) + } + } + + // test with partial/final-only GPU bloom build, CPU bloom probe + for (mode <- Seq("partial", "final")) { + for (numEstimated <- Seq(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS.defaultValue.get)) { + for (numBits <- Seq(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS.defaultValue.get)) { + ALLOW_NON_GPU_testSparkResultsAreEqualWithCapture( + s"might_contain GPU $mode build CPU probe estimated=$numEstimated numBits=$numBits", + buildData, + Seq("ObjectHashAggregateExec", "ProjectExec", "ShuffleExchangeExec"), + conf = bloomFilterEnabledConf.clone() + .set("spark.rapids.sql.expression.BloomFilterMightContain", "false") + .set("spark.rapids.sql.hashAgg.replaceMode", mode) + )(doBloomFilterTest(numEstimated, numBits))(getPlanValidator("ObjectHashAggregateExec")) + } } } testSparkResultsAreEqual( "might_contain with literal bloom filter buffer", - spark => spark.range(1, 1).asInstanceOf[DataFrame]) { + spark => spark.range(1, 1).asInstanceOf[DataFrame], + conf=bloomFilterEnabledConf.clone()) { df => - try { - installSqlFuncs(df.sparkSession) - df.sparkSession.sql( + withExposedSqlFuncs(df.sparkSession) { spark => + spark.sql( """SELECT might_contain( |X'00000001000000050000000343A2EC6EA8C117E2D3CDB767296B144FC5BFBCED9737F267', |cast(201 as long))""".stripMargin) - } finally { - uninstallSqlFuncs(df.sparkSession) } } - ALLOW_NON_GPU_testSparkResultsAreEqual( + testSparkResultsAreEqual( "might_contain with all NULL inputs", spark => spark.range(1, 1).asInstanceOf[DataFrame], - Seq("ObjectHashAggregateExec", "ShuffleExchangeExec")) { + conf=bloomFilterEnabledConf.clone()) { df => - try { - installSqlFuncs(df.sparkSession) - df.sparkSession.sql( + withExposedSqlFuncs(df.sparkSession) { spark => + spark.sql( """ |SELECT might_contain(null, null) both_null, | might_contain(null, 1L) null_bf, | might_contain((SELECT bloom_filter_agg(cast(id as long)) from range(1, 10000)), | null) null_value """.stripMargin) - } finally { - uninstallSqlFuncs(df.sparkSession) + } + } + + testSparkResultsAreEqual( + "bloom_filter_agg with empty input", + spark => spark.range(1, 1).asInstanceOf[DataFrame], + conf=bloomFilterEnabledConf.clone()) { + df => + withExposedSqlFuncs(df.sparkSession) { spark => + spark.sql("""SELECT bloom_filter_agg(cast(id as long)) from range(1, 1)""") } } }