From a42e328b7ce8c01a53722cb0f06d9b58ddc450f1 Mon Sep 17 00:00:00 2001 From: Gera Shegalov Date: Tue, 13 Dec 2022 10:00:38 -0800 Subject: [PATCH] Add a shim for Databricks 11.3 spark330db [databricks] (#7152) * introduce non330db directories * ShimExtractValue * GpuPredicateHelper now extends and shims PredicateHelper * Allow passing TEST_PARALLEL to test.sh to be able to run integration tests on a small instance * No need to override getSparkShimVersion using the same implementation in every shim Fixes #6879 Signed-off-by: Gera Shegalov Signed-off-by: Ahmed Hussein (amahussein) Co-authored-by: Ahmed Hussein (amahussein) Signed-off-by: Niranjan Artal Co-authored-by: Niranjan Artal --- jenkins/databricks/build.sh | 4 +- jenkins/databricks/test.sh | 2 +- pom.xml | 50 ++- sql-plugin/pom.xml | 18 +- .../spark/rapids/shims/ShimLeafExecNode.scala | 4 +- .../rapids/shims/ShimExtractValue.scala} | 6 +- .../spark/rapids/shims/SparkShims.scala | 4 - .../rapids/shims/ShimPredicateHelper.scala | 39 +++ .../rapids/shims/CastingConfigShim.scala | 0 .../shims/DecimalArithmeticOverrides.scala | 0 .../spark/rapids/shims/GetMapValueMeta.scala | 0 .../rapids/shims/ParquetStringPredShims.scala | 0 .../ShimFilePartitionReaderFactory.scala | 0 .../spark/rapids/shims/TypeUtilsShims.scala | 0 .../spark/rapids/shims/SparkShims.scala | 4 - .../spark/rapids/shims/SparkShims.scala | 4 - .../spark/rapids/shims/SparkShims.scala | 4 - .../spark/rapids/shims/SparkShims.scala | 6 +- .../rapids/shims/ShimPredicateHelper.scala | 33 ++ .../spark/rapids/shims/SparkShims.scala | 4 +- .../rapids/shims/AnsiCastRuleShims.scala} | 2 +- .../rapids/DataSourceStrategyUtils.scala | 0 .../rapids/shims/GpuCheckDeltaInvariant.scala | 6 +- .../shims/GpuDeltaInvariantCheckerExec.scala | 0 .../tahoe/rapids/shims/GpuDeltaLog.scala | 0 .../shims/GpuOptimisticTransactionBase.scala} | 33 +- .../rapids/shims/GpuWriteIntoDelta.scala | 0 .../delta/shims/DeltaProviderImpl.scala | 0 .../delta/shims/DeltaProviderShims.scala | 0 .../delta/shims/GpuDeltaDataSource.scala | 0 .../rapids/delta/shims/package-shims.scala | 7 +- .../nvidia/spark/rapids/shims/AQEUtils.scala | 0 .../rapids/shims/AggregationTagging.scala | 0 .../spark/rapids/shims/DeltaLakeUtils.scala | 0 .../shims/ShimBroadcastExchangeLike.scala | 0 .../rapids/shims/Spark321PlusDBShims.scala | 291 ++++++++++++++++++ .../shims/GpuSubqueryBroadcastMeta.scala | 0 ...ReuseGpuBroadcastExchangeAndSubquery.scala | 0 .../rapids/shims/GpuShuffleExchangeExec.scala | 0 .../shims/GpuFlatMapGroupsInPandasExec.scala | 0 .../spark/rapids/shims/SparkShims.scala | 2 +- .../spark/rapids/shims/SparkShims.scala | 2 +- .../shims/GpuOptimisticTransaction.scala | 65 ++++ .../delta/shims/DeltaShims321PlusDB.scala | 38 +++ .../spark/rapids/shims/SparkShims.scala | 274 +---------------- .../spark/rapids/shims/SparkShims.scala | 2 +- .../spark/rapids/shims/SparkShims.scala | 2 +- .../rapids/shims/Spark330PlusNonDBShims.scala | 19 ++ .../rapids/shims/Spark330PlusShims.scala | 0 .../com/nvidia/spark/rapids/SparkShims.scala | 6 +- .../spark/rapids/shims/SparkShims.scala | 6 +- .../shims/GpuOptimisticTransaction.scala | 49 +++ .../delta/shims/DeltaShims321PlusDB.scala | 39 +++ .../spark/rapids/shims/SparkShims.scala | 53 ++++ .../spark330db/SparkShimServiceProvider.scala | 34 ++ .../shims/SparkDateTimeExceptionShims.scala | 35 +++ .../shims/SparkUpgradeExceptionShims.scala | 33 ++ .../parquet/rapids/shims/ParquetCVShims.scala | 0 .../sql/rapids/shims/RapidsErrorUtils.scala | 0 .../shims/SparkUpgradeExceptionShims.scala | 0 .../rapids/shims/Spark331PlusShims.scala | 2 +- .../com/nvidia/spark/rapids/SparkShims.scala | 6 +- .../spark/rapids/shims/SparkShims.scala | 19 ++ .../rapids/shims/CastingConfigShim.scala | 0 .../shims/DecimalArithmeticOverrides.scala | 0 .../spark/rapids/shims/GetMapValueMeta.scala | 0 .../rapids/shims/ParquetStringPredShims.scala | 0 .../ShimFilePartitionReaderFactory.scala | 0 .../spark/rapids/shims/TypeUtilsShims.scala | 0 .../parquet/rapids/shims/ParquetCVShims.scala | 0 .../rapids/DataSourceStrategyUtils.scala | 0 .../sql/rapids/shims/RapidsErrorUtils.scala | 5 +- .../shims/SparkDateTimeExceptionShims.scala | 34 ++ .../com/nvidia/spark/rapids/SparkShims.scala | 2 +- .../spark/rapids/basicPhysicalOperators.scala | 14 +- .../sql/rapids/complexTypeExtractors.scala | 33 +- .../apache/spark/sql/rapids/predicates.scala | 11 +- .../sql/rapids/shims/GpuFileScanRDD.scala | 0 78 files changed, 907 insertions(+), 399 deletions(-) rename sql-plugin/src/main/{332/scala/com/nvidia/spark/rapids/SparkShims.scala => 311+-non330db/scala/com/nvidia/spark/rapids/shims/ShimExtractValue.scala} (77%) create mode 100644 sql-plugin/src/main/311until320-all/scala/com/nvidia/spark/rapids/shims/ShimPredicateHelper.scala rename sql-plugin/src/main/{311until340-all => 311until340-non330db}/scala/com/nvidia/spark/rapids/shims/CastingConfigShim.scala (100%) rename sql-plugin/src/main/{311until340-all => 311until340-non330db}/scala/com/nvidia/spark/rapids/shims/DecimalArithmeticOverrides.scala (100%) rename sql-plugin/src/main/{311until340-all => 311until340-non330db}/scala/com/nvidia/spark/rapids/shims/GetMapValueMeta.scala (100%) rename sql-plugin/src/main/{311until340-all => 311until340-non330db}/scala/com/nvidia/spark/rapids/shims/ParquetStringPredShims.scala (100%) rename sql-plugin/src/main/{311until340-all => 311until340-non330db}/scala/com/nvidia/spark/rapids/shims/ShimFilePartitionReaderFactory.scala (100%) rename sql-plugin/src/main/{311until340-all => 311until340-non330db}/scala/com/nvidia/spark/rapids/shims/TypeUtilsShims.scala (100%) create mode 100644 sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/ShimPredicateHelper.scala rename sql-plugin/src/main/{320until340-all/scala/com/nvidia/spark/rapids/shims/Spark320until340Shims.scala => 320until340-non330db/scala/com/nvidia/spark/rapids/shims/AnsiCastRuleShims.scala} (98%) rename sql-plugin/src/main/{320until340-all => 320until340-non330db}/scala/org/apache/spark/sql/execution/datasources/rapids/DataSourceStrategyUtils.scala (100%) rename sql-plugin/src/main/{321db => 321+-db}/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuCheckDeltaInvariant.scala (96%) rename sql-plugin/src/main/{321db => 321+-db}/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuDeltaInvariantCheckerExec.scala (100%) rename sql-plugin/src/main/{321db => 321+-db}/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuDeltaLog.scala (100%) rename sql-plugin/src/main/{321db/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuOptimisticTransaction.scala => 321+-db/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuOptimisticTransactionBase.scala} (89%) rename sql-plugin/src/main/{321db => 321+-db}/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuWriteIntoDelta.scala (100%) rename sql-plugin/src/main/{321db => 321+-db}/scala/com/nvidia/spark/rapids/delta/shims/DeltaProviderImpl.scala (100%) rename sql-plugin/src/main/{321db => 321+-db}/scala/com/nvidia/spark/rapids/delta/shims/DeltaProviderShims.scala (100%) rename sql-plugin/src/main/{321db => 321+-db}/scala/com/nvidia/spark/rapids/delta/shims/GpuDeltaDataSource.scala (100%) rename sql-plugin/src/main/{321db => 321+-db}/scala/com/nvidia/spark/rapids/delta/shims/package-shims.scala (86%) rename sql-plugin/src/main/{321db => 321+-db}/scala/com/nvidia/spark/rapids/shims/AQEUtils.scala (100%) rename sql-plugin/src/main/{321db => 321+-db}/scala/com/nvidia/spark/rapids/shims/AggregationTagging.scala (100%) rename sql-plugin/src/main/{321db => 321+-db}/scala/com/nvidia/spark/rapids/shims/DeltaLakeUtils.scala (100%) rename sql-plugin/src/main/{321db => 321+-db}/scala/com/nvidia/spark/rapids/shims/ShimBroadcastExchangeLike.scala (100%) create mode 100644 sql-plugin/src/main/321+-db/scala/com/nvidia/spark/rapids/shims/Spark321PlusDBShims.scala rename sql-plugin/src/main/{321db => 321+-db}/scala/org/apache/spark/rapids/execution/shims/GpuSubqueryBroadcastMeta.scala (100%) rename sql-plugin/src/main/{321db => 321+-db}/scala/org/apache/spark/rapids/execution/shims/ReuseGpuBroadcastExchangeAndSubquery.scala (100%) rename sql-plugin/src/main/{321db => 321+-db}/scala/org/apache/spark/rapids/shims/GpuShuffleExchangeExec.scala (100%) rename sql-plugin/src/main/{321db => 321+-db}/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuFlatMapGroupsInPandasExec.scala (100%) create mode 100644 sql-plugin/src/main/321db/com/databricks/sql/transaction/tahoe/rapids/shims/GpuOptimisticTransaction.scala create mode 100644 sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/delta/shims/DeltaShims321PlusDB.scala create mode 100644 sql-plugin/src/main/330+-nondb/scala/com/nvidia/spark/rapids/shims/Spark330PlusNonDBShims.scala rename sql-plugin/src/main/{330+ => 330+-nondb}/scala/com/nvidia/spark/rapids/shims/Spark330PlusShims.scala (100%) create mode 100644 sql-plugin/src/main/330db/com/databricks/sql/transaction/tahoe/rapids/shims/GpuOptimisticTransaction.scala create mode 100644 sql-plugin/src/main/330db/scala/com/nvidia/spark/rapids/delta/shims/DeltaShims321PlusDB.scala create mode 100644 sql-plugin/src/main/330db/scala/com/nvidia/spark/rapids/shims/SparkShims.scala create mode 100644 sql-plugin/src/main/330db/scala/com/nvidia/spark/rapids/shims/spark330db/SparkShimServiceProvider.scala create mode 100644 sql-plugin/src/main/330db/scala/org/apache/spark/sql/rapids/shims/SparkDateTimeExceptionShims.scala create mode 100644 sql-plugin/src/main/330db/scala/org/apache/spark/sql/rapids/shims/SparkUpgradeExceptionShims.scala rename sql-plugin/src/main/{330until340 => 330until340-nondb}/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/ParquetCVShims.scala (100%) rename sql-plugin/src/main/{330until340 => 330until340-nondb}/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala (100%) rename sql-plugin/src/main/{330until340 => 330until340-nondb}/scala/org/apache/spark/sql/rapids/shims/SparkUpgradeExceptionShims.scala (100%) create mode 100644 sql-plugin/src/main/332/scala/com/nvidia/spark/rapids/shims/SparkShims.scala rename sql-plugin/src/main/{340+ => 340+-and-330db}/scala/com/nvidia/spark/rapids/shims/CastingConfigShim.scala (100%) rename sql-plugin/src/main/{340+ => 340+-and-330db}/scala/com/nvidia/spark/rapids/shims/DecimalArithmeticOverrides.scala (100%) rename sql-plugin/src/main/{340+ => 340+-and-330db}/scala/com/nvidia/spark/rapids/shims/GetMapValueMeta.scala (100%) rename sql-plugin/src/main/{340+ => 340+-and-330db}/scala/com/nvidia/spark/rapids/shims/ParquetStringPredShims.scala (100%) rename sql-plugin/src/main/{340+ => 340+-and-330db}/scala/com/nvidia/spark/rapids/shims/ShimFilePartitionReaderFactory.scala (100%) rename sql-plugin/src/main/{340+ => 340+-and-330db}/scala/com/nvidia/spark/rapids/shims/TypeUtilsShims.scala (100%) rename sql-plugin/src/main/{340+ => 340+-and-330db}/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/ParquetCVShims.scala (100%) rename sql-plugin/src/main/{340+ => 340+-and-330db}/scala/org/apache/spark/sql/execution/datasources/rapids/DataSourceStrategyUtils.scala (100%) rename sql-plugin/src/main/{340+ => 340+-and-330db}/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala (94%) create mode 100644 sql-plugin/src/main/340+/scala/org/apache/spark/sql/rapids/shims/SparkDateTimeExceptionShims.scala rename sql-plugin/src/main/{311+-db => }/scala/org/apache/spark/sql/rapids/shims/GpuFileScanRDD.scala (100%) diff --git a/jenkins/databricks/build.sh b/jenkins/databricks/build.sh index b19e3495ccc..e436adcb9ae 100755 --- a/jenkins/databricks/build.sh +++ b/jenkins/databricks/build.sh @@ -323,9 +323,9 @@ install_dependencies() initialize if [[ $SKIP_DEP_INSTALL == "1" ]] then - echo "SKIP_DEP_INSTALL is set to $SKIP_DEP_INSTALL. Skipping dependencies." + echo "!!!! SKIP_DEP_INSTALL is set to $SKIP_DEP_INSTALL. Skipping install-file for dependencies." else - # Install required dependencies. + echo "!!!! Installing dependendecies. Set SKIP_DEP_INSTALL=1 to speed up reruns of build.sh"# Install required dependencies. install_dependencies fi # Build the RAPIDS plugin by running package command for databricks diff --git a/jenkins/databricks/test.sh b/jenkins/databricks/test.sh index 75a59714cf0..6a1cf915ef1 100755 --- a/jenkins/databricks/test.sh +++ b/jenkins/databricks/test.sh @@ -144,7 +144,7 @@ export PYSP_TEST_spark_eventLog_enabled=true mkdir -p /tmp/spark-events ## limit parallelism to avoid OOM kill -export TEST_PARALLEL=4 +export TEST_PARALLEL=${TEST_PARALLEL:-4} if [ -d "$LOCAL_JAR_PATH" ]; then if [[ $TEST_MODE == "DEFAULT" ]]; then ## Run tests with jars in the LOCAL_JAR_PATH dir downloading from the dependency repo diff --git a/pom.xml b/pom.xml index 1f9eeaeb092..89818d5c347 100644 --- a/pom.xml +++ b/pom.xml @@ -449,6 +449,7 @@ ${spark330.version} 1.12.2 ${spark330.sources} + ${spark330.iceberg.version} ${project.basedir}/src/test/${buildver}/scala @@ -477,6 +478,7 @@ ${spark331.version} 1.12.2 ${spark331.sources} + ${spark330.iceberg.version} ${project.basedir}/src/test/${buildver}/scala @@ -505,6 +507,7 @@ ${spark332.version} 1.12.2 ${spark332.sources} + ${spark330.iceberg.version} ${project.basedir}/src/test/${buildver}/scala @@ -533,6 +536,7 @@ ${spark340.version} 1.12.3 ${spark340.sources} + ${spark330.iceberg.version} ${project.basedir}/src/test/${buildver}/scala @@ -561,6 +565,7 @@ ${spark330cdh.version} 1.10.99.7.1.8.0-801 ${spark330cdh.sources} + ${spark330.iceberg.version} ${project.basedir}/src/test/${buildver}/scala @@ -586,6 +591,48 @@ aggregator + + + + release330db + + + buildver + 330db + + + + + ${spark.version} + + 3.4.4 + spark330db + spark330db + + ${spark330db.version} + ${spark330db.version} + 3.3.1 + true + 1.12.0 + ${spark330db.sources} + ${spark330.iceberg.version} + ${project.basedir}/src/test/${buildver}/scala + + + dist + integration_tests + shuffle-plugin + sql-plugin + tests + udf-compiler + aggregator + + udf-compiler @@ -683,6 +730,7 @@ 3.3.2-SNAPSHOT 3.4.0-SNAPSHOT 3.3.0.3.3.7180.0-274 + 3.3.0-databricks 3.6.0 4.3.0 3.2.0 @@ -700,7 +748,7 @@ 3.1.0 false true - + 0.14.1 + + @@ -615,7 +619,7 @@ - + @@ -665,10 +669,20 @@ + + + + + + + + + + diff --git a/sql-plugin/src/main/311+-db/scala/com/nvidia/spark/rapids/shims/ShimLeafExecNode.scala b/sql-plugin/src/main/311+-db/scala/com/nvidia/spark/rapids/shims/ShimLeafExecNode.scala index 877442f6d90..0bc732cf80c 100644 --- a/sql-plugin/src/main/311+-db/scala/com/nvidia/spark/rapids/shims/ShimLeafExecNode.scala +++ b/sql-plugin/src/main/311+-db/scala/com/nvidia/spark/rapids/shims/ShimLeafExecNode.scala @@ -24,7 +24,7 @@ trait ShimLeafExecNode extends LeafExecNode { // For AQE support in Databricks, all Exec nodes implement computeStats(). This is actually // a recursive call to traverse the entire physical plan to aggregate this number. For the // end of the computation, this means that all LeafExecNodes must implement this method to - // avoid a stack overflow. For now, based on feedback from Databricks, Long.MaxValue is + // avoid a stack overflow. For now, based on feedback from Databricks, Long.MaxValue is // sufficient to satisfy this computation. override def computeStats(): Statistics = { Statistics( @@ -40,4 +40,6 @@ trait ShimDataSourceV2ScanExecBase extends DataSourceV2ScanExecBase { sizeInBytes = Long.MaxValue ) } + + def ordering: Option[Seq[org.apache.spark.sql.catalyst.expressions.SortOrder]] = None } \ No newline at end of file diff --git a/sql-plugin/src/main/332/scala/com/nvidia/spark/rapids/SparkShims.scala b/sql-plugin/src/main/311+-non330db/scala/com/nvidia/spark/rapids/shims/ShimExtractValue.scala similarity index 77% rename from sql-plugin/src/main/332/scala/com/nvidia/spark/rapids/SparkShims.scala rename to sql-plugin/src/main/311+-non330db/scala/com/nvidia/spark/rapids/shims/ShimExtractValue.scala index 2ced31cd978..39bbd62bbd0 100644 --- a/sql-plugin/src/main/332/scala/com/nvidia/spark/rapids/SparkShims.scala +++ b/sql-plugin/src/main/311+-non330db/scala/com/nvidia/spark/rapids/shims/ShimExtractValue.scala @@ -16,8 +16,6 @@ package com.nvidia.spark.rapids.shims -import com.nvidia.spark.rapids._ +import org.apache.spark.sql.catalyst.expressions._ -object SparkShimImpl extends Spark331PlusShims with Spark320until340Shims { - override def getSparkShimVersion: ShimVersion = ShimLoader.getShimVersion -} +trait ShimExtractValue extends ExtractValue diff --git a/sql-plugin/src/main/311-nondb/scala/com/nvidia/spark/rapids/shims/SparkShims.scala b/sql-plugin/src/main/311-nondb/scala/com/nvidia/spark/rapids/shims/SparkShims.scala index fcc1b32208b..5fc95e7ccb5 100644 --- a/sql-plugin/src/main/311-nondb/scala/com/nvidia/spark/rapids/shims/SparkShims.scala +++ b/sql-plugin/src/main/311-nondb/scala/com/nvidia/spark/rapids/shims/SparkShims.scala @@ -16,15 +16,11 @@ package com.nvidia.spark.rapids.shims -import com.nvidia.spark.rapids._ import org.apache.parquet.schema.MessageType import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters object SparkShimImpl extends Spark31XShims { - - override def getSparkShimVersion: ShimVersion = ShimLoader.getShimVersion - override def hasCastFloatTimestampUpcast: Boolean = false override def reproduceEmptyStringBug: Boolean = true diff --git a/sql-plugin/src/main/311until320-all/scala/com/nvidia/spark/rapids/shims/ShimPredicateHelper.scala b/sql-plugin/src/main/311until320-all/scala/com/nvidia/spark/rapids/shims/ShimPredicateHelper.scala new file mode 100644 index 00000000000..3767a4369e1 --- /dev/null +++ b/sql-plugin/src/main/311until320-all/scala/com/nvidia/spark/rapids/shims/ShimPredicateHelper.scala @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.rapids._ + +trait ShimPredicateHelper extends PredicateHelper { + // SPARK-30027 from 3.2.0 + // If one expression and its children are null intolerant, it is null intolerant. + protected def isNullIntolerant(expr: Expression): Boolean = expr match { + case e: NullIntolerant => e.children.forall(isNullIntolerant) + case _ => false + } + + override protected def splitConjunctivePredicates( + condition: Expression + ): Seq[Expression] = { + condition match { + case GpuAnd(cond1, cond2) => + splitConjunctivePredicates(cond1) ++ splitConjunctivePredicates(cond2) + case other => super.splitConjunctivePredicates(condition) + } + } +} \ No newline at end of file diff --git a/sql-plugin/src/main/311until340-all/scala/com/nvidia/spark/rapids/shims/CastingConfigShim.scala b/sql-plugin/src/main/311until340-non330db/scala/com/nvidia/spark/rapids/shims/CastingConfigShim.scala similarity index 100% rename from sql-plugin/src/main/311until340-all/scala/com/nvidia/spark/rapids/shims/CastingConfigShim.scala rename to sql-plugin/src/main/311until340-non330db/scala/com/nvidia/spark/rapids/shims/CastingConfigShim.scala diff --git a/sql-plugin/src/main/311until340-all/scala/com/nvidia/spark/rapids/shims/DecimalArithmeticOverrides.scala b/sql-plugin/src/main/311until340-non330db/scala/com/nvidia/spark/rapids/shims/DecimalArithmeticOverrides.scala similarity index 100% rename from sql-plugin/src/main/311until340-all/scala/com/nvidia/spark/rapids/shims/DecimalArithmeticOverrides.scala rename to sql-plugin/src/main/311until340-non330db/scala/com/nvidia/spark/rapids/shims/DecimalArithmeticOverrides.scala diff --git a/sql-plugin/src/main/311until340-all/scala/com/nvidia/spark/rapids/shims/GetMapValueMeta.scala b/sql-plugin/src/main/311until340-non330db/scala/com/nvidia/spark/rapids/shims/GetMapValueMeta.scala similarity index 100% rename from sql-plugin/src/main/311until340-all/scala/com/nvidia/spark/rapids/shims/GetMapValueMeta.scala rename to sql-plugin/src/main/311until340-non330db/scala/com/nvidia/spark/rapids/shims/GetMapValueMeta.scala diff --git a/sql-plugin/src/main/311until340-all/scala/com/nvidia/spark/rapids/shims/ParquetStringPredShims.scala b/sql-plugin/src/main/311until340-non330db/scala/com/nvidia/spark/rapids/shims/ParquetStringPredShims.scala similarity index 100% rename from sql-plugin/src/main/311until340-all/scala/com/nvidia/spark/rapids/shims/ParquetStringPredShims.scala rename to sql-plugin/src/main/311until340-non330db/scala/com/nvidia/spark/rapids/shims/ParquetStringPredShims.scala diff --git a/sql-plugin/src/main/311until340-all/scala/com/nvidia/spark/rapids/shims/ShimFilePartitionReaderFactory.scala b/sql-plugin/src/main/311until340-non330db/scala/com/nvidia/spark/rapids/shims/ShimFilePartitionReaderFactory.scala similarity index 100% rename from sql-plugin/src/main/311until340-all/scala/com/nvidia/spark/rapids/shims/ShimFilePartitionReaderFactory.scala rename to sql-plugin/src/main/311until340-non330db/scala/com/nvidia/spark/rapids/shims/ShimFilePartitionReaderFactory.scala diff --git a/sql-plugin/src/main/311until340-all/scala/com/nvidia/spark/rapids/shims/TypeUtilsShims.scala b/sql-plugin/src/main/311until340-non330db/scala/com/nvidia/spark/rapids/shims/TypeUtilsShims.scala similarity index 100% rename from sql-plugin/src/main/311until340-all/scala/com/nvidia/spark/rapids/shims/TypeUtilsShims.scala rename to sql-plugin/src/main/311until340-non330db/scala/com/nvidia/spark/rapids/shims/TypeUtilsShims.scala diff --git a/sql-plugin/src/main/312-nondb/scala/com/nvidia/spark/rapids/shims/SparkShims.scala b/sql-plugin/src/main/312-nondb/scala/com/nvidia/spark/rapids/shims/SparkShims.scala index c36e4e32ee6..27ab277ab4f 100644 --- a/sql-plugin/src/main/312-nondb/scala/com/nvidia/spark/rapids/shims/SparkShims.scala +++ b/sql-plugin/src/main/312-nondb/scala/com/nvidia/spark/rapids/shims/SparkShims.scala @@ -16,15 +16,11 @@ package com.nvidia.spark.rapids.shims -import com.nvidia.spark.rapids._ import org.apache.parquet.schema.MessageType import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters object SparkShimImpl extends Spark31XShims { - - override def getSparkShimVersion: ShimVersion = ShimLoader.getShimVersion - override def hasCastFloatTimestampUpcast: Boolean = true override def reproduceEmptyStringBug: Boolean = true diff --git a/sql-plugin/src/main/312db/scala/com/nvidia/spark/rapids/shims/SparkShims.scala b/sql-plugin/src/main/312db/scala/com/nvidia/spark/rapids/shims/SparkShims.scala index 09fe4b6999f..fec2de36959 100644 --- a/sql-plugin/src/main/312db/scala/com/nvidia/spark/rapids/shims/SparkShims.scala +++ b/sql-plugin/src/main/312db/scala/com/nvidia/spark/rapids/shims/SparkShims.scala @@ -16,16 +16,12 @@ package com.nvidia.spark.rapids.shims -import com.nvidia.spark.rapids._ import org.apache.parquet.schema.MessageType import org.apache.spark.sql.execution.datasources.DataSourceUtils import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters object SparkShimImpl extends Spark31XdbShims { - - override def getSparkShimVersion: ShimVersion = ShimLoader.getShimVersion - override def getParquetFilters( schema: MessageType, pushDownDate: Boolean, diff --git a/sql-plugin/src/main/313/scala/com/nvidia/spark/rapids/shims/SparkShims.scala b/sql-plugin/src/main/313/scala/com/nvidia/spark/rapids/shims/SparkShims.scala index 16a8ab81ddc..53152532145 100644 --- a/sql-plugin/src/main/313/scala/com/nvidia/spark/rapids/shims/SparkShims.scala +++ b/sql-plugin/src/main/313/scala/com/nvidia/spark/rapids/shims/SparkShims.scala @@ -16,16 +16,12 @@ package com.nvidia.spark.rapids.shims -import com.nvidia.spark.rapids._ import org.apache.parquet.schema.MessageType import org.apache.spark.sql.execution.datasources.DataSourceUtils import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters object SparkShimImpl extends Spark31XShims { - - override def getSparkShimVersion: ShimVersion = ShimLoader.getShimVersion - override def getParquetFilters( schema: MessageType, pushDownDate: Boolean, diff --git a/sql-plugin/src/main/314/scala/com/nvidia/spark/rapids/shims/SparkShims.scala b/sql-plugin/src/main/314/scala/com/nvidia/spark/rapids/shims/SparkShims.scala index f01f7ac5a63..5876c0d44e7 100644 --- a/sql-plugin/src/main/314/scala/com/nvidia/spark/rapids/shims/SparkShims.scala +++ b/sql-plugin/src/main/314/scala/com/nvidia/spark/rapids/shims/SparkShims.scala @@ -16,16 +16,12 @@ package com.nvidia.spark.rapids.shims -import com.nvidia.spark.rapids._ import org.apache.parquet.schema.MessageType import org.apache.spark.sql.execution.datasources.DataSourceUtils import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters object SparkShimImpl extends Spark31XShims { - - override def getSparkShimVersion: ShimVersion = ShimLoader.getShimVersion - override def getParquetFilters( schema: MessageType, pushDownDate: Boolean, @@ -45,4 +41,6 @@ object SparkShimImpl extends Spark31XShims { override def hasCastFloatTimestampUpcast: Boolean = true override def isCastingStringToNegDecimalScaleSupported: Boolean = true + + override def reproduceEmptyStringBug: Boolean = true } diff --git a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/ShimPredicateHelper.scala b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/ShimPredicateHelper.scala new file mode 100644 index 00000000000..6839eaeffc7 --- /dev/null +++ b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/ShimPredicateHelper.scala @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.rapids._ + +trait ShimPredicateHelper extends PredicateHelper { + // SPARK-30027 provides isNullIntolerant + override protected def splitConjunctivePredicates( + condition: Expression + ): Seq[Expression] = { + condition match { + case GpuAnd(cond1, cond2) => + splitConjunctivePredicates(cond1) ++ splitConjunctivePredicates(cond2) + case other => super.splitConjunctivePredicates(condition) + } + } +} \ No newline at end of file diff --git a/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/SparkShims.scala b/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/SparkShims.scala index 3be7128e20e..5b64f8e2ef0 100644 --- a/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/SparkShims.scala +++ b/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/SparkShims.scala @@ -16,7 +16,6 @@ package com.nvidia.spark.rapids.shims -import com.nvidia.spark.rapids._ import org.apache.parquet.schema.MessageType import org.apache.spark.sql.execution.datasources.DataSourceUtils @@ -25,8 +24,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters object SparkShimImpl extends Spark320PlusShims with Spark320PlusNonDBShims with Spark31Xuntil33XShims - with Spark320until340Shims { - override def getSparkShimVersion: ShimVersion = ShimLoader.getShimVersion + with AnsiCastRuleShims { override def getParquetFilters( schema: MessageType, diff --git a/sql-plugin/src/main/320until340-all/scala/com/nvidia/spark/rapids/shims/Spark320until340Shims.scala b/sql-plugin/src/main/320until340-non330db/scala/com/nvidia/spark/rapids/shims/AnsiCastRuleShims.scala similarity index 98% rename from sql-plugin/src/main/320until340-all/scala/com/nvidia/spark/rapids/shims/Spark320until340Shims.scala rename to sql-plugin/src/main/320until340-non330db/scala/com/nvidia/spark/rapids/shims/AnsiCastRuleShims.scala index fff960e704c..546af5ef948 100644 --- a/sql-plugin/src/main/320until340-all/scala/com/nvidia/spark/rapids/shims/Spark320until340Shims.scala +++ b/sql-plugin/src/main/320until340-non330db/scala/com/nvidia/spark/rapids/shims/AnsiCastRuleShims.scala @@ -19,7 +19,7 @@ import com.nvidia.spark.rapids._ import org.apache.spark.sql.catalyst.expressions.{AnsiCast, Expression} -trait Spark320until340Shims extends SparkShims { +trait AnsiCastRuleShims extends SparkShims { override def ansiCastRule: ExprRule[ _ <: Expression] = { GpuOverrides.expr[AnsiCast]( diff --git a/sql-plugin/src/main/320until340-all/scala/org/apache/spark/sql/execution/datasources/rapids/DataSourceStrategyUtils.scala b/sql-plugin/src/main/320until340-non330db/scala/org/apache/spark/sql/execution/datasources/rapids/DataSourceStrategyUtils.scala similarity index 100% rename from sql-plugin/src/main/320until340-all/scala/org/apache/spark/sql/execution/datasources/rapids/DataSourceStrategyUtils.scala rename to sql-plugin/src/main/320until340-non330db/scala/org/apache/spark/sql/execution/datasources/rapids/DataSourceStrategyUtils.scala diff --git a/sql-plugin/src/main/321db/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuCheckDeltaInvariant.scala b/sql-plugin/src/main/321+-db/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuCheckDeltaInvariant.scala similarity index 96% rename from sql-plugin/src/main/321db/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuCheckDeltaInvariant.scala rename to sql-plugin/src/main/321+-db/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuCheckDeltaInvariant.scala index f00a8b80aa0..c55433ecb32 100644 --- a/sql-plugin/src/main/321db/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuCheckDeltaInvariant.scala +++ b/sql-plugin/src/main/321+-db/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuCheckDeltaInvariant.scala @@ -24,8 +24,8 @@ package com.databricks.sql.transaction.tahoe.rapids.shims import ai.rapids.cudf.{ColumnVector, Scalar} import com.databricks.sql.transaction.tahoe.constraints.{CheckDeltaInvariant, Constraint} import com.databricks.sql.transaction.tahoe.constraints.Constraints.{Check, NotNull} -import com.databricks.sql.transaction.tahoe.schema.InvariantViolationException import com.nvidia.spark.rapids.{DataFromReplacementRule, ExprChecks, GpuBindReferences, GpuColumnVector, GpuExpression, GpuExpressionsUtils, GpuOverrides, RapidsConf, RapidsMeta, TypeSig, UnaryExprMeta} +import com.nvidia.spark.rapids.delta.shims.InvariantViolationExceptionShim import com.nvidia.spark.rapids.shims.ShimUnaryExpression import org.apache.spark.internal.Logging @@ -66,7 +66,7 @@ case class GpuCheckDeltaInvariant( constraint match { case n: NotNull => if (col.getBase.hasNulls) { - throw InvariantViolationException(n) + throw InvariantViolationExceptionShim(n) } case c: Check => if (col.getBase.hasNulls || hasFalse(col.getBase)) { @@ -118,7 +118,7 @@ case class GpuCheckDeltaInvariant( val hostBatch = new ColumnarBatch(filteredHostCols.toArray, filteredHostCols(0).getBase.getRowCount.toInt) val row = hostBatch.getRow(0) - throw InvariantViolationException(check, columnExtractors.mapValues(_.eval(row)).toMap) + throw InvariantViolationExceptionShim(check, columnExtractors.mapValues(_.eval(row)).toMap) } } } diff --git a/sql-plugin/src/main/321db/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuDeltaInvariantCheckerExec.scala b/sql-plugin/src/main/321+-db/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuDeltaInvariantCheckerExec.scala similarity index 100% rename from sql-plugin/src/main/321db/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuDeltaInvariantCheckerExec.scala rename to sql-plugin/src/main/321+-db/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuDeltaInvariantCheckerExec.scala diff --git a/sql-plugin/src/main/321db/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuDeltaLog.scala b/sql-plugin/src/main/321+-db/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuDeltaLog.scala similarity index 100% rename from sql-plugin/src/main/321db/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuDeltaLog.scala rename to sql-plugin/src/main/321+-db/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuDeltaLog.scala diff --git a/sql-plugin/src/main/321db/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuOptimisticTransaction.scala b/sql-plugin/src/main/321+-db/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuOptimisticTransactionBase.scala similarity index 89% rename from sql-plugin/src/main/321db/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuOptimisticTransaction.scala rename to sql-plugin/src/main/321+-db/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuOptimisticTransactionBase.scala index fc0dfbe28d7..02a159c6a23 100644 --- a/sql-plugin/src/main/321db/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuOptimisticTransaction.scala +++ b/sql-plugin/src/main/321+-db/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuOptimisticTransactionBase.scala @@ -28,7 +28,6 @@ import scala.collection.mutable.ListBuffer import ai.rapids.cudf.ColumnView import com.databricks.sql.transaction.tahoe._ import com.databricks.sql.transaction.tahoe.actions.FileAction -import com.databricks.sql.transaction.tahoe.commands.cdc.CDCReader import com.databricks.sql.transaction.tahoe.constraints.{Constraint, Constraints, DeltaInvariantCheckerExec} import com.databricks.sql.transaction.tahoe.metering.DeltaLogging import com.databricks.sql.transaction.tahoe.schema.InvariantViolationException @@ -45,11 +44,11 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, FileFormatWriter} -import org.apache.spark.sql.functions.{col, to_json} +import org.apache.spark.sql.functions.to_json import org.apache.spark.sql.rapids.{BasicColumnarWriteJobStatsTracker, ColumnarWriteJobStatsTracker, GpuFileFormatWriter} import org.apache.spark.sql.rapids.GpuV1WriteUtils.GpuEmpty2Null import org.apache.spark.sql.rapids.delta.GpuIdentityColumn -import org.apache.spark.sql.types.{StringType, StructField, StructType} +import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.{Clock, SerializableConfiguration} /** @@ -64,7 +63,7 @@ import org.apache.spark.util.{Clock, SerializableConfiguration} * @param snapshot The snapshot that this transaction is reading at. * @param rapidsConf RAPIDS Accelerator config settings. */ -class GpuOptimisticTransaction +abstract class GpuOptimisticTransactionBase (deltaLog: DeltaLog, snapshot: Snapshot, rapidsConf: RapidsConf) (implicit clock: Clock) extends OptimisticTransaction(deltaLog, snapshot)(clock) @@ -80,28 +79,6 @@ class GpuOptimisticTransaction this(deltaLog, deltaLog.update(), rapidsConf) } - /** - * Returns a tuple of (data, partition schema). For CDC writes, a `__is_cdc` column is added to - * the data and `__is_cdc=true/false` is added to the front of the partition schema. - */ - protected def performCDCPartition(inputData: Dataset[_]): (DataFrame, StructType) = { - // If this is a CDC write, we need to generate the CDC_PARTITION_COL in order to properly - // dispatch rows between the main table and CDC event records. This is a virtual partition - // and will be stripped out later in [[DelayedCommitProtocolEdge]]. - // Note that the ordering of the partition schema is relevant - CDC_PARTITION_COL must - // come first in order to ensure CDC data lands in the right place. - if (CDCReader.isCDCEnabledOnTable(metadata) && - inputData.schema.fieldNames.contains(CDCReader.CDC_TYPE_COLUMN_NAME)) { - val augmentedData = inputData.withColumn( - CDCReader.CDC_PARTITION_COL, col(CDCReader.CDC_TYPE_COLUMN_NAME).isNotNull) - val partitionSchema = StructType( - StructField(CDCReader.CDC_PARTITION_COL, StringType) +: metadata.physicalPartitionSchema) - (augmentedData, partitionSchema) - } else { - (inputData.toDF(), metadata.physicalPartitionSchema) - } - } - /** * Adds checking of constraints on the table * @param plan Plan to generate the table to check against constraints @@ -185,6 +162,8 @@ class GpuOptimisticTransaction writeFiles(inputData, None, additionalConstraints) } + private[shims] def shimPerformCDCPartition(inputData: Dataset[_]): (DataFrame, StructType) + override def writeFiles( inputData: Dataset[_], writeOptions: Option[DeltaOptions], @@ -192,7 +171,7 @@ class GpuOptimisticTransaction hasWritten = true val spark = inputData.sparkSession - val (data, partitionSchema) = performCDCPartition(inputData) + val (data, partitionSchema) = shimPerformCDCPartition(inputData) val outputPath = deltaLog.dataPath val (queryExecution, output, generatedColumnConstraints, dataHighWaterMarks) = diff --git a/sql-plugin/src/main/321db/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuWriteIntoDelta.scala b/sql-plugin/src/main/321+-db/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuWriteIntoDelta.scala similarity index 100% rename from sql-plugin/src/main/321db/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuWriteIntoDelta.scala rename to sql-plugin/src/main/321+-db/scala/com/databricks/sql/transaction/tahoe/rapids/shims/GpuWriteIntoDelta.scala diff --git a/sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/delta/shims/DeltaProviderImpl.scala b/sql-plugin/src/main/321+-db/scala/com/nvidia/spark/rapids/delta/shims/DeltaProviderImpl.scala similarity index 100% rename from sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/delta/shims/DeltaProviderImpl.scala rename to sql-plugin/src/main/321+-db/scala/com/nvidia/spark/rapids/delta/shims/DeltaProviderImpl.scala diff --git a/sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/delta/shims/DeltaProviderShims.scala b/sql-plugin/src/main/321+-db/scala/com/nvidia/spark/rapids/delta/shims/DeltaProviderShims.scala similarity index 100% rename from sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/delta/shims/DeltaProviderShims.scala rename to sql-plugin/src/main/321+-db/scala/com/nvidia/spark/rapids/delta/shims/DeltaProviderShims.scala diff --git a/sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/delta/shims/GpuDeltaDataSource.scala b/sql-plugin/src/main/321+-db/scala/com/nvidia/spark/rapids/delta/shims/GpuDeltaDataSource.scala similarity index 100% rename from sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/delta/shims/GpuDeltaDataSource.scala rename to sql-plugin/src/main/321+-db/scala/com/nvidia/spark/rapids/delta/shims/GpuDeltaDataSource.scala diff --git a/sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/delta/shims/package-shims.scala b/sql-plugin/src/main/321+-db/scala/com/nvidia/spark/rapids/delta/shims/package-shims.scala similarity index 86% rename from sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/delta/shims/package-shims.scala rename to sql-plugin/src/main/321+-db/scala/com/nvidia/spark/rapids/delta/shims/package-shims.scala index 42acb42e51e..467f16986aa 100644 --- a/sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/delta/shims/package-shims.scala +++ b/sql-plugin/src/main/321+-db/scala/com/nvidia/spark/rapids/delta/shims/package-shims.scala @@ -17,22 +17,17 @@ package com.nvidia.spark.rapids.delta.shims import com.databricks.sql.expressions.JoinedProjection -import com.databricks.sql.transaction.tahoe.{DeltaColumnMapping, DeltaUDF} +import com.databricks.sql.transaction.tahoe.DeltaColumnMapping import com.databricks.sql.transaction.tahoe.stats.UsesMetadataFields import com.databricks.sql.transaction.tahoe.util.JsonUtils import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.types.StructField object ShimDeltaColumnMapping { def getPhysicalName(field: StructField): String = DeltaColumnMapping.getPhysicalName(field) } -object ShimDeltaUDF { - def stringStringUdf(f: String => String): UserDefinedFunction = DeltaUDF.stringStringUdf(f) -} - object ShimJoinedProjection { def bind( leftAttributes: Seq[Attribute], diff --git a/sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/shims/AQEUtils.scala b/sql-plugin/src/main/321+-db/scala/com/nvidia/spark/rapids/shims/AQEUtils.scala similarity index 100% rename from sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/shims/AQEUtils.scala rename to sql-plugin/src/main/321+-db/scala/com/nvidia/spark/rapids/shims/AQEUtils.scala diff --git a/sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/shims/AggregationTagging.scala b/sql-plugin/src/main/321+-db/scala/com/nvidia/spark/rapids/shims/AggregationTagging.scala similarity index 100% rename from sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/shims/AggregationTagging.scala rename to sql-plugin/src/main/321+-db/scala/com/nvidia/spark/rapids/shims/AggregationTagging.scala diff --git a/sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/shims/DeltaLakeUtils.scala b/sql-plugin/src/main/321+-db/scala/com/nvidia/spark/rapids/shims/DeltaLakeUtils.scala similarity index 100% rename from sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/shims/DeltaLakeUtils.scala rename to sql-plugin/src/main/321+-db/scala/com/nvidia/spark/rapids/shims/DeltaLakeUtils.scala diff --git a/sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/shims/ShimBroadcastExchangeLike.scala b/sql-plugin/src/main/321+-db/scala/com/nvidia/spark/rapids/shims/ShimBroadcastExchangeLike.scala similarity index 100% rename from sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/shims/ShimBroadcastExchangeLike.scala rename to sql-plugin/src/main/321+-db/scala/com/nvidia/spark/rapids/shims/ShimBroadcastExchangeLike.scala diff --git a/sql-plugin/src/main/321+-db/scala/com/nvidia/spark/rapids/shims/Spark321PlusDBShims.scala b/sql-plugin/src/main/321+-db/scala/com/nvidia/spark/rapids/shims/Spark321PlusDBShims.scala new file mode 100644 index 00000000000..4678723c279 --- /dev/null +++ b/sql-plugin/src/main/321+-db/scala/com/nvidia/spark/rapids/shims/Spark321PlusDBShims.scala @@ -0,0 +1,291 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims + +import com.databricks.sql.execution.window.RunningWindowFunctionExec +import com.databricks.sql.optimizer.PlanDynamicPruningFilters +import com.nvidia.spark.rapids._ +import org.apache.hadoop.fs.FileStatus + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive._ +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.json.JsonFileFormat +import org.apache.spark.sql.execution.exchange._ +import org.apache.spark.sql.execution.python._ +import org.apache.spark.sql.execution.window._ +import org.apache.spark.sql.rapids.GpuFileSourceScanExec +import org.apache.spark.sql.rapids.execution._ +import org.apache.spark.sql.rapids.execution.shims.{GpuSubqueryBroadcastMeta,ReuseGpuBroadcastExchangeAndSubquery} +import org.apache.spark.sql.rapids.shims._ +import org.apache.spark.sql.types._ + + +trait Spark321PlusDBShims extends SparkShims + with Spark321PlusShims { + override def isCastingStringToNegDecimalScaleSupported: Boolean = true + + override def getFileScanRDD( + sparkSession: SparkSession, + readFunction: PartitionedFile => Iterator[InternalRow], + filePartitions: Seq[FilePartition], + readDataSchema: StructType, + metadataColumns: Seq[AttributeReference]): RDD[InternalRow] = { + new GpuFileScanRDD(sparkSession, readFunction, filePartitions) + } + + override def broadcastModeTransform(mode: BroadcastMode, rows: Array[InternalRow]): Any = { + // In some cases we can be asked to transform when there's no task context, which appears to + // be new behavior since Databricks 10.4. A task memory manager must be passed, so if one is + // not available we construct one from the main memory manager using a task attempt ID of 0. + val memoryManager = Option(TaskContext.get).map(_.taskMemoryManager()).getOrElse { + new TaskMemoryManager(SparkEnv.get.memoryManager, 0) + } + mode.transform(rows, memoryManager) + } + + override def newBroadcastQueryStageExec( + old: BroadcastQueryStageExec, + newPlan: SparkPlan): BroadcastQueryStageExec = + BroadcastQueryStageExec(old.id, newPlan, old.originalPlan, old.isSparkExchange) + + override def filesFromFileIndex(fileCatalog: PartitioningAwareFileIndex): Seq[FileStatus] = { + fileCatalog.allFiles().map(_.toFileStatus) + } + + override def neverReplaceShowCurrentNamespaceCommand: ExecRule[_ <: SparkPlan] = null + + override def getWindowExpressions(winPy: WindowInPandasExec): Seq[NamedExpression] = + winPy.projectList + + override def isWindowFunctionExec(plan: SparkPlan): Boolean = + plan.isInstanceOf[WindowExecBase] || plan.isInstanceOf[RunningWindowFunctionExec] + + override def applyShimPlanRules(plan: SparkPlan, conf: RapidsConf): SparkPlan = { + if (plan.conf.adaptiveExecutionEnabled) { + plan // AQE+DPP cooperation ensures the optimization runs early + } else { + val sparkSession = plan.session + val rules = Seq( + PlanDynamicPruningFilters(sparkSession) + ) + rules.foldLeft(plan) { case (sp, rule) => + rule.apply(sp) + } + } + } + + override def applyPostShimPlanRules(plan: SparkPlan): SparkPlan = { + val rules = Seq( + ReuseGpuBroadcastExchangeAndSubquery + ) + rules.foldLeft(plan) { case (sp, rule) => + rule.apply(sp) + } + } + + private val shimExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = { + Seq( + GpuOverrides.exec[SubqueryBroadcastExec]( + "Plan to collect and transform the broadcast key values", + ExecChecks(TypeSig.all, TypeSig.all), + (s, conf, p, r) => new GpuSubqueryBroadcastMeta(s, conf, p, r)), + GpuOverrides.exec[FileSourceScanExec]( + "Reading data from files, often from Hive tables", + ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP + + TypeSig.ARRAY + TypeSig.BINARY + TypeSig.DECIMAL_128).nested(), TypeSig.all), + (fsse, conf, p, r) => new SparkPlanMeta[FileSourceScanExec](fsse, conf, p, r) { + + // Replaces SubqueryBroadcastExec inside dynamic pruning filters with GPU counterpart + // if possible. Instead regarding filters as childExprs of current Meta, we create + // a new meta for SubqueryBroadcastExec. The reason is that the GPU replacement of + // FileSourceScan is independent from the replacement of the partitionFilters. It is + // possible that the FileSourceScan is on the CPU, while the dynamic partitionFilters + // are on the GPU. And vice versa. + private lazy val partitionFilters = { + val convertBroadcast = (bc: SubqueryBroadcastExec) => { + val meta = GpuOverrides.wrapAndTagPlan(bc, conf) + meta.tagForExplain() + val converted = meta.convertIfNeeded() + // Because the PlanSubqueries rule is not called (and does not work as expected), + // we might actually have to fully convert the subquery plan as the plugin would + // intend (in this case calling GpuTransitionOverrides to insert GpuCoalesceBatches, + // etc.) to match the other side of the join to reuse the BroadcastExchange. + // This happens when SubqueryBroadcast has the original (Gpu)BroadcastExchangeExec + converted match { + case e: GpuSubqueryBroadcastExec => e.child match { + // If the GpuBroadcastExchange is here, then we will need to run the transition + // overrides here + case _: GpuBroadcastExchangeExec => + var updated = ApplyColumnarRulesAndInsertTransitions(Seq(), true) + .apply(converted) + updated = (new GpuTransitionOverrides()).apply(updated) + updated match { + case h: GpuBringBackToHost => + h.child.asInstanceOf[BaseSubqueryExec] + case c2r: GpuColumnarToRowExec => + c2r.child.asInstanceOf[BaseSubqueryExec] + case _: GpuSubqueryBroadcastExec => + updated.asInstanceOf[BaseSubqueryExec] + } + // Otherwise, if this SubqueryBroadcast is using a ReusedExchange, then we don't + // do anything further + case _: ReusedExchangeExec => + converted.asInstanceOf[BaseSubqueryExec] + } + case _ => + converted.asInstanceOf[BaseSubqueryExec] + } + } + + wrapped.partitionFilters.map { filter => + filter.transformDown { + case dpe @ DynamicPruningExpression(inSub: InSubqueryExec) => + inSub.plan match { + case bc: SubqueryBroadcastExec => + dpe.copy(inSub.copy(plan = convertBroadcast(bc))) + case reuse @ ReusedSubqueryExec(bc: SubqueryBroadcastExec) => + dpe.copy(inSub.copy(plan = reuse.copy(convertBroadcast(bc)))) + case _ => + dpe + } + } + } + } + + // partition filters and data filters are not run on the GPU + override val childExprs: Seq[ExprMeta[_]] = Seq.empty + + override def tagPlanForGpu(): Unit = { + // this is very specific check to have any of the Delta log metadata queries + // fallback and run on the CPU since there is some incompatibilities in + // Databricks Spark and Apache Spark. + if (wrapped.relation.fileFormat.isInstanceOf[JsonFileFormat] && + wrapped.relation.location.getClass.getCanonicalName() == + "com.databricks.sql.transaction.tahoe.DeltaLogFileIndex") { + this.entirePlanWillNotWork("Plans that read Delta Index JSON files can not run " + + "any part of the plan on the GPU!") + } + GpuFileSourceScanExec.tagSupport(this) + } + + override def convertToCpu(): SparkPlan = { + wrapped.copy(partitionFilters = partitionFilters) + } + + override def convertToGpu(): GpuExec = { + val sparkSession = wrapped.relation.sparkSession + val options = wrapped.relation.options + val (location, alluxioPathsToReplaceMap) = + if (AlluxioCfgUtils.enabledAlluxioReplacementAlgoConvertTime(conf)) { + val shouldReadFromS3 = wrapped.relation.location match { + // Only handle InMemoryFileIndex + // + // skip handle `MetadataLogFileIndex`, from the description of this class: + // it's about the files generated by the `FileStreamSink`. + // The streaming data source is not in our scope. + // + // For CatalogFileIndex and FileIndex of `delta` data source, + // need more investigation. + case inMemory: InMemoryFileIndex => + // List all the partitions to reduce overhead, pass in 2 empty filters. + // Subsequent process will do the right partition pruning. + // This operation is fast, because it lists files from the caches and the caches + // already exist in the `InMemoryFileIndex`. + val pds = inMemory.listFiles(Seq.empty, Seq.empty) + AlluxioUtils.shouldReadDirectlyFromS3(conf, pds) + case _ => + false + } + + if (!shouldReadFromS3) { + // it's convert time algorithm and some paths are not large tables + AlluxioUtils.replacePathIfNeeded( + conf, + wrapped.relation, + partitionFilters, + wrapped.dataFilters) + } else { + // convert time algorithm and read large files + (wrapped.relation.location, None) + } + } else { + // it's not convert time algorithm or read large files, do not replace + (wrapped.relation.location, None) + } + + val newRelation = HadoopFsRelation( + location, + wrapped.relation.partitionSchema, + wrapped.relation.dataSchema, + wrapped.relation.bucketSpec, + GpuFileSourceScanExec.convertFileFormat(wrapped.relation.fileFormat), + options)(sparkSession) + + GpuFileSourceScanExec( + newRelation, + wrapped.output, + wrapped.requiredSchema, + partitionFilters, + wrapped.optionalBucketSet, + // TODO: Does Databricks have coalesced bucketing implemented? + None, + wrapped.dataFilters, + wrapped.tableIdentifier, + wrapped.disableBucketedScan, + queryUsesInputFile = false, + alluxioPathsToReplaceMap)(conf) + } + }), + GpuOverrides.exec[RunningWindowFunctionExec]( + "Databricks-specific window function exec, for \"running\" windows, " + + "i.e. (UNBOUNDED PRECEDING TO CURRENT ROW)", + ExecChecks( + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP).nested(), + TypeSig.all, + Map("partitionSpec" -> + InputCheck(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + TypeSig.all))), + (runningWindowFunctionExec, conf, p, r) => + new GpuRunningWindowExecMeta(runningWindowFunctionExec, conf, p, r) + ) + ).map(r => (r.getClassFor.asSubclass(classOf[SparkPlan]), r)).toMap + } + + override def getExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = + super.getExecs ++ shimExecs + + /** + * Case class ShuffleQueryStageExec holds an additional field shuffleOrigin + * affecting the unapply method signature + */ + override def reusedExchangeExecPfn: PartialFunction[SparkPlan, ReusedExchangeExec] = { + case ShuffleQueryStageExec(_, e: ReusedExchangeExec, _, _) => e + case BroadcastQueryStageExec(_, e: ReusedExchangeExec, _, _) => e + } + + override def reproduceEmptyStringBug: Boolean = true +} \ No newline at end of file diff --git a/sql-plugin/src/main/321db/scala/org/apache/spark/rapids/execution/shims/GpuSubqueryBroadcastMeta.scala b/sql-plugin/src/main/321+-db/scala/org/apache/spark/rapids/execution/shims/GpuSubqueryBroadcastMeta.scala similarity index 100% rename from sql-plugin/src/main/321db/scala/org/apache/spark/rapids/execution/shims/GpuSubqueryBroadcastMeta.scala rename to sql-plugin/src/main/321+-db/scala/org/apache/spark/rapids/execution/shims/GpuSubqueryBroadcastMeta.scala diff --git a/sql-plugin/src/main/321db/scala/org/apache/spark/rapids/execution/shims/ReuseGpuBroadcastExchangeAndSubquery.scala b/sql-plugin/src/main/321+-db/scala/org/apache/spark/rapids/execution/shims/ReuseGpuBroadcastExchangeAndSubquery.scala similarity index 100% rename from sql-plugin/src/main/321db/scala/org/apache/spark/rapids/execution/shims/ReuseGpuBroadcastExchangeAndSubquery.scala rename to sql-plugin/src/main/321+-db/scala/org/apache/spark/rapids/execution/shims/ReuseGpuBroadcastExchangeAndSubquery.scala diff --git a/sql-plugin/src/main/321db/scala/org/apache/spark/rapids/shims/GpuShuffleExchangeExec.scala b/sql-plugin/src/main/321+-db/scala/org/apache/spark/rapids/shims/GpuShuffleExchangeExec.scala similarity index 100% rename from sql-plugin/src/main/321db/scala/org/apache/spark/rapids/shims/GpuShuffleExchangeExec.scala rename to sql-plugin/src/main/321+-db/scala/org/apache/spark/rapids/shims/GpuShuffleExchangeExec.scala diff --git a/sql-plugin/src/main/321db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuFlatMapGroupsInPandasExec.scala b/sql-plugin/src/main/321+-db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuFlatMapGroupsInPandasExec.scala similarity index 100% rename from sql-plugin/src/main/321db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuFlatMapGroupsInPandasExec.scala rename to sql-plugin/src/main/321+-db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuFlatMapGroupsInPandasExec.scala diff --git a/sql-plugin/src/main/321/scala/com/nvidia/spark/rapids/shims/SparkShims.scala b/sql-plugin/src/main/321/scala/com/nvidia/spark/rapids/shims/SparkShims.scala index d0475439f74..ec502eedbf1 100644 --- a/sql-plugin/src/main/321/scala/com/nvidia/spark/rapids/shims/SparkShims.scala +++ b/sql-plugin/src/main/321/scala/com/nvidia/spark/rapids/shims/SparkShims.scala @@ -21,7 +21,7 @@ import com.nvidia.spark.rapids._ object SparkShimImpl extends Spark321PlusShims with Spark320PlusNonDBShims with Spark31Xuntil33XShims - with Spark320until340Shims { + with AnsiCastRuleShims { override def getSparkShimVersion: ShimVersion = ShimLoader.getShimVersion override def reproduceEmptyStringBug: Boolean = true diff --git a/sql-plugin/src/main/321cdh/scala/com/nvidia/spark/rapids/shims/SparkShims.scala b/sql-plugin/src/main/321cdh/scala/com/nvidia/spark/rapids/shims/SparkShims.scala index 8e5d50a4d10..1604df81f1f 100644 --- a/sql-plugin/src/main/321cdh/scala/com/nvidia/spark/rapids/shims/SparkShims.scala +++ b/sql-plugin/src/main/321cdh/scala/com/nvidia/spark/rapids/shims/SparkShims.scala @@ -21,7 +21,7 @@ import com.nvidia.spark.rapids._ object SparkShimImpl extends Spark321PlusShims with Spark320PlusNonDBShims with Spark31Xuntil33XShims - with Spark320until340Shims { + with AnsiCastRuleShims { override def getSparkShimVersion: ShimVersion = ShimLoader.getShimVersion override def reproduceEmptyStringBug: Boolean = true diff --git a/sql-plugin/src/main/321db/com/databricks/sql/transaction/tahoe/rapids/shims/GpuOptimisticTransaction.scala b/sql-plugin/src/main/321db/com/databricks/sql/transaction/tahoe/rapids/shims/GpuOptimisticTransaction.scala new file mode 100644 index 00000000000..a5a11359c7a --- /dev/null +++ b/sql-plugin/src/main/321db/com/databricks/sql/transaction/tahoe/rapids/shims/GpuOptimisticTransaction.scala @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * This file was derived from OptimisticTransaction.scala and TransactionalWrite.scala + * in the Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.sql.transaction.tahoe.rapids.shims + +import com.databricks.sql.transaction.tahoe._ +import com.databricks.sql.transaction.tahoe.commands.cdc.CDCReader +import com.nvidia.spark.rapids._ + +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types._ +import org.apache.spark.util.Clock + + +class GpuOptimisticTransaction( + deltaLog: DeltaLog, + snapshot: Snapshot, + rapidsConf: RapidsConf)(implicit clock: Clock) + extends GpuOptimisticTransactionBase(deltaLog, snapshot, rapidsConf)(clock) { + + def this(deltaLog: DeltaLog, rapidsConf: RapidsConf)(implicit clock: Clock) { + this(deltaLog, deltaLog.update(), rapidsConf) + } + + /** + * Returns a tuple of (data, partition schema). For CDC writes, a `__is_cdc` column is added to + * the data and `__is_cdc=true/false` is added to the front of the partition schema. + */ + override def shimPerformCDCPartition(inputData: Dataset[_]): (DataFrame, StructType) = { + // If this is a CDC write, we need to generate the CDC_PARTITION_COL in order to properly + // dispatch rows between the main table and CDC event records. This is a virtual partition + // and will be stripped out later in [[DelayedCommitProtocolEdge]]. + // Note that the ordering of the partition schema is relevant - CDC_PARTITION_COL must + // come first in order to ensure CDC data lands in the right place. + if (CDCReader.isCDCEnabledOnTable(metadata) && + inputData.schema.fieldNames.contains(CDCReader.CDC_TYPE_COLUMN_NAME)) { + val augmentedData = inputData.withColumn( + CDCReader.CDC_PARTITION_COL, col(CDCReader.CDC_TYPE_COLUMN_NAME).isNotNull) + val partitionSchema = StructType( + StructField(CDCReader.CDC_PARTITION_COL, StringType) +: metadata.physicalPartitionSchema) + (augmentedData, partitionSchema) + } else { + (inputData.toDF(), metadata.physicalPartitionSchema) + } + } +} diff --git a/sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/delta/shims/DeltaShims321PlusDB.scala b/sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/delta/shims/DeltaShims321PlusDB.scala new file mode 100644 index 00000000000..d007a02c779 --- /dev/null +++ b/sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/delta/shims/DeltaShims321PlusDB.scala @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids.delta.shims + +import com.databricks.sql.transaction.tahoe.DeltaUDF +import com.databricks.sql.transaction.tahoe.constraints.Constraints._ +import com.databricks.sql.transaction.tahoe.schema.InvariantViolationException + +import org.apache.spark.sql.expressions.UserDefinedFunction + +object InvariantViolationExceptionShim { + def apply(c: Check, m: Map[String, Any]): InvariantViolationException = { + InvariantViolationException(c, m) + } + + def apply(c: NotNull): InvariantViolationException = { + InvariantViolationException(c) + } +} + +object ShimDeltaUDF { + def stringStringUdf(f: String => String): UserDefinedFunction = { + DeltaUDF.stringStringUdf(f) + } +} \ No newline at end of file diff --git a/sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/shims/SparkShims.scala b/sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/shims/SparkShims.scala index aa33d1e4933..69effebab6b 100644 --- a/sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/shims/SparkShims.scala +++ b/sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/shims/SparkShims.scala @@ -16,279 +16,9 @@ package com.nvidia.spark.rapids.shims -import com.databricks.sql.execution.window.RunningWindowFunctionExec -import com.databricks.sql.optimizer.PlanDynamicPruningFilters -import com.nvidia.spark.rapids._ -import org.apache.hadoop.fs.FileStatus +import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.memory.TaskMemoryManager -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec} -import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.json.JsonFileFormat -import org.apache.spark.sql.execution.exchange.ReusedExchangeExec -import org.apache.spark.sql.execution.python.WindowInPandasExec -import org.apache.spark.sql.execution.window.WindowExecBase -import org.apache.spark.sql.rapids.GpuFileSourceScanExec -import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExec, GpuSubqueryBroadcastExec} -import org.apache.spark.sql.rapids.execution.shims.{GpuSubqueryBroadcastMeta, ReuseGpuBroadcastExchangeAndSubquery} -import org.apache.spark.sql.rapids.shims.GpuFileScanRDD -import org.apache.spark.sql.types._ - -object SparkShimImpl extends Spark321PlusShims with Spark320until340Shims { - - override def getSparkShimVersion: ShimVersion = ShimLoader.getShimVersion - - override def isCastingStringToNegDecimalScaleSupported: Boolean = true - - override def getFileScanRDD( - sparkSession: SparkSession, - readFunction: PartitionedFile => Iterator[InternalRow], - filePartitions: Seq[FilePartition], - readDataSchema: StructType, - metadataColumns: Seq[AttributeReference]): RDD[InternalRow] = { - new GpuFileScanRDD(sparkSession, readFunction, filePartitions) - } - - override def broadcastModeTransform(mode: BroadcastMode, rows: Array[InternalRow]): Any = { - // In some cases we can be asked to transform when there's no task context, which appears to - // be new behavior since Databricks 10.4. A task memory manager must be passed, so if one is - // not available we construct one from the main memory manager using a task attempt ID of 0. - val memoryManager = Option(TaskContext.get).map(_.taskMemoryManager()).getOrElse { - new TaskMemoryManager(SparkEnv.get.memoryManager, 0) - } - mode.transform(rows, memoryManager) - } - - override def newBroadcastQueryStageExec( - old: BroadcastQueryStageExec, - newPlan: SparkPlan): BroadcastQueryStageExec = - BroadcastQueryStageExec(old.id, newPlan, old.originalPlan, old.isSparkExchange) - - override def filesFromFileIndex(fileCatalog: PartitioningAwareFileIndex): Seq[FileStatus] = { - fileCatalog.allFiles().map(_.toFileStatus) - } - - override def neverReplaceShowCurrentNamespaceCommand: ExecRule[_ <: SparkPlan] = null - - override def getWindowExpressions(winPy: WindowInPandasExec): Seq[NamedExpression] = - winPy.projectList - - override def isWindowFunctionExec(plan: SparkPlan): Boolean = - plan.isInstanceOf[WindowExecBase] || plan.isInstanceOf[RunningWindowFunctionExec] - - override def applyShimPlanRules(plan: SparkPlan, conf: RapidsConf): SparkPlan = { - if (plan.conf.adaptiveExecutionEnabled) { - plan // AQE+DPP cooperation ensures the optimization runs early - } else { - val sparkSession = plan.session - val rules = Seq( - PlanDynamicPruningFilters(sparkSession) - ) - rules.foldLeft(plan) { case (sp, rule) => - rule.apply(sp) - } - } - } - - override def applyPostShimPlanRules(plan: SparkPlan): SparkPlan = { - val rules = Seq( - ReuseGpuBroadcastExchangeAndSubquery - ) - rules.foldLeft(plan) { case (sp, rule) => - rule.apply(sp) - } - } - - private val shimExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = { - Seq( - GpuOverrides.exec[SubqueryBroadcastExec]( - "Plan to collect and transform the broadcast key values", - ExecChecks(TypeSig.all, TypeSig.all), - (s, conf, p, r) => new GpuSubqueryBroadcastMeta(s, conf, p, r)), - GpuOverrides.exec[FileSourceScanExec]( - "Reading data from files, often from Hive tables", - ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP + - TypeSig.ARRAY + TypeSig.BINARY + TypeSig.DECIMAL_128).nested(), TypeSig.all), - (fsse, conf, p, r) => new SparkPlanMeta[FileSourceScanExec](fsse, conf, p, r) { - - // Replaces SubqueryBroadcastExec inside dynamic pruning filters with GPU counterpart - // if possible. Instead regarding filters as childExprs of current Meta, we create - // a new meta for SubqueryBroadcastExec. The reason is that the GPU replacement of - // FileSourceScan is independent from the replacement of the partitionFilters. It is - // possible that the FileSourceScan is on the CPU, while the dynamic partitionFilters - // are on the GPU. And vice versa. - private lazy val partitionFilters = { - val convertBroadcast = (bc: SubqueryBroadcastExec) => { - val meta = GpuOverrides.wrapAndTagPlan(bc, conf) - meta.tagForExplain() - val converted = meta.convertIfNeeded() - // Because the PlanSubqueries rule is not called (and does not work as expected), - // we might actually have to fully convert the subquery plan as the plugin would - // intend (in this case calling GpuTransitionOverrides to insert GpuCoalesceBatches, - // etc.) to match the other side of the join to reuse the BroadcastExchange. - // This happens when SubqueryBroadcast has the original (Gpu)BroadcastExchangeExec - converted match { - case e: GpuSubqueryBroadcastExec => e.child match { - // If the GpuBroadcastExchange is here, then we will need to run the transition - // overrides here - case _: GpuBroadcastExchangeExec => - var updated = ApplyColumnarRulesAndInsertTransitions(Seq(), true) - .apply(converted) - updated = (new GpuTransitionOverrides()).apply(updated) - updated match { - case h: GpuBringBackToHost => - h.child.asInstanceOf[BaseSubqueryExec] - case c2r: GpuColumnarToRowExec => - c2r.child.asInstanceOf[BaseSubqueryExec] - case _: GpuSubqueryBroadcastExec => - updated.asInstanceOf[BaseSubqueryExec] - } - // Otherwise, if this SubqueryBroadcast is using a ReusedExchange, then we don't - // do anything further - case _: ReusedExchangeExec => - converted.asInstanceOf[BaseSubqueryExec] - } - case _ => - converted.asInstanceOf[BaseSubqueryExec] - } - } - - wrapped.partitionFilters.map { filter => - filter.transformDown { - case dpe @ DynamicPruningExpression(inSub: InSubqueryExec) => - inSub.plan match { - case bc: SubqueryBroadcastExec => - dpe.copy(inSub.copy(plan = convertBroadcast(bc))) - case reuse @ ReusedSubqueryExec(bc: SubqueryBroadcastExec) => - dpe.copy(inSub.copy(plan = reuse.copy(convertBroadcast(bc)))) - case _ => - dpe - } - } - } - } - - // partition filters and data filters are not run on the GPU - override val childExprs: Seq[ExprMeta[_]] = Seq.empty - - override def tagPlanForGpu(): Unit = { - // this is very specific check to have any of the Delta log metadata queries - // fallback and run on the CPU since there is some incompatibilities in - // Databricks Spark and Apache Spark. - if (wrapped.relation.fileFormat.isInstanceOf[JsonFileFormat] && - wrapped.relation.location.getClass.getCanonicalName() == - "com.databricks.sql.transaction.tahoe.DeltaLogFileIndex") { - this.entirePlanWillNotWork("Plans that read Delta Index JSON files can not run " + - "any part of the plan on the GPU!") - } - GpuFileSourceScanExec.tagSupport(this) - } - - override def convertToCpu(): SparkPlan = { - wrapped.copy(partitionFilters = partitionFilters) - } - - override def convertToGpu(): GpuExec = { - val sparkSession = wrapped.relation.sparkSession - val options = wrapped.relation.options - val (location, alluxioPathsToReplaceMap) = - if (AlluxioCfgUtils.enabledAlluxioReplacementAlgoConvertTime(conf)) { - val shouldReadFromS3 = wrapped.relation.location match { - // Only handle InMemoryFileIndex - // - // skip handle `MetadataLogFileIndex`, from the description of this class: - // it's about the files generated by the `FileStreamSink`. - // The streaming data source is not in our scope. - // - // For CatalogFileIndex and FileIndex of `delta` data source, - // need more investigation. - case inMemory: InMemoryFileIndex => - // List all the partitions to reduce overhead, pass in 2 empty filters. - // Subsequent process will do the right partition pruning. - // This operation is fast, because it lists files from the caches and the caches - // already exist in the `InMemoryFileIndex`. - val pds = inMemory.listFiles(Seq.empty, Seq.empty) - AlluxioUtils.shouldReadDirectlyFromS3(conf, pds) - case _ => - false - } - - if (!shouldReadFromS3) { - // it's convert time algorithm and some paths are not large tables - AlluxioUtils.replacePathIfNeeded( - conf, - wrapped.relation, - partitionFilters, - wrapped.dataFilters) - } else { - // convert time algorithm and read large files - (wrapped.relation.location, None) - } - } else { - // it's not convert time algorithm or read large files, do not replace - (wrapped.relation.location, None) - } - - val newRelation = HadoopFsRelation( - location, - wrapped.relation.partitionSchema, - wrapped.relation.dataSchema, - wrapped.relation.bucketSpec, - GpuFileSourceScanExec.convertFileFormat(wrapped.relation.fileFormat), - options)(sparkSession) - - GpuFileSourceScanExec( - newRelation, - wrapped.output, - wrapped.requiredSchema, - partitionFilters, - wrapped.optionalBucketSet, - // TODO: Does Databricks have coalesced bucketing implemented? - None, - wrapped.dataFilters, - wrapped.tableIdentifier, - wrapped.disableBucketedScan, - queryUsesInputFile = false, - alluxioPathsToReplaceMap)(conf) - } - }), - GpuOverrides.exec[RunningWindowFunctionExec]( - "Databricks-specific window function exec, for \"running\" windows, " + - "i.e. (UNBOUNDED PRECEDING TO CURRENT ROW)", - ExecChecks( - (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + - TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP).nested(), - TypeSig.all, - Map("partitionSpec" -> - InputCheck(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, - TypeSig.all))), - (runningWindowFunctionExec, conf, p, r) => - new GpuRunningWindowExecMeta(runningWindowFunctionExec, conf, p, r) - ) - ).map(r => (r.getClassFor.asSubclass(classOf[SparkPlan]), r)).toMap - } - - override def getExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = - super.getExecs ++ shimExecs - - /** - * Case class ShuffleQueryStageExec holds an additional field shuffleOrigin - * affecting the unapply method signature - */ - override def reusedExchangeExecPfn: PartialFunction[SparkPlan, ReusedExchangeExec] = { - case ShuffleQueryStageExec(_, e: ReusedExchangeExec, _, _) => e - case BroadcastQueryStageExec(_, e: ReusedExchangeExec, _, _) => e - } - - override def reproduceEmptyStringBug: Boolean = true -} +object SparkShimImpl extends Spark321PlusDBShims with AnsiCastRuleShims // Fallback to the default definition of `deterministic` trait GpuDeterministicFirstLastCollectShim extends Expression diff --git a/sql-plugin/src/main/322/scala/com/nvidia/spark/rapids/shims/SparkShims.scala b/sql-plugin/src/main/322/scala/com/nvidia/spark/rapids/shims/SparkShims.scala index 1352c75099f..1bd4cd77818 100644 --- a/sql-plugin/src/main/322/scala/com/nvidia/spark/rapids/shims/SparkShims.scala +++ b/sql-plugin/src/main/322/scala/com/nvidia/spark/rapids/shims/SparkShims.scala @@ -21,6 +21,6 @@ import com.nvidia.spark.rapids._ object SparkShimImpl extends Spark321PlusShims with Spark320PlusNonDBShims with Spark31Xuntil33XShims - with Spark320until340Shims { + with AnsiCastRuleShims { override def getSparkShimVersion: ShimVersion = ShimLoader.getShimVersion } diff --git a/sql-plugin/src/main/323/scala/com/nvidia/spark/rapids/shims/SparkShims.scala b/sql-plugin/src/main/323/scala/com/nvidia/spark/rapids/shims/SparkShims.scala index 06a2b6ef29c..9f6b3f70935 100644 --- a/sql-plugin/src/main/323/scala/com/nvidia/spark/rapids/shims/SparkShims.scala +++ b/sql-plugin/src/main/323/scala/com/nvidia/spark/rapids/shims/SparkShims.scala @@ -21,6 +21,6 @@ import com.nvidia.spark.rapids._ object SparkShimImpl extends Spark321PlusShims with Spark320PlusNonDBShims with Spark31Xuntil33XShims - with Spark320until340Shims { + with AnsiCastRuleShims { override def getSparkShimVersion: ShimVersion = ShimLoader.getShimVersion } diff --git a/sql-plugin/src/main/330+-nondb/scala/com/nvidia/spark/rapids/shims/Spark330PlusNonDBShims.scala b/sql-plugin/src/main/330+-nondb/scala/com/nvidia/spark/rapids/shims/Spark330PlusNonDBShims.scala new file mode 100644 index 00000000000..1974e17ab52 --- /dev/null +++ b/sql-plugin/src/main/330+-nondb/scala/com/nvidia/spark/rapids/shims/Spark330PlusNonDBShims.scala @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims + +trait Spark330PlusNonDBShims extends Spark330PlusShims diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark330PlusShims.scala b/sql-plugin/src/main/330+-nondb/scala/com/nvidia/spark/rapids/shims/Spark330PlusShims.scala similarity index 100% rename from sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark330PlusShims.scala rename to sql-plugin/src/main/330+-nondb/scala/com/nvidia/spark/rapids/shims/Spark330PlusShims.scala diff --git a/sql-plugin/src/main/330/scala/com/nvidia/spark/rapids/SparkShims.scala b/sql-plugin/src/main/330/scala/com/nvidia/spark/rapids/SparkShims.scala index 5f4ea0b35f4..f793678350e 100644 --- a/sql-plugin/src/main/330/scala/com/nvidia/spark/rapids/SparkShims.scala +++ b/sql-plugin/src/main/330/scala/com/nvidia/spark/rapids/SparkShims.scala @@ -16,8 +16,4 @@ package com.nvidia.spark.rapids.shims -import com.nvidia.spark.rapids._ - -object SparkShimImpl extends Spark330PlusShims with Spark320until340Shims { - override def getSparkShimVersion: ShimVersion = ShimLoader.getShimVersion -} +object SparkShimImpl extends Spark330PlusShims with AnsiCastRuleShims diff --git a/sql-plugin/src/main/330cdh/scala/com/nvidia/spark/rapids/shims/SparkShims.scala b/sql-plugin/src/main/330cdh/scala/com/nvidia/spark/rapids/shims/SparkShims.scala index c384a3da842..db0f502051a 100644 --- a/sql-plugin/src/main/330cdh/scala/com/nvidia/spark/rapids/shims/SparkShims.scala +++ b/sql-plugin/src/main/330cdh/scala/com/nvidia/spark/rapids/shims/SparkShims.scala @@ -16,8 +16,4 @@ package com.nvidia.spark.rapids.shims -import com.nvidia.spark.rapids._ - -object SparkShimImpl extends Spark330PlusShims with Spark320until340Shims { - override def getSparkShimVersion: ShimVersion = ShimLoader.getShimVersion -} +object SparkShimImpl extends Spark330PlusNonDBShims with AnsiCastRuleShims diff --git a/sql-plugin/src/main/330db/com/databricks/sql/transaction/tahoe/rapids/shims/GpuOptimisticTransaction.scala b/sql-plugin/src/main/330db/com/databricks/sql/transaction/tahoe/rapids/shims/GpuOptimisticTransaction.scala new file mode 100644 index 00000000000..1dfe75c5e2b --- /dev/null +++ b/sql-plugin/src/main/330db/com/databricks/sql/transaction/tahoe/rapids/shims/GpuOptimisticTransaction.scala @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * This file was derived from OptimisticTransaction.scala and TransactionalWrite.scala + * in the Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.sql.transaction.tahoe.rapids.shims + +import com.databricks.sql.transaction.tahoe._ +import com.nvidia.spark.rapids._ + +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.types._ +import org.apache.spark.util.Clock + + +class GpuOptimisticTransaction( + deltaLog: DeltaLog, + snapshot: Snapshot, + rapidsConf: RapidsConf)(implicit clock: Clock) + extends GpuOptimisticTransactionBase(deltaLog, snapshot, rapidsConf)(clock) { + + def this(deltaLog: DeltaLog, rapidsConf: RapidsConf)(implicit clock: Clock) { + this(deltaLog, deltaLog.update(), rapidsConf) + } + + /** + * Returns a tuple of (data, partition schema). For CDC writes, a `__is_cdc` column is added to + * the data and `__is_cdc=true/false` is added to the front of the partition schema. + */ + override def shimPerformCDCPartition(inputData: Dataset[_]): (DataFrame, StructType) = { + performCDCPartition(inputData: Dataset[_]) + } +} diff --git a/sql-plugin/src/main/330db/scala/com/nvidia/spark/rapids/delta/shims/DeltaShims321PlusDB.scala b/sql-plugin/src/main/330db/scala/com/nvidia/spark/rapids/delta/shims/DeltaShims321PlusDB.scala new file mode 100644 index 00000000000..0fe7267e489 --- /dev/null +++ b/sql-plugin/src/main/330db/scala/com/nvidia/spark/rapids/delta/shims/DeltaShims321PlusDB.scala @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids.delta.shims + +import com.databricks.sql.transaction.tahoe.DeltaUDF +import com.databricks.sql.transaction.tahoe.constraints.Constraints._ +import com.databricks.sql.transaction.tahoe.schema.DeltaInvariantViolationException +import com.databricks.sql.transaction.tahoe.schema.InvariantViolationException + +import org.apache.spark.sql.expressions.UserDefinedFunction + +object InvariantViolationExceptionShim { + def apply(c: Check, m: Map[String, Any]): InvariantViolationException = { + DeltaInvariantViolationException(c, m) + } + + def apply(c: NotNull): InvariantViolationException = { + DeltaInvariantViolationException(c) + } +} + +object ShimDeltaUDF { + def stringStringUdf(f: String => String): UserDefinedFunction ={ + DeltaUDF.stringFromString(f) + } +} \ No newline at end of file diff --git a/sql-plugin/src/main/330db/scala/com/nvidia/spark/rapids/shims/SparkShims.scala b/sql-plugin/src/main/330db/scala/com/nvidia/spark/rapids/shims/SparkShims.scala new file mode 100644 index 00000000000..f29bb9fc2c8 --- /dev/null +++ b/sql-plugin/src/main/330db/scala/com/nvidia/spark/rapids/shims/SparkShims.scala @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims + +import com.nvidia.spark.rapids._ +import org.apache.parquet.schema.MessageType + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.trees.TreePattern._ +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters + +object SparkShimImpl extends Spark321PlusDBShims { + // AnsiCast is removed from Spark3.4.0 + override def ansiCastRule: ExprRule[_ <: Expression] = null + + override def getParquetFilters( + schema: MessageType, + pushDownDate: Boolean, + pushDownTimestamp: Boolean, + pushDownDecimal: Boolean, + pushDownStartWith: Boolean, + pushDownInFilterThreshold: Int, + caseSensitive: Boolean, + lookupFileMeta: String => String, + dateTimeRebaseModeFromConf: String): ParquetFilters = { + val datetimeRebaseMode = DataSourceUtils + .datetimeRebaseSpec(lookupFileMeta, dateTimeRebaseModeFromConf) + new ParquetFilters(schema, pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStartWith, + pushDownInFilterThreshold, caseSensitive, datetimeRebaseMode) + } +} + +trait ShimExtractValue extends ExtractValue { + override def nodePatternsInternal(): Seq[TreePattern] = Seq.empty +} + +// Fallback to the default definition of `deterministic` +trait GpuDeterministicFirstLastCollectShim extends Expression \ No newline at end of file diff --git a/sql-plugin/src/main/330db/scala/com/nvidia/spark/rapids/shims/spark330db/SparkShimServiceProvider.scala b/sql-plugin/src/main/330db/scala/com/nvidia/spark/rapids/shims/spark330db/SparkShimServiceProvider.scala new file mode 100644 index 00000000000..a7e819e148b --- /dev/null +++ b/sql-plugin/src/main/330db/scala/com/nvidia/spark/rapids/shims/spark330db/SparkShimServiceProvider.scala @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims.spark330db + +import com.nvidia.spark.rapids.{DatabricksShimVersion, ShimVersion} + +import org.apache.spark.SparkEnv + +object SparkShimServiceProvider { + val VERSION = DatabricksShimVersion(3, 3, 0) +} + +class SparkShimServiceProvider extends com.nvidia.spark.rapids.SparkShimServiceProvider { + + override def getShimVersion: ShimVersion = SparkShimServiceProvider.VERSION + + def matchesVersion(version: String): Boolean = { + SparkEnv.get.conf.get("spark.databricks.clusterUsageTags.sparkVersion", "").startsWith("11.3.") + } +} diff --git a/sql-plugin/src/main/330db/scala/org/apache/spark/sql/rapids/shims/SparkDateTimeExceptionShims.scala b/sql-plugin/src/main/330db/scala/org/apache/spark/sql/rapids/shims/SparkDateTimeExceptionShims.scala new file mode 100644 index 00000000000..0210302a4f9 --- /dev/null +++ b/sql-plugin/src/main/330db/scala/org/apache/spark/sql/rapids/shims/SparkDateTimeExceptionShims.scala @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.shims + +import org.apache.spark.{QueryContext, SparkDateTimeException} + +object SparkDateTimeExceptionShims { + + def newSparkDateTimeException( + errorClass: String, + messageParameters: Map[String, String], + context: Array[QueryContext], + summary: String): SparkDateTimeException = { + new SparkDateTimeException( + errorClass, + None, + Array.empty, + context, + summary) + } +} diff --git a/sql-plugin/src/main/330db/scala/org/apache/spark/sql/rapids/shims/SparkUpgradeExceptionShims.scala b/sql-plugin/src/main/330db/scala/org/apache/spark/sql/rapids/shims/SparkUpgradeExceptionShims.scala new file mode 100644 index 00000000000..3675e5f23a0 --- /dev/null +++ b/sql-plugin/src/main/330db/scala/org/apache/spark/sql/rapids/shims/SparkUpgradeExceptionShims.scala @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.shims + +import org.apache.spark.SparkUpgradeException + +object SparkUpgradeExceptionShims { + + def newSparkUpgradeException( + version: String, + message: String, + cause: Throwable): SparkUpgradeException = { + new SparkUpgradeException( + "INCONSISTENT_BEHAVIOR_CROSS_VERSION", + None, + Array(version, message), + cause) + } +} diff --git a/sql-plugin/src/main/330until340/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/ParquetCVShims.scala b/sql-plugin/src/main/330until340-nondb/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/ParquetCVShims.scala similarity index 100% rename from sql-plugin/src/main/330until340/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/ParquetCVShims.scala rename to sql-plugin/src/main/330until340-nondb/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/ParquetCVShims.scala diff --git a/sql-plugin/src/main/330until340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/330until340-nondb/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala similarity index 100% rename from sql-plugin/src/main/330until340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala rename to sql-plugin/src/main/330until340-nondb/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala diff --git a/sql-plugin/src/main/330until340/scala/org/apache/spark/sql/rapids/shims/SparkUpgradeExceptionShims.scala b/sql-plugin/src/main/330until340-nondb/scala/org/apache/spark/sql/rapids/shims/SparkUpgradeExceptionShims.scala similarity index 100% rename from sql-plugin/src/main/330until340/scala/org/apache/spark/sql/rapids/shims/SparkUpgradeExceptionShims.scala rename to sql-plugin/src/main/330until340-nondb/scala/org/apache/spark/sql/rapids/shims/SparkUpgradeExceptionShims.scala diff --git a/sql-plugin/src/main/331+/scala/com/nvidia/spark/rapids/shims/Spark331PlusShims.scala b/sql-plugin/src/main/331+/scala/com/nvidia/spark/rapids/shims/Spark331PlusShims.scala index 868fe0dbddc..d2973274373 100644 --- a/sql-plugin/src/main/331+/scala/com/nvidia/spark/rapids/shims/Spark331PlusShims.scala +++ b/sql-plugin/src/main/331+/scala/com/nvidia/spark/rapids/shims/Spark331PlusShims.scala @@ -21,7 +21,7 @@ import com.nvidia.spark.rapids.{ExprChecks, ExprRule, GpuCast, GpuExpression, Gp import org.apache.spark.sql.catalyst.expressions.{CheckOverflowInTableInsert, Expression} import org.apache.spark.sql.rapids.GpuCheckOverflowInTableInsert -trait Spark331PlusShims extends Spark330PlusShims { +trait Spark331PlusShims extends Spark330PlusNonDBShims { override def getExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = { val map: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq( // Add expression CheckOverflowInTableInsert starting Spark-3.3.1+ diff --git a/sql-plugin/src/main/331/scala/com/nvidia/spark/rapids/SparkShims.scala b/sql-plugin/src/main/331/scala/com/nvidia/spark/rapids/SparkShims.scala index 2ced31cd978..3a1e198a741 100644 --- a/sql-plugin/src/main/331/scala/com/nvidia/spark/rapids/SparkShims.scala +++ b/sql-plugin/src/main/331/scala/com/nvidia/spark/rapids/SparkShims.scala @@ -16,8 +16,4 @@ package com.nvidia.spark.rapids.shims -import com.nvidia.spark.rapids._ - -object SparkShimImpl extends Spark331PlusShims with Spark320until340Shims { - override def getSparkShimVersion: ShimVersion = ShimLoader.getShimVersion -} +object SparkShimImpl extends Spark331PlusShims with AnsiCastRuleShims diff --git a/sql-plugin/src/main/332/scala/com/nvidia/spark/rapids/shims/SparkShims.scala b/sql-plugin/src/main/332/scala/com/nvidia/spark/rapids/shims/SparkShims.scala new file mode 100644 index 00000000000..3a1e198a741 --- /dev/null +++ b/sql-plugin/src/main/332/scala/com/nvidia/spark/rapids/shims/SparkShims.scala @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims + +object SparkShimImpl extends Spark331PlusShims with AnsiCastRuleShims diff --git a/sql-plugin/src/main/340+/scala/com/nvidia/spark/rapids/shims/CastingConfigShim.scala b/sql-plugin/src/main/340+-and-330db/scala/com/nvidia/spark/rapids/shims/CastingConfigShim.scala similarity index 100% rename from sql-plugin/src/main/340+/scala/com/nvidia/spark/rapids/shims/CastingConfigShim.scala rename to sql-plugin/src/main/340+-and-330db/scala/com/nvidia/spark/rapids/shims/CastingConfigShim.scala diff --git a/sql-plugin/src/main/340+/scala/com/nvidia/spark/rapids/shims/DecimalArithmeticOverrides.scala b/sql-plugin/src/main/340+-and-330db/scala/com/nvidia/spark/rapids/shims/DecimalArithmeticOverrides.scala similarity index 100% rename from sql-plugin/src/main/340+/scala/com/nvidia/spark/rapids/shims/DecimalArithmeticOverrides.scala rename to sql-plugin/src/main/340+-and-330db/scala/com/nvidia/spark/rapids/shims/DecimalArithmeticOverrides.scala diff --git a/sql-plugin/src/main/340+/scala/com/nvidia/spark/rapids/shims/GetMapValueMeta.scala b/sql-plugin/src/main/340+-and-330db/scala/com/nvidia/spark/rapids/shims/GetMapValueMeta.scala similarity index 100% rename from sql-plugin/src/main/340+/scala/com/nvidia/spark/rapids/shims/GetMapValueMeta.scala rename to sql-plugin/src/main/340+-and-330db/scala/com/nvidia/spark/rapids/shims/GetMapValueMeta.scala diff --git a/sql-plugin/src/main/340+/scala/com/nvidia/spark/rapids/shims/ParquetStringPredShims.scala b/sql-plugin/src/main/340+-and-330db/scala/com/nvidia/spark/rapids/shims/ParquetStringPredShims.scala similarity index 100% rename from sql-plugin/src/main/340+/scala/com/nvidia/spark/rapids/shims/ParquetStringPredShims.scala rename to sql-plugin/src/main/340+-and-330db/scala/com/nvidia/spark/rapids/shims/ParquetStringPredShims.scala diff --git a/sql-plugin/src/main/340+/scala/com/nvidia/spark/rapids/shims/ShimFilePartitionReaderFactory.scala b/sql-plugin/src/main/340+-and-330db/scala/com/nvidia/spark/rapids/shims/ShimFilePartitionReaderFactory.scala similarity index 100% rename from sql-plugin/src/main/340+/scala/com/nvidia/spark/rapids/shims/ShimFilePartitionReaderFactory.scala rename to sql-plugin/src/main/340+-and-330db/scala/com/nvidia/spark/rapids/shims/ShimFilePartitionReaderFactory.scala diff --git a/sql-plugin/src/main/340+/scala/com/nvidia/spark/rapids/shims/TypeUtilsShims.scala b/sql-plugin/src/main/340+-and-330db/scala/com/nvidia/spark/rapids/shims/TypeUtilsShims.scala similarity index 100% rename from sql-plugin/src/main/340+/scala/com/nvidia/spark/rapids/shims/TypeUtilsShims.scala rename to sql-plugin/src/main/340+-and-330db/scala/com/nvidia/spark/rapids/shims/TypeUtilsShims.scala diff --git a/sql-plugin/src/main/340+/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/ParquetCVShims.scala b/sql-plugin/src/main/340+-and-330db/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/ParquetCVShims.scala similarity index 100% rename from sql-plugin/src/main/340+/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/ParquetCVShims.scala rename to sql-plugin/src/main/340+-and-330db/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/ParquetCVShims.scala diff --git a/sql-plugin/src/main/340+/scala/org/apache/spark/sql/execution/datasources/rapids/DataSourceStrategyUtils.scala b/sql-plugin/src/main/340+-and-330db/scala/org/apache/spark/sql/execution/datasources/rapids/DataSourceStrategyUtils.scala similarity index 100% rename from sql-plugin/src/main/340+/scala/org/apache/spark/sql/execution/datasources/rapids/DataSourceStrategyUtils.scala rename to sql-plugin/src/main/340+-and-330db/scala/org/apache/spark/sql/execution/datasources/rapids/DataSourceStrategyUtils.scala diff --git a/sql-plugin/src/main/340+/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/340+-and-330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala similarity index 94% rename from sql-plugin/src/main/340+/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala rename to sql-plugin/src/main/340+-and-330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index bf436326e8b..f8a0bd4881e 100644 --- a/sql-plugin/src/main/340+/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/340+-and-330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -52,7 +52,7 @@ object RapidsErrorUtils extends RapidsErrorUtilsFor330plus { QueryExecutionErrors.arithmeticOverflowError(message, hint, errorContext) } - def cannotChangeDecimalPrecisionError( + def cannotChangeDecimalPrecisionError( value: Decimal, toType: DecimalType, context: SQLQueryContext = null): ArithmeticException = { @@ -72,7 +72,8 @@ object RapidsErrorUtils extends RapidsErrorUtilsFor330plus { val errorClass = "CAST_INVALID_INPUT" val messageParameters = Map("expression" -> infOrNan, "sourceType" -> "DOUBLE", "targetType" -> "TIMESTAMP", "ansiConfig" -> SQLConf.ANSI_ENABLED.key) - new SparkDateTimeException(errorClass, messageParameters, Array.empty, "") + SparkDateTimeExceptionShims.newSparkDateTimeException(errorClass, messageParameters, + Array.empty, "") } def sqlArrayIndexNotStartAtOneError(): RuntimeException = { diff --git a/sql-plugin/src/main/340+/scala/org/apache/spark/sql/rapids/shims/SparkDateTimeExceptionShims.scala b/sql-plugin/src/main/340+/scala/org/apache/spark/sql/rapids/shims/SparkDateTimeExceptionShims.scala new file mode 100644 index 00000000000..da586510b11 --- /dev/null +++ b/sql-plugin/src/main/340+/scala/org/apache/spark/sql/rapids/shims/SparkDateTimeExceptionShims.scala @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.shims + +import org.apache.spark.{QueryContext, SparkDateTimeException} + +object SparkDateTimeExceptionShims { + + def newSparkDateTimeException( + errorClass: String, + messageParameters: Map[String, String], + context: Array[QueryContext], + summary: String): SparkDateTimeException = { + new SparkDateTimeException( + errorClass, + messageParameters, + context, + summary) + } +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala index cafe88c7244..240aef9e796 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala @@ -56,7 +56,7 @@ case class DatabricksShimVersion( } trait SparkShims { - def getSparkShimVersion: ShimVersion + def getSparkShimVersion: ShimVersion = ShimLoader.getShimVersion def parquetRebaseReadKey: String def parquetRebaseWriteKey: String def avroRebaseReadKey: String diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala index a6d0fdf9ad4..ba4333211af 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala @@ -23,16 +23,16 @@ import ai.rapids.cudf import ai.rapids.cudf._ import com.nvidia.spark.rapids.GpuMetric._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.shims.{ShimLeafExecNode, ShimSparkPlan, ShimUnaryExecNode} +import com.nvidia.spark.rapids.shims._ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeReference, AttributeSeq, Descending, Expression, NamedExpression, NullIntolerant, SortOrder} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, RangePartitioning, SinglePartition, UnknownPartitioning} import org.apache.spark.sql.execution.{ProjectExec, SampleExec, SparkPlan} -import org.apache.spark.sql.rapids.{GpuPartitionwiseSampledRDD, GpuPoissonSampler, GpuPredicateHelper} +import org.apache.spark.sql.rapids.{GpuPartitionwiseSampledRDD, GpuPoissonSampler} import org.apache.spark.sql.rapids.execution.TrampolineUtil import org.apache.spark.sql.types.{DataType, LongType} import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -424,7 +424,7 @@ case class GpuFilterExec( condition: Expression, child: SparkPlan, override val coalesceAfter: Boolean = true) - extends ShimUnaryExecNode with GpuPredicateHelper with GpuExec { + extends ShimUnaryExecNode with ShimPredicateHelper with GpuExec { override lazy val additionalMetrics: Map[String, GpuMetric] = Map( OP_TIME -> createNanoTimingMetric(MODERATE_LEVEL, DESCRIPTION_OP_TIME)) @@ -435,12 +435,6 @@ case class GpuFilterExec( case _ => false } - // If one expression and its children are null intolerant, it is null intolerant. - private def isNullIntolerant(expr: Expression): Boolean = expr match { - case e: NullIntolerant => e.children.forall(isNullIntolerant) - case _ => false - } - // The columns that will filtered out by `IsNotNull` could be considered as not nullable. private val notNullAttributes = notNullPreds.flatMap(_.references).distinct.map(_.exprId) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala index a44f3f84b80..2198839e7ed 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala @@ -21,18 +21,21 @@ import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuBinaryExpression, Gp import com.nvidia.spark.rapids.ArrayIndexUtils.firstIndexAndNumElementUnchecked import com.nvidia.spark.rapids.BoolUtils.isAnyValidTrue import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.shims.ShimUnaryExpression +import com.nvidia.spark.rapids.shims._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExtractValue, GetArrayStructFields, ImplicitCastInputTypes, NullIntolerant} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{quoteIdentifier, TypeUtils} import org.apache.spark.sql.rapids.shims.RapidsErrorUtils import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, IntegralType, MapType, StructField, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch case class GpuGetStructField(child: Expression, ordinal: Int, name: Option[String] = None) - extends ShimUnaryExpression with GpuExpression with ExtractValue with NullIntolerant { + extends ShimUnaryExpression + with GpuExpression + with ShimExtractValue + with NullIntolerant { lazy val childSchema: StructType = child.dataType.asInstanceOf[StructType] @@ -79,7 +82,9 @@ case class GpuGetStructField(child: Expression, ordinal: Int, name: Option[Strin * We need to do type checking here as `ordinal` expression maybe unresolved. */ case class GpuGetArrayItem(child: Expression, ordinal: Expression, failOnError: Boolean) - extends GpuBinaryExpression with ExpectsInputTypes with ExtractValue { + extends GpuBinaryExpression + with ExpectsInputTypes + with ShimExtractValue { // We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType) @@ -259,17 +264,17 @@ case class GpuArrayContains(left: Expression, right: Expression) /** * Helper function to account for `libcudf`'s `listContains()` semantics. - * - * If a list row contains at least one null element, and is found not to contain - * the search key, `libcudf` returns false instead of null. SparkSQL expects to + * + * If a list row contains at least one null element, and is found not to contain + * the search key, `libcudf` returns false instead of null. SparkSQL expects to * return null in those cases. - * - * This method determines the result's validity mask by ORing the output of + * + * This method determines the result's validity mask by ORing the output of * `listContains()` with the NOT of `listContainsNulls()`. - * A result row is thus valid if either the search key is found in the list, + * A result row is thus valid if either the search key is found in the list, * or if the list does not contain any null elements. */ - private def orNotContainsNull(containsResult: ColumnVector, + private def orNotContainsNull(containsResult: ColumnVector, inputListsColumn:ColumnVector): ColumnVector = { val notContainsNull = withResource(inputListsColumn.listContainsNulls) { _.not @@ -297,7 +302,7 @@ case class GpuArrayContains(left: Expression, right: Expression) override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector = { val inputListsColumn = lhs.getBase - withResource(inputListsColumn.listContainsColumn(rhs.getBase)) { + withResource(inputListsColumn.listContainsColumn(rhs.getBase)) { orNotContainsNull(_, inputListsColumn) } } @@ -327,7 +332,9 @@ case class GpuGetArrayStructFields( field: StructField, ordinal: Int, numFields: Int, - containsNull: Boolean) extends GpuUnaryExpression with ExtractValue with NullIntolerant { + containsNull: Boolean) extends GpuUnaryExpression + with ShimExtractValue + with NullIntolerant { override def dataType: DataType = ArrayType(field.dataType, containsNull) override def toString: String = s"$child.${field.name}" diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/predicates.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/predicates.scala index a7d0c2e79b0..9590f161cad 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/predicates.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/predicates.scala @@ -21,20 +21,11 @@ import ai.rapids.cudf.ast.BinaryOperator import com.nvidia.spark.rapids._ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, NullIntolerant, Predicate} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, BooleanType, DataType, DoubleType, FloatType} import org.apache.spark.sql.vectorized.ColumnarBatch -trait GpuPredicateHelper { - protected def splitConjunctivePredicates(condition: Expression): Seq[Expression] = { - condition match { - case GpuAnd(cond1, cond2) => - splitConjunctivePredicates(cond1) ++ splitConjunctivePredicates(cond2) - case other => other :: Nil - } - } -} case class GpuNot(child: Expression) extends CudfUnaryExpression with Predicate with ImplicitCastInputTypes with NullIntolerant { diff --git a/sql-plugin/src/main/311+-db/scala/org/apache/spark/sql/rapids/shims/GpuFileScanRDD.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/shims/GpuFileScanRDD.scala similarity index 100% rename from sql-plugin/src/main/311+-db/scala/org/apache/spark/sql/rapids/shims/GpuFileScanRDD.scala rename to sql-plugin/src/main/scala/org/apache/spark/sql/rapids/shims/GpuFileScanRDD.scala