Skip to content

Commit

Permalink
Revert "Support split non-AST-able join condition for BroadcastNested…
Browse files Browse the repository at this point in the history
…LoopJoin (NVIDIA#9635)"

This reverts commit c20a843.
  • Loading branch information
winningsix committed Nov 14, 2023
1 parent c495339 commit e110466
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 666 deletions.
34 changes: 2 additions & 32 deletions integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import pytest
from _pytest.mark.structures import ParameterSet
from pyspark.sql.functions import array_contains, broadcast, col
from pyspark.sql.functions import broadcast, col
from pyspark.sql.types import *
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect, assert_cpu_and_gpu_are_equal_collect_with_capture
from conftest import is_databricks_runtime, is_emr_runtime
Expand Down Expand Up @@ -397,47 +397,17 @@ def do_join(spark):
return left.join(broadcast(right), left.a > f.log(right.r_a), join_type)
assert_gpu_and_cpu_are_equal_collect(do_join)

@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [IntegerGen(), LongGen(), pytest.param(FloatGen(), marks=[incompat]), pytest.param(DoubleGen(), marks=[incompat])], ids=idfn)
@pytest.mark.parametrize('join_type', ['Cross', 'Left', 'LeftSemi', 'LeftAnti'], ids=idfn)
def test_broadcast_nested_loop_join_with_condition(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 50, 25)
# AST does not support cast or logarithm yet which is supposed to be extracted into child
# nodes. And this test doesn't cover other join types due to:
# (1) build right are not supported for Right
# (2) FullOuter: currently is not supported
# Those fallback reasons are not due to AST. Additionally, this test case changes test_broadcast_nested_loop_join_with_condition_fallback:
# (1) adapt double to integer since AST current doesn't support it.
# (2) switch to right side build to pass checks of 'Left', 'LeftSemi', 'LeftAnti' join types
return left.join(broadcast(right), f.round(left.a).cast('integer') > f.round(f.log(right.r_a).cast('integer')), join_type)
assert_gpu_and_cpu_are_equal_collect(do_join, conf={"spark.rapids.sql.castFloatToIntegralTypes.enabled": True})

@allow_non_gpu('BroadcastExchangeExec', 'BroadcastNestedLoopJoinExec', 'Cast', 'GreaterThan', 'Log')
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [IntegerGen(), LongGen(), pytest.param(FloatGen(), marks=[incompat]), pytest.param(DoubleGen(), marks=[incompat])], ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'FullOuter', 'LeftSemi', 'LeftAnti'], ids=idfn)
def test_broadcast_nested_loop_join_with_condition_fallback(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 50, 25)
# AST does not support double type which is not split-able into child nodes.
# AST does not support cast or logarithm yet
return broadcast(left).join(right, left.a > f.log(right.r_a), join_type)
assert_gpu_fallback_collect(do_join, 'BroadcastNestedLoopJoinExec')

@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [byte_gen, short_gen, int_gen, long_gen,
float_gen, double_gen,
string_gen, boolean_gen, date_gen, timestamp_gen], ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'FullOuter', 'LeftSemi', 'LeftAnti'], ids=idfn)
def test_broadcast_nested_loop_join_with_array_contains(data_gen, join_type):
arr_gen = ArrayGen(data_gen)
literal = with_cpu_session(lambda spark: gen_scalar(data_gen))
def do_join(spark):
left, right = create_df(spark, arr_gen, 50, 25)
# Array_contains will be pushed down into project child nodes
return broadcast(left).join(right, array_contains(left.a, literal.cast(data_gen.data_type)) < array_contains(right.r_a, literal.cast(data_gen.data_type)))
assert_gpu_and_cpu_are_equal_collect(do_join)

@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'LeftSemi', 'LeftAnti'], ids=idfn)
Expand Down
122 changes: 0 additions & 122 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -1115,15 +1115,6 @@ abstract class BaseExprMeta[INPUT <: Expression](
childExprs.forall(_.canThisBeAst) && cannotBeAstReasons.isEmpty
}

/**
* Check whether this node itself can be converted to AST. It will not recursively check its
* children. It's used to check join condition AST-ability in top-down fashion.
*/
lazy val canSelfBeAst = {
tagForAst()
cannotBeAstReasons.isEmpty
}

final def requireAstForGpu(): Unit = {
tagForAst()
cannotBeAstReasons.foreach { reason =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,22 +240,6 @@ class GpuEquivalentExpressions {
}

object GpuEquivalentExpressions {
/**
* Recursively replaces semantic equal expression with its proxy expression in `substitutionMap`.
*/
def replaceWithSemanticCommonRef(
expr: Expression,
substitutionMap: mutable.HashMap[GpuExpressionEquals, Expression]): Expression = {
expr match {
case e: AttributeReference => e
case _ =>
substitutionMap.get(GpuExpressionEquals(expr)) match {
case Some(attr) => attr
case None => expr.mapChildren(replaceWithSemanticCommonRef(_, substitutionMap))
}
}
}

/**
* Recursively replaces expression with its proxy expression in `substitutionMap`.
*/
Expand Down
Loading

0 comments on commit e110466

Please sign in to comment.