Skip to content

Commit

Permalink
Add in support for struct and named_struct (NVIDIA#1458)
Browse files Browse the repository at this point in the history
Signed-off-by: Robert (Bobby) Evans <[email protected]>
  • Loading branch information
revans2 authored Jan 7, 2021
1 parent d0fc637 commit 24aa9bb
Show file tree
Hide file tree
Showing 6 changed files with 306 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.Cos"></a>spark.rapids.sql.expression.Cos|`cos`|Cosine|true|None|
<a name="sql.expression.Cosh"></a>spark.rapids.sql.expression.Cosh|`cosh`|Hyperbolic cosine|true|None|
<a name="sql.expression.Cot"></a>spark.rapids.sql.expression.Cot|`cot`|Cotangent|true|None|
<a name="sql.expression.CreateNamedStruct"></a>spark.rapids.sql.expression.CreateNamedStruct|`named_struct`, `struct`|Creates a struct with the given field names and values.|true|None|
<a name="sql.expression.CurrentRow$"></a>spark.rapids.sql.expression.CurrentRow$| |Special boundary for a window frame, indicating stopping at the current row|true|None|
<a name="sql.expression.DateAdd"></a>spark.rapids.sql.expression.DateAdd|`date_add`|Returns the date that is num_days after start_date|true|None|
<a name="sql.expression.DateDiff"></a>spark.rapids.sql.expression.DateDiff|`datediff`|Returns the number of days from startDate to endDate|true|None|
Expand Down
132 changes: 132 additions & 0 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -3283,6 +3283,138 @@ Accelerator support is described below.
<td> </td>
</tr>
<tr>
<td rowSpan="6">CreateNamedStruct</td>
<td rowSpan="6">`named_struct`, `struct`</td>
<td rowSpan="6">Creates a struct with the given field names and values.</td>
<td rowSpan="6">None</td>
<td rowSpan="3">project</td>
<td>name</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>value</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S*</td>
<td>S</td>
<td>S*</td>
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td> </td>
</tr>
<tr>
<td rowSpan="3">lambda</td>
<td>name</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>value</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
</tr>
<tr>
<td rowSpan="1">CurrentRow$</td>
<td rowSpan="1"> </td>
<td rowSpan="1">Special boundary for a window frame, indicating stopping at the current row</td>
Expand Down
10 changes: 9 additions & 1 deletion integration_tests/src/main/python/struct_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,3 +31,11 @@ def test_struct_get_item(data_gen):
'a.first',
'a.second',
'a.third'))

