diff --git a/integration_tests/src/main/python/join_test.py b/integration_tests/src/main/python/join_test.py index 75f72ec6a20..94289193485 100644 --- a/integration_tests/src/main/python/join_test.py +++ b/integration_tests/src/main/python/join_test.py @@ -504,7 +504,7 @@ def do_join(spark): # After 3.1.0 is the min spark version we can drop this @ignore_order(local=True) @pytest.mark.parametrize('data_gen', join_ast_gen, ids=idfn) -@pytest.mark.parametrize('join_type', ['Left', 'Right', 'FullOuter', 'Inner', 'Cross'], ids=idfn) +@pytest.mark.parametrize('join_type', all_join_types, ids=idfn) def test_broadcast_join_with_conditionals(data_gen, join_type): def do_join(spark): left, right = create_df(spark, data_gen, 500, 250) @@ -512,21 +512,6 @@ def do_join(spark): (left.a == right.r_a) & (left.b >= right.r_b), join_type) assert_gpu_and_cpu_are_equal_collect(do_join, conf=allow_negative_scale_of_decimal_conf) -# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84 -# After 3.1.0 is the min spark version we can drop this -@allow_non_gpu('BroadcastExchangeExec', 'BroadcastHashJoinExec', 'Cast', 'GreaterThan') -@ignore_order(local=True) -@pytest.mark.parametrize('data_gen', [long_gen], ids=idfn) -@pytest.mark.parametrize('join_type', ['LeftSemi', 'LeftAnti'], ids=idfn) -def test_broadcast_join_with_condition_join_type_fallback(data_gen, join_type): - def do_join(spark): - left, right = create_df(spark, data_gen, 50, 25) - # AST does not support cast or logarithm yet - return left.join(broadcast(right), - (left.a == right.r_a) & (left.b > right.r_b), join_type) - conf = allow_negative_scale_of_decimal_conf - assert_gpu_fallback_collect(do_join, 'BroadcastHashJoinExec', conf=conf) - # local sort because of https://github.com/NVIDIA/spark-rapids/issues/84 # After 3.1.0 is the min spark version we can drop this @allow_non_gpu('BroadcastExchangeExec', 'BroadcastHashJoinExec', 'Cast', 'GreaterThan', 'Log', 'SortMergeJoinExec') @@ -571,45 +556,23 @@ def do_join(spark): (left.a == right.r_a) & (left.b > right.r_b), join_type) assert_gpu_and_cpu_are_equal_collect(do_join, conf=allow_negative_scale_of_decimal_conf) -# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84 -# After 3.1.0 is the min spark version we can drop this -@ignore_order(local=True) -@pytest.mark.parametrize('data_gen', all_gen, ids=idfn) -def test_sortmerge_join_with_conditionals(data_gen): - def do_join(spark): - left, right = create_df(spark, data_gen, 500, 250) - return left.join(right, (left.a == right.r_a) & (left.b >= right.r_b), 'Inner') - assert_gpu_and_cpu_are_equal_collect(do_join, conf=_sortmerge_join_conf) - # local sort because of https://github.com/NVIDIA/spark-rapids/issues/84 # After 3.1.0 is the min spark version we can drop this @ignore_order(local=True) @pytest.mark.parametrize('data_gen', join_ast_gen, ids=idfn) -@pytest.mark.parametrize('join_type', ['Left', 'Right', 'FullOuter'], ids=idfn) +@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'FullOuter', 'LeftSemi', 'LeftAnti'], ids=idfn) def test_sortmerge_join_with_condition_ast(data_gen, join_type): def do_join(spark): left, right = create_df(spark, data_gen, 500, 250) return left.join(right, (left.a == right.r_a) & (left.b >= right.r_b), join_type) assert_gpu_and_cpu_are_equal_collect(do_join, conf=_sortmerge_join_conf) -# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84 -# After 3.1.0 is the min spark version we can drop this -@allow_non_gpu('GreaterThan', 'ShuffleExchangeExec', 'SortMergeJoinExec') -@ignore_order(local=True) -@pytest.mark.parametrize('data_gen', [long_gen], ids=idfn) -@pytest.mark.parametrize('join_type', ['LeftSemi', 'LeftAnti'], ids=idfn) -def test_sortmerge_join_with_condition_join_type_fallback(data_gen, join_type): - def do_join(spark): - left, right = create_df(spark, data_gen, 500, 250) - return left.join(right, (left.a == right.r_a) & (left.b >= right.r_b), join_type) - assert_gpu_fallback_collect(do_join, 'SortMergeJoinExec', conf=_sortmerge_join_conf) - # local sort because of https://github.com/NVIDIA/spark-rapids/issues/84 # After 3.1.0 is the min spark version we can drop this @allow_non_gpu('GreaterThan', 'Log', 'ShuffleExchangeExec', 'SortMergeJoinExec') @ignore_order(local=True) @pytest.mark.parametrize('data_gen', [long_gen], ids=idfn) -@pytest.mark.parametrize('join_type', ['Left', 'Right', 'FullOuter'], ids=idfn) +@pytest.mark.parametrize('join_type', ['Left', 'Right', 'FullOuter', 'LeftSemi', 'LeftAnti'], ids=idfn) def test_sortmerge_join_with_condition_ast_op_fallback(data_gen, join_type): def do_join(spark): left, right = create_df(spark, data_gen, 500, 250) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala index d54f3e86bc5..f6a78890fcf 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala @@ -104,25 +104,19 @@ object GpuHashJoin extends Arm { conditionMeta: Option[BaseExprMeta[_]]): Unit = { val keyDataTypes = (leftKeys ++ rightKeys).map(_.dataType) - def unSupportNonEqualCondition(): Unit = if (conditionMeta.isDefined) { - meta.willNotWorkOnGpu(s"$joinType joins currently do not support conditions") - } - def unSupportStructKeys(): Unit = if (keyDataTypes.exists(_.isInstanceOf[StructType])) { - meta.willNotWorkOnGpu(s"$joinType joins currently do not support with struct keys") - } JoinTypeChecks.tagForGpu(joinType, meta) joinType match { case _: InnerLike => - case RightOuter | LeftOuter => + case RightOuter | LeftOuter | LeftSemi | LeftAnti => conditionMeta.foreach(meta.requireAstForGpuOn) - case LeftSemi | LeftAnti => - unSupportNonEqualCondition() case FullOuter => conditionMeta.foreach(meta.requireAstForGpuOn) // FullOuter join cannot support with struct keys as two issues below // * https://github.com/NVIDIA/spark-rapids/issues/2126 // * https://github.com/rapidsai/cudf/issues/7947 - unSupportStructKeys() + if (keyDataTypes.exists(_.isInstanceOf[StructType])) { + meta.willNotWorkOnGpu(s"$joinType joins currently do not support with struct keys") + } case _ => meta.willNotWorkOnGpu(s"$joinType currently is not supported") } @@ -475,17 +469,26 @@ class ConditionalHashJoinIterator( withResource(GpuColumnVector.from(leftData.getBatch)) { leftTable => withResource(GpuColumnVector.from(rightData.getBatch)) { rightTable => val maps = joinType match { - case _: InnerLike => Table.mixedInnerJoinGatherMaps( - leftKeys, rightKeys, leftTable, rightTable, compiledCondition, nullEquality) - case LeftOuter => Table.mixedLeftJoinGatherMaps( - leftKeys, rightKeys, leftTable, rightTable, compiledCondition, nullEquality) + case _: InnerLike => + Table.mixedInnerJoinGatherMaps(leftKeys, rightKeys, leftTable, rightTable, + compiledCondition, nullEquality) + case LeftOuter => + Table.mixedLeftJoinGatherMaps(leftKeys, rightKeys, leftTable, rightTable, + compiledCondition, nullEquality) case RightOuter => // Reverse the output of the join, because we expect the right gather map to // always be on the right Table.mixedLeftJoinGatherMaps(rightKeys, leftKeys, rightTable, leftTable, compiledCondition, nullEquality).reverse - case FullOuter => Table.mixedFullJoinGatherMaps( - leftKeys, rightKeys, leftTable, rightTable, compiledCondition, nullEquality) + case FullOuter => + Table.mixedFullJoinGatherMaps(leftKeys, rightKeys, leftTable, rightTable, + compiledCondition, nullEquality) + case LeftSemi => + Array(Table.mixedLeftSemiJoinGatherMap(leftKeys, rightKeys, leftTable, rightTable, + compiledCondition, nullEquality)) + case LeftAnti => + Array(Table.mixedLeftAntiJoinGatherMap(leftKeys, rightKeys, leftTable, rightTable, + compiledCondition, nullEquality)) case _ => throw new NotImplementedError(s"Joint Type ${joinType.getClass} is not currently" + s" supported")