Skip to content

Commit

Permalink
Merge pull request NVIDIA#2165 from NVIDIA/branch-0.5
Browse files Browse the repository at this point in the history
[auto-merge] branch-0.5 to branch-0.6 [skip ci] [bot]
  • Loading branch information
nvauto authored Apr 16, 2021
2 parents f20d87f + e574c45 commit 480c8c2
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 5 deletions.
25 changes: 24 additions & 1 deletion integration_tests/src/main/python/hash_aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,30 @@ def test_hash_grpby_avg(data_gen, conf):
conf=conf
)

@ignore_order
@pytest.mark.parametrize('data_gen', [_grpkey_strings_with_extra_nulls], ids=idfn)
@pytest.mark.parametrize('conf', get_params(_confs, params_markers_for_confs), ids=idfn)
@pytest.mark.parametrize('ansi_enabled', ['true', 'false'])
def test_hash_grpby_avg_nulls(data_gen, conf, ansi_enabled):
conf.update({'spark.sql.ansi.enabled': ansi_enabled})
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, data_gen, length=100).groupby('a')
.agg(f.avg('c')),
conf=conf
)

@ignore_order
@pytest.mark.parametrize('data_gen', [_grpkey_strings_with_extra_nulls], ids=idfn)
@pytest.mark.parametrize('conf', get_params(_confs, params_markers_for_confs), ids=idfn)
@pytest.mark.parametrize('ansi_enabled', ['true', 'false'])
def test_hash_reduction_avg_nulls(data_gen, conf, ansi_enabled):
conf.update({'spark.sql.ansi.enabled': ansi_enabled})
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, data_gen, length=100)
.agg(f.avg('c')),
conf=conf
)

# tracks https://github.com/NVIDIA/spark-rapids/issues/154
@approximate_float
@ignore_order
Expand Down Expand Up @@ -302,7 +326,6 @@ def test_hash_query_max_with_multiple_distincts(data_gen, conf, parameterless):
'count(distinct b) from hash_agg_table group by a',
conf)


@ignore_order
@pytest.mark.parametrize('data_gen', _init_list_no_nans, ids=idfn)
@pytest.mark.parametrize('conf', get_params(_confs, params_markers_for_confs), ids=idfn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,13 @@ case class GpuAverage(child: Expression) extends GpuDeclarativeAggregate
// average = (0 + 1) and not 2 which is the rowcount of the projected column.
override lazy val updateExpressions: Seq[GpuExpression] = Seq(new CudfSum(cudfSum),
new CudfSum(cudfCount))

// NOTE: this sets `failOnErrorOverride=false` in `GpuDivide` to force it not to throw
// divide-by-zero exceptions, even when ansi mode is enabled in Spark.
// This is to conform with Spark's behavior in the Average aggregate function.
override lazy val evaluateExpression: GpuExpression = GpuDivide(
GpuCast(cudfSum, DoubleType),
GpuCast(cudfCount, DoubleType))
GpuCast(cudfCount, DoubleType), failOnErrorOverride = false)

override lazy val initialValues: Seq[GpuLiteral] = Seq(
GpuLiteral(0.0, DoubleType),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,8 @@ object GpuDivModLike {
}

trait GpuDivModLike extends CudfBinaryArithmetic {
lazy val failOnError: Boolean = ShimLoader.getSparkShims.shouldFailDivByZero()
lazy val failOnError: Boolean =
ShimLoader.getSparkShims.shouldFailDivByZero()

override def nullable: Boolean = true

Expand Down Expand Up @@ -330,13 +331,18 @@ object GpuDivideUtil {
}

// This is for doubles and floats...
case class GpuDivide(left: Expression, right: Expression) extends GpuDivModLike {
case class GpuDivide(left: Expression, right: Expression,
failOnErrorOverride: Boolean = ShimLoader.getSparkShims.shouldFailDivByZero())
extends GpuDivModLike {

override lazy val failOnError: Boolean = failOnErrorOverride

override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType)

override def symbol: String = "/"

override def binaryOp: BinaryOp = (left.dataType, right.dataType) match {
case (_: DecimalType, _: DecimalType) => BinaryOp.DIV
case (_: DecimalType, _: DecimalType) => BinaryOp.DIV
case _ => BinaryOp.TRUE_DIV
}

Expand Down

0 comments on commit 480c8c2

Please sign in to comment.