diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index 16ed1d7ec45..a2577e8393d 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -14513,95 +14513,10 @@ are limited.
UDT |
-ApproximatePercentile |
-`percentile_approx`, `approx_percentile` |
-Approximate percentile |
-This is not 100% compatible with the Spark version because the GPU implementation of approx_percentile is not bit-for-bit compatible with Apache Spark. To enable it, set spark.rapids.sql.incompatibleOps.enabled |
-reduction |
-input |
- |
-NS |
-NS |
-NS |
-NS |
-NS |
-NS |
-NS |
-NS |
- |
-NS |
- |
- |
- |
- |
- |
- |
- |
-
-
-percentage |
- |
- |
- |
- |
- |
- |
-NS |
- |
- |
- |
- |
- |
- |
- |
-NS |
- |
- |
- |
-
-
-accuracy |
- |
- |
- |
-NS |
- |
- |
- |
- |
- |
- |
- |
- |
- |
- |
- |
- |
- |
- |
-
-
-result |
- |
-NS |
-NS |
-NS |
-NS |
-NS |
-NS |
-NS |
-NS |
- |
-NS |
- |
- |
- |
-NS |
- |
- |
- |
-
-
+ApproximatePercentile |
+`percentile_approx`, `approx_percentile` |
+Approximate percentile |
+This is not 100% compatible with the Spark version because the GPU implementation of approx_percentile is not bit-for-bit compatible with Apache Spark. To enable it, set spark.rapids.sql.incompatibleOps.enabled |
aggregation |
input |
|
@@ -14687,19 +14602,19 @@ are limited.
|
-window |
+reduction |
input |
|
-NS |
-NS |
-NS |
-NS |
-NS |
-NS |
+S |
+S |
+S |
+S |
+S |
+S |
NS |
NS |
|
-NS |
+S |
|
|
|
@@ -14716,7 +14631,7 @@ are limited.
|
|
|
-NS |
+S |
|
|
|
@@ -14724,7 +14639,7 @@ are limited.
|
|
|
-NS |
+S |
|
|
|
@@ -14734,7 +14649,7 @@ are limited.
|
|
|
-NS |
+S |
|
|
|
@@ -14753,20 +14668,20 @@ are limited.
result |
|
-NS |
-NS |
-NS |
-NS |
-NS |
-NS |
+S |
+S |
+S |
+S |
+S |
+S |
NS |
NS |
|
-NS |
+S |
|
|
|
-NS |
+PS unsupported child types DATE, TIMESTAMP |
|
|
|
@@ -14905,32 +14820,6 @@ are limited.
|
-Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
CollectList |
`collect_list` |
Collect a list of non-unique elements, not supported in reduction |
@@ -15064,6 +14953,32 @@ are limited.
|
+Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
CollectSet |
`collect_set` |
Collect a set of unique elements, not supported in reduction |
@@ -15330,32 +15245,6 @@ are limited.
|
-Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
First |
`first_value`, `first` |
first aggregate operator |
@@ -15489,6 +15378,32 @@ are limited.
NS |
+Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
Last |
`last`, `last_value` |
last aggregate operator |
@@ -15755,32 +15670,6 @@ are limited.
NS |
-Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
Min |
`min` |
Min aggregate operator |
@@ -15914,6 +15803,32 @@ are limited.
NS |
+Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
PivotFirst |
|
PivotFirst operator |
@@ -16179,32 +16094,6 @@ are limited.
|
-Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
StddevSamp |
`stddev_samp`, `std`, `stddev` |
Aggregation computing sample standard deviation |
@@ -16338,6 +16227,32 @@ are limited.
|
+Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
Sum |
`sum` |
Sum aggregate operator |
@@ -16604,32 +16519,6 @@ are limited.
|
-Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
VarianceSamp |
`var_samp`, `variance` |
Aggregation computing sample variance |
@@ -16763,6 +16652,32 @@ are limited.
|
+Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
NormalizeNaNAndZero |
|
Normalize NaN and zero |
diff --git a/integration_tests/src/main/python/hash_aggregate_test.py b/integration_tests/src/main/python/hash_aggregate_test.py
index 9ce6c721d8f..4ab98da2eae 100644
--- a/integration_tests/src/main/python/hash_aggregate_test.py
+++ b/integration_tests/src/main/python/hash_aggregate_test.py
@@ -1271,6 +1271,14 @@ def do_it(spark):
return df.groupBy('a').agg(f.min(df.b[1]["a"]))
assert_gpu_and_cpu_are_equal_collect(do_it)
+@incompat
+@pytest.mark.parametrize('aqe_enabled', ['false', 'true'], ids=idfn)
+def test_hash_groupby_approx_percentile_reduction(aqe_enabled):
+ conf = {'spark.sql.adaptive.enabled': aqe_enabled}
+ compare_percentile_approx(
+ lambda spark: gen_df(spark, [('v', DoubleGen())], length=100),
+ [0.05, 0.25, 0.5, 0.75, 0.95], conf, reduction = True)
+
@incompat
@pytest.mark.parametrize('aqe_enabled', ['false', 'true'], ids=idfn)
def test_hash_groupby_approx_percentile_byte(aqe_enabled):
@@ -1405,11 +1413,11 @@ def test_hash_groupby_approx_percentile_decimal128_single():
# results due to the different algorithms being used. Instead we compute an exact percentile on the CPU and then
# compute approximate percentiles on CPU and GPU and assert that the GPU numbers are accurate within some percentage
# of the CPU numbers
-def compare_percentile_approx(df_fun, percentiles, conf = {}):
+def compare_percentile_approx(df_fun, percentiles, conf = {}, reduction = False):
# create SQL statements for exact and approx percentiles
- p_exact_sql = create_percentile_sql("percentile", percentiles)
- p_approx_sql = create_percentile_sql("approx_percentile", percentiles)
+ p_exact_sql = create_percentile_sql("percentile", percentiles, reduction)
+ p_approx_sql = create_percentile_sql("approx_percentile", percentiles, reduction)
def run_exact(spark):
df = df_fun(spark)
@@ -1436,8 +1444,9 @@ def run_approx(spark):
gpu_approx_result = approx_gpu[i]
# assert that keys match
- assert cpu_exact_result['k'] == cpu_approx_result['k']
- assert cpu_exact_result['k'] == gpu_approx_result['k']
+ if not reduction:
+ assert cpu_exact_result['k'] == cpu_approx_result['k']
+ assert cpu_exact_result['k'] == gpu_approx_result['k']
# extract the percentile result column
exact_percentile = cpu_exact_result['the_percentile']
@@ -1472,13 +1481,22 @@ def run_approx(spark):
else:
assert abs(cpu_delta / gpu_delta) - 1 < 0.001
-def create_percentile_sql(func_name, percentiles):
- if isinstance(percentiles, list):
- return """select k, {}(v, array({})) as the_percentile from t group by k order by k""".format(
- func_name, ",".join(str(i) for i in percentiles))
+def create_percentile_sql(func_name, percentiles, reduction):
+ if reduction:
+ if isinstance(percentiles, list):
+ return """select {}(v, array({})) as the_percentile from t""".format(
+ func_name, ",".join(str(i) for i in percentiles))
+ else:
+ return """select {}(v, {}) as the_percentile from t""".format(
+ func_name, percentiles)
else:
- return """select k, {}(v, {}) as the_percentile from t group by k order by k""".format(
- func_name, percentiles)
+ if isinstance(percentiles, list):
+ return """select k, {}(v, array({})) as the_percentile from t group by k order by k""".format(
+ func_name, ",".join(str(i) for i in percentiles))
+ else:
+ return """select k, {}(v, {}) as the_percentile from t group by k order by k""".format(
+ func_name, percentiles)
+
@ignore_order
@pytest.mark.parametrize('data_gen', [_grpkey_strings_with_extra_nulls], ids=idfn)
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuApproximatePercentile.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuApproximatePercentile.scala
index f70068cfb39..fae1d8335c1 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuApproximatePercentile.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuApproximatePercentile.scala
@@ -17,7 +17,7 @@
package com.nvidia.spark.rapids
import ai.rapids.cudf
-import ai.rapids.cudf.GroupByAggregation
+import ai.rapids.cudf.{DType, GroupByAggregation, ReductionAggregation}
import com.nvidia.spark.rapids.GpuCast.doCast
import com.nvidia.spark.rapids.shims.ShimExpression
@@ -178,8 +178,12 @@ case class ApproxPercentileFromTDigestExpr(
class CudfTDigestUpdate(accuracyExpression: GpuLiteral)
extends CudfAggregate {
- override lazy val reductionAggregate: cudf.ColumnVector => cudf.Scalar = _ =>
- throw new UnsupportedOperationException("TDigest is not yet supported in reduction")
+
+ override lazy val reductionAggregate: cudf.ColumnVector => cudf.Scalar =
+ (col: cudf.ColumnVector) =>
+ col.reduce(ReductionAggregation.createTDigest(CudfTDigest.accuracy(accuracyExpression)),
+ DType.STRUCT)
+
override lazy val groupByAggregate: GroupByAggregation =
GroupByAggregation.createTDigest(CudfTDigest.accuracy(accuracyExpression))
override val name: String = "CudfTDigestUpdate"
@@ -189,8 +193,10 @@ class CudfTDigestUpdate(accuracyExpression: GpuLiteral)
class CudfTDigestMerge(accuracyExpression: GpuLiteral)
extends CudfAggregate {
- override lazy val reductionAggregate: cudf.ColumnVector => cudf.Scalar = _ =>
- throw new UnsupportedOperationException("TDigest is not yet supported in reduction")
+ override lazy val reductionAggregate: cudf.ColumnVector => cudf.Scalar =
+ (col: cudf.ColumnVector) =>
+ col.reduce(ReductionAggregation.mergeTDigest(CudfTDigest.accuracy(accuracyExpression)))
+
override lazy val groupByAggregate: GroupByAggregation =
GroupByAggregation.mergeTDigest(CudfTDigest.accuracy(accuracyExpression))
override val name: String = "CudfTDigestMerge"
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index e437867f635..dfc3b2e4e8f 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -3308,7 +3308,7 @@ object GpuOverrides extends Logging {
}),
expr[ApproximatePercentile](
"Approximate percentile",
- ExprChecks.groupByOnly(
+ ExprChecks.reductionAndGroupByAgg(
// note that output can be single number or array depending on whether percentiles param
// is a single number or an array
TypeSig.gpuNumeric +
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ApproximatePercentileSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ApproximatePercentileSuite.scala
index c4837252816..8f3dc2305ce 100644
--- a/tests/src/test/scala/com/nvidia/spark/rapids/ApproximatePercentileSuite.scala
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/ApproximatePercentileSuite.scala
@@ -94,8 +94,12 @@ class ApproximatePercentileSuite extends SparkQueryCompareTestSuite {
"FROM salaries GROUP BY dept")
}
- test("fall back to CPU for reduction") {
- sqlFallbackTest("SELECT approx_percentile(salary, array(0.5)) FROM salaries")
+ testSparkResultsAreEqual("approx percentile reduction",
+ df => salaries(df, DataTypes.DoubleType, 100),
+ maxFloatDiff = 25.0, // approx percentile on GPU uses a different algorithm to Spark
+ incompat = true) { df =>
+ df.createOrReplaceTempView("salaries")
+ df.sparkSession.sql("SELECT approx_percentile(salary, array(0.5)) FROM salaries")
}
def sqlFallbackTest(sql: String) {