From 24aa9bb238d11e0156fc3027c929d6c865f8f27a Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Thu, 7 Jan 2021 13:21:56 -0600 Subject: [PATCH] Add in support for struct and named_struct (#1458) Signed-off-by: Robert (Bobby) Evans --- docs/configs.md | 1 + docs/supported_ops.md | 132 ++++++++++++++++++ .../src/main/python/struct_test.py | 10 +- .../nvidia/spark/rapids/GpuOverrides.scala | 9 +- .../com/nvidia/spark/rapids/TypeChecks.scala | 57 +++++++- .../spark/sql/rapids/complexTypeCreator.scala | 100 +++++++++++++ 6 files changed, 306 insertions(+), 3 deletions(-) create mode 100644 sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeCreator.scala diff --git a/docs/configs.md b/docs/configs.md index 1e236a1873f..062135abb38 100644 --- a/docs/configs.md +++ b/docs/configs.md @@ -129,6 +129,7 @@ Name | SQL Function(s) | Description | Default Value | Notes spark.rapids.sql.expression.Cos|`cos`|Cosine|true|None| spark.rapids.sql.expression.Cosh|`cosh`|Hyperbolic cosine|true|None| spark.rapids.sql.expression.Cot|`cot`|Cotangent|true|None| +spark.rapids.sql.expression.CreateNamedStruct|`named_struct`, `struct`|Creates a struct with the given field names and values.|true|None| spark.rapids.sql.expression.CurrentRow$| |Special boundary for a window frame, indicating stopping at the current row|true|None| spark.rapids.sql.expression.DateAdd|`date_add`|Returns the date that is num_days after start_date|true|None| spark.rapids.sql.expression.DateDiff|`datediff`|Returns the number of days from startDate to endDate|true|None| diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 3d4639dd7d7..a9856de427c 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -3283,6 +3283,138 @@ Accelerator support is described below. +CreateNamedStruct +`named_struct`, `struct` +Creates a struct with the given field names and values. +None +project +name + + + + + + + + + +S + + + + + + + + + + +value +S +S +S +S +S +S +S +S +S* +S +S* +S +NS +NS +NS +NS +NS +NS + + +result + + + + + + + + + + + + + + + + +PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) + + + +lambda +name + + + + + + + + + +NS + + + + + + + + + + +value +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS + + +result + + + + + + + + + + + + + + + + +NS + + + CurrentRow$ Special boundary for a window frame, indicating stopping at the current row diff --git a/integration_tests/src/main/python/struct_test.py b/integration_tests/src/main/python/struct_test.py index 604572c0f6e..004550049cd 100644 --- a/integration_tests/src/main/python/struct_test.py +++ b/integration_tests/src/main/python/struct_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2021, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,3 +31,11 @@ def test_struct_get_item(data_gen): 'a.first', 'a.second', 'a.third')) + +@pytest.mark.parametrize('data_gen', all_basic_gens + [decimal_gen_default, decimal_gen_scale_precision], ids=idfn) +def test_make_struct(data_gen): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : binary_op_df(spark, data_gen).selectExpr( + 'struct(a, b)', + 'named_struct("foo", b, "bar", 5, "end", a)')) + diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 3a3e5a6079c..4a063a90831 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1891,6 +1891,13 @@ object GpuOverrides { ("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)), ("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all)), (in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r)), + expr[CreateNamedStruct]( + "Creates a struct with the given field names and values.", + CreateNamedStructCheck, + (in, conf, p, r) => new ExprMeta[CreateNamedStruct](in, conf, p, r) { + override def convertToGpu(): GpuExpression = + GpuCreateNamedStruct(childExprs.map(_.convertToGpu())) + }), expr[StringLocate]( "Substring search operator", ExprChecks.projectNotLambda(TypeSig.INT, TypeSig.INT, diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index 0cef1ddeb21..264415dbb72 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -669,6 +669,61 @@ object WindowSpecCheck extends ExprChecks { } } +/** + * A check for CreateNamedStruct. The parameter values alternate between one type and another. + * If this pattern shows up again we can make this more generic at that point. + */ +object CreateNamedStructCheck extends ExprChecks { + val nameSig: TypeSig = TypeSig.lit(TypeEnum.STRING) + val sparkNameSig: TypeSig = TypeSig.lit(TypeEnum.STRING) + val valueSig: TypeSig = TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + val sparkValueSig: TypeSig = TypeSig.all + val resultSig: TypeSig = TypeSig.STRUCT.nested(valueSig) + val sparkResultSig: TypeSig = TypeSig.STRUCT.nested(sparkValueSig) + + override def tag(meta: RapidsMeta[_, _, _]): Unit = { + val exprMeta = meta.asInstanceOf[BaseExprMeta[_]] + val context = exprMeta.context + if (context != ProjectExprContext) { + meta.willNotWorkOnGpu(s"this is not supported in the $context context") + } else { + val origExpr = exprMeta.wrapped.asInstanceOf[Expression] + val (nameExprs, valExprs) = origExpr.children.grouped(2).map { + case Seq(name, value) => (name, value) + }.toList.unzip + nameExprs.foreach { expr => + nameSig.tagExprParam(meta, expr, "name") + } + valExprs.foreach { expr => + valueSig.tagExprParam(meta, expr, "value") + } + if (!resultSig.isSupportedByPlugin(origExpr.dataType, meta.conf.decimalTypeEnabled)) { + meta.willNotWorkOnGpu(s"unsupported data types in output: ${origExpr.dataType}") + } + } + } + + override def support(dataType: TypeEnum.Value): + Map[ExpressionContext, Map[String, SupportLevel]] = { + val nameProjectSupport = nameSig.getSupportLevel(dataType, sparkNameSig) + val nameLambdaSupport = TypeSig.none.getSupportLevel(dataType, sparkNameSig) + val valueProjectSupport = valueSig.getSupportLevel(dataType, sparkValueSig) + val valueLambdaSupport = TypeSig.none.getSupportLevel(dataType, sparkValueSig) + val resultProjectSupport = resultSig.getSupportLevel(dataType, sparkResultSig) + val resultLambdaSupport = TypeSig.none.getSupportLevel(dataType, sparkResultSig) + Map((ProjectExprContext, + Map( + ("name", nameProjectSupport), + ("value", valueProjectSupport), + ("result", resultProjectSupport))), + (LambdaExprContext, + Map( + ("name", nameLambdaSupport), + ("value", valueLambdaSupport), + ("result", resultLambdaSupport)))) + } +} + class CastChecks extends ExprChecks { // Don't show this with other operators show it in a different location override val shown: Boolean = false diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeCreator.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeCreator.scala new file mode 100644 index 00000000000..f134ae4577f --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeCreator.scala @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids + +import ai.rapids.cudf.ColumnVector +import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuScalar} +import com.nvidia.spark.rapids.RapidsPluginImplicits.{AutoCloseableArray, ReallyAGpuExpression} + +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FUNC_ALIAS +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.{EmptyRow, Expression, NamedExpression} +import org.apache.spark.sql.types.{Metadata, StringType, StructField, StructType} +import org.apache.spark.sql.vectorized.ColumnarBatch + +case class GpuCreateNamedStruct(children: Seq[Expression]) extends GpuExpression { + lazy val (nameExprs, valExprs) = children.grouped(2).map { + case Seq(name, value) => (name, value) + }.toList.unzip + + private lazy val names = nameExprs.map { + case g: GpuExpression => g.columnarEval(null) + case e => e.eval(EmptyRow) + } + + override def nullable: Boolean = false + + override def foldable: Boolean = valExprs.forall(_.foldable) + + override lazy val dataType: StructType = { + val fields = names.zip(valExprs).map { + case (name, expr) => + val metadata = expr match { + case ne: NamedExpression => ne.metadata + case _ => Metadata.empty + } + StructField(name.toString, expr.dataType, expr.nullable, metadata) + } + StructType(fields) + } + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.size % 2 != 0) { + TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.") + } else { + val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) + if (invalidNames.nonEmpty) { + TypeCheckResult.TypeCheckFailure( + s"Only foldable ${StringType.catalogString} expressions are allowed to appear at odd" + + s" position, got: ${invalidNames.mkString(",")}") + } else if (!names.contains(null)) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure("Field name should not be null") + } + } + } + + // There is an alias set at `CreateStruct.create`. If there is an alias, + // this is the struct function explicitly called by a user and we should + // respect it in the SQL string as `struct(...)`. + override def prettyName: String = getTagValue(FUNC_ALIAS).getOrElse("named_struct") + + override def sql: String = getTagValue(FUNC_ALIAS).map { alias => + val childrenSQL = children.indices.filter(_ % 2 == 1).map(children(_).sql).mkString(", ") + s"$alias($childrenSQL)" + }.getOrElse(super.sql) + + override def columnarEval(batch: ColumnarBatch): Any = { + // The names are only used for the type. Here we really just care about the data + withResource(new Array[ColumnVector](valExprs.size)) { columns => + val numRows = batch.numRows() + valExprs.indices.foreach { index => + valExprs(index).columnarEval(batch) match { + case cv: GpuColumnVector => + columns(index) = cv.getBase + case other => + val dt = dataType.fields(index).dataType + withResource(GpuScalar.from(other, dt)) { scalar => + columns(index) = ColumnVector.fromScalar(scalar, numRows) + } + } + } + GpuColumnVector.from(ColumnVector.makeStruct(numRows, columns: _*), dataType) + } + } +} \ No newline at end of file