Skip to content

Commit

Permalink
Support split non-AST-able join condition for BroadcastNestedLoopJoin (
Browse files Browse the repository at this point in the history
…#9635)

Signed-off-by: Ferdinand Xu <[email protected]>
  • Loading branch information
winningsix authored Nov 13, 2023
1 parent f4a898c commit c20a843
Show file tree
Hide file tree
Showing 10 changed files with 666 additions and 68 deletions.
34 changes: 32 additions & 2 deletions integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import pytest
from _pytest.mark.structures import ParameterSet
from pyspark.sql.functions import broadcast, col
from pyspark.sql.functions import array_contains, broadcast, col
from pyspark.sql.types import *
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect, assert_cpu_and_gpu_are_equal_collect_with_capture
from conftest import is_databricks_runtime, is_emr_runtime
Expand Down Expand Up @@ -397,17 +397,47 @@ def do_join(spark):
return left.join(broadcast(right), left.a > f.log(right.r_a), join_type)
assert_gpu_and_cpu_are_equal_collect(do_join)

@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [IntegerGen(), LongGen(), pytest.param(FloatGen(), marks=[incompat]), pytest.param(DoubleGen(), marks=[incompat])], ids=idfn)
@pytest.mark.parametrize('join_type', ['Cross', 'Left', 'LeftSemi', 'LeftAnti'], ids=idfn)
def test_broadcast_nested_loop_join_with_condition(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 50, 25)
# AST does not support cast or logarithm yet which is supposed to be extracted into child
# nodes. And this test doesn't cover other join types due to:
# (1) build right are not supported for Right
# (2) FullOuter: currently is not supported
# Those fallback reasons are not due to AST. Additionally, this test case changes test_broadcast_nested_loop_join_with_condition_fallback:
# (1) adapt double to integer since AST current doesn't support it.
# (2) switch to right side build to pass checks of 'Left', 'LeftSemi', 'LeftAnti' join types
return left.join(broadcast(right), f.round(left.a).cast('integer') > f.round(f.log(right.r_a).cast('integer')), join_type)
assert_gpu_and_cpu_are_equal_collect(do_join, conf={"spark.rapids.sql.castFloatToIntegralTypes.enabled": True})

@allow_non_gpu('BroadcastExchangeExec', 'BroadcastNestedLoopJoinExec', 'Cast', 'GreaterThan', 'Log')
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [IntegerGen(), LongGen(), pytest.param(FloatGen(), marks=[incompat]), pytest.param(DoubleGen(), marks=[incompat])], ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'FullOuter', 'LeftSemi', 'LeftAnti'], ids=idfn)
def test_broadcast_nested_loop_join_with_condition_fallback(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 50, 25)
# AST does not support cast or logarithm yet
# AST does not support double type which is not split-able into child nodes.
return broadcast(left).join(right, left.a > f.log(right.r_a), join_type)
assert_gpu_fallback_collect(do_join, 'BroadcastNestedLoopJoinExec')

@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [byte_gen, short_gen, int_gen, long_gen,
float_gen, double_gen,
string_gen, boolean_gen, date_gen, timestamp_gen], ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'FullOuter', 'LeftSemi', 'LeftAnti'], ids=idfn)
def test_broadcast_nested_loop_join_with_array_contains(data_gen, join_type):
arr_gen = ArrayGen(data_gen)
literal = with_cpu_session(lambda spark: gen_scalar(data_gen))
def do_join(spark):
left, right = create_df(spark, arr_gen, 50, 25)
# Array_contains will be pushed down into project child nodes
return broadcast(left).join(right, array_contains(left.a, literal.cast(data_gen.data_type)) < array_contains(right.r_a, literal.cast(data_gen.data_type)))
assert_gpu_and_cpu_are_equal_collect(do_join)

@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'LeftSemi', 'LeftAnti'], ids=idfn)
Expand Down
122 changes: 122 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

import scala.collection.mutable
import scala.collection.mutable.ListBuffer

import org.apache.spark.sql.catalyst.expressions.{AttributeSeq, Expression, ExprId, NamedExpression}
import org.apache.spark.sql.rapids.catalyst.expressions.{GpuEquivalentExpressions, GpuExpressionEquals}


