Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
winningsix committed Nov 9, 2023
1 parent d63bca6 commit 6546c3c
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 34 deletions.
25 changes: 14 additions & 11 deletions integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,42 +397,45 @@ def do_join(spark):
return left.join(broadcast(right), left.a > f.log(right.r_a), join_type)
assert_gpu_and_cpu_are_equal_collect(do_join)

@allow_non_gpu('BroadcastExchangeExec')
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [IntegerGen(), LongGen(), pytest.param(FloatGen(), marks=[incompat]), pytest.param(DoubleGen(), marks=[incompat])], ids=idfn)
@pytest.mark.parametrize('join_type', ['Cross', 'Right'], ids=idfn)
def test_broadcast_nested_loop_join_with_non_ast_condition_no_fallback(data_gen, join_type):
@pytest.mark.parametrize('join_type', ['Cross', 'Left', 'LeftSemi', 'LeftAnti'], ids=idfn)
def test_broadcast_nested_loop_join_with_condition(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 50, 25)
# AST does not support cast or logarithm yet which is supposed to be extracted into child
# nodes
return broadcast(left).join(right, f.round(left.a).cast('integer') > f.round(f.log(right.r_a).cast('integer')), join_type)
# nodes. And this test doesn't cover other join types due to:
# (1) build right are not supported for Right
# (2) FullOuter: currently is not supported
# Those fallback reasons are not due to AST. Additionally, this test case changes test_broadcast_nested_loop_join_with_condition_fallback:
# (1) adapt double to integer since AST current doesn't support it.
# (2) switch to right side build to pass checks of 'Left', 'LeftSemi', 'LeftAnti' join types
return left.join(broadcast(right), f.round(left.a).cast('integer') > f.round(f.log(right.r_a).cast('integer')), join_type)
assert_gpu_and_cpu_are_equal_collect(do_join, conf={"spark.rapids.sql.castFloatToIntegralTypes.enabled": True})

@allow_non_gpu('BroadcastExchangeExec', 'BroadcastNestedLoopJoinExec', 'Cast', 'GreaterThan', 'Log')
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [IntegerGen(), LongGen(), pytest.param(FloatGen(), marks=[incompat]), pytest.param(DoubleGen(), marks=[incompat])], ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'FullOuter', 'LeftSemi', 'LeftAnti'], ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'FullOuter', 'LeftSemi', 'LeftAnti'], ids=idfn)
def test_broadcast_nested_loop_join_with_condition_fallback(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 50, 25)
# AST does not support cast or logarithm yet for double type
# AST does not support double type which is not split-able into child nodes.
return broadcast(left).join(right, left.a > f.log(right.r_a), join_type)
assert_gpu_fallback_collect(do_join, 'BroadcastNestedLoopJoinExec')

@allow_non_gpu('BroadcastExchangeExec')
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [byte_gen, short_gen, int_gen, long_gen,
float_gen, double_gen,
string_gen, boolean_gen, date_gen, timestamp_gen], ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'FullOuter', 'LeftSemi', 'LeftAnti'], ids=idfn)
def test_broadcast_nested_loop_join_with_array_contains_no_fallback(data_gen, join_type):
def test_broadcast_nested_loop_join_with_array_contains(data_gen, join_type):
arr_gen = ArrayGen(data_gen)
literal = with_cpu_session(lambda spark: gen_scalar(data_gen))
def do_join(spark):
left, right = create_df(spark, arr_gen, 50, 25)
# Array_contains should be pushed down since ast doesn't support it.
return broadcast(left).join(right, array_contains(col('a'), literal.cast(data_gen.data_type)))
# Array_contains will be pushed down into project child nodes
return broadcast(left).join(right, array_contains(left.a, literal.cast(data_gen.data_type)) < array_contains(right.r_a, literal.cast(data_gen.data_type)))
assert_gpu_and_cpu_are_equal_collect(do_join)

@ignore_order(local=True)
Expand Down
27 changes: 9 additions & 18 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package com.nvidia.spark.rapids
import scala.collection.mutable
import scala.collection.mutable.ListBuffer

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSeq, Expression, ExprId, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.{AttributeSeq, Expression, ExprId, NamedExpression}
import org.apache.spark.sql.rapids.catalyst.expressions.{GpuEquivalentExpressions, GpuExpressionEquals}


