Skip to content

Commit

Permalink
Fix tests to skip on Databricks and check for specific classes
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowe committed Aug 3, 2023
1 parent 26f3f5a commit 9202839
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from conftest import is_databricks_runtime, is_emr_runtime
from data_gen import *
from marks import ignore_order, allow_non_gpu, incompat, validate_execs_in_gpu_plan
from spark_session import with_cpu_session, is_before_spark_330
from spark_session import with_cpu_session, is_before_spark_330, is_databricks_runtime

pytestmark = [pytest.mark.nightly_resource_consuming_test]

Expand Down Expand Up @@ -931,7 +931,7 @@ def do_join(spark):
assert_gpu_and_cpu_are_equal_collect(do_join, conf=_all_conf)


def check_bloom_filter_join(confs):
def check_bloom_filter_join(confs, expected_classes):
def do_join(spark):
left = spark.range(100000)
right = spark.range(10).withColumn("id2", col("id").cast("string"))
Expand All @@ -940,29 +940,36 @@ def do_join(spark):
"spark.sql.optimizer.runtime.bloomFilter.applicationSideScanSizeThreshold": 1,
"spark.sql.optimizer.runtime.bloomFilter.creationSideThreshold": "100GB",
"spark.sql.optimizer.runtime.bloomFilter.enabled": "true"})
assert_gpu_and_cpu_are_equal_collect(do_join, conf=all_confs)
assert_cpu_and_gpu_are_equal_collect_with_capture(do_join, expected_classes, conf=all_confs)

@ignore_order(local=True)
@pytest.mark.skipif(is_databricks_runtime(), reason="https://github.com/NVIDIA/spark-rapids/issues/8921")
@pytest.mark.skipif(is_before_spark_330(), reason="Bloom filter joins added in Spark 3.3.0")
def test_bloom_filter_join():
check_bloom_filter_join(confs={})
check_bloom_filter_join(confs={},
expected_classes="GpuBloomFilterMightContain,GpuBloomFilterAggregate")

@allow_non_gpu("FilterExec", "ShuffleExchangeExec")
@ignore_order(local=True)
@pytest.mark.skipif(is_databricks_runtime(), reason="https://github.com/NVIDIA/spark-rapids/issues/8921")
@pytest.mark.skipif(is_before_spark_330(), reason="Bloom filter joins added in Spark 3.3.0")
def test_bloom_filter_join_cpu_probe():
check_bloom_filter_join(confs={"spark.rapids.sql.expression.BloomFilterMightContain": "false"})
check_bloom_filter_join(confs={"spark.rapids.sql.expression.BloomFilterMightContain": "false"},
expected_classes="BloomFilterMightContain,GpuBloomFilterAggregate")

@allow_non_gpu("ObjectHashAggregateExec", "ShuffleExchangeExec")
@ignore_order(local=True)
@pytest.mark.skipif(is_databricks_runtime(), reason="https://github.com/NVIDIA/spark-rapids/issues/8921")
@pytest.mark.skipif(is_before_spark_330(), reason="Bloom filter joins added in Spark 3.3.0")
def test_bloom_filter_join_cpu_build():
check_bloom_filter_join(confs={"spark.rapids.sql.expression.BloomFilterAggregate": "false"})
check_bloom_filter_join(confs={"spark.rapids.sql.expression.BloomFilterAggregate": "false"},
expected_classes="GpuBloomFilterMightContain,BloomFilterAggregate")

@allow_non_gpu("ObjectHashAggregateExec", "ProjectExec", "ShuffleExchangeExec")
@ignore_order(local=True)
@pytest.mark.parametrize("agg_replace_mode", ["partial", "final"])
@pytest.mark.skipif(is_databricks_runtime(), reason="https://github.com/NVIDIA/spark-rapids/issues/8921")
@pytest.mark.skipif(is_before_spark_330(), reason="Bloom filter joins added in Spark 3.3.0")
def test_bloom_filter_join_cpu_build(agg_replace_mode):
check_bloom_filter_join(confs={"spark.rapids.sql.expression.BloomFilterAggregate": "false",
"spark.rapids.sql.hashAgg.replaceMode": agg_replace_mode})
def test_bloom_filter_join_split_cpu_build(agg_replace_mode):
check_bloom_filter_join(confs={"spark.rapids.sql.hashAgg.replaceMode": agg_replace_mode},
expected_classes="GpuBloomFilterMightContain,BloomFilterAggregate,GpuBloomFilterAggregate")

0 comments on commit 9202839

Please sign in to comment.