diff --git a/docs/configs.md b/docs/configs.md
index 78622f0a49b..adb03112042 100644
--- a/docs/configs.md
+++ b/docs/configs.md
@@ -252,6 +252,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
spark.rapids.sql.expression.Sin|`sin`|Sine|true|None|
spark.rapids.sql.expression.Sinh|`sinh`|Hyperbolic sine|true|None|
spark.rapids.sql.expression.Size|`size`, `cardinality`|The size of an array or a map|true|None|
+spark.rapids.sql.expression.SortArray|`sort_array`|Returns a sorted array with the input array and the ascending / descending order|true|None|
spark.rapids.sql.expression.SortOrder| |Sort order|true|None|
spark.rapids.sql.expression.SparkPartitionID|`spark_partition_id`|Returns the current partition id|true|None|
spark.rapids.sql.expression.SpecifiedWindowFrame| |Specification of the width of the group (or "frame") of input rows around which a window function is evaluated|true|None|
diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index 7ca8d298242..310d60b2613 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -13762,6 +13762,164 @@ Accelerator support is described below.
|
+SortArray |
+`sort_array` |
+Returns a sorted array with the input array and the ascending / descending order |
+None |
+project |
+array |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) |
+ |
+ |
+ |
+
+
+ascendingOrder |
+S |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+
+
+result |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) |
+ |
+ |
+ |
+
+
+lambda |
+array |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+NS |
+ |
+ |
+ |
+
+
+ascendingOrder |
+NS |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+
+
+result |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+NS |
+ |
+ |
+ |
+
+
+Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
SortOrder |
|
Sort order |
@@ -13857,32 +14015,6 @@ Accelerator support is described below.
|
-Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
SpecifiedWindowFrame |
|
Specification of the width of the group (or "frame") of input rows around which a window function is evaluated |
@@ -14173,6 +14305,32 @@ Accelerator support is described below.
|
+Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
StringLPad |
`lpad` |
Pad a string on the left |
@@ -14347,32 +14505,6 @@ Accelerator support is described below.
|
-Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
StringLocate |
`position`, `locate` |
Substring search operator |
@@ -14547,6 +14679,32 @@ Accelerator support is described below.
|
+Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
StringRPad |
`rpad` |
Pad a string on the right |
@@ -14721,32 +14879,6 @@ Accelerator support is described below.
|
-Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
StringReplace |
`replace` |
StringReplace operator |
@@ -14921,6 +15053,32 @@ Accelerator support is described below.
|
+Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
StringSplit |
`split` |
Splits `str` around occurrences that match `regex` |
@@ -15095,32 +15253,6 @@ Accelerator support is described below.
|
-Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
StringTrim |
`trim` |
StringTrim operator |
@@ -15385,6 +15517,32 @@ Accelerator support is described below.
|
+Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
StringTrimRight |
`rtrim` |
StringTrimRight operator |
@@ -15517,32 +15675,6 @@ Accelerator support is described below.
|
-Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
Substring |
`substr`, `substring` |
Substring operator |
diff --git a/integration_tests/src/main/python/collection_ops_test.py b/integration_tests/src/main/python/collection_ops_test.py
index 5083b301ce0..6c870885e6f 100644
--- a/integration_tests/src/main/python/collection_ops_test.py
+++ b/integration_tests/src/main/python/collection_ops_test.py
@@ -14,9 +14,10 @@
import pytest
-from asserts import assert_gpu_and_cpu_are_equal_collect
+from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql
from data_gen import *
from pyspark.sql.types import *
+from spark_session import with_cpu_session
from string_test import mk_str_gen
import pyspark.sql.functions as f
@@ -95,3 +96,18 @@ def test_size_of_map(data_gen, size_of_null):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr('size(a)'),
conf={'spark.sql.legacy.sizeOfNull': size_of_null})
+
+@pytest.mark.parametrize('data_gen', non_nested_array_gens, ids=idfn)
+@pytest.mark.parametrize('is_ascending', [True, False], ids=idfn)
+def test_sort_array(data_gen, is_ascending):
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark: unary_op_df(spark, data_gen).select(
+ f.sort_array(f.col('a'), is_ascending)))
+
+@pytest.mark.parametrize('data_gen', non_nested_array_gens, ids=idfn)
+@pytest.mark.parametrize('is_ascending', [True, False], ids=idfn)
+def test_sort_array_lit(data_gen, is_ascending):
+ array_lit = gen_scalar(data_gen)
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark: unary_op_df(spark, data_gen, length=10).select(
+ f.sort_array(f.lit(array_lit), is_ascending)))
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index 223c3ba8716..c22db905880 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -2362,6 +2362,20 @@ object GpuOverrides {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuArrayContains(lhs, rhs)
}),
+ expr[SortArray](
+ "Returns a sorted array with the input array and the ascending / descending order",
+ ExprChecks.binaryProjectNotLambda(
+ TypeSig.ARRAY.nested(_commonTypes),
+ TypeSig.ARRAY.nested(TypeSig.all),
+ ("array", TypeSig.ARRAY.nested(_commonTypes),
+ TypeSig.ARRAY.nested(TypeSig.all)),
+ ("ascendingOrder", TypeSig.lit(TypeEnum.BOOLEAN), TypeSig.lit(TypeEnum.BOOLEAN))),
+ (sortExpression, conf, p, r) => new BinaryExprMeta[SortArray](sortExpression, conf, p, r) {
+ override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = {
+ GpuSortArray(lhs, rhs)
+ }
+ }
+ ),
expr[CreateArray](
" Returns an array with the given elements",
ExprChecks.projectNotLambda(
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala
index 9be8ed2ef24..8542022fa91 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala
@@ -19,12 +19,12 @@ package org.apache.spark.sql.rapids
import scala.collection.mutable.ArrayBuffer
import ai.rapids.cudf.{ColumnVector, ColumnView, Scalar}
-import com.nvidia.spark.rapids.{GpuBinaryExpression, GpuColumnVector, GpuComplexTypeMergingExpression, GpuExpressionsUtils, GpuScalar, GpuUnaryExpression}
+import com.nvidia.spark.rapids.{GpuBinaryExpression, GpuColumnVector, GpuComplexTypeMergingExpression, GpuExpression, GpuExpressionsUtils, GpuLiteral, GpuScalar, GpuUnaryExpression}
import com.nvidia.spark.rapids.GpuExpressionsUtils.columnarEvalToColumn
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
-import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression}
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, Literal, RowOrdering}
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.unsafe.types.UTF8String
@@ -215,3 +215,58 @@ case class GpuSize(child: Expression, legacySizeOfNull: Boolean)
}
}
}
+
+case class GpuSortArray(base: Expression, ascendingOrder: Expression)
+ extends GpuBinaryExpression with ExpectsInputTypes {
+
+ override def left: Expression = base
+
+ override def right: Expression = ascendingOrder
+
+ override def dataType: DataType = base.dataType
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType)
+
+ override def checkInputDataTypes(): TypeCheckResult = base.dataType match {
+ case ArrayType(dt, _) if RowOrdering.isOrderable(dt) =>
+ ascendingOrder match {
+ // replace Literal with GpuLiteral here
+ case GpuLiteral(_: Boolean, BooleanType) =>
+ TypeCheckResult.TypeCheckSuccess
+ case order =>
+ TypeCheckResult.TypeCheckFailure(
+ s"Sort order in second argument requires a boolean literal, but found $order")
+ }
+ case ArrayType(dt, _) =>
+ val dtSimple = dt.catalogString
+ TypeCheckResult.TypeCheckFailure(
+ s"$prettyName does not support sorting array of type $dtSimple which is not orderable")
+ case dt =>
+ TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input, but found $dt")
+ }
+
+ override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector =
+ throw new IllegalArgumentException("lhs has to be a vector and rhs has to be a scalar for " +
+ "the sort_array operator to work")
+
+ override def doColumnar(lhs: GpuScalar, rhs: GpuColumnVector): ColumnVector =
+ throw new IllegalArgumentException("lhs has to be a vector and rhs has to be a scalar for " +
+ "the sort_array operator to work")
+
+ override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = {
+ val isDescending = isDescendingOrder(rhs)
+ lhs.getBase.listSortRows(isDescending, true)
+ }
+
+ override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = {
+ val isDescending = isDescendingOrder(rhs)
+ withResource(GpuColumnVector.from(lhs, numRows, left.dataType)) { cv =>
+ cv.getBase.listSortRows(isDescending, true)
+ }
+ }
+
+ private def isDescendingOrder(scalar: GpuScalar): Boolean = scalar.getValue match {
+ case ascending: Boolean => !ascending
+ case invalidValue => throw new IllegalArgumentException(s"invalid value $invalidValue")
+ }
+}