diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f60fdd1f3e7..f9fb7ac056d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -55,6 +55,27 @@ You can build against different versions of the CUDA Toolkit by using one of the ## Code contributions +### Source code layout + +Conventional code locations in Maven modules are found under `src/main/`. In addition to +that and in order to support multiple versions of Apache Spark with the minimum amount of source +code we maintain Spark-version-specific locations within non-shim modules if necessary. This allows +us to switch between incompatible parent classes inside without copying the shared code to +dedicated shim modules. + +Thus, the conventional source code root directories `src/main/` contain the files that +are source-compatible with all supported Spark releases, both upstream and vendor-specific. + +The version-specific directory names have one of the following forms / use cases: +- `src/main/312/scala` contains Scala source code for a single Spark version, 3.1.2 in this case +- `src/main/312+-apache/scala`contains Scala source code for *upstream* **Apache** Spark builds, + only beginning with version Spark 3.1.2, and + signifies there is no upper version boundary + among the supported versions +- `src/main/302until312-all` contains code that applies to all shims between 3.0.2 *inclusive*, +3.1.2 *exclusive* +- `src/main/302to312-cdh` contains code that applies to Cloudera CDH shims between 3.0.2 *inclusive*, + 3.1.2 *inclusive* + ### Your first issue 1. Read the [Developer Overview](docs/dev/README.md) to understand how the RAPIDS Accelerator diff --git a/pom.xml b/pom.xml index b437e59ea61..5f235789413 100644 --- a/pom.xml +++ b/pom.xml @@ -80,7 +80,6 @@ udf-compiler udf-examples - default @@ -91,8 +90,8 @@ api_validation tools - - + + no-buildver-default @@ -114,8 +113,7 @@ generate-sources - ${project.basedir}/src/main/spark30+all/scala - ${project.basedir}/src/main/spark30+all/java + ${project.basedir}/src/main/301until320-all/scala @@ -123,9 +121,8 @@ - - - + + buildver-default @@ -144,8 +141,7 @@ generate-sources - ${project.basedir}/src/main/spark${buildver}/scala - ${project.basedir}/src/main/spark${buildver}/java + ${project.basedir}/src/main/${buildver}/scala @@ -153,8 +149,8 @@ - - + + release301 @@ -178,7 +174,7 @@ generate-sources - ${project.basedir}/src/main/spark30+all/scala + ${project.basedir}/src/main/301until320-all/scala @@ -214,7 +210,7 @@ generate-sources - ${project.basedir}/src/main/spark30+all/scala + ${project.basedir}/src/main/301until320-all/scala @@ -243,7 +239,7 @@ generate-sources - ${project.basedir}/src/main/spark30+all/scala + ${project.basedir}/src/main/301until320-all/scala @@ -283,7 +279,7 @@ generate-sources - ${project.basedir}/src/main/spark30+all/scala + ${project.basedir}/src/main/301until320-all/scala @@ -319,7 +315,7 @@ generate-sources - ${project.basedir}/src/main/spark30+all/scala + ${project.basedir}/src/main/301until320-all/scala @@ -355,9 +351,9 @@ generate-sources - ${project.basedir}/src/main/spark30+all/scala - ${project.basedir}/src/main/spark31+all/scala - ${project.basedir}/src/main/spark31+apache/scala + ${project.basedir}/src/main/301until320-all/scala + ${project.basedir}/src/main/311+-all/scala + ${project.basedir}/src/main/311+-apache/scala @@ -448,9 +444,9 @@ generate-sources - ${project.basedir}/src/main/spark30+all/scala - ${project.basedir}/src/main/spark31+all/scala - ${project.basedir}/src/main/spark31+apache/scala + ${project.basedir}/src/main/301until320-all/scala + ${project.basedir}/src/main/311+-all/scala + ${project.basedir}/src/main/311+-apache/scala @@ -486,9 +482,9 @@ generate-sources - ${project.basedir}/src/main/spark30+all/scala - ${project.basedir}/src/main/spark31+all/scala - ${project.basedir}/src/main/spark31+apache/scala + ${project.basedir}/src/main/301until320-all/scala + ${project.basedir}/src/main/311+-all/scala + ${project.basedir}/src/main/311+-apache/scala @@ -524,7 +520,7 @@ generate-sources - ${project.basedir}/src/main/spark31+all/scala + ${project.basedir}/src/main/311+-all/scala @@ -557,8 +553,8 @@ generate-sources - ${project.basedir}/src/main/spark30+all/scala - ${project.basedir}/src/main/spark31+all/scala + ${project.basedir}/src/main/301until320-all/scala + ${project.basedir}/src/main/311+-all/scala diff --git a/shims/spark301/pom.xml b/shims/spark301/pom.xml index 02628b548b6..dfde7c38ca1 100644 --- a/shims/spark301/pom.xml +++ b/shims/spark301/pom.xml @@ -82,4 +82,41 @@ provided + + + + no-buildver-default + + + !buildver + + + + ${project.basedir}/../../sql-plugin + + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-profile-src-default + add-source + generate-sources + + + ${spark-rapids.sql-plugin.root}/src/main/301until320-all/scala + + + + + + + + + diff --git a/shims/spark301/src/main/scala/com/nvidia/spark/rapids/shims/spark301/RapidsShuffleInternalManager.scala b/shims/spark301/src/main/scala/com/nvidia/spark/rapids/shims/spark301/RapidsShuffleInternalManager.scala index 637108fb580..60a7cdda6ae 100644 --- a/shims/spark301/src/main/scala/com/nvidia/spark/rapids/shims/spark301/RapidsShuffleInternalManager.scala +++ b/shims/spark301/src/main/scala/com/nvidia/spark/rapids/shims/spark301/RapidsShuffleInternalManager.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.shims.spark301 import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.shuffle._ -import org.apache.spark.sql.rapids.RapidsShuffleInternalManagerBase +import org.apache.spark.sql.rapids.{ProxyRapidsShuffleInternalManagerBase, RapidsShuffleInternalManagerBase} /** * A shuffle manager optimized for the RAPIDS Plugin For Apache Spark. @@ -50,3 +50,30 @@ class RapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean) } } + +class ProxyRapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean) + extends ProxyRapidsShuffleInternalManagerBase(conf, isDriver) { + + override def getReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter + ): org.apache.spark.shuffle.ShuffleReader[K,C] = { + self.getReader(handle, startPartition, endPartition, context, metrics) + } + + override def getReaderForRange[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter + ): ShuffleReader[K,C] = { + self.getReaderForRange(handle, startMapIndex, endMapIndex, startPartition, endPartition, + context, metrics) + } +} \ No newline at end of file diff --git a/shims/spark301/src/main/scala/com/nvidia/spark/rapids/spark301/RapidsShuffleManager.scala b/shims/spark301/src/main/scala/com/nvidia/spark/rapids/spark301/RapidsShuffleManager.scala index b5628fdc47f..8359be2eb9c 100644 --- a/shims/spark301/src/main/scala/com/nvidia/spark/rapids/spark301/RapidsShuffleManager.scala +++ b/shims/spark301/src/main/scala/com/nvidia/spark/rapids/spark301/RapidsShuffleManager.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,10 +17,10 @@ package com.nvidia.spark.rapids.spark301 import org.apache.spark.SparkConf -import org.apache.spark.sql.rapids.shims.spark301.RapidsShuffleInternalManager +import org.apache.spark.sql.rapids.shims.spark301.ProxyRapidsShuffleInternalManager /** A shuffle manager optimized for the RAPIDS Plugin for Apache Spark. */ sealed class RapidsShuffleManager( conf: SparkConf, - isDriver: Boolean) extends RapidsShuffleInternalManager(conf, isDriver) { + isDriver: Boolean) extends ProxyRapidsShuffleInternalManager(conf, isDriver) { } diff --git a/shims/spark301db/src/main/scala/com/nvidia/spark/rapids/spark301db/RapidsShuffleManager.scala b/shims/spark301db/src/main/scala/com/nvidia/spark/rapids/spark301db/RapidsShuffleManager.scala index 52e5fc2a807..b35ac4f2a54 100644 --- a/shims/spark301db/src/main/scala/com/nvidia/spark/rapids/spark301db/RapidsShuffleManager.scala +++ b/shims/spark301db/src/main/scala/com/nvidia/spark/rapids/spark301db/RapidsShuffleManager.scala @@ -17,10 +17,10 @@ package com.nvidia.spark.rapids.spark301db import org.apache.spark.SparkConf -import org.apache.spark.sql.rapids.shims.spark301db.RapidsShuffleInternalManager +import org.apache.spark.rapids.shims.v2.ProxyRapidsShuffleInternalManager /** A shuffle manager optimized for the RAPIDS Plugin for Apache Spark. */ sealed class RapidsShuffleManager( conf: SparkConf, - isDriver: Boolean) extends RapidsShuffleInternalManager(conf, isDriver) { + isDriver: Boolean) extends ProxyRapidsShuffleInternalManager(conf, isDriver) { } diff --git a/shims/spark301emr/pom.xml b/shims/spark301emr/pom.xml index 5b23e26f197..6693094061e 100644 --- a/shims/spark301emr/pom.xml +++ b/shims/spark301emr/pom.xml @@ -82,4 +82,42 @@ provided + + + + + no-buildver-default + + + !buildver + + + + ${project.basedir}/../../sql-plugin + + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-profile-src-default + add-source + generate-sources + + + ${spark-rapids.sql-plugin.root}/src/main/301until320-all/scala + + + + + + + + + diff --git a/shims/spark301emr/src/main/scala/com/nvidia/spark/rapids/shims/spark301emr/RapidsShuffleInternalManager.scala b/shims/spark301emr/src/main/scala/com/nvidia/spark/rapids/shims/spark301emr/RapidsShuffleInternalManager.scala index 6a93ca4130a..f9c91380970 100644 --- a/shims/spark301emr/src/main/scala/com/nvidia/spark/rapids/shims/spark301emr/RapidsShuffleInternalManager.scala +++ b/shims/spark301emr/src/main/scala/com/nvidia/spark/rapids/shims/spark301emr/RapidsShuffleInternalManager.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.shims.spark301emr import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.shuffle._ -import org.apache.spark.sql.rapids.RapidsShuffleInternalManagerBase +import org.apache.spark.sql.rapids.{ProxyRapidsShuffleInternalManagerBase, RapidsShuffleInternalManagerBase} /** * A shuffle manager optimized for the RAPIDS Plugin For Apache Spark. @@ -50,3 +50,31 @@ class RapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean) } } + + +class ProxyRapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean) + extends ProxyRapidsShuffleInternalManagerBase(conf, isDriver) { + + override def getReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter + ): org.apache.spark.shuffle.ShuffleReader[K,C] = { + self.getReader(handle, startPartition, endPartition, context, metrics) + } + + override def getReaderForRange[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter + ): ShuffleReader[K,C] = { + self.getReaderForRange(handle, startMapIndex, endMapIndex, startPartition, endPartition, + context, metrics) + } +} \ No newline at end of file diff --git a/shims/spark301emr/src/main/scala/com/nvidia/spark/rapids/spark301emr/RapidsShuffleManager.scala b/shims/spark301emr/src/main/scala/com/nvidia/spark/rapids/spark301emr/RapidsShuffleManager.scala index c90e5bcf1b2..c17f14fbac0 100644 --- a/shims/spark301emr/src/main/scala/com/nvidia/spark/rapids/spark301emr/RapidsShuffleManager.scala +++ b/shims/spark301emr/src/main/scala/com/nvidia/spark/rapids/spark301emr/RapidsShuffleManager.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,10 +17,10 @@ package com.nvidia.spark.rapids.spark301emr import org.apache.spark.SparkConf -import org.apache.spark.sql.rapids.shims.spark301emr.RapidsShuffleInternalManager +import org.apache.spark.sql.rapids.shims.spark301emr.ProxyRapidsShuffleInternalManager /** A shuffle manager optimized for the RAPIDS Plugin for Apache Spark. */ sealed class RapidsShuffleManager( conf: SparkConf, - isDriver: Boolean) extends RapidsShuffleInternalManager(conf, isDriver) { + isDriver: Boolean) extends ProxyRapidsShuffleInternalManager(conf, isDriver) { } diff --git a/shims/spark302/pom.xml b/shims/spark302/pom.xml index 20cc7af983e..a0d77a8e539 100644 --- a/shims/spark302/pom.xml +++ b/shims/spark302/pom.xml @@ -82,4 +82,41 @@ provided + + + + no-buildver-default + + + !buildver + + + + ${project.basedir}/../../sql-plugin + + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-profile-src-default + add-source + generate-sources + + + ${spark-rapids.sql-plugin.root}/src/main/301until320-all/scala +\ + + + + + + + + diff --git a/shims/spark302/src/main/scala/com/nvidia/spark/rapids/shims/spark302/RapidsShuffleInternalManager.scala b/shims/spark302/src/main/scala/com/nvidia/spark/rapids/shims/spark302/RapidsShuffleInternalManager.scala index f07de9f473e..d59761f3ce9 100644 --- a/shims/spark302/src/main/scala/com/nvidia/spark/rapids/shims/spark302/RapidsShuffleInternalManager.scala +++ b/shims/spark302/src/main/scala/com/nvidia/spark/rapids/shims/spark302/RapidsShuffleInternalManager.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.shims.spark302 import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.shuffle._ -import org.apache.spark.sql.rapids.RapidsShuffleInternalManagerBase +import org.apache.spark.sql.rapids.{ProxyRapidsShuffleInternalManagerBase, RapidsShuffleInternalManagerBase} /** * A shuffle manager optimized for the RAPIDS Plugin For Apache Spark. @@ -50,3 +50,30 @@ class RapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean) } } + +class ProxyRapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean) + extends ProxyRapidsShuffleInternalManagerBase(conf, isDriver) { + + override def getReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter + ): org.apache.spark.shuffle.ShuffleReader[K,C] = { + self.getReader(handle, startPartition, endPartition, context, metrics) + } + + override def getReaderForRange[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter + ): ShuffleReader[K,C] = { + self.getReaderForRange(handle, startMapIndex, endMapIndex, startPartition, endPartition, + context, metrics) + } +} \ No newline at end of file diff --git a/shims/spark302/src/main/scala/com/nvidia/spark/rapids/spark302/RapidsShuffleManager.scala b/shims/spark302/src/main/scala/com/nvidia/spark/rapids/spark302/RapidsShuffleManager.scala index 0c8dfc22950..0153eac46ee 100644 --- a/shims/spark302/src/main/scala/com/nvidia/spark/rapids/spark302/RapidsShuffleManager.scala +++ b/shims/spark302/src/main/scala/com/nvidia/spark/rapids/spark302/RapidsShuffleManager.scala @@ -17,10 +17,10 @@ package com.nvidia.spark.rapids.spark302 import org.apache.spark.SparkConf -import org.apache.spark.sql.rapids.shims.spark302.RapidsShuffleInternalManager +import org.apache.spark.sql.rapids.shims.spark302.ProxyRapidsShuffleInternalManager /** A shuffle manager optimized for the RAPIDS Plugin for Apache Spark. */ sealed class RapidsShuffleManager( conf: SparkConf, - isDriver: Boolean) extends RapidsShuffleInternalManager(conf, isDriver) { + isDriver: Boolean) extends ProxyRapidsShuffleInternalManager(conf, isDriver) { } diff --git a/shims/spark303/pom.xml b/shims/spark303/pom.xml index db2badfc6fe..d8de0bc44ea 100644 --- a/shims/spark303/pom.xml +++ b/shims/spark303/pom.xml @@ -82,4 +82,41 @@ provided + + + + no-buildver-default + + + !buildver + + + + ${project.basedir}/../../sql-plugin + + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-profile-src-default + add-source + generate-sources + + + ${spark-rapids.sql-plugin.root}/src/main/301until320-all/scala + + + + + + + + + diff --git a/shims/spark303/src/main/scala/com/nvidia/spark/rapids/shims/spark303/RapidsShuffleInternalManager.scala b/shims/spark303/src/main/scala/com/nvidia/spark/rapids/shims/spark303/RapidsShuffleInternalManager.scala index 9d4b4e9ebdc..022aa65f887 100644 --- a/shims/spark303/src/main/scala/com/nvidia/spark/rapids/shims/spark303/RapidsShuffleInternalManager.scala +++ b/shims/spark303/src/main/scala/com/nvidia/spark/rapids/shims/spark303/RapidsShuffleInternalManager.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.shims.spark303 import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.shuffle._ -import org.apache.spark.sql.rapids.RapidsShuffleInternalManagerBase +import org.apache.spark.sql.rapids.{ProxyRapidsShuffleInternalManagerBase, RapidsShuffleInternalManagerBase} /** * A shuffle manager optimized for the RAPIDS Plugin For Apache Spark. @@ -50,3 +50,31 @@ class RapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean) } } + + +class ProxyRapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean) + extends ProxyRapidsShuffleInternalManagerBase(conf, isDriver) { + + override def getReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter + ): org.apache.spark.shuffle.ShuffleReader[K,C] = { + self.getReader(handle, startPartition, endPartition, context, metrics) + } + + override def getReaderForRange[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter + ): ShuffleReader[K,C] = { + self.getReaderForRange(handle, startMapIndex, endMapIndex, startPartition, endPartition, + context, metrics) + } +} \ No newline at end of file diff --git a/shims/spark303/src/main/scala/com/nvidia/spark/rapids/spark303/RapidsShuffleManager.scala b/shims/spark303/src/main/scala/com/nvidia/spark/rapids/spark303/RapidsShuffleManager.scala index 2b8b404f40b..f3c277cb61a 100644 --- a/shims/spark303/src/main/scala/com/nvidia/spark/rapids/spark303/RapidsShuffleManager.scala +++ b/shims/spark303/src/main/scala/com/nvidia/spark/rapids/spark303/RapidsShuffleManager.scala @@ -17,10 +17,10 @@ package com.nvidia.spark.rapids.spark303 import org.apache.spark.SparkConf -import org.apache.spark.sql.rapids.shims.spark303.RapidsShuffleInternalManager +import org.apache.spark.sql.rapids.shims.spark303.ProxyRapidsShuffleInternalManager /** A shuffle manager optimized for the RAPIDS Plugin for Apache Spark. */ sealed class RapidsShuffleManager( conf: SparkConf, - isDriver: Boolean) extends RapidsShuffleInternalManager(conf, isDriver) { + isDriver: Boolean) extends ProxyRapidsShuffleInternalManager(conf, isDriver) { } diff --git a/shims/spark304/pom.xml b/shims/spark304/pom.xml index 3de5548ff18..d2fb5071448 100644 --- a/shims/spark304/pom.xml +++ b/shims/spark304/pom.xml @@ -82,4 +82,41 @@ provided + + + + no-buildver-default + + + !buildver + + + + ${project.basedir}/../../sql-plugin + + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-profile-src-default + add-source + generate-sources + + + ${spark-rapids.sql-plugin.root}/src/main/301until320-all/scala + + + + + + + + + diff --git a/shims/spark304/src/main/scala/com/nvidia/spark/rapids/shims/spark304/RapidsShuffleInternalManager.scala b/shims/spark304/src/main/scala/com/nvidia/spark/rapids/shims/spark304/RapidsShuffleInternalManager.scala index 3c7b100dd92..d53ecafe1c5 100644 --- a/shims/spark304/src/main/scala/com/nvidia/spark/rapids/shims/spark304/RapidsShuffleInternalManager.scala +++ b/shims/spark304/src/main/scala/com/nvidia/spark/rapids/shims/spark304/RapidsShuffleInternalManager.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.shims.spark304 import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.shuffle._ -import org.apache.spark.sql.rapids.RapidsShuffleInternalManagerBase +import org.apache.spark.sql.rapids.{ProxyRapidsShuffleInternalManagerBase, RapidsShuffleInternalManagerBase} /** * A shuffle manager optimized for the RAPIDS Plugin For Apache Spark. @@ -50,3 +50,31 @@ class RapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean) } } + + +class ProxyRapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean) + extends ProxyRapidsShuffleInternalManagerBase(conf, isDriver) { + + override def getReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter + ): org.apache.spark.shuffle.ShuffleReader[K,C] = { + self.getReader(handle, startPartition, endPartition, context, metrics) + } + + override def getReaderForRange[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter + ): ShuffleReader[K,C] = { + self.getReaderForRange(handle, startMapIndex, endMapIndex, startPartition, endPartition, + context, metrics) + } +} \ No newline at end of file diff --git a/shims/spark304/src/main/scala/com/nvidia/spark/rapids/spark304/RapidsShuffleManager.scala b/shims/spark304/src/main/scala/com/nvidia/spark/rapids/spark304/RapidsShuffleManager.scala index 5c7cdbbfe20..4fbb4684bf5 100644 --- a/shims/spark304/src/main/scala/com/nvidia/spark/rapids/spark304/RapidsShuffleManager.scala +++ b/shims/spark304/src/main/scala/com/nvidia/spark/rapids/spark304/RapidsShuffleManager.scala @@ -17,10 +17,10 @@ package com.nvidia.spark.rapids.spark304 import org.apache.spark.SparkConf -import org.apache.spark.sql.rapids.shims.spark304.RapidsShuffleInternalManager +import org.apache.spark.sql.rapids.shims.spark304.ProxyRapidsShuffleInternalManager /** A shuffle manager optimized for the RAPIDS Plugin for Apache Spark. */ sealed class RapidsShuffleManager( conf: SparkConf, - isDriver: Boolean) extends RapidsShuffleInternalManager(conf, isDriver) { + isDriver: Boolean) extends ProxyRapidsShuffleInternalManager(conf, isDriver) { } diff --git a/shims/spark311/pom.xml b/shims/spark311/pom.xml index bae129ce488..623b4dbc729 100644 --- a/shims/spark311/pom.xml +++ b/shims/spark311/pom.xml @@ -57,12 +57,8 @@ generate-sources - ${spark-rapids.sql-plugin.root}/src/main/spark311/scala - ${spark-rapids.sql-plugin.root}/src/main/spark311/java - ${spark-rapids.sql-plugin.root}/src/main/spark31+all/scala - ${spark-rapids.sql-plugin.root}/src/main/spark31+all/java - ${spark-rapids.sql-plugin.root}/src/main/spark31+apache/scala - ${spark-rapids.sql-plugin.root}/src/main/spark31+apache/java + ${spark-rapids.sql-plugin.root}/src/main/311+-all/scala + ${spark-rapids.sql-plugin.root}/src/main/311+-apache/scala diff --git a/shims/spark311/src/main/scala/com/nvidia/spark/rapids/spark311/RapidsShuffleManager.scala b/shims/spark311/src/main/scala/com/nvidia/spark/rapids/spark311/RapidsShuffleManager.scala index 018f5421b7e..c517f71d9f9 100644 --- a/shims/spark311/src/main/scala/com/nvidia/spark/rapids/spark311/RapidsShuffleManager.scala +++ b/shims/spark311/src/main/scala/com/nvidia/spark/rapids/spark311/RapidsShuffleManager.scala @@ -17,10 +17,10 @@ package com.nvidia.spark.rapids.spark311 import org.apache.spark.SparkConf -import org.apache.spark.sql.rapids.shims.spark311.RapidsShuffleInternalManager +import org.apache.spark.sql.rapids.shims.spark311.ProxyRapidsShuffleInternalManager /** A shuffle manager optimized for the RAPIDS Plugin for Apache Spark. */ sealed class RapidsShuffleManager( conf: SparkConf, - isDriver: Boolean) extends RapidsShuffleInternalManager(conf, isDriver) { + isDriver: Boolean) extends ProxyRapidsShuffleInternalManager(conf, isDriver) { } diff --git a/shims/spark311/src/main/scala/org/apache/spark/sql/rapids/shims/spark311/RapidsShuffleInternalManager.scala b/shims/spark311/src/main/scala/org/apache/spark/sql/rapids/shims/spark311/RapidsShuffleInternalManager.scala index 72870b7eea6..1a32ba393a5 100644 --- a/shims/spark311/src/main/scala/org/apache/spark/sql/rapids/shims/spark311/RapidsShuffleInternalManager.scala +++ b/shims/spark311/src/main/scala/org/apache/spark/sql/rapids/shims/spark311/RapidsShuffleInternalManager.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.shims.spark311 import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.shuffle._ -import org.apache.spark.sql.rapids.RapidsShuffleInternalManagerBase +import org.apache.spark.sql.rapids.{ProxyRapidsShuffleInternalManagerBase, RapidsShuffleInternalManagerBase} /** * A shuffle manager optimized for the RAPIDS Plugin For Apache Spark. @@ -40,3 +40,21 @@ class RapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean) metrics) } } + + +class ProxyRapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean) + extends ProxyRapidsShuffleInternalManagerBase(conf, isDriver) { + + def getReader[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter + ): ShuffleReader[K, C] = { + self.getReader(handle, startMapIndex, endMapIndex, startPartition, endPartition, context, + metrics) + } +} \ No newline at end of file diff --git a/shims/spark311cdh/pom.xml b/shims/spark311cdh/pom.xml index 94662a046dd..3f60e0b7eb2 100644 --- a/shims/spark311cdh/pom.xml +++ b/shims/spark311cdh/pom.xml @@ -57,10 +57,8 @@ generate-sources - ${spark-rapids.sql-plugin.root}/src/main/spark311cdh/scala - ${spark-rapids.sql-plugin.root}/src/main/spark311cdh/java - ${spark-rapids.sql-plugin.root}/src/main/spark31+all/scala - ${spark-rapids.sql-plugin.root}/src/main/spark31+all/java + ${spark-rapids.sql-plugin.root}/src/main/311cdh/scala + ${spark-rapids.sql-plugin.root}/src/main/311+-all/scala diff --git a/shims/spark311cdh/src/main/scala/com/nvidia/spark/rapids/spark311cdh/RapidsShuffleManager.scala b/shims/spark311cdh/src/main/scala/com/nvidia/spark/rapids/spark311cdh/RapidsShuffleManager.scala index 21b7e8c734f..6748239f37b 100644 --- a/shims/spark311cdh/src/main/scala/com/nvidia/spark/rapids/spark311cdh/RapidsShuffleManager.scala +++ b/shims/spark311cdh/src/main/scala/com/nvidia/spark/rapids/spark311cdh/RapidsShuffleManager.scala @@ -17,10 +17,11 @@ package com.nvidia.spark.rapids.spark311cdh import org.apache.spark.SparkConf -import org.apache.spark.sql.rapids.shims.spark311cdh.RapidsShuffleInternalManager +import org.apache.spark.sql.rapids.shims.spark311cdh.ProxyRapidsShuffleInternalManager /** A shuffle manager optimized for the RAPIDS Plugin for Apache Spark. */ sealed class RapidsShuffleManager( conf: SparkConf, - isDriver: Boolean) extends RapidsShuffleInternalManager(conf, isDriver) { + isDriver: Boolean) extends ProxyRapidsShuffleInternalManager(conf, isDriver) { } + diff --git a/shims/spark311cdh/src/main/scala/org/apache/spark/sql/rapids/shims/spark311cdh/RapidsShuffleInternalManager.scala b/shims/spark311cdh/src/main/scala/org/apache/spark/sql/rapids/shims/spark311cdh/RapidsShuffleInternalManager.scala index e5c9b132b9e..50f3a8581d8 100644 --- a/shims/spark311cdh/src/main/scala/org/apache/spark/sql/rapids/shims/spark311cdh/RapidsShuffleInternalManager.scala +++ b/shims/spark311cdh/src/main/scala/org/apache/spark/sql/rapids/shims/spark311cdh/RapidsShuffleInternalManager.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.shims.spark311cdh import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.shuffle._ -import org.apache.spark.sql.rapids.RapidsShuffleInternalManagerBase +import org.apache.spark.sql.rapids.{ProxyRapidsShuffleInternalManagerBase, RapidsShuffleInternalManagerBase} /** * A shuffle manager optimized for the RAPIDS Plugin For Apache Spark. @@ -40,3 +40,20 @@ class RapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean) metrics) } } + +class ProxyRapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean) + extends ProxyRapidsShuffleInternalManagerBase(conf, isDriver) { + + def getReader[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter + ): ShuffleReader[K, C] = { + self.getReader(handle, startMapIndex, endMapIndex, startPartition, endPartition, context, + metrics) + } +} diff --git a/shims/spark311db/src/main/scala/com/nvidia/spark/rapids/spark311db/RapidsShuffleManager.scala b/shims/spark311db/src/main/scala/com/nvidia/spark/rapids/spark311db/RapidsShuffleManager.scala index 07e761b1b70..c95830e86e0 100644 --- a/shims/spark311db/src/main/scala/com/nvidia/spark/rapids/spark311db/RapidsShuffleManager.scala +++ b/shims/spark311db/src/main/scala/com/nvidia/spark/rapids/spark311db/RapidsShuffleManager.scala @@ -22,5 +22,4 @@ import org.apache.spark.sql.rapids.shims.spark311db.RapidsShuffleInternalManager /** A shuffle manager optimized for the RAPIDS Plugin for Apache Spark. */ sealed class RapidsShuffleManager( conf: SparkConf, - isDriver: Boolean) extends RapidsShuffleInternalManager(conf, isDriver) { -} + isDriver: Boolean) extends ProxyRapidsShuffleInternalManager(conf, isDriver) diff --git a/shims/spark312/pom.xml b/shims/spark312/pom.xml index 5d6e6f23d9a..7e0f33c1a81 100644 --- a/shims/spark312/pom.xml +++ b/shims/spark312/pom.xml @@ -57,12 +57,8 @@ generate-sources - ${spark-rapids.sql-plugin.root}/src/main/spark312/scala - ${spark-rapids.sql-plugin.root}/src/main/spark312/java - ${spark-rapids.sql-plugin.root}/src/main/spark31+all/scala - ${spark-rapids.sql-plugin.root}/src/main/spark31+all/java - ${spark-rapids.sql-plugin.root}/src/main/spark31+apache/scala - ${spark-rapids.sql-plugin.root}/src/main/spark31+apache/java + ${spark-rapids.sql-plugin.root}/src/main/311+-all/scala + ${spark-rapids.sql-plugin.root}/src/main/311+-apache/scala diff --git a/shims/spark312/src/main/scala/com/nvidia/spark/rapids/spark312/RapidsShuffleManager.scala b/shims/spark312/src/main/scala/com/nvidia/spark/rapids/spark312/RapidsShuffleManager.scala index cc33d09137d..2383d7fd4cf 100644 --- a/shims/spark312/src/main/scala/com/nvidia/spark/rapids/spark312/RapidsShuffleManager.scala +++ b/shims/spark312/src/main/scala/com/nvidia/spark/rapids/spark312/RapidsShuffleManager.scala @@ -17,10 +17,10 @@ package com.nvidia.spark.rapids.spark312 import org.apache.spark.SparkConf -import org.apache.spark.sql.rapids.shims.spark312.RapidsShuffleInternalManager +import org.apache.spark.sql.rapids.shims.spark312.ProxyRapidsShuffleInternalManager /** A shuffle manager optimized for the RAPIDS Plugin for Apache Spark. */ sealed class RapidsShuffleManager( conf: SparkConf, - isDriver: Boolean) extends RapidsShuffleInternalManager(conf, isDriver) { + isDriver: Boolean) extends ProxyRapidsShuffleInternalManager(conf, isDriver) { } diff --git a/shims/spark312/src/main/scala/org/apache/spark/sql/rapids/shims/spark312/RapidsShuffleInternalManager.scala b/shims/spark312/src/main/scala/org/apache/spark/sql/rapids/shims/spark312/RapidsShuffleInternalManager.scala index b25f33239f0..9b4ec767d8d 100644 --- a/shims/spark312/src/main/scala/org/apache/spark/sql/rapids/shims/spark312/RapidsShuffleInternalManager.scala +++ b/shims/spark312/src/main/scala/org/apache/spark/sql/rapids/shims/spark312/RapidsShuffleInternalManager.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.shims.spark312 import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.shuffle._ -import org.apache.spark.sql.rapids.RapidsShuffleInternalManagerBase +import org.apache.spark.sql.rapids.{ProxyRapidsShuffleInternalManagerBase, RapidsShuffleInternalManagerBase} /** * A shuffle manager optimized for the RAPIDS Plugin For Apache Spark. @@ -40,3 +40,21 @@ class RapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean) metrics) } } + + +class ProxyRapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean) + extends ProxyRapidsShuffleInternalManagerBase(conf, isDriver) { + + def getReader[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter + ): ShuffleReader[K, C] = { + self.getReader(handle, startMapIndex, endMapIndex, startPartition, endPartition, context, + metrics) + } +} \ No newline at end of file diff --git a/shims/spark313/pom.xml b/shims/spark313/pom.xml index 37276779ce9..dd3e117427d 100644 --- a/shims/spark313/pom.xml +++ b/shims/spark313/pom.xml @@ -82,4 +82,42 @@ provided + + + + no-buildver-default + + + !buildver + + + + ${project.basedir}/../../sql-plugin + + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-profile-src-default + add-source + generate-sources + + + ${spark-rapids.sql-plugin.root}/src/main/311+-all/scala + ${spark-rapids.sql-plugin.root}/src/main/311+-apache/scala + + + + + + + + + diff --git a/shims/spark313/src/main/scala/com/nvidia/spark/rapids/spark313/RapidsShuffleManager.scala b/shims/spark313/src/main/scala/com/nvidia/spark/rapids/spark313/RapidsShuffleManager.scala index 697b9dfdfee..47dda977ff7 100644 --- a/shims/spark313/src/main/scala/com/nvidia/spark/rapids/spark313/RapidsShuffleManager.scala +++ b/shims/spark313/src/main/scala/com/nvidia/spark/rapids/spark313/RapidsShuffleManager.scala @@ -17,10 +17,10 @@ package com.nvidia.spark.rapids.spark313 import org.apache.spark.SparkConf -import org.apache.spark.sql.rapids.shims.spark313.RapidsShuffleInternalManager +import org.apache.spark.sql.rapids.shims.spark313.ProxyRapidsShuffleInternalManager /** A shuffle manager optimized for the RAPIDS Plugin for Apache Spark. */ sealed class RapidsShuffleManager( conf: SparkConf, - isDriver: Boolean) extends RapidsShuffleInternalManager(conf, isDriver) { + isDriver: Boolean) extends ProxyRapidsShuffleInternalManager(conf, isDriver) { } diff --git a/shims/spark313/src/main/scala/org/apache/spark/sql/rapids/shims/spark313/RapidsShuffleInternalManager.scala b/shims/spark313/src/main/scala/org/apache/spark/sql/rapids/shims/spark313/RapidsShuffleInternalManager.scala index dc6b4630f5a..71e8a89888a 100644 --- a/shims/spark313/src/main/scala/org/apache/spark/sql/rapids/shims/spark313/RapidsShuffleInternalManager.scala +++ b/shims/spark313/src/main/scala/org/apache/spark/sql/rapids/shims/spark313/RapidsShuffleInternalManager.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.shims.spark313 import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.shuffle.{ShuffleHandle, ShuffleReader, ShuffleReadMetricsReporter} -import org.apache.spark.sql.rapids.RapidsShuffleInternalManagerBase +import org.apache.spark.sql.rapids.{ProxyRapidsShuffleInternalManagerBase, RapidsShuffleInternalManagerBase} /** * A shuffle manager optimized for the RAPIDS Plugin For Apache Spark. @@ -39,3 +39,21 @@ class RapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean) metrics) } } + + +class ProxyRapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean) + extends ProxyRapidsShuffleInternalManagerBase(conf, isDriver) { + + def getReader[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter + ): ShuffleReader[K, C] = { + self.getReader(handle, startMapIndex, endMapIndex, startPartition, endPartition, context, + metrics) + } +} \ No newline at end of file diff --git a/shims/spark320/src/main/scala/com/nvidia/spark/rapids/spark320/RapidsShuffleManager.scala b/shims/spark320/src/main/scala/com/nvidia/spark/rapids/spark320/RapidsShuffleManager.scala index 2dffdb89576..ccab7bb3c33 100644 --- a/shims/spark320/src/main/scala/com/nvidia/spark/rapids/spark320/RapidsShuffleManager.scala +++ b/shims/spark320/src/main/scala/com/nvidia/spark/rapids/spark320/RapidsShuffleManager.scala @@ -17,10 +17,10 @@ package com.nvidia.spark.rapids.spark320 import org.apache.spark.SparkConf -import org.apache.spark.sql.rapids.shims.spark320.RapidsShuffleInternalManager +import org.apache.spark.sql.rapids.shims.spark320.ProxyRapidsShuffleInternalManager /** A shuffle manager optimized for the RAPIDS Plugin for Apache Spark. */ sealed class RapidsShuffleManager( conf: SparkConf, - isDriver: Boolean) extends RapidsShuffleInternalManager(conf, isDriver) { + isDriver: Boolean) extends ProxyRapidsShuffleInternalManager(conf, isDriver) { } diff --git a/shims/spark320/src/main/scala/org/apache/spark/sql/rapids/shims/spark320/RapidsShuffleInternalManager.scala b/shims/spark320/src/main/scala/org/apache/spark/sql/rapids/shims/spark320/RapidsShuffleInternalManager.scala index 0fed708abd8..9180ce6f460 100644 --- a/shims/spark320/src/main/scala/org/apache/spark/sql/rapids/shims/spark320/RapidsShuffleInternalManager.scala +++ b/shims/spark320/src/main/scala/org/apache/spark/sql/rapids/shims/spark320/RapidsShuffleInternalManager.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.shims.spark320 import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.shuffle._ -import org.apache.spark.sql.rapids.RapidsShuffleInternalManagerBase +import org.apache.spark.sql.rapids.{ProxyRapidsShuffleInternalManagerBase, RapidsShuffleInternalManagerBase} /** * A shuffle manager optimized for the RAPIDS Plugin For Apache Spark. @@ -40,3 +40,21 @@ class RapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean) metrics) } } + + +class ProxyRapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean) + extends ProxyRapidsShuffleInternalManagerBase(conf, isDriver) { + + def getReader[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter + ): ShuffleReader[K, C] = { + self.getReader(handle, startMapIndex, endMapIndex, startPartition, endPartition, context, + metrics) + } +} \ No newline at end of file diff --git a/sql-plugin/src/main/spark30+all/scala/com/nvidia/spark/rapids/shims/v2/ShimAQEShuffleReadExec.scala b/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/ShimAQEShuffleReadExec.scala similarity index 100% rename from sql-plugin/src/main/spark30+all/scala/com/nvidia/spark/rapids/shims/v2/ShimAQEShuffleReadExec.scala rename to sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/ShimAQEShuffleReadExec.scala diff --git a/sql-plugin/src/main/spark30+all/scala/com/nvidia/spark/rapids/shims/v2/ShimDataSourceRDD.scala b/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/ShimDataSourceRDD.scala similarity index 97% rename from sql-plugin/src/main/spark30+all/scala/com/nvidia/spark/rapids/shims/v2/ShimDataSourceRDD.scala rename to sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/ShimDataSourceRDD.scala index 509cc3770e5..407a9a6c47d 100644 --- a/sql-plugin/src/main/spark30+all/scala/com/nvidia/spark/rapids/shims/v2/ShimDataSourceRDD.scala +++ b/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/ShimDataSourceRDD.scala @@ -24,5 +24,5 @@ class ShimDataSourceRDD( sc: SparkContext, @transient private val inputPartitions: Seq[InputPartition], partitionReaderFactory: PartitionReaderFactory, - columnarReads: Boolean, + columnarReads: Boolean ) extends DataSourceRDD(sc, inputPartitions, partitionReaderFactory, columnarReads) diff --git a/sql-plugin/src/main/spark30+all/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala b/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala similarity index 100% rename from sql-plugin/src/main/spark30+all/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala rename to sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala diff --git a/sql-plugin/src/main/spark30+all/scala/com/nvidia/spark/rapids/shims/v2/TreeNode.scala b/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/TreeNode.scala similarity index 100% rename from sql-plugin/src/main/spark30+all/scala/com/nvidia/spark/rapids/shims/v2/TreeNode.scala rename to sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/TreeNode.scala diff --git a/sql-plugin/src/main/spark30+all/scala/org/apache/spark/rapids/shims/v2/GpuShuffleBlockResolver.scala b/sql-plugin/src/main/301until320-all/scala/org/apache/spark/rapids/shims/v2/GpuShuffleBlockResolver.scala similarity index 90% rename from sql-plugin/src/main/spark30+all/scala/org/apache/spark/rapids/shims/v2/GpuShuffleBlockResolver.scala rename to sql-plugin/src/main/301until320-all/scala/org/apache/spark/rapids/shims/v2/GpuShuffleBlockResolver.scala index f1f0e4a6c1a..48671f5171b 100644 --- a/sql-plugin/src/main/spark30+all/scala/org/apache/spark/rapids/shims/v2/GpuShuffleBlockResolver.scala +++ b/sql-plugin/src/main/301until320-all/scala/org/apache/spark/rapids/shims/v2/GpuShuffleBlockResolver.scala @@ -18,10 +18,9 @@ package org.apache.spark.sql.rapids.shims.v2 import com.nvidia.spark.rapids.ShuffleBufferCatalog -import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.shuffle.IndexShuffleBlockResolver import org.apache.spark.sql.rapids.GpuShuffleBlockResolverBase -import org.apache.spark.storage.ShuffleBlockId class GpuShuffleBlockResolver(resolver: IndexShuffleBlockResolver, catalog: ShuffleBufferCatalog) extends GpuShuffleBlockResolverBase(resolver, catalog) + diff --git a/sql-plugin/src/main/spark30+all/scala/org/apache/spark/rapids/shims/v2/api/python/ShimBasePythonRunner.scala b/sql-plugin/src/main/301until320-all/scala/org/apache/spark/rapids/shims/v2/api/python/ShimBasePythonRunner.scala similarity index 100% rename from sql-plugin/src/main/spark30+all/scala/org/apache/spark/rapids/shims/v2/api/python/ShimBasePythonRunner.scala rename to sql-plugin/src/main/301until320-all/scala/org/apache/spark/rapids/shims/v2/api/python/ShimBasePythonRunner.scala diff --git a/sql-plugin/src/main/spark31+all/scala/com/nvidia/spark/rapids/shims/v2/ParquetCachedBatchSerializer.scala b/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/v2/ParquetCachedBatchSerializer.scala similarity index 100% rename from sql-plugin/src/main/spark31+all/scala/com/nvidia/spark/rapids/shims/v2/ParquetCachedBatchSerializer.scala rename to sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/v2/ParquetCachedBatchSerializer.scala diff --git a/sql-plugin/src/main/spark31+all/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ParquetMaterializer.scala b/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ParquetMaterializer.scala similarity index 100% rename from sql-plugin/src/main/spark31+all/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ParquetMaterializer.scala rename to sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ParquetMaterializer.scala diff --git a/sql-plugin/src/main/spark31+all/scala/org/apache/spark/sql/rapids/shims/v2/GpuColumnarToRowTransitionExec.scala b/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/v2/GpuColumnarToRowTransitionExec.scala similarity index 100% rename from sql-plugin/src/main/spark31+all/scala/org/apache/spark/sql/rapids/shims/v2/GpuColumnarToRowTransitionExec.scala rename to sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/v2/GpuColumnarToRowTransitionExec.scala diff --git a/sql-plugin/src/main/spark31+all/scala/org/apache/spark/sql/rapids/shims/v2/GpuInMemoryTableScanExec.scala b/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/v2/GpuInMemoryTableScanExec.scala similarity index 98% rename from sql-plugin/src/main/spark31+all/scala/org/apache/spark/sql/rapids/shims/v2/GpuInMemoryTableScanExec.scala rename to sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/v2/GpuInMemoryTableScanExec.scala index 166913dc21c..7da80e47c79 100644 --- a/sql-plugin/src/main/spark31+all/scala/org/apache/spark/sql/rapids/shims/v2/GpuInMemoryTableScanExec.scala +++ b/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/v2/GpuInMemoryTableScanExec.scala @@ -91,7 +91,8 @@ case class GpuInMemoryTableScanExec( override def outputOrdering: Seq[SortOrder] = relation.cachedPlan.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) - lazy val enableAccumulatorsForTest: Boolean = sparkSession.sqlContext.conf.inMemoryTableScanStatisticsEnabled + lazy val enableAccumulatorsForTest: Boolean = sparkSession.sqlContext + .conf.inMemoryTableScanStatisticsEnabled // Accumulators used for testing purposes lazy val readPartitions = sparkSession.sparkContext.longAccumulator diff --git a/sql-plugin/src/main/spark31+all/scala/org/apache/spark/sql/rapids/shims/v2/GpuSchemaUtils.scala b/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/v2/GpuSchemaUtils.scala similarity index 100% rename from sql-plugin/src/main/spark31+all/scala/org/apache/spark/sql/rapids/shims/v2/GpuSchemaUtils.scala rename to sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/v2/GpuSchemaUtils.scala diff --git a/sql-plugin/src/main/spark31+all/scala/org/apache/spark/sql/rapids/shims/v2/HadoopFSUtilsShim.scala b/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/v2/HadoopFSUtilsShim.scala similarity index 100% rename from sql-plugin/src/main/spark31+all/scala/org/apache/spark/sql/rapids/shims/v2/HadoopFSUtilsShim.scala rename to sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/v2/HadoopFSUtilsShim.scala diff --git a/sql-plugin/src/main/spark31+all/scala/org/apache/spark/sql/rapids/shims/v2/ShuffleManagerShim.scala b/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/v2/ShuffleManagerShim.scala similarity index 100% rename from sql-plugin/src/main/spark31+all/scala/org/apache/spark/sql/rapids/shims/v2/ShuffleManagerShim.scala rename to sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/v2/ShuffleManagerShim.scala diff --git a/sql-plugin/src/main/spark31+apache/scala/org/apache/spark/rapids/shims/v2/sql/execution/datasources/parquet/rapids/ShimVectorizedColumnReader.scala b/sql-plugin/src/main/311+-apache/scala/org/apache/spark/rapids/shims/v2/sql/execution/datasources/parquet/rapids/ShimVectorizedColumnReader.scala similarity index 100% rename from sql-plugin/src/main/spark31+apache/scala/org/apache/spark/rapids/shims/v2/sql/execution/datasources/parquet/rapids/ShimVectorizedColumnReader.scala rename to sql-plugin/src/main/311+-apache/scala/org/apache/spark/rapids/shims/v2/sql/execution/datasources/parquet/rapids/ShimVectorizedColumnReader.scala diff --git a/sql-plugin/src/main/spark311cdh/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/v2/ShimVectorizedColumnReader.scala b/sql-plugin/src/main/311cdh/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/v2/ShimVectorizedColumnReader.scala similarity index 100% rename from sql-plugin/src/main/spark311cdh/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/v2/ShimVectorizedColumnReader.scala rename to sql-plugin/src/main/311cdh/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/v2/ShimVectorizedColumnReader.scala diff --git a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/v2/ShimAQEShuffleReadExec.scala b/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/ShimAQEShuffleReadExec.scala similarity index 100% rename from sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/v2/ShimAQEShuffleReadExec.scala rename to sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/ShimAQEShuffleReadExec.scala diff --git a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/v2/ShimDataSourceRDD.scala b/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/ShimDataSourceRDD.scala similarity index 97% rename from sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/v2/ShimDataSourceRDD.scala rename to sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/ShimDataSourceRDD.scala index c68ad51c589..a69dc595a2d 100644 --- a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/v2/ShimDataSourceRDD.scala +++ b/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/ShimDataSourceRDD.scala @@ -25,6 +25,6 @@ class ShimDataSourceRDD( sc: SparkContext, @transient private val inputPartitions: Seq[InputPartition], partitionReaderFactory: PartitionReaderFactory, - columnarReads: Boolean, + columnarReads: Boolean ) extends DataSourceRDD(sc, inputPartitions, partitionReaderFactory, columnarReads, Map.empty[String, SQLMetric]) diff --git a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala b/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala similarity index 99% rename from sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala rename to sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala index bf97aad0764..dc9f17de2f6 100644 --- a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala +++ b/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala @@ -18,7 +18,6 @@ package com.nvidia.spark.rapids.shims.v2 import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.GpuOverrides.exec -import com.nvidia.spark.rapids.shims._ import org.apache.hadoop.fs.FileStatus import org.apache.parquet.schema.MessageType diff --git a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/v2/TreeNode.scala b/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/TreeNode.scala similarity index 100% rename from sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/v2/TreeNode.scala rename to sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/TreeNode.scala diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/rapids/shims/v2/GpuShuffleBlockResolver.scala b/sql-plugin/src/main/320/scala/org/apache/spark/rapids/shims/v2/GpuShuffleBlockResolver.scala similarity index 100% rename from sql-plugin/src/main/spark320/scala/org/apache/spark/rapids/shims/v2/GpuShuffleBlockResolver.scala rename to sql-plugin/src/main/320/scala/org/apache/spark/rapids/shims/v2/GpuShuffleBlockResolver.scala diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/rapids/shims/v2/api/python/ShimBasePythonRunner.scala b/sql-plugin/src/main/320/scala/org/apache/spark/rapids/shims/v2/api/python/ShimBasePythonRunner.scala similarity index 100% rename from sql-plugin/src/main/spark320/scala/org/apache/spark/rapids/shims/v2/api/python/ShimBasePythonRunner.scala rename to sql-plugin/src/main/320/scala/org/apache/spark/rapids/shims/v2/api/python/ShimBasePythonRunner.scala diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/Spark32XShimsUtils.scala b/sql-plugin/src/main/320/scala/org/apache/spark/sql/Spark32XShimsUtils.scala similarity index 100% rename from sql-plugin/src/main/spark320/scala/org/apache/spark/sql/Spark32XShimsUtils.scala rename to sql-plugin/src/main/320/scala/org/apache/spark/sql/Spark32XShimsUtils.scala diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/v2/ShimVectorizedColumnReader.scala b/sql-plugin/src/main/320/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/v2/ShimVectorizedColumnReader.scala similarity index 100% rename from sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/v2/ShimVectorizedColumnReader.scala rename to sql-plugin/src/main/320/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/v2/ShimVectorizedColumnReader.scala diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/SQLPlugin.scala b/sql-plugin/src/main/scala/com/nvidia/spark/SQLPlugin.scala index 9d4fc345f93..078dc5b7277 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/SQLPlugin.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/SQLPlugin.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ package com.nvidia.spark -import com.nvidia.spark.rapids.{RapidsDriverPlugin, RapidsExecutorPlugin} +import com.nvidia.spark.rapids.ShimLoader import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, SparkPlugin} import org.apache.spark.internal.Logging @@ -26,6 +26,7 @@ import org.apache.spark.internal.Logging * To enable this plugin, set the config "spark.plugins" to `com.nvidia.spark.SQLPlugin` */ class SQLPlugin extends SparkPlugin with Logging { - override def driverPlugin(): DriverPlugin = new RapidsDriverPlugin - override def executorPlugin(): ExecutorPlugin = new RapidsExecutorPlugin + override def driverPlugin(): DriverPlugin = ShimLoader.newDriverPlugin() + + override def executorPlugin(): ExecutorPlugin = ShimLoader.newExecutorPlugin() } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala index 59c9c57ae64..8cf92979b0b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala @@ -16,7 +16,6 @@ package com.nvidia.spark.rapids -import java.util import java.util.Properties import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} @@ -29,7 +28,7 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, PluginContext} import org.apache.spark.internal.Logging import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} -import org.apache.spark.sql.{DataFrame, SparkSessionExtensions} +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ @@ -43,32 +42,14 @@ class PluginException(msg: String) extends RuntimeException(msg) case class CudfVersionMismatchException(errorMsg: String) extends PluginException(errorMsg) case class ColumnarOverrideRules() extends ColumnarRule with Logging { - val overrides: Rule[SparkPlan] = GpuOverrides() - val overrideTransitions: Rule[SparkPlan] = new GpuTransitionOverrides() + lazy val overrides: Rule[SparkPlan] = GpuOverrides() + lazy val overrideTransitions: Rule[SparkPlan] = new GpuTransitionOverrides() override def preColumnarTransitions : Rule[SparkPlan] = overrides override def postColumnarTransitions: Rule[SparkPlan] = overrideTransitions } -/** - * Extension point to enable GPU SQL processing. - */ -class SQLExecPlugin extends (SparkSessionExtensions => Unit) with Logging { - override def apply(extensions: SparkSessionExtensions): Unit = { - val pluginProps = RapidsPluginUtils.loadProps(RapidsPluginUtils.PLUGIN_PROPS_FILENAME) - logInfo(s"RAPIDS Accelerator build: $pluginProps") - val cudfProps = RapidsPluginUtils.loadProps(RapidsPluginUtils.CUDF_PROPS_FILENAME) - logInfo(s"cudf build: $cudfProps") - val pluginVersion = pluginProps.getProperty("version", "UNKNOWN") - val cudfVersion = cudfProps.getProperty("version", "UNKNOWN") - logWarning(s"RAPIDS Accelerator $pluginVersion using cudf $cudfVersion." + - s" To disable GPU support set `${RapidsConf.SQL_ENABLED}` to false") - extensions.injectColumnar(_ => ColumnarOverrideRules()) - extensions.injectQueryStagePrepRule(_ => GpuQueryStagePrepOverrides()) - } -} - object RapidsPluginUtils extends Logging { val CUDF_PROPS_FILENAME = "cudf-java-version-info.properties" val PLUGIN_PROPS_FILENAME = "rapids4spark-version-info.properties" @@ -82,6 +63,17 @@ object RapidsPluginUtils extends Logging { private val KRYO_REGISTRATOR_KEY = "spark.kryo.registrator" private val KRYO_REGISTRATOR_NAME = classOf[GpuKryoRegistrator].getName + { + val pluginProps = loadProps(RapidsPluginUtils.PLUGIN_PROPS_FILENAME) + logInfo(s"RAPIDS Accelerator build: $pluginProps") + val cudfProps = loadProps(RapidsPluginUtils.CUDF_PROPS_FILENAME) + logInfo(s"cudf build: $cudfProps") + val pluginVersion = pluginProps.getProperty("version", "UNKNOWN") + val cudfVersion = cudfProps.getProperty("version", "UNKNOWN") + logWarning(s"RAPIDS Accelerator $pluginVersion using cudf $cudfVersion." + + s" To disable GPU support set `${RapidsConf.SQL_ENABLED}` to false") + } + def fixupConfigs(conf: SparkConf): Unit = { // First add in the SQL executor plugin because that is what we need at a minimum if (conf.contains(SQL_PLUGIN_CONF_KEY)) { @@ -151,13 +143,15 @@ class RapidsDriverPlugin extends DriverPlugin with Logging { } } - override def init(sc: SparkContext, pluginContext: PluginContext): util.Map[String, String] = { + override def init( + sc: SparkContext, pluginContext: PluginContext): java.util.Map[String, String] = { val sparkConf = pluginContext.conf RapidsPluginUtils.fixupConfigs(sparkConf) val conf = new RapidsConf(sparkConf) - if (conf.shimsProviderOverride.isDefined) { + if (conf.shimsProviderOverride.isDefined) { // TODO test it, probably not working yet ShimLoader.setSparkShimProviderClass(conf.shimsProviderOverride.get) } + if (GpuShuffleEnv.isRapidsShuffleAvailable && conf.shuffleTransportEarlyStart) { rapidsShuffleHeartbeatManager = @@ -177,7 +171,7 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging { override def init( pluginContext: PluginContext, - extraConf: util.Map[String, String]): Unit = { + extraConf: java.util.Map[String, String]): Unit = { try { val conf = new RapidsConf(extraConf.asScala.toMap) if (conf.shimsProviderOverride.isDefined) { @@ -198,6 +192,7 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging { conf.shuffleTransportEarlyStart) { logInfo("Initializing shuffle manager heartbeats") rapidsShuffleHeartbeatEndpoint = new RapidsShuffleHeartbeatEndpoint(pluginContext, conf) + rapidsShuffleHeartbeatEndpoint.registerShuffleHeartbeat() } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsShuffleHeartbeatManager.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsShuffleHeartbeatManager.scala index a561b31ebd0..fd851390e20 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsShuffleHeartbeatManager.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsShuffleHeartbeatManager.scala @@ -21,12 +21,12 @@ import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit} import scala.collection.mutable.ArrayBuffer import com.google.common.util.concurrent.ThreadFactoryBuilder -import java.util import org.apache.commons.lang3.mutable.MutableLong +import org.apache.spark.SparkEnv import org.apache.spark.api.plugin.PluginContext import org.apache.spark.internal.Logging -import org.apache.spark.sql.rapids.{GpuShuffleEnv, RapidsShuffleInternalManagerBase} +import org.apache.spark.sql.rapids.RapidsShuffleInternalManagerBase import org.apache.spark.storage.BlockManagerId /** @@ -78,7 +78,8 @@ class RapidsShuffleHeartbeatManager(heartbeatIntervalMillis: Long, private[this] var executors = new ArrayBuffer[ExecutorRegistration]() // A mapping of executor IDs to its registration, populated in `registerExecutor` - private[this] val executorRegistrations = new util.HashMap[BlockManagerId, ExecutorRegistration] + private[this] val executorRegistrations = + new java.util.HashMap[BlockManagerId, ExecutorRegistration] // Min-heap with the root node being the executor that least recently heartbeated private[this] val leastRecentHeartbeat = @@ -235,11 +236,13 @@ class RapidsShuffleHeartbeatEndpoint(pluginContext: PluginContext, conf: RapidsC } } - GpuShuffleEnv.mgr.foreach { mgr => - if (mgr.isDriver) { + def registerShuffleHeartbeat(): Unit = { + val rapidsShuffleManager = SparkEnv.get.shuffleManager.asInstanceOf[Proxy].self + .asInstanceOf[RapidsShuffleInternalManagerBase] + if (rapidsShuffleManager.isDriver) { logDebug("Local mode detected. Skipping shuffle heartbeat registration.") } else { - executorService.submit(new InitializeShuffleManager(pluginContext, mgr)) + executorService.submit(new InitializeShuffleManager(pluginContext, rapidsShuffleManager)) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SQLExecPlugin.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SQLExecPlugin.scala new file mode 100644 index 00000000000..fb0a8d83306 --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SQLExecPlugin.scala @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan} + +/** + * Extension point to enable GPU SQL processing. + */ +class SQLExecPlugin extends (SparkSessionExtensions => Unit) with Logging { + override def apply(extensions: SparkSessionExtensions): Unit = { + extensions.injectColumnar(columnarOverrides) + extensions.injectQueryStagePrepRule(queryStagePrepOverrides) + } + + private def columnarOverrides(sparkSession: SparkSession): ColumnarRule = { + ShimLoader.newColumnarOverrideRules() + } + + private def queryStagePrepOverrides(sparkSession: SparkSession): Rule[SparkPlan] = { + ShimLoader.newGpuQueryStagePrepOverrides() + } +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala index 1ce7fed3456..8a5157cae83 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala @@ -16,53 +16,221 @@ package com.nvidia.spark.rapids -import java.util.ServiceLoader +import java.net.URL import scala.collection.JavaConverters._ -import org.apache.spark.{SPARK_BUILD_USER, SPARK_VERSION} +import org.apache.spark.{SPARK_BUILD_USER, SPARK_VERSION, SparkConf} +import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin} import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan} +import org.apache.spark.sql.rapids.VisibleShuffleManager +import org.apache.spark.util.{MutableURLClassLoader, ParentClassLoader} +/* + Plugin jar uses non-standard class file layout. It consists of three types of areas, + "parallel worlds" in the JDK's com.sun.istack.internal.tools.ParallelWorldClassLoader parlance + + 1. a few publicly documented classes in the conventional layout at the top + 2. a large fraction of classes whose bytecode is identical under all supported Spark versions + in spark3xx-common + 3. a smaller fraction of classes that differ under one of the supported Spark versions + + com/nvidia/spark/SQLPlugin.class + + spark3xx-common/com/nvidia/spark/rapids/CastExprMeta.class + + spark301/org/apache/spark/sql/rapids/GpuUnaryMinus.class + spark311/org/apache/spark/sql/rapids/GpuUnaryMinus.class + spark320/org/apache/spark/sql/rapids/GpuUnaryMinus.class + + Each shim can see a consistent parallel world without conflicts by referencing + only one conflicting directory. + + E.g., Spark 3.2.0 Shim will use + + jar:file:/home/spark/rapids-4-spark_2.12-21.10.jar!/spark3xx-common/ + jar:file:/home/spark/rapids-4-spark_2.12-21.10.jar!/spark320/ + + Spark 3.1.1 will use + + jar:file:/home/spark/rapids-4-spark_2.12-21.10.jar!/spark3xx-common/ + jar:file:/home/spark/rapids-4-spark_2.12-21.10.jar!/spark311/ + + Using these Jar URL's allows referencing different bytecode produced from identical sources + by incompatible Scala / Spark dependencies. + */ object ShimLoader extends Logging { - private var shimProviderClass: String = null - private var sparkShims: SparkShims = null + logDebug(s"ShimLoader object instance: ${this} loaded by ${getClass.getClassLoader}") + private val shimRootURL = { + val thisClassFile = getClass.getName.replace(".", "/") + ".class" + val url = getClass.getClassLoader.getResource(thisClassFile) + val urlStr = url.toString + val rootUrlStr = urlStr.substring(0, urlStr.length - thisClassFile.length) + new URL(rootUrlStr) + } + + private val shimCommonURL = new URL(s"${shimRootURL.toString}spark3xx-common/") + + @volatile private var shimProviderClass: String = _ + @volatile private var sparkShims: SparkShims = _ + @volatile private var shimURL: URL = _ + @volatile private var pluginClassLoader: ClassLoader = _ + + // REPL-only logic + @volatile private var tmpClassLoader: MutableURLClassLoader = _ + + private def shimId: String = shimIdFromPackageName(shimProviderClass) + + // defensively call findShimProvider logic on all entry points to avoid uninitialized + // this won't be necessary if we can upstream changes to the plugin and shuffle + // manager loading changes to Apache Spark + private def initShimProviderIfNeeded(): Unit = { + if (shimURL == null) { + findShimProvider() + } + } + + // Ideally we would like to expose a simple Boolean config instead of having to document + // per-shim ShuffleManager implementations: + // https://github.com/NVIDIA/spark-rapids/blob/branch-21.08/docs/additional-functionality/ + // rapids-shuffle.md#spark-app-configuration + // + // This is not possible at the current stage of the shim layer rewrite because of the combination + // of the following two reasons: + // 1) Spark processes ShuffleManager config before any of the plugin code initialized + // 2) We can't combine the implementation of the ShuffleManaeger trait for different Spark + // versions in the same Scala class. A method was changed to final + // https://github.com/apache/spark/blame/v3.2.0-rc2/core/src/main/scala/ + // org/apache/spark/shuffle/ShuffleManager.scala#L57 + // + // ShuffleBlockResolver implementation for 3.1 has MergedBlockMeta in signatures + // missing in the prior versions leading to CNF when loaded in earlier version + // + def getRapidsShuffleManagerClass: String = { + initShimProviderIfNeeded() + s"com.nvidia.spark.rapids.$shimId.RapidsShuffleManager" + } + + def getRapidsShuffleInternalClass: String = { + initShimProviderIfNeeded() + s"org.apache.spark.sql.rapids.shims.$shimId.RapidsShuffleInternalManager" + } + + private def updateSparkClassLoader(): Unit = { + // TODO propose a proper addClassPathURL API to Spark similar to addJar but + // accepting non-file-based URI + val contextClassLoader = Thread.currentThread().getContextClassLoader + Option(contextClassLoader).collect { + case mutable: MutableURLClassLoader => mutable + case replCL if replCL.getClass.getName == "org.apache.spark.repl.ExecutorClassLoader" => + val parentLoaderField = replCL.getClass.getDeclaredMethod("parentLoader") + val parentLoader = parentLoaderField.invoke(replCL).asInstanceOf[ParentClassLoader] + parentLoader.getParent.asInstanceOf[MutableURLClassLoader] + }.foreach { mutable => + // MutableURLClassloader dedupes for us + pluginClassLoader = contextClassLoader + mutable.addURL(shimURL) + mutable.addURL(shimCommonURL) + } + } + + private def getShimClassLoader(): ClassLoader = { + initShimProviderIfNeeded() + if (pluginClassLoader == null) { + updateSparkClassLoader() + } + if (pluginClassLoader == null) { + if (tmpClassLoader == null) { + tmpClassLoader = new MutableURLClassLoader(Array(shimURL, shimCommonURL), + getClass.getClassLoader) + } + tmpClassLoader + } else { + pluginClassLoader + } + } + + private val SERVICE_LOADER_PREFIX = "META-INF/services/" - private def detectShimProvider(): SparkShimServiceProvider = { + private def detectShimProvider(): String = { val sparkVersion = getSparkVersion logInfo(s"Loading shim for Spark version: $sparkVersion") - // This is not ideal, but pass the version in here because otherwise loader that match the - // same version (3.0.1 Apache and 3.0.1 Databricks) would need to know how to differentiate. - val sparkShimLoaders = ServiceLoader.load(classOf[SparkShimServiceProvider]) - .asScala.filter(_.matchesVersion(sparkVersion)) - if (sparkShimLoaders.size > 1) { - throw new IllegalArgumentException(s"Multiple Spark Shim Loaders found: $sparkShimLoaders") + val thisClassLoader = getClass.getClassLoader + + // Emulating service loader manually because we have a non-standard jar layout for classes + // when we pass a classloader to https://docs.oracle.com/javase/8/docs/api/java/util/ + // ServiceLoader.html#load-java.lang.Class-java.lang.ClassLoader- + // it expects META-INF/services at the normal root locations (OK) + // and provider classes under the normal root entry as well. The latter is not OK because we + // want to minimize the use of reflection and prevent leaking the provider to a conventional + // classloader. + // + // Alternatively, we could use a ChildFirstClassloader implementation. However, this means that + // ShimServiceProvider API definition is not shared via parent and we run + // into ClassCastExceptions. If we find a way to solve this then we can revert to ServiceLoader + + val serviceProviderListPath = SERVICE_LOADER_PREFIX + classOf[SparkShimServiceProvider].getName + val serviceProviderList = thisClassLoader.getResources(serviceProviderListPath) + .asScala.map(scala.io.Source.fromURL) + .flatMap(_.getLines()) + + assert(serviceProviderList.nonEmpty, "Classpath should contain the resource for " + + serviceProviderListPath) + + val shimServiceProviderOpt = serviceProviderList.flatMap { shimServiceProviderStr => + val mask = shimIdFromPackageName(shimServiceProviderStr) + try { + val shimURL = new java.net.URL(s"${shimRootURL.toString}$mask/") + val shimClassLoader = new MutableURLClassLoader(Array(shimURL, shimCommonURL), + thisClassLoader) + val shimClass = shimClassLoader.loadClass(shimServiceProviderStr) + Option( + (instantiateClass(shimClass).asInstanceOf[SparkShimServiceProvider], shimURL) + ) + } catch { + case cnf: ClassNotFoundException => + logWarning(cnf + ": Could not load the provider", cnf) + None + } + }.find { case (shimServiceProvider, _) => + shimServiceProvider.matchesVersion(sparkVersion) + }.map { case (inst, url) => + shimURL = url + // this class will be loaded again by the real executor classloader + inst.getClass.getName } - logInfo(s"Found shims: $sparkShimLoaders") - val loader = sparkShimLoaders.headOption match { - case Some(loader) => loader - case None => + + shimServiceProviderOpt.getOrElse { throw new IllegalArgumentException(s"Could not find Spark Shim Loader for $sparkVersion") } - loader } - private def findShimProvider(): SparkShimServiceProvider = { + // shimId corresponds to spark.version.classifier by convention + // e.g. com.nvidia.spark.rapids.shims.spark320.SparkShimServiceProvider implies + // shimId = "spark320" + private def shimIdFromPackageName(shimServiceProvider: SparkShimServiceProvider) = { + shimServiceProvider.getClass.getPackage.toString.split('.').last + } + + private def shimIdFromPackageName(shimServiceProviderStr: String) = { + shimServiceProviderStr.split('.').takeRight(2).head + } + + private def findShimProvider(): String = { + // TODO restore support for shim provider override if (shimProviderClass == null) { - detectShimProvider() - } else { - logWarning(s"Overriding Spark shims provider to $shimProviderClass. " + - "This may be an untested configuration!") - val providerClass = Class.forName(shimProviderClass) - val constructor = providerClass.getConstructor() - constructor.newInstance().asInstanceOf[SparkShimServiceProvider] + shimProviderClass = detectShimProvider() } + shimProviderClass } def getSparkShims: SparkShims = { if (sparkShims == null) { - val provider = findShimProvider() - sparkShims = provider.buildShim + sparkShims = newInstanceOf[SparkShimServiceProvider](findShimProvider()).buildShim } sparkShims } @@ -76,7 +244,64 @@ object ShimLoader extends Logging { } } + // TODO broken right now, check if this can be supported with parallel worlds + // it implies the prerequisite of having such a class in the conventional root jar entry + // - or the necessity of an additional parameter for specifying the shim subdirectory + // - or enforcing the convention the class file parent directory is the shimid that is also + // a top entry e.g. /spark301/com/nvidia/test/shim/spark301/Spark301Shims.class def setSparkShimProviderClass(classname: String): Unit = { shimProviderClass = classname } + + private def newInstanceOf[T](className: String): T = { + val loader = getShimClassLoader() + logDebug(s"Loading $className using $loader with the parent loader ${loader.getParent}") + instantiateClass(loader.loadClass(className)).asInstanceOf[T] + } + + // avoid cached constructors + private def instantiateClass[T](cls: Class[T]): T = { + logDebug(s"Instantiate ${cls.getName} using classloader " + cls.getClassLoader) + cls.getClassLoader match { + case m: MutableURLClassLoader => + logDebug("urls " + m.getURLs.mkString("\n")) + case _ => + } + val constructor = cls.getConstructor() + constructor.newInstance() + } + + + // + // Reflection-based API with Spark to switch the classloader used by the caller + // + + def newInternalShuffleManager(conf: SparkConf, isDriver: Boolean): VisibleShuffleManager = { + val shuffleClassLoader = getShimClassLoader() + val shuffleClassName = getRapidsShuffleInternalClass + val shuffleClass = shuffleClassLoader.loadClass(shuffleClassName) + shuffleClass.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE) + .newInstance(conf, java.lang.Boolean.valueOf(isDriver)) + .asInstanceOf[VisibleShuffleManager] + } + + def newDriverPlugin(): DriverPlugin = { + newInstanceOf("com.nvidia.spark.rapids.RapidsDriverPlugin") + } + + def newExecutorPlugin(): ExecutorPlugin = { + newInstanceOf("com.nvidia.spark.rapids.RapidsExecutorPlugin") + } + + def newColumnarOverrideRules(): ColumnarRule = { + newInstanceOf("com.nvidia.spark.rapids.ColumnarOverrideRules") + } + + def newGpuQueryStagePrepOverrides(): Rule[SparkPlan] = { + newInstanceOf("com.nvidia.spark.rapids.GpuQueryStagePrepOverrides") + } + + def newUdfLogicalPlanRules(): Rule[LogicalPlan] = { + newInstanceOf("com.nvidia.spark.udf.LogicalPlanRules") + } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala index 6c64218f3bb..0875bb29ecf 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala @@ -63,17 +63,19 @@ class GpuShuffleEnv(rapidsConf: RapidsConf) extends Logging { } object GpuShuffleEnv extends Logging { - def shutdown() = { - mgr.foreach(_.stop()) - mgr = None - } - - val RAPIDS_SHUFFLE_CLASS: String = ShimLoader.getSparkShims.getRapidsShuffleManagerClass - - var mgr: Option[RapidsShuffleInternalManagerBase] = None + val RAPIDS_SHUFFLE_CLASS: String = ShimLoader.getRapidsShuffleManagerClass + val RAPIDS_SHUFFLE_INTERNAL: String = ShimLoader.getRapidsShuffleInternalClass @volatile private var env: GpuShuffleEnv = _ + def shutdown() = { + // check for nulls in tests + val shuffleManager = Option(SparkEnv.get) + .map(_.shuffleManager) + .collect { case sm: VisibleShuffleManager => sm } + .foreach(_.stop()) + } + // // Functions below get called from the driver or executors // @@ -84,14 +86,14 @@ object GpuShuffleEnv extends Logging { def isRapidsShuffleAvailable: Boolean = { // the driver has `mgr` defined when this is checked - val isRapidsManager = mgr.isDefined + val sparkEnv = SparkEnv.get + val isRapidsManager = sparkEnv.shuffleManager.isInstanceOf[VisibleShuffleManager] + if (isRapidsManager) { + validateRapidsShuffleManager(sparkEnv.shuffleManager.getClass.getName) + } // executors have `env` defined when this is checked // in tests - val isConfiguredInEnv = if (env != null) { - env.isRapidsShuffleConfigured - } else { - false - } + val isConfiguredInEnv = Option(env).map(_.isRapidsShuffleConfigured).getOrElse(false) (isConfiguredInEnv || isRapidsManager) && !isExternalShuffleEnabled } @@ -99,27 +101,22 @@ object GpuShuffleEnv extends Logging { conf.shuffleManagerEnabled && isRapidsShuffleAvailable } - def setRapidsShuffleManager( - managerOpt: Option[RapidsShuffleInternalManagerBase] = None): Unit = { - if (managerOpt.isDefined) { - val manager = managerOpt.get - if (manager.getClass.getCanonicalName != GpuShuffleEnv.RAPIDS_SHUFFLE_CLASS) { - throw new IllegalStateException(s"RapidsShuffleManager class mismatch (" + - s"${manager.getClass.getCanonicalName} != ${GpuShuffleEnv.RAPIDS_SHUFFLE_CLASS}). " + - s"Check that configuration setting spark.shuffle.manager is correct for the Spark " + - s"version being used.") - } - logInfo("RapidsShuffleManager is initialized") - } - mgr = managerOpt - } - def getCatalog: ShuffleBufferCatalog = if (env == null) { null } else { env.getCatalog } + private def validateRapidsShuffleManager(shuffManagerClassName: String): Unit = { + val shuffleManagerStr = ShimLoader.getRapidsShuffleManagerClass + if (shuffManagerClassName != shuffleManagerStr) { + throw new IllegalStateException(s"RapidsShuffleManager class mismatch (" + + s"${shuffManagerClassName} != $shuffleManagerStr). " + + s"Check that configuration setting spark.shuffle.manager is correct for the Spark " + + s"version being used.") + } + } + // // Functions below only get called from the executor // diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala index 1b86da86ce4..7516dbd2887 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala @@ -236,8 +236,6 @@ abstract class RapidsShuffleInternalManagerBase(conf: SparkConf, val isDriver: B protected val wrapped = new SortShuffleManager(conf) - GpuShuffleEnv.setRapidsShuffleManager(Some(this)) - private[this] val transportEnabledMessage = if (!rapidsConf.shuffleTransportEnabled) { "Transport disabled (local cached blocks only)" } else { @@ -429,3 +427,41 @@ abstract class RapidsShuffleInternalManagerBase(conf: SparkConf, val isDriver: B trait VisibleShuffleManager extends ShuffleManager { def isDriver: Boolean } + +abstract class ProxyRapidsShuffleInternalManagerBase( + conf: SparkConf, + override val isDriver: Boolean +) extends VisibleShuffleManager with Proxy { + + // touched in the plugin code after the shim initialization + // is complete + override lazy val self: VisibleShuffleManager = + ShimLoader.newInternalShuffleManager(conf, isDriver) + + + // + // Signatures unchanged since 3.0.1 follow + // + + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Long, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter + ): ShuffleWriter[K, V] = { + self.getWriter(handle, mapId, context, metrics) + } + + override def registerShuffle[K, V, C]( + shuffleId: Int, + dependency: ShuffleDependency[K, V, C] + ): ShuffleHandle = { + self.registerShuffle(shuffleId, dependency) + } + + override def unregisterShuffle(shuffleId: Int): Boolean = self.unregisterShuffle(shuffleId) + + override def shuffleBlockResolver: ShuffleBlockResolver = self.shuffleBlockResolver + + override def stop(): Unit = self.stop() +} diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala b/tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala index 6bd303dad2c..cc2f5426d8e 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala @@ -140,7 +140,6 @@ object TestUtils extends Assertions with Arm { SparkSession.clearActiveSession() SparkSession.clearDefaultSession() } - GpuShuffleEnv.setRapidsShuffleManager(None) } } } diff --git a/udf-compiler/pom.xml b/udf-compiler/pom.xml index ec17f3c2bcc..759696caa22 100644 --- a/udf-compiler/pom.xml +++ b/udf-compiler/pom.xml @@ -78,6 +78,46 @@ ${spark.version.classifier} provided + + + com.nvidia + rapids-4-spark-shims-aggregator_${scala.binary.version} + ${project.version} + + + + + + org.apache.maven.plugins + maven-jar-plugin + + ${spark.version.classifier} + + + + + + + no-classifier + + !buildver + + + + com.nvidia + rapids-4-spark-sql_${scala.binary.version} + ${project.version} + ${spark.version.classifier} + provided + + + + com.nvidia + rapids-4-spark-shims-aggregator_${scala.binary.version} + ${project.version} + diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/LogicalPlanRules.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/LogicalPlanRules.scala new file mode 100644 index 00000000000..e9732e064c7 --- /dev/null +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/LogicalPlanRules.scala @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.udf + +import ai.rapids.cudf.{NvtxColor, NvtxRange} +import com.nvidia.spark.rapids.RapidsConf + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, ScalaUDF} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.rapids.GpuScalaUDF.getRapidsUDFInstance + + +case class LogicalPlanRules() extends Rule[LogicalPlan] with Logging { + def replacePartialFunc(plan: LogicalPlan): PartialFunction[Expression, Expression] = { + case d: Expression => { + val nvtx = new NvtxRange("replace UDF", NvtxColor.BLUE) + try { + attemptToReplaceExpression(plan, d) + } finally { + nvtx.close() + } + } + } + + def attemptToReplaceExpression(plan: LogicalPlan, exp: Expression): Expression = { + val conf = new RapidsConf(plan.conf) + // iterating over NamedExpression + exp match { + // Check if this UDF implements RapidsUDF interface. If so, the UDF has already provided a + // columnar execution that could run on GPU, then no need to translate it to Catalyst + // expressions. If not, compile it. + case f: ScalaUDF if getRapidsUDFInstance(f.function).isEmpty => + GpuScalaUDFLogical(f).compile(conf.isTestEnabled) + case _ => + if (exp == null) { + exp + } else { + try { + if (exp.children != null && !exp.children.contains(null)) { + exp.withNewChildren(exp.children.map(c => { + if (c != null && c.isInstanceOf[Expression]) { + attemptToReplaceExpression(plan, c) + } else { + c + } + })) + } else { + exp + } + } catch { + case _: NullPointerException => { + exp + } + } + } + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + val conf = new RapidsConf(plan.conf) + if (conf.isUdfCompilerEnabled) { + plan match { + case project: Project => + Project(project.projectList.map(e => attemptToReplaceExpression(plan, e)) + .asInstanceOf[Seq[NamedExpression]], apply(project.child)) + case x => { + x.transformExpressions(replacePartialFunc(plan)) + } + } + } else { + plan + } + } +} diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/Plugin.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/Plugin.scala index a850f8b8baa..bd46e64faf4 100644 --- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/Plugin.scala +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/Plugin.scala @@ -16,83 +16,21 @@ package com.nvidia.spark.udf -import ai.rapids.cudf.{NvtxColor, NvtxRange} -import com.nvidia.spark.rapids.RapidsConf +import com.nvidia.spark.rapids.ShimLoader import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSessionExtensions -import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression, ScalaUDF} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.rapids.GpuScalaUDF.getRapidsUDFInstance -class Plugin extends Function1[SparkSessionExtensions, Unit] with Logging { +class Plugin extends (SparkSessionExtensions => Unit) with Logging { override def apply(extensions: SparkSessionExtensions): Unit = { logWarning("Installing rapids UDF compiler extensions to Spark. The compiler is disabled" + - s" by default. To enable it, set `${RapidsConf.UDF_COMPILER_ENABLED}` to true") - extensions.injectResolutionRule(_ => LogicalPlanRules()) + s" by default. To enable it, set `spark.rapids.sql.udfCompiler.enabled` to true") + extensions.injectResolutionRule(logicalPlanRules) } -} -case class LogicalPlanRules() extends Rule[LogicalPlan] with Logging { - def replacePartialFunc(plan: LogicalPlan): PartialFunction[Expression, Expression] = { - case d: Expression => { - val nvtx = new NvtxRange("replace UDF", NvtxColor.BLUE) - try { - attemptToReplaceExpression(plan, d) - } finally { - nvtx.close() - } - } + def logicalPlanRules(sparkSession: SparkSession): Rule[LogicalPlan] = { + ShimLoader.newUdfLogicalPlanRules() } - - def attemptToReplaceExpression(plan: LogicalPlan, exp: Expression): Expression = { - val conf = new RapidsConf(plan.conf) - // iterating over NamedExpression - exp match { - // Check if this UDF implements RapidsUDF interface. If so, the UDF has already provided a - // columnar execution that could run on GPU, then no need to translate it to Catalyst - // expressions. If not, compile it. - case f: ScalaUDF if getRapidsUDFInstance(f.function).isEmpty => - GpuScalaUDFLogical(f).compile(conf.isTestEnabled) - case _ => - if (exp == null) { - exp - } else { - try { - if (exp.children != null && !exp.children.contains(null)) { - exp.withNewChildren(exp.children.map(c => { - if (c != null && c.isInstanceOf[Expression]) { - attemptToReplaceExpression(plan, c) - } else { - c - } - })) - } else { - exp - } - } catch { - case _: NullPointerException => { - exp - } - } - } - } - } - - override def apply(plan: LogicalPlan): LogicalPlan = { - val conf = new RapidsConf(plan.conf) - if (conf.isUdfCompilerEnabled) { - plan match { - case project: Project => - Project(project.projectList.map(e => attemptToReplaceExpression(plan, e)) - .asInstanceOf[Seq[NamedExpression]], apply(project.child)) - case x => { - x.transformExpressions(replacePartialFunc(plan)) - } - } - } else { - plan - } - } -} +} \ No newline at end of file