@pytest.mark.parametrize('data_gen', all_basic_gens + [decimal_gen_default, decimal_gen_scale_precision], ids=idfn)
def test_make_struct(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : binary_op_df(spark, data_gen).selectExpr(
'struct(a, b)',
'named_struct("foo", b, "bar", 5, "end", a)'))

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -1891,6 +1891,13 @@ object GpuOverrides {
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all)),
(in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r)),
expr[CreateNamedStruct](
"Creates a struct with the given field names and values.",
CreateNamedStructCheck,
(in, conf, p, r) => new ExprMeta[CreateNamedStruct](in, conf, p, r) {
override def convertToGpu(): GpuExpression =
GpuCreateNamedStruct(childExprs.map(_.convertToGpu()))
}),
expr[StringLocate](
"Substring search operator",
ExprChecks.projectNotLambda(TypeSig.INT, TypeSig.INT,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -669,6 +669,61 @@ object WindowSpecCheck extends ExprChecks {
}
}

/**
* A check for CreateNamedStruct. The parameter values alternate between one type and another.
* If this pattern shows up again we can make this more generic at that point.
*/
object CreateNamedStructCheck extends ExprChecks {
val nameSig: TypeSig = TypeSig.lit(TypeEnum.STRING)
val sparkNameSig: TypeSig = TypeSig.lit(TypeEnum.STRING)
val valueSig: TypeSig = TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL
val sparkValueSig: TypeSig = TypeSig.all
val resultSig: TypeSig = TypeSig.STRUCT.nested(valueSig)
val sparkResultSig: TypeSig = TypeSig.STRUCT.nested(sparkValueSig)

override def tag(meta: RapidsMeta[_, _, _]): Unit = {
val exprMeta = meta.asInstanceOf[BaseExprMeta[_]]
val context = exprMeta.context
if (context != ProjectExprContext) {
meta.willNotWorkOnGpu(s"this is not supported in the $context context")
} else {
val origExpr = exprMeta.wrapped.asInstanceOf[Expression]
val (nameExprs, valExprs) = origExpr.children.grouped(2).map {
case Seq(name, value) => (name, value)
}.toList.unzip
nameExprs.foreach { expr =>
nameSig.tagExprParam(meta, expr, "name")
}
valExprs.foreach { expr =>
valueSig.tagExprParam(meta, expr, "value")
}
if (!resultSig.isSupportedByPlugin(origExpr.dataType, meta.conf.decimalTypeEnabled)) {
meta.willNotWorkOnGpu(s"unsupported data types in output: ${origExpr.dataType}")
}
}
}

override def support(dataType: TypeEnum.Value):
Map[ExpressionContext, Map[String, SupportLevel]] = {
val nameProjectSupport = nameSig.getSupportLevel(dataType, sparkNameSig)
val nameLambdaSupport = TypeSig.none.getSupportLevel(dataType, sparkNameSig)
val valueProjectSupport = valueSig.getSupportLevel(dataType, sparkValueSig)
val valueLambdaSupport = TypeSig.none.getSupportLevel(dataType, sparkValueSig)
val resultProjectSupport = resultSig.getSupportLevel(dataType, sparkResultSig)
val resultLambdaSupport = TypeSig.none.getSupportLevel(dataType, sparkResultSig)
Map((ProjectExprContext,
Map(
("name", nameProjectSupport),
("value", valueProjectSupport),
("result", resultProjectSupport))),
(LambdaExprContext,
Map(
("name", nameLambdaSupport),
("value", valueLambdaSupport),
("result", resultLambdaSupport))))
}
}

class CastChecks extends ExprChecks {
// Don't show this with other operators show it in a different location
override val shown: Boolean = false
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.rapids

import ai.rapids.cudf.ColumnVector
import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuScalar}
import com.nvidia.spark.rapids.RapidsPluginImplicits.{AutoCloseableArray, ReallyAGpuExpression}

import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FUNC_ALIAS
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.{EmptyRow, Expression, NamedExpression}
import org.apache.spark.sql.types.{Metadata, StringType, StructField, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch

case class GpuCreateNamedStruct(children: Seq[Expression]) extends GpuExpression {
lazy val (nameExprs, valExprs) = children.grouped(2).map {
case Seq(name, value) => (name, value)
}.toList.unzip

private lazy val names = nameExprs.map {
case g: GpuExpression => g.columnarEval(null)
case e => e.eval(EmptyRow)
}

override def nullable: Boolean = false

override def foldable: Boolean = valExprs.forall(_.foldable)

override lazy val dataType: StructType = {
val fields = names.zip(valExprs).map {
case (name, expr) =>
val metadata = expr match {
case ne: NamedExpression => ne.metadata
case _ => Metadata.empty
}
StructField(name.toString, expr.dataType, expr.nullable, metadata)
}
StructType(fields)
}

override def checkInputDataTypes(): TypeCheckResult = {
if (children.size % 2 != 0) {
TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.")
} else {
val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType)
if (invalidNames.nonEmpty) {
TypeCheckResult.TypeCheckFailure(
s"Only foldable ${StringType.catalogString} expressions are allowed to appear at odd" +
s" position, got: ${invalidNames.mkString(",")}")
} else if (!names.contains(null)) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure("Field name should not be null")
}
}
}

// There is an alias set at `CreateStruct.create`. If there is an alias,
// this is the struct function explicitly called by a user and we should
// respect it in the SQL string as `struct(...)`.
override def prettyName: String = getTagValue(FUNC_ALIAS).getOrElse("named_struct")

override def sql: String = getTagValue(FUNC_ALIAS).map { alias =>
val childrenSQL = children.indices.filter(_ % 2 == 1).map(children(_).sql).mkString(", ")
s"$alias($childrenSQL)"
}.getOrElse(super.sql)

override def columnarEval(batch: ColumnarBatch): Any = {
// The names are only used for the type. Here we really just care about the data
withResource(new Array[ColumnVector](valExprs.size)) { columns =>
val numRows = batch.numRows()
valExprs.indices.foreach { index =>
valExprs(index).columnarEval(batch) match {
case cv: GpuColumnVector =>
columns(index) = cv.getBase
case other =>
val dt = dataType.fields(index).dataType
withResource(GpuScalar.from(other, dt)) { scalar =>
columns(index) = ColumnVector.fromScalar(scalar, numRows)
}
}
}
GpuColumnVector.from(ColumnVector.makeStruct(numRows, columns: _*), dataType)
}
}
}

0 comments on commit 24aa9bb

Please sign in to comment.