Skip to content

Commit

Permalink
Add broadcast hash join conditional tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowe committed Jan 7, 2022
1 parent 619a7b5 commit 8728ff4
Showing 1 changed file with 74 additions and 9 deletions.
83 changes: 74 additions & 9 deletions integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -58,12 +58,18 @@
nested_3d_struct_gens = StructGen([['child0', nested_2d_struct_gens]], nullable=False)
struct_gens = [basic_struct_gen, basic_struct_gen_with_no_null_child, nested_2d_struct_gens, nested_3d_struct_gens]

double_gen = [pytest.param(DoubleGen(), marks=[incompat])]

basic_nested_gens = single_level_array_gens + map_string_string_gen + [all_basic_struct_gen]

# data types supported by AST expressions
ast_gen = [boolean_gen, byte_gen, short_gen, int_gen, long_gen, timestamp_gen]
# data types supported by AST expressions in joins
join_ast_gen = [
boolean_gen, byte_gen, short_gen, int_gen, long_gen, date_gen, timestamp_gen
]

# data types not supported by AST expressions in joins
join_no_ast_gen = [
pytest.param(FloatGen(), marks=[incompat]), pytest.param(DoubleGen(), marks=[incompat]),
string_gen, null_gen, decimal_gen_default, decimal_gen_64bit
]

_sortmerge_join_conf = {'spark.sql.autoBroadcastJoinThreshold': '-1',
'spark.sql.join.preferSortMergeJoin': 'True',
Expand Down Expand Up @@ -349,7 +355,7 @@ def do_join(spark):
# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84
# After 3.1.0 is the min spark version we can drop this
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', ast_gen, ids=idfn)
@pytest.mark.parametrize('data_gen', join_ast_gen, ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross'], ids=idfn)
@pytest.mark.parametrize('batch_size', ['100', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches
def test_right_broadcast_nested_loop_join_with_ast_condition(data_gen, join_type, batch_size):
Expand All @@ -366,7 +372,7 @@ def do_join(spark):
# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84
# After 3.1.0 is the min spark version we can drop this
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', ast_gen, ids=idfn)
@pytest.mark.parametrize('data_gen', join_ast_gen, ids=idfn)
@pytest.mark.parametrize('batch_size', ['100', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches
def test_left_broadcast_nested_loop_join_with_ast_condition(data_gen, batch_size):
def do_join(spark):
Expand Down Expand Up @@ -497,15 +503,74 @@ def do_join(spark):
# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84
# After 3.1.0 is the min spark version we can drop this
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
@pytest.mark.parametrize('join_type', ['Inner', 'Cross'], ids=idfn)
@pytest.mark.parametrize('data_gen', join_ast_gen, ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'FullOuter', 'Inner', 'Cross'], ids=idfn)
def test_broadcast_join_with_conditionals(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 500, 250)
return left.join(broadcast(right),
(left.a == right.r_a) & (left.b >= right.r_b), join_type)
assert_gpu_and_cpu_are_equal_collect(do_join, conf=allow_negative_scale_of_decimal_conf)

# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84
# After 3.1.0 is the min spark version we can drop this
@allow_non_gpu('BroadcastExchangeExec', 'BroadcastHashJoinExec', 'Cast', 'GreaterThan')
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [long_gen], ids=idfn)
@pytest.mark.parametrize('join_type', ['LeftSemi', 'LeftAnti'], ids=idfn)
def test_broadcast_join_with_condition_join_type_fallback(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 50, 25)
# AST does not support cast or logarithm yet
return left.join(broadcast(right),
(left.a == right.r_a) & (left.b > right.r_b), join_type)
conf = allow_negative_scale_of_decimal_conf
assert_gpu_fallback_collect(do_join, 'BroadcastHashJoinExec', conf=conf)

# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84
# After 3.1.0 is the min spark version we can drop this
@allow_non_gpu('BroadcastExchangeExec', 'BroadcastHashJoinExec', 'Cast', 'GreaterThan', 'Log', 'SortMergeJoinExec')
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [long_gen], ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'FullOuter', 'LeftSemi', 'LeftAnti'], ids=idfn)
def test_broadcast_join_with_condition_ast_op_fallback(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 50, 25)
# AST does not support cast or logarithm yet
return left.join(broadcast(right),
(left.a == right.r_a) & (left.b > f.log(right.r_b)), join_type)
conf = allow_negative_scale_of_decimal_conf
exec = 'SortMergeJoinExec' if join_type in ['Right', 'FullOuter'] else 'BroadcastHashJoinExec'
assert_gpu_fallback_collect(do_join, exec, conf=conf)

# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84
# After 3.1.0 is the min spark version we can drop this
@allow_non_gpu('BroadcastExchangeExec', 'BroadcastHashJoinExec', 'Cast', 'GreaterThan', 'SortMergeJoinExec')
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', join_no_ast_gen, ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'FullOuter', 'LeftSemi', 'LeftAnti'], ids=idfn)
def test_broadcast_join_with_condition_ast_type_fallback(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 50, 25)
# AST does not support cast or logarithm yet
return left.join(broadcast(right),
(left.a == right.r_a) & (left.b > right.r_b), join_type)
conf = allow_negative_scale_of_decimal_conf
exec = 'SortMergeJoinExec' if join_type in ['Right', 'FullOuter'] else 'BroadcastHashJoinExec'
assert_gpu_fallback_collect(do_join, exec, conf=conf)

# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84
# After 3.1.0 is the min spark version we can drop this
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', join_no_ast_gen, ids=idfn)
@pytest.mark.parametrize('join_type', ['Inner', 'Cross'], ids=idfn)
def test_broadcast_join_with_condition_post_filter(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 500, 250)
return left.join(broadcast(right),
(left.a == right.r_a) & (left.b > right.r_b), join_type)
assert_gpu_and_cpu_are_equal_collect(do_join, conf=allow_negative_scale_of_decimal_conf)

# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84
# After 3.1.0 is the min spark version we can drop this
@ignore_order(local=True)
Expand Down

0 comments on commit 8728ff4

Please sign in to comment.