Expand All @@ -32,30 +32,20 @@ object AstUtil {
* attributes from both join sides. In such case, it's not able
* to push down into single child.
*/
def canExtractNonAstConditionIfNeed(expr: BaseExprMeta[_], left: Seq[Attribute],
right: Seq[Attribute]): Boolean = {
if (!expr.canSelfBeAst) {
// Returns false directly since can't split non-ast-able root node into child
false
} else {
// Check whether any child contains the case not able to split
expr.childExprs.isEmpty || expr.childExprs.forall(canExtractInternal(_, left, right))
}
}

private[this] def canExtractInternal(expr: BaseExprMeta[_], left: Seq[Attribute],
right: Seq[Attribute]): Boolean = {
def canExtractNonAstConditionIfNeed(expr: BaseExprMeta[_], left: Seq[ExprId],
right: Seq[ExprId]): Boolean = {
if (!expr.canSelfBeAst) {
// It needs to be split since not ast-able. Check itself and childerns to ensure
// pushing-down can be made, which doesn't need attributions from both sides.
val exprRef = expr.wrapped.asInstanceOf[Expression]
val leftTree = exprRef.references.exists(left.contains(_))
val rightTree = exprRef.references.exists(right.contains(_))
val leftTree = exprRef.references.exists(r => left.contains(r.exprId))
val rightTree = exprRef.references.exists(r => right.contains(r.exprId))
// Can't extract a condition involving columns from both sides
!(rightTree && leftTree)
} else {
// Check whether any child contains the case not able to split
expr.childExprs.isEmpty || expr.childExprs.forall(canExtractInternal(_, left, right))
expr.childExprs.isEmpty || expr.childExprs.forall(
canExtractNonAstConditionIfNeed(_, left, right))
}
}

Expand Down Expand Up @@ -91,7 +81,8 @@ object AstUtil {
// 1st step to construct 1) left expr list; 2) right expr list; 3) substitutionMap
// No need to consider common sub-expressions here since project node will use tiered execution
condition.foreach(c =>
if (skipCheck || canExtractNonAstConditionIfNeed(c, left.attrs, right.attrs)) {
if (skipCheck || canExtractNonAstConditionIfNeed(c, left.attrs.map(_.exprId), right.attrs
.map(_.exprId))) {
splitNonAstInternal(c, exprIds, leftExprs, rightExprs, substitutionMap, isLeft)
})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ abstract class GpuBroadcastNestedLoopJoinMetaBase(
val Seq(leftPlan, rightPlan) = childPlans
conditionMeta match {
case Some(e) => isAstCond = AstUtil.canExtractNonAstConditionIfNeed(
e, leftPlan.outputAttributes, rightPlan.outputAttributes)
e, leftPlan.outputAttributes.map(_.exprId), rightPlan.outputAttributes.map(_.exprId))
case None => isAstCond = true
}
taggedForAstCheck = true
Expand Down
11 changes: 7 additions & 4 deletions tests/src/test/scala/com/nvidia/spark/rapids/AstUtilSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class AstUtilSuite extends GpuUnitTests {
when(exprMeta.canSelfBeAst).thenReturn(!containsNonAstAble)
when(exprMeta.wrapped).thenReturn(expr)

AstUtil.canExtractNonAstConditionIfNeed(exprMeta, Seq(l1, l2), Seq(r1, r2))
AstUtil.canExtractNonAstConditionIfNeed(exprMeta, Seq(l1, l2).map(_.exprId), Seq(r1, r2).map
(_.exprId))
}

private[this] def testMultiNodes(containsNonAstAble: Boolean, crossMultiChildPlan: Boolean)
Expand All @@ -71,7 +72,8 @@ class AstUtilSuite extends GpuUnitTests {

when(rootExprMeta.canSelfBeAst).thenReturn(true)

AstUtil.canExtractNonAstConditionIfNeed(rootExprMeta, Seq(l1, l2), Seq(r1, r2))
AstUtil.canExtractNonAstConditionIfNeed(rootExprMeta, Seq(l1, l2).map(_.exprId), Seq(r1, r2)
.map(_.exprId))
}

private[this] def buildLeaf(attributeSet: AttributeSet, containsNonAstAble: Boolean)
Expand Down Expand Up @@ -114,12 +116,13 @@ class AstUtilSuite extends GpuUnitTests {
when(rootExprMeta.childExprs).thenReturn(Seq(leftExprMeta, rightExprMeta))
when(rootExprMeta.canSelfBeAst).thenReturn(true)

AstUtil.canExtractNonAstConditionIfNeed(rootExprMeta, Seq(l1, l2), Seq(r1, r2))
AstUtil.canExtractNonAstConditionIfNeed(rootExprMeta, Seq(l1, l2).map(_.exprId), Seq(r1, r2)
.map(_.exprId))
}

test("Single node tree for ast split if needed") {
for ((canAstSplitIfNeeded, containsNonAstAble, crossMultiChildPlan) <- Seq(
(false, true, true), (false, true, false), (true, false, true), (true, false, false))) {
(false, true, true), (true, true, false), (true, false, true), (true, false, false))) {
assertResult(
canAstSplitIfNeeded)(testSingleNode(containsNonAstAble, crossMultiChildPlan))
}
Expand Down

0 comments on commit 6546c3c

Please sign in to comment.