diff --git a/sql-plugin/src/main/311until340-all/scala/com/nvidia/spark/rapids/shims/TypeUtilsShims.scala b/sql-plugin/src/main/311until340-all/scala/com/nvidia/spark/rapids/shims/TypeUtilsShims.scala new file mode 100644 index 00000000000..bc7d65daed6 --- /dev/null +++ b/sql-plugin/src/main/311until340-all/scala/com/nvidia/spark/rapids/shims/TypeUtilsShims.scala @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims + +import org.apache.spark.sql.catalyst.util.TypeUtils + +/** + * Reimplement the function `checkForNumericExpr` which has been removed since + * Spark 3.4.0 + */ +object TypeUtilsShims { + val checkForNumericExpr = TypeUtils.checkForNumericExpr _ +} diff --git a/sql-plugin/src/main/340+/scala/com/nvidia/spark/rapids/shims/TypeUtilsShims.scala b/sql-plugin/src/main/340+/scala/com/nvidia/spark/rapids/shims/TypeUtilsShims.scala new file mode 100644 index 00000000000..8e144cf6444 --- /dev/null +++ b/sql-plugin/src/main/340+/scala/com/nvidia/spark/rapids/shims/TypeUtilsShims.scala @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.types.{DataType, NullType, NumericType} + +/** + * Reimplement the function `checkForNumericExpr` which has been removed since + * Spark 3.4.0 + */ +object TypeUtilsShims { + def checkForNumericExpr(dt: DataType, caller: String): TypeCheckResult = { + if (dt.isInstanceOf[NumericType] || dt == NullType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"$caller requires numeric types, not ${dt.catalogString}") + } + } +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala index affe29a776f..ea12c3520ad 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.rapids import ai.rapids.cudf import ai.rapids.cudf.{Aggregation128Utils, BinaryOp, ColumnVector, DType, GroupByAggregation, GroupByScanAggregation, NaNEquality, NullEquality, NullPolicy, ReductionAggregation, ReplacePolicy, RollingAggregation, RollingAggregationOnColumn, Scalar, ScanAggregation} import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.shims.{GpuDeterministicFirstLastCollectShim, ShimExpression, ShimUnaryExpression} +import com.nvidia.spark.rapids.shims.{GpuDeterministicFirstLastCollectShim, ShimExpression, ShimUnaryExpression, TypeUtilsShims} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult @@ -1081,7 +1081,7 @@ abstract class GpuSum( override def children: Seq[Expression] = child :: Nil override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function gpu sum") + TypeUtilsShims.checkForNumericExpr(child.dataType, "function gpu sum") // GENERAL WINDOW FUNCTION // Spark 3.2.0+ stopped casting the input data to the output type before the sum operation @@ -1572,7 +1572,7 @@ abstract class GpuAverage(child: Expression, sumDataType: DataType) extends GpuA override def children: Seq[Expression] = child :: Nil override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function gpu average") + TypeUtilsShims.checkForNumericExpr(child.dataType, "function gpu average") override def nullable: Boolean = true