Skip to content

Commit

Permalink
Merge pull request NVIDIA#20 from nvliyuan/0612-base-local-for-case-w…
Browse files Browse the repository at this point in the history
…hen-perf

case when improvement: avoid copy_if_else
  • Loading branch information
res-life authored Jun 17, 2024
2 parents c37cfa2 + c3d5401 commit 0591964
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 15 deletions.
16 changes: 15 additions & 1 deletion integration_tests/src/main/python/conditionals_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql
from data_gen import *
from spark_session import is_before_spark_320, is_jvm_charset_utf8
from pyspark.sql.types import *
Expand Down Expand Up @@ -296,3 +296,17 @@ def test_conditional_with_side_effects_unary_minus(data_gen, ansi_enabled):
'CASE WHEN a > -32768 THEN -a ELSE null END'),
conf = {'spark.sql.ansi.enabled': ansi_enabled})

def test_case_when_all_then_values_are_scalars():
data_gen = [
("a", boolean_gen),
("b", boolean_gen),
("c", boolean_gen),
("d", boolean_gen),
("e", boolean_gen)]
assert_gpu_and_cpu_are_equal_sql(
lambda spark : gen_df(spark, data_gen),
"tab",
"select case when a then 'aaa' when b then 'bbb' when c then 'ccc' " +
"when d then 'ddd' when e then 'eee' else 'unknown' end from tab",
conf = {'spark.rapids.sql.case_when.fuse': 'true'})

Original file line number Diff line number Diff line change
Expand Up @@ -2023,7 +2023,7 @@ object GpuOverrides extends Logging {
} else {
None
}
GpuCaseWhen(branches, elseValue)
GpuCaseWhen(branches, elseValue, conf.caseWhenFuseEnabled)
}
}),
expr[If](
Expand Down
10 changes: 10 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2323,6 +2323,14 @@ val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression.
.integerConf
.createWithDefault(1024)

val CASE_WHEN_FUSE =
conf("spark.rapids.sql.case_when.fuse")
.doc("If when branches is greater than 3 and all then values in case when are string " +
"scalar, fuse mode improves the performance. By default this is enabled.")
.internal()
.booleanConf
.createWithDefault(true)

private def printSectionHeader(category: String): Unit =
println(s"\n### $category")

Expand Down Expand Up @@ -3142,6 +3150,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging {

lazy val testGetJsonObjectSaveRows: Int = get(TEST_GET_JSON_OBJECT_SAVE_ROWS)

lazy val caseWhenFuseEnabled: Boolean = get(CASE_WHEN_FUSE)

private val optimizerDefaults = Map(
// this is not accurate because CPU projections do have a cost due to appending values
// to each row that is produced, but this needs to be a really small number because
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ package com.nvidia.spark.rapids
import ai.rapids.cudf.{BinaryOp, ColumnVector, DType, NullPolicy, Scalar, ScanAggregation, ScanType, Table, UnaryOp}
import com.nvidia.spark.rapids.Arm._
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.jni.CaseWhen
import com.nvidia.spark.rapids.shims.ShimExpression

import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.{ComplexTypeMergingExpression, Expression}
import org.apache.spark.sql.types.{BooleanType, DataType, DataTypes}
import org.apache.spark.sql.types.{BooleanType, DataType, DataTypes, StringType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.unsafe.types.UTF8String

object GpuExpressionWithSideEffectUtils {

Expand All @@ -47,7 +49,7 @@ object GpuExpressionWithSideEffectUtils {

/**
* Used to shortcircuit predicates and filter conditions.
*
*
* @param nullsAsFalse when true, null values are considered false.
* @param col the input being evaluated.
* @return boolean. When nullsAsFalse is set, it returns True if none of the rows is true;
Expand Down Expand Up @@ -182,9 +184,9 @@ case class GpuIf(
predicateExpr: Expression,
trueExpr: Expression,
falseExpr: Expression) extends GpuConditionalExpression {

import GpuExpressionWithSideEffectUtils._

@transient
override lazy val inputTypesForMerging: Seq[DataType] = {
Seq(trueExpr.dataType, falseExpr.dataType)
Expand Down Expand Up @@ -314,7 +316,9 @@ case class GpuIf(

case class GpuCaseWhen(
branches: Seq[(Expression, Expression)],
elseValue: Option[Expression] = None) extends GpuConditionalExpression with Serializable {
elseValue: Option[Expression] = None,
caseWhenFuseEnabled: Boolean = true)
extends GpuConditionalExpression with Serializable {

import GpuExpressionWithSideEffectUtils._

Expand Down Expand Up @@ -359,15 +363,48 @@ case class GpuCaseWhen(
if (branchesWithSideEffects) {
columnarEvalWithSideEffects(batch)
} else {
// `elseRet` will be closed in `computeIfElse`.
val elseRet = elseValue
.map(_.columnarEvalAny(batch))
.getOrElse(GpuScalar(null, branches.last._2.dataType))
val any = branches.foldRight[Any](elseRet) {
case ((predicateExpr, trueExpr), falseRet) =>
computeIfElse(batch, predicateExpr, trueExpr, falseRet)
if (caseWhenFuseEnabled && branches.size > 2 &&
inputTypesForMerging.head == StringType &&
(branches.map(_._2) ++ elseValue).forall(_.isInstanceOf[GpuLiteral])
) {
// when branches size > 2;
// return type is string type;
// all the then and else exprs are Scalars.
// Avoid to use multiple `computeIfElse`s which will create multiple temp columns

// 1. select first true index from bool columns
val whenBoolCols = branches.safeMap(_._1.columnarEval(batch).getBase).toArray
val firstTrueIndex: ColumnVector = withResource(whenBoolCols) { _ =>
CaseWhen.selectFirstTrueIndex(whenBoolCols)
}

withResource(firstTrueIndex) { _ =>
val thenElseScalars = (branches.map(_._2) ++ elseValue).map(_.columnarEvalAny(batch)
.asInstanceOf[GpuScalar])
withResource(thenElseScalars) { _ =>
// 2. generate a column to store all scalars
val scalarsBytes = thenElseScalars.map(ret => ret.getValue
.asInstanceOf[UTF8String].getBytes)
val scalarCol = ColumnVector.fromUTF8Strings(scalarsBytes: _*)
withResource(scalarCol) { _ =>
// 3. execute final select
val finalRet = CaseWhen.selectFromIndex(scalarCol, firstTrueIndex)
// return final column vector
GpuColumnVector.from(finalRet, dataType)
}
}
}
} else {
// `elseRet` will be closed in `computeIfElse`.
val elseRet = elseValue
.map(_.columnarEvalAny(batch))
.getOrElse(GpuScalar(null, branches.last._2.dataType))
val any = branches.foldRight[Any](elseRet) {
case ((predicateExpr, trueExpr), falseRet) =>
computeIfElse(batch, predicateExpr, trueExpr, falseRet)
}
GpuExpressionsUtils.resolveColumnVector(any, batch.numRows())
}
GpuExpressionsUtils.resolveColumnVector(any, batch.numRows())
}
}

Expand Down

0 comments on commit 0591964

Please sign in to comment.