Skip to content

Commit

Permalink
Incorrect output from averages with filters in partial only mode (#612)
Browse files Browse the repository at this point in the history
Signed-off-by: Kuhu Shukla <[email protected]>

Co-authored-by: Kuhu Shukla <[email protected]>
  • Loading branch information
Kuhu Shukla and kuhushukla authored Aug 27, 2020
1 parent 72a7d5d commit a15b228
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 21 deletions.
19 changes: 1 addition & 18 deletions integration_tests/src/main/python/hash_aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,28 +280,11 @@ def test_hash_multiple_filters(data_gen, conf):
"hash_agg_table",
'select count(a) filter (where c > 50),' +
'count(b) filter (where c > 100),' +
# Uncomment after https://github.com/NVIDIA/spark-rapids/issues/155 is fixed
# 'avg(b) filter (where b > 20),' +
'avg(b) filter (where b > 20),' +
'min(a), max(b) filter (where c > 250) from hash_agg_table group by a',
conf)


@pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/155')
@ignore_order
@allow_non_gpu(
'HashAggregateExec', 'AggregateExpression',
'AttributeReference', 'Alias', 'Sum', 'Count', 'Max', 'Min', 'Average', 'Cast',
'KnownFloatingPointNormalized', 'NormalizeNaNAndZero', 'GreaterThan', 'Literal', 'If',
'EqualTo', 'First', 'SortAggregateExec', 'Coalesce')
@pytest.mark.parametrize('data_gen', [_longs_with_nulls], ids=idfn)
def test_hash_multiple_filters_fail(data_gen):
assert_gpu_and_cpu_are_equal_sql(
lambda spark : gen_df(spark, data_gen, length=100),
"hash_agg_table",
'select avg(b) filter (where b > 20) from hash_agg_table group by a',
_no_nans_float_conf_partial)


@ignore_order
@allow_non_gpu('HashAggregateExec', 'AggregateExpression', 'AttributeReference', 'Alias', 'Max',
'KnownFloatingPointNormalized', 'NormalizeNaNAndZero')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import ai.rapids.cudf
import com.nvidia.spark.rapids._

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExprId, ImplicitCastInputTypes, Literal}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExprId, ImplicitCastInputTypes}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, Complete, Final, Partial, PartialMerge}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, BooleanType, DataType, DoubleType, LongType, NumericType, StructType}
Expand Down Expand Up @@ -76,8 +76,19 @@ case class GpuAggregateExpression(origAggregateFunction: GpuAggregateFunction,
case class WrappedAggFunction(aggregateFunction: GpuAggregateFunction, filter: Expression)
extends GpuDeclarativeAggregate {
override val inputProjection: Seq[GpuExpression] = {
val caseWhenExpressions = aggregateFunction.inputProjection.map {ip =>
GpuCaseWhen(Seq((filter, ip)))
val caseWhenExpressions = aggregateFunction.inputProjection.map { ip =>
// special case average with null result from the filter as expected values should be
// (0.0,0) for (sum, count)
val initialValue: Expression =
origAggregateFunction match {
case _ : GpuAverage => ip.dataType match {
case doubleType: DoubleType => GpuLiteral(0D, doubleType)
case _ : LongType => GpuLiteral(0L, LongType)
}
case _ => GpuLiteral(null, ip.dataType)
}
val filterConditional = GpuCaseWhen(Seq((filter, ip)))
GpuCaseWhen(Seq((GpuIsNotNull(filterConditional), filterConditional)), Some(initialValue))
}
caseWhenExpressions
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1635,4 +1635,45 @@ class HashAggregatesSuite extends SparkQueryCompareTestSuite {
.set(RapidsConf.ENABLE_FLOAT_AGG.key, "true")) {
frame => frame.groupBy(col("double")).agg(sum(col("int")))
}

testSparkResultsAreEqual("Agg expression with filter avg with nulls", nullDf, execsAllowedNonGpu =
Seq("HashAggregateExec", "AggregateExpression", "AttributeReference", "Alias", "Average",
"Count", "Cast"),
conf = partialOnlyConf, repart = 2) {
frame => frame.createOrReplaceTempView("testTable")
frame.sparkSession.sql(
s"""
| SELECT
| avg(more_longs) filter (where more_longs > 2)
| FROM testTable
| group by longs
|""".stripMargin)
}

testSparkResultsAreEqual("Agg expression with filter count with nulls",
nullDf, execsAllowedNonGpu = Seq("HashAggregateExec", "AggregateExpression",
"AttributeReference", "Alias", "Count", "Cast"),
conf = partialOnlyConf, repart = 2) {
frame => frame.createOrReplaceTempView("testTable")
frame.sparkSession.sql(
s"""
| SELECT
| count(more_longs) filter (where more_longs > 2)
| FROM testTable
| group by longs
|""".stripMargin)
}

testSparkResultsAreEqual("Agg expression with filter sum with nulls", nullDf, execsAllowedNonGpu =
Seq("HashAggregateExec", "AggregateExpression", "AttributeReference", "Alias", "Sum", "Cast"),
conf = partialOnlyConf, repart = 2) {
frame => frame.createOrReplaceTempView("testTable")
frame.sparkSession.sql(
s"""
| SELECT
| sum(more_longs) filter (where more_longs > 2)
| FROM testTable
| group by longs
|""".stripMargin)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1234,6 +1234,14 @@ trait SparkQueryCompareTestSuite extends FunSuite with Arm {
).toDF("doubles")
}

def nullDf(session: SparkSession): DataFrame = {
import session.sqlContext.implicits._
Seq[(java.lang.Long, java.lang.Long)](
(100L, 15L),
(100L, null)
).toDF("longs", "more_longs")
}

def mixedDoubleDf(session: SparkSession): DataFrame = {
import session.sqlContext.implicits._
Seq[(java.lang.Double, java.lang.Double)](
Expand Down

0 comments on commit a15b228

Please sign in to comment.