Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML-187]Support spark 3.1.3 and 3.2.0 and support CDH #197

Merged
merged 7 commits into from
Apr 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions mllib-dal/src/main/scala/com/intel/oap/mllib/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ package com.intel.oap.mllib
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.{SparkConf, SparkContext}

import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext}
import java.net.InetAddress

object Utils {
Expand Down Expand Up @@ -155,4 +154,16 @@ object Utils {
// Return executor number (exclude driver)
executorInfos.length - 1
}
def getSparkVersion(): String = {
// For example: CHD spark version is 3.1.1.3.1.7290.5-2.
// The string before the third dot is the spark version.
val array = SPARK_VERSION.split("\\.")
val sparkVersion = if (array.size > 3) {
val version = array.take(3).mkString(".")
version
} else {
SPARK_VERSION
}
sparkVersion
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

package com.intel.oap.mllib.classification

import com.intel.oap.mllib.Utils

import org.apache.spark.internal.Logging
import org.apache.spark.ml.classification.NaiveBayesModel
import org.apache.spark.ml.classification.spark320.{NaiveBayes => NaiveBayesSpark320}
import org.apache.spark.ml.classification.spark321.{NaiveBayes => NaiveBayesSpark321}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.Dataset
import org.apache.spark.{SPARK_VERSION, SparkException}
Expand All @@ -31,8 +33,9 @@ trait NaiveBayesShim extends Logging {
object NaiveBayesShim extends Logging {
def create(uid: String): NaiveBayesShim = {
logInfo(s"Loading NaiveBayes for Spark $SPARK_VERSION")
val shim = SPARK_VERSION match {
case "3.1.1" | "3.1.2" | "3.2.0" => new NaiveBayesSpark320(uid)

val shim = Utils.getSparkVersion() match {
case "3.1.1" | "3.1.2" | "3.1.3" | "3.2.0" | "3.2.1" => new NaiveBayesSpark321(uid)
case _ => throw new SparkException(s"Unsupported Spark version $SPARK_VERSION")
}
shim
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

package com.intel.oap.mllib.clustering

import com.intel.oap.mllib.Utils

import org.apache.spark.internal.Logging
import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
import org.apache.spark.ml.clustering.spark320.{KMeans => KMeansSpark320}
import org.apache.spark.ml.clustering.spark321.{KMeans => KMeansSpark321}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.Dataset
import org.apache.spark.{SPARK_VERSION, SparkException}
Expand All @@ -31,8 +33,8 @@ trait KMeansShim extends Logging {
object KMeansShim extends Logging {
def create(uid: String): KMeansShim = {
logInfo(s"Loading KMeans for Spark $SPARK_VERSION")
val kmeans = SPARK_VERSION match {
case "3.1.1" | "3.1.2" | "3.2.0" => new KMeansSpark320(uid)
val kmeans = Utils.getSparkVersion() match {
case "3.1.1" | "3.1.2" | "3.1.3" | "3.2.0" | "3.2.1" => new KMeansSpark321(uid)
case _ => throw new SparkException(s"Unsupported Spark version $SPARK_VERSION")
}
kmeans
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

package com.intel.oap.mllib.feature

import com.intel.oap.mllib.Utils

import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.PCAModel
import org.apache.spark.ml.feature.spark320.{PCA => PCASpark320}
import org.apache.spark.ml.feature.spark321.{PCA => PCASpark321}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.Dataset
import org.apache.spark.{SPARK_VERSION, SparkException}
Expand All @@ -31,8 +33,8 @@ trait PCAShim extends Logging {
object PCAShim extends Logging {
def create(uid: String): PCAShim = {
logInfo(s"Loading PCA for Spark $SPARK_VERSION")
val pca = SPARK_VERSION match {
case "3.1.1" | "3.1.2" | "3.2.0" => new PCASpark320(uid)
val pca = Utils.getSparkVersion() match {
case "3.1.1" | "3.1.2" | "3.1.3" | "3.2.0" | "3.2.1" => new PCASpark321(uid)
case _ => throw new SparkException(s"Unsupported Spark version $SPARK_VERSION")
}
pca
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@

package com.intel.oap.mllib.recommendation

import com.intel.oap.mllib.Utils

import org.apache.spark.internal.Logging
import org.apache.spark.ml.recommendation.ALS.Rating
import org.apache.spark.ml.recommendation.spark312.{ALS => ALSSpark312}
import org.apache.spark.ml.recommendation.spark320.{ALS => ALSSpark320}
import org.apache.spark.ml.recommendation.spark313.{ALS => ALSSpark313}
import org.apache.spark.ml.recommendation.spark321.{ALS => ALSSpark321}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.{SPARK_VERSION, SparkException}
Expand All @@ -46,9 +48,9 @@ trait ALSShim extends Serializable with Logging {
object ALSShim extends Logging {
def create(): ALSShim = {
logInfo(s"Loading ALS for Spark $SPARK_VERSION")
val als = SPARK_VERSION match {
case "3.1.1" | "3.1.2" => new ALSSpark312()
case "3.2.0" => new ALSSpark320()
val als = Utils.getSparkVersion() match {
case "3.1.1" | "3.1.2" | "3.1.3" => new ALSSpark313()
case "3.2.0" | "3.2.1" => new ALSSpark321()
case _ => throw new SparkException(s"Unsupported Spark version $SPARK_VERSION")
}
als
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@

package com.intel.oap.mllib.regression

import com.intel.oap.mllib.Utils

import org.apache.spark.internal.Logging
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.regression.LinearRegressionModel
import org.apache.spark.ml.regression.spark312.{LinearRegression => LinearRegressionSpark312}
import org.apache.spark.ml.regression.spark320.{LinearRegression => LinearRegressionSpark320}
import org.apache.spark.ml.regression.spark313.{LinearRegression => LinearRegressionSpark313}
import org.apache.spark.ml.regression.spark321.{LinearRegression => LinearRegressionSpark321}
import org.apache.spark.sql.Dataset
import org.apache.spark.{SPARK_VERSION, SparkException}

Expand All @@ -32,9 +34,9 @@ trait LinearRegressionShim extends Serializable with Logging {
object LinearRegressionShim extends Logging {
def create(uid: String): LinearRegressionShim = {
logInfo(s"Loading ALS for Spark $SPARK_VERSION")
val linearRegression = SPARK_VERSION match {
case "3.1.1" | "3.1.2" => new LinearRegressionSpark312(uid)
case "3.2.0" => new LinearRegressionSpark320(uid)
val linearRegression = Utils.getSparkVersion() match {
case "3.1.1" | "3.1.2" | "3.1.3" => new LinearRegressionSpark313(uid)
case "3.2.0" | "3.2.1" => new LinearRegressionSpark321(uid)
case _ => throw new SparkException(s"Unsupported Spark version $SPARK_VERSION")
}
linearRegression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.intel.oap.mllib.stat

import com.intel.oap.mllib.Utils
import org.apache.spark.{SPARK_VERSION, SparkException}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.recommendation.ALS.Rating
Expand All @@ -24,8 +25,7 @@ import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.storage.StorageLevel

import scala.reflect.ClassTag

import org.apache.spark.ml.stat.spark320.{Correlation => CorrelationSpark320 }
import org.apache.spark.ml.stat.spark321.{Correlation => CorrelationSpark321}

trait CorrelationShim extends Serializable with Logging {
def corr(dataset: Dataset[_], column: String, method: String): DataFrame
Expand All @@ -34,8 +34,8 @@ trait CorrelationShim extends Serializable with Logging {
object CorrelationShim extends Logging {
def create(): CorrelationShim = {
logInfo(s"Loading Correlation for Spark $SPARK_VERSION")
val als = SPARK_VERSION match {
case "3.1.1" | "3.1.2" | "3.2.0" => new CorrelationSpark320()
val als = Utils.getSparkVersion() match {
case "3.1.1" | "3.1.2" | "3.1.3" | "3.2.0" | "3.2.1" => new CorrelationSpark321()
case _ => throw new SparkException(s"Unsupported Spark version $SPARK_VERSION")
}
als
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@

package com.intel.oap.mllib.stat

import com.intel.oap.mllib.Utils

import org.apache.spark.{SPARK_VERSION, SparkException}
import org.apache.spark.internal.Logging
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.stat.MultivariateStatisticalSummary
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}

import org.apache.spark.mllib.stat.spark320.{Statistics => SummarizerSpark320 }
import org.apache.spark.mllib.stat.spark321.{Statistics => SummarizerSpark321}

trait SummarizerShim extends Serializable with Logging {
def colStats(X: RDD[Vector]): MultivariateStatisticalSummary
Expand All @@ -33,8 +34,8 @@ trait SummarizerShim extends Serializable with Logging {
object SummarizerShim extends Logging {
def create(): SummarizerShim = {
logInfo(s"Loading Summarizer for Spark $SPARK_VERSION")
val summarizer = SPARK_VERSION match {
case "3.1.1" | "3.1.2" | "3.2.0" => new SummarizerSpark320()
val summarizer = Utils.getSparkVersion() match {
case "3.1.1" | "3.1.2" | "3.1.3" | "3.2.0" | "3.2.1" => new SummarizerSpark321()
case _ => throw new SparkException(s"Unsupported Spark version $SPARK_VERSION")
}
summarizer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
*/
// scalastyle:on

package org.apache.spark.ml.classification.spark320
package org.apache.spark.ml.classification.spark321

import com.intel.oap.mllib.Utils
import com.intel.oap.mllib.classification.{NaiveBayesDALImpl, NaiveBayesShim}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
*/
// scalastyle:on

package org.apache.spark.ml.clustering.spark320
package org.apache.spark.ml.clustering.spark321

import com.intel.oap.mllib.Utils
import com.intel.oap.mllib.clustering.{KMeansDALImpl, KMeansShim}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
*/
// scalastyle:on

package org.apache.spark.ml.feature.spark320
package org.apache.spark.ml.feature.spark321

import com.intel.oap.mllib.Utils
import com.intel.oap.mllib.feature.{PCADALImpl, PCAShim}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
*/
// scalastyle:on

package org.apache.spark.ml.recommendation.spark312
package org.apache.spark.ml.recommendation.spark313

import com.github.fommil.netlib.BLAS.{getInstance => blas}
import com.intel.oap.mllib.{Utils => DALUtils}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
*/
// scalastyle:on

package org.apache.spark.ml.recommendation.spark320
package org.apache.spark.ml.recommendation.spark321

import com.github.fommil.netlib.BLAS.{getInstance => blas}
import com.intel.oap.mllib.{Utils => DALUtils}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
*/
// scalastyle:on

package org.apache.spark.ml.regression.spark312
package org.apache.spark.ml.regression.spark313

import breeze.linalg.{DenseVector => BDV}
import breeze.optimize.{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
*/
// scalastyle:on

package org.apache.spark.ml.regression.spark320
package org.apache.spark.ml.regression.spark321

import breeze.linalg.{DenseVector => BDV}
import breeze.optimize.{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
*/
// scalastyle:on

package org.apache.spark.ml.stat.spark320
package org.apache.spark.ml.stat.spark321

import com.intel.oap.mllib.Utils
import com.intel.oap.mllib.stat.{CorrelationDALImpl, CorrelationShim}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package org.apache.spark.mllib.stat.spark320
package org.apache.spark.mllib.stat.spark321

import com.intel.oap.mllib.Utils
import com.intel.oap.mllib.stat.{SummarizerDALImpl, SummarizerShim}
Expand Down