object AstUtil {

/**
* Check whether it can be split into non-ast sub-expression if needed
*
* @return true when: 1) If all ast-able in expr; 2) all non-ast-able tree nodes don't contain
* attributes from both join sides. In such case, it's not able
* to push down into single child.
*/
def canExtractNonAstConditionIfNeed(expr: BaseExprMeta[_], left: Seq[ExprId],
right: Seq[ExprId]): Boolean = {
if (!expr.canSelfBeAst) {
// It needs to be split since not ast-able. Check itself and childerns to ensure
// pushing-down can be made, which doesn't need attributions from both sides.
val exprRef = expr.wrapped.asInstanceOf[Expression]
val leftTree = exprRef.references.exists(r => left.contains(r.exprId))
val rightTree = exprRef.references.exists(r => right.contains(r.exprId))
// Can't extract a condition involving columns from both sides
!(rightTree && leftTree)
} else {
// Check whether any child contains the case not able to split
expr.childExprs.isEmpty || expr.childExprs.forall(
canExtractNonAstConditionIfNeed(_, left, right))
}
}

/**
* Extract non-AST functions from join conditions and update the original join condition. Based
* on the attributes, it decides which side the split condition belongs to. The replaced
* condition is wrapped with GpuAlias with new intermediate attributes.
*
* @param condition to be split if needed
* @param left attributions from left child
* @param right attributions from right child
* @param skipCheck whether skip split-able check
* @return a tuple of [[Expression]] for remained expressions, List of [[NamedExpression]] for
* left child if any, List of [[NamedExpression]] for right child if any
*/
def extractNonAstFromJoinCond(condition: Option[BaseExprMeta[_]],
left: AttributeSeq, right: AttributeSeq, skipCheck: Boolean):
(Option[Expression], List[NamedExpression], List[NamedExpression]) = {
// Choose side with smaller key size. Use expr ID to check the side which project expr
// belonging to.
val (exprIds, isLeft) = if (left.attrs.size < right.attrs.size) {
(left.attrs.map(_.exprId), true)
} else {
(right.attrs.map(_.exprId), false)
}
// List of expression pushing down to left side child
val leftExprs: ListBuffer[NamedExpression] = ListBuffer.empty
// List of expression pushing down to right side child
val rightExprs: ListBuffer[NamedExpression] = ListBuffer.empty
// Substitution map used to replace targeted expressions based on semantic equality
val substitutionMap = mutable.HashMap.empty[GpuExpressionEquals, Expression]

// 1st step to construct 1) left expr list; 2) right expr list; 3) substitutionMap
// No need to consider common sub-expressions here since project node will use tiered execution
condition.foreach(c =>
if (skipCheck || canExtractNonAstConditionIfNeed(c, left.attrs.map(_.exprId), right.attrs
.map(_.exprId))) {
splitNonAstInternal(c, exprIds, leftExprs, rightExprs, substitutionMap, isLeft)
})

// 2nd step to replace expression pushing down to child plans in depth first fashion
(condition.map(
_.convertToGpu().mapChildren(
GpuEquivalentExpressions.replaceWithSemanticCommonRef(_,
substitutionMap))), leftExprs.toList, rightExprs.toList)
}

private[this] def splitNonAstInternal(condition: BaseExprMeta[_], childAtt: Seq[ExprId],
left: ListBuffer[NamedExpression], right: ListBuffer[NamedExpression],
substitutionMap: mutable.HashMap[GpuExpressionEquals, Expression], isLeft: Boolean): Unit = {
for (child <- condition.childExprs) {
if (!child.canSelfBeAst) {
val exprRef = child.wrapped.asInstanceOf[Expression]
val gpuProj = child.convertToGpu()
val alias = substitutionMap.get(GpuExpressionEquals(gpuProj)) match {
case Some(_) => None
case None =>
if (exprRef.references.exists(r => childAtt.contains(r.exprId)) ^ isLeft) {
val alias = GpuAlias(gpuProj, s"_agpu_non_ast_r_${left.size}")()
right += alias
Some(alias)
} else {
val alias = GpuAlias(gpuProj, s"_agpu_non_ast_l_${left.size}")()
left += alias
Some(alias)
}
}
alias.foreach(a => substitutionMap.put(GpuExpressionEquals(gpuProj), a.toAttribute))
} else {
splitNonAstInternal(child, childAtt, left, right, substitutionMap, isLeft)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,15 @@ abstract class BaseExprMeta[INPUT <: Expression](
childExprs.forall(_.canThisBeAst) && cannotBeAstReasons.isEmpty
}

/**
* Check whether this node itself can be converted to AST. It will not recursively check its
* children. It's used to check join condition AST-ability in top-down fashion.
*/
lazy val canSelfBeAst = {
tagForAst()
cannotBeAstReasons.isEmpty
}

final def requireAstForGpu(): Unit = {
tagForAst()
cannotBeAstReasons.foreach { reason =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,22 @@ class GpuEquivalentExpressions {
}

object GpuEquivalentExpressions {
/**
* Recursively replaces semantic equal expression with its proxy expression in `substitutionMap`.
*/
def replaceWithSemanticCommonRef(
expr: Expression,
substitutionMap: mutable.HashMap[GpuExpressionEquals, Expression]): Expression = {
expr match {
case e: AttributeReference => e
case _ =>
substitutionMap.get(GpuExpressionEquals(expr)) match {
case Some(attr) => attr
case None => expr.mapChildren(replaceWithSemanticCommonRef(_, substitutionMap))
}
}
}

/**
* Recursively replaces expression with its proxy expression in `substitutionMap`.
*/
Expand Down
Loading

0 comments on commit c20a843

Please sign in to comment.