Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cross-compile all shims from JDK17 to JDK8 #3

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions dist/scripts/binary-dedupe.sh
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,12 @@ function verify_same_sha_for_unshimmed() {
# TODO currently RapidsShuffleManager is "removed" from /spark* by construction in
# dist pom.xml via ant. We could delegate this logic to this script
# and make both simmpler
if [[ ! "$class_file_quoted" =~ (com/nvidia/spark/rapids/spark[34].*/.*ShuffleManager.class|org/apache/spark/sql/rapids/shims/spark[34].*/ProxyRapidsShuffleInternalManager.class) ]]; then
if [[ ! "$class_file_quoted" =~ com/nvidia/spark/rapids/spark[34].*/.*ShuffleManager.class ]]; then
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why we got rid of the ProxyRapidsShuffleInternalManager

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tech debt, should have been done in NVIDIA#6030


if ! grep -q "/spark.\+/$class_file_quoted" "$SPARK_SHARED_TXT"; then
echo >&2 "$class_file is not bitwise-identical across shims"
exit 255
fi
if ! grep -q "/spark.\+/$class_file_quoted" "$SPARK_SHARED_TXT"; then
echo >&2 "$class_file is not bitwise-identical across shims"
exit 255
fi
fi
}

Expand Down
5 changes: 1 addition & 4 deletions jdk-profiles/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@
<profile>
<id>jdk9plus</id>
<properties>
<scala.plugin.version>4.6.1</scala.plugin.version>
<maven.compiler.source>${java.specification.version}</maven.compiler.source>
<maven.compiler.release>${maven.compiler.source}</maven.compiler.release>
<maven.compiler.target>${maven.compiler.source}</maven.compiler.target>
<maven.compiler.source>1.8</maven.compiler.source>
</properties>
<activation>
<!-- activate for all java versions after 9 -->
Expand Down
11 changes: 3 additions & 8 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -505,8 +505,6 @@
</property>
</activation>
<properties>
<!-- Downgrade scala plugin version due to: https://github.com/sbt/sbt/issues/4305 -->
<scala.plugin.version>3.4.4</scala.plugin.version>
<spark.version.classifier>spark330db</spark.version.classifier>
<spark.version>${spark330db.version}</spark.version>
<spark.test.version>${spark330db.version}</spark.test.version>
Expand All @@ -531,8 +529,6 @@
</property>
</activation>
<properties>
<!-- Downgrade scala plugin version due to: https://github.com/sbt/sbt/issues/4305 -->
<scala.plugin.version>3.4.4</scala.plugin.version>
<spark.version.classifier>spark332db</spark.version.classifier>
<spark.version>${spark332db.version}</spark.version>
<spark.test.version>${spark332db.version}</spark.test.version>
Expand All @@ -556,8 +552,6 @@
</property>
</activation>
<properties>
<!-- Downgrade scala plugin version due to: https://github.com/sbt/sbt/issues/4305 -->
<scala.plugin.version>3.4.4</scala.plugin.version>
<spark.version.classifier>spark341db</spark.version.classifier>
<spark.version>${spark341db.version}</spark.version>
<spark.test.version>${spark341db.version}</spark.test.version>
Expand Down Expand Up @@ -755,7 +749,6 @@
<allowConventionalDistJar>false</allowConventionalDistJar>
<buildver>311</buildver>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<java.major.version>8</java.major.version>
<spark.version>${spark311.version}</spark.version>
<spark.test.version>${spark.version}</spark.test.version>
Expand Down Expand Up @@ -829,7 +822,8 @@
<spark351.version>3.5.1</spark351.version>
<spark400.version>4.0.0-SNAPSHOT</spark400.version>
<mockito.version>3.12.4</mockito.version>
<scala.plugin.version>4.3.0</scala.plugin.version>
<!-- same as Apache Spark 4.0.0 -->
<scala.plugin.version>4.7.1</scala.plugin.version>
<maven.install.plugin.version>3.1.1</maven.install.plugin.version>
<maven.jar.plugin.version>3.3.0</maven.jar.plugin.version>
<scalatest-maven-plugin.version>2.0.2</scalatest-maven-plugin.version>
Expand Down Expand Up @@ -1367,6 +1361,7 @@ This will force full Scala code rebuild in downstream modules.
<checkMultipleScalaVersions>true</checkMultipleScalaVersions>
<failOnMultipleScalaVersions>true</failOnMultipleScalaVersions>
<recompileMode>${scala.recompileMode}</recompileMode>
<release>${java.major.version}</release>
<args>
<arg>-unchecked</arg>
<arg>-deprecation</arg>
Expand Down
5 changes: 1 addition & 4 deletions scala2.13/jdk-profiles/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@
<profile>
<id>jdk9plus</id>
<properties>
<scala.plugin.version>4.6.1</scala.plugin.version>
<maven.compiler.source>${java.specification.version}</maven.compiler.source>
<maven.compiler.release>${maven.compiler.source}</maven.compiler.release>
<maven.compiler.target>${maven.compiler.source}</maven.compiler.target>
<maven.compiler.source>1.8</maven.compiler.source>
</properties>
<activation>
<!-- activate for all java versions after 9 -->
Expand Down
11 changes: 3 additions & 8 deletions scala2.13/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -505,8 +505,6 @@
</property>
</activation>
<properties>
<!-- Downgrade scala plugin version due to: https://github.com/sbt/sbt/issues/4305 -->
<scala.plugin.version>3.4.4</scala.plugin.version>
<spark.version.classifier>spark330db</spark.version.classifier>
<spark.version>${spark330db.version}</spark.version>
<spark.test.version>${spark330db.version}</spark.test.version>
Expand All @@ -531,8 +529,6 @@
</property>
</activation>
<properties>
<!-- Downgrade scala plugin version due to: https://github.com/sbt/sbt/issues/4305 -->
<scala.plugin.version>3.4.4</scala.plugin.version>
<spark.version.classifier>spark332db</spark.version.classifier>
<spark.version>${spark332db.version}</spark.version>
<spark.test.version>${spark332db.version}</spark.test.version>
Expand All @@ -556,8 +552,6 @@
</property>
</activation>
<properties>
<!-- Downgrade scala plugin version due to: https://github.com/sbt/sbt/issues/4305 -->
<scala.plugin.version>3.4.4</scala.plugin.version>
<spark.version.classifier>spark341db</spark.version.classifier>
<spark.version>${spark341db.version}</spark.version>
<spark.test.version>${spark341db.version}</spark.test.version>
Expand Down Expand Up @@ -755,7 +749,6 @@
<allowConventionalDistJar>false</allowConventionalDistJar>
<buildver>311</buildver>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<java.major.version>8</java.major.version>
<spark.version>${spark330.version}</spark.version>
<spark.test.version>${spark.version}</spark.test.version>
Expand Down Expand Up @@ -829,7 +822,8 @@
<spark351.version>3.5.1</spark351.version>
<spark400.version>4.0.0-SNAPSHOT</spark400.version>
<mockito.version>3.12.4</mockito.version>
<scala.plugin.version>4.3.0</scala.plugin.version>
<!-- same as Apache Spark 4.0.0 -->
<scala.plugin.version>4.7.1</scala.plugin.version>
<maven.install.plugin.version>3.1.1</maven.install.plugin.version>
<maven.jar.plugin.version>3.3.0</maven.jar.plugin.version>
<scalatest-maven-plugin.version>2.0.2</scalatest-maven-plugin.version>
Expand Down Expand Up @@ -1367,6 +1361,7 @@ This will force full Scala code rebuild in downstream modules.
<checkMultipleScalaVersions>true</checkMultipleScalaVersions>
<failOnMultipleScalaVersions>true</failOnMultipleScalaVersions>
<recompileMode>${scala.recompileMode}</recompileMode>
<release>${java.major.version}</release>
<args>
<arg>-unchecked</arg>
<arg>-deprecation</arg>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -25,7 +25,7 @@ import org.apache.spark.internal.Logging
* The RAPIDS plugin for Spark.
* To enable this plugin, set the config "spark.plugins" to `com.nvidia.spark.SQLPlugin`
*/
class SQLPlugin extends SparkPlugin with Logging {
class SQLPlugin extends SparkPlugin {
override def driverPlugin(): DriverPlugin = ShimLoader.newDriverPlugin()

override def executorPlugin(): ExecutorPlugin = ShimLoader.newExecutorPlugin()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import org.apache.commons.lang3.reflect.MethodUtils
import org.apache.spark.{SPARK_BRANCH, SPARK_BUILD_DATE, SPARK_BUILD_USER, SPARK_REPO_URL, SPARK_REVISION, SPARK_VERSION, SparkConf, SparkEnv}
import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin}
import org.apache.spark.api.resource.ResourceDiscoveryPlugin
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -57,8 +56,9 @@ import org.apache.spark.util.MutableURLClassLoader
Using these Jar URL's allows referencing different bytecode produced from identical sources
by incompatible Scala / Spark dependencies.
*/
object ShimLoader extends Logging {
logDebug(s"ShimLoader object instance: $this loaded by ${getClass.getClassLoader}")
object ShimLoader {
val log = org.slf4j.LoggerFactory.getLogger(getClass().getName().stripSuffix("$"))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So much better than what we were previously doing! I don't know why we even bothered with the Logging trait

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is/was a convenient pattern but it is an internal class that would now need shimming. And resulted in a chicken-and-egg problem for shim loading.

log.debug(s"ShimLoader object instance: $this loaded by ${getClass.getClassLoader}")
private val shimRootURL = {
val thisClassFile = getClass.getName.replace(".", "/") + ".class"
val url = getClass.getClassLoader.getResource(thisClassFile)
Expand Down Expand Up @@ -124,11 +124,11 @@ object ShimLoader extends Logging {
// brute-force call addURL using reflection
classLoader match {
case nullClassLoader if nullClassLoader == null =>
logInfo("findURLClassLoader failed to locate a mutable classloader")
log.info("findURLClassLoader failed to locate a mutable classloader")
None
case urlCl: java.net.URLClassLoader =>
// fast path
logInfo(s"findURLClassLoader found a URLClassLoader $urlCl")
log.info(s"findURLClassLoader found a URLClassLoader $urlCl")
Option(urlCl)
case replCl if replCl.getClass.getName == "org.apache.spark.repl.ExecutorClassLoader" ||
replCl.getClass.getName == "org.apache.spark.executor.ExecutorClassLoader" =>
Expand All @@ -137,20 +137,21 @@ object ShimLoader extends Logging {
// https://issues.apache.org/jira/browse/SPARK-18646
val parentLoader = MethodUtils.invokeMethod(replCl, true, "parentLoader")
.asInstanceOf[ClassLoader]
logInfo(s"findURLClassLoader found $replCl, trying parentLoader=$parentLoader")
log.info(s"findURLClassLoader found $replCl, trying parentLoader=$parentLoader")
findURLClassLoader(parentLoader)
case urlAddable: ClassLoader if null != MethodUtils.getMatchingMethod(
urlAddable.getClass, "addURL", classOf[java.net.URL]) =>
// slow defensive path
logInfo(s"findURLClassLoader found a urLAddable classloader $urlAddable")
log.info(s"findURLClassLoader found a urLAddable classloader $urlAddable")
Option(urlAddable)
case root if root.getParent == null || root.getParent == root =>
logInfo(s"findURLClassLoader hit the Boostrap classloader $root, " +
log.info(s"findURLClassLoader hit the Boostrap classloader $root, " +
s"failed to find a mutable classloader!")
None
case cl =>
val parentClassLoader = cl.getParent
logInfo(s"findURLClassLoader found an immutable $cl, trying parent=$parentClassLoader")
log.info(s"findURLClassLoader found an immutable $cl" +
s", trying parent=$parentClassLoader")
findURLClassLoader(parentClassLoader)
}
}
Expand All @@ -159,15 +160,15 @@ object ShimLoader extends Logging {
findURLClassLoader(UnshimmedTrampolineUtil.sparkClassLoader).foreach { urlAddable =>
urlsForSparkClassLoader.foreach { url =>
if (!conventionalSingleShimJarDetected) {
logInfo(s"Updating spark classloader $urlAddable with the URLs: " +
log.info(s"Updating spark classloader $urlAddable with the URLs: " +
urlsForSparkClassLoader.mkString(", "))
MethodUtils.invokeMethod(urlAddable, true, "addURL", url)
logInfo(s"Spark classLoader $urlAddable updated successfully")
log.info(s"Spark classLoader $urlAddable updated successfully")
urlAddable match {
case urlCl: java.net.URLClassLoader =>
if (!urlCl.getURLs.contains(shimCommonURL)) {
// infeasible, defensive diagnostics
logWarning(s"Didn't find expected URL $shimCommonURL in the spark " +
log.warn(s"Didn't find expected URL $shimCommonURL in the spark " +
s"classloader $urlCl although addURL succeeded, maybe pushed up to the " +
s"parent classloader ${urlCl.getParent}")
}
Expand All @@ -188,7 +189,7 @@ object ShimLoader extends Logging {
if (tmpClassLoader == null) {
tmpClassLoader = new MutableURLClassLoader(Array(shimURL, shimCommonURL),
getClass.getClassLoader)
logWarning("Found an unexpected context classloader " +
log.warn("Found an unexpected context classloader " +
s"${Thread.currentThread().getContextClassLoader}. We will try to recover from this, " +
"but it may cause class loading problems.")
}
Expand All @@ -202,9 +203,9 @@ object ShimLoader extends Logging {

private def detectShimProvider(): String = {
val sparkVersion = getSparkVersion
logInfo(s"Loading shim for Spark version: $sparkVersion")
logInfo("Complete Spark build info: " + sparkBuildInfo.mkString(", "))
logInfo("Scala version: " + util.Properties.versionString)
log.info(s"Loading shim for Spark version: $sparkVersion")
log.info("Complete Spark build info: " + sparkBuildInfo.mkString(", "))
log.info("Scala version: " + util.Properties.versionString)

val thisClassLoader = getClass.getClassLoader

Expand All @@ -225,7 +226,7 @@ object ShimLoader extends Logging {
val shimServiceProviderOverrideClassName = Option(SparkEnv.get) // Spark-less RapidsConf.help
.flatMap(_.conf.getOption("spark.rapids.shims-provider-override"))
shimServiceProviderOverrideClassName.foreach { shimProviderClass =>
logWarning(s"Overriding Spark shims provider to $shimProviderClass. " +
log.warn(s"Overriding Spark shims provider to $shimProviderClass. " +
"This may be an untested configuration!")
}

Expand All @@ -252,7 +253,7 @@ object ShimLoader extends Logging {
val ret = thisClassLoader.loadClass(shimServiceProviderStr)
if (numShimServiceProviders == 1) {
conventionalSingleShimJarDetected = true
logInfo("Conventional shim jar layout for a single Spark verision detected")
log.info("Conventional shim jar layout for a single Spark verision detected")
}
ret
}.getOrElse(shimClassLoader.loadClass(shimServiceProviderStr))
Expand All @@ -262,7 +263,7 @@ object ShimLoader extends Logging {
)
} catch {
case cnf: ClassNotFoundException =>
logDebug(cnf + ": Could not load the provider, likely a dev build", cnf)
log.debug(cnf + ": Could not load the provider, likely a dev build", cnf)
None
}
}.partition { case (inst, _) =>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,19 +15,18 @@
*/

package com.nvidia.spark.rapids

import org.apache.spark.internal.Logging

/*
* This is specifically for functions dealing with loading classes via reflection. This
* class itself should not contain or import any shimmed/parallel world classes so that
* it can also be called via reflection, like calling getMethod on ShimReflectionUtils.
*/
object ShimReflectionUtils extends Logging {
object ShimReflectionUtils {

val log = org.slf4j.LoggerFactory.getLogger(getClass().getName().stripSuffix("$"))

def loadClass(className: String): Class[_] = {
val loader = ShimLoader.getShimClassLoader()
logDebug(s"Loading $className using $loader with the parent loader ${loader.getParent}")
log.debug(s"Loading $className using $loader with the parent loader ${loader.getParent}")
loader.loadClass(className)
}

Expand All @@ -37,10 +36,10 @@ object ShimReflectionUtils extends Logging {

// avoid cached constructors
def instantiateClass[T](cls: Class[T]): T = {
logDebug(s"Instantiate ${cls.getName} using classloader " + cls.getClassLoader)
log.debug(s"Instantiate ${cls.getName} using classloader " + cls.getClassLoader)
cls.getClassLoader match {
case urcCl: java.net.URLClassLoader =>
logDebug("urls " + urcCl.getURLs.mkString("\n"))
log.debug("urls " + urcCl.getURLs.mkString("\n"))
case _ =>
}
val constructor = cls.getConstructor()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,12 +18,11 @@ package com.nvidia.spark.udf

import com.nvidia.spark.rapids.ShimLoader

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule

class Plugin extends (SparkSessionExtensions => Unit) with Logging {
class Plugin extends (SparkSessionExtensions => Unit) {
override def apply(extensions: SparkSessionExtensions): Unit = {
extensions.injectResolutionRule(logicalPlanRules)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,15 +16,14 @@

package com.nvidia.spark.rapids

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions, Strategy}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan}

/**
* Extension point to enable GPU SQL processing.
*/
class SQLExecPlugin extends (SparkSessionExtensions => Unit) with Logging {
class SQLExecPlugin extends (SparkSessionExtensions => Unit) {
private val strategyRules: Strategy = ShimLoader.newStrategyRules()

override def apply(extensions: SparkSessionExtensions): Unit = {
Expand Down
Loading
